Add --loss-aggregation for the four ScaleRL pg_loss aggregation modes#2090
Open
EazyReal wants to merge 1 commit into
Open
Add --loss-aggregation for the four ScaleRL pg_loss aggregation modes#2090EazyReal wants to merge 1 commit into
--loss-aggregation for the four ScaleRL pg_loss aggregation modes#2090EazyReal wants to merge 1 commit into
Conversation
Add a unified `--loss-aggregation {sample_mean,prompt_mean,token_mean,constant}`
(+ `--loss-aggregation-divisor L` for constant) selecting how pg_loss is
aggregated across a training step, riding the existing `sample_denoms` seam in
`get_sum_of_sample_mean`.
Modes follow the ScaleRL taxonomy (arXiv:2510.13786 §3.2):
- sample_mean (default): GRPO sample average — per-rollout token-weighted mean
via `rollout_mask_sums`. Byte-identical to the prior default (no extra batch
key in any non-prompt_mean mode).
- prompt_mean: DAPO prompt average — step-level `prompt_mask_sums` grouped by
Sample.group_index, built ONLY under prompt_mean and plumbed like
`rollout_mask_sums` (CP- and variable-GBS-correct). The other three modes
never read group_index and never build the key. A None group_index under
prompt_mean fails loud (the prompt-grouping invariant is broken; silently
renumbering it into a singleton group would degrade prompt_mean -> sample_mean
for that sample). Every prompt group enters the step sum once under the same
`/ step_global_batch_size` divisor, so relative per-prompt weighting is
uniform; absolute scale differs from a strict 1/P DAPO average by a constant
factor (P/N), which the learning rate absorbs.
- token_mean: token average — aliased onto `--calculate-per-token-loss` so the
whole loss-scaling/reporting path stays per-token.
- constant: Dr.GRPO (arXiv:2503.20783) — masked token sum / L via a new
`constant_divisor` branch in cp_utils.
Aggregation applies to pg_loss only (metrics keep sum_of_sample_mean);
`--custom-pg-loss-reducer-function-path` still takes precedence. `L` is validated
> 0 at startup only for constant. Supersedes the open THUDM#2060 `--pg-loss-divisor`.
Tests: test_cp_utils.py pins constant divides by L, prompt_mean's per-group
denominator distinct from sample/token mean, the constant/per-token mutual-
exclusion guard, and CP rank-sum invariance for the new constant branch
(prompt_mean reuses the per-rollout sample-mean CP path already pinned).
test_megatron_argument_validation.py pins divisor validation and the token_mean
alias. test_rollout_validation.py pins the prompt_mask_sums build: fail-loud on a
None group_index under prompt_mean, no key built for the other three modes
(default batch unchanged), and correct per-group totals. Mutation-verified.
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
fcd9aa9 to
ccbb81a
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Problem
slime aggregates
pg_lossacross a training step as a per-rollout token-weighted sample mean. Recent RL recipes deliberately choose a different aggregation: DAPO averages per prompt group, Dr.GRPO divides by a constant, and some recipes use a global per-token mean. ScaleRL (arXiv:2510.13786 §3.2) catalogs these as one knob and reports the choice materially affects stability and final reward. Today the only escape hatch is--custom-pg-loss-reducer-function-path(write your own reducer per recipe); there is no first-class flag, and the open #2060 only adds a single Dr.GRPO--pg-loss-divisor.What this adds
A single
--loss-aggregation {sample_mean,prompt_mean,token_mean,constant}flag (plus--loss-aggregation-divisor Lforconstant) selecting howpg_lossis aggregated. Modes follow the ScaleRL taxonomy:sample_mean(default)prompt_meanSample.group_indexshare one denominator). ScaleRL's recommended default for new recipes.token_mean--calculate-per-token-loss(which it sets).constantsum(token_loss * loss_mask) / L,L = --loss-aggregation-divisor(e.g. the max context length).Before / after
pg_lossreducer is the per-rolloutsum_of_sample_mean. No flag.pg_lossreducer is chosen by--loss-aggregation. The defaultsample_meanreturns the same reducer object — so an existing run'spg_lossis byte-identical (verified bytest_default_reduces_to_per_sample_meanand the validation-sidetest_loss_aggregation_default_leaves_per_token_loss_off).Lis validated> 0at startup, only forconstant.Why this shape
It rides the
sample_denomsseam already added in #1933 tocp_utils.get_sum_of_sample_mean. That function takes pre-computed per-sample denominators that are CP-correct and remain correct when a rollout's samples are packed across micro-batches. The default path feedsrollout_mask_sums(per-rollout totals).prompt_meanis the only new step-level computation:prompt_mask_sums(per-prompt-group mask totals grouped bySample.group_index), computed inRolloutManagerright besiderollout_mask_sumsand plumbed through the identical path (data.pylog filter, DP split,model.pypad list, actor GPU promotion). It is a per-sample broadcast of the whole-group total, so CP, the DP split, and micro-batch packing all sum partial(x·m)/D_groupto(Σx·m)/D_group— same correctness assample_mean, for free. ASample.group_indexofNoneunderprompt_meanfails loud (raisesValueError): a None means the sample belongs to no prompt group, so its per-prompt denominator is undefined — silently renumbering it into its own singleton group would degradeprompt_mean→sample_meanfor that sample, which is a real break of the prompt-grouping invariant, not a benign special value.constantis a smallconstant_divisorbranch inget_sum_of_sample_mean(sum_of_token(x) / L); being identical on every CP rank, Megatron's gradient sum-allreduce already yields the full-batch value.token_meanaliases onto--calculate-per-token-lossso the loss-scaling and reporting path stays consistent rather than introducing a second per-token codepath.prompt_mask_sumsis computed only under--loss-aggregation prompt_mean(the mode that consumes it); the other three modes never readSample.group_indexand never build the key, so the default (and every non-prompt_mean) batch is unchanged — no extra batch key, no always-on group-aggregation compute/bandwidth. The reducer stillbatch.get("prompt_mask_sums")-checks and fails loud if a custom convert path selectedprompt_meanbut dropped the key.This keeps the new code minimal and reuses the verified reducer rather than adding a parallel aggregation stack.
Scope: pg_loss only (deliberate)
Aggregation applies to
pg_lossonly. The diagnostic metrics —pg_clipfrac,ppo_kl,entropy_loss,kl_loss— keep the default sample-mean reducer so they stay interpretable and comparable across runs (e.g. aconstant/Lmust not crushppo_klby the same factor and make it unreadable). This matches the existing scope of--custom-pg-loss-reducer-function-path, which still takes precedence when set.token_meanis the documented exception: because it reuses--calculate-per-token-loss, it is per-token everywhere.Alternative for reviewers: one could apply the chosen aggregation uniformly to the metrics too (single normalizer for loss and diagnostics). We chose not to, to preserve metric comparability across aggregation modes; this is a one-line change at the call site if the project prefers the uniform convention.
Honesty: prompt_mean absolute scale
prompt_meanweights every prompt group equally — each group's token-weighted mean enters the step sum once, all under the same/ step_global_batch_sizedivisor — so the relative per-prompt weighting (the property DAPO is about) is exact. Its absolute scale differs from a strict1/PDAPO average by a constant factor (P / N, prompts over rollouts), which the learning rate absorbs. Documented indocs/en/get_started/customization.mdand the flag help.Tests
tests/test_cp_utils.py:constantdivides the masked token-sum byL;prompt_mean's per-group denominator is distinct from the per-rollout (sample_mean) and per-token (token_mean) results on uneven fixtures; theconstant/--calculate-per-token-lossmutual-exclusion guard; CP rank-sum invariance for the newconstantbranch (prompt_meanreuses the per-rollout sample-mean CP path already pinned in the same file).tests/test_megatron_argument_validation.py:--loss-aggregation-divisorrejected when missing / non-positive / NaN underconstant, accepted when positive;token_meanaliases--calculate-per-token-loss; default leaves it off.tests/test_rollout_validation.py: theprompt_mask_sumsbuild in_convert_samples_to_train_data— fails loud on aNonegroup_indexunderprompt_mean, builds the correct per-prompt-group totals, and (forsample_mean/constant/token_mean) neither consultsgroup_indexnor adds theprompt_mask_sumskey, keeping the non-prompt_meanbatch unchanged. Mutation-verified (removing the fail-loud guard fails the test).Supersedes #2060.