Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 31 additions & 4 deletions docs/en/get_started/customization.md
Original file line number Diff line number Diff line change
Expand Up @@ -295,10 +295,37 @@ def get_pg_loss_reducer(
```

**Use Cases**:
- Dr.GRPO: Divide by a constant instead of effective token count
- Custom loss normalization strategies

**Example**: `examples/DrGRPO/custom_reducer.py:get_pg_loss_reducer`
- Custom loss normalization strategies not covered by `--loss-aggregation`

> The four standard loss-aggregation modes (GRPO sample average, DAPO prompt
> average, token average, Dr.GRPO constant divisor) are available first-class via
> `--loss-aggregation` (see below) — no custom reducer needed. Reach for this
> hook only for a normalization those modes do not express. When set, it takes
> precedence over `--loss-aggregation`.

**Built-in modes — `--loss-aggregation`**

`--loss-aggregation {sample_mean,prompt_mean,token_mean,constant}` selects how
pg_loss is aggregated across a training step (pg_loss only; every other metric
keeps the default sample-mean reducer — same scope as the custom hook above).
Modes follow the ScaleRL taxonomy ([arXiv:2510.13786](https://arxiv.org/abs/2510.13786) §3.2):

| Mode | Paper | pg_loss denominator |
| :--- | :--- | :--- |
| `sample_mean` (default) | GRPO sample average | Per-rollout token-weighted mean — each rollout contributes equally regardless of fan-out. Byte-identical to slime's prior default. |
| `prompt_mean` | DAPO prompt average | Per-prompt-group token-weighted mean (all rollouts sharing a prompt share one denominator). ScaleRL's recommended default for new recipes. |
| `token_mean` | token average | Global per-token mean. Equivalent to `--calculate-per-token-loss` (which it sets). |
| `constant` | Dr.GRPO ([arXiv:2503.20783](https://arxiv.org/abs/2503.20783)) | `sum(token_loss * loss_mask) / L`, where `L = --loss-aggregation-divisor` (e.g. the max context length). |

`--loss-aggregation-divisor L` is required (validated `> 0` at startup) only for
`constant`; it is ignored for the other modes. The default (`sample_mean`) leaves
behavior unchanged.

`prompt_mean` weights every prompt group equally (each group's token-weighted
mean enters the step sum once, all under the same `/ step_global_batch_size`
divisor). Its absolute scale differs from a strict `1/P` DAPO average by a
constant factor (`P / N`, prompts over rollouts), which the learning rate
absorbs; the relative per-prompt weighting is uniform.

---

Expand Down
11 changes: 5 additions & 6 deletions slime/backends/megatron_utils/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,12 +235,11 @@ def _get_rollout_data(self, rollout_data_ref: Box) -> RolloutBatch:
rollout_data["loss_masks"] = [
t.to(device=device, dtype=torch.int, non_blocking=True) for t in rollout_data["loss_masks"]
]
if "rollout_mask_sums" in rollout_data:
# Promote precomputed per-rollout mask totals to GPU tensors here
# (matching loss_masks) so the loss reducer can just divide.
rollout_data["rollout_mask_sums"] = rollout_data["rollout_mask_sums"].to(
device=device, dtype=torch.float32, non_blocking=True
)
for mask_sums_key in ("rollout_mask_sums", "prompt_mask_sums"):
if mask_sums_key in rollout_data:
rollout_data[mask_sums_key] = rollout_data[mask_sums_key].to(
device=device, dtype=torch.float32, non_blocking=True
)
if "multimodal_train_inputs" in rollout_data:
# Move multimodal training tensors to GPU in advance
rollout_data["multimodal_train_inputs"] = [
Expand Down
18 changes: 18 additions & 0 deletions slime/backends/megatron_utils/cp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def get_sum_of_sample_mean(
calculate_per_token_loss: bool = False,
qkv_format: str = "thd",
max_seq_lens: list[int] | None = None,
constant_divisor: float | None = None,
) -> Callable[[torch.Tensor], torch.Tensor]:
"""
Calculate correct sample mean for CP.
Expand All @@ -71,7 +72,17 @@ def get_sum_of_sample_mean(
step level rather than per-mb is required — otherwise a rollout whose
samples land in different micro-batches would get a partial denominator
on each side.

``constant_divisor`` (Dr.GRPO, arXiv:2503.20783) divides the masked
token-sum by a fixed ``L``; being identical on every CP rank, Megatron's
gradient sum-allreduce already yields the full-batch value (no extra
all-reduce). Mutually exclusive with ``calculate_per_token_loss``.
"""
if constant_divisor is not None and calculate_per_token_loss:
raise ValueError(
"constant_divisor (loss-aggregation=constant) and calculate_per_token_loss "
"(loss-aggregation=token_mean) are mutually exclusive aggregation modes."
)
if sample_denoms is None:
sample_denoms = [m.sum() for m in loss_masks]

Expand Down Expand Up @@ -133,6 +144,13 @@ def sum_of_token(x: torch.Tensor) -> torch.Tensor:
]
)

if constant_divisor is not None:

def sum_of_constant(x: torch.Tensor) -> torch.Tensor:
return sum_of_token(x) / constant_divisor

return sum_of_constant

return sum_of_sample_mean if not calculate_per_token_loss else sum_of_token


Expand Down
1 change: 1 addition & 0 deletions slime/backends/megatron_utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,7 @@ def log_rollout_data(
"sample_indices",
"rollout_ids",
"rollout_mask_sums",
"prompt_mask_sums",
"rollout_routed_experts",
"max_seq_lens",
"global_batch_sizes",
Expand Down
67 changes: 63 additions & 4 deletions slime/backends/megatron_utils/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -798,6 +798,56 @@ def icepop_function(
return pg_loss, loss_masks, metrics


def get_pg_loss_reducer(
args: Namespace,
batch: RolloutBatch,
*,
total_lengths: list[int],
response_lengths: list[int],
pg_loss_masks: list[torch.Tensor],
max_seq_lens: list[int] | None,
default_reducer: Callable[[torch.Tensor], torch.Tensor],
) -> Callable[[torch.Tensor], torch.Tensor]:
"""The ``--loss-aggregation`` reducer for pg_loss. ``sample_mean`` returns
``default_reducer`` unchanged, keeping the prior default byte-identical.
"""
mode = getattr(args, "loss_aggregation", "sample_mean")
if mode in ("sample_mean", "token_mean"):
# token_mean is aliased onto --calculate-per-token-loss, so
# default_reducer is already the per-token path.
return default_reducer
if mode == "prompt_mean":
prompt_mask_sums = batch.get("prompt_mask_sums")
if prompt_mask_sums is None:
# None would silently fall back to the per-sample mean; a custom
# convert path that drops prompt_mask_sums must fail, not degrade.
raise ValueError(
"--loss-aggregation=prompt_mean requires per-prompt-group mask sums "
"(batch['prompt_mask_sums']), but they are missing. A custom "
"--custom-convert-samples-to-train-data-path must populate "
"'prompt_mask_sums' (grouped by Sample.group_index)."
)
return get_sum_of_sample_mean(
total_lengths,
response_lengths,
pg_loss_masks,
prompt_mask_sums,
args.calculate_per_token_loss,
args.qkv_format,
max_seq_lens,
)
if mode == "constant":
return get_sum_of_sample_mean(
total_lengths,
response_lengths,
pg_loss_masks,
qkv_format=args.qkv_format,
max_seq_lens=max_seq_lens,
constant_divisor=args.loss_aggregation_divisor,
)
raise ValueError(f"Unknown --loss-aggregation mode: {mode!r}")


def policy_loss_function(
args: Namespace,
batch: RolloutBatch,
Expand Down Expand Up @@ -951,16 +1001,25 @@ def policy_loss_function(
max_seq_lens,
)

# Determine pg_loss reducer: use custom if specified, otherwise default
# Under TIS/RS rejected tokens are zeroed in modified_response_masks.
pg_loss_masks = modified_response_masks if (args.get_mismatch_metrics or args.use_tis) else batch["loss_masks"]

# Custom reducer path takes precedence over --loss-aggregation.
if getattr(args, "custom_pg_loss_reducer_function_path", None) is not None:
custom_pg_loss_reducer_func = load_function(args.custom_pg_loss_reducer_function_path)
# Determine which loss_masks to use for pg_loss reducer
pg_loss_masks = modified_response_masks if (args.get_mismatch_metrics or args.use_tis) else batch["loss_masks"]
pg_loss_reducer = custom_pg_loss_reducer_func(
total_lengths, response_lengths, pg_loss_masks, args.calculate_per_token_loss
)
else:
pg_loss_reducer = sum_of_sample_mean
pg_loss_reducer = get_pg_loss_reducer(
args,
batch,
total_lengths=total_lengths,
response_lengths=response_lengths,
pg_loss_masks=pg_loss_masks,
max_seq_lens=max_seq_lens,
default_reducer=sum_of_sample_mean,
)

pg_loss = pg_loss_reducer(pg_loss)
pg_clipfrac = sum_of_sample_mean(pg_clipfrac)
Expand Down
1 change: 1 addition & 0 deletions slime/backends/megatron_utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,7 @@ def forward_step(data_iterator: DataIterator, model: GPTModel, return_schedule_p
"max_seq_lens",
"teacher_log_probs",
"rollout_mask_sums",
"prompt_mask_sums",
],
args.data_pad_size_multiplier,
args.qkv_format,
Expand Down
33 changes: 28 additions & 5 deletions slime/ray/rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,12 @@ def _tensorize_rollout_data_for_training(rollout_data: dict[str, Any]) -> None:
for mm_dict in rollout_data["multimodal_train_inputs"]
]

if "rollout_mask_sums" in rollout_data:
rollout_data["rollout_mask_sums"] = _cpu_tensor(
rollout_data["rollout_mask_sums"],
dtype=torch.float32,
)
for mask_sums_key in ("rollout_mask_sums", "prompt_mask_sums"):
if mask_sums_key in rollout_data:
rollout_data[mask_sums_key] = _cpu_tensor(
rollout_data[mask_sums_key],
dtype=torch.float32,
)


@dataclasses.dataclass
Expand Down Expand Up @@ -775,6 +776,27 @@ def _convert_samples_to_train_data(self, samples: list[Sample] | list[list[Sampl
rollout_total_mask[rid] = rollout_total_mask.get(rid, 0) + ms
train_data["rollout_mask_sums"] = [rollout_total_mask[rid] for rid in rollout_id_list]

# prompt_mask_sums: per-prompt-group mask total, summed here at the step
# level (every sibling is visible) and broadcast per-sample so a group
# split across micro-batches by packing still divides by its whole total.
# Built only under prompt_mean — the other modes never read it, so the
# default (sample_mean) batch stays byte-identical with no extra key.
if getattr(self.args, "loss_aggregation", "sample_mean") == "prompt_mean":
group_total_mask: dict[int, int] = {}
for sample, ms in zip(samples, mask_sums_per_sample, strict=True):
# A None group_index would collapse unrelated prompts into one
# denominator, silently degrading prompt_mean -> sample_mean for
# that sample. The prompt-grouping invariant is violated, so fail.
if sample.group_index is None:
raise ValueError(
"--loss-aggregation prompt_mean requires every Sample.group_index to be set, "
"but a sample has group_index=None. prompt_mean divides each sample by its "
"prompt group's total mask; a None group_index means the sample belongs to no "
"prompt group, so its denominator is undefined."
)
group_total_mask[sample.group_index] = group_total_mask.get(sample.group_index, 0) + ms
train_data["prompt_mask_sums"] = [group_total_mask[sample.group_index] for sample in samples]

# Overwrite raw_reward when available. Mixed-source batches may only
# populate this field for a subset of samples (e.g. SWE but not code).
if any(sample.metadata and "raw_reward" in sample.metadata for sample in samples):
Expand Down Expand Up @@ -848,6 +870,7 @@ def _split_train_data_by_dp(self, data):
"sample_indices",
"rollout_ids",
"rollout_mask_sums",
"prompt_mask_sums",
"rollout_log_probs",
"rollout_routed_experts",
"prompt",
Expand Down
50 changes: 50 additions & 0 deletions slime/utils/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -1052,6 +1052,41 @@ def add_algo_arguments(parser):
default=None,
help="Path to a custom reducer function for pg_loss only. When set, pg_loss will use this custom reducer while other metrics (pg_clipfrac, ppo_kl, entropy_loss, etc.) still use the default sum_of_sample_mean. (e.g., examples/Dr.GRPO/custom_reducer.py:get_pg_loss_reducer).",
)
parser.add_argument(
"--loss-aggregation",
type=str,
default="sample_mean",
choices=["sample_mean", "prompt_mean", "token_mean", "constant"],
help=(
"How pg_loss is aggregated across the step (applies to pg_loss only; "
"pg_clipfrac, ppo_kl, entropy_loss, kl_loss keep the default sample-mean "
"reducer — same scope as --custom-pg-loss-reducer-function-path, which still "
"takes precedence when set). Modes follow the ScaleRL taxonomy "
"(arXiv:2510.13786 §3.2): "
"'sample_mean' (default; GRPO sample average) — each rollout's tokens are "
"averaged with the per-rollout token-weighted denominator, so every rollout "
"contributes equally regardless of fan-out (byte-identical to slime's prior "
"default); "
"'prompt_mean' (DAPO prompt average; ScaleRL's recommended default for new "
"recipes) — tokens are averaged over each prompt group (all rollouts sharing a "
"Sample.group_index share one denominator); "
"'token_mean' (token average) — global per-token mean, equivalent to "
"--calculate-per-token-loss; "
"'constant' (Dr.GRPO, arXiv:2503.20783) — masked token sum divided by a fixed "
"--loss-aggregation-divisor (e.g. the max context length)."
),
)
parser.add_argument(
"--loss-aggregation-divisor",
type=float,
default=None,
help=(
"Constant divisor L for --loss-aggregation=constant (Dr.GRPO). pg_loss is "
"aggregated as sum(token_loss * loss_mask) / L instead of any data-dependent "
"denominator. Required and validated > 0 at startup only when "
"--loss-aggregation=constant; ignored for the other modes."
),
)

parser.add_argument(
"--use-routing-replay",
Expand Down Expand Up @@ -1823,6 +1858,21 @@ def slime_validate_args(args):

assert not (args.kl_coef != 0 and args.kl_loss_coef != 0), "Only one of kl_coef and kl_loss_coef can be set"

loss_aggregation = getattr(args, "loss_aggregation", "sample_mean")
if loss_aggregation == "constant":
# Dr.GRPO needs a fixed, positive divisor; fail at startup, not mid-train.
divisor = getattr(args, "loss_aggregation_divisor", None)
if divisor is None or not (divisor > 0):
raise ValueError(
"--loss-aggregation-divisor must be set to a positive value when "
f"--loss-aggregation=constant (got {divisor!r})."
)
elif loss_aggregation == "token_mean":
# Alias onto --calculate-per-token-loss so pg_loss does not desync from
# the normalizer/metric path (both already implement the per-token mean).
if not getattr(args, "calculate_per_token_loss", False):
args.calculate_per_token_loss = True

if args.advantage_estimator in ["reinforce_plus_plus", "reinforce_plus_plus_baseline"]:
assert args.normalize_advantages, (
"The 'reinforce_plus_plus' and 'reinforce_plus_plus_baseline' advantage estimators "
Expand Down
Loading
Loading