Skip to content
33 changes: 22 additions & 11 deletions openfold3/core/data/framework/data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,11 @@ class DataModuleConfig(BaseModel):
datasets: list[SerializeAsAny[BaseModel]]
batch_size: int = 1
num_workers: int = 0
prefetch_factor: int | None = None
num_workers_validation: int = 0
multiprocessing_context: str = "openfold-default"
prefetch_factor_validation: int | None = None
persistent_workers: bool = False
multiprocessing_context: str | None = "openfold-default"
data_seed: int = 42
epoch_len: int = 1

Expand Down Expand Up @@ -241,9 +244,14 @@ def __init__(self, data_module_config: DataModuleConfig) -> None:

# Possibly initialize directly from DataModuleConfig
self.batch_size = data_module_config.batch_size

self.num_workers = data_module_config.num_workers
self.prefetch_factor = data_module_config.prefetch_factor
self.num_workers_validation = data_module_config.num_workers_validation
self.prefetch_factor_validation = data_module_config.prefetch_factor_validation
self.persistent_workers = data_module_config.persistent_workers
self.multiprocessing_context = data_module_config.multiprocessing_context

self.data_seed = data_module_config.data_seed
self.next_data_seed = data_module_config.data_seed
self.epoch_len = data_module_config.epoch_len
Expand Down Expand Up @@ -485,17 +493,24 @@ def generate_dataloader(self, mode: DatasetMode, sampler: Sampler | None = None)
Returns:
DataLoader: DataLoader object.
"""

# TODO: Val does not need this many workers. Due to memory leak issue,
# reduce workers here to run with more workers overall in training
# as temporary quick fix.
if (
mode == DatasetMode.validation
and DatasetMode.train in self.multi_dataset_config.modes
):
num_workers = self.num_workers_validation
prefetch_factor = self.prefetch_factor_validation
else:
num_workers = self.num_workers
prefetch_factor = self.prefetch_factor

persistent_workers = self.persistent_workers and num_workers > 0
prefetch_factor = prefetch_factor if num_workers > 0 else None

# Set a sensible default for multiprocesssing start method
# depending on platform and python version.
multiprocessing_context = DataModuleConfig.safe_multiprocessing_context(
self.multiprocessing_context, num_workers
)

generator = self.generators.get(mode)
if generator is None:
Expand All @@ -511,12 +526,6 @@ def generate_dataloader(self, mode: DatasetMode, sampler: Sampler | None = None)
# passed explicitly here.
worker_init_fn = partial(pl_worker_init_function, rank=self.global_rank)

# Set a sensible default for multiprocesssing start method
# depending on platform and python version.
multiprocessing_context = DataModuleConfig.safe_multiprocessing_context(
self.multiprocessing_context, num_workers
)

logger.debug(
f"Creating {mode} dataloader: "
f"num_workers={num_workers}, "
Expand All @@ -531,6 +540,8 @@ def generate_dataloader(self, mode: DatasetMode, sampler: Sampler | None = None)
collate_fn=openfold_batch_collator,
generator=self.generators[mode],
worker_init_fn=worker_init_fn,
persistent_workers=persistent_workers,
prefetch_factor=prefetch_factor,
multiprocessing_context=multiprocessing_context,
)

Expand Down
24 changes: 17 additions & 7 deletions openfold3/core/data/framework/single_datasets/base_of3.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,15 +153,19 @@ def __init__(self, dataset_config) -> None:
# TODO: rename dataset_cache_file to dataset_cache_path to signal that it can be
# a directory or a file
# TODO: potentially expose the LMDB database encoding types
self.dataset_cache = read_datacache(
dataset_config.dataset_paths.dataset_cache_file
)
self._dataset_cache_file = dataset_config.dataset_paths.dataset_cache_file
self.dataset_cache = read_datacache(self._dataset_cache_file)

self.datapoint_cache = {}

if dataset_config.dataset_paths.template_structures_directory is not None:
self.ccd = pdbx.CIFFile.read(dataset_config.dataset_paths.ccd_file)
else:
self.ccd = None
# Only used if template structures are not preprocessed
# Lazy-loaded so the dataset is picklable (forkserver)
self._ccd = None
self._ccd_file = (
dataset_config.dataset_paths.ccd_file
if dataset_config.dataset_paths.template_structures_directory is not None
else None
)

# Dataset configuration
# n_tokens can be set in the getitem method separately for each sample using
Expand All @@ -174,6 +178,12 @@ def __init__(self, dataset_config) -> None:
self.single_moltype = None
self.debug_mode = dataset_config.debug_mode

@property
def ccd(self):
if self._ccd is None and self._ccd_file is not None:
self._ccd = pdbx.CIFFile.read(self._ccd_file)
return self._ccd

@log_runtime_memory(runtime_dict_key="runtime-create-structure-features")
def create_structure_features(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
naive_alignment,
)

logger = logging.getLogger(__name__)
worker_seed_log = logging.getLogger(f"{__name__}.worker_seed")


Expand Down
3 changes: 3 additions & 0 deletions openfold3/core/data/framework/single_datasets/monomer.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ def __init__(self, dataset_config: dict) -> None:
# Datapoint cache
self.create_datapoint_cache()

# Release so fork/forkserver inherits clean state for the first epoch
self.dataset_cache.release_connections()

def create_datapoint_cache(self):
"""Creates the datapoint_cache for uniform sampling.
Expand Down
3 changes: 3 additions & 0 deletions openfold3/core/data/framework/single_datasets/pdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,9 @@ def __init__(self, dataset_config: dict) -> None:
# Datapoint cache
self.create_datapoint_cache()

