Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
f291c11
refactor: rename thread_count to files_per_rank
ronaldw07 Mar 3, 2026
bf5ffd7
fix(tests): update stale docstring and fix lint errors
ronaldw07 Mar 3, 2026
d7a4e3a
revert(docs): restore historical commit messages in changelog
ronaldw07 Mar 8, 2026
353a891
Merge branch 'main' of https://github.com/google/ml-flashpoint into r…
ronaldw07 Mar 8, 2026
bce6888
Merge branch 'main' of https://github.com/google/ml-flashpoint into r…
ronaldw07 Mar 10, 2026
84f54df
fix(tests): wrap long function signature to satisfy E501 lint rule
ronaldw07 Mar 10, 2026
814a55e
Merge branch 'main' into rename-thread-count-to-files-per-rank
g-husam Mar 13, 2026
c7d0ee7
Merge branch 'main' into rename-thread-count-to-files-per-rank
ronaldw07 Mar 13, 2026
4dc67a6
docs: clarify write_files_per_rank docstring in wrapper_util
ronaldw07 Mar 13, 2026
fb6e578
Merge remote-tracking branch 'origin/rename-thread-count-to-files-per…
ronaldw07 Mar 13, 2026
0745c9b
Merge branch 'main' into rename-thread-count-to-files-per-rank
ronaldw07 Mar 17, 2026
eb5a166
Merge branch 'main' into rename-thread-count-to-files-per-rank
ronaldw07 Mar 23, 2026
ca50ff8
Merge branch 'main' into rename-thread-count-to-files-per-rank
g-husam Mar 27, 2026
4ee835a
Merge branch 'main' into rename-thread-count-to-files-per-rank
ronaldw07 Apr 1, 2026
9af5d9f
fix: complete rename of thread_count to files_per_rank
ronaldw07 Apr 1, 2026
131c14f
Merge branch 'main' into rename-thread-count-to-files-per-rank
ronaldw07 Apr 9, 2026
8644d99
Merge branch 'main' into rename-thread-count-to-files-per-rank
ronaldw07 Apr 10, 2026
7cb9822
Merge branch 'main' into rename-thread-count-to-files-per-rank
ronaldw07 Apr 15, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ _Release Notes: BEGINNING -> 194b781e75807afaba682f9eef2826464fcc120e_
* (9d1316a) replication_manager: Implement sync_bulk_retrieve method.
* (d0b8770) replication/transfer_service: Save received data to tmp object before finalizing.
* (60dd014) replication_manager: Implement async_replicate of replication_manager.
* (3cb4c38) adapter/pytorch: make writer thread_count and buffer size configurable with defaults
* (3cb4c38) adapter/pytorch: make writer files_per_rank and buffer size configurable with defaults
Comment thread
ronaldw07 marked this conversation as resolved.
Outdated
* (af05afb) replication/transfer_service: Implement async_get method.
* (6c4fe99) replication/transfer_service: Implement async_put method.
* (e62cd71) replication/transfer_service: Implement transfer_service initialize and shutdown.
Expand Down Expand Up @@ -106,7 +106,7 @@ _Release Notes: BEGINNING -> 194b781e75807afaba682f9eef2826464fcc120e_
* (6d54996) adapter/nemo: implement MLFlashpointCheckpointCallback; add CheckpointContainerId.from_parent() helper

### :white_check_mark: Bug Fixes
* (b33ddfa) wrapper_util: Expose write_thread_count and initial_write_buffer_size_bytes to user.
* (b33ddfa) wrapper_util: Expose write_files_per_rank and initial_write_buffer_size_bytes to user.
* (21cc19e) adapter/nemo: make CheckpointObjectManager a param to wrapper_util; passthrough kwargs in MLFlashpointAutoResume to parent
* (23daae9) Fix implementation of PairwiseReplicationStrategy and add more tests.
* (4620be7) core/saver: ensure writer can overwrite unfinished checkpoint data after recovery
Expand Down
2 changes: 1 addition & 1 deletion docs/troubleshooting.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
1. Ensure you have sufficient space on your base container mount.
1. If you have enough memory, but are running out of buffer space during writes, you can:
1. Increase the default initial buffer size via `initial_write_buffer_size_bytes` in the `wrap` API you are using (the default is 16 GB).
1. Increase the write thread count, so that each rank writes to multiple buffers, effectively cutting the size of each buffer proportionally, via `write_thread_count` in the `wrap` API you are using (the default is 1).
1. Increase the number of files per rank, so that each rank writes to multiple buffers, effectively cutting the size of each buffer proportionally, via `write_files_per_rank` in the `wrap` API you are using (the default is 1).

### How can I clean up ML Flashpoint checkpoints after job completion?

