Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions src/ml_flashpoint/adapter/megatron/save_strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,12 +141,16 @@ def async_save(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Union
# 1a. First, initialize the checkpoint. This marks this checkpoint container as "dirty".
# This must always be the very first operation.
self._checkpoint_saver.initialize_checkpoint(checkpoint_id)
old_storage_writer = self._storage_writer
# 1b. Re-initialize the StorageWriter to use a new instance per save to avoid hangs from shared state.
self._storage_writer = MemoryStorageWriter(
checkpoint_saver=self._checkpoint_saver,
mp_manager_future=self._storage_writer._main_process_torchmp_manager_future,
thread_count=self._storage_writer._thread_count,
)
Comment on lines 146 to 150
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For improved readability, consider using the old_storage_writer variable when creating the new MemoryStorageWriter instance. While the current code is functionally correct because self._storage_writer still refers to the old instance at that point, explicitly using old_storage_writer makes the intent clearer and less prone to misinterpretation during future maintenance.

Suggested change
self._storage_writer = MemoryStorageWriter(
checkpoint_saver=self._checkpoint_saver,
mp_manager_future=self._storage_writer._main_process_torchmp_manager_future,
thread_count=self._storage_writer._thread_count,
)
self._storage_writer = MemoryStorageWriter(
checkpoint_saver=self._checkpoint_saver,
mp_manager_future=old_storage_writer._main_process_torchmp_manager_future,
thread_count=old_storage_writer._thread_count,
)

# Reuse existing proxy objects from the manager instead of creating new ones
self._storage_writer._write_events_per_checkpoint_id = old_storage_writer._write_events_per_checkpoint_id
self._storage_writer._write_results_per_checkpoint_id = old_storage_writer._write_results_per_checkpoint_id
# 1c. Reset the StorageWriter for this checkpoint version.
self._storage_writer.reset(checkpoint_id.data)

Expand Down Expand Up @@ -259,3 +263,12 @@ def async_save(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Union
},
finalize_fns=finalize_fns,
)

def teardown(self) -> None:
"""Tears down resources used by this strategy, including the StorageWriter and its mp_manager."""
if (
hasattr(self, "_storage_writer")
and self._storage_writer is not None
and hasattr(self._storage_writer, "teardown")
):
self._storage_writer.teardown()
25 changes: 24 additions & 1 deletion src/ml_flashpoint/adapter/nemo/checkpoint_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,21 @@ def remove_checkpoint(self, path: _PATH) -> None:
else:
self.fallback_checkpoint_io.remove_checkpoint(path)

@log_execution_time(logger=_LOGGER, name="MLFlashpointCheckpointIO.teardown")
def teardown(self) -> None:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where is this invoked btw?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

inside the teardown of MLFlashpointAsyncFinalizableCheckpointIO

"""Tears down the CheckpointIO instance and its strategies/fallbacks."""
if hasattr(super(), "teardown"):
super().teardown()

if hasattr(self, "save_strategy") and self.save_strategy and hasattr(self.save_strategy, "teardown"):
self.save_strategy.teardown()
if (
hasattr(self, "fallback_checkpoint_io")
and self.fallback_checkpoint_io
and hasattr(self.fallback_checkpoint_io, "teardown")
):
self.fallback_checkpoint_io.teardown()