# Release so fork/forkserver inherits clean state for the first epoch
self.dataset_cache.release_connections()

def create_datapoint_cache(self) -> None:
"""Creates the datapoint_cache with chain/interface probabilities.
Expand Down
3 changes: 3 additions & 0 deletions openfold3/core/data/framework/single_datasets/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ def __init__(self, dataset_config: dict, world_size: int | None = None) -> None:
# Dataset/datapoint cache
self.create_datapoint_cache()

# Release so fork/forkserver inherits clean state for the first epoch
self.dataset_cache.release_connections()

# Cropping should be disabled for validation datasets
if self.crop["token_crop"]["enabled"]:
logger.warning(
Expand Down
22 changes: 8 additions & 14 deletions openfold3/core/data/io/dataset_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,13 +142,10 @@ def _read_datacache_file(datacache_path: Path) -> "DataCacheType":
else:
raise ValueError("Could not determine the type of the dataset cache.")

try:
# Infer which class to build
dataset_cache_class = DATASET_CACHE_CLASS_REGISTRY.get(dataset_cache_type)
except KeyError as exc:
raise ValueError(
f"Unknown dataset cache type: {dataset_cache_type}"
) from exc
# Infer which class to build
dataset_cache_class = DATASET_CACHE_CLASS_REGISTRY.get(dataset_cache_type)
if dataset_cache_class is None:
raise ValueError(f"Unknown dataset cache type: {dataset_cache_type}")

return dataset_cache_class.from_json(datacache_path)

Expand Down Expand Up @@ -200,13 +197,10 @@ def read_datacache(
if not dataset_cache_type:
raise ValueError("No type found for this directory.")

try:
# Infer which class to build
dataset_cache_class = DATASET_CACHE_CLASS_REGISTRY.get(dataset_cache_type)
except KeyError as exc:
raise ValueError(
f"Unknown dataset cache type: {dataset_cache_type}"
) from exc
# Infer which class to build
dataset_cache_class = DATASET_CACHE_CLASS_REGISTRY.get(dataset_cache_type)
if dataset_cache_class is None:
raise ValueError(f"Unknown dataset cache type: {dataset_cache_type}")

dataset_cache = dataset_cache_class.from_lmdb(
datacache_path,
Expand Down
62 changes: 34 additions & 28 deletions openfold3/core/data/primitives/caches/format.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

import lmdb

from openfold3.core.data.primitives.caches.lmdb import LMDBDict
from openfold3.core.data.primitives.caches.lmdb import LMDBDict, LMDBEnv
from openfold3.core.data.resources.residues import MoleculeType

K = TypeVar("K")
Expand Down Expand Up @@ -135,6 +135,7 @@ def from_json(cls, file: Path) -> PreprocessingDataCache:
# rerunning preprocessing
elif status == "failed":
release_date = None
experimental_method = None
resolution = None
chains = None
interfaces = None
Expand Down Expand Up @@ -410,12 +411,11 @@ class DatasetCache:

_registered = False
_format_validated: bool = False
_lmdb_env = None # set by from_lmdb; LMDB forbids multiple opens per directory

# TODO: update parsers for this base class
@classmethod
def from_json(cls, file: Path) -> DatasetCache:
"""Costructs a datacache from a json.
"""Constructs a datacache from a json.

Args:
file (Path):
Expand All @@ -435,13 +435,15 @@ def from_json(cls, file: Path) -> DatasetCache:
reference_molecule_data=cls._parse_ref_mol_data_json(data),
)

@staticmethod
def _parse_type_json(data: dict) -> None:
# Remove _type field (already an internal private attribute so shouldn't be
# defined as an explicit field)
if "_type" in data:
# This is conditional for legacy compatibility, should be removed after
del data["_type"]

@staticmethod
def _parse_name_json(data: dict) -> str:
return data["name"]

Expand Down Expand Up @@ -480,6 +482,15 @@ def _parse_ref_mol_data_json(cls, data: dict) -> dict:
ref_mol_data[ref_mol_id] = per_ref_mol_data_fmt
return ref_mol_data

def release_connections(self) -> None:
"""
Close any open backend connections so fork inherits clean state.
Each backend reopens lazily on next access. No-op for plain dicts.
"""
for attr in (self.structure_data, self.reference_molecule_data):
if hasattr(attr, "close"):
attr.close()

def to_json(self, file: Path) -> None:
"""Write the dataset cache to a JSON file.

