diff --git a/src/ml_flashpoint/adapter/megatron/save_strategies.py b/src/ml_flashpoint/adapter/megatron/save_strategies.py index bc7452b..5ea8ef6 100644 --- a/src/ml_flashpoint/adapter/megatron/save_strategies.py +++ b/src/ml_flashpoint/adapter/megatron/save_strategies.py @@ -141,12 +141,16 @@ def async_save(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Union # 1a. First, initialize the checkpoint. This marks this checkpoint container as "dirty". # This must always be the very first operation. self._checkpoint_saver.initialize_checkpoint(checkpoint_id) + old_storage_writer = self._storage_writer # 1b. Re-initialize the StorageWriter to use a new instance per save to avoid hangs from shared state. self._storage_writer = MemoryStorageWriter( checkpoint_saver=self._checkpoint_saver, mp_manager_future=self._storage_writer._main_process_torchmp_manager_future, thread_count=self._storage_writer._thread_count, ) + # Reuse existing proxy objects from the manager instead of creating new ones + self._storage_writer._write_events_per_checkpoint_id = old_storage_writer._write_events_per_checkpoint_id + self._storage_writer._write_results_per_checkpoint_id = old_storage_writer._write_results_per_checkpoint_id # 1c. Reset the StorageWriter for this checkpoint version. self._storage_writer.reset(checkpoint_id.data) @@ -259,3 +263,12 @@ def async_save(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Union }, finalize_fns=finalize_fns, ) + + def teardown(self) -> None: + """Tears down resources used by this strategy, including the StorageWriter and its mp_manager.""" + if ( + hasattr(self, "_storage_writer") + and self._storage_writer is not None + and hasattr(self._storage_writer, "teardown") + ): + self._storage_writer.teardown() diff --git a/src/ml_flashpoint/adapter/nemo/checkpoint_io.py b/src/ml_flashpoint/adapter/nemo/checkpoint_io.py index 9093d73..26b0b5a 100644 --- a/src/ml_flashpoint/adapter/nemo/checkpoint_io.py +++ b/src/ml_flashpoint/adapter/nemo/checkpoint_io.py @@ -265,6 +265,21 @@ def remove_checkpoint(self, path: _PATH) -> None: else: self.fallback_checkpoint_io.remove_checkpoint(path) + @log_execution_time(logger=_LOGGER, name="MLFlashpointCheckpointIO.teardown") + def teardown(self) -> None: + """Tears down the CheckpointIO instance and its strategies/fallbacks.""" + if hasattr(super(), "teardown"): + super().teardown() + + if hasattr(self, "save_strategy") and self.save_strategy and hasattr(self.save_strategy, "teardown"): + self.save_strategy.teardown() + if ( + hasattr(self, "fallback_checkpoint_io") + and self.fallback_checkpoint_io + and hasattr(self.fallback_checkpoint_io, "teardown") + ): + self.fallback_checkpoint_io.teardown() + class MLFlashpointAsyncFinalizableCheckpointIO(AsyncFinalizableCheckpointIO): """CheckpointIO wrapper for async checkpoint saving and synchronous finalization @@ -391,8 +406,16 @@ def maybe_finalize_save_checkpoint(self, blocking: bool = False) -> bool: @override @log_execution_time(logger=_LOGGER, name="MLFlashpointAsyncFinalizeCheckpointIO.teardown") def teardown(self) -> None: - """Warns if there are any pending checkpoint saves.""" + """Warns if there are any pending checkpoint saves and cleans up resources.""" super().teardown() + + if ( + hasattr(self, "mlf_checkpoint_io") + and self.mlf_checkpoint_io + and hasattr(self.mlf_checkpoint_io, "teardown") + ): + self.mlf_checkpoint_io.teardown() + if ( self._mlf_async_calls_queue.get_num_unfinalized_calls() + self._alt_async_calls_queue.get_num_unfinalized_calls() diff --git a/src/ml_flashpoint/adapter/pytorch/memory_storage_writer.py b/src/ml_flashpoint/adapter/pytorch/memory_storage_writer.py index e527cba..8e031d3 100644 --- a/src/ml_flashpoint/adapter/pytorch/memory_storage_writer.py +++ b/src/ml_flashpoint/adapter/pytorch/memory_storage_writer.py @@ -388,6 +388,21 @@ def finish_checkpoint( checkpoint_id, ) self._write_results_per_checkpoint_id.pop(checkpoint_id, None) + self._write_events_per_checkpoint_id.pop(checkpoint_id, None) + + @log_execution_time(logger=_LOGGER, name="teardown", level=logging.INFO) + def teardown(self) -> None: + """Tears down the StorageWriter, including shutting down the torch_mp Manager.""" + if self._main_process_torchmp_manager_future is not None: + try: + manager = self._main_process_torchmp_manager_future.result(timeout=1.0) + _LOGGER.info("Shutting down torch_mp Manager...") + manager.shutdown() + _LOGGER.info("Successfully shut down torch_mp Manager.") + except Exception as e: + _LOGGER.warning("Failed to shutdown torch_mp Manager: %s", e) + finally: + self._main_process_torchmp_manager_future = None @classmethod @override diff --git a/tests/adapter/megatron/test_save_strategies.py b/tests/adapter/megatron/test_save_strategies.py index 00174a0..d136372 100644 --- a/tests/adapter/megatron/test_save_strategies.py +++ b/tests/adapter/megatron/test_save_strategies.py @@ -72,6 +72,17 @@ def test_can_handle_sharded_objects(self, storage_writer): # When/Then assert strategy.can_handle_sharded_objects() is True + def test_teardown(self, mocker, storage_writer): + # Given + strategy = MLFlashpointMegatronAsyncSaveStrategy(storage_writer=storage_writer) + mocker.spy(storage_writer, "teardown") + + # When + strategy.teardown() + + # Then + storage_writer.teardown.assert_called_once() + class TestAsyncSave: @pytest.fixture(autouse=True) def mock_dist(self, mocker): @@ -160,7 +171,8 @@ def async_save_setup(self, mocker, monkeypatch, storage_writer, checkpoint_id): def test_async_save_initialization_calls_success( self, mocker, async_save_setup, storage_writer, checkpoint_saver, dummy_write_buckets ): - """Tests the initialization calls within async_save, including StorageWriter re-initialization.""" + """Tests the initialization calls within async_save, including + StorageWriter re-initialization and proxy reuse.""" # Given mock_statedictsaver = mocker.patch("ml_flashpoint.adapter.megatron.save_strategies.statedictsaver") ( @@ -182,6 +194,10 @@ def test_async_save_initialization_calls_success( ) mock_new_storage_writer_instance = mock_memory_storage_writer_cls.return_value + # Setup the old storage writer to have some dummy proxies + storage_writer._write_events_per_checkpoint_id = "dummy_events_proxy" + storage_writer._write_results_per_checkpoint_id = "dummy_results_proxy" + initialize_checkpoint_spy = mocker.spy(checkpoint_saver, "initialize_checkpoint") # When @@ -195,6 +211,11 @@ def test_async_save_initialization_calls_success( mp_manager_future=storage_writer._main_process_torchmp_manager_future, thread_count=storage_writer._thread_count, ) + + # Verify proxy reuse + assert mock_new_storage_writer_instance._write_events_per_checkpoint_id == "dummy_events_proxy" + assert mock_new_storage_writer_instance._write_results_per_checkpoint_id == "dummy_results_proxy" + mock_new_storage_writer_instance.reset.assert_called_once_with(checkpoint_id.data) mock_new_storage_writer_instance.stage_write_data_buckets.assert_called_once_with( checkpoint_id, dummy_write_buckets, non_blocking=True diff --git a/tests/adapter/nemo/test_checkpoint_io.py b/tests/adapter/nemo/test_checkpoint_io.py index 3b9e050..a79233a 100644 --- a/tests/adapter/nemo/test_checkpoint_io.py +++ b/tests/adapter/nemo/test_checkpoint_io.py @@ -784,6 +784,23 @@ def simulate_io_dump(path, **kwargs): # Check that NO thread was spawned mock_thread_cls.assert_not_called() + def test_teardown(self, checkpoint_io_components, mocker): + """Tests that teardown correctly propagates to the save strategy.""" + # Given + checkpoint_io = checkpoint_io_components["checkpoint_io"] + save_strategy = checkpoint_io_components["save_strategy"] + fallback_checkpoint_io = checkpoint_io.fallback_checkpoint_io + + mocker.spy(save_strategy, "teardown") + mocker.spy(fallback_checkpoint_io, "teardown") + + # When + checkpoint_io.teardown() + + # Then + save_strategy.teardown.assert_called_once() + fallback_checkpoint_io.teardown.assert_called_once() + class TestMLFlashpointAsyncFinalizableCheckpointIO: """Parent test class for MLFlashpointAsyncFinalizableCheckpointIO.""" @@ -1109,6 +1126,34 @@ def test_teardown_with_no_pending_saves(self, mocker): # Then self.mock_logger.warning.assert_not_called() + def test_teardown_propagates_to_mlf_checkpoint_io(self, mocker): + """Tests that teardown propagates the teardown call to the underlying MLFlashpointCheckpointIO.""" + # Given + mock_checkpoint_io = mocker.Mock( + spec=MLFlashpointCheckpointIO, + trainer=mocker.MagicMock(), + save_strategy=mocker.MagicMock(), + load_strategy=mocker.MagicMock(), + chkpt_obj_manager=mocker.MagicMock(), + fallback_checkpoint_io=mocker.MagicMock(), + async_save=True, + flashpoint_base_dir="/mlf/checkpoints", + ) + mock_checkpoint_io.trainer.global_rank = 0 + mock_checkpoint_io.save_strategy.thread_count = 1 + mock_mlf_queue = mocker.MagicMock() + mock_alt_queue = mocker.MagicMock() + self.mock_async_calls_queue_cls.side_effect = [mock_mlf_queue, mock_alt_queue] + instance = MLFlashpointAsyncFinalizableCheckpointIO(mock_checkpoint_io) + mock_mlf_queue.get_num_unfinalized_calls.return_value = 0 + mock_alt_queue.get_num_unfinalized_calls.return_value = 0 + + # When + instance.teardown() + + # Then + mock_checkpoint_io.teardown.assert_called_once() + def test_teardown_with_pending_mlf_saves(self, mocker): """Tests that a warning is logged when there are pending MLF saves.""" # Given diff --git a/tests/adapter/pytorch/test_memory_storage_writer.py b/tests/adapter/pytorch/test_memory_storage_writer.py index bb11cd8..04f0103 100644 --- a/tests/adapter/pytorch/test_memory_storage_writer.py +++ b/tests/adapter/pytorch/test_memory_storage_writer.py @@ -100,6 +100,29 @@ def test_init_thread_count(self, mocker, mp_manager_future, thread_count, expect # Then assert writer._thread_count == expected_thread_count + def test_teardown(self, mocker, mp_manager_future): + """Tests that teardown gracefully shuts down the torch_mp Manager.""" + # Given + import concurrent.futures + + mock_saver = mocker.MagicMock(spec=MLFlashpointCheckpointSaver) + writer = MemoryStorageWriter(checkpoint_saver=mock_saver, mp_manager_future=mp_manager_future) + + mock_manager = mocker.MagicMock() + mock_future = concurrent.futures.Future() + mock_future.set_result(mock_manager) + writer._main_process_torchmp_manager_future = mock_future + + # When + writer.teardown() + + # Then + mock_manager.shutdown.assert_called_once() + assert writer._main_process_torchmp_manager_future is None + + # Call again to test idempotency + writer.teardown() + def test_validate_checkpoint_id(self): """Tests the validate_checkpoint_id class method.""" # Valid cases @@ -1016,6 +1039,7 @@ def test_finish_checkpoint_success(self, mocker, mp_manager_future): assert metadata.storage_meta == expected_storage_meta mock_saver.write_metadata.assert_called_once_with(checkpoint_id, metadata) assert checkpoint_id not in writer._write_results_per_checkpoint_id + assert checkpoint_id not in writer._write_events_per_checkpoint_id def test_finish_checkpoint_empty_results(self, mocker, mp_manager_future): """Tests finish_checkpoint with an empty results list.""" @@ -1408,7 +1432,8 @@ def test_finish_clears_results(self, writer): writer.finish(metadata, [[wr1]]) # Then - assert writer.get_write_results(checkpoint_id) is None + assert checkpoint_id not in writer._write_results_per_checkpoint_id + assert checkpoint_id not in writer._write_events_per_checkpoint_id def test_write_data_in_separate_process(self, writer, mocker): """Tests that write_data in a separate process correctly updates the shared results."""