class MLFlashpointAsyncFinalizableCheckpointIO(AsyncFinalizableCheckpointIO):
"""CheckpointIO wrapper for async checkpoint saving and synchronous finalization
Expand Down Expand Up @@ -391,8 +406,16 @@ def maybe_finalize_save_checkpoint(self, blocking: bool = False) -> bool:
@override
@log_execution_time(logger=_LOGGER, name="MLFlashpointAsyncFinalizeCheckpointIO.teardown")
def teardown(self) -> None:
"""Warns if there are any pending checkpoint saves."""
"""Warns if there are any pending checkpoint saves and cleans up resources."""
super().teardown()

if (
hasattr(self, "mlf_checkpoint_io")
and self.mlf_checkpoint_io
and hasattr(self.mlf_checkpoint_io, "teardown")
):
self.mlf_checkpoint_io.teardown()

if (
self._mlf_async_calls_queue.get_num_unfinalized_calls()
+ self._alt_async_calls_queue.get_num_unfinalized_calls()
Expand Down
15 changes: 15 additions & 0 deletions src/ml_flashpoint/adapter/pytorch/memory_storage_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,21 @@ def finish_checkpoint(
checkpoint_id,
)
self._write_results_per_checkpoint_id.pop(checkpoint_id, None)
self._write_events_per_checkpoint_id.pop(checkpoint_id, None)

@log_execution_time(logger=_LOGGER, name="teardown", level=logging.INFO)
def teardown(self) -> None:
"""Tears down the StorageWriter, including shutting down the torch_mp Manager."""
if self._main_process_torchmp_manager_future is not None:
try:
manager = self._main_process_torchmp_manager_future.result(timeout=1.0)
_LOGGER.info("Shutting down torch_mp Manager...")
manager.shutdown()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what if the call above times out, this will fail right? what value will manager hold

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if it timeout, will get timeout exception and caught by line 403, then the manager.shutdown() would be skipped

_LOGGER.info("Successfully shut down torch_mp Manager.")
except Exception as e:
_LOGGER.warning("Failed to shutdown torch_mp Manager: %s", e)
finally:
self._main_process_torchmp_manager_future = None

@classmethod
@override
Expand Down
23 changes: 22 additions & 1 deletion tests/adapter/megatron/test_save_strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,17 @@ def test_can_handle_sharded_objects(self, storage_writer):
# When/Then
assert strategy.can_handle_sharded_objects() is True

def test_teardown(self, mocker, storage_writer):
# Given
strategy = MLFlashpointMegatronAsyncSaveStrategy(storage_writer=storage_writer)
mocker.spy(storage_writer, "teardown")

# When
strategy.teardown()

# Then
storage_writer.teardown.assert_called_once()

class TestAsyncSave:
@pytest.fixture(autouse=True)
def mock_dist(self, mocker):
Expand Down Expand Up @@ -160,7 +171,8 @@ def async_save_setup(self, mocker, monkeypatch, storage_writer, checkpoint_id):
def test_async_save_initialization_calls_success(
self, mocker, async_save_setup, storage_writer, checkpoint_saver, dummy_write_buckets
):
"""Tests the initialization calls within async_save, including StorageWriter re-initialization."""
"""Tests the initialization calls within async_save, including
StorageWriter re-initialization and proxy reuse."""
# Given
mock_statedictsaver = mocker.patch("ml_flashpoint.adapter.megatron.save_strategies.statedictsaver")
(
Expand All @@ -182,6 +194,10 @@ def test_async_save_initialization_calls_success(
)
mock_new_storage_writer_instance = mock_memory_storage_writer_cls.return_value

# Setup the old storage writer to have some dummy proxies
storage_writer._write_events_per_checkpoint_id = "dummy_events_proxy"
storage_writer._write_results_per_checkpoint_id = "dummy_results_proxy"

initialize_checkpoint_spy = mocker.spy(checkpoint_saver, "initialize_checkpoint")

# When
Expand All @@ -195,6 +211,11 @@ def test_async_save_initialization_calls_success(
mp_manager_future=storage_writer._main_process_torchmp_manager_future,
thread_count=storage_writer._thread_count,
)

# Verify proxy reuse
assert mock_new_storage_writer_instance._write_events_per_checkpoint_id == "dummy_events_proxy"
assert mock_new_storage_writer_instance._write_results_per_checkpoint_id == "dummy_results_proxy"

mock_new_storage_writer_instance.reset.assert_called_once_with(checkpoint_id.data)
mock_new_storage_writer_instance.stage_write_data_buckets.assert_called_once_with(
checkpoint_id, dummy_write_buckets, non_blocking=True
Expand Down
45 changes: 45 additions & 0 deletions tests/adapter/nemo/test_checkpoint_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -784,6 +784,23 @@ def simulate_io_dump(path, **kwargs):
# Check that NO thread was spawned
mock_thread_cls.assert_not_called()

def test_teardown(self, checkpoint_io_components, mocker):
"""Tests that teardown correctly propagates to the save strategy."""
# Given
checkpoint_io = checkpoint_io_components["checkpoint_io"]
save_strategy = checkpoint_io_components["save_strategy"]
fallback_checkpoint_io = checkpoint_io.fallback_checkpoint_io

mocker.spy(save_strategy, "teardown")
mocker.spy(fallback_checkpoint_io, "teardown")

# When
checkpoint_io.teardown()

# Then
save_strategy.teardown.assert_called_once()
fallback_checkpoint_io.teardown.assert_called_once()


class TestMLFlashpointAsyncFinalizableCheckpointIO:
"""Parent test class for MLFlashpointAsyncFinalizableCheckpointIO."""
Expand Down Expand Up @@ -1109,6 +1126,34 @@ def test_teardown_with_no_pending_saves(self, mocker):
# Then
self.mock_logger.warning.assert_not_called()

def test_teardown_propagates_to_mlf_checkpoint_io(self, mocker):
"""Tests that teardown propagates the teardown call to the underlying MLFlashpointCheckpointIO."""
# Given
mock_checkpoint_io = mocker.Mock(
spec=MLFlashpointCheckpointIO,
trainer=mocker.MagicMock(),
save_strategy=mocker.MagicMock(),
load_strategy=mocker.MagicMock(),
chkpt_obj_manager=mocker.MagicMock(),
fallback_checkpoint_io=mocker.MagicMock(),
async_save=True,
flashpoint_base_dir="/mlf/checkpoints",
)
mock_checkpoint_io.trainer.global_rank = 0
mock_checkpoint_io.save_strategy.thread_count = 1
mock_mlf_queue = mocker.MagicMock()
mock_alt_queue = mocker.MagicMock()
self.mock_async_calls_queue_cls.side_effect = [mock_mlf_queue, mock_alt_queue]
instance = MLFlashpointAsyncFinalizableCheckpointIO(mock_checkpoint_io)
mock_mlf_queue.get_num_unfinalized_calls.return_value = 0
mock_alt_queue.get_num_unfinalized_calls.return_value = 0

# When
instance.teardown()

# Then
mock_checkpoint_io.teardown.assert_called_once()

def test_teardown_with_pending_mlf_saves(self, mocker):
"""Tests that a warning is logged when there are pending MLF saves."""
# Given
Expand Down
27 changes: 26 additions & 1 deletion tests/adapter/pytorch/test_memory_storage_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,29 @@ def test_init_thread_count(self, mocker, mp_manager_future, thread_count, expect
# Then
assert writer._thread_count == expected_thread_count

def test_teardown(self, mocker, mp_manager_future):
"""Tests that teardown gracefully shuts down the torch_mp Manager."""
# Given
import concurrent.futures

mock_saver = mocker.MagicMock(spec=MLFlashpointCheckpointSaver)
writer = MemoryStorageWriter(checkpoint_saver=mock_saver, mp_manager_future=mp_manager_future)

mock_manager = mocker.MagicMock()
mock_future = concurrent.futures.Future()
mock_future.set_result(mock_manager)
writer._main_process_torchmp_manager_future = mock_future

# When
writer.teardown()

# Then
mock_manager.shutdown.assert_called_once()
assert writer._main_process_torchmp_manager_future is None

# Call again to test idempotency
writer.teardown()

def test_validate_checkpoint_id(self):
"""Tests the validate_checkpoint_id class method."""
# Valid cases
Expand Down Expand Up @@ -1016,6 +1039,7 @@ def test_finish_checkpoint_success(self, mocker, mp_manager_future):
assert metadata.storage_meta == expected_storage_meta
mock_saver.write_metadata.assert_called_once_with(checkpoint_id, metadata)
assert checkpoint_id not in writer._write_results_per_checkpoint_id
assert checkpoint_id not in writer._write_events_per_checkpoint_id

def test_finish_checkpoint_empty_results(self, mocker, mp_manager_future):
"""Tests finish_checkpoint with an empty results list."""
Expand Down Expand Up @@ -1408,7 +1432,8 @@ def test_finish_clears_results(self, writer):
writer.finish(metadata, [[wr1]])

# Then
assert writer.get_write_results(checkpoint_id) is None
assert checkpoint_id not in writer._write_results_per_checkpoint_id
assert checkpoint_id not in writer._write_events_per_checkpoint_id

def test_write_data_in_separate_process(self, writer, mocker):
"""Tests that write_data in a separate process correctly updates the shared results."""
Expand Down
Loading