Skip to content
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion training/bf16_master_weight/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,11 @@ def main():
input_ids = torch.randint(0, actual_vocab_size, (args.batch_size, args.seq_length), device=device)
labels = torch.randint(0, actual_vocab_size, (args.batch_size, args.seq_length), device=device)

# Forward pass with optional autocast
# Forward pass with optional autocast.
Comment thread
tohtana marked this conversation as resolved.
Outdated
# DeepSpeed already applies torch.autocast inside engine.forward(), but
# we wrap the entire forward+loss block so that loss_fn also runs under
# autocast. The nested autocast on engine.forward() is harmless —
# PyTorch's torch.autocast is idempotent when nested with the same dtype.
if use_autocast:
with torch.autocast(device_type="cuda", dtype=autocast_dtype):
logits = model_engine(input_ids)
Expand Down