Skip to content

Commit 8eeb937

Browse files
authored
Merge branch 'main' into rename-thread-count-to-files-per-rank
2 parents 353a891 + 984b718 commit 8eeb937

9 files changed

Lines changed: 337 additions & 171 deletions

File tree

docs/user-guide.md

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,8 @@ auto_resume = wrap_trainer_and_auto_resume_with_mlflashpoint(
9292
# always_save_context=False, # Optional, defaults to False
9393
# write_files_per_rank=1, # Optional, defaults to 1
9494
# initial_write_buffer_size_bytes=DESIRED_NUM_BYTES, # Optional, defaults to 16 GB
95+
# use_optimized_save=True, # Optional, defaults to True. Uses the optimized save method to reduce write time.
96+
# use_cached_ckpt_structure=True, # Optional, defaults to False. Caches the checkpoint structure after identifying 2 consecutive save plan structures that are equal.
9597
)
9698
```
9799

@@ -126,6 +128,7 @@ from ml_flashpoint.adapter.megatron.save_strategies import (
126128
)
127129

128130
# Loading
131+
import torch.distributed as dist
129132
from ml_flashpoint.adapter.megatron.load_strategies import MLFlashpointMegatronLoadStrategy
130133
from ml_flashpoint.checkpoint_object_manager.checkpoint_object_manager import CheckpointObjectManager
131134
from ml_flashpoint.core.checkpoint_loader import DefaultMLFlashpointCheckpointLoader
@@ -148,6 +151,7 @@ memory_storage_writer = MemoryStorageWriter(...)
148151
# Use it to instantiate the Save Strategy
149152
megatron_save_strategy = MLFlashpointMegatronAsyncSaveStrategy(
150153
storage_writer=memory_storage_writer,
154+
# use_cached_ckpt_structure=True, # Optional, defaults to False. Caches the checkpoint structure after identifying 2 consecutive save plan structures that are equal.
151155
)
152156
```
153157

