Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,5 @@ endif()
# Add your child projects. They will inherit all settings and
# dependencies (absl::*, gtest*, pybind11::*) from this file.
add_subdirectory(src/ml_flashpoint/checkpoint_object_manager/buffer_object)
add_subdirectory(src/ml_flashpoint/checkpoint_object_manager/object_manager)
add_subdirectory(src/ml_flashpoint/replication/transfer_service)

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

38 changes: 28 additions & 10 deletions src/ml_flashpoint/core/checkpoint_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@
import os
import pickle
import queue
import subprocess
import threading
from typing import Callable, Optional, Protocol, Union
from typing import Callable, Protocol, Union

import torch
from torch.distributed.checkpoint import metadata as torchdistmeta
Expand All @@ -31,7 +32,6 @@
from typing_extensions import override

from ml_flashpoint.checkpoint_object_manager.checkpoint_object_manager import CheckpointObjectManager
from ml_flashpoint.checkpoint_object_manager.object_manager import object_manager_ext
from ml_flashpoint.core.checkpoint_id_types import CheckpointContainerId, CheckpointObjectId
from ml_flashpoint.core.defaults import DIRTY_MARKER_SUFFIX, CheckpointFormat, default_metadata_object_name
from ml_flashpoint.core.mlf_logging import get_logger
Expand Down Expand Up @@ -269,7 +269,7 @@ def write_metadata(
pass

@abc.abstractmethod
def finalize_checkpoint(self, checkpoint_id: CheckpointContainerId) -> Optional[object_manager_ext.BasicFutureVoid]:
def finalize_checkpoint(self, checkpoint_id: CheckpointContainerId) -> Union[subprocess.Popen, None]:
"""Finalize the checkpoint for checkpoint_id, indicating it is complete and safe to recover from.
This specifically does the following:
1. Cleans up the unfinished marker created by initialize_checkpoint().
Expand All @@ -281,7 +281,7 @@ def finalize_checkpoint(self, checkpoint_id: CheckpointContainerId) -> Optional[
checkpoint_id: The CheckpointContainerId to mark as finalized.

Returns:
A future that completes when deletion of older checkpoints is done, or None if no deletion was started.
The subprocess.Popen object that handles deletion of older checkpoints, or None if no deletion was started.
"""
pass

Expand Down Expand Up @@ -549,7 +549,7 @@ def write_metadata(

@override
@log_execution_time(logger=_LOGGER, name="finalize_checkpoint")
def finalize_checkpoint(self, checkpoint_id: CheckpointContainerId) -> Optional[object_manager_ext.BasicFutureVoid]:
def finalize_checkpoint(self, checkpoint_id: CheckpointContainerId) -> Union[subprocess.Popen, None]:
self._remove_dirty_checkpoint_marker(checkpoint_id)
# synchronize across ranks to guarantee they all completed checkpointing before proceeding
with log_execution_time(logger=_LOGGER, name="finalize_checkpoint__barrier_func", level=logging.DEBUG):
Expand Down Expand Up @@ -707,9 +707,7 @@ def _write_to_buffer_from_queue_worker(
object_write_bucket_queue.task_done()

@log_execution_time(logger=_LOGGER, name="_remove_older_checkpoints")
def _remove_older_checkpoints(
self, older_than: CheckpointContainerId
) -> Optional[object_manager_ext.BasicFutureVoid]:
def _remove_older_checkpoints(self, older_than: CheckpointContainerId) -> subprocess.Popen | None:
"""Scans for sibling checkpoint containers to `older_than`, by listing the children of its parent and filtering
for those that match the expected format as a safety check, and then deletes all those that are considered
older _async_.
Expand All @@ -721,7 +719,7 @@ def _remove_older_checkpoints(
older than this will be removed.

Returns:
A future that completes when deletion is done, or None if no deletion was started.
The subprocess.Popen object that handles deletion of older checkpoints, or None if no deletion was started.
"""
parent_dir = os.path.dirname(older_than.data)
older_than_step = CheckpointContainerId.parse_version_container_step(os.path.basename(older_than.data))
Expand All @@ -741,7 +739,27 @@ def _remove_older_checkpoints(
if step is not None and step < older_than_step:
siblings_to_delete.add(full_path)

return object_manager_ext.delete_directories_async(list(siblings_to_delete))
if siblings_to_delete:
try:
# We use a background subprocess (rm -rf) instead of Python's shutil.rmtree
# to avoid blocking the main Python thread (GIL) during large directory deletions.
# This allows the training process to continue immediately while the OS handles
# the deletion asynchronously.
#
# start_new_session=True is used to ensure the deletion process is decoupled
# from the parent process group, preventing it from being interrupted by signals
# (like SIGINT during Ctrl+C) sent to the main training job.
p = subprocess.Popen(
Comment thread
Leahlijuan marked this conversation as resolved.
["rm", "-rf"] + list(siblings_to_delete),
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
start_new_session=True,
)
return p
except Exception as e:
_LOGGER.exception("Background deletion of old checkpoints failed: %s", e)
return None
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hm ideally we'd at least return something that informs the caller there was an error. Otherwise they can't distinguish between "nothing to do" and "failed to delete".

Maybe bubble up the exception? Or can return a wrapper type encapsulating an Optional proc and an optional error e.g.

@dataclass
class DeletionResult:
    proc: Optional[subprocess.Popen] = None
    error: Optional[Exception] = None
    skipped: bool = False

Not a blocker, but will make the API more explicit.

Then in the failure test, we can confirm there was an error by asserting on the result, as opposed to just expecting proc to be None, which is the same expectation when there is no error and nothing to do.

return None

def _save_tensor_optimized(self, tensor: torch.Tensor, buffer_io_writer):
"""Saves a tensor to the buffer using a zero-copy approach where possible.
Expand Down
Loading
Loading