diff --git a/openfold3/core/data/framework/data_module.py b/openfold3/core/data/framework/data_module.py index 57d962ef3..cf663dd84 100644 --- a/openfold3/core/data/framework/data_module.py +++ b/openfold3/core/data/framework/data_module.py @@ -155,8 +155,11 @@ class DataModuleConfig(BaseModel): datasets: list[SerializeAsAny[BaseModel]] batch_size: int = 1 num_workers: int = 0 + prefetch_factor: int | None = None num_workers_validation: int = 0 - multiprocessing_context: str = "openfold-default" + prefetch_factor_validation: int | None = None + persistent_workers: bool = False + multiprocessing_context: str | None = "openfold-default" data_seed: int = 42 epoch_len: int = 1 @@ -241,9 +244,14 @@ def __init__(self, data_module_config: DataModuleConfig) -> None: # Possibly initialize directly from DataModuleConfig self.batch_size = data_module_config.batch_size + self.num_workers = data_module_config.num_workers + self.prefetch_factor = data_module_config.prefetch_factor self.num_workers_validation = data_module_config.num_workers_validation + self.prefetch_factor_validation = data_module_config.prefetch_factor_validation + self.persistent_workers = data_module_config.persistent_workers self.multiprocessing_context = data_module_config.multiprocessing_context + self.data_seed = data_module_config.data_seed self.next_data_seed = data_module_config.data_seed self.epoch_len = data_module_config.epoch_len @@ -485,17 +493,24 @@ def generate_dataloader(self, mode: DatasetMode, sampler: Sampler | None = None) Returns: DataLoader: DataLoader object. """ - - # TODO: Val does not need this many workers. Due to memory leak issue, - # reduce workers here to run with more workers overall in training - # as temporary quick fix. if ( mode == DatasetMode.validation and DatasetMode.train in self.multi_dataset_config.modes ): num_workers = self.num_workers_validation + prefetch_factor = self.prefetch_factor_validation else: num_workers = self.num_workers + prefetch_factor = self.prefetch_factor + + persistent_workers = self.persistent_workers and num_workers > 0 + prefetch_factor = prefetch_factor if num_workers > 0 else None + + # Set a sensible default for multiprocesssing start method + # depending on platform and python version. + multiprocessing_context = DataModuleConfig.safe_multiprocessing_context( + self.multiprocessing_context, num_workers + ) generator = self.generators.get(mode) if generator is None: @@ -511,12 +526,6 @@ def generate_dataloader(self, mode: DatasetMode, sampler: Sampler | None = None) # passed explicitly here. worker_init_fn = partial(pl_worker_init_function, rank=self.global_rank) - # Set a sensible default for multiprocesssing start method - # depending on platform and python version. - multiprocessing_context = DataModuleConfig.safe_multiprocessing_context( - self.multiprocessing_context, num_workers - ) - logger.debug( f"Creating {mode} dataloader: " f"num_workers={num_workers}, " @@ -531,6 +540,8 @@ def generate_dataloader(self, mode: DatasetMode, sampler: Sampler | None = None) collate_fn=openfold_batch_collator, generator=self.generators[mode], worker_init_fn=worker_init_fn, + persistent_workers=persistent_workers, + prefetch_factor=prefetch_factor, multiprocessing_context=multiprocessing_context, ) diff --git a/openfold3/core/data/framework/single_datasets/base_of3.py b/openfold3/core/data/framework/single_datasets/base_of3.py index 3b9e01de9..7b6b432f1 100644 --- a/openfold3/core/data/framework/single_datasets/base_of3.py +++ b/openfold3/core/data/framework/single_datasets/base_of3.py @@ -153,15 +153,19 @@ def __init__(self, dataset_config) -> None: # TODO: rename dataset_cache_file to dataset_cache_path to signal that it can be # a directory or a file # TODO: potentially expose the LMDB database encoding types - self.dataset_cache = read_datacache( - dataset_config.dataset_paths.dataset_cache_file - ) + self._dataset_cache_file = dataset_config.dataset_paths.dataset_cache_file + self.dataset_cache = read_datacache(self._dataset_cache_file) + self.datapoint_cache = {} - if dataset_config.dataset_paths.template_structures_directory is not None: - self.ccd = pdbx.CIFFile.read(dataset_config.dataset_paths.ccd_file) - else: - self.ccd = None + # Only used if template structures are not preprocessed + # Lazy-loaded so the dataset is picklable (forkserver) + self._ccd = None + self._ccd_file = ( + dataset_config.dataset_paths.ccd_file + if dataset_config.dataset_paths.template_structures_directory is not None + else None + ) # Dataset configuration # n_tokens can be set in the getitem method separately for each sample using @@ -174,6 +178,12 @@ def __init__(self, dataset_config) -> None: self.single_moltype = None self.debug_mode = dataset_config.debug_mode + @property + def ccd(self): + if self._ccd is None and self._ccd_file is not None: + self._ccd = pdbx.CIFFile.read(self._ccd_file) + return self._ccd + @log_runtime_memory(runtime_dict_key="runtime-create-structure-features") def create_structure_features( self, diff --git a/openfold3/core/data/framework/single_datasets/dataset_utils.py b/openfold3/core/data/framework/single_datasets/dataset_utils.py index 7fbb2ed04..ebee151a2 100644 --- a/openfold3/core/data/framework/single_datasets/dataset_utils.py +++ b/openfold3/core/data/framework/single_datasets/dataset_utils.py @@ -30,6 +30,7 @@ naive_alignment, ) +logger = logging.getLogger(__name__) worker_seed_log = logging.getLogger(f"{__name__}.worker_seed") diff --git a/openfold3/core/data/framework/single_datasets/monomer.py b/openfold3/core/data/framework/single_datasets/monomer.py index db41cedbc..2180d4ece 100644 --- a/openfold3/core/data/framework/single_datasets/monomer.py +++ b/openfold3/core/data/framework/single_datasets/monomer.py @@ -50,6 +50,9 @@ def __init__(self, dataset_config: dict) -> None: # Datapoint cache self.create_datapoint_cache() + # Release so fork/forkserver inherits clean state for the first epoch + self.dataset_cache.release_connections() + def create_datapoint_cache(self): """Creates the datapoint_cache for uniform sampling. diff --git a/openfold3/core/data/framework/single_datasets/pdb.py b/openfold3/core/data/framework/single_datasets/pdb.py index 505f0eebc..c85d37609 100644 --- a/openfold3/core/data/framework/single_datasets/pdb.py +++ b/openfold3/core/data/framework/single_datasets/pdb.py @@ -189,6 +189,9 @@ def __init__(self, dataset_config: dict) -> None: # Datapoint cache self.create_datapoint_cache() + # Release so fork/forkserver inherits clean state for the first epoch + self.dataset_cache.release_connections() + def create_datapoint_cache(self) -> None: """Creates the datapoint_cache with chain/interface probabilities. diff --git a/openfold3/core/data/framework/single_datasets/validation.py b/openfold3/core/data/framework/single_datasets/validation.py index dc7b78abb..e7e6b6b2d 100644 --- a/openfold3/core/data/framework/single_datasets/validation.py +++ b/openfold3/core/data/framework/single_datasets/validation.py @@ -63,6 +63,9 @@ def __init__(self, dataset_config: dict, world_size: int | None = None) -> None: # Dataset/datapoint cache self.create_datapoint_cache() + # Release so fork/forkserver inherits clean state for the first epoch + self.dataset_cache.release_connections() + # Cropping should be disabled for validation datasets if self.crop["token_crop"]["enabled"]: logger.warning( diff --git a/openfold3/core/data/io/dataset_cache.py b/openfold3/core/data/io/dataset_cache.py index 3cb8ede15..5d8c6f1f2 100644 --- a/openfold3/core/data/io/dataset_cache.py +++ b/openfold3/core/data/io/dataset_cache.py @@ -142,13 +142,10 @@ def _read_datacache_file(datacache_path: Path) -> "DataCacheType": else: raise ValueError("Could not determine the type of the dataset cache.") - try: - # Infer which class to build - dataset_cache_class = DATASET_CACHE_CLASS_REGISTRY.get(dataset_cache_type) - except KeyError as exc: - raise ValueError( - f"Unknown dataset cache type: {dataset_cache_type}" - ) from exc + # Infer which class to build + dataset_cache_class = DATASET_CACHE_CLASS_REGISTRY.get(dataset_cache_type) + if dataset_cache_class is None: + raise ValueError(f"Unknown dataset cache type: {dataset_cache_type}") return dataset_cache_class.from_json(datacache_path) @@ -200,13 +197,10 @@ def read_datacache( if not dataset_cache_type: raise ValueError("No type found for this directory.") - try: - # Infer which class to build - dataset_cache_class = DATASET_CACHE_CLASS_REGISTRY.get(dataset_cache_type) - except KeyError as exc: - raise ValueError( - f"Unknown dataset cache type: {dataset_cache_type}" - ) from exc + # Infer which class to build + dataset_cache_class = DATASET_CACHE_CLASS_REGISTRY.get(dataset_cache_type) + if dataset_cache_class is None: + raise ValueError(f"Unknown dataset cache type: {dataset_cache_type}") dataset_cache = dataset_cache_class.from_lmdb( datacache_path, diff --git a/openfold3/core/data/primitives/caches/format.py b/openfold3/core/data/primitives/caches/format.py index fd07cee18..4156b1912 100755 --- a/openfold3/core/data/primitives/caches/format.py +++ b/openfold3/core/data/primitives/caches/format.py @@ -24,7 +24,7 @@ import lmdb -from openfold3.core.data.primitives.caches.lmdb import LMDBDict +from openfold3.core.data.primitives.caches.lmdb import LMDBDict, LMDBEnv from openfold3.core.data.resources.residues import MoleculeType K = TypeVar("K") @@ -135,6 +135,7 @@ def from_json(cls, file: Path) -> PreprocessingDataCache: # rerunning preprocessing elif status == "failed": release_date = None + experimental_method = None resolution = None chains = None interfaces = None @@ -410,12 +411,11 @@ class DatasetCache: _registered = False _format_validated: bool = False - _lmdb_env = None # set by from_lmdb; LMDB forbids multiple opens per directory # TODO: update parsers for this base class @classmethod def from_json(cls, file: Path) -> DatasetCache: - """Costructs a datacache from a json. + """Constructs a datacache from a json. Args: file (Path): @@ -435,6 +435,7 @@ def from_json(cls, file: Path) -> DatasetCache: reference_molecule_data=cls._parse_ref_mol_data_json(data), ) + @staticmethod def _parse_type_json(data: dict) -> None: # Remove _type field (already an internal private attribute so shouldn't be # defined as an explicit field) @@ -442,6 +443,7 @@ def _parse_type_json(data: dict) -> None: # This is conditional for legacy compatibility, should be removed after del data["_type"] + @staticmethod def _parse_name_json(data: dict) -> str: return data["name"] @@ -480,6 +482,15 @@ def _parse_ref_mol_data_json(cls, data: dict) -> dict: ref_mol_data[ref_mol_id] = per_ref_mol_data_fmt return ref_mol_data + def release_connections(self) -> None: + """ + Close any open backend connections so fork inherits clean state. + Each backend reopens lazily on next access. No-op for plain dicts. + """ + for attr in (self.structure_data, self.reference_molecule_data): + if hasattr(attr, "close"): + attr.close() + def to_json(self, file: Path) -> None: """Write the dataset cache to a JSON file. @@ -524,29 +535,29 @@ def from_lmdb( DatasetCache: The constructed datacache. """ - - lmdb_env = lmdb.open( - str(lmdb_directory), readonly=True, lock=False, subdir=True - ) - - with lmdb_env.begin() as transaction: + lmdb_env = LMDBEnv(str(lmdb_directory)) + with lmdb_env.get().begin() as transaction: _ = cls._parse_type_lmdb(transaction, str_encoding) name = cls._parse_name_lmdb(transaction, str_encoding) - structure_data = cls._parse_structure_data_lmdb( - lmdb_env, str_encoding, structure_data_encoding - ) - reference_molecule_data = cls._parse_ref_mol_data_lmdb( - lmdb_env, str_encoding, reference_molecule_data_encoding - ) - instance = cls( + structure_data = cls._parse_structure_data_lmdb( + lmdb_env=lmdb_env, + str_encoding=str_encoding, + structure_data_encoding=structure_data_encoding, + ) + reference_molecule_data = cls._parse_ref_mol_data_lmdb( + lmdb_env=lmdb_env, + str_encoding=str_encoding, + reference_molecule_data_encoding=reference_molecule_data_encoding, + ) + + return cls( name=name, structure_data=structure_data, reference_molecule_data=reference_molecule_data, ) - instance._lmdb_env = lmdb_env - return instance + @staticmethod def _parse_type_lmdb( transaction: lmdb.Transaction, str_encoding: Literal["utf-8", "pkl"] ) -> str: @@ -558,6 +569,7 @@ def _parse_type_lmdb( return _type + @staticmethod def _parse_name_lmdb( transaction: lmdb.Transaction, str_encoding: Literal["utf-8", "pkl"] ) -> str: @@ -569,15 +581,12 @@ def _parse_name_lmdb( return name + @staticmethod def _parse_structure_data_lmdb( - lmdb_env: lmdb.Environment, + lmdb_env: LMDBEnv, str_encoding: Literal["utf-8", "pkl"], structure_data_encoding: Literal["utf-8", "pkl"], ) -> LMDBDict: - from openfold3.core.data.primitives.caches.lmdb import ( - LMDBDict, - ) - return LMDBDict( lmdb_env=lmdb_env, prefix="structure_data", @@ -585,15 +594,12 @@ def _parse_structure_data_lmdb( value_encoding=structure_data_encoding, ) + @staticmethod def _parse_ref_mol_data_lmdb( - lmdb_env: lmdb.Environment, + lmdb_env: LMDBEnv, str_encoding: Literal["utf-8", "pkl"], reference_molecule_data_encoding: Literal["utf-8", "pkl"], ) -> LMDBDict: - from openfold3.core.data.primitives.caches.lmdb import ( - LMDBDict, - ) - return LMDBDict( lmdb_env=lmdb_env, prefix="reference_molecule_data", diff --git a/openfold3/core/data/primitives/caches/lmdb.py b/openfold3/core/data/primitives/caches/lmdb.py index 0ca615ae0..f7317a9d1 100644 --- a/openfold3/core/data/primitives/caches/lmdb.py +++ b/openfold3/core/data/primitives/caches/lmdb.py @@ -28,6 +28,30 @@ V = TypeVar("V") +class LMDBEnv: + """Lazy-opened LMDB environment shared between LMDBDict instances""" + + def __init__(self, path: str) -> None: + self._path = path + self._env: lmdb.Environment | None = None + + def get(self) -> lmdb.Environment: + if self._env is None: + self._env = lmdb.open(self._path, readonly=True, lock=False, subdir=True) + return self._env + + def close(self) -> None: + if self._env is not None: + self._env.close() + self._env = None + + def __getstate__(self) -> dict: + return {"_path": self._path, "_env": None} + + def __setstate__(self, state: dict) -> None: + self.__dict__.update(state) + + def convert_datacache_to_lmdb( dataset_cache_file_or_obj: Union[Path, "DatasetCache"], lmdb_directory: Path, @@ -130,19 +154,24 @@ def convert_datacache_to_lmdb( class LMDBDict(Mapping[K, V], Generic[K, V]): def __init__( self, - lmdb_env: lmdb.Environment, + lmdb_env: LMDBEnv, prefix: str, - separator: chr = ":", + separator: str = ":", key_encoding: Literal["utf-8", "pkl"] = "utf-8", value_encoding: Literal["utf-8", "pkl"] = "pkl", ): """A dict-like class with an LMDB backend for lazy loading of datacache entries. + Takes a shared LMDBEnv instance. Multiple LMDBDict objects for the same + file should share a single LMDBEnv so only one lmdb.Environment is opened + per file per process. Because pickle deduplicates shared references, this + sharing is preserved across fork/forkserver/spawn. + Args: - lmdb_env (lmdb.Environment): - The LMDB environment object. + lmdb_env (LMDBEnv): + Shared lazy env for this LMDB directory. prefix (str): header for fields used to construct keys in lmdb - separator (chr): Single separator character used to construct key + separator (str): Single separator character used to construct key key_encoding (Literal["utf-8", "pkl"]): Encoding of keys. Defaults to "utf-8". value_encoding (Literal["utf-8", "pkl"]): @@ -156,14 +185,11 @@ def __init__( self._prefix = prefix + separator self._key_encoding = key_encoding self._value_encoding = value_encoding + self._n_keys = None # Computed on first __len__ call - with self._lmdb_env.begin() as transaction, transaction.cursor() as cursor: - # Collect all keys - encoded_prefix = prefix.encode(self._key_encoding) - # Assign the number of keys so don't have to store all keys in memory - self._n_keys = len( - [key for key, _ in cursor if key.startswith(encoded_prefix)] - ) + def close(self) -> None: + """Close the underlying env. Reopens lazily on next access.""" + self._lmdb_env.close() def _decode_key(self, key): encoded_prefix = self._prefix.encode(self._key_encoding) @@ -172,7 +198,7 @@ def _decode_key(self, key): def __iter__(self): "Use an iterative method to not have to store all keys in memory." encoded_prefix = self._prefix.encode(self._key_encoding) - with self._lmdb_env.begin() as txn, txn.cursor() as cursor: + with self._lmdb_env.get().begin() as txn, txn.cursor() as cursor: # Seek to the first key >= prefix if cursor.set_range(encoded_prefix): while True: @@ -186,14 +212,32 @@ def __iter__(self): if not cursor.next(): break + def _count_keys(self): + """Count keys matching the prefix.""" + encoded_prefix = self._prefix.encode(self._key_encoding) + count = 0 + with self._lmdb_env.get().begin() as txn, txn.cursor() as cursor: + # Use set_range to jump to the first prefix occurrence + # and avoid scanning the entire LMDB. + if cursor.set_range(encoded_prefix): + while True: + if not cursor.key().startswith(encoded_prefix): + break + count += 1 + if not cursor.next(): + break + return count + def __len__(self): + if self._n_keys is None: + self._n_keys = self._count_keys() return self._n_keys def __getitem__(self, key): - with self._lmdb_env.begin() as transaction: + with self._lmdb_env.get().begin() as transaction: key_bytes = f"{self._prefix}{key}".encode(self._key_encoding) value_bytes = transaction.get(key_bytes) - if not value_bytes: + if value_bytes is None: raise KeyError(key) else: if self._value_encoding == "pkl": diff --git a/openfold3/entry_points/validator.py b/openfold3/entry_points/validator.py index 5acd7d54a..33f0f03f3 100644 --- a/openfold3/entry_points/validator.py +++ b/openfold3/entry_points/validator.py @@ -110,7 +110,11 @@ class DataModuleArgs(BaseModel): batch_size: int = 1 data_seed: int | None = None num_workers: int = 10 + prefetch_factor: int | None = None num_workers_validation: int = 4 + prefetch_factor_validation: int | None = None + multiprocessing_context: str | None = "openfold-default" + persistent_workers: bool = False epoch_len: int = 4 diff --git a/openfold3/tests/core/data/primitives/caches/test_format.py b/openfold3/tests/core/data/primitives/caches/test_format.py index 165ae7fd7..b82f3ec22 100644 --- a/openfold3/tests/core/data/primitives/caches/test_format.py +++ b/openfold3/tests/core/data/primitives/caches/test_format.py @@ -17,19 +17,21 @@ import pytest from openfold3.core.data.io.dataset_cache import read_datacache -from openfold3.core.data.primitives.caches.lmdb import LMDBDict +from openfold3.core.data.primitives.caches.lmdb import LMDBDict, LMDBEnv class TestDatasetCacheFromLMDB: def test_from_lmdb_sets_lmdb_env(self, lmdb_cache): - """read_datacache(lmdb_dir) should set _lmdb_env to a live lmdb.Environment.""" - assert lmdb_cache._lmdb_env is not None - assert isinstance(lmdb_cache._lmdb_env, lmdb.Environment) + """from_lmdb should attach a live LMDBEnv to each LMDBDict.""" + env = lmdb_cache.structure_data._lmdb_env + assert isinstance(env, LMDBEnv) + assert isinstance(env.get(), lmdb.Environment) - def test_from_json_lmdb_env_is_none(self, json_cache): - """from_json should leave _lmdb_env as None.""" + def test_from_json_produces_plain_dicts(self, json_cache): + """from_json should produce plain dict fields, not LMDBDicts.""" cache = read_datacache(json_cache) - assert cache._lmdb_env is None + assert not isinstance(cache.structure_data, LMDBDict) + assert not isinstance(cache.reference_molecule_data, LMDBDict) @pytest.mark.parametrize( "field", diff --git a/openfold3/tests/core/data/primitives/caches/test_lmdb.py b/openfold3/tests/core/data/primitives/caches/test_lmdb.py index 137039ccb..6fd14dbec 100644 --- a/openfold3/tests/core/data/primitives/caches/test_lmdb.py +++ b/openfold3/tests/core/data/primitives/caches/test_lmdb.py @@ -12,16 +12,62 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for the LMDB dict and convert_datacache_to_lmdb.""" +"""Tests for LMDB dict, multiprocessing safety, and convert_datacache_to_lmdb.""" import json +import pickle +import sys import lmdb import pytest -from conftest import TEST_DATASET_CONFIG +import torch +from torch.utils.data import DataLoader, Dataset from openfold3.core.data.io.dataset_cache import read_datacache -from openfold3.core.data.primitives.caches.lmdb import convert_datacache_to_lmdb +from openfold3.core.data.primitives.caches.lmdb import ( + LMDBDict, + LMDBEnv, + convert_datacache_to_lmdb, +) +from openfold3.tests.core.data.primitives.caches.conftest import TEST_DATASET_CONFIG + + +def create_test_lmdb(lmdb_dir, num_items=10): + """Create a small LMDB with ``item:0`` … ``item:N-1`` keys.""" + env = lmdb.open(str(lmdb_dir), map_size=1024 * 1024, subdir=True) + with env.begin(write=True) as txn: + for i in range(num_items): + key = f"item:{i}".encode() + value = json.dumps({"index": i}).encode("utf-8") + txn.put(key, value) + env.close() + + +class LMDBDataset(Dataset): + """Minimal Dataset backed by LMDBEnv + LMDBDict. + + Defined at module level so spawn/forkserver workers can import it. + """ + + def __init__(self, lmdb_dir: str, num_items: int): + self._lmdb_env = LMDBEnv(lmdb_dir) + self._dict = LMDBDict( + lmdb_env=self._lmdb_env, + prefix="item", + key_encoding="utf-8", + value_encoding="utf-8", + ) + self._n = num_items + + def __len__(self): + return self._n + + def __getitem__(self, idx): + item = self._dict[str(idx)] + return torch.tensor(item["index"]) + + def release_connections(self): + self._lmdb_env.close() class TestLMDBDict: @@ -83,3 +129,83 @@ def test_metadata_keys_written(self, tmp_path, json_cache): assert _type == TEST_DATASET_CONFIG["_type"] assert name == TEST_DATASET_CONFIG["name"] + + +class TestLMDBEnvPickle: + def test_raw_lmdb_env_not_pickleable(self, tmp_path): + """Raw lmdb.Environment cannot be pickled — this is the root cause + of spawn/forkserver failures without the LMDBEnv wrapper.""" + lmdb_dir = tmp_path / "raw" + create_test_lmdb(lmdb_dir) + env = lmdb.open(str(lmdb_dir), readonly=True, lock=False, subdir=True) + with pytest.raises(TypeError, match="cannot pickle"): + pickle.dumps(env) + env.close() + + def test_lmdb_env_pickle_roundtrip(self, tmp_path): + """LMDBEnv can be pickled and reads correctly after unpickling.""" + lmdb_dir = tmp_path / "env_pkl" + create_test_lmdb(lmdb_dir, num_items=3) + + env = LMDBEnv(str(lmdb_dir)) + _ = env.get() # force open + + data = pickle.dumps(env) + env.close() # close original — LMDB forbids two open envs for same path + + env2 = pickle.loads(data) + assert env2._env is None # connection stripped by __getstate__ + with env2.get().begin() as txn: + assert txn.get(b"item:0") is not None + env2.close() + + def test_lmdb_dict_pickle_roundtrip(self, tmp_path): + """LMDBDict survives pickle roundtrip and reads correctly.""" + lmdb_dir = tmp_path / "dict_pkl" + num_items = 5 + create_test_lmdb(lmdb_dir, num_items=num_items) + + env = LMDBEnv(str(lmdb_dir)) + d = LMDBDict( + lmdb_env=env, + prefix="item", + key_encoding="utf-8", + value_encoding="utf-8", + ) + original = d["0"] + + data = pickle.dumps(d) + env.close() # close original — LMDB forbids two open envs for same path + + d2 = pickle.loads(data) + assert d2["0"] == original + assert len(d2) == num_items + + +class TestLMDBMultiprocessingDataLoader: + @pytest.mark.parametrize("mp_context", ["fork", "forkserver", "spawn"]) + def test_dataloader_reads_all_items(self, tmp_path, mp_context): + """DataLoader with num_workers>0 reads all LMDB items correctly + across fork, forkserver, and spawn multiprocessing contexts.""" + if mp_context == "fork" and sys.platform == "darwin": + pytest.skip("fork is unsafe on macOS with Python >= 3.8") + + num_items = 20 + lmdb_dir = tmp_path / "mp_lmdb" + create_test_lmdb(lmdb_dir, num_items=num_items) + + dataset = LMDBDataset(str(lmdb_dir), num_items=num_items) + dataset.release_connections() # mimic real codebase: clean state before fork + + loader = DataLoader( + dataset, + batch_size=1, + num_workers=2, + multiprocessing_context=mp_context, + ) + + results = [] + for batch in loader: + results.extend(batch.tolist()) + + assert sorted(results) == list(range(num_items)) diff --git a/openfold3/tests/core/data/primitives/caches/test_read_datacache.py b/openfold3/tests/core/data/primitives/caches/test_read_datacache.py index 1891909c7..95f33be08 100644 --- a/openfold3/tests/core/data/primitives/caches/test_read_datacache.py +++ b/openfold3/tests/core/data/primitives/caches/test_read_datacache.py @@ -22,13 +22,13 @@ class TestReadDatacacheLMDB: def test_type_peek_env_cleaned_up(self, lmdb_dir): - """read_datacache should return a cache with a live _lmdb_env. + """read_datacache should return a cache with a live LMDBEnv. If the internal type-peek env leaked, lmdb.open in from_lmdb would raise 'already open in this process'. """ cache = read_datacache(lmdb_dir) - assert cache._lmdb_env is not None + assert cache.structure_data._lmdb_env is not None def test_returns_correct_type(self, lmdb_dir): """Should infer the correct DatasetCache subclass from _type.""" @@ -54,5 +54,5 @@ def test_invalid_path_raises(self, tmp_path): def test_lmdb_env_is_readonly(self, lmdb_dir): """The env held by from_lmdb should be opened readonly.""" cache = read_datacache(lmdb_dir) - env_flags = cache._lmdb_env.flags() + env_flags = cache.structure_data._lmdb_env.get().flags() assert env_flags["readonly"] is True