Skip to content
Merged
24 changes: 20 additions & 4 deletions openfold3/core/data/framework/data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,11 @@ class DataModuleConfig(BaseModel):
datasets: list[SerializeAsAny[BaseModel]]
batch_size: int = 1
num_workers: int = 0
prefetch_factor: int | None = None
num_workers_validation: int = 0
prefetch_factor_validation: int | None = None
persistent_workers: bool = False
multiprocessing_context: str | None = None
data_seed: int = 42
epoch_len: int = 1

Expand All @@ -165,8 +169,14 @@ def __init__(self, data_module_config: DataModuleConfig) -> None:

# Possibly initialize directly from DataModuleConfig
self.batch_size = data_module_config.batch_size

self.num_workers = data_module_config.num_workers
self.prefetch_factor = data_module_config.prefetch_factor
self.num_workers_validation = data_module_config.num_workers_validation
self.prefetch_factor_validation = data_module_config.prefetch_factor_validation
self.persistent_workers = data_module_config.persistent_workers
self.multiprocessing_context = data_module_config.multiprocessing_context

self.data_seed = data_module_config.data_seed
self.next_data_seed = data_module_config.data_seed
self.epoch_len = data_module_config.epoch_len
Expand Down Expand Up @@ -408,17 +418,20 @@ def generate_dataloader(self, mode: DatasetMode, sampler: Sampler | None = None)
Returns:
DataLoader: DataLoader object.
"""

# TODO: Val does not need this many workers. Due to memory leak issue,
# reduce workers here to run with more workers overall in training
# as temporary quick fix.
if (
mode == DatasetMode.validation
and DatasetMode.train in self.multi_dataset_config.modes
):
num_workers = self.num_workers_validation
prefetch_factor = self.prefetch_factor_validation
else:
num_workers = self.num_workers
prefetch_factor = self.prefetch_factor

persistent_workers = self.persistent_workers and num_workers > 0
multiprocessing_context = (
self.multiprocessing_context if num_workers > 0 else None
)

generator = self.generators.get(mode)
if generator is None:
Expand All @@ -445,6 +458,9 @@ def generate_dataloader(self, mode: DatasetMode, sampler: Sampler | None = None)
collate_fn=openfold_batch_collator,
generator=self.generators[mode],
worker_init_fn=worker_init_fn,
persistent_workers=persistent_workers,
prefetch_factor=prefetch_factor,
multiprocessing_context=multiprocessing_context,
)

def train_dataloader(self) -> DataLoader:
Expand Down
31 changes: 24 additions & 7 deletions openfold3/core/data/framework/single_datasets/base_of3.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
SingleDataset,
register_dataset,
)
from openfold3.core.data.framework.single_datasets.dataset_utils import warm_lmdb_cache
from openfold3.core.data.io.dataset_cache import read_datacache
from openfold3.core.data.pipelines.featurization.conformer import (
featurize_reference_conformers_of3,
Expand Down Expand Up @@ -153,15 +154,20 @@ def __init__(self, dataset_config) -> None:
# TODO: rename dataset_cache_file to dataset_cache_path to signal that it can be
# a directory or a file
# TODO: potentially expose the LMDB database encoding types
self.dataset_cache = read_datacache(
dataset_config.dataset_paths.dataset_cache_file
)
self._dataset_cache_file = dataset_config.dataset_paths.dataset_cache_file
self.dataset_cache = read_datacache(self._dataset_cache_file)
self.warm_cache()

self.datapoint_cache = {}

if dataset_config.dataset_paths.template_structures_directory is not None:
self.ccd = pdbx.CIFFile.read(dataset_config.dataset_paths.ccd_file)
else:
self.ccd = None
# Only used if template structures are not preprocessed
# Lazy-loaded so the dataset is picklable (forkserver)
self._ccd = None
self._ccd_file = (
dataset_config.dataset_paths.ccd_file
if dataset_config.dataset_paths.template_structures_directory is not None
else None
)

# Dataset configuration
# n_tokens can be set in the getitem method separately for each sample using
Expand All @@ -174,6 +180,17 @@ def __init__(self, dataset_config) -> None:
self.single_moltype = None
self.debug_mode = dataset_config.debug_mode

def warm_cache(self) -> None:
"""Warm the OS page cache for LMDB. No-op for JSON."""
if self._dataset_cache_file.is_dir():
warm_lmdb_cache(self._dataset_cache_file)

@property
def ccd(self):
if self._ccd is None and self._ccd_file is not None:
self._ccd = pdbx.CIFFile.read(self._ccd_file)
return self._ccd

@log_runtime_memory(runtime_dict_key="runtime-create-structure-features")
def create_structure_features(
self,
Expand Down
21 changes: 21 additions & 0 deletions openfold3/core/data/framework/single_datasets/dataset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
import copy
import logging
import os
import time
from itertools import cycle, islice
from pathlib import Path

import pandas as pd
import torch
Expand All @@ -30,6 +32,7 @@
naive_alignment,
)

logger = logging.getLogger(__name__)
worker_seed_log = logging.getLogger(f"{__name__}.worker_seed")


Expand Down Expand Up @@ -153,3 +156,21 @@ def getitem_debug_log(dataset_name: str = "") -> None:
f"pid={os.getpid()} worker_id={worker_id} wi.seed={wi_seed} "
f"wi.base_seed={wi_base_seed} torch.initial_seed={torch_seed}",
)


def warm_file_cache(file_path: Path) -> None:
"""Sequentially read a file to warm the OS page cache."""
file_size_gb = file_path.stat().st_size / (1024**3)
logger.info(f"Warming page cache for {file_path} ({file_size_gb:.1f} GB)...")
t0 = time.monotonic()
chunk_size = 8 * 1024 * 1024
with open(file_path, "rb") as f:
while f.read(chunk_size):
pass
elapsed = time.monotonic() - t0
logger.info(f"Page cache warm complete in {elapsed:.1f}s")


def warm_lmdb_cache(lmdb_directory: Path) -> None:
"""Sequentially read the LMDB data file to warm the OS page cache."""
warm_file_cache(lmdb_directory / "data.mdb")
3 changes: 3 additions & 0 deletions openfold3/core/data/framework/single_datasets/monomer.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ def __init__(self, dataset_config: dict) -> None:
# Datapoint cache
self.create_datapoint_cache()

# Release so fork/forkserver inherits clean state for the first epoch
self.dataset_cache.release_connections()

def create_datapoint_cache(self):
"""Creates the datapoint_cache for uniform sampling.
Expand Down
3 changes: 3 additions & 0 deletions openfold3/core/data/framework/single_datasets/pdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,9 @@ def __init__(self, dataset_config: dict) -> None:
# Datapoint cache
self.create_datapoint_cache()

# Release so fork/forkserver inherits clean state for the first epoch
self.dataset_cache.release_connections()

def create_datapoint_cache(self) -> None:
"""Creates the datapoint_cache with chain/interface probabilities.

Expand Down
3 changes: 3 additions & 0 deletions openfold3/core/data/framework/single_datasets/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ def __init__(self, dataset_config: dict, world_size: int | None = None) -> None:
# Dataset/datapoint cache
self.create_datapoint_cache()

# Release so fork/forkserver inherits clean state for the first epoch
self.dataset_cache.release_connections()

# Cropping should be disabled for validation datasets
if self.crop["token_crop"]["enabled"]:
logger.warning(
Expand Down
3 changes: 3 additions & 0 deletions openfold3/core/data/io/dataset_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,9 @@ def read_datacache(
with lmdb_env.begin() as txn:
dataset_cache_type = json.loads(txn.get(type_key).decode(str_encoding))

# Only one connection can be open at a time, close before creating LMDBDict
lmdb_env.close()

if not dataset_cache_type:
raise ValueError("No type found for this directory.")

Expand Down
67 changes: 38 additions & 29 deletions openfold3/core/data/primitives/caches/format.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

import lmdb

from openfold3.core.data.primitives.caches.lmdb import LMDBDict
from openfold3.core.data.primitives.caches.lmdb import LMDBDict, LMDBEnv
from openfold3.core.data.resources.residues import MoleculeType

K = TypeVar("K")
Expand Down Expand Up @@ -135,6 +135,7 @@ def from_json(cls, file: Path) -> PreprocessingDataCache:
# rerunning preprocessing
elif status == "failed":
release_date = None
experimental_method = None
resolution = None
chains = None
interfaces = None
Expand Down Expand Up @@ -414,7 +415,7 @@ class DatasetCache:
# TODO: update parsers for this base class
@classmethod
def from_json(cls, file: Path) -> DatasetCache:
"""Costructs a datacache from a json.
"""Constructs a datacache from a json.

Args:
file (Path):
Expand All @@ -434,13 +435,15 @@ def from_json(cls, file: Path) -> DatasetCache:
reference_molecule_data=cls._parse_ref_mol_data_json(data),
)

@staticmethod
def _parse_type_json(data: dict) -> None:
# Remove _type field (already an internal private attribute so shouldn't be
# defined as an explicit field)
if "_type" in data:
# This is conditional for legacy compatibility, should be removed after
del data["_type"]

@staticmethod
def _parse_name_json(data: dict) -> str:
return data["name"]

Expand Down Expand Up @@ -479,6 +482,15 @@ def _parse_ref_mol_data_json(cls, data: dict) -> dict:
ref_mol_data[ref_mol_id] = per_ref_mol_data_fmt
return ref_mol_data

def release_connections(self) -> None:
"""
Close any open backend connections so fork inherits clean state.
Each backend reopens lazily on next access. No-op for plain dicts.
"""
for attr in (self.structure_data, self.reference_molecule_data):
if hasattr(attr, "close"):
attr.close()

def to_json(self, file: Path) -> None:
"""Write the dataset cache to a JSON file.

