Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions scripts/train_dflash.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from specforge.optimizer import BF16Optimizer
from specforge.tracker import create_tracker
from specforge.utils import (
cleanup_checkpoints,
get_last_checkpoint,
get_local_device,
print_on_rank0,
Expand Down Expand Up @@ -151,6 +152,13 @@ def parse_args():
output_group.add_argument("--log-interval", type=int, default=50)
output_group.add_argument("--eval-interval", type=int, default=1000)
output_group.add_argument("--save-interval", type=int, default=1000)
output_group.add_argument(
"--save-total-limit",
type=int,
default=None,
help="Maximum number of checkpoints to keep. Older checkpoints are "
"removed when this limit is exceeded. None (default) keeps all.",
)

optimization_group = parser.add_argument_group("optimization")
optimization_group.add_argument(
Expand Down Expand Up @@ -604,6 +612,10 @@ def main():
save_checkpoint(
args, epoch, global_step, dflash_model, draft_model, optimizer
)
cleanup_checkpoints(
args.output_dir, args.save_total_limit
)
dist.barrier()
Comment on lines +615 to +618

save_checkpoint(
args, args.num_epochs, global_step, dflash_model, draft_model, optimizer
Expand Down
18 changes: 17 additions & 1 deletion scripts/train_domino.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,12 @@
from specforge.modeling.target.target_utils import TargetEmbeddingsAndHead
from specforge.optimizer import BF16Optimizer
from specforge.tracker import create_tracker
from specforge.utils import get_last_checkpoint, print_on_rank0, print_with_rank
from specforge.utils import (
cleanup_checkpoints,
get_last_checkpoint,
print_on_rank0,
print_with_rank,
)


def parse_args():
Expand Down Expand Up @@ -135,6 +140,13 @@ def parse_args():
output_group.add_argument("--log-interval", type=int, default=50)
output_group.add_argument("--eval-interval", type=int, default=1000)
output_group.add_argument("--save-interval", type=int, default=1000)
output_group.add_argument(
"--save-total-limit",
type=int,
default=None,
help="Maximum number of checkpoints to keep. Older checkpoints are "
"removed when this limit is exceeded. None (default) keeps all.",
)

optimization_group = parser.add_argument_group("optimization")
optimization_group.add_argument(
Expand Down Expand Up @@ -650,6 +662,10 @@ def main():
save_checkpoint(
args, epoch, global_step, domino_model, draft_model, optimizer
)
cleanup_checkpoints(
args.output_dir, args.save_total_limit
)
dist.barrier()
Comment on lines +665 to +668

save_checkpoint(
args, args.num_epochs, global_step, domino_model, draft_model, optimizer
Expand Down
12 changes: 12 additions & 0 deletions scripts/train_eagle3.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from specforge.optimizer import BF16Optimizer
from specforge.tracker import Tracker, create_tracker, get_tracker_class
from specforge.utils import (
cleanup_checkpoints,
create_draft_config_from_target,
get_last_checkpoint,
print_args_with_dots,
Expand Down Expand Up @@ -188,6 +189,13 @@ def build_parser() -> ArgumentParser:
)
training_group.add_argument("--eval-interval", type=int, default=5000)
training_group.add_argument("--save-interval", type=int, default=5000)
training_group.add_argument(
"--save-total-limit",
type=int,
default=None,
help="Maximum number of checkpoints to keep. Older checkpoints are "
"removed when this limit is exceeded. None (default) keeps all.",
)
training_group.add_argument(
"--log-interval",
type=int,
Expand Down Expand Up @@ -1345,6 +1353,10 @@ def main():
if global_step % (args.save_interval * args.draft_accumulation_steps) == 0:
# Save the model
save_checkpoints(args, epoch, global_step, eagle3_model, optimizer)
cleanup_checkpoints(
args.output_dir, args.save_total_limit
)
dist.barrier()
Comment on lines +1356 to +1359

if args.max_num_steps is not None and global_step >= args.max_num_steps:
break
Expand Down
54 changes: 54 additions & 0 deletions specforge/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging
import os
import re
import shutil
from contextlib import contextmanager

import torch
Expand Down Expand Up @@ -148,6 +149,59 @@ def sort_key(x):
return os.path.join(folder, last_checkpoint), (epoch, step)


def cleanup_checkpoints(output_dir, save_total_limit, prefix="epoch"):
"""
Remove oldest checkpoint directories exceeding save_total_limit.

Only rank 0 performs deletion. Callers should invoke ``dist.barrier()``
after this function in distributed settings.

Args:
output_dir: Directory containing checkpoint folders.
save_total_limit: Maximum number of checkpoints to keep.
``None`` or non-positive means no cleanup (backward compatible).
prefix: Checkpoint directory prefix, default is "epoch".

Returns:
Number of checkpoints removed.
"""
if save_total_limit is None or save_total_limit <= 0:
return 0

if not os.path.isdir(output_dir):
return 0
Comment on lines +168 to +172

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

In distributed training settings with a shared filesystem (e.g., NFS, Lustre), having all ranks attempt to delete the same checkpoint directories simultaneously will cause race conditions, leading to FileNotFoundError or OSError on non-zero ranks and crashing the training run.

Although the docstring mentions that only rank 0 performs the deletion, the implementation does not actually enforce this check. Adding a check for dist.is_initialized() and dist.get_rank() != 0 ensures that only rank 0 executes the cleanup, while other ranks return early and wait at the subsequent dist.barrier() in the training scripts.

Suggested change
if save_total_limit is None or save_total_limit <= 0:
return 0
if not os.path.isdir(output_dir):
return 0
if save_total_limit is None or save_total_limit <= 0:
return 0
if dist.is_initialized() and dist.get_rank() != 0:
return 0
if not os.path.isdir(output_dir):
return 0


Comment on lines +168 to +173
content = os.listdir(output_dir)
_re_checkpoint = re.compile(rf"^{re.escape(prefix)}_(\d+)(?:_step_(\d+))?$")

checkpoints = [
path
for path in content
if _re_checkpoint.search(path) is not None
and os.path.isdir(os.path.join(output_dir, path))
]

if len(checkpoints) <= save_total_limit:
return 0

def sort_key(x):
match = _re_checkpoint.search(x)
epoch = int(match.group(1))
step = int(match.group(2)) if match.group(2) else 0
return (epoch, step)

checkpoints.sort(key=sort_key)
num_to_remove = len(checkpoints) - save_total_limit
removed = 0
for ckpt in checkpoints[:num_to_remove]:
ckpt_path = os.path.join(output_dir, ckpt)
shutil.rmtree(ckpt_path)
print_on_rank0(f"Removed old checkpoint: {ckpt_path}")
removed += 1
Comment on lines +196 to +200

return removed


def generate_draft_model_config(
target_model_path: str, template_config_path: str = None, cache_dir: str = None
):
Expand Down
114 changes: 114 additions & 0 deletions tests/test_utils/test_checkpoint_cleanup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
"""Tests for checkpoint cleanup utility."""

import os
import tempfile
import unittest

from specforge.utils import cleanup_checkpoints


def create_checkpoint_dir(output_dir, epoch, step):
"""Create a mock checkpoint directory."""
dirname = f"epoch_{epoch}_step_{step}"
path = os.path.join(output_dir, dirname)
os.makedirs(path, exist_ok=True)
# Add a dummy file so the dir is non-empty
with open(os.path.join(path, "config.json"), "w") as f:
f.write("{}")
return path


class TestCleanupCheckpoints(unittest.TestCase):
def setUp(self):
self.tmpdir = tempfile.mkdtemp()

def tearDown(self):
import shutil

shutil.rmtree(self.tmpdir, ignore_errors=True)

def test_no_cleanup_when_limit_none(self):
"""No checkpoints removed when save_total_limit is None."""
for i in range(5):
create_checkpoint_dir(self.tmpdir, 0, (i + 1) * 1000)
removed = cleanup_checkpoints(self.tmpdir, None)
self.assertEqual(removed, 0)
self.assertEqual(len(os.listdir(self.tmpdir)), 5)

def test_no_cleanup_when_limit_zero(self):
"""No checkpoints removed when save_total_limit is 0."""
for i in range(5):
create_checkpoint_dir(self.tmpdir, 0, (i + 1) * 1000)
removed = cleanup_checkpoints(self.tmpdir, 0)
self.assertEqual(removed, 0)
self.assertEqual(len(os.listdir(self.tmpdir)), 5)

def test_no_cleanup_when_limit_negative(self):
"""No checkpoints removed when save_total_limit is negative."""
for i in range(5):
create_checkpoint_dir(self.tmpdir, 0, (i + 1) * 1000)
removed = cleanup_checkpoints(self.tmpdir, -1)
self.assertEqual(removed, 0)

def test_removes_oldest_checkpoints(self):
"""Older checkpoints removed when limit exceeded."""
for i in range(5):
create_checkpoint_dir(self.tmpdir, 0, (i + 1) * 1000)
removed = cleanup_checkpoints(self.tmpdir, 2)
self.assertEqual(removed, 3)
remaining = sorted(os.listdir(self.tmpdir))
self.assertEqual(remaining, ["epoch_0_step_4000", "epoch_0_step_5000"])

def test_no_removal_when_under_limit(self):
"""No checkpoints removed when count <= limit."""
for i in range(3):
create_checkpoint_dir(self.tmpdir, 0, (i + 1) * 1000)
removed = cleanup_checkpoints(self.tmpdir, 5)
self.assertEqual(removed, 0)
self.assertEqual(len(os.listdir(self.tmpdir)), 3)

def test_removes_oldest_across_epochs(self):
"""Correct removal ordering across multiple epochs."""
steps = [
(0, 1000),
(0, 2000),
(1, 3000),
(1, 4000),
(2, 5000),
]
for epoch, step in steps:
create_checkpoint_dir(self.tmpdir, epoch, step)
removed = cleanup_checkpoints(self.tmpdir, 2)
self.assertEqual(removed, 3)
remaining = sorted(os.listdir(self.tmpdir))
self.assertEqual(remaining, ["epoch_1_step_4000", "epoch_2_step_5000"])

def test_nonexistent_dir(self):
"""No error when output_dir does not exist."""
removed = cleanup_checkpoints("/nonexistent/path", 3)
self.assertEqual(removed, 0)

def test_empty_dir(self):
"""No error when output_dir is empty."""
removed = cleanup_checkpoints(self.tmpdir, 3)
self.assertEqual(removed, 0)

def test_ignores_non_checkpoint_dirs(self):
"""Non-checkpoint directories are not affected."""
create_checkpoint_dir(self.tmpdir, 0, 1000)
create_checkpoint_dir(self.tmpdir, 0, 2000)
create_checkpoint_dir(self.tmpdir, 0, 3000)
# Create non-checkpoint dirs/files
os.makedirs(os.path.join(self.tmpdir, "some_other_dir"), exist_ok=True)
with open(os.path.join(self.tmpdir, "random_file.txt"), "w") as f:
f.write("test")
removed = cleanup_checkpoints(self.tmpdir, 1)
self.assertEqual(removed, 2)
remaining = sorted(os.listdir(self.tmpdir))
self.assertIn("epoch_0_step_3000", remaining)
self.assertIn("some_other_dir", remaining)
self.assertIn("random_file.txt", remaining)


if __name__ == "__main__":
unittest.main()