Fix per-rank training state resume#587
Conversation
There was a problem hiding this comment.
Code Review
This pull request updates the training scripts (train_dflash.py, train_domino.py, and train_eagle3.py) to save and load rank-specific training states using a new helper function get_training_state_path in specforge/utils.py. It also updates the optimizer state dictionary to include the learning rate and adjusts how the optimizer state is loaded in train_domino.py. The review feedback highlights that removing the os.path.exists check when loading the training state across all three training scripts makes the resume process fragile, as it will crash with a FileNotFoundError if the training state file is missing. The reviewer suggests restoring the existence check to ensure a graceful fallback.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| training_state_path = get_training_state_path(draft_model_last_checkpoint) | ||
| resume_state = torch.load( | ||
| training_state_path, map_location="cpu", weights_only=False | ||
| ) | ||
| print( | ||
| f"Will resume from epoch {resume_state['epoch']}, " | ||
| f"step {resume_state['global_step']}" | ||
| ) |
There was a problem hiding this comment.
Removing the os.path.exists check for the training state file makes the resume process fragile. If a user attempts to resume training from a checkpoint that only contains model weights (or if the training state files were deleted/not saved), the script will crash with a FileNotFoundError instead of gracefully falling back to starting with a fresh optimizer state. Restoring the existence check ensures robust fallback behavior.
| training_state_path = get_training_state_path(draft_model_last_checkpoint) | |
| resume_state = torch.load( | |
| training_state_path, map_location="cpu", weights_only=False | |
| ) | |
| print( | |
| f"Will resume from epoch {resume_state['epoch']}, " | |
| f"step {resume_state['global_step']}" | |
| ) | |
| training_state_path = get_training_state_path(draft_model_last_checkpoint) | |
| if os.path.exists(training_state_path): | |
| resume_state = torch.load( | |
| training_state_path, map_location="cpu", weights_only=False | |
| ) | |
| print( | |
| f"Will resume from epoch {resume_state['epoch']}, " | |
| f"step {resume_state['global_step']}" | |
| ) |
| training_state_path = get_training_state_path(draft_model_last_checkpoint) | ||
| resume_state = torch.load( | ||
| training_state_path, map_location="cpu", weights_only=False | ||
| ) | ||
| print( | ||
| f"Will resume from epoch {resume_state['epoch']}, " | ||
| f"step {resume_state['global_step']}" | ||
| ) |
There was a problem hiding this comment.
Removing the os.path.exists check for the training state file makes the resume process fragile. If a user attempts to resume training from a checkpoint that only contains model weights (or if the training state files were deleted/not saved), the script will crash with a FileNotFoundError instead of gracefully falling back to starting with a fresh optimizer state. Restoring the existence check ensures robust fallback behavior.
| training_state_path = get_training_state_path(draft_model_last_checkpoint) | |
| resume_state = torch.load( | |
| training_state_path, map_location="cpu", weights_only=False | |
| ) | |
| print( | |
| f"Will resume from epoch {resume_state['epoch']}, " | |
| f"step {resume_state['global_step']}" | |
| ) | |
| training_state_path = get_training_state_path(draft_model_last_checkpoint) | |
| if os.path.exists(training_state_path): | |
| resume_state = torch.load( | |
| training_state_path, map_location="cpu", weights_only=False | |
| ) | |
| print( | |
| f"Will resume from epoch {resume_state['epoch']}, " | |
| f"step {resume_state['global_step']}" | |
| ) |
| training_state_path = get_training_state_path(draft_model_last_checkpoint) | ||
| resume_state = torch.load( | ||
| training_state_path, map_location="cpu", weights_only=False | ||
| ) | ||
| print_on_rank0( | ||
| f"Loaded training state from {training_state_path}: " | ||
| f"epoch={resume_state['epoch']}, step={resume_state['global_step']}" | ||
| ) |
There was a problem hiding this comment.
Removing the os.path.exists check for the training state file makes the resume process fragile. If a user attempts to resume training from a checkpoint that only contains model weights (or if the training state files were deleted/not saved), the script will crash with a FileNotFoundError instead of gracefully falling back to starting with a fresh optimizer state. Restoring the existence check ensures robust fallback behavior.
| training_state_path = get_training_state_path(draft_model_last_checkpoint) | |
| resume_state = torch.load( | |
| training_state_path, map_location="cpu", weights_only=False | |
| ) | |
| print_on_rank0( | |
| f"Loaded training state from {training_state_path}: " | |
| f"epoch={resume_state['epoch']}, step={resume_state['global_step']}" | |
| ) | |
| training_state_path = get_training_state_path(draft_model_last_checkpoint) | |
| if os.path.exists(training_state_path): | |
| resume_state = torch.load( | |
| training_state_path, map_location="cpu", weights_only=False | |
| ) | |
| print_on_rank0( | |
| f"Loaded training state from {training_state_path}: " | |
| f"epoch={resume_state['epoch']}, step={resume_state['global_step']}" | |
| ) |
Summary
Fix resume behavior by saving and loading optimizer/scheduler training state per rank.
Changes
training_state_rank_{rank}.pttraining_state.ptlrin optimizer stateValidation
python -m py_compile specforge/optimizer.py specforge/utils.py scripts/train_eagle3.py scripts/train_dflash.py scripts/train_domino.pygit diff --checkChecklist