Expand Down Expand Up @@ -523,27 +535,29 @@ def from_lmdb(
DatasetCache:
The constructed datacache.
"""

lmdb_env = lmdb.open(
str(lmdb_directory), readonly=True, lock=False, subdir=True
)

with lmdb_env.begin() as transaction:
lmdb_env = LMDBEnv(str(lmdb_directory))
with lmdb_env.get().begin() as transaction:
_ = cls._parse_type_lmdb(transaction, str_encoding)
name = cls._parse_name_lmdb(transaction, str_encoding)
structure_data = cls._parse_structure_data_lmdb(
lmdb_env, str_encoding, structure_data_encoding
)
reference_molecule_data = cls._parse_ref_mol_data_lmdb(
lmdb_env, str_encoding, reference_molecule_data_encoding
)

return cls(
name=name,
structure_data=structure_data,
reference_molecule_data=reference_molecule_data,
)
structure_data = cls._parse_structure_data_lmdb(
lmdb_env=lmdb_env,
str_encoding=str_encoding,
structure_data_encoding=structure_data_encoding,
)
reference_molecule_data = cls._parse_ref_mol_data_lmdb(
lmdb_env=lmdb_env,
str_encoding=str_encoding,
reference_molecule_data_encoding=reference_molecule_data_encoding,
)

return cls(
name=name,
structure_data=structure_data,
reference_molecule_data=reference_molecule_data,
)

@staticmethod
def _parse_type_lmdb(
transaction: lmdb.Transaction, str_encoding: Literal["utf-8", "pkl"]
) -> str:
Expand All @@ -555,6 +569,7 @@ def _parse_type_lmdb(

return _type

@staticmethod
def _parse_name_lmdb(
transaction: lmdb.Transaction, str_encoding: Literal["utf-8", "pkl"]
) -> str:
Expand All @@ -566,31 +581,25 @@ def _parse_name_lmdb(

return name

@staticmethod
def _parse_structure_data_lmdb(
lmdb_env: lmdb.Environment,
lmdb_env: LMDBEnv,
str_encoding: Literal["utf-8", "pkl"],
structure_data_encoding: Literal["utf-8", "pkl"],
) -> LMDBDict:
from openfold3.core.data.primitives.caches.lmdb import (
LMDBDict,
)

return LMDBDict(
lmdb_env=lmdb_env,
prefix="structure_data",
key_encoding=str_encoding,
value_encoding=structure_data_encoding,
)

@staticmethod
def _parse_ref_mol_data_lmdb(
lmdb_env: lmdb.Environment,
lmdb_env: LMDBEnv,
str_encoding: Literal["utf-8", "pkl"],
reference_molecule_data_encoding: Literal["utf-8", "pkl"],
) -> LMDBDict:
from openfold3.core.data.primitives.caches.lmdb import (
LMDBDict,
)

return LMDBDict(
lmdb_env=lmdb_env,
prefix="reference_molecule_data",
Expand Down
Loading
Loading