Skip to content

Add --loss-aggregation for the four ScaleRL pg_loss aggregation modes#2090

Open
EazyReal wants to merge 1 commit into
THUDM:mainfrom
EazyReal:upstream-pr/loss-aggregation-modes
Open

Add --loss-aggregation for the four ScaleRL pg_loss aggregation modes#2090
EazyReal wants to merge 1 commit into
THUDM:mainfrom
EazyReal:upstream-pr/loss-aggregation-modes

Conversation

@EazyReal

@EazyReal EazyReal commented Jun 16, 2026

Copy link
Copy Markdown
Contributor

Problem

slime aggregates pg_loss across a training step as a per-rollout token-weighted sample mean. Recent RL recipes deliberately choose a different aggregation: DAPO averages per prompt group, Dr.GRPO divides by a constant, and some recipes use a global per-token mean. ScaleRL (arXiv:2510.13786 §3.2) catalogs these as one knob and reports the choice materially affects stability and final reward. Today the only escape hatch is --custom-pg-loss-reducer-function-path (write your own reducer per recipe); there is no first-class flag, and the open #2060 only adds a single Dr.GRPO --pg-loss-divisor.

What this adds

A single --loss-aggregation {sample_mean,prompt_mean,token_mean,constant} flag (plus --loss-aggregation-divisor L for constant) selecting how pg_loss is aggregated. Modes follow the ScaleRL taxonomy:

Mode Paper pg_loss denominator
sample_mean (default) GRPO sample average Per-rollout token-weighted mean (each rollout contributes equally regardless of fan-out). Byte-identical to slime's prior default.
prompt_mean DAPO prompt average Per-prompt-group token-weighted mean (all rollouts sharing a Sample.group_index share one denominator). ScaleRL's recommended default for new recipes.
token_mean token average Global per-token mean. Equivalent to --calculate-per-token-loss (which it sets).
constant Dr.GRPO (arXiv:2503.20783) sum(token_loss * loss_mask) / L, L = --loss-aggregation-divisor (e.g. the max context length).

Before / after

  • Before: the pg_loss reducer is the per-rollout sum_of_sample_mean. No flag.
  • After: the pg_loss reducer is chosen by --loss-aggregation. The default sample_mean returns the same reducer object — so an existing run's pg_loss is byte-identical (verified by test_default_reduces_to_per_sample_mean and the validation-side test_loss_aggregation_default_leaves_per_token_loss_off). L is validated > 0 at startup, only for constant.

Why this shape

It rides the sample_denoms seam already added in #1933 to cp_utils.get_sum_of_sample_mean. That function takes pre-computed per-sample denominators that are CP-correct and remain correct when a rollout's samples are packed across micro-batches. The default path feeds rollout_mask_sums (per-rollout totals).

  • prompt_mean is the only new step-level computation: prompt_mask_sums (per-prompt-group mask totals grouped by Sample.group_index), computed in RolloutManager right beside rollout_mask_sums and plumbed through the identical path (data.py log filter, DP split, model.py pad list, actor GPU promotion). It is a per-sample broadcast of the whole-group total, so CP, the DP split, and micro-batch packing all sum partial (x·m)/D_group to (Σx·m)/D_group — same correctness as sample_mean, for free. A Sample.group_index of None under prompt_mean fails loud (raises ValueError): a None means the sample belongs to no prompt group, so its per-prompt denominator is undefined — silently renumbering it into its own singleton group would degrade prompt_meansample_mean for that sample, which is a real break of the prompt-grouping invariant, not a benign special value.
  • constant is a small constant_divisor branch in get_sum_of_sample_mean (sum_of_token(x) / L); being identical on every CP rank, Megatron's gradient sum-allreduce already yields the full-batch value.
  • token_mean aliases onto --calculate-per-token-loss so the loss-scaling and reporting path stays consistent rather than introducing a second per-token codepath.

prompt_mask_sums is computed only under --loss-aggregation prompt_mean (the mode that consumes it); the other three modes never read Sample.group_index and never build the key, so the default (and every non-prompt_mean) batch is unchanged — no extra batch key, no always-on group-aggregation compute/bandwidth. The reducer still batch.get("prompt_mask_sums")-checks and fails loud if a custom convert path selected prompt_mean but dropped the key.

This keeps the new code minimal and reuses the verified reducer rather than adding a parallel aggregation stack.

Scope: pg_loss only (deliberate)

