Skip to content

[megatron] don't re-assert no_sync_func every step with overlap_grad_reduce#2066

Draft
HaozheZhang6 wants to merge 1 commit into
THUDM:mainfrom
HaozheZhang6:fix/overlap-grad-reduce-no-sync-assert
Draft

[megatron] don't re-assert no_sync_func every step with overlap_grad_reduce#2066
HaozheZhang6 wants to merge 1 commit into
THUDM:mainfrom
HaozheZhang6:fix/overlap-grad-reduce-no-sync-assert

Conversation

@HaozheZhang6

Copy link
Copy Markdown

Summary

train() in slime/backends/megatron_utils/model.py sets up config.no_sync_func on every step, but config = get_model_config(model[0]) is the model config and persists across steps. With --overlap-grad-reduce, the first step passes assert config.no_sync_func is None and sets it; the second step then trips that same assert and crashes.

assert config.no_sync_func is None, (...)   # step 1: None -> ok, sets it
                                            # step 2: not None -> AssertionError

Fix

Guard the setup with if config.no_sync_func is None: so the sync funcs are set once. They're constant (the model chunks' own no_sync / start_grad_sync), so skipping on later steps is a no-op — the only thing the per-step re-run did was trip the assert.

Reproduction

Run any training with --overlap-grad-reduce (from #1779) — it crashes on the 2nd step without this change.

Notes

  • I couldn't add an automated test: reproducing needs a multi-GPU Megatron training run with --overlap-grad-reduce, which I can't stand up locally. The fix is small and verified by inspection against the reported repro; happy to add a test if you can point me at a lightweight harness for this path.
  • Opening as a draft to confirm the approach: the original assert also doubled as a guard against a pre-supplied custom no_sync_func. Under the is None guard it becomes redundant, and a hypothetical pre-set custom func would now be left as-is rather than rejected. If you'd rather keep that rejection explicit (e.g. assert only on first setup), I'll adjust.

Fixes #1779

…reduce

`train()` sets up `config.no_sync_func` on every step, but `config` is the
model config and persists across steps. With `--overlap-grad-reduce` the first
step sets it, then the second step trips `assert config.no_sync_func is None`
and crashes. Guard the setup with `if config.no_sync_func is None:` so the sync
funcs are set once (they are constant, so skipping later steps is a no-op).

Fixes THUDM#1779
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug] assert config.no_sync_func is None always gets hit second step

1 participant