-
Notifications
You must be signed in to change notification settings - Fork 271
Fix per-rank training state resume #587
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -33,7 +33,12 @@ | |||||||||||||||||||||||||||||||||||
| from specforge.modeling.target.target_utils import TargetEmbeddingsAndHead | ||||||||||||||||||||||||||||||||||||
| from specforge.optimizer import BF16Optimizer | ||||||||||||||||||||||||||||||||||||
| from specforge.tracker import create_tracker | ||||||||||||||||||||||||||||||||||||
| from specforge.utils import get_last_checkpoint, print_on_rank0, print_with_rank | ||||||||||||||||||||||||||||||||||||
| from specforge.utils import ( | ||||||||||||||||||||||||||||||||||||
| get_last_checkpoint, | ||||||||||||||||||||||||||||||||||||
| get_training_state_path, | ||||||||||||||||||||||||||||||||||||
| print_on_rank0, | ||||||||||||||||||||||||||||||||||||
| print_with_rank, | ||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| def parse_args(): | ||||||||||||||||||||||||||||||||||||
|
|
@@ -324,17 +329,17 @@ def save_checkpoint(args, epoch, step, domino_model, draft_model, optimizer): | |||||||||||||||||||||||||||||||||||
| if "draft_model." in k | ||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| if dist.get_rank() == 0: | ||||||||||||||||||||||||||||||||||||
| torch.save( | ||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||
| "epoch": epoch, | ||||||||||||||||||||||||||||||||||||
| "global_step": step, | ||||||||||||||||||||||||||||||||||||
| "args": args, | ||||||||||||||||||||||||||||||||||||
| **optimizer.state_dict(), | ||||||||||||||||||||||||||||||||||||
| }, | ||||||||||||||||||||||||||||||||||||
| os.path.join(save_dir, "training_state.pt"), | ||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||
| torch.save( | ||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||
| "epoch": epoch, | ||||||||||||||||||||||||||||||||||||
| "global_step": step, | ||||||||||||||||||||||||||||||||||||
| "args": args, | ||||||||||||||||||||||||||||||||||||
| **optimizer.state_dict(), | ||||||||||||||||||||||||||||||||||||
| }, | ||||||||||||||||||||||||||||||||||||
| get_training_state_path(save_dir), | ||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| if dist.get_rank() == 0: | ||||||||||||||||||||||||||||||||||||
| draft_model.save_pretrained(save_dir, state_dict=draft_state_dict) | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| modeling_src = os.path.join( | ||||||||||||||||||||||||||||||||||||
|
|
@@ -475,17 +480,14 @@ def main(): | |||||||||||||||||||||||||||||||||||
| del loaded_model | ||||||||||||||||||||||||||||||||||||
| print("Loaded draft model weights from checkpoint") | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| training_state_path = os.path.join( | ||||||||||||||||||||||||||||||||||||
| draft_model_last_checkpoint, "training_state.pt" | ||||||||||||||||||||||||||||||||||||
| training_state_path = get_training_state_path(draft_model_last_checkpoint) | ||||||||||||||||||||||||||||||||||||
| resume_state = torch.load( | ||||||||||||||||||||||||||||||||||||
| training_state_path, map_location="cpu", weights_only=False | ||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||
| print( | ||||||||||||||||||||||||||||||||||||
| f"Will resume from epoch {resume_state['epoch']}, " | ||||||||||||||||||||||||||||||||||||
| f"step {resume_state['global_step']}" | ||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||
|
Comment on lines
+483
to
490
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Removing the
Suggested change
|
||||||||||||||||||||||||||||||||||||
| if os.path.exists(training_state_path): | ||||||||||||||||||||||||||||||||||||
| resume_state = torch.load( | ||||||||||||||||||||||||||||||||||||
| training_state_path, map_location="cpu", weights_only=False | ||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||
| print( | ||||||||||||||||||||||||||||||||||||
| f"Will resume from epoch {resume_state['epoch']}, " | ||||||||||||||||||||||||||||||||||||
| f"step {resume_state['global_step']}" | ||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| tokenizer = AutoTokenizer.from_pretrained(args.target_model_path) | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
|
|
@@ -553,7 +555,7 @@ def main(): | |||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| if resume_state is not None: | ||||||||||||||||||||||||||||||||||||
| optimizer.scheduler.load_state_dict(resume_state["scheduler_state_dict"]) | ||||||||||||||||||||||||||||||||||||
| optimizer.load_state_dict(resume_state) | ||||||||||||||||||||||||||||||||||||
| start_epoch = resume_state["epoch"] | ||||||||||||||||||||||||||||||||||||
| global_step = resume_state["global_step"] | ||||||||||||||||||||||||||||||||||||
| del resume_state | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -48,6 +48,7 @@ | |||||||||||||||||||||||||||||||||||
| from specforge.utils import ( | ||||||||||||||||||||||||||||||||||||
| create_draft_config_from_target, | ||||||||||||||||||||||||||||||||||||
| get_last_checkpoint, | ||||||||||||||||||||||||||||||||||||
| get_training_state_path, | ||||||||||||||||||||||||||||||||||||
| print_args_with_dots, | ||||||||||||||||||||||||||||||||||||
| print_on_rank0, | ||||||||||||||||||||||||||||||||||||
| print_with_rank, | ||||||||||||||||||||||||||||||||||||
|
|
@@ -532,17 +533,14 @@ def build_draft_model(args: Namespace) -> Tuple[AutoDraftModelConfig, nn.Module] | |||||||||||||||||||||||||||||||||||
| # Load training state (optimizer, scheduler, epoch, step) for true resume | ||||||||||||||||||||||||||||||||||||
| resume_state = None | ||||||||||||||||||||||||||||||||||||
| if is_resume_checkpoint and draft_model_last_checkpoint: | ||||||||||||||||||||||||||||||||||||
| training_state_path = os.path.join( | ||||||||||||||||||||||||||||||||||||
| draft_model_last_checkpoint, "training_state.pt" | ||||||||||||||||||||||||||||||||||||
| training_state_path = get_training_state_path(draft_model_last_checkpoint) | ||||||||||||||||||||||||||||||||||||
| resume_state = torch.load( | ||||||||||||||||||||||||||||||||||||
| training_state_path, map_location="cpu", weights_only=False | ||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||
| print_on_rank0( | ||||||||||||||||||||||||||||||||||||
| f"Loaded training state from {training_state_path}: " | ||||||||||||||||||||||||||||||||||||
| f"epoch={resume_state['epoch']}, step={resume_state['global_step']}" | ||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||
|
Comment on lines
+536
to
543
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Removing the
Suggested change
|
||||||||||||||||||||||||||||||||||||
| if os.path.exists(training_state_path): | ||||||||||||||||||||||||||||||||||||
| resume_state = torch.load( | ||||||||||||||||||||||||||||||||||||
| training_state_path, map_location="cpu", weights_only=False | ||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||
| print_on_rank0( | ||||||||||||||||||||||||||||||||||||
| f"Loaded training state from {training_state_path}: " | ||||||||||||||||||||||||||||||||||||
| f"epoch={resume_state['epoch']}, step={resume_state['global_step']}" | ||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| draft_model.load_embedding(args.target_model_path, embedding_key=args.embedding_key) | ||||||||||||||||||||||||||||||||||||
| draft_model.freeze_embedding() | ||||||||||||||||||||||||||||||||||||
|
|
@@ -699,13 +697,10 @@ def save_checkpoints( | |||||||||||||||||||||||||||||||||||
| state_to_save.update(optimizer.state_dict()) | ||||||||||||||||||||||||||||||||||||
| draft_model_state_dict = filter_draft_state_dict(model_state_dict) | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| torch.save(state_to_save, get_training_state_path(epoch_output_dir)) | ||||||||||||||||||||||||||||||||||||
| if dist.get_rank() == 0: | ||||||||||||||||||||||||||||||||||||
| torch.save( | ||||||||||||||||||||||||||||||||||||
| state_to_save, | ||||||||||||||||||||||||||||||||||||
| os.path.join(epoch_output_dir, "training_state.pt"), | ||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||
| print_on_rank0( | ||||||||||||||||||||||||||||||||||||
| f"Saved full training state to {epoch_output_dir}/training_state.pt" | ||||||||||||||||||||||||||||||||||||
| f"Saved full training state to {epoch_output_dir}/training_state_rank_*.pt" | ||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||
| eagle3_model.draft_model.save_pretrained( | ||||||||||||||||||||||||||||||||||||
| epoch_output_dir, | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removing the
os.path.existscheck for the training state file makes the resume process fragile. If a user attempts to resume training from a checkpoint that only contains model weights (or if the training state files were deleted/not saved), the script will crash with aFileNotFoundErrorinstead of gracefully falling back to starting with a fresh optimizer state. Restoring the existence check ensures robust fallback behavior.