Expand Down
2 changes: 1 addition & 1 deletion docs/user-guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ auto_resume = wrap_trainer_and_auto_resume_with_mlflashpoint(
async_save=not args.sync_save,
default_auto_resume=auto_resume, # Optional
# always_save_context=False, # Optional, defaults to False
# write_thread_count=1, # Optional, defaults to 1
# write_files_per_rank=1, # Optional, defaults to 1
# initial_write_buffer_size_bytes=DESIRED_NUM_BYTES, # Optional, defaults to 16 GB
)
```
Expand Down
2 changes: 1 addition & 1 deletion src/ml_flashpoint/adapter/megatron/save_strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def async_save(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Union
self._storage_writer = MemoryStorageWriter(
checkpoint_saver=self._checkpoint_saver,
mp_manager=self._storage_writer._main_process_torchmp_manager,
thread_count=self._storage_writer._thread_count,
files_per_rank=self._storage_writer._files_per_rank,
)
# 1c. Reset the StorageWriter for this checkpoint version.
self._storage_writer.reset(checkpoint_id.data)
Expand Down
16 changes: 8 additions & 8 deletions src/ml_flashpoint/adapter/nemo/wrapper_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def wrap_trainer_and_auto_resume_with_mlflashpoint(
async_save: bool,
default_auto_resume: nl.AutoResume = None,
always_save_context: bool = False,
write_thread_count: int = 1,
write_files_per_rank: int = 1,
initial_write_buffer_size_bytes: int = DEFAULT_INITIAL_BUFFER_SIZE_BYTES,
use_optimized_save: bool = True,
) -> MLFlashpointAutoResume:
Expand All @@ -59,7 +59,7 @@ def wrap_trainer_and_auto_resume_with_mlflashpoint(
async_save: Whether to enable asynchronous saving for checkpoints.
default_auto_resume: The default AutoResume configuration to inherit from.
always_save_context: Whether to always save the context. Defaults to `False`.
write_thread_count: Optional. The number of threads to use for writing checkpoint data. Defaults to 1.
write_files_per_rank: Optional. The number of files each rank writes to for checkpoint data. Defaults to 1.
Comment thread
ronaldw07 marked this conversation as resolved.
Outdated
initial_write_buffer_size_bytes: Optional. The initial size of the buffer for writing checkpoint data
in bytes. Defaults to `DEFAULT_INITIAL_BUFFER_SIZE_BYTES`.
Returns:
Expand Down Expand Up @@ -87,7 +87,7 @@ def wrap_trainer_and_auto_resume_with_mlflashpoint(
async_save=async_save,
checkpoint_loader=ckpt_loader,
always_save_context=always_save_context,
write_thread_count=write_thread_count,
write_files_per_rank=write_files_per_rank,
initial_write_buffer_size_bytes=initial_write_buffer_size_bytes,
use_optimized_save=use_optimized_save,
)
Expand All @@ -108,7 +108,7 @@ def wrap_trainer_checkpoint_io_with_mlflashpoint(
async_save: bool,
checkpoint_loader: DefaultMLFlashpointCheckpointLoader,
always_save_context: bool = False,
write_thread_count: int = 1,
write_files_per_rank: int = 1,
initial_write_buffer_size_bytes: int = DEFAULT_INITIAL_BUFFER_SIZE_BYTES,
use_optimized_save: bool = True,
):
Expand All @@ -135,7 +135,7 @@ def wrap_trainer_checkpoint_io_with_mlflashpoint(
async_save: Whether to enable asynchronous saving.
checkpoint_loader: The checkpoint loader to use.
always_save_context: Whether to always save the context. Defaults to `False`.
write_thread_count: Optional. The number of threads to use for writing checkpoint data. Defaults to 1.
write_files_per_rank: Optional. The number of files each rank writes to for checkpoint data. Defaults to 1.
Comment thread
ronaldw07 marked this conversation as resolved.
Outdated
initial_write_buffer_size_bytes: Optional. The initial size of the buffer for writing checkpoint data
in bytes. Defaults to `DEFAULT_INITIAL_BUFFER_SIZE_BYTES`.

Expand All @@ -152,8 +152,8 @@ def wrap_trainer_checkpoint_io_with_mlflashpoint(
raise ValueError("The 'ckpt_obj_manager' argument cannot be None.")
if replication_manager is None:
raise ValueError("The 'replication_manager' argument cannot be None.")
if write_thread_count < 1:
raise ValueError(f"write_thread_count must be >= 1, got {write_thread_count}.")
if write_files_per_rank < 1:
raise ValueError(f"write_files_per_rank must be >= 1, got {write_files_per_rank}.")
if initial_write_buffer_size_bytes <= 0:
raise ValueError(f"initial_write_buffer_size_bytes must be > 0, got {initial_write_buffer_size_bytes}.")

Expand Down Expand Up @@ -217,7 +217,7 @@ def wrap_trainer_checkpoint_io_with_mlflashpoint(
use_optimized_save=use_optimized_save,
),
mp_manager=ctx.Manager(),
thread_count=write_thread_count,
files_per_rank=write_files_per_rank,
)
)
load_strategy = MLFlashpointMegatronLoadStrategy(
Expand Down
16 changes: 8 additions & 8 deletions src/ml_flashpoint/adapter/pytorch/memory_storage_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def __init__(
self,
checkpoint_saver: MLFlashpointCheckpointSaver,
mp_manager: torch_mp.Manager,
thread_count: int = 1,
files_per_rank: int = 1,
):
"""Initializes the MemoryStorageWriter.

Expand All @@ -100,18 +100,18 @@ def __init__(
It is highly recommended to create this manager using a 'spawn'
multiprocessing context to avoid inheriting the parent's CUDA context,
which prevents CUDA OOM errors during failure recoveries
thread_count: Optional. The number of threads to use for writing checkpoint data.
files_per_rank: Optional. The number of files each rank writes to for checkpoint data.
Defaults to 1. If a value less than 1 is provided, it will be reset to 1,
and a warning will be logged.
"""
super().__init__()
self._current_checkpoint_id: CheckpointContainerId | None = None
self._current_save_id: str | None = None
self._checkpoint_saver: MLFlashpointCheckpointSaver = checkpoint_saver
if thread_count < 1:
_LOGGER.warning("thread_count must be >= 1, but was %d. Setting to 1.", thread_count)
thread_count = 1
self._thread_count = thread_count
if files_per_rank < 1:
_LOGGER.warning("files_per_rank must be >= 1, but was %d. Setting to 1.", files_per_rank)
files_per_rank = 1
self._files_per_rank = files_per_rank
# _main_process_torchmp_manager should only be used in the main process, not in the spawned processes.
# This is because mp_manager is not picklable.
self._main_process_torchmp_manager = mp_manager
Expand Down Expand Up @@ -197,7 +197,7 @@ def prepare_write_data_buckets(
self._write_events_per_checkpoint_id[checkpoint_id] = self._main_process_torchmp_manager.Event()

write_buckets = self.checkpoint_saver.prepare_write_data(
checkpoint_id, plan.items, planner, plan.storage_data.prefix, bucket_count=self._thread_count
checkpoint_id, plan.items, planner, plan.storage_data.prefix, bucket_count=self._files_per_rank
)
return write_buckets
# self._write_buckets_per_checkpoint_id[checkpoint_id] = write_buckets
Expand Down Expand Up @@ -237,7 +237,7 @@ def write_staged_data_buckets(
write_results = self._checkpoint_saver.write_data(
checkpoint_id,
write_buckets=staged_write_buckets,
thread_count=self._thread_count,
files_per_rank=self._files_per_rank,
replicate_after_write=replicate_after_write,
)
end_time = time.perf_counter()
Expand Down
22 changes: 11 additions & 11 deletions src/ml_flashpoint/core/checkpoint_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def write_data(
self,
checkpoint_id: CheckpointContainerId,
write_buckets: list[ObjectWriteBucket],
thread_count: int,
files_per_rank: int,
replicate_after_write: bool,
) -> list[WriteResult]:
"""Performs the core write logic for the given write items and checkpoint_id.
Expand All @@ -225,7 +225,7 @@ def write_data(
checkpoint_id: Unique hierarchical ID representing this checkpoint container.
This typically follows a directory path structure.
write_buckets: A list of `ObjectWriteBucket` objects, each containing resolved data ready for writing.
thread_count: The number of threads to use for writing data.
files_per_rank: The number of files each rank writes to.
replicate_after_write: Whether to trigger async replication of each object after it is written.

Returns:
Expand Down Expand Up @@ -371,7 +371,7 @@ def prepare_write_data(
) -> list[ObjectWriteBucket]:
bucket_count = max(bucket_count, 1)
_LOGGER.debug(
"%s prepare_write_data with prefix: '%s', thread_count: %d",
"%s prepare_write_data with prefix: '%s', files_per_rank: %d",
self.__class__.__name__,
object_name_prefix,
bucket_count,
Expand Down Expand Up @@ -403,7 +403,7 @@ def _clone_if_needed(tensor: torch.Tensor):
# NOTE: There is support for multiple threads, to simplify modifying that setting, but we typically
# only use 1 thread.

# Group items into buckets, one bucket per file, up to thread_count files
# Group items into buckets, one bucket per file, up to files_per_rank files
buckets = _split_by_size_and_type(bucket_count, write_items)
for bucket in buckets:
if not bucket:
Expand Down Expand Up @@ -437,22 +437,22 @@ def write_data(
checkpoint_id: CheckpointContainerId,
write_buckets: list[ObjectWriteBucket],
replicate_after_write: bool,
thread_count: int = 1,
files_per_rank: int = 1,
) -> list[WriteResult]:
thread_count = max(thread_count, 1)
files_per_rank = max(files_per_rank, 1)
num_cpus = os.cpu_count() or 1
num_ranks = max(get_accelerator_count(), 1)
# Use 50% of available CPU cores for PyTorch intra-op threads and evenly distribute them across ranks.
torch_thread_count = max(1, num_cpus // 2 // num_ranks // thread_count)
torch_thread_count = max(1, num_cpus // 2 // num_ranks // files_per_rank)
original_num_threads = torch.get_num_threads()
# Explicitly set PyTorch intra-op threads to optimize for performance.
# This also avoids potential runtime errors in tensor.copy_() with concurrent writers
torch.set_num_threads(torch_thread_count)
_LOGGER.debug(
"%s starting multi-threaded write_data. thread_count: %d, original_num_threads: %d, "
"%s starting multi-threaded write_data. files_per_rank: %d, original_num_threads: %d, "
"num_cpus: %d, num_ranks: %d, torch_thread_count: %d",
self.__class__.__name__,
thread_count,
files_per_rank,
original_num_threads,
num_cpus,
num_ranks,
Expand All @@ -471,8 +471,8 @@ def write_data(
threads = []

# Kick off additional threads to main thread, if any.
_LOGGER.debug("Spawning %d extra writer threads (in addition to the main thread).", thread_count - 1)
for i in range(1, thread_count):
_LOGGER.debug("Spawning %d extra writer threads (in addition to the main thread).", files_per_rank - 1)
for i in range(1, files_per_rank):
thread = threading.Thread(
target=self._write_to_buffer_from_queue_worker,
args=(object_items_queue, results_from_threads, replicate_after_write, self._use_optimized_save),
Expand Down
17 changes: 9 additions & 8 deletions tests/adapter/megatron/test_save_strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,18 +190,19 @@ def test_async_save_initialization_calls_success(
mock_memory_storage_writer_cls.assert_called_once_with(
checkpoint_saver=checkpoint_saver,
mp_manager=storage_writer._main_process_torchmp_manager,
thread_count=storage_writer._thread_count,
files_per_rank=storage_writer._files_per_rank,
)
mock_new_storage_writer_instance.reset.assert_called_once_with(checkpoint_id.data)
mock_new_storage_writer_instance.stage_write_data_buckets.assert_called_once_with(
checkpoint_id, dummy_write_buckets, non_blocking=True
)

@pytest.mark.parametrize("expected_thread_count", [1, 2, 3, 5])
def test_async_save_reinitializes_storage_writer_with_thread_count(
self, mocker, async_save_setup, storage_writer, checkpoint_saver, dummy_write_buckets, expected_thread_count
@pytest.mark.parametrize("expected_files_per_rank", [1, 2, 3, 5])
def test_async_save_reinitializes_storage_writer_with_files_per_rank(
self, mocker, async_save_setup, storage_writer, checkpoint_saver, dummy_write_buckets,
expected_files_per_rank,
):
"""Tests that the StorageWriter is re-initialized with the correct thread_count."""
"""Tests that the StorageWriter is re-initialized with the correct files_per_rank."""
# Given
mock_statedictsaver = mocker.patch("ml_flashpoint.adapter.megatron.save_strategies.statedictsaver")
(
Expand All @@ -216,8 +217,8 @@ def test_async_save_reinitializes_storage_writer_with_thread_count(
mocker.MagicMock(),
)

# Set a specific thread_count on the original storage_writer
storage_writer._thread_count = expected_thread_count
# Set a specific files_per_rank on the original storage_writer
storage_writer._files_per_rank = expected_files_per_rank

mock_memory_storage_writer_cls = mocker.patch(
"ml_flashpoint.adapter.megatron.save_strategies.MemoryStorageWriter"
Expand All @@ -230,7 +231,7 @@ def test_async_save_reinitializes_storage_writer_with_thread_count(
mock_memory_storage_writer_cls.assert_called_once_with(
checkpoint_saver=checkpoint_saver,
mp_manager=storage_writer._main_process_torchmp_manager,
thread_count=expected_thread_count,
files_per_rank=expected_files_per_rank,
)

def test_initialize_checkpoint_failure(self, mocker, async_save_setup, checkpoint_saver):
Expand Down
Loading
Loading