diff --git a/docs/troubleshooting.md b/docs/troubleshooting.md index 816aeab..4c5af9f 100644 --- a/docs/troubleshooting.md +++ b/docs/troubleshooting.md @@ -7,7 +7,7 @@ 1. Ensure you have sufficient space on your base container mount. 1. If you have enough memory, but are running out of buffer space during writes, you can: 1. Increase the default initial buffer size via `initial_write_buffer_size_bytes` in the `wrap` API you are using (the default is 16 GB). - 1. Increase the write thread count, so that each rank writes to multiple buffers, effectively cutting the size of each buffer proportionally, via `write_thread_count` in the `wrap` API you are using (the default is 1). + 1. Increase the number of files per rank, so that each rank writes to multiple buffers, effectively cutting the size of each buffer proportionally, via `write_files_per_rank` in the `wrap` API you are using (the default is 1). ### How can I clean up ML Flashpoint checkpoints after job completion? diff --git a/docs/user-guide.md b/docs/user-guide.md index 3d50b2d..f93cded 100644 --- a/docs/user-guide.md +++ b/docs/user-guide.md @@ -90,7 +90,7 @@ auto_resume = wrap_trainer_and_auto_resume_with_mlflashpoint( async_save=not args.sync_save, default_auto_resume=auto_resume, # Optional # always_save_context=False, # Optional, defaults to False - # write_thread_count=1, # Optional, defaults to 1 + # write_files_per_rank=1, # Optional, defaults to 1 # initial_write_buffer_size_bytes=DESIRED_NUM_BYTES, # Optional, defaults to 16 GB # use_optimized_save=True, # Optional, defaults to True. Uses the optimized save method to reduce write time. # use_cached_ckpt_structure=False, # Optional, defaults to False. Caches the checkpoint structure after identifying 2 consecutive save plan structures that are equal. diff --git a/src/ml_flashpoint/adapter/megatron/save_strategies.py b/src/ml_flashpoint/adapter/megatron/save_strategies.py index bc7452b..f26b5d8 100644 --- a/src/ml_flashpoint/adapter/megatron/save_strategies.py +++ b/src/ml_flashpoint/adapter/megatron/save_strategies.py @@ -108,9 +108,9 @@ def __init__( self._use_cached_ckpt_structure: bool = use_cached_ckpt_structure @property - def thread_count(self) -> int: - """Returns the number of threads used by the storage writer.""" - return self._storage_writer._thread_count + def files_per_rank(self) -> int: + """Returns the number of files per rank used by the storage writer.""" + return self._storage_writer._files_per_rank @override def can_handle_sharded_objects(self) -> bool: @@ -145,7 +145,7 @@ def async_save(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Union self._storage_writer = MemoryStorageWriter( checkpoint_saver=self._checkpoint_saver, mp_manager_future=self._storage_writer._main_process_torchmp_manager_future, - thread_count=self._storage_writer._thread_count, + files_per_rank=self._storage_writer._files_per_rank, ) # 1c. Reset the StorageWriter for this checkpoint version. self._storage_writer.reset(checkpoint_id.data) diff --git a/src/ml_flashpoint/adapter/nemo/wrapper_util.py b/src/ml_flashpoint/adapter/nemo/wrapper_util.py index 564c70a..f86367a 100644 --- a/src/ml_flashpoint/adapter/nemo/wrapper_util.py +++ b/src/ml_flashpoint/adapter/nemo/wrapper_util.py @@ -53,7 +53,7 @@ def wrap_trainer_and_auto_resume_with_mlflashpoint( async_save: bool, default_auto_resume: nl.AutoResume = None, always_save_context: bool = False, - write_thread_count: int = 1, + write_files_per_rank: int = 1, initial_write_buffer_size_bytes: Optional[int] = DEFAULT_INITIAL_BUFFER_SIZE_BYTES, use_optimized_save: bool = True, use_cached_ckpt_structure: bool = False, @@ -72,7 +72,8 @@ def wrap_trainer_and_auto_resume_with_mlflashpoint( async_save: Whether to enable asynchronous saving for checkpoints. default_auto_resume: The default AutoResume configuration to inherit from. always_save_context: Whether to always save the context. Defaults to `False`. - write_thread_count: Optional. The number of threads to use for writing checkpoint data. Defaults to 1. + write_files_per_rank: Optional. The number of files each rank writes to for checkpoint data. + Checkpoint data will be split roughly evenly among the files (per rank). Defaults to 1. initial_write_buffer_size_bytes: Optional. The initial size of the buffer for writing checkpoint data in bytes. Defaults to `DEFAULT_INITIAL_BUFFER_SIZE_BYTES`, even if set to None explicitly. use_cached_ckpt_structure: Whether to reuse the checkpoint structure (plan) from the previous save. @@ -92,7 +93,7 @@ def wrap_trainer_and_auto_resume_with_mlflashpoint( pool_config = BufferPoolConfig( pool_dir_path=os.path.join(str(flashpoint_base_container), "buffer_pool"), rank=trainer.global_rank, - num_buffers=write_thread_count * NUM_OF_BUFFERS_PER_OBJECT, + num_buffers=write_files_per_rank * NUM_OF_BUFFERS_PER_OBJECT, buffer_size=initial_write_buffer_size_bytes or DEFAULT_INITIAL_BUFFER_SIZE_BYTES, ) @@ -119,7 +120,7 @@ def wrap_trainer_and_auto_resume_with_mlflashpoint( async_save=async_save, checkpoint_loader=ckpt_loader, always_save_context=always_save_context, - write_thread_count=write_thread_count, + write_files_per_rank=write_files_per_rank, initial_write_buffer_size_bytes=initial_write_buffer_size_bytes, use_optimized_save=use_optimized_save, use_cached_ckpt_structure=use_cached_ckpt_structure, @@ -142,7 +143,7 @@ def wrap_trainer_checkpoint_io_with_mlflashpoint( async_save: bool, checkpoint_loader: DefaultMLFlashpointCheckpointLoader, always_save_context: bool = False, - write_thread_count: int = 1, + write_files_per_rank: int = 1, initial_write_buffer_size_bytes: Optional[int] = DEFAULT_INITIAL_BUFFER_SIZE_BYTES, use_optimized_save: bool = True, use_cached_ckpt_structure: bool = False, @@ -171,7 +172,8 @@ def wrap_trainer_checkpoint_io_with_mlflashpoint( async_save: Whether to enable asynchronous saving. checkpoint_loader: The checkpoint loader to use. always_save_context: Whether to always save the context. Defaults to `False`. - write_thread_count: Optional. The number of threads to use for writing checkpoint data. Defaults to 1. + write_files_per_rank: Optional. The number of files each rank writes to for checkpoint data. + Checkpoint data will be split roughly evenly among the files (per rank). Defaults to 1. initial_write_buffer_size_bytes: Optional. The initial size of the buffer for writing checkpoint data in bytes. Defaults to `DEFAULT_INITIAL_BUFFER_SIZE_BYTES`, even if set to None explicitly. use_cached_ckpt_structure: Whether to reuse the checkpoint structure (plan) from the previous save. @@ -193,8 +195,8 @@ def wrap_trainer_checkpoint_io_with_mlflashpoint( raise ValueError("The 'ckpt_obj_manager' argument cannot be None.") if replication_manager is None: raise ValueError("The 'replication_manager' argument cannot be None.") - if write_thread_count < 1: - raise ValueError(f"write_thread_count must be >= 1, got {write_thread_count}.") + if write_files_per_rank < 1: + raise ValueError(f"write_files_per_rank must be >= 1, got {write_files_per_rank}.") if initial_write_buffer_size_bytes is None: initial_write_buffer_size_bytes = DEFAULT_INITIAL_BUFFER_SIZE_BYTES if initial_write_buffer_size_bytes <= 0: @@ -271,7 +273,7 @@ def start_manager(): use_optimized_save=use_optimized_save, ), mp_manager_future=mp_manager_future, - thread_count=write_thread_count, + files_per_rank=write_files_per_rank, ), use_cached_ckpt_structure=use_cached_ckpt_structure, ) diff --git a/src/ml_flashpoint/adapter/pytorch/memory_storage_writer.py b/src/ml_flashpoint/adapter/pytorch/memory_storage_writer.py index e527cba..5bbed96 100644 --- a/src/ml_flashpoint/adapter/pytorch/memory_storage_writer.py +++ b/src/ml_flashpoint/adapter/pytorch/memory_storage_writer.py @@ -88,19 +88,20 @@ def __init__( self, checkpoint_saver: MLFlashpointCheckpointSaver, mp_manager_future: concurrent.futures.Future, - thread_count: int = 1, + files_per_rank: int = 1, ): """Initializes the MemoryStorageWriter. Args: checkpoint_saver: An instance of `MLFlashpointCheckpointSaver` used for handling the actual checkpoint saving logic. - mp_manager: A `torch.multiprocessing.Manager` instance for managing - shared state across processes, particularly for write results and events. + mp_manager_future: A `concurrent.futures.Future` that resolves to a + `torch.multiprocessing.Manager` instance for managing shared state + across processes, particularly for write results and events. It is highly recommended to create this manager using a 'spawn' multiprocessing context to avoid inheriting the parent's CUDA context, which prevents CUDA OOM errors during failure recoveries - thread_count: Optional. The number of threads to use for writing checkpoint data. + files_per_rank: Optional. The number of files each rank writes to for checkpoint data. Defaults to 1. If a value less than 1 is provided, it will be reset to 1, and a warning will be logged. """ @@ -108,24 +109,24 @@ def __init__( self._current_checkpoint_id: CheckpointContainerId | None = None self._current_save_id: str | None = None self._checkpoint_saver: MLFlashpointCheckpointSaver = checkpoint_saver - if thread_count < 1: - _LOGGER.warning("thread_count must be >= 1, but was %d. Setting to 1.", thread_count) - thread_count = 1 - self._thread_count = thread_count - # _main_process_torchmp_manager should only be used in the main process, not in the spawned processes. - # This is because mp_manager is not picklable. + if files_per_rank < 1: + _LOGGER.warning("files_per_rank must be >= 1, but was %d. Setting to 1.", files_per_rank) + files_per_rank = 1 + self._files_per_rank = files_per_rank + # _main_process_torchmp_manager_future should only be used in the main process, not in the spawned processes. + # This is because the mp_manager it resolves to is not picklable. self._main_process_torchmp_manager_future = mp_manager_future self._write_events_per_checkpoint_id: Optional[dict[CheckpointContainerId, torch_mp.Event]] = None self._write_results_per_checkpoint_id: Optional[dict[CheckpointContainerId, list[WriteResult]]] = None def __getstate__(self): - """Custom pickling to exclude unpicklable mp_manager.""" + """Custom pickling to exclude unpicklable mp_manager_future.""" state = self.__dict__.copy() state.pop("_main_process_torchmp_manager_future", None) return state def __setstate__(self, state): - """Custom unpickling to restore state and set mp_manager to None.""" + """Custom unpickling to restore state and set mp_manager_future to None.""" self.__dict__.update(state) self._main_process_torchmp_manager_future = None @@ -203,7 +204,7 @@ def prepare_write_data_buckets( ) write_buckets = self.checkpoint_saver.prepare_write_data( - checkpoint_id, plan.items, planner, plan.storage_data.prefix, bucket_count=self._thread_count + checkpoint_id, plan.items, planner, plan.storage_data.prefix, bucket_count=self._files_per_rank ) return write_buckets # self._write_buckets_per_checkpoint_id[checkpoint_id] = write_buckets @@ -243,7 +244,7 @@ def write_staged_data_buckets( write_results = self._checkpoint_saver.write_data( checkpoint_id, write_buckets=staged_write_buckets, - thread_count=self._thread_count, + files_per_rank=self._files_per_rank, replicate_after_write=replicate_after_write, ) end_time = time.perf_counter() diff --git a/src/ml_flashpoint/core/checkpoint_saver.py b/src/ml_flashpoint/core/checkpoint_saver.py index c2126cf..bfe7fac 100644 --- a/src/ml_flashpoint/core/checkpoint_saver.py +++ b/src/ml_flashpoint/core/checkpoint_saver.py @@ -210,7 +210,7 @@ def write_data( self, checkpoint_id: CheckpointContainerId, write_buckets: list[ObjectWriteBucket], - thread_count: int, + files_per_rank: int, replicate_after_write: bool, ) -> list[WriteResult]: """Performs the core write logic for the given write items and checkpoint_id. @@ -225,7 +225,7 @@ def write_data( checkpoint_id: Unique hierarchical ID representing this checkpoint container. This typically follows a directory path structure. write_buckets: A list of `ObjectWriteBucket` objects, each containing resolved data ready for writing. - thread_count: The number of threads to use for writing data. + files_per_rank: The number of files each rank writes to. replicate_after_write: Whether to trigger async replication of each object after it is written. Returns: @@ -371,7 +371,7 @@ def prepare_write_data( ) -> list[ObjectWriteBucket]: bucket_count = max(bucket_count, 1) _LOGGER.debug( - "%s prepare_write_data with prefix: '%s', thread_count: %d", + "%s prepare_write_data with prefix: '%s', files_per_rank: %d", self.__class__.__name__, object_name_prefix, bucket_count, @@ -403,7 +403,7 @@ def _clone_if_needed(tensor: torch.Tensor): # NOTE: There is support for multiple threads, to simplify modifying that setting, but we typically # only use 1 thread. - # Group items into buckets, one bucket per file, up to thread_count files + # Group items into buckets, one bucket per file, up to files_per_rank files buckets = _split_by_size_and_type(bucket_count, write_items) for bucket in buckets: if not bucket: @@ -437,22 +437,22 @@ def write_data( checkpoint_id: CheckpointContainerId, write_buckets: list[ObjectWriteBucket], replicate_after_write: bool, - thread_count: int = 1, + files_per_rank: int = 1, ) -> list[WriteResult]: - thread_count = max(thread_count, 1) + files_per_rank = max(files_per_rank, 1) num_cpus = os.cpu_count() or 1 num_ranks = max(get_accelerator_count(), 1) # Use 50% of available CPU cores for PyTorch intra-op threads and evenly distribute them across ranks. - torch_thread_count = max(1, num_cpus // 2 // num_ranks // thread_count) + torch_thread_count = max(1, num_cpus // 2 // num_ranks // files_per_rank) original_num_threads = torch.get_num_threads() # Explicitly set PyTorch intra-op threads to optimize for performance. # This also avoids potential runtime errors in tensor.copy_() with concurrent writers torch.set_num_threads(torch_thread_count) _LOGGER.debug( - "%s starting multi-threaded write_data. thread_count: %d, original_num_threads: %d, " + "%s starting multi-threaded write_data. files_per_rank: %d, original_num_threads: %d, " "num_cpus: %d, num_ranks: %d, torch_thread_count: %d", self.__class__.__name__, - thread_count, + files_per_rank, original_num_threads, num_cpus, num_ranks, @@ -471,8 +471,8 @@ def write_data( threads = [] # Kick off additional threads to main thread, if any. - _LOGGER.debug("Spawning %d extra writer threads (in addition to the main thread).", thread_count - 1) - for i in range(1, thread_count): + _LOGGER.debug("Spawning %d extra writer threads (in addition to the main thread).", files_per_rank - 1) + for i in range(1, files_per_rank): thread = threading.Thread( target=self._write_to_buffer_from_queue_worker, args=(object_items_queue, results_from_threads, replicate_after_write, self._use_optimized_save), diff --git a/tests/adapter/megatron/test_save_strategies.py b/tests/adapter/megatron/test_save_strategies.py index 00174a0..55caaf0 100644 --- a/tests/adapter/megatron/test_save_strategies.py +++ b/tests/adapter/megatron/test_save_strategies.py @@ -48,7 +48,7 @@ def checkpoint_saver() -> MLFlashpointCheckpointSaver: def storage_writer(mocker, checkpoint_saver) -> MemoryStorageWriter: # Using a real MemoryStorageWriter instance instead of a mock. # We can still spy on its methods if needed. - # The mp_manager is mocked as it's not relevant to these tests. + # The mp_manager_future is mocked as it's not relevant to these tests. return MemoryStorageWriter( checkpoint_saver=checkpoint_saver, mp_manager_future=mocker.MagicMock(), @@ -193,18 +193,19 @@ def test_async_save_initialization_calls_success( mock_memory_storage_writer_cls.assert_called_once_with( checkpoint_saver=checkpoint_saver, mp_manager_future=storage_writer._main_process_torchmp_manager_future, - thread_count=storage_writer._thread_count, + files_per_rank=storage_writer._files_per_rank, ) mock_new_storage_writer_instance.reset.assert_called_once_with(checkpoint_id.data) mock_new_storage_writer_instance.stage_write_data_buckets.assert_called_once_with( checkpoint_id, dummy_write_buckets, non_blocking=True ) - @pytest.mark.parametrize("expected_thread_count", [1, 2, 3, 5]) - def test_async_save_reinitializes_storage_writer_with_thread_count( - self, mocker, async_save_setup, storage_writer, checkpoint_saver, dummy_write_buckets, expected_thread_count + @pytest.mark.parametrize("expected_files_per_rank", [1, 2, 3, 5]) + def test_async_save_reinitializes_storage_writer_with_files_per_rank( + self, mocker, async_save_setup, storage_writer, checkpoint_saver, dummy_write_buckets, + expected_files_per_rank, ): - """Tests that the StorageWriter is re-initialized with the correct thread_count.""" + """Tests that the StorageWriter is re-initialized with the correct files_per_rank.""" # Given mock_statedictsaver = mocker.patch("ml_flashpoint.adapter.megatron.save_strategies.statedictsaver") ( @@ -221,8 +222,8 @@ def test_async_save_reinitializes_storage_writer_with_thread_count( False, ) - # Set a specific thread_count on the original storage_writer - storage_writer._thread_count = expected_thread_count + # Set a specific files_per_rank on the original storage_writer + storage_writer._files_per_rank = expected_files_per_rank mock_memory_storage_writer_cls = mocker.patch( "ml_flashpoint.adapter.megatron.save_strategies.MemoryStorageWriter" @@ -235,7 +236,7 @@ def test_async_save_reinitializes_storage_writer_with_thread_count( mock_memory_storage_writer_cls.assert_called_once_with( checkpoint_saver=checkpoint_saver, mp_manager_future=storage_writer._main_process_torchmp_manager_future, - thread_count=expected_thread_count, + files_per_rank=expected_files_per_rank, ) def test_initialize_checkpoint_failure(self, mocker, async_save_setup, checkpoint_saver): diff --git a/tests/adapter/nemo/test_checkpoint_io.py b/tests/adapter/nemo/test_checkpoint_io.py index 3b9e050..c4f6495 100644 --- a/tests/adapter/nemo/test_checkpoint_io.py +++ b/tests/adapter/nemo/test_checkpoint_io.py @@ -808,8 +808,8 @@ def test_successful_initialization(self, mocker): async_save=True, flashpoint_base_dir="/mlf/checkpoints", ) - # Mock the thread count needed for buffer pool init - mock_checkpoint_io.save_strategy.thread_count = 1 + # Mock the files_per_rank needed for buffer pool init + mock_checkpoint_io.save_strategy.files_per_rank = 1 mock_mlf_queue = mocker.MagicMock() mock_alt_queue = mocker.MagicMock() @@ -854,7 +854,7 @@ def test_save_ml_flashpoint_checkpoint(self, mocker): flashpoint_base_dir="/mlf/checkpoints", ) mock_checkpoint_io.trainer.global_rank = 0 - mock_checkpoint_io.save_strategy.thread_count = 1 + mock_checkpoint_io.save_strategy.files_per_rank = 1 mock_checkpoint_io.flashpoint_base_dir = "/mlf/checkpoints" mock_mlf_queue = MagicMock() mock_alt_queue = MagicMock() @@ -890,7 +890,7 @@ def test_save_alternative_checkpoint(self, mocker): flashpoint_base_dir="/mlf/checkpoints", ) mock_checkpoint_io.trainer.global_rank = 0 - mock_checkpoint_io.save_strategy.thread_count = 1 + mock_checkpoint_io.save_strategy.files_per_rank = 1 mock_checkpoint_io.flashpoint_base_dir = "/mlf/checkpoints" mock_mlf_queue = mocker.MagicMock() mock_alt_queue = mocker.MagicMock() @@ -923,7 +923,7 @@ def test_save_with_external_finalize_fn(self, mocker): flashpoint_base_dir="/mlf/checkpoints", ) mock_checkpoint_io.trainer.global_rank = 0 - mock_checkpoint_io.save_strategy.thread_count = 1 + mock_checkpoint_io.save_strategy.files_per_rank = 1 mock_checkpoint_io.flashpoint_base_dir = "/mlf/checkpoints" self.mock_async_calls_queue_cls.side_effect = [mocker.MagicMock(), mocker.MagicMock()] instance = MLFlashpointAsyncFinalizableCheckpointIO(mock_checkpoint_io) @@ -961,7 +961,7 @@ def test_no_unfinalized_calls(self, mocker): flashpoint_base_dir="/mlf/checkpoints", ) mock_checkpoint_io.trainer.global_rank = 0 - mock_checkpoint_io.save_strategy.thread_count = 1 + mock_checkpoint_io.save_strategy.files_per_rank = 1 mock_mlf_queue = mocker.MagicMock() mock_alt_queue = mocker.MagicMock() self.mock_async_calls_queue_cls.side_effect = [mock_mlf_queue, mock_alt_queue] @@ -991,7 +991,7 @@ def test_finalize_mlf_calls_only(self, mocker): flashpoint_base_dir="/mlf/checkpoints", ) mock_checkpoint_io.trainer.global_rank = 0 - mock_checkpoint_io.save_strategy.thread_count = 1 + mock_checkpoint_io.save_strategy.files_per_rank = 1 mock_mlf_queue = mocker.MagicMock() mock_alt_queue = mocker.MagicMock() self.mock_async_calls_queue_cls.side_effect = [mock_mlf_queue, mock_alt_queue] @@ -1023,7 +1023,7 @@ def test_finalize_alt_calls_only(self, mocker): flashpoint_base_dir="/mlf/checkpoints", ) mock_checkpoint_io.trainer.global_rank = 0 - mock_checkpoint_io.save_strategy.thread_count = 1 + mock_checkpoint_io.save_strategy.files_per_rank = 1 mock_mlf_queue = mocker.MagicMock() mock_alt_queue = mocker.MagicMock() self.mock_async_calls_queue_cls.side_effect = [mock_mlf_queue, mock_alt_queue] @@ -1055,7 +1055,7 @@ def test_finalize_both_queues(self, mocker): flashpoint_base_dir="/mlf/checkpoints", ) mock_checkpoint_io.trainer.global_rank = 0 - mock_checkpoint_io.save_strategy.thread_count = 1 + mock_checkpoint_io.save_strategy.files_per_rank = 1 mock_mlf_queue = mocker.MagicMock() mock_alt_queue = mocker.MagicMock() self.mock_async_calls_queue_cls.side_effect = [mock_mlf_queue, mock_alt_queue] @@ -1095,7 +1095,7 @@ def test_teardown_with_no_pending_saves(self, mocker): flashpoint_base_dir="/mlf/checkpoints", ) mock_checkpoint_io.trainer.global_rank = 0 - mock_checkpoint_io.save_strategy.thread_count = 1 + mock_checkpoint_io.save_strategy.files_per_rank = 1 mock_mlf_queue = MagicMock() mock_alt_queue = MagicMock() self.mock_async_calls_queue_cls.side_effect = [mock_mlf_queue, mock_alt_queue] @@ -1123,7 +1123,7 @@ def test_teardown_with_pending_mlf_saves(self, mocker): flashpoint_base_dir="/mlf/checkpoints", ) mock_checkpoint_io.trainer.global_rank = 0 - mock_checkpoint_io.save_strategy.thread_count = 1 + mock_checkpoint_io.save_strategy.files_per_rank = 1 mock_mlf_queue = mocker.MagicMock() mock_alt_queue = mocker.MagicMock() self.mock_async_calls_queue_cls.side_effect = [mock_mlf_queue, mock_alt_queue] @@ -1151,7 +1151,7 @@ def test_teardown_with_pending_alt_saves(self, mocker): flashpoint_base_dir="/mlf/checkpoints", ) mock_checkpoint_io.trainer.global_rank = 0 - mock_checkpoint_io.save_strategy.thread_count = 1 + mock_checkpoint_io.save_strategy.files_per_rank = 1 mock_mlf_queue = mocker.MagicMock() mock_alt_queue = mocker.MagicMock() self.mock_async_calls_queue_cls.side_effect = [mock_mlf_queue, mock_alt_queue] @@ -1179,7 +1179,7 @@ def test_buffer_pool_teardown_scheduled(self, mocker): flashpoint_base_dir="/mlf/checkpoints", ) mock_checkpoint_io.trainer.global_rank = 0 - mock_checkpoint_io.save_strategy.thread_count = 1 + mock_checkpoint_io.save_strategy.files_per_rank = 1 mock_mlf_queue = mocker.MagicMock() mock_alt_queue = mocker.MagicMock() self.mock_async_calls_queue_cls.side_effect = [mock_mlf_queue, mock_alt_queue] @@ -1217,7 +1217,7 @@ def test_teardown_handles_closed_queue(self, mocker): flashpoint_base_dir="/mlf/checkpoints", ) mock_checkpoint_io.trainer.global_rank = 0 - mock_checkpoint_io.save_strategy.thread_count = 1 + mock_checkpoint_io.save_strategy.files_per_rank = 1 mock_mlf_queue = mocker.MagicMock() mock_alt_queue = mocker.MagicMock() self.mock_async_calls_queue_cls.side_effect = [mock_mlf_queue, mock_alt_queue] @@ -1258,7 +1258,7 @@ def test_full_lifecycle_mlf_checkpoint(self, mocker): flashpoint_base_dir="/mlf/checkpoints", ) mock_checkpoint_io.trainer.global_rank = 0 - mock_checkpoint_io.save_strategy.thread_count = 1 + mock_checkpoint_io.save_strategy.files_per_rank = 1 mock_checkpoint_io.flashpoint_base_dir = "/mlf/checkpoints" mock_mlf_queue = MagicMock() mock_alt_queue = MagicMock() @@ -1295,7 +1295,7 @@ def test_full_lifecycle_alt_checkpoint(self, mocker): flashpoint_base_dir="/mlf/checkpoints", ) mock_checkpoint_io.trainer.global_rank = 0 - mock_checkpoint_io.save_strategy.thread_count = 1 + mock_checkpoint_io.save_strategy.files_per_rank = 1 mock_checkpoint_io.flashpoint_base_dir = "/mlf/checkpoints" mock_mlf_queue = mocker.MagicMock() mock_alt_queue = mocker.MagicMock() @@ -1329,7 +1329,7 @@ def test_interleaved_checkpoints_are_finalized_independently(self, mocker): flashpoint_base_dir="/mlf/checkpoints", ) mock_checkpoint_io.trainer.global_rank = 0 - mock_checkpoint_io.save_strategy.thread_count = 1 + mock_checkpoint_io.save_strategy.files_per_rank = 1 mock_checkpoint_io.flashpoint_base_dir = "/mlf/checkpoints" mock_mlf_queue = mocker.MagicMock() mock_alt_queue = mocker.MagicMock() diff --git a/tests/adapter/nemo/test_wrapper_util.py b/tests/adapter/nemo/test_wrapper_util.py index 1a16b17..ee6a5a0 100644 --- a/tests/adapter/nemo/test_wrapper_util.py +++ b/tests/adapter/nemo/test_wrapper_util.py @@ -128,7 +128,7 @@ def test_successful_wrap_and_resume_creation(self, mocker, mock_ckpt_obj_manager async_save=async_save, checkpoint_loader=actual_auto_resume.checkpoint_loader, always_save_context=False, - write_thread_count=1, + write_files_per_rank=1, initial_write_buffer_size_bytes=DEFAULT_INITIAL_BUFFER_SIZE_BYTES, use_optimized_save=True, use_cached_ckpt_structure=False, @@ -252,16 +252,16 @@ def test_initial_save_buffer_size_forwarding( assert kwargs["initial_buffer_size_bytes"] == expected_buffer_size @pytest.mark.parametrize( - "thread_count_kwarg, expected_thread_count", + "files_per_rank_kwarg, expected_files_per_rank", [ ({}, 1), - ({"write_thread_count": 4}, 4), + ({"write_files_per_rank": 4}, 4), ], ) - def test_write_thread_count_forwarding( - self, mocker, mock_ckpt_obj_manager, mock_replication_manager, thread_count_kwarg, expected_thread_count + def test_write_files_per_rank_forwarding( + self, mocker, mock_ckpt_obj_manager, mock_replication_manager, files_per_rank_kwarg, expected_files_per_rank ): - """Tests that the write_thread_count is forwarded correctly.""" + """Tests that the write_files_per_rank is forwarded correctly.""" # Given trainer = mocker.MagicMock(spec=nl_trainer.Trainer) trainer.global_rank = 0 @@ -285,14 +285,14 @@ def test_write_thread_count_forwarding( flashpoint_base_container, async_save, default_auto_resume, - **thread_count_kwarg, + **files_per_rank_kwarg, ) # Then # Verify that MemoryStorageWriter was initialized with the correct thread count spy_memory_storage_writer_init.assert_called_once() _, kwargs = spy_memory_storage_writer_init.call_args # Capture kwargs - assert kwargs["thread_count"] == expected_thread_count + assert kwargs["files_per_rank"] == expected_files_per_rank @pytest.mark.parametrize("always_save_context", [True, False]) def test_loader_initialization_arguments(self, mocker, always_save_context): @@ -469,18 +469,18 @@ def test_validation_missing_replication_manager(self, mocker, mock_ckpt_obj_mana checkpoint_loader=mocker.MagicMock(spec=DefaultMLFlashpointCheckpointLoader), ) - def test_validation_invalid_write_thread_count(self, mocker, mock_ckpt_obj_manager, mock_replication_manager): - """Tests validation check for invalid write thread count.""" + def test_validation_invalid_write_files_per_rank(self, mocker, mock_ckpt_obj_manager, mock_replication_manager): + """Tests validation check for invalid write_files_per_rank.""" trainer = mocker.MagicMock(spec=nl_trainer.Trainer) base_container = "/test_base_container" - with pytest.raises(ValueError, match="write_thread_count must be >= 1"): + with pytest.raises(ValueError, match="write_files_per_rank must be >= 1"): wrap_trainer_checkpoint_io_with_mlflashpoint( trainer, base_container, mock_ckpt_obj_manager, replication_manager=mock_replication_manager, async_save=True, - write_thread_count=0, + write_files_per_rank=0, checkpoint_loader=mocker.MagicMock(spec=DefaultMLFlashpointCheckpointLoader), ) @@ -755,7 +755,7 @@ def test_idempotency_check_with_mlf_async_wrapper_and_async_save_true( mlf_io.trainer = mocker.MagicMock() mlf_io.trainer.global_rank = 0 mlf_io.save_strategy = mocker.MagicMock() - mlf_io.save_strategy._storage_writer._thread_count = 1 + mlf_io.save_strategy._storage_writer._files_per_rank = 1 mlf_io.chkpt_obj_manager = mock_ckpt_obj_manager original_async_wrapped_mlf_io = MLFlashpointAsyncFinalizableCheckpointIO(mlf_io) @@ -924,7 +924,7 @@ def test_invalid_config_with_mlf_async_wrapper_and_async_save_false( mlf_io.trainer = mocker.MagicMock() mlf_io.trainer.global_rank = 0 mlf_io.save_strategy = mocker.MagicMock() - mlf_io.save_strategy._storage_writer._thread_count = 1 + mlf_io.save_strategy._storage_writer._files_per_rank = 1 mlf_io.chkpt_obj_manager = mock_ckpt_obj_manager original_async_wrapped_mlf_io = MLFlashpointAsyncFinalizableCheckpointIO(mlf_io) trainer.strategy.checkpoint_io = original_async_wrapped_mlf_io @@ -985,16 +985,16 @@ def test_initial_save_buffer_size_forwarding( assert kwargs["initial_buffer_size_bytes"] == expected_buffer_size @pytest.mark.parametrize( - "thread_count_kwarg, expected_thread_count", + "files_per_rank_kwarg, expected_files_per_rank", [ ({}, 1), - ({"write_thread_count": 4}, 4), + ({"write_files_per_rank": 4}, 4), ], ) - def test_write_thread_count_forwarding( - self, mocker, mock_ckpt_obj_manager, mock_replication_manager, thread_count_kwarg, expected_thread_count + def test_write_files_per_rank_forwarding( + self, mocker, mock_ckpt_obj_manager, mock_replication_manager, files_per_rank_kwarg, expected_files_per_rank ): - """Tests that the write_thread_count is forwarded correctly.""" + """Tests that the write_files_per_rank is forwarded correctly.""" # Given trainer = mocker.MagicMock(spec=nl_trainer.Trainer) trainer.callbacks = [mocker.MagicMock(spec=MLFlashpointCheckpointCallback)] @@ -1014,14 +1014,14 @@ def test_write_thread_count_forwarding( mock_replication_manager, async_save=True, checkpoint_loader=mocker.MagicMock(spec=DefaultMLFlashpointCheckpointLoader), - **thread_count_kwarg, + **files_per_rank_kwarg, ) # Then # Verify that MemoryStorageWriter was initialized with the correct thread count spy_memory_storage_writer_init.assert_called_once() _, kwargs = spy_memory_storage_writer_init.call_args - assert kwargs["thread_count"] == expected_thread_count + assert kwargs["files_per_rank"] == expected_files_per_rank @pytest.mark.parametrize("use_cached_ckpt_structure", [True, False]) def test_cached_ckpt_structure_forwarding( diff --git a/tests/adapter/pytorch/test_memory_storage_writer.py b/tests/adapter/pytorch/test_memory_storage_writer.py index bb11cd8..b493203 100644 --- a/tests/adapter/pytorch/test_memory_storage_writer.py +++ b/tests/adapter/pytorch/test_memory_storage_writer.py @@ -77,10 +77,10 @@ def test_init(self, mocker, mp_manager_future): assert writer._main_process_torchmp_manager_future is mp_manager_future assert writer._write_events_per_checkpoint_id is None assert writer._write_results_per_checkpoint_id is None - assert writer._thread_count == 1 + assert writer._files_per_rank == 1 @pytest.mark.parametrize( - "thread_count, expected_thread_count", + "files_per_rank, expected_files_per_rank", [ (5, 5), (1, 1), @@ -89,16 +89,16 @@ def test_init(self, mocker, mp_manager_future): (-10, 1), ], ) - def test_init_thread_count(self, mocker, mp_manager_future, thread_count, expected_thread_count): - """Tests that the __init__ method sets the _thread_count attribute correctly.""" + def test_init_files_per_rank(self, mocker, mp_manager_future, files_per_rank, expected_files_per_rank): + """Tests that the __init__ method sets the _files_per_rank attribute correctly.""" # Given mock_saver = mocker.MagicMock(spec=MLFlashpointCheckpointSaver) # When writer = MemoryStorageWriter( - checkpoint_saver=mock_saver, mp_manager_future=mp_manager_future, thread_count=thread_count + checkpoint_saver=mock_saver, mp_manager_future=mp_manager_future, files_per_rank=files_per_rank ) # Then - assert writer._thread_count == expected_thread_count + assert writer._files_per_rank == expected_files_per_rank def test_validate_checkpoint_id(self): """Tests the validate_checkpoint_id class method.""" @@ -515,15 +515,15 @@ def test_prepare_write_data_buckets(self, mocker, mp_manager_future): ) assert actual_buckets == expected_buckets - @pytest.mark.parametrize("thread_count", [1, 4, 8]) - def test_prepare_write_data_buckets_with_thread_count(self, mocker, mp_manager_future, thread_count): - """Tests that prepare_write_data_buckets calls the saver with the specified thread_count.""" + @pytest.mark.parametrize("files_per_rank", [1, 4, 8]) + def test_prepare_write_data_buckets_with_files_per_rank(self, mocker, mp_manager_future, files_per_rank): + """Tests that prepare_write_data_buckets calls the saver with the specified files_per_rank.""" # Given mock_saver = mocker.MagicMock(spec=MLFlashpointCheckpointSaver) writer = MemoryStorageWriter( - checkpoint_saver=mock_saver, mp_manager_future=mp_manager_future, thread_count=thread_count + checkpoint_saver=mock_saver, mp_manager_future=mp_manager_future, files_per_rank=files_per_rank ) - checkpoint_id = CheckpointContainerId("/test_checkpoint_with_thread_count") + checkpoint_id = CheckpointContainerId("/test_checkpoint_with_files_per_rank") plan = SavePlan(items=[], storage_data=_StorageDataContext(prefix="__0_")) planner = mocker.MagicMock() expected_buckets = _create_rich_object_write_buckets(checkpoint_id) @@ -535,7 +535,7 @@ def test_prepare_write_data_buckets_with_thread_count(self, mocker, mp_manager_f # Then mock_saver.prepare_write_data.assert_called_once_with( - checkpoint_id, plan.items, planner, plan.storage_data.prefix, bucket_count=thread_count + checkpoint_id, plan.items, planner, plan.storage_data.prefix, bucket_count=files_per_rank ) assert actual_buckets == expected_buckets @@ -637,21 +637,23 @@ def test_write_staged_data_buckets(self, mocker, mp_manager_future): # Then mock_saver.write_data.assert_called_once_with( - checkpoint_id, write_buckets=staged_write_buckets, thread_count=1, replicate_after_write=False + checkpoint_id, write_buckets=staged_write_buckets, files_per_rank=1, replicate_after_write=False ) assert writer._write_events_per_checkpoint_id[checkpoint_id].is_set() assert result_future.wait() == expected_write_results - @pytest.mark.parametrize("thread_count", [1, 4, 8]) - def test_write_staged_data_buckets_with_explicit_thread_count(self, mocker, mp_manager_future, thread_count): + @pytest.mark.parametrize("files_per_rank", [1, 4, 8]) + def test_write_staged_data_buckets_with_explicit_files_per_rank( + self, mocker, mp_manager_future, files_per_rank + ): """Tests that write_staged_data_buckets calls checkpoint_saver.write_data with the specified - thread_count.""" + files_per_rank.""" # Given mock_saver = mocker.MagicMock(spec=MLFlashpointCheckpointSaver) writer = MemoryStorageWriter( - checkpoint_saver=mock_saver, mp_manager_future=mp_manager_future, thread_count=thread_count + checkpoint_saver=mock_saver, mp_manager_future=mp_manager_future, files_per_rank=files_per_rank ) - checkpoint_id = CheckpointContainerId("/test_checkpoint_explicit_thread_count") + checkpoint_id = CheckpointContainerId("/test_checkpoint_explicit_files_per_rank") writer.reset(checkpoint_id.data) staged_write_buckets = _create_rich_object_write_buckets(checkpoint_id) expected_write_results = [ @@ -673,7 +675,7 @@ def test_write_staged_data_buckets_with_explicit_thread_count(self, mocker, mp_m mock_saver.write_data.assert_called_once_with( checkpoint_id, write_buckets=staged_write_buckets, - thread_count=thread_count, + files_per_rank=files_per_rank, replicate_after_write=False, ) assert writer._write_events_per_checkpoint_id[checkpoint_id].is_set() @@ -699,7 +701,7 @@ def test_write_staged_data_buckets_saver_exception(self, mocker, mp_manager_futu # Then mock_saver.write_data.assert_called_once_with( - checkpoint_id, write_buckets=staged_write_buckets, thread_count=1, replicate_after_write=False + checkpoint_id, write_buckets=staged_write_buckets, files_per_rank=1, replicate_after_write=False ) assert not writer._write_events_per_checkpoint_id[ checkpoint_id @@ -1157,7 +1159,7 @@ def test_write_data_item_types(self, writer, mocker, item_type): writer._checkpoint_saver.write_data.assert_called_once_with( checkpoint_id, write_buckets=writer._checkpoint_saver.prepare_write_data.return_value, - thread_count=1, + files_per_rank=1, replicate_after_write=True, ) diff --git a/tests/core/test_checkpoint_saver.py b/tests/core/test_checkpoint_saver.py index 1659b2d..ab35b05 100644 --- a/tests/core/test_checkpoint_saver.py +++ b/tests/core/test_checkpoint_saver.py @@ -264,9 +264,9 @@ def test_write_data_resets_num_threads( # When if exception_in_worker: with pytest.raises(RuntimeError, match="Worker failed"): - saver.write_data(checkpoint_id, [], replicate_after_write=False, thread_count=1) + saver.write_data(checkpoint_id, [], replicate_after_write=False, files_per_rank=1) else: - saver.write_data(checkpoint_id, [], replicate_after_write=False, thread_count=1) + saver.write_data(checkpoint_id, [], replicate_after_write=False, files_per_rank=1) # Then # Verify it was reset to original_num_threads in finally block @@ -310,7 +310,7 @@ def test_write_data_multithreaded(self, chkpt_object_manager, replication_manage ) # When - results = saver.write_data(checkpoint_id, buckets, replicate_after_write=False, thread_count=4) + results = saver.write_data(checkpoint_id, buckets, replicate_after_write=False, files_per_rank=4) # Then assert len(results) == num_items @@ -1149,9 +1149,9 @@ def saver(self, chkpt_object_manager, replication_manager): replication_manager=replication_manager, ) - @pytest.mark.parametrize("thread_count", [1, 2, 3]) - def test_write_data_single_tensor(self, saver, temp_dir_path, chkpt_object_manager, thread_count): - checkpoint_id = CheckpointContainerId(f"{temp_dir_path}/checkpoint_write_data_tensor_th{thread_count}") + @pytest.mark.parametrize("files_per_rank", [1, 2, 3]) + def test_write_data_single_tensor(self, saver, temp_dir_path, chkpt_object_manager, files_per_rank): + checkpoint_id = CheckpointContainerId(f"{temp_dir_path}/checkpoint_write_data_tensor_th{files_per_rank}") os.makedirs(checkpoint_id.data, exist_ok=True) tensor_data = torch.tensor([1, 2, 3]) @@ -1165,13 +1165,13 @@ def test_write_data_single_tensor(self, saver, temp_dir_path, chkpt_object_manag ), ] write_buckets = saver.prepare_write_data( - checkpoint_id, write_items, resolver, "data", bucket_count=thread_count + checkpoint_id, write_items, resolver, "data", bucket_count=files_per_rank ) actual_results = saver.write_data( checkpoint_id, write_buckets, - thread_count=thread_count, + files_per_rank=files_per_rank, replicate_after_write=False, ) @@ -1188,9 +1188,9 @@ def test_write_data_single_tensor(self, saver, temp_dir_path, chkpt_object_manag chkpt_object_manager, ) - @pytest.mark.parametrize("thread_count", [1, 2, 3]) - def test_write_data_single_byteio(self, saver, temp_dir_path, chkpt_object_manager, thread_count): - checkpoint_id = CheckpointContainerId(f"{temp_dir_path}/checkpoint_write_data_byteio_th{thread_count}") + @pytest.mark.parametrize("files_per_rank", [1, 2, 3]) + def test_write_data_single_byteio(self, saver, temp_dir_path, chkpt_object_manager, files_per_rank): + checkpoint_id = CheckpointContainerId(f"{temp_dir_path}/checkpoint_write_data_byteio_th{files_per_rank}") os.makedirs(checkpoint_id.data, exist_ok=True) data_binary = b"test byte data" @@ -1205,13 +1205,13 @@ def test_write_data_single_byteio(self, saver, temp_dir_path, chkpt_object_manag ), ] write_buckets = saver.prepare_write_data( - checkpoint_id, write_items, resolver, "data", bucket_count=thread_count + checkpoint_id, write_items, resolver, "data", bucket_count=files_per_rank ) actual_results = saver.write_data( checkpoint_id, write_buckets, - thread_count=thread_count, + files_per_rank=files_per_rank, replicate_after_write=False, ) @@ -1230,10 +1230,10 @@ def test_write_data_single_byteio(self, saver, temp_dir_path, chkpt_object_manag assert actual_results[0].size_in_bytes == len(data_binary) assert actual_results[0].storage_data.offset == 0 - @pytest.mark.parametrize("thread_count", [1, 2, 3]) - def test_write_data_multiple_tensors(self, saver, temp_dir_path, chkpt_object_manager, thread_count): + @pytest.mark.parametrize("files_per_rank", [1, 2, 3]) + def test_write_data_multiple_tensors(self, saver, temp_dir_path, chkpt_object_manager, files_per_rank): checkpoint_id = CheckpointContainerId( - f"{temp_dir_path}/checkpoint_write_data_multi_tensor_th{thread_count}" + f"{temp_dir_path}/checkpoint_write_data_multi_tensor_th{files_per_rank}" ) os.makedirs(checkpoint_id.data, exist_ok=True) @@ -1257,13 +1257,13 @@ def test_write_data_multiple_tensors(self, saver, temp_dir_path, chkpt_object_ma ), ] write_buckets = saver.prepare_write_data( - checkpoint_id, write_items, resolver, "data", bucket_count=thread_count + checkpoint_id, write_items, resolver, "data", bucket_count=files_per_rank ) actual_results = saver.write_data( checkpoint_id, write_buckets, - thread_count=thread_count, + files_per_rank=files_per_rank, replicate_after_write=False, ) @@ -1308,9 +1308,9 @@ def test_write_data_multiple_tensors(self, saver, temp_dir_path, chkpt_object_ma chkpt_object_manager, ) - @pytest.mark.parametrize("thread_count", [1, 2, 3]) - def test_write_data_mixed_types(self, saver, temp_dir_path, chkpt_object_manager, thread_count): - checkpoint_id = CheckpointContainerId(f"{temp_dir_path}/checkpoint_write_data_mixed_th{thread_count}") + @pytest.mark.parametrize("files_per_rank", [1, 2, 3]) + def test_write_data_mixed_types(self, saver, temp_dir_path, chkpt_object_manager, files_per_rank): + checkpoint_id = CheckpointContainerId(f"{temp_dir_path}/checkpoint_write_data_mixed_th{files_per_rank}") os.makedirs(checkpoint_id.data, exist_ok=True) expected_tensor0 = torch.tensor([1, 2, 3]) @@ -1345,13 +1345,13 @@ def test_write_data_mixed_types(self, saver, temp_dir_path, chkpt_object_manager ), ] write_buckets = saver.prepare_write_data( - checkpoint_id, write_items, resolver, "data", bucket_count=thread_count + checkpoint_id, write_items, resolver, "data", bucket_count=files_per_rank ) actual_results = saver.write_data( checkpoint_id, write_buckets, - thread_count=thread_count, + files_per_rank=files_per_rank, replicate_after_write=False, ) @@ -1432,9 +1432,9 @@ def test_write_data_mixed_types(self, saver, temp_dir_path, chkpt_object_manager assert torch.equal(loaded_tensor, expected_tensor2) current_offset += sinfo.length - @pytest.mark.parametrize("thread_count", [1, 2, 3]) - def test_write_data_io_error(self, saver, temp_dir_path, mocker, thread_count): - checkpoint_id = CheckpointContainerId(f"{temp_dir_path}/checkpoint_write_data_io_error_th{thread_count}") + @pytest.mark.parametrize("files_per_rank", [1, 2, 3]) + def test_write_data_io_error(self, saver, temp_dir_path, mocker, files_per_rank): + checkpoint_id = CheckpointContainerId(f"{temp_dir_path}/checkpoint_write_data_io_error_th{files_per_rank}") os.makedirs(checkpoint_id.data, exist_ok=True) tensor_data = torch.tensor([1, 2, 3]) @@ -1448,7 +1448,7 @@ def test_write_data_io_error(self, saver, temp_dir_path, mocker, thread_count): ), ] write_buckets = saver.prepare_write_data( - checkpoint_id, write_items, resolver, "data", bucket_count=thread_count + checkpoint_id, write_items, resolver, "data", bucket_count=files_per_rank ) # Mock the checkpoint object manager to raise an IOError during buffer creation @@ -1457,19 +1457,19 @@ def test_write_data_io_error(self, saver, temp_dir_path, mocker, thread_count): saver.write_data( checkpoint_id, write_buckets, - thread_count=thread_count, + files_per_rank=files_per_rank, replicate_after_write=False, ) - @pytest.mark.parametrize("thread_count", [1, 2, 3]) - def test_write_data_empty_buckets(self, saver, temp_dir_path, thread_count): - checkpoint_id = CheckpointContainerId(f"{temp_dir_path}/checkpoint_write_data_empty_th{thread_count}") + @pytest.mark.parametrize("files_per_rank", [1, 2, 3]) + def test_write_data_empty_buckets(self, saver, temp_dir_path, files_per_rank): + checkpoint_id = CheckpointContainerId(f"{temp_dir_path}/checkpoint_write_data_empty_th{files_per_rank}") os.makedirs(checkpoint_id.data, exist_ok=True) actual_results = saver.write_data( checkpoint_id, write_buckets=[], - thread_count=thread_count, + files_per_rank=files_per_rank, replicate_after_write=False, ) @@ -1481,7 +1481,7 @@ def test_write_data_empty_buckets(self, saver, temp_dir_path, thread_count): @pytest.mark.parametrize("preexisting_content", [b"", b"some old data"]) def test_write_data_overwrite(self, saver, temp_dir_path, chkpt_object_manager, preexisting_content): # Given - thread_count = 1 + files_per_rank = 1 checkpoint_id = CheckpointContainerId(f"{temp_dir_path}/checkpoint_write_data_overwrite") os.makedirs(checkpoint_id.data, exist_ok=True) @@ -1496,7 +1496,7 @@ def test_write_data_overwrite(self, saver, temp_dir_path, chkpt_object_manager, ), ] write_buckets = saver.prepare_write_data( - checkpoint_id, write_items, resolver, "data", bucket_count=thread_count + checkpoint_id, write_items, resolver, "data", bucket_count=files_per_rank ) # Create a file with the same name that will be written to @@ -1509,7 +1509,7 @@ def test_write_data_overwrite(self, saver, temp_dir_path, chkpt_object_manager, actual_results = saver.write_data( checkpoint_id, write_buckets, - thread_count=thread_count, + files_per_rank=files_per_rank, replicate_after_write=False, ) @@ -1547,10 +1547,12 @@ def test_write_data_overwrite(self, saver, temp_dir_path, chkpt_object_manager, loaded_tensor = _load_tensor_maybe_optimized(io.BytesIO(data_from_file), header=header) assert torch.equal(loaded_tensor, tensor_data) - @pytest.mark.parametrize("thread_count", [0, -1, -5]) - def test_write_data_thread_count_less_than_1(self, saver, temp_dir_path, chkpt_object_manager, thread_count): + @pytest.mark.parametrize("files_per_rank", [0, -1, -5]) + def test_write_data_files_per_rank_less_than_1( + self, saver, temp_dir_path, chkpt_object_manager, files_per_rank + ): # Given - checkpoint_id = CheckpointContainerId(f"{temp_dir_path}/checkpoint_write_data_tensor_th{thread_count}") + checkpoint_id = CheckpointContainerId(f"{temp_dir_path}/checkpoint_write_data_tensor_th{files_per_rank}") os.makedirs(checkpoint_id.data, exist_ok=True) tensor_data1 = torch.tensor([1, 2, 3]) @@ -1574,7 +1576,7 @@ def test_write_data_thread_count_less_than_1(self, saver, temp_dir_path, chkpt_o actual_results = saver.write_data( checkpoint_id, write_buckets, - thread_count=thread_count, + files_per_rank=files_per_rank, replicate_after_write=False, ) @@ -1602,7 +1604,7 @@ def test_write_data_thread_count_less_than_1(self, saver, temp_dir_path, chkpt_o def test_write_data_triggers_replication(self, saver, temp_dir_path, replication_manager): # Given - thread_count = 1 + files_per_rank = 1 checkpoint_id = CheckpointContainerId(f"{temp_dir_path}/checkpoint_write_data_replication") os.makedirs(checkpoint_id.data, exist_ok=True) @@ -1617,14 +1619,14 @@ def test_write_data_triggers_replication(self, saver, temp_dir_path, replication ), ] write_buckets = saver.prepare_write_data( - checkpoint_id, write_items, resolver, "data", bucket_count=thread_count + checkpoint_id, write_items, resolver, "data", bucket_count=files_per_rank ) # When saver.write_data( checkpoint_id, write_buckets, - thread_count=thread_count, + files_per_rank=files_per_rank, replicate_after_write=True, ) @@ -1643,7 +1645,7 @@ def test_write_data_triggers_replication(self, saver, temp_dir_path, replication def test_write_data_no_replication(self, saver, temp_dir_path, replication_manager): # Given - thread_count = 1 + files_per_rank = 1 checkpoint_id = CheckpointContainerId(f"{temp_dir_path}/checkpoint_write_data_no_replication") os.makedirs(checkpoint_id.data, exist_ok=True) @@ -1658,14 +1660,14 @@ def test_write_data_no_replication(self, saver, temp_dir_path, replication_manag ), ] write_buckets = saver.prepare_write_data( - checkpoint_id, write_items, resolver, "data", bucket_count=thread_count + checkpoint_id, write_items, resolver, "data", bucket_count=files_per_rank ) # When saver.write_data( checkpoint_id, write_buckets, - thread_count=thread_count, + files_per_rank=files_per_rank, replicate_after_write=False, )