Expand Down Expand Up @@ -524,29 +535,29 @@ def from_lmdb(
DatasetCache:
The constructed datacache.
"""

lmdb_env = lmdb.open(
str(lmdb_directory), readonly=True, lock=False, subdir=True
)

with lmdb_env.begin() as transaction:
lmdb_env = LMDBEnv(str(lmdb_directory))
with lmdb_env.get().begin() as transaction:
_ = cls._parse_type_lmdb(transaction, str_encoding)
name = cls._parse_name_lmdb(transaction, str_encoding)
structure_data = cls._parse_structure_data_lmdb(
lmdb_env, str_encoding, structure_data_encoding
)
reference_molecule_data = cls._parse_ref_mol_data_lmdb(
lmdb_env, str_encoding, reference_molecule_data_encoding
)

instance = cls(
structure_data = cls._parse_structure_data_lmdb(
lmdb_env=lmdb_env,
str_encoding=str_encoding,
structure_data_encoding=structure_data_encoding,
)
reference_molecule_data = cls._parse_ref_mol_data_lmdb(
lmdb_env=lmdb_env,
str_encoding=str_encoding,
reference_molecule_data_encoding=reference_molecule_data_encoding,
)

return cls(
name=name,
structure_data=structure_data,
reference_molecule_data=reference_molecule_data,
)
instance._lmdb_env = lmdb_env
return instance

@staticmethod
def _parse_type_lmdb(
transaction: lmdb.Transaction, str_encoding: Literal["utf-8", "pkl"]
) -> str:
Expand All @@ -558,6 +569,7 @@ def _parse_type_lmdb(

return _type

@staticmethod
def _parse_name_lmdb(
transaction: lmdb.Transaction, str_encoding: Literal["utf-8", "pkl"]
) -> str:
Expand All @@ -569,31 +581,25 @@ def _parse_name_lmdb(

return name

@staticmethod
def _parse_structure_data_lmdb(
lmdb_env: lmdb.Environment,
lmdb_env: LMDBEnv,
str_encoding: Literal["utf-8", "pkl"],
structure_data_encoding: Literal["utf-8", "pkl"],
) -> LMDBDict:
from openfold3.core.data.primitives.caches.lmdb import (
LMDBDict,
)

return LMDBDict(
lmdb_env=lmdb_env,
prefix="structure_data",
key_encoding=str_encoding,
value_encoding=structure_data_encoding,
)

@staticmethod
def _parse_ref_mol_data_lmdb(
lmdb_env: lmdb.Environment,
lmdb_env: LMDBEnv,
str_encoding: Literal["utf-8", "pkl"],
reference_molecule_data_encoding: Literal["utf-8", "pkl"],
) -> LMDBDict:
from openfold3.core.data.primitives.caches.lmdb import (
LMDBDict,
)

return LMDBDict(
lmdb_env=lmdb_env,
prefix="reference_molecule_data",
Expand Down
Loading
Loading