Skip to content

Fix per-rank training state resume#587

Open
gq112 wants to merge 1 commit into
sgl-project:mainfrom
gq112:fix/resume-pr
Open

Fix per-rank training state resume#587
gq112 wants to merge 1 commit into
sgl-project:mainfrom
gq112:fix/resume-pr

Conversation

@gq112

@gq112 gq112 commented Jun 20, 2026

Copy link
Copy Markdown

Summary

Fix resume behavior by saving and loading optimizer/scheduler training state per rank.

Changes

  • Save training state as training_state_rank_{rank}.pt
  • Load the matching per-rank state on resume
  • Stop using the shared training_state.pt
  • Restore full optimizer state for Domino resume
  • Include current lr in optimizer state

Validation

  • python -m py_compile specforge/optimizer.py specforge/utils.py scripts/train_eagle3.py scripts/train_dflash.py scripts/train_domino.py
  • git diff --check

Checklist

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request updates the training scripts (train_dflash.py, train_domino.py, and train_eagle3.py) to save and load rank-specific training states using a new helper function get_training_state_path in specforge/utils.py. It also updates the optimizer state dictionary to include the learning rate and adjusts how the optimizer state is loaded in train_domino.py. The review feedback highlights that removing the os.path.exists check when loading the training state across all three training scripts makes the resume process fragile, as it will crash with a FileNotFoundError if the training state file is missing. The reviewer suggests restoring the existence check to ensure a graceful fallback.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment thread scripts/train_dflash.py
Comment on lines +414 to 421
training_state_path = get_training_state_path(draft_model_last_checkpoint)
resume_state = torch.load(
training_state_path, map_location="cpu", weights_only=False
)
print(
f"Will resume from epoch {resume_state['epoch']}, "
f"step {resume_state['global_step']}"
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Removing the os.path.exists check for the training state file makes the resume process fragile. If a user attempts to resume training from a checkpoint that only contains model weights (or if the training state files were deleted/not saved), the script will crash with a FileNotFoundError instead of gracefully falling back to starting with a fresh optimizer state. Restoring the existence check ensures robust fallback behavior.

Suggested change
training_state_path = get_training_state_path(draft_model_last_checkpoint)
resume_state = torch.load(
training_state_path, map_location="cpu", weights_only=False
)
print(
f"Will resume from epoch {resume_state['epoch']}, "
f"step {resume_state['global_step']}"
)
training_state_path = get_training_state_path(draft_model_last_checkpoint)
if os.path.exists(training_state_path):
resume_state = torch.load(
training_state_path, map_location="cpu", weights_only=False
)
print(
f"Will resume from epoch {resume_state['epoch']}, "
f"step {resume_state['global_step']}"
)

Comment thread scripts/train_domino.py
Comment on lines +483 to 490
training_state_path = get_training_state_path(draft_model_last_checkpoint)
resume_state = torch.load(
training_state_path, map_location="cpu", weights_only=False
)
print(
f"Will resume from epoch {resume_state['epoch']}, "
f"step {resume_state['global_step']}"
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Removing the os.path.exists check for the training state file makes the resume process fragile. If a user attempts to resume training from a checkpoint that only contains model weights (or if the training state files were deleted/not saved), the script will crash with a FileNotFoundError instead of gracefully falling back to starting with a fresh optimizer state. Restoring the existence check ensures robust fallback behavior.

Suggested change
training_state_path = get_training_state_path(draft_model_last_checkpoint)
resume_state = torch.load(
training_state_path, map_location="cpu", weights_only=False
)
print(
f"Will resume from epoch {resume_state['epoch']}, "
f"step {resume_state['global_step']}"
)
training_state_path = get_training_state_path(draft_model_last_checkpoint)
if os.path.exists(training_state_path):
resume_state = torch.load(
training_state_path, map_location="cpu", weights_only=False
)
print(
f"Will resume from epoch {resume_state['epoch']}, "
f"step {resume_state['global_step']}"
)

Comment thread scripts/train_eagle3.py
Comment on lines +536 to 543
training_state_path = get_training_state_path(draft_model_last_checkpoint)
resume_state = torch.load(
training_state_path, map_location="cpu", weights_only=False
)
print_on_rank0(
f"Loaded training state from {training_state_path}: "
f"epoch={resume_state['epoch']}, step={resume_state['global_step']}"
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Removing the os.path.exists check for the training state file makes the resume process fragile. If a user attempts to resume training from a checkpoint that only contains model weights (or if the training state files were deleted/not saved), the script will crash with a FileNotFoundError instead of gracefully falling back to starting with a fresh optimizer state. Restoring the existence check ensures robust fallback behavior.

Suggested change
training_state_path = get_training_state_path(draft_model_last_checkpoint)
resume_state = torch.load(
training_state_path, map_location="cpu", weights_only=False
)
print_on_rank0(
f"Loaded training state from {training_state_path}: "
f"epoch={resume_state['epoch']}, step={resume_state['global_step']}"
)
training_state_path = get_training_state_path(draft_model_last_checkpoint)
if os.path.exists(training_state_path):
resume_state = torch.load(
training_state_path, map_location="cpu", weights_only=False
)
print_on_rank0(
f"Loaded training state from {training_state_path}: "
f"epoch={resume_state['epoch']}, step={resume_state['global_step']}"
)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants