Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions configs/debug/algorithms/max_rl.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
max_steps = 20
seq_len = 2048

[model]
name = "PrimeIntellect/Qwen3-0.6B-Reverse-Text-SFT"

[wandb]
project = "reverse-text-debug"
name = "debug-max-rl"

[orchestrator]
batch_size = 128
group_size = 16

[orchestrator.algo]
name = "max_rl"

[orchestrator.renderer]
name = "qwen3"

[orchestrator.train.sampling]
max_completion_tokens = 128

[[orchestrator.train.env]]
id = "reverse-text"

[orchestrator.eval]
interval = 1
num_examples = 128

[orchestrator.eval.sampling]
max_completion_tokens = 128

[[orchestrator.eval.env]]
id = "reverse-text"

[trainer.optim]
lr = 3e-6

[ckpt]

[inference]
gpu_memory_utilization = 0.5
2 changes: 2 additions & 0 deletions docs/algorithms.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ name = "grpo" # the default
| Preset | Sampling | Advantage | Loss | What it is |
|---|---|---|---|---|
| `grpo` | policy | `group_norm` | `rl` on actions | Standard group-relative RL. |
| `max_rl` | policy | `max_rl` | `rl` on actions | MaxRL ([arXiv:2602.02710](https://arxiv.org/abs/2602.02710)): GRPO's centered reward normalized by the group **mean** instead of the standard deviation — the gradient is unbiased for the order-`group_size` truncation of the maximum-likelihood objective, upweighting hard examples like `1/p`. |
| `opd` | policy | `ref_kl` | `ref_kl` on actions | On-policy distillation ([Thinking Machines](https://thinkingmachines.ai/blog/on-policy-distillation/)): the policy samples, per-token reverse KL against a reference model as the gradient signal. Needs an inline `model`. |
| `sft_distill` | *(set via `model`)* | `supervised` | `ce` on actions | Hard distillation: a frozen model generates rollouts, the policy trains with CE on its tokens. Needs an inline `model`. |
| `self_distill` | policy | `demo_ref_kl` | `ref_kl` on actions | SDFT ([arXiv:2601.19897](https://arxiv.org/abs/2601.19897)): the model is its own reference, conditioned on an expert demonstration. Defaults to the live policy (the paper's setting, no extra deployment); set an inline `model` to score under a frozen copy instead. |
Expand Down Expand Up @@ -115,6 +116,7 @@ At runtime, each env's resolved config builds two objects: a `Sampler` (`prime_r
|---|---|---|---|
| `group_norm` | `GRPOAlgorithm` | group-norm credit (optional length penalty) | — |
| `echo` | `EchoAlgorithm` | group-norm credit, plus weighted ce on observation tokens | — |
| `max_rl` | `MaxRLAlgorithm` | mean-normalized group credit | — |
| `ref_kl` | `OPDAlgorithm` | — | own-context prefill under the teacher |
| `demo_ref_kl` | `OPSDAlgorithm` | — | demo-conditioned prefill under the teacher |
| `supervised` | `SFTDistillAlgorithm` | group-norm credit (feeds filters) | — |
Expand Down
20 changes: 19 additions & 1 deletion packages/prime-rl-configs/src/prime_rl/configs/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from prime_rl.configs.shared import ClientConfig
from prime_rl.utils.config import BaseConfig

AlgorithmName: TypeAlias = Literal["grpo", "opd", "sft_distill", "self_distill", "echo"]
AlgorithmName: TypeAlias = Literal["grpo", "max_rl", "opd", "sft_distill", "self_distill", "echo"]


class FrozenModelConfig(ClientConfig):
Expand Down Expand Up @@ -138,6 +138,22 @@ class EchoAdvantageConfig(GroupNormAdvantageConfig):
token — tool and user feedback alike."""


class MaxRLAdvantageConfig(BaseConfig):
type: Literal["max_rl"] = "max_rl"
"""MaxRL (arXiv:2602.02710): scalar advantage = (reward − group mean) /
group mean, consumed by the ``rl`` loss component. Normalizing by the
mean instead of GRPO's standard deviation makes the policy gradient
unbiased for the order-``group_size`` truncation of the maximum-likelihood
objective: low-pass-rate examples get ~1/p weight, and ``group_size`` is
the truncation order interpolating REINFORCE (1) → exact maximum
likelihood (∞). Designed for non-negative (canonically binary) rewards;
a group with mean reward 0 carries zero advantages everywhere (the
zero-advantage filter drops it, matching the paper's K=0 convention)."""

action_loss_type: ClassVar[ActionLossType] = "rl"
group_relative: ClassVar[bool] = True


class RewardAdvantageConfig(BaseConfig):
type: Literal["reward"] = "reward"
"""Scalar advantage = raw reward, no group baseline. Consumed by the
Expand Down Expand Up @@ -242,6 +258,7 @@ class CustomAdvantageConfig(BaseConfig):
AdvantageConfig: TypeAlias = Annotated[
GroupNormAdvantageConfig
| EchoAdvantageConfig
| MaxRLAdvantageConfig
| RewardAdvantageConfig
| RefKLAdvantageConfig
| DemoRefKLAdvantageConfig
Expand All @@ -264,6 +281,7 @@ class CustomAdvantageConfig(BaseConfig):
# live policy (the SDFT setting).
_PRESETS: dict[AlgorithmName, dict[str, dict[str, Any]]] = {
"grpo": {},
"max_rl": {"advantage": {"type": "max_rl"}},
"opd": {"advantage": {"type": "ref_kl"}},
"sft_distill": {"sampling": {"source": None}, "advantage": {"type": "supervised"}},
"self_distill": {"advantage": {"type": "demo_ref_kl"}},
Expand Down
4 changes: 4 additions & 0 deletions src/prime_rl/orchestrator/algo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,14 @@
AdvantageOutputs,
assign_advantages,
default_advantage_fn,
max_rl_advantage_fn,
)
from prime_rl.orchestrator.algo.algorithm import (
Algorithm,
CustomAlgorithm,
EchoAlgorithm,
GRPOAlgorithm,
MaxRLAlgorithm,
OPDAlgorithm,
OPSDAlgorithm,
RewardAlgorithm,
Expand All @@ -49,6 +51,7 @@
"CustomAlgorithm",
"EchoAlgorithm",
"GRPOAlgorithm",
"MaxRLAlgorithm",
"OPDAlgorithm",
"OPSDAlgorithm",
"RewardAlgorithm",
Expand All @@ -57,6 +60,7 @@
"build_algorithm",
"connect_frozen_pool",
"default_advantage_fn",
"max_rl_advantage_fn",
"score_train_batch",
"spread_token_advantages",
"stamp_loss_routing",
Expand Down
16 changes: 16 additions & 0 deletions src/prime_rl/orchestrator/algo/advantage.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,22 @@ def default_advantage_fn(
return AdvantageOutputs(advantages=(rewards - rewards.mean()).tolist())


def max_rl_advantage_fn(inputs: AdvantageInputs) -> AdvantageOutputs:
"""MaxRL advantage for a single group (arXiv:2602.02710): reward minus the
per-group mean, divided by that mean — equivalent to averaging score
functions over successful rollouts only, which makes the policy gradient
unbiased for the order-``group_size`` truncation of the maximum-likelihood
objective instead of pass@1. Assumes non-negative (canonically binary)
rewards; a group with mean reward <= 0 carries no signal and gets zero
advantages (the zero-advantage filter drops it, matching the paper's
no-success convention)."""
rewards = torch.tensor([r["reward"] for r in inputs.rollouts], dtype=torch.float32)
mean = rewards.mean()
if mean <= 0:
return AdvantageOutputs(advantages=torch.zeros_like(rewards).tolist())
return AdvantageOutputs(advantages=((rewards - mean) / mean).tolist())


def _efficiency_shaping(
rewards: Float[Tensor, "group_size"],
costs: Float[Tensor, "group_size"],
Expand Down
14 changes: 14 additions & 0 deletions src/prime_rl/orchestrator/algo/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
AdvantageOutputs,
assign_advantages,
default_advantage_fn,
max_rl_advantage_fn,
)
from prime_rl.orchestrator.algo.routing import spread_token_advantages, stamp_loss_routing
from prime_rl.orchestrator.utils import compute_prefill_logprobs
Expand Down Expand Up @@ -177,6 +178,18 @@ def __init__(self, config: AlgorithmConfig, policy_pool: InferencePool, renderer
self.observation_tokens = config.advantage.observations


class MaxRLAlgorithm(Algorithm):
"""Maximum-likelihood RL (arXiv:2602.02710): the GRPO pipeline with
mean-normalized advantages — ``(reward − group mean) / group mean``
instead of plain centering. The mean normalization upweights low-pass-rate
examples like the maximum-likelihood gradient does, and ``group_size``
doubles as the truncation order of the likelihood expansion the gradient
is unbiased for (REINFORCE at 1 → exact maximum likelihood as it grows)."""

def assign(self, rollouts: list[TrainRollout]) -> None:
assign_advantages(rollouts, max_rl_advantage_fn)


class OPDAlgorithm(Algorithm):
"""On-policy distillation. Needs a teacher: the frozen reference model the
per-token reverse KL is computed against.
Expand Down Expand Up @@ -331,6 +344,7 @@ def assign(self, rollouts: list[TrainRollout]) -> None:
ALGORITHM_CLASSES: dict[str, type[Algorithm]] = {
"group_norm": GRPOAlgorithm,
"echo": EchoAlgorithm,
"max_rl": MaxRLAlgorithm,
"ref_kl": OPDAlgorithm,
"demo_ref_kl": OPSDAlgorithm,
"supervised": SFTDistillAlgorithm,
Expand Down
12 changes: 12 additions & 0 deletions tests/unit/orchestrator/test_advantage.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
AdvantageOutputs,
assign_advantages,
default_advantage_fn,
max_rl_advantage_fn,
)
from prime_rl.orchestrator.types import TrainRollout

Expand Down Expand Up @@ -65,6 +66,17 @@ def test_default_advantage_fn_simple_mean():
assert sum(result.advantages) == pytest.approx(0.0, abs=1e-6)


def test_max_rl_advantage_fn_mean_normalized():
# mean 0.25: the success gets (1 - 0.25)/0.25 = 3, failures (0 - 0.25)/0.25 = -1
result = max_rl_advantage_fn(_make_group(rewards=[1.0, 0.0, 0.0, 0.0]))
assert result.advantages == pytest.approx([3.0, -1.0, -1.0, -1.0])

# no-success groups carry no signal (the paper's K=0 convention) ...
assert max_rl_advantage_fn(_make_group(rewards=[0.0, 0.0])).advantages == [0.0, 0.0]
# ... and all-success groups center to zero like GRPO
assert max_rl_advantage_fn(_make_group(rewards=[1.0, 1.0])).advantages == pytest.approx([0.0, 0.0])


def test_efficiency_mixed_group():
"""Mixed group: reward shaping preserves zero-mean, shorter correct gets higher advantage."""
inputs = _make_group(rewards=[1.0, 1.0, 0.0, 1.0], completion_lengths=[10, 30, 20, 20])
Expand Down
Loading