Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions docs/en/get_started/customization.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ Below is a summary of all available customization interfaces and their purposes.
| [`--rollout-data-postprocess-path`](#8-rollout-data-postprocess---rollout-data-postprocess-path) | Post-process rollout data after log probs are computed. |
| [`--custom-loss-function-path`](#9-custom-loss-function---custom-loss-function-path) | Implement custom training loss computation. |
| [`--custom-tis-function-path`](#10-custom-tisrs-function---custom-tis-function-path) | Implement custom importance sampling for off-policy correction. |
| [`--custom-pg-loss-reducer-function-path`](#11-custom-pg-loss-reducer---custom-pg-loss-reducer-function-path) | Customize pg_loss reduction (e.g., for Dr.GRPO). |
| [`--custom-pg-loss-reducer-function-path`](#11-custom-pg-loss-reducer---custom-pg-loss-reducer-function-path) | Customize pg_loss reduction. |
| [`--custom-reward-post-process-path`](#12-reward-post-processing---custom-reward-post-process-path) | Custom post-processing of rewards before advantage computation. |
| [`--custom-convert-samples-to-train-data-path`](#13-samples-to-train-data-conversion---custom-convert-samples-to-train-data-path) | Override the conversion of samples to training data format. |
| [`--custom-rollout-log-function-path`](#14-logging-functions) | Custom logging for training rollouts. |
Expand Down Expand Up @@ -295,10 +295,15 @@ def get_pg_loss_reducer(
```

**Use Cases**:
- Dr.GRPO: Divide by a constant instead of effective token count
- Custom loss normalization strategies

**Example**: `examples/DrGRPO/custom_reducer.py:get_pg_loss_reducer`
**Note**: For the Dr.GRPO normalization — divide pg_loss by a constant instead of the
data-dependent active-token count (arXiv:2503.20783, Eq. 2; also used by DeepSWE) — no
custom reducer is needed. It is built in:

```bash
--pg-loss-divisor 40960 # a constant, e.g. the max context length
```

---

Expand Down
10 changes: 7 additions & 3 deletions docs/zh/get_started/customization.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ slime 通过函数路径参数提供了广泛的自定义能力。这些参数
| [`--rollout-data-postprocess-path`](#8-rollout-数据后处理---rollout-data-postprocess-path) | 在计算 log probabilities 后对 rollout 数据进行后处理。 |
| [`--custom-loss-function-path`](#9-自定义损失函数---custom-loss-function-path) | 实现自定义训练损失计算。 |
| [`--custom-tis-function-path`](#10-自定义-tisrs-函数---custom-tis-function-path) | 实现用于离策略(off-policy)校正的自定义重要性采样。 |
| [`--custom-pg-loss-reducer-function-path`](#11-自定义-pg-loss-reducer---custom-pg-loss-reducer-function-path) | 自定义 pg_loss 的归约方式(如 Dr.GRPO)。 |
| [`--custom-pg-loss-reducer-function-path`](#11-自定义-pg-loss-reducer---custom-pg-loss-reducer-function-path) | 自定义 pg_loss 的归约方式。 |
| [`--custom-reward-post-process-path`](#12-奖励后处理---custom-reward-post-process-path) | 在优势计算前对奖励进行自定义后处理。 |
| [`--custom-convert-samples-to-train-data-path`](#13-样本转训练数据---custom-convert-samples-to-train-data-path) | 覆盖样本到训练数据格式的转换逻辑。 |
| [`--custom-rollout-log-function-path`](#14-日志函数) | 训练 rollout 的自定义日志记录。 |
Expand Down Expand Up @@ -295,10 +295,14 @@ def get_pg_loss_reducer(
```

**使用场景**:
- Dr.GRPO:除以常数而非有效 token 数
- 自定义损失归一化策略

**示例**: `examples/DrGRPO/custom_reducer.py:get_pg_loss_reducer`
**注**:对于 Dr.GRPO 归一化——用常数而非依赖数据的有效 token 数来除 pg_loss
(arXiv:2503.20783,式 2;DeepSWE 也采用此方式)——无需自定义 reducer,已内置支持:

```bash
--pg-loss-divisor 40960 # 常数,例如最大上下文长度
```

---

Expand Down
16 changes: 16 additions & 0 deletions examples/DrGRPO/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Dr.GRPO

Dr.GRPO's constant-divisor loss normalization (arXiv:2503.20783, Eq. 2; also used by
DeepSWE) is built into slime — no custom code is needed:

```bash
--pg-loss-divisor 40960 # a constant, e.g. the max context length
```

When set, pg_loss is aggregated as `sum(token_loss * loss_mask) / divisor` instead of the
default sum of per-sample active-token means, removing the length bias of per-sample
normalization. Other metrics (pg_clipfrac, ppo_kl, entropy_loss, etc.) keep the default
reducer.

For normalizations that need more than a constant divisor, see the *Custom pg_loss
Reducer* section in [docs/en/get_started/customization.md](../../docs/en/get_started/customization.md).
21 changes: 20 additions & 1 deletion slime/backends/megatron_utils/cp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def get_sum_of_sample_mean(
calculate_per_token_loss: bool = False,
qkv_format: str = "thd",
max_seq_lens: list[int] | None = None,
constant_divisor: float | None = None,
) -> Callable[[torch.Tensor], torch.Tensor]:
"""
Calculate correct sample mean for CP.
Expand All @@ -71,6 +72,14 @@ def get_sum_of_sample_mean(
step level rather than per-mb is required — otherwise a rollout whose
samples land in different micro-batches would get a partial denominator
on each side.

``constant_divisor`` switches to the Dr.GRPO normalization (see
``--pg-loss-divisor``): the masked token-loss sum is divided by this
constant instead of per-sample active-token means. The constant is
identical on every CP rank, so the gradient sum-allreduce across CP ranks
needs no denominator correction. Intentionally not applied under
``calculate_per_token_loss`` — Megatron already divides by the
all-reduced token count there.
"""
if sample_denoms is None:
sample_denoms = [m.sum() for m in loss_masks]
Expand Down Expand Up @@ -133,7 +142,17 @@ def sum_of_token(x: torch.Tensor) -> torch.Tensor:
]
)

return sum_of_sample_mean if not calculate_per_token_loss else sum_of_token
if calculate_per_token_loss:
return sum_of_token

if constant_divisor is not None:

def sum_of_token_over_constant(x: torch.Tensor) -> torch.Tensor:
return sum_of_token(x) / constant_divisor

return sum_of_token_over_constant

return sum_of_sample_mean


def reduce_train_step_metrics(
Expand Down
13 changes: 12 additions & 1 deletion slime/backends/megatron_utils/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -947,14 +947,25 @@ def policy_loss_function(
max_seq_lens,
)

# Determine pg_loss reducer: use custom if specified, otherwise default
# Determine pg_loss reducer: custom hook first, then --pg-loss-divisor, otherwise default
if getattr(args, "custom_pg_loss_reducer_function_path", None) is not None:
custom_pg_loss_reducer_func = load_function(args.custom_pg_loss_reducer_function_path)
# Determine which loss_masks to use for pg_loss reducer
pg_loss_masks = modified_response_masks if (args.get_mismatch_metrics or args.use_tis) else batch["loss_masks"]
pg_loss_reducer = custom_pg_loss_reducer_func(
total_lengths, response_lengths, pg_loss_masks, args.calculate_per_token_loss
)
elif getattr(args, "pg_loss_divisor", None) is not None:
pg_loss_masks = modified_response_masks if (args.get_mismatch_metrics or args.use_tis) else batch["loss_masks"]
pg_loss_reducer = get_sum_of_sample_mean(
total_lengths,
response_lengths,
pg_loss_masks,
calculate_per_token_loss=args.calculate_per_token_loss,
qkv_format=args.qkv_format,
max_seq_lens=max_seq_lens,
constant_divisor=args.pg_loss_divisor,
)
else:
pg_loss_reducer = sum_of_sample_mean

Expand Down
17 changes: 16 additions & 1 deletion slime/utils/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -1039,7 +1039,19 @@ def add_algo_arguments(parser):
"--custom-pg-loss-reducer-function-path",
type=str,
default=None,
help="Path to a custom reducer function for pg_loss only. When set, pg_loss will use this custom reducer while other metrics (pg_clipfrac, ppo_kl, entropy_loss, etc.) still use the default sum_of_sample_mean. (e.g., examples/Dr.GRPO/custom_reducer.py:get_pg_loss_reducer).",
help="Path to a custom reducer function for pg_loss only. When set, pg_loss will use this custom reducer while other metrics (pg_clipfrac, ppo_kl, entropy_loss, etc.) still use the default sum_of_sample_mean. See the Custom pg_loss Reducer section of docs/en/get_started/customization.md for the expected signature.",
)
parser.add_argument(
"--pg-loss-divisor",
type=float,
default=None,
help="Constant divisor for pg_loss aggregation. When set, pg_loss is the masked "
"token-loss sum divided by this constant (e.g. the max context length) instead of "
"the default sum of per-sample active-token means, removing the length bias of "
"per-sample normalization (Dr.GRPO, arXiv:2503.20783, Eq. 2; also used by DeepSWE). "
"Other metrics keep the default reducer. Ignored when "
"--custom-pg-loss-reducer-function-path is set, and under "
"--calculate-per-token-loss (Megatron already divides by the token count there).",
)

parser.add_argument(
Expand Down Expand Up @@ -1839,6 +1851,9 @@ def slime_validate_args(args):
if args.use_rollout_logprobs:
assert not args.use_tis, "use_rollout_logprobs and use_tis cannot be set at the same time."

if args.pg_loss_divisor is not None and not args.pg_loss_divisor > 0:
raise ValueError(f"--pg-loss-divisor must be a positive number, got {args.pg_loss_divisor}.")

if args.get_mismatch_metrics:
assert (
args.custom_tis_function_path is not None
Expand Down
83 changes: 83 additions & 0 deletions tests/test_cp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,5 +176,88 @@ def test_cp_chunking_preserves_per_rollout_mean_report(monkeypatch):
assert cp_total == pytest.approx(baseline)


@pytest.mark.unit
def test_constant_divisor_replaces_per_sample_means():
"""``--pg-loss-divisor`` contract: masked token sum over a constant, NOT the
sum of per-sample active-token means. Masked-out tokens drop from the
numerator while the denominator stays the constant."""
total_lengths, response_lengths, loss_masks = _make_inputs([3, 3])
loss_masks[1] = torch.tensor([1.0, 0.0, 1.0])
x = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])

reducer = get_sum_of_sample_mean(total_lengths, response_lengths, loss_masks, constant_divisor=8.0)
# masked sum = (1+2+3) + (4+6) = 16; divided by the constant 8.
assert reducer(x).item() == pytest.approx(2.0)

# Default (divisor unset) keeps the per-sample means: 2 + 5 = 7.
default = get_sum_of_sample_mean(total_lengths, response_lengths, loss_masks)
assert default(x).item() == pytest.approx(7.0)


@pytest.mark.unit
def test_constant_divisor_not_applied_under_per_token_loss():
"""With ``calculate_per_token_loss`` the reducer must return the raw masked
token sum — Megatron divides by the all-reduced token count itself, so
applying the constant here would double-normalize."""
total_lengths, response_lengths, loss_masks = _make_inputs([3])
x = torch.tensor([1.0, 2.0, 3.0])

reducer = get_sum_of_sample_mean(total_lengths, response_lengths, loss_masks, None, True, constant_divisor=4.0)
assert reducer(x).item() == pytest.approx(6.0)


@pytest.mark.unit
@pytest.mark.parametrize("qkv_format", ["thd", "bshd"])
def test_cp_chunking_preserves_constant_divisor_total(monkeypatch, qkv_format):
"""Summing per-CP-rank reducer outputs reproduces the cp=1 value for both
layouts: the constant divisor is identical on every CP rank, so the
gradient sum-allreduce needs no denominator correction."""
from megatron.core import mpu as _mpu

total_lengths = [12, 12]
response_lengths = [8, 8]
max_seq_lens = [16, 16] if qkv_format == "bshd" else None
loss_masks = [torch.ones(r, dtype=torch.float32) for r in response_lengths]
divisor = 32.0
x_full = [
torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]),
torch.tensor([10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0]),
]

monkeypatch.setattr(_mpu, "get_context_parallel_world_size", lambda: 1)
monkeypatch.setattr(_mpu, "get_context_parallel_rank", lambda: 0)
reducer_cp1 = get_sum_of_sample_mean(
total_lengths, response_lengths, loss_masks, None, False, qkv_format, max_seq_lens, constant_divisor=divisor
)
baseline = reducer_cp1(torch.cat(x_full)).item()
assert baseline == pytest.approx(sum(x.sum().item() for x in x_full) / divisor)

monkeypatch.setattr(_mpu, "get_context_parallel_world_size", lambda: 2)
cp_total = 0.0
for cp_rank in range(2):
monkeypatch.setattr(_mpu, "get_context_parallel_rank", lambda r=cp_rank: r)
x_chunks_per_sample = []
for i, (tl, rl, x) in enumerate(zip(total_lengths, response_lengths, x_full, strict=True)):
prompt_length = tl - rl
max_seq_len = max_seq_lens[i] if max_seq_lens is not None else None
_, _, _, tokens_offset = get_logits_and_tokens_offset_with_cp(tl, rl, qkv_format, max_seq_len)
chunk_0 = x[tokens_offset[0][0] - prompt_length : tokens_offset[0][1] - prompt_length]
chunk_1 = x[tokens_offset[1][0] - prompt_length : tokens_offset[1][1] - prompt_length]
x_chunks_per_sample.append(torch.cat([chunk_0, chunk_1]))
reducer_cp2 = get_sum_of_sample_mean(
total_lengths,
response_lengths,
loss_masks,
None,
False,
qkv_format,
max_seq_lens,
constant_divisor=divisor,
)
cp_total += reducer_cp2(torch.cat(x_chunks_per_sample)).item()

assert cp_total == pytest.approx(baseline)


if __name__ == "__main__":
raise SystemExit(pytest.main([__file__]))
20 changes: 20 additions & 0 deletions tests/test_megatron_argument_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@ def make_slime_validate_args(**overrides):
normalize_advantages=False,
use_rollout_logprobs=False,
use_tis=False,
pg_loss_divisor=None,
get_mismatch_metrics=False,
custom_tis_function_path=None,
use_dynamic_batch_size=False,
Expand Down Expand Up @@ -310,6 +311,25 @@ def make_slime_validate_args(**overrides):
return types.SimpleNamespace(**values)


@pytest.mark.unit
@pytest.mark.parametrize("bad_divisor", [0.0, -1.0, float("nan")])
def test_slime_validate_args_rejects_non_positive_pg_loss_divisor(monkeypatch, bad_divisor):
module = load_slime_arguments_module(monkeypatch)

with pytest.raises(ValueError, match="--pg-loss-divisor"):
module.slime_validate_args(make_slime_validate_args(pg_loss_divisor=bad_divisor))


@pytest.mark.unit
def test_slime_validate_args_accepts_positive_pg_loss_divisor(monkeypatch):
module = load_slime_arguments_module(monkeypatch)

args = make_slime_validate_args(pg_loss_divisor=40960.0)
module.slime_validate_args(args)

assert args.pg_loss_divisor == 40960.0


@pytest.mark.unit
def test_slime_validate_args_preserves_zero_rollout_gpus_under_colocate(monkeypatch):
module = load_slime_arguments_module(monkeypatch)
Expand Down
Loading