Aggregation applies to pg_loss only. The diagnostic metrics — pg_clipfrac, ppo_kl, entropy_loss, kl_loss — keep the default sample-mean reducer so they stay interpretable and comparable across runs (e.g. a constant /L must not crush ppo_kl by the same factor and make it unreadable). This matches the existing scope of --custom-pg-loss-reducer-function-path, which still takes precedence when set. token_mean is the documented exception: because it reuses --calculate-per-token-loss, it is per-token everywhere.

Alternative for reviewers: one could apply the chosen aggregation uniformly to the metrics too (single normalizer for loss and diagnostics). We chose not to, to preserve metric comparability across aggregation modes; this is a one-line change at the call site if the project prefers the uniform convention.

Honesty: prompt_mean absolute scale

prompt_mean weights every prompt group equally — each group's token-weighted mean enters the step sum once, all under the same / step_global_batch_size divisor — so the relative per-prompt weighting (the property DAPO is about) is exact. Its absolute scale differs from a strict 1/P DAPO average by a constant factor (P / N, prompts over rollouts), which the learning rate absorbs. Documented in docs/en/get_started/customization.md and the flag help.

Tests

  • tests/test_cp_utils.py: constant divides the masked token-sum by L; prompt_mean's per-group denominator is distinct from the per-rollout (sample_mean) and per-token (token_mean) results on uneven fixtures; the constant / --calculate-per-token-loss mutual-exclusion guard; CP rank-sum invariance for the new constant branch (prompt_mean reuses the per-rollout sample-mean CP path already pinned in the same file).
  • tests/test_megatron_argument_validation.py: --loss-aggregation-divisor rejected when missing / non-positive / NaN under constant, accepted when positive; token_mean aliases --calculate-per-token-loss; default leaves it off.
  • tests/test_rollout_validation.py: the prompt_mask_sums build in _convert_samples_to_train_data — fails loud on a None group_index under prompt_mean, builds the correct per-prompt-group totals, and (for sample_mean / constant / token_mean) neither consults group_index nor adds the prompt_mask_sums key, keeping the non-prompt_mean batch unchanged. Mutation-verified (removing the fail-loud guard fails the test).

Supersedes #2060.

Add a unified `--loss-aggregation {sample_mean,prompt_mean,token_mean,constant}`
(+ `--loss-aggregation-divisor L` for constant) selecting how pg_loss is
aggregated across a training step, riding the existing `sample_denoms` seam in
`get_sum_of_sample_mean`.

Modes follow the ScaleRL taxonomy (arXiv:2510.13786 §3.2):
- sample_mean (default): GRPO sample average — per-rollout token-weighted mean
  via `rollout_mask_sums`. Byte-identical to the prior default (no extra batch
  key in any non-prompt_mean mode).
- prompt_mean: DAPO prompt average — step-level `prompt_mask_sums` grouped by
  Sample.group_index, built ONLY under prompt_mean and plumbed like
  `rollout_mask_sums` (CP- and variable-GBS-correct). The other three modes
  never read group_index and never build the key. A None group_index under
  prompt_mean fails loud (the prompt-grouping invariant is broken; silently
  renumbering it into a singleton group would degrade prompt_mean -> sample_mean
  for that sample). Every prompt group enters the step sum once under the same
  `/ step_global_batch_size` divisor, so relative per-prompt weighting is
  uniform; absolute scale differs from a strict 1/P DAPO average by a constant
  factor (P/N), which the learning rate absorbs.
- token_mean: token average — aliased onto `--calculate-per-token-loss` so the
  whole loss-scaling/reporting path stays per-token.
- constant: Dr.GRPO (arXiv:2503.20783) — masked token sum / L via a new
  `constant_divisor` branch in cp_utils.

Aggregation applies to pg_loss only (metrics keep sum_of_sample_mean);
`--custom-pg-loss-reducer-function-path` still takes precedence. `L` is validated
> 0 at startup only for constant. Supersedes the open THUDM#2060 `--pg-loss-divisor`.

Tests: test_cp_utils.py pins constant divides by L, prompt_mean's per-group
denominator distinct from sample/token mean, the constant/per-token mutual-
exclusion guard, and CP rank-sum invariance for the new constant branch
(prompt_mean reuses the per-rollout sample-mean CP path already pinned).
test_megatron_argument_validation.py pins divisor validation and the token_mean
alias. test_rollout_validation.py pins the prompt_mask_sums build: fail-loud on a
None group_index under prompt_mean, no key built for the other three modes
(default batch unchanged), and correct per-group totals. Mutation-verified.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
@EazyReal EazyReal force-pushed the upstream-pr/loss-aggregation-modes branch from fcd9aa9 to ccbb81a Compare June 16, 2026 08:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant