diff --git a/docs/en/get_started/customization.md b/docs/en/get_started/customization.md index 77f5cd5e3..bed10ac91 100644 --- a/docs/en/get_started/customization.md +++ b/docs/en/get_started/customization.md @@ -295,10 +295,37 @@ def get_pg_loss_reducer( ``` **Use Cases**: -- Dr.GRPO: Divide by a constant instead of effective token count -- Custom loss normalization strategies - -**Example**: `examples/DrGRPO/custom_reducer.py:get_pg_loss_reducer` +- Custom loss normalization strategies not covered by `--loss-aggregation` + +> The four standard loss-aggregation modes (GRPO sample average, DAPO prompt +> average, token average, Dr.GRPO constant divisor) are available first-class via +> `--loss-aggregation` (see below) — no custom reducer needed. Reach for this +> hook only for a normalization those modes do not express. When set, it takes +> precedence over `--loss-aggregation`. + +**Built-in modes — `--loss-aggregation`** + +`--loss-aggregation {sample_mean,prompt_mean,token_mean,constant}` selects how +pg_loss is aggregated across a training step (pg_loss only; every other metric +keeps the default sample-mean reducer — same scope as the custom hook above). +Modes follow the ScaleRL taxonomy ([arXiv:2510.13786](https://arxiv.org/abs/2510.13786) §3.2): + +| Mode | Paper | pg_loss denominator | +| :--- | :--- | :--- | +| `sample_mean` (default) | GRPO sample average | Per-rollout token-weighted mean — each rollout contributes equally regardless of fan-out. Byte-identical to slime's prior default. | +| `prompt_mean` | DAPO prompt average | Per-prompt-group token-weighted mean (all rollouts sharing a prompt share one denominator). ScaleRL's recommended default for new recipes. | +| `token_mean` | token average | Global per-token mean. Equivalent to `--calculate-per-token-loss` (which it sets). | +| `constant` | Dr.GRPO ([arXiv:2503.20783](https://arxiv.org/abs/2503.20783)) | `sum(token_loss * loss_mask) / L`, where `L = --loss-aggregation-divisor` (e.g. the max context length). | + +`--loss-aggregation-divisor L` is required (validated `> 0` at startup) only for +`constant`; it is ignored for the other modes. The default (`sample_mean`) leaves +behavior unchanged. + +`prompt_mean` weights every prompt group equally (each group's token-weighted +mean enters the step sum once, all under the same `/ step_global_batch_size` +divisor). Its absolute scale differs from a strict `1/P` DAPO average by a +constant factor (`P / N`, prompts over rollouts), which the learning rate +absorbs; the relative per-prompt weighting is uniform. --- diff --git a/slime/backends/megatron_utils/actor.py b/slime/backends/megatron_utils/actor.py index 941659f1f..842e121d7 100644 --- a/slime/backends/megatron_utils/actor.py +++ b/slime/backends/megatron_utils/actor.py @@ -235,12 +235,11 @@ def _get_rollout_data(self, rollout_data_ref: Box) -> RolloutBatch: rollout_data["loss_masks"] = [ t.to(device=device, dtype=torch.int, non_blocking=True) for t in rollout_data["loss_masks"] ] - if "rollout_mask_sums" in rollout_data: - # Promote precomputed per-rollout mask totals to GPU tensors here - # (matching loss_masks) so the loss reducer can just divide. - rollout_data["rollout_mask_sums"] = rollout_data["rollout_mask_sums"].to( - device=device, dtype=torch.float32, non_blocking=True - ) + for mask_sums_key in ("rollout_mask_sums", "prompt_mask_sums"): + if mask_sums_key in rollout_data: + rollout_data[mask_sums_key] = rollout_data[mask_sums_key].to( + device=device, dtype=torch.float32, non_blocking=True + ) if "multimodal_train_inputs" in rollout_data: # Move multimodal training tensors to GPU in advance rollout_data["multimodal_train_inputs"] = [ diff --git a/slime/backends/megatron_utils/cp_utils.py b/slime/backends/megatron_utils/cp_utils.py index 448c154c6..ec469f21e 100644 --- a/slime/backends/megatron_utils/cp_utils.py +++ b/slime/backends/megatron_utils/cp_utils.py @@ -58,6 +58,7 @@ def get_sum_of_sample_mean( calculate_per_token_loss: bool = False, qkv_format: str = "thd", max_seq_lens: list[int] | None = None, + constant_divisor: float | None = None, ) -> Callable[[torch.Tensor], torch.Tensor]: """ Calculate correct sample mean for CP. @@ -71,7 +72,17 @@ def get_sum_of_sample_mean( step level rather than per-mb is required — otherwise a rollout whose samples land in different micro-batches would get a partial denominator on each side. + + ``constant_divisor`` (Dr.GRPO, arXiv:2503.20783) divides the masked + token-sum by a fixed ``L``; being identical on every CP rank, Megatron's + gradient sum-allreduce already yields the full-batch value (no extra + all-reduce). Mutually exclusive with ``calculate_per_token_loss``. """ + if constant_divisor is not None and calculate_per_token_loss: + raise ValueError( + "constant_divisor (loss-aggregation=constant) and calculate_per_token_loss " + "(loss-aggregation=token_mean) are mutually exclusive aggregation modes." + ) if sample_denoms is None: sample_denoms = [m.sum() for m in loss_masks] @@ -133,6 +144,13 @@ def sum_of_token(x: torch.Tensor) -> torch.Tensor: ] ) + if constant_divisor is not None: + + def sum_of_constant(x: torch.Tensor) -> torch.Tensor: + return sum_of_token(x) / constant_divisor + + return sum_of_constant + return sum_of_sample_mean if not calculate_per_token_loss else sum_of_token diff --git a/slime/backends/megatron_utils/data.py b/slime/backends/megatron_utils/data.py index 7c12a5a77..8c6d907d7 100644 --- a/slime/backends/megatron_utils/data.py +++ b/slime/backends/megatron_utils/data.py @@ -298,6 +298,7 @@ def log_rollout_data( "sample_indices", "rollout_ids", "rollout_mask_sums", + "prompt_mask_sums", "rollout_routed_experts", "max_seq_lens", "global_batch_sizes", diff --git a/slime/backends/megatron_utils/loss.py b/slime/backends/megatron_utils/loss.py index a456939e7..df222ffe9 100644 --- a/slime/backends/megatron_utils/loss.py +++ b/slime/backends/megatron_utils/loss.py @@ -798,6 +798,56 @@ def icepop_function( return pg_loss, loss_masks, metrics +def get_pg_loss_reducer( + args: Namespace, + batch: RolloutBatch, + *, + total_lengths: list[int], + response_lengths: list[int], + pg_loss_masks: list[torch.Tensor], + max_seq_lens: list[int] | None, + default_reducer: Callable[[torch.Tensor], torch.Tensor], +) -> Callable[[torch.Tensor], torch.Tensor]: + """The ``--loss-aggregation`` reducer for pg_loss. ``sample_mean`` returns + ``default_reducer`` unchanged, keeping the prior default byte-identical. + """ + mode = getattr(args, "loss_aggregation", "sample_mean") + if mode in ("sample_mean", "token_mean"): + # token_mean is aliased onto --calculate-per-token-loss, so + # default_reducer is already the per-token path. + return default_reducer + if mode == "prompt_mean": + prompt_mask_sums = batch.get("prompt_mask_sums") + if prompt_mask_sums is None: + # None would silently fall back to the per-sample mean; a custom + # convert path that drops prompt_mask_sums must fail, not degrade. + raise ValueError( + "--loss-aggregation=prompt_mean requires per-prompt-group mask sums " + "(batch['prompt_mask_sums']), but they are missing. A custom " + "--custom-convert-samples-to-train-data-path must populate " + "'prompt_mask_sums' (grouped by Sample.group_index)." + ) + return get_sum_of_sample_mean( + total_lengths, + response_lengths, + pg_loss_masks, + prompt_mask_sums, + args.calculate_per_token_loss, + args.qkv_format, + max_seq_lens, + ) + if mode == "constant": + return get_sum_of_sample_mean( + total_lengths, + response_lengths, + pg_loss_masks, + qkv_format=args.qkv_format, + max_seq_lens=max_seq_lens, + constant_divisor=args.loss_aggregation_divisor, + ) + raise ValueError(f"Unknown --loss-aggregation mode: {mode!r}") + + def policy_loss_function( args: Namespace, batch: RolloutBatch, @@ -951,16 +1001,25 @@ def policy_loss_function( max_seq_lens, ) - # Determine pg_loss reducer: use custom if specified, otherwise default + # Under TIS/RS rejected tokens are zeroed in modified_response_masks. + pg_loss_masks = modified_response_masks if (args.get_mismatch_metrics or args.use_tis) else batch["loss_masks"] + + # Custom reducer path takes precedence over --loss-aggregation. if getattr(args, "custom_pg_loss_reducer_function_path", None) is not None: custom_pg_loss_reducer_func = load_function(args.custom_pg_loss_reducer_function_path) - # Determine which loss_masks to use for pg_loss reducer - pg_loss_masks = modified_response_masks if (args.get_mismatch_metrics or args.use_tis) else batch["loss_masks"] pg_loss_reducer = custom_pg_loss_reducer_func( total_lengths, response_lengths, pg_loss_masks, args.calculate_per_token_loss ) else: - pg_loss_reducer = sum_of_sample_mean + pg_loss_reducer = get_pg_loss_reducer( + args, + batch, + total_lengths=total_lengths, + response_lengths=response_lengths, + pg_loss_masks=pg_loss_masks, + max_seq_lens=max_seq_lens, + default_reducer=sum_of_sample_mean, + ) pg_loss = pg_loss_reducer(pg_loss) pg_clipfrac = sum_of_sample_mean(pg_clipfrac) diff --git a/slime/backends/megatron_utils/model.py b/slime/backends/megatron_utils/model.py index db6020a94..41f6141b0 100644 --- a/slime/backends/megatron_utils/model.py +++ b/slime/backends/megatron_utils/model.py @@ -502,6 +502,7 @@ def forward_step(data_iterator: DataIterator, model: GPTModel, return_schedule_p "max_seq_lens", "teacher_log_probs", "rollout_mask_sums", + "prompt_mask_sums", ], args.data_pad_size_multiplier, args.qkv_format, diff --git a/slime/ray/rollout.py b/slime/ray/rollout.py index 5766d6b17..56122be99 100644 --- a/slime/ray/rollout.py +++ b/slime/ray/rollout.py @@ -93,11 +93,12 @@ def _tensorize_rollout_data_for_training(rollout_data: dict[str, Any]) -> None: for mm_dict in rollout_data["multimodal_train_inputs"] ] - if "rollout_mask_sums" in rollout_data: - rollout_data["rollout_mask_sums"] = _cpu_tensor( - rollout_data["rollout_mask_sums"], - dtype=torch.float32, - ) + for mask_sums_key in ("rollout_mask_sums", "prompt_mask_sums"): + if mask_sums_key in rollout_data: + rollout_data[mask_sums_key] = _cpu_tensor( + rollout_data[mask_sums_key], + dtype=torch.float32, + ) @dataclasses.dataclass @@ -775,6 +776,27 @@ def _convert_samples_to_train_data(self, samples: list[Sample] | list[list[Sampl rollout_total_mask[rid] = rollout_total_mask.get(rid, 0) + ms train_data["rollout_mask_sums"] = [rollout_total_mask[rid] for rid in rollout_id_list] + # prompt_mask_sums: per-prompt-group mask total, summed here at the step + # level (every sibling is visible) and broadcast per-sample so a group + # split across micro-batches by packing still divides by its whole total. + # Built only under prompt_mean — the other modes never read it, so the + # default (sample_mean) batch stays byte-identical with no extra key. + if getattr(self.args, "loss_aggregation", "sample_mean") == "prompt_mean": + group_total_mask: dict[int, int] = {} + for sample, ms in zip(samples, mask_sums_per_sample, strict=True): + # A None group_index would collapse unrelated prompts into one + # denominator, silently degrading prompt_mean -> sample_mean for + # that sample. The prompt-grouping invariant is violated, so fail. + if sample.group_index is None: + raise ValueError( + "--loss-aggregation prompt_mean requires every Sample.group_index to be set, " + "but a sample has group_index=None. prompt_mean divides each sample by its " + "prompt group's total mask; a None group_index means the sample belongs to no " + "prompt group, so its denominator is undefined." + ) + group_total_mask[sample.group_index] = group_total_mask.get(sample.group_index, 0) + ms + train_data["prompt_mask_sums"] = [group_total_mask[sample.group_index] for sample in samples] + # Overwrite raw_reward when available. Mixed-source batches may only # populate this field for a subset of samples (e.g. SWE but not code). if any(sample.metadata and "raw_reward" in sample.metadata for sample in samples): @@ -848,6 +870,7 @@ def _split_train_data_by_dp(self, data): "sample_indices", "rollout_ids", "rollout_mask_sums", + "prompt_mask_sums", "rollout_log_probs", "rollout_routed_experts", "prompt", diff --git a/slime/utils/arguments.py b/slime/utils/arguments.py index 6efe85eae..c535b6f5a 100644 --- a/slime/utils/arguments.py +++ b/slime/utils/arguments.py @@ -1052,6 +1052,41 @@ def add_algo_arguments(parser): default=None, help="Path to a custom reducer function for pg_loss only. When set, pg_loss will use this custom reducer while other metrics (pg_clipfrac, ppo_kl, entropy_loss, etc.) still use the default sum_of_sample_mean. (e.g., examples/Dr.GRPO/custom_reducer.py:get_pg_loss_reducer).", ) + parser.add_argument( + "--loss-aggregation", + type=str, + default="sample_mean", + choices=["sample_mean", "prompt_mean", "token_mean", "constant"], + help=( + "How pg_loss is aggregated across the step (applies to pg_loss only; " + "pg_clipfrac, ppo_kl, entropy_loss, kl_loss keep the default sample-mean " + "reducer — same scope as --custom-pg-loss-reducer-function-path, which still " + "takes precedence when set). Modes follow the ScaleRL taxonomy " + "(arXiv:2510.13786 §3.2): " + "'sample_mean' (default; GRPO sample average) — each rollout's tokens are " + "averaged with the per-rollout token-weighted denominator, so every rollout " + "contributes equally regardless of fan-out (byte-identical to slime's prior " + "default); " + "'prompt_mean' (DAPO prompt average; ScaleRL's recommended default for new " + "recipes) — tokens are averaged over each prompt group (all rollouts sharing a " + "Sample.group_index share one denominator); " + "'token_mean' (token average) — global per-token mean, equivalent to " + "--calculate-per-token-loss; " + "'constant' (Dr.GRPO, arXiv:2503.20783) — masked token sum divided by a fixed " + "--loss-aggregation-divisor (e.g. the max context length)." + ), + ) + parser.add_argument( + "--loss-aggregation-divisor", + type=float, + default=None, + help=( + "Constant divisor L for --loss-aggregation=constant (Dr.GRPO). pg_loss is " + "aggregated as sum(token_loss * loss_mask) / L instead of any data-dependent " + "denominator. Required and validated > 0 at startup only when " + "--loss-aggregation=constant; ignored for the other modes." + ), + ) parser.add_argument( "--use-routing-replay", @@ -1823,6 +1858,21 @@ def slime_validate_args(args): assert not (args.kl_coef != 0 and args.kl_loss_coef != 0), "Only one of kl_coef and kl_loss_coef can be set" + loss_aggregation = getattr(args, "loss_aggregation", "sample_mean") + if loss_aggregation == "constant": + # Dr.GRPO needs a fixed, positive divisor; fail at startup, not mid-train. + divisor = getattr(args, "loss_aggregation_divisor", None) + if divisor is None or not (divisor > 0): + raise ValueError( + "--loss-aggregation-divisor must be set to a positive value when " + f"--loss-aggregation=constant (got {divisor!r})." + ) + elif loss_aggregation == "token_mean": + # Alias onto --calculate-per-token-loss so pg_loss does not desync from + # the normalizer/metric path (both already implement the per-token mean). + if not getattr(args, "calculate_per_token_loss", False): + args.calculate_per_token_loss = True + if args.advantage_estimator in ["reinforce_plus_plus", "reinforce_plus_plus_baseline"]: assert args.normalize_advantages, ( "The 'reinforce_plus_plus' and 'reinforce_plus_plus_baseline' advantage estimators " diff --git a/tests/test_cp_utils.py b/tests/test_cp_utils.py index c7d3abe9a..567e4bc84 100644 --- a/tests/test_cp_utils.py +++ b/tests/test_cp_utils.py @@ -176,5 +176,107 @@ def test_cp_chunking_preserves_per_rollout_mean_report(monkeypatch): assert cp_total == pytest.approx(baseline) +@pytest.mark.unit +def test_constant_divisor_divides_masked_token_sum_by_L(): + """``constant`` (Dr.GRPO) aggregation: masked token sum / L, NOT any + data-dependent denominator.""" + total_lengths, response_lengths, loss_masks = _make_inputs([3, 3]) + L = 40.0 + reducer = get_sum_of_sample_mean(total_lengths, response_lengths, loss_masks, constant_divisor=L) + x = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) + # sum of all masked tokens = 21; / 40 = 0.525. + assert reducer(x).item() == pytest.approx(21.0 / L) + + +@pytest.mark.unit +def test_prompt_mean_denom_is_per_group_token_sum(): + """``prompt_mean`` (DAPO): two prompts × G rollouts. Each sample's + denominator is its WHOLE prompt-group's mask total (all rollouts of that + prompt), distinct from per-rollout (sample_mean) and per-token (token_mean). + + Fixture: prompt P0 has 2 rollouts of length 2 (group mask sum = 4); + prompt P1 has 2 rollouts of length 4 (group mask sum = 8). + """ + # 4 samples: [P0r0=2, P0r1=2, P1r0=4, P1r1=4]. + total_lengths, response_lengths, loss_masks = _make_inputs([2, 2, 4, 4]) + # prompt_mask_sums: P0 group total = 2+2 = 4 (both P0 samples); P1 = 4+4 = 8. + prompt_denoms = _denoms(4, 4, 8, 8) + reducer = get_sum_of_sample_mean(total_lengths, response_lengths, loss_masks, prompt_denoms) + + # x laid out per sample: P0r0=[1,1], P0r1=[1,1], P1r0=[1,1,1,1], P1r1=[1,1,1,1] + x = torch.tensor([1.0] * 2 + [1.0] * 2 + [1.0] * 4 + [1.0] * 4) + # P0 group mean: (sum of P0 tokens)/4 = 4/4 = 1. P1 group mean: 8/8 = 1. Sum = 2. + assert reducer(x).item() == pytest.approx(2.0) + + # Now make the per-prompt content uneven so prompt_mean is numerically + # distinct from both sample_mean and token_mean. + x2 = torch.tensor([2.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]) + # prompt_mean: P0 sum = 2, /4 = 0.5; P1 sum = 8, /8 = 1.0 → 1.5. + assert reducer(x2).item() == pytest.approx(1.5) + # sample_mean (per-rollout denoms = own length): P0r0 2/2=1, P0r1 0/2=0, + # P1r0 4/4=1, P1r1 4/4=1 → 3.0. Distinct from prompt_mean. + sample_mean = get_sum_of_sample_mean(total_lengths, response_lengths, loss_masks) + assert sample_mean(x2).item() == pytest.approx(3.0) + assert sample_mean(x2).item() != pytest.approx(reducer(x2).item()) + # token_mean numerator (raw masked sum) = 10; distinct again. + token_sum = get_sum_of_sample_mean(total_lengths, response_lengths, loss_masks, calculate_per_token_loss=True) + assert token_sum(x2).item() == pytest.approx(10.0) + assert token_sum(x2).item() != pytest.approx(reducer(x2).item()) + + +@pytest.mark.unit +def test_constant_and_per_token_loss_are_mutually_exclusive(): + """The constant divisor and per-token-loss are distinct aggregation modes; + asking for both is a configuration error, rejected eagerly.""" + total_lengths, response_lengths, loss_masks = _make_inputs([3]) + with pytest.raises(ValueError, match="mutually exclusive"): + get_sum_of_sample_mean( + total_lengths, + response_lengths, + loss_masks, + calculate_per_token_loss=True, + constant_divisor=40.0, + ) + + +@pytest.mark.unit +def test_cp_chunking_preserves_constant_divisor(monkeypatch): + """CP rank-sum invariance for the constant divisor: the divisor is identical + on every CP rank, so summing per-rank reducer outputs reproduces cp=1.""" + from megatron.core import mpu as _mpu + + total_lengths = [12, 12] + response_lengths = [8, 8] + loss_masks = [torch.ones(r, dtype=torch.float32) for r in response_lengths] + L = 40.0 + x_full = [ + torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]), + torch.tensor([10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0]), + ] + x_concat = torch.cat(x_full) + + monkeypatch.setattr(_mpu, "get_context_parallel_world_size", lambda: 1) + monkeypatch.setattr(_mpu, "get_context_parallel_rank", lambda: 0) + reducer_cp1 = get_sum_of_sample_mean(total_lengths, response_lengths, loss_masks, constant_divisor=L) + baseline = reducer_cp1(x_concat).item() + + monkeypatch.setattr(_mpu, "get_context_parallel_world_size", lambda: 2) + cp_total = 0.0 + for cp_rank in range(2): + monkeypatch.setattr(_mpu, "get_context_parallel_rank", lambda r=cp_rank: r) + x_chunks_per_sample = [] + for tl, rl, x in zip(total_lengths, response_lengths, x_full, strict=True): + prompt_length = tl - rl + _, _, _, tokens_offset = get_logits_and_tokens_offset_with_cp(tl, rl) + chunk_0 = x[tokens_offset[0][0] - prompt_length : tokens_offset[0][1] - prompt_length] + chunk_1 = x[tokens_offset[1][0] - prompt_length : tokens_offset[1][1] - prompt_length] + x_chunks_per_sample.append(torch.cat([chunk_0, chunk_1])) + x_for_rank = torch.cat(x_chunks_per_sample) + reducer_cp2 = get_sum_of_sample_mean(total_lengths, response_lengths, loss_masks, constant_divisor=L) + cp_total += reducer_cp2(x_for_rank).item() + + assert cp_total == pytest.approx(baseline) + + if __name__ == "__main__": raise SystemExit(pytest.main([__file__])) diff --git a/tests/test_megatron_argument_validation.py b/tests/test_megatron_argument_validation.py index 1f435cb57..5ab038fbe 100644 --- a/tests/test_megatron_argument_validation.py +++ b/tests/test_megatron_argument_validation.py @@ -304,6 +304,9 @@ def make_slime_validate_args(**overrides): update_weight_disk_dir=None, update_weight_delta_dir=None, update_weight_mode="full", + loss_aggregation="sample_mean", + loss_aggregation_divisor=None, + calculate_per_token_loss=False, ) values.update(overrides) return types.SimpleNamespace(**values) @@ -352,6 +355,48 @@ def test_slime_validate_args_preserves_zero_rollout_gpus_without_colocate(monkey assert args.offload_rollout is False +@pytest.mark.unit +@pytest.mark.parametrize("divisor", [None, 0.0, -1.0, float("nan")]) +def test_loss_aggregation_constant_rejects_nonpositive_divisor(monkeypatch, divisor): + module = load_slime_arguments_module(monkeypatch) + args = make_slime_validate_args(loss_aggregation="constant", loss_aggregation_divisor=divisor) + + with pytest.raises(ValueError, match="loss-aggregation-divisor"): + module.slime_validate_args(args) + + +@pytest.mark.unit +def test_loss_aggregation_constant_accepts_positive_divisor(monkeypatch): + module = load_slime_arguments_module(monkeypatch) + args = make_slime_validate_args(loss_aggregation="constant", loss_aggregation_divisor=40960.0) + + module.slime_validate_args(args) + + assert args.loss_aggregation_divisor == 40960.0 + + +@pytest.mark.unit +def test_loss_aggregation_token_mean_aliases_calculate_per_token_loss(monkeypatch): + module = load_slime_arguments_module(monkeypatch) + args = make_slime_validate_args(loss_aggregation="token_mean", calculate_per_token_loss=False) + + module.slime_validate_args(args) + + assert args.calculate_per_token_loss is True + + +@pytest.mark.unit +def test_loss_aggregation_default_leaves_per_token_loss_off(monkeypatch): + module = load_slime_arguments_module(monkeypatch) + args = make_slime_validate_args() # default sample_mean + + module.slime_validate_args(args) + + assert args.calculate_per_token_loss is False + # No divisor required for non-constant modes. + assert args.loss_aggregation == "sample_mean" + + @pytest.mark.unit def test_update_weight_delta_rejects_colocate(monkeypatch): module = load_slime_arguments_module(monkeypatch) diff --git a/tests/test_rollout_validation.py b/tests/test_rollout_validation.py index fa2ff3433..ddc63b30d 100644 --- a/tests/test_rollout_validation.py +++ b/tests/test_rollout_validation.py @@ -1,3 +1,5 @@ +import types + import pytest from slime.ray.rollout_validation import validate_server_group_gpu_indices @@ -59,5 +61,101 @@ def test_validate_server_group_gpu_indices_reports_config_context(): assert "rollout_num_gpus_per_engine=2" in message +# --------------------------------------------------------------------------- +# _convert_samples_to_train_data: prompt_mask_sums build (--loss-aggregation) +# --------------------------------------------------------------------------- +# +# prompt_mask_sums (the per-prompt-group denominator for prompt_mean) is built +# only under --loss-aggregation=prompt_mean, and that build fails loud if any +# sample is missing its prompt group (group_index is None) — a None would +# silently collapse unrelated prompts into one denominator, degrading +# prompt_mean to sample_mean for that sample. The other three modes never read +# group_index, so they must neither build the key nor consult group_index. + + +def _make_convert_manager(loss_aggregation): + """A bare RolloutManager (no Ray/sglang init) wired just enough to call + ``_convert_samples_to_train_data``: no custom hooks, and reward + post-processing reduced to identity (advantage_estimator outside the + group-norm set), so the only behavior under test is the prompt_mask_sums + build + its group_index guard.""" + from slime.ray.rollout import RolloutManager + + manager = RolloutManager.__new__(RolloutManager) + manager.custom_convert_samples_to_train_data_func = None + manager.custom_reward_post_process_func = None + manager.args = types.SimpleNamespace( + loss_aggregation=loss_aggregation, + reward_key=None, + advantage_estimator="reinforce", # outside the group-norm reshape path + rewards_normalization=False, + grpo_std_normalization=False, + ) + return manager + + +def _make_grouped_samples(group_indices): + """One Sample per entry; each carries a length-2 loss_mask so the + per-group mask totals are non-trivial.""" + from slime.utils.types import Sample + + samples = [] + for i, gid in enumerate(group_indices): + samples.append( + Sample( + index=i, + group_index=gid, + rollout_id=i, + tokens=[0, 1, 2, 3], + response_length=2, + reward=0.0, + loss_mask=[1, 1], + ) + ) + return samples + + +@pytest.mark.unit +def test_prompt_mean_fails_loud_on_none_group_index(): + """prompt_mean with a None group_index is a real break (the prompt-grouping + invariant is violated), so the convert step must raise — not silently + renumber the sample into its own singleton group.""" + pytest.importorskip("sglang") # RolloutManager import pulls sglang + manager = _make_convert_manager("prompt_mean") + samples = _make_grouped_samples([0, 0, None, 1]) + + with pytest.raises(ValueError, match="group_index"): + manager._convert_samples_to_train_data(samples) + + +@pytest.mark.unit +def test_prompt_mean_builds_per_group_mask_sums(): + """Sanity: with every group_index set, prompt_mask_sums is the per-group + mask total broadcast per sample (group 0 has two length-2 samples → 4).""" + pytest.importorskip("sglang") + manager = _make_convert_manager("prompt_mean") + samples = _make_grouped_samples([0, 0, 1]) # group 0: 2 samples, group 1: 1 + + train_data = manager._convert_samples_to_train_data(samples) + + # group 0 = 2+2 = 4 (both samples), group 1 = 2. + assert train_data["prompt_mask_sums"] == [4, 4, 2] + + +@pytest.mark.unit +@pytest.mark.parametrize("mode", ["sample_mean", "constant", "token_mean"]) +def test_non_prompt_mean_modes_ignore_none_group_index(mode): + """The other three modes never read group_index and never build + prompt_mask_sums, so a None group_index must NOT raise and the key must be + absent (keeping the default batch byte-identical — no extra key).""" + pytest.importorskip("sglang") + manager = _make_convert_manager(mode) + samples = _make_grouped_samples([0, None, 1]) + + train_data = manager._convert_samples_to_train_data(samples) + + assert "prompt_mask_sums" not in train_data + + if __name__ == "__main__": raise SystemExit(pytest.main([__file__]))