-
Notifications
You must be signed in to change notification settings - Fork 7
refactor(checkpoint-object-manager): deprecate C++ object_manager layer for lightweight Python subprocess async deletion #102
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
7a3c1fe
a18aaea
7d42f32
3f2d840
a51e25b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
This file was deleted.
This file was deleted.
This file was deleted.
This file was deleted.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -20,8 +20,9 @@ | |
| import os | ||
| import pickle | ||
| import queue | ||
| import subprocess | ||
| import threading | ||
| from typing import Callable, Optional, Protocol, Union | ||
| from typing import Callable, Protocol, Union | ||
|
|
||
| import torch | ||
| from torch.distributed.checkpoint import metadata as torchdistmeta | ||
|
|
@@ -31,7 +32,6 @@ | |
| from typing_extensions import override | ||
|
|
||
| from ml_flashpoint.checkpoint_object_manager.checkpoint_object_manager import CheckpointObjectManager | ||
| from ml_flashpoint.checkpoint_object_manager.object_manager import object_manager_ext | ||
| from ml_flashpoint.core.checkpoint_id_types import CheckpointContainerId, CheckpointObjectId | ||
| from ml_flashpoint.core.defaults import DIRTY_MARKER_SUFFIX, CheckpointFormat, default_metadata_object_name | ||
| from ml_flashpoint.core.mlf_logging import get_logger | ||
|
|
@@ -269,7 +269,7 @@ def write_metadata( | |
| pass | ||
|
|
||
| @abc.abstractmethod | ||
| def finalize_checkpoint(self, checkpoint_id: CheckpointContainerId) -> Optional[object_manager_ext.BasicFutureVoid]: | ||
| def finalize_checkpoint(self, checkpoint_id: CheckpointContainerId) -> Union[subprocess.Popen, None]: | ||
| """Finalize the checkpoint for checkpoint_id, indicating it is complete and safe to recover from. | ||
| This specifically does the following: | ||
| 1. Cleans up the unfinished marker created by initialize_checkpoint(). | ||
|
|
@@ -281,7 +281,7 @@ def finalize_checkpoint(self, checkpoint_id: CheckpointContainerId) -> Optional[ | |
| checkpoint_id: The CheckpointContainerId to mark as finalized. | ||
|
|
||
| Returns: | ||
| A future that completes when deletion of older checkpoints is done, or None if no deletion was started. | ||
| The subprocess.Popen object that handles deletion of older checkpoints, or None if no deletion was started. | ||
| """ | ||
| pass | ||
|
|
||
|
|
@@ -549,7 +549,7 @@ def write_metadata( | |
|
|
||
| @override | ||
| @log_execution_time(logger=_LOGGER, name="finalize_checkpoint") | ||
| def finalize_checkpoint(self, checkpoint_id: CheckpointContainerId) -> Optional[object_manager_ext.BasicFutureVoid]: | ||
| def finalize_checkpoint(self, checkpoint_id: CheckpointContainerId) -> Union[subprocess.Popen, None]: | ||
| self._remove_dirty_checkpoint_marker(checkpoint_id) | ||
| # synchronize across ranks to guarantee they all completed checkpointing before proceeding | ||
| with log_execution_time(logger=_LOGGER, name="finalize_checkpoint__barrier_func", level=logging.DEBUG): | ||
|
|
@@ -707,9 +707,7 @@ def _write_to_buffer_from_queue_worker( | |
| object_write_bucket_queue.task_done() | ||
|
|
||
| @log_execution_time(logger=_LOGGER, name="_remove_older_checkpoints") | ||
| def _remove_older_checkpoints( | ||
| self, older_than: CheckpointContainerId | ||
| ) -> Optional[object_manager_ext.BasicFutureVoid]: | ||
| def _remove_older_checkpoints(self, older_than: CheckpointContainerId) -> subprocess.Popen | None: | ||
| """Scans for sibling checkpoint containers to `older_than`, by listing the children of its parent and filtering | ||
| for those that match the expected format as a safety check, and then deletes all those that are considered | ||
| older _async_. | ||
|
|
@@ -721,7 +719,7 @@ def _remove_older_checkpoints( | |
| older than this will be removed. | ||
|
|
||
| Returns: | ||
| A future that completes when deletion is done, or None if no deletion was started. | ||
| The subprocess.Popen object that handles deletion of older checkpoints, or None if no deletion was started. | ||
| """ | ||
| parent_dir = os.path.dirname(older_than.data) | ||
| older_than_step = CheckpointContainerId.parse_version_container_step(os.path.basename(older_than.data)) | ||
|
|
@@ -741,7 +739,27 @@ def _remove_older_checkpoints( | |
| if step is not None and step < older_than_step: | ||
| siblings_to_delete.add(full_path) | ||
|
|
||
| return object_manager_ext.delete_directories_async(list(siblings_to_delete)) | ||
| if siblings_to_delete: | ||
| try: | ||
| # We use a background subprocess (rm -rf) instead of Python's shutil.rmtree | ||
| # to avoid blocking the main Python thread (GIL) during large directory deletions. | ||
| # This allows the training process to continue immediately while the OS handles | ||
| # the deletion asynchronously. | ||
| # | ||
| # start_new_session=True is used to ensure the deletion process is decoupled | ||
| # from the parent process group, preventing it from being interrupted by signals | ||
| # (like SIGINT during Ctrl+C) sent to the main training job. | ||
| p = subprocess.Popen( | ||
| ["rm", "-rf"] + list(siblings_to_delete), | ||
| stdout=subprocess.DEVNULL, | ||
| stderr=subprocess.DEVNULL, | ||
| start_new_session=True, | ||
| ) | ||
| return p | ||
| except Exception as e: | ||
| _LOGGER.exception("Background deletion of old checkpoints failed: %s", e) | ||
| return None | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. hm ideally we'd at least return something that informs the caller there was an error. Otherwise they can't distinguish between "nothing to do" and "failed to delete". Maybe bubble up the exception? Or can return a wrapper type encapsulating an Optional proc and an optional error e.g. @dataclass
class DeletionResult:
proc: Optional[subprocess.Popen] = None
error: Optional[Exception] = None
skipped: bool = FalseNot a blocker, but will make the API more explicit. Then in the failure test, we can confirm there was an error by asserting on the result, as opposed to just expecting proc to be None, which is the same expectation when there is no error and nothing to do. |
||
| return None | ||
|
|
||
| def _save_tensor_optimized(self, tensor: torch.Tensor, buffer_io_writer): | ||
| """Saves a tensor to the buffer using a zero-copy approach where possible. | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.