Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 23 additions & 21 deletions scripts/train_dflash.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,12 @@
from specforge.modeling.target.target_utils import TargetEmbeddingsAndHead
from specforge.optimizer import BF16Optimizer
from specforge.tracker import create_tracker
from specforge.utils import get_last_checkpoint, print_on_rank0, print_with_rank
from specforge.utils import (
get_last_checkpoint,
get_training_state_path,
print_on_rank0,
print_with_rank,
)


def parse_args():
Expand Down Expand Up @@ -306,17 +311,17 @@ def save_checkpoint(args, epoch, step, dflash_model, draft_model, optimizer):
if "draft_model." in k
}

if dist.get_rank() == 0:
torch.save(
{
"epoch": epoch,
"global_step": step,
"args": args,
**optimizer.state_dict(),
},
os.path.join(save_dir, "training_state.pt"),
)
torch.save(
{
"epoch": epoch,
"global_step": step,
"args": args,
**optimizer.state_dict(),
},
get_training_state_path(save_dir),
)

if dist.get_rank() == 0:
draft_model.save_pretrained(save_dir, state_dict=draft_state_dict)

modeling_src = os.path.join(
Expand Down Expand Up @@ -406,17 +411,14 @@ def main():
del loaded_model
print("Loaded draft model weights from checkpoint")

training_state_path = os.path.join(
draft_model_last_checkpoint, "training_state.pt"
training_state_path = get_training_state_path(draft_model_last_checkpoint)
resume_state = torch.load(
training_state_path, map_location="cpu", weights_only=False
)
print(
f"Will resume from epoch {resume_state['epoch']}, "
f"step {resume_state['global_step']}"
)
Comment on lines +414 to 421

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Removing the os.path.exists check for the training state file makes the resume process fragile. If a user attempts to resume training from a checkpoint that only contains model weights (or if the training state files were deleted/not saved), the script will crash with a FileNotFoundError instead of gracefully falling back to starting with a fresh optimizer state. Restoring the existence check ensures robust fallback behavior.

Suggested change
training_state_path = get_training_state_path(draft_model_last_checkpoint)
resume_state = torch.load(
training_state_path, map_location="cpu", weights_only=False
)
print(
f"Will resume from epoch {resume_state['epoch']}, "
f"step {resume_state['global_step']}"
)
training_state_path = get_training_state_path(draft_model_last_checkpoint)
if os.path.exists(training_state_path):
resume_state = torch.load(
training_state_path, map_location="cpu", weights_only=False
)
print(
f"Will resume from epoch {resume_state['epoch']}, "
f"step {resume_state['global_step']}"
)

if os.path.exists(training_state_path):
resume_state = torch.load(
training_state_path, map_location="cpu", weights_only=False
)
print(
f"Will resume from epoch {resume_state['epoch']}, "
f"step {resume_state['global_step']}"
)

tokenizer = AutoTokenizer.from_pretrained(args.target_model_path)

Expand Down
46 changes: 24 additions & 22 deletions scripts/train_domino.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,12 @@
from specforge.modeling.target.target_utils import TargetEmbeddingsAndHead
from specforge.optimizer import BF16Optimizer
from specforge.tracker import create_tracker
from specforge.utils import get_last_checkpoint, print_on_rank0, print_with_rank
from specforge.utils import (
get_last_checkpoint,
get_training_state_path,
print_on_rank0,
print_with_rank,
)


def parse_args():
Expand Down Expand Up @@ -324,17 +329,17 @@ def save_checkpoint(args, epoch, step, domino_model, draft_model, optimizer):
if "draft_model." in k
}

if dist.get_rank() == 0:
torch.save(
{
"epoch": epoch,
"global_step": step,
"args": args,
**optimizer.state_dict(),
},
os.path.join(save_dir, "training_state.pt"),
)
torch.save(
{
"epoch": epoch,
"global_step": step,
"args": args,
**optimizer.state_dict(),
},
get_training_state_path(save_dir),
)

if dist.get_rank() == 0:
draft_model.save_pretrained(save_dir, state_dict=draft_state_dict)

modeling_src = os.path.join(
Expand Down Expand Up @@ -475,17 +480,14 @@ def main():
del loaded_model
print("Loaded draft model weights from checkpoint")

training_state_path = os.path.join(
draft_model_last_checkpoint, "training_state.pt"
training_state_path = get_training_state_path(draft_model_last_checkpoint)
resume_state = torch.load(
training_state_path, map_location="cpu", weights_only=False
)
print(
f"Will resume from epoch {resume_state['epoch']}, "
f"step {resume_state['global_step']}"
)
Comment on lines +483 to 490

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Removing the os.path.exists check for the training state file makes the resume process fragile. If a user attempts to resume training from a checkpoint that only contains model weights (or if the training state files were deleted/not saved), the script will crash with a FileNotFoundError instead of gracefully falling back to starting with a fresh optimizer state. Restoring the existence check ensures robust fallback behavior.

Suggested change
training_state_path = get_training_state_path(draft_model_last_checkpoint)
resume_state = torch.load(
training_state_path, map_location="cpu", weights_only=False
)
print(
f"Will resume from epoch {resume_state['epoch']}, "
f"step {resume_state['global_step']}"
)
training_state_path = get_training_state_path(draft_model_last_checkpoint)
if os.path.exists(training_state_path):
resume_state = torch.load(
training_state_path, map_location="cpu", weights_only=False
)
print(
f"Will resume from epoch {resume_state['epoch']}, "
f"step {resume_state['global_step']}"
)

if os.path.exists(training_state_path):
resume_state = torch.load(
training_state_path, map_location="cpu", weights_only=False
)
print(
f"Will resume from epoch {resume_state['epoch']}, "
f"step {resume_state['global_step']}"
)

tokenizer = AutoTokenizer.from_pretrained(args.target_model_path)

Expand Down Expand Up @@ -553,7 +555,7 @@ def main():
)

if resume_state is not None:
optimizer.scheduler.load_state_dict(resume_state["scheduler_state_dict"])
optimizer.load_state_dict(resume_state)
start_epoch = resume_state["epoch"]
global_step = resume_state["global_step"]
del resume_state
Expand Down
25 changes: 10 additions & 15 deletions scripts/train_eagle3.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
from specforge.utils import (
create_draft_config_from_target,
get_last_checkpoint,
get_training_state_path,
print_args_with_dots,
print_on_rank0,
print_with_rank,
Expand Down Expand Up @@ -532,17 +533,14 @@ def build_draft_model(args: Namespace) -> Tuple[AutoDraftModelConfig, nn.Module]
# Load training state (optimizer, scheduler, epoch, step) for true resume
resume_state = None
if is_resume_checkpoint and draft_model_last_checkpoint:
training_state_path = os.path.join(
draft_model_last_checkpoint, "training_state.pt"
training_state_path = get_training_state_path(draft_model_last_checkpoint)
resume_state = torch.load(
training_state_path, map_location="cpu", weights_only=False
)
print_on_rank0(
f"Loaded training state from {training_state_path}: "
f"epoch={resume_state['epoch']}, step={resume_state['global_step']}"
)
Comment on lines +536 to 543

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Removing the os.path.exists check for the training state file makes the resume process fragile. If a user attempts to resume training from a checkpoint that only contains model weights (or if the training state files were deleted/not saved), the script will crash with a FileNotFoundError instead of gracefully falling back to starting with a fresh optimizer state. Restoring the existence check ensures robust fallback behavior.

Suggested change
training_state_path = get_training_state_path(draft_model_last_checkpoint)
resume_state = torch.load(
training_state_path, map_location="cpu", weights_only=False
)
print_on_rank0(
f"Loaded training state from {training_state_path}: "
f"epoch={resume_state['epoch']}, step={resume_state['global_step']}"
)
training_state_path = get_training_state_path(draft_model_last_checkpoint)
if os.path.exists(training_state_path):
resume_state = torch.load(
training_state_path, map_location="cpu", weights_only=False
)
print_on_rank0(
f"Loaded training state from {training_state_path}: "
f"epoch={resume_state['epoch']}, step={resume_state['global_step']}"
)

if os.path.exists(training_state_path):
resume_state = torch.load(
training_state_path, map_location="cpu", weights_only=False
)
print_on_rank0(
f"Loaded training state from {training_state_path}: "
f"epoch={resume_state['epoch']}, step={resume_state['global_step']}"
)

draft_model.load_embedding(args.target_model_path, embedding_key=args.embedding_key)
draft_model.freeze_embedding()
Expand Down Expand Up @@ -699,13 +697,10 @@ def save_checkpoints(
state_to_save.update(optimizer.state_dict())
draft_model_state_dict = filter_draft_state_dict(model_state_dict)

torch.save(state_to_save, get_training_state_path(epoch_output_dir))
if dist.get_rank() == 0:
torch.save(
state_to_save,
os.path.join(epoch_output_dir, "training_state.pt"),
)
print_on_rank0(
f"Saved full training state to {epoch_output_dir}/training_state.pt"
f"Saved full training state to {epoch_output_dir}/training_state_rank_*.pt"
)
eagle3_model.draft_model.save_pretrained(
epoch_output_dir,
Expand Down
1 change: 1 addition & 0 deletions specforge/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def state_dict(self):
return {
"optimizer_state_dict": self.optimizer.state_dict(),
"scheduler_state_dict": self.scheduler.state_dict(),
"lr": self.get_learning_rate(),
}

def get_learning_rate(self):
Expand Down
5 changes: 5 additions & 0 deletions specforge/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,11 @@ def sort_key(x):
return os.path.join(folder, last_checkpoint), (epoch, step)


def get_training_state_path(checkpoint_dir):
rank = dist.get_rank() if dist.is_available() and dist.is_initialized() else 0
return os.path.join(checkpoint_dir, f"training_state_rank_{rank}.pt")


def generate_draft_model_config(
target_model_path: str, template_config_path: str = None, cache_dir: str = None
):
Expand Down