@@ -167,7 +171,7 @@ async_request = save_local_aware_megatron_checkpoint(
167171

168172
!!! note
169173

170-
Make sure to specify the checkpoint ID/path when saving based on the current step using:
174+
Make sure to specify the checkpoint ID/path when saving based on the current step using:
171175
`CheckpointContainerId.create_child(base_container, CheckpointContainerId.format_version_container(current_step))`
172176
where `base_container` is the base path CheckpointContainerId used for all checkpoints for the current job, e.g. `"/tmp/mlf-checkpoints/job123"`.
173177

@@ -188,6 +192,11 @@ replication_manager.initialize(checkpoint_object_manager)
188192
checkpoint_loader = DefaultMLFlashpointCheckpointLoader(
189193
checkpoint_object_manager=checkpoint_object_manager,
190194
replication_manager=replication_manager,
195+
global_rank_getter=dist.get_rank,
196+
local_rank_getter=torch.distributed.get_node_local_rank,
197+
broadcast_object_list_func=dist.broadcast_object_list,
198+
all_gather_object_func=dist.all_gather_object,
199+
world_size_getter=dist.get_world_size,
191200
)
192201

193202
# Instantiate the Load Strategy with the dependencies
@@ -229,11 +238,12 @@ Code: See the [`ml_flashpoint.adapter.pytorch`](https://github.com/google/ml-fla
229238
To use directly with PyTorch DCP, use the provided `StorageWriter` and `StorageReader` implementations.
230239
You can use whatever `Planner` implementations work for your use case, or resort to the defaults.
231240

232-
If your per-rank checkpoint data exceeds the default buffer size (16 GB as of this writing), you can increase it using the optional `initial_buffer_size_bytes` parameter.
241+
If your per-rank checkpoint data exceeds the default buffer size (16 GB as of this writing), you can increase it using the optional `initial_buffer_size_bytes` parameter.
233242

234243
#### Imports
235244
```python
236245
import torch
246+
import torch.distributed as dist
237247
from torch import multiprocessing as torch_mp
238248
import torch.distributed.checkpoint as dcp
239249

@@ -262,6 +272,7 @@ memory_storage_writer = MemoryStorageWriter(
262272
ckpt_obj_manager=checkpoint_object_manager,
263273
replication_manager=replication_manager,
264274
# initial_buffer_size_bytes=initial_write_buffer_size_bytes, # Optional - increase for larger checkpoint sizes per rank
275+
# use_optimized_save=True, # Optional, defaults to True. Uses the optimized save method to reduce write time.
265276
),
266277
mp_manager=torch_mp.Manager(),
267278
)
@@ -270,6 +281,11 @@ memory_storage_writer = MemoryStorageWriter(
270281
checkpoint_loader = DefaultMLFlashpointCheckpointLoader(
271282
checkpoint_object_manager=checkpoint_object_manager,
272283
replication_manager=replication_manager,
284+
global_rank_getter=dist.get_rank,
285+
local_rank_getter=torch.distributed.get_node_local_rank,
286+
broadcast_object_list_func=dist.broadcast_object_list,
287+
all_gather_object_func=dist.all_gather_object,
288+
world_size_getter=dist.get_world_size,
273289
)
274290
memory_storage_reader = MemoryStorageReader(
275291
path=checkpoint_dir,

src/ml_flashpoint/adapter/nemo/nemo_checkpoint_loader.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import os
1616
from pathlib import Path
17-
from typing import List, Set
17+
from typing import Callable, List, Set
1818

1919
from typing_extensions import override
2020

@@ -33,6 +33,12 @@ def __init__(
3333
self,
3434
checkpoint_object_manager: CheckpointObjectManager,
3535
replication_manager: ReplicationManager,
36+
*,
37+
global_rank_getter: Callable[[], int],
38+
local_rank_getter: Callable[[], int],
39+
broadcast_object_list_func: Callable[..., None],
40+
all_gather_object_func: Callable[..., None],
41+
world_size_getter: Callable[[], int],
3642
recover_context: bool = False,
3743
):
3844
"""Initializes the NeMoMLFlashpointCheckpointLoader.
@@ -42,9 +48,24 @@ def __init__(
4248
reading data.
4349
replication_manager: The replication manager to use for retrieving
4450
missing checkpoint objects from peer nodes.
51+
global_rank_getter: A callable that returns the global rank.
52+
local_rank_getter: A callable that returns the node-local rank.
53+
broadcast_object_list_func: A callable with the same signature as
54+
``torch.distributed.broadcast_object_list``.
55+
all_gather_object_func: A callable with the same signature as
56+
``torch.distributed.all_gather_object``.
57+
world_size_getter: A callable that returns the world size.
4558
recover_context: Whether to recover the context directory if missing.
4659
"""
47-
super().__init__(checkpoint_object_manager, replication_manager)
60+
super().__init__(
61+
checkpoint_object_manager,
62+
replication_manager,
63+
global_rank_getter=global_rank_getter,
64+
local_rank_getter=local_rank_getter,
65+
broadcast_object_list_func=broadcast_object_list_func,
66+
all_gather_object_func=all_gather_object_func,
67+
world_size_getter=world_size_getter,
68+
)
4869
self._recover_context = recover_context
4970

5071
@override

src/ml_flashpoint/adapter/nemo/wrapper_util.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from typing import Union
1616

1717
import torch
18+
import torch.distributed as dist
1819
from nemo import lightning as nl
1920
from nemo.lightning.io.pl import MegatronCheckpointIO
2021
from nemo.lightning.pytorch import strategies as nl_strategies
@@ -79,6 +80,11 @@ def wrap_trainer_and_auto_resume_with_mlflashpoint(
7980
ckpt_loader = NeMoMLFlashpointCheckpointLoader(
8081
checkpoint_object_manager=ckpt_obj_manager,
8182
replication_manager=replication_manager,
83+
global_rank_getter=dist.get_rank,
84+
local_rank_getter=dist.get_node_local_rank,
85+
broadcast_object_list_func=dist.broadcast_object_list,
86+
all_gather_object_func=dist.all_gather_object,
87+
world_size_getter=dist.get_world_size,
8288
recover_context=always_save_context,
8389
)
8490

src/ml_flashpoint/core/checkpoint_loader.py

Lines changed: 31 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,9 @@
2222
import struct
2323
from collections import defaultdict
2424
from pathlib import Path
25-
from typing import IO, List, Optional, Set, Tuple, TypeVar, cast
25+
from typing import IO, Callable, List, Optional, Set, Tuple, TypeVar, cast
2626

2727
import torch
28-
import torch.distributed as dist
2928
from torch.distributed._shard._utils import narrow_tensor_by_index
3029
from torch.distributed.checkpoint import Metadata
3130
from torch.distributed.checkpoint.filesystem import _StorageInfo
@@ -128,6 +127,12 @@ def __init__(
128127
self,
129128
checkpoint_object_manager: CheckpointObjectManager,
130129
replication_manager: ReplicationManager,
130+
*,
131+
global_rank_getter: Callable[[], int],
132+
local_rank_getter: Callable[[], int],
133+
broadcast_object_list_func: Callable[..., None],
134+
all_gather_object_func: Callable[..., None],
135+
world_size_getter: Callable[[], int],
131136
):
132137
"""Initializes the DefaultMLFlashpointCheckpointLoader.
133138
@@ -136,9 +141,21 @@ def __init__(
136141
reading data.
137142
replication_manager: The replication manager to use for retrieving
138143
missing checkpoint objects from peer nodes.
144+
global_rank_getter: A callable that returns the global rank.
145+
local_rank_getter: A callable that returns the node-local rank.
146+
broadcast_object_list_func: A callable with the same signature as
147+
``torch.distributed.broadcast_object_list``.
148+
all_gather_object_func: A callable with the same signature as
149+
``torch.distributed.all_gather_object``.
150+
world_size_getter: A callable that returns the world size.
139151
"""
140152
self._checkpoint_object_manager = checkpoint_object_manager
141153
self._replication_manager = replication_manager
154+
self._global_rank_getter = global_rank_getter
155+
self._local_rank_getter = local_rank_getter
156+
self._broadcast_object_list_func = broadcast_object_list_func
157+
self._all_gather_object_func = all_gather_object_func
158+
self._world_size_getter = world_size_getter
142159
# Cache for available objects: CheckpointContainerId -> dict[object_path, list[rank]]
143160
self._available_objects_cache: dict[CheckpointContainerId, dict[str, List[int]]] = {}
144161

@@ -337,8 +354,7 @@ def get_latest_complete_checkpoint(
337354
else continue to the next candidate checkpoint
338355
- return the checkpoint container id of the latest complete checkpoint
339356
"""
340-
# TODO: use global_rank_getter and local_rank_getter.
341-
rank = dist.get_rank()
357+
rank = self._global_rank_getter()
342358
_LOGGER.debug(
343359
"Rank %s: Getting latest complete checkpoint for '%s'",
344360
rank,
@@ -382,7 +398,7 @@ def get_latest_complete_checkpoint(
382398
retrieval_plan = self._compute_retrieval_plan(checkpoint, available_objects_by_rank)
383399
# Broadcast the retrieval plan to all ranks.
384400
plan_container = [retrieval_plan]
385-
dist.broadcast_object_list(plan_container, src=planner_rank)
401+
self._broadcast_object_list_func(plan_container, src=planner_rank)
386402
retrieval_plan = plan_container[0]
387403

388404
if retrieval_plan is None:
@@ -451,7 +467,7 @@ def _compute_retrieval_plan(
451467

452468
objects_needed_by_local_rank_0.update(self._get_extra_needed_objects(checkpoint, available_objects_by_rank))
453469

454-
world_size = dist.get_world_size()
470+
world_size = self._world_size_getter()
455471
num_nodes = get_num_of_nodes()
456472
ranks_per_node = world_size // num_nodes
457473

@@ -507,8 +523,8 @@ def get_candidate_checkpoints(
507523

508524
# Scan locally only on the first rank of each node
509525
base_path = Path(checkpoint_base_container.data)
510-
rank = dist.get_rank()
511-
local_rank = dist.get_node_local_rank()
526+
rank = self._global_rank_getter()
527+
local_rank = self._local_rank_getter()
512528

513529
local_candidate_ckpt_ids = []
514530

@@ -532,8 +548,8 @@ def get_candidate_checkpoints(
532548
else:
533549
_LOGGER.debug("Rank %s: Base path '%s' is not a directory or does not exist.", rank, base_path)
534550

535-
all_checkpoint_container_path_lists = [None for _ in range(dist.get_world_size())]
536-
dist.all_gather_object(all_checkpoint_container_path_lists, local_candidate_ckpt_ids)
551+
all_checkpoint_container_path_lists = [None for _ in range(self._world_size_getter())]
552+
self._all_gather_object_func(all_checkpoint_container_path_lists, local_candidate_ckpt_ids)
537553
_LOGGER.debug(
538554
"Rank %s: Gathered checkpoint container paths from all ranks: '%s'",
539555
rank,
@@ -589,8 +605,8 @@ def get_checkpoint_objects_by_rank(
589605

590606
local_objects.extend(self._get_extra_local_objects(container_path))
591607

592-
all_objects_by_rank_paths = [None for _ in range(dist.get_world_size())]
593-
dist.all_gather_object(all_objects_by_rank_paths, local_objects)
608+
all_objects_by_rank_paths = [None for _ in range(self._world_size_getter())]
609+
self._all_gather_object_func(all_objects_by_rank_paths, local_objects)
594610

595611
result = {}
596612
object_locations = defaultdict(list)
@@ -620,7 +636,7 @@ def retrieve_checkpoint(
620636
If empty for this rank, no retrieval is needed.
621637
"""
622638

623-
rank = dist.get_rank()
639+
rank = self._global_rank_getter()
624640
all_success = True
625641

626642
# Only proceed with retrieval if we have items to retrieve
@@ -656,8 +672,8 @@ def retrieve_checkpoint(
656672

657673
# Gather success status from all ranks
658674
_LOGGER.debug("Gathering success status from all ranks")
659-
all_success_list = [None for _ in range(dist.get_world_size())]
660-
dist.all_gather_object(all_success_list, all_success)
675+
all_success_list = [None for _ in range(self._world_size_getter())]
676+
self._all_gather_object_func(all_success_list, all_success)
661677
_LOGGER.debug("All success list: '%s'", all_success_list)
662678
return all(all_success_list)
663679

tests/adapter/megatron/test_load_strategies.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,11 @@ def test_load_metadata_with_default_loader(checkpoint_directory, mocker):
276276
loader = DefaultMLFlashpointCheckpointLoader(
277277
checkpoint_object_manager=CheckpointObjectManager(),
278278
replication_manager=mock_replication_manager,
279+
global_rank_getter=lambda: 0,
280+
local_rank_getter=lambda: 0,
281+
broadcast_object_list_func=lambda *args, **kwargs: None,
282+
all_gather_object_func=lambda *args, **kwargs: None,
283+
world_size_getter=lambda: 1,
279284
)
280285
strategy = MLFlashpointMegatronLoadStrategy(checkpoint_loader=loader, replication_manager=mock_replication_manager)
281286

tests/adapter/nemo/test_auto_resume.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,11 @@ def test_initializer(self, mocker):
6464
checkpoint_loader = DefaultMLFlashpointCheckpointLoader(
6565
checkpoint_object_manager=chkpt_obj_manager,
6666
replication_manager=replication_manager,
67+
global_rank_getter=lambda: 0,
68+
local_rank_getter=lambda: 0,
69+
broadcast_object_list_func=lambda *args, **kwargs: None,
70+
all_gather_object_func=lambda *args, **kwargs: None,
71+
world_size_getter=lambda: 1,
6772
)
6873
base_container = CheckpointContainerId("/tmp/ml_flashpoint_checkpoints")
6974

@@ -83,6 +88,11 @@ def test_initializer_superclass_properties_are_correct(self, mocker):
8388
checkpoint_loader = DefaultMLFlashpointCheckpointLoader(
8489
checkpoint_object_manager=CheckpointObjectManager(),
8590
replication_manager=ReplicationManager(),
91+
global_rank_getter=lambda: 0,
92+
local_rank_getter=lambda: 0,
93+
broadcast_object_list_func=lambda *args, **kwargs: None,
94+
all_gather_object_func=lambda *args, **kwargs: None,
95+
world_size_getter=lambda: 1,
8696
)
8797
base_container = CheckpointContainerId("/tmp/ml_flashpoint_checkpoints")
8898

@@ -103,6 +113,11 @@ def test_initializer_propagates_true_params(self):
103113
checkpoint_loader = DefaultMLFlashpointCheckpointLoader(
104114
checkpoint_object_manager=CheckpointObjectManager(),
105115
replication_manager=ReplicationManager(),
116+
global_rank_getter=lambda: 0,
117+
local_rank_getter=lambda: 0,
118+
broadcast_object_list_func=lambda *args, **kwargs: None,
119+
all_gather_object_func=lambda *args, **kwargs: None,
120+
world_size_getter=lambda: 1,
106121
)
107122
base_container = CheckpointContainerId("/tmp/ml_flashpoint_checkpoints")
108123

@@ -122,7 +137,13 @@ def test_initializer_respects_params(self):
122137
"""Tests that init respects the passed parameters for resume flags."""
123138
# Arrange
124139
checkpoint_loader = DefaultMLFlashpointCheckpointLoader(
125-
checkpoint_object_manager=CheckpointObjectManager(), replication_manager=ReplicationManager()
140+
checkpoint_object_manager=CheckpointObjectManager(),
141+
replication_manager=ReplicationManager(),
142+
global_rank_getter=lambda: 0,
143+
local_rank_getter=lambda: 0,
144+
broadcast_object_list_func=lambda *args, **kwargs: None,
145+
all_gather_object_func=lambda *args, **kwargs: None,
146+
world_size_getter=lambda: 1,
126147
)
127148
base_container = CheckpointContainerId("/tmp/ml_flashpoint_checkpoints")
128149

@@ -142,7 +163,13 @@ def test_initializer_passes_kwargs_to_super(self):
142163
"""Tests that kwargs (like restore_config) are passed to the superclass."""
143164
# Arrange
144165
checkpoint_loader = DefaultMLFlashpointCheckpointLoader(
145-
checkpoint_object_manager=CheckpointObjectManager(), replication_manager=ReplicationManager()
166+
checkpoint_object_manager=CheckpointObjectManager(),
167+
replication_manager=ReplicationManager(),
168+
global_rank_getter=lambda: 0,
169+
local_rank_getter=lambda: 0,
170+
broadcast_object_list_func=lambda *args, **kwargs: None,
171+
all_gather_object_func=lambda *args, **kwargs: None,
172+
world_size_getter=lambda: 1,
146173
)
147174
base_container = CheckpointContainerId("/tmp/ml_flashpoint_checkpoints")
148175
restore_config = RestoreConfig(path="nemo://some-model")

0 commit comments

Comments
 (0)