diff --git a/configs/debug/algorithms/max_rl.toml b/configs/debug/algorithms/max_rl.toml new file mode 100644 index 0000000000..e9d2e77a15 --- /dev/null +++ b/configs/debug/algorithms/max_rl.toml @@ -0,0 +1,43 @@ +max_steps = 20 +seq_len = 2048 + +[model] +name = "PrimeIntellect/Qwen3-0.6B-Reverse-Text-SFT" + +[wandb] +project = "reverse-text-debug" +name = "debug-max-rl" + +[orchestrator] +batch_size = 128 +group_size = 16 + +[orchestrator.algo] +name = "max_rl" + +[orchestrator.renderer] +name = "qwen3" + +[orchestrator.train.sampling] +max_completion_tokens = 128 + +[[orchestrator.train.env]] +id = "reverse-text" + +[orchestrator.eval] +interval = 1 +num_examples = 128 + +[orchestrator.eval.sampling] +max_completion_tokens = 128 + +[[orchestrator.eval.env]] +id = "reverse-text" + +[trainer.optim] +lr = 3e-6 + +[ckpt] + +[inference] +gpu_memory_utilization = 0.5 diff --git a/docs/algorithms.md b/docs/algorithms.md index 9526e6afe0..a3858c56ed 100644 --- a/docs/algorithms.md +++ b/docs/algorithms.md @@ -66,6 +66,7 @@ name = "grpo" # the default | Preset | Sampling | Advantage | Loss | What it is | |---|---|---|---|---| | `grpo` | policy | `group_norm` | `rl` on actions | Standard group-relative RL. | +| `max_rl` | policy | `max_rl` | `rl` on actions | MaxRL ([arXiv:2602.02710](https://arxiv.org/abs/2602.02710)): GRPO's centered reward normalized by the group **mean** instead of the standard deviation — the gradient is unbiased for the order-`group_size` truncation of the maximum-likelihood objective, upweighting hard examples like `1/p`. | | `opd` | policy | `ref_kl` | `ref_kl` on actions | On-policy distillation ([Thinking Machines](https://thinkingmachines.ai/blog/on-policy-distillation/)): the policy samples, per-token reverse KL against a reference model as the gradient signal. Needs an inline `model`. | | `sft_distill` | *(set via `model`)* | `supervised` | `ce` on actions | Hard distillation: a frozen model generates rollouts, the policy trains with CE on its tokens. Needs an inline `model`. | | `self_distill` | policy | `demo_ref_kl` | `ref_kl` on actions | SDFT ([arXiv:2601.19897](https://arxiv.org/abs/2601.19897)): the model is its own reference, conditioned on an expert demonstration. Defaults to the live policy (the paper's setting, no extra deployment); set an inline `model` to score under a frozen copy instead. | @@ -115,6 +116,7 @@ At runtime, each env's resolved config builds two objects: a `Sampler` (`prime_r |---|---|---|---| | `group_norm` | `GRPOAlgorithm` | group-norm credit (optional length penalty) | — | | `echo` | `EchoAlgorithm` | group-norm credit, plus weighted ce on observation tokens | — | +| `max_rl` | `MaxRLAlgorithm` | mean-normalized group credit | — | | `ref_kl` | `OPDAlgorithm` | — | own-context prefill under the teacher | | `demo_ref_kl` | `OPSDAlgorithm` | — | demo-conditioned prefill under the teacher | | `supervised` | `SFTDistillAlgorithm` | group-norm credit (feeds filters) | — | diff --git a/packages/prime-rl-configs/src/prime_rl/configs/algorithm.py b/packages/prime-rl-configs/src/prime_rl/configs/algorithm.py index e475a92682..9542509761 100644 --- a/packages/prime-rl-configs/src/prime_rl/configs/algorithm.py +++ b/packages/prime-rl-configs/src/prime_rl/configs/algorithm.py @@ -33,7 +33,7 @@ from prime_rl.configs.shared import ClientConfig from prime_rl.utils.config import BaseConfig -AlgorithmName: TypeAlias = Literal["grpo", "opd", "sft_distill", "self_distill", "echo"] +AlgorithmName: TypeAlias = Literal["grpo", "max_rl", "opd", "sft_distill", "self_distill", "echo"] class FrozenModelConfig(ClientConfig): @@ -138,6 +138,22 @@ class EchoAdvantageConfig(GroupNormAdvantageConfig): token — tool and user feedback alike.""" +class MaxRLAdvantageConfig(BaseConfig): + type: Literal["max_rl"] = "max_rl" + """MaxRL (arXiv:2602.02710): scalar advantage = (reward − group mean) / + group mean, consumed by the ``rl`` loss component. Normalizing by the + mean instead of GRPO's standard deviation makes the policy gradient + unbiased for the order-``group_size`` truncation of the maximum-likelihood + objective: low-pass-rate examples get ~1/p weight, and ``group_size`` is + the truncation order interpolating REINFORCE (1) → exact maximum + likelihood (∞). Designed for non-negative (canonically binary) rewards; + a group with mean reward 0 carries zero advantages everywhere (the + zero-advantage filter drops it, matching the paper's K=0 convention).""" + + action_loss_type: ClassVar[ActionLossType] = "rl" + group_relative: ClassVar[bool] = True + + class RewardAdvantageConfig(BaseConfig): type: Literal["reward"] = "reward" """Scalar advantage = raw reward, no group baseline. Consumed by the @@ -242,6 +258,7 @@ class CustomAdvantageConfig(BaseConfig): AdvantageConfig: TypeAlias = Annotated[ GroupNormAdvantageConfig | EchoAdvantageConfig + | MaxRLAdvantageConfig | RewardAdvantageConfig | RefKLAdvantageConfig | DemoRefKLAdvantageConfig @@ -264,6 +281,7 @@ class CustomAdvantageConfig(BaseConfig): # live policy (the SDFT setting). _PRESETS: dict[AlgorithmName, dict[str, dict[str, Any]]] = { "grpo": {}, + "max_rl": {"advantage": {"type": "max_rl"}}, "opd": {"advantage": {"type": "ref_kl"}}, "sft_distill": {"sampling": {"source": None}, "advantage": {"type": "supervised"}}, "self_distill": {"advantage": {"type": "demo_ref_kl"}}, diff --git a/src/prime_rl/orchestrator/algo/__init__.py b/src/prime_rl/orchestrator/algo/__init__.py index 54cf6ebf40..a0b348b799 100644 --- a/src/prime_rl/orchestrator/algo/__init__.py +++ b/src/prime_rl/orchestrator/algo/__init__.py @@ -25,12 +25,14 @@ AdvantageOutputs, assign_advantages, default_advantage_fn, + max_rl_advantage_fn, ) from prime_rl.orchestrator.algo.algorithm import ( Algorithm, CustomAlgorithm, EchoAlgorithm, GRPOAlgorithm, + MaxRLAlgorithm, OPDAlgorithm, OPSDAlgorithm, RewardAlgorithm, @@ -49,6 +51,7 @@ "CustomAlgorithm", "EchoAlgorithm", "GRPOAlgorithm", + "MaxRLAlgorithm", "OPDAlgorithm", "OPSDAlgorithm", "RewardAlgorithm", @@ -57,6 +60,7 @@ "build_algorithm", "connect_frozen_pool", "default_advantage_fn", + "max_rl_advantage_fn", "score_train_batch", "spread_token_advantages", "stamp_loss_routing", diff --git a/src/prime_rl/orchestrator/algo/advantage.py b/src/prime_rl/orchestrator/algo/advantage.py index 104f6469d6..7f114696fe 100644 --- a/src/prime_rl/orchestrator/algo/advantage.py +++ b/src/prime_rl/orchestrator/algo/advantage.py @@ -79,6 +79,22 @@ def default_advantage_fn( return AdvantageOutputs(advantages=(rewards - rewards.mean()).tolist()) +def max_rl_advantage_fn(inputs: AdvantageInputs) -> AdvantageOutputs: + """MaxRL advantage for a single group (arXiv:2602.02710): reward minus the + per-group mean, divided by that mean — equivalent to averaging score + functions over successful rollouts only, which makes the policy gradient + unbiased for the order-``group_size`` truncation of the maximum-likelihood + objective instead of pass@1. Assumes non-negative (canonically binary) + rewards; a group with mean reward <= 0 carries no signal and gets zero + advantages (the zero-advantage filter drops it, matching the paper's + no-success convention).""" + rewards = torch.tensor([r["reward"] for r in inputs.rollouts], dtype=torch.float32) + mean = rewards.mean() + if mean <= 0: + return AdvantageOutputs(advantages=torch.zeros_like(rewards).tolist()) + return AdvantageOutputs(advantages=((rewards - mean) / mean).tolist()) + + def _efficiency_shaping( rewards: Float[Tensor, "group_size"], costs: Float[Tensor, "group_size"], diff --git a/src/prime_rl/orchestrator/algo/algorithm.py b/src/prime_rl/orchestrator/algo/algorithm.py index 7236782f8a..926247384e 100644 --- a/src/prime_rl/orchestrator/algo/algorithm.py +++ b/src/prime_rl/orchestrator/algo/algorithm.py @@ -47,6 +47,7 @@ AdvantageOutputs, assign_advantages, default_advantage_fn, + max_rl_advantage_fn, ) from prime_rl.orchestrator.algo.routing import spread_token_advantages, stamp_loss_routing from prime_rl.orchestrator.utils import compute_prefill_logprobs @@ -177,6 +178,18 @@ def __init__(self, config: AlgorithmConfig, policy_pool: InferencePool, renderer self.observation_tokens = config.advantage.observations +class MaxRLAlgorithm(Algorithm): + """Maximum-likelihood RL (arXiv:2602.02710): the GRPO pipeline with + mean-normalized advantages — ``(reward − group mean) / group mean`` + instead of plain centering. The mean normalization upweights low-pass-rate + examples like the maximum-likelihood gradient does, and ``group_size`` + doubles as the truncation order of the likelihood expansion the gradient + is unbiased for (REINFORCE at 1 → exact maximum likelihood as it grows).""" + + def assign(self, rollouts: list[TrainRollout]) -> None: + assign_advantages(rollouts, max_rl_advantage_fn) + + class OPDAlgorithm(Algorithm): """On-policy distillation. Needs a teacher: the frozen reference model the per-token reverse KL is computed against. @@ -331,6 +344,7 @@ def assign(self, rollouts: list[TrainRollout]) -> None: ALGORITHM_CLASSES: dict[str, type[Algorithm]] = { "group_norm": GRPOAlgorithm, "echo": EchoAlgorithm, + "max_rl": MaxRLAlgorithm, "ref_kl": OPDAlgorithm, "demo_ref_kl": OPSDAlgorithm, "supervised": SFTDistillAlgorithm, diff --git a/tests/unit/orchestrator/test_advantage.py b/tests/unit/orchestrator/test_advantage.py index fae9e6880f..3223c4eb73 100644 --- a/tests/unit/orchestrator/test_advantage.py +++ b/tests/unit/orchestrator/test_advantage.py @@ -14,6 +14,7 @@ AdvantageOutputs, assign_advantages, default_advantage_fn, + max_rl_advantage_fn, ) from prime_rl.orchestrator.types import TrainRollout @@ -65,6 +66,17 @@ def test_default_advantage_fn_simple_mean(): assert sum(result.advantages) == pytest.approx(0.0, abs=1e-6) +def test_max_rl_advantage_fn_mean_normalized(): + # mean 0.25: the success gets (1 - 0.25)/0.25 = 3, failures (0 - 0.25)/0.25 = -1 + result = max_rl_advantage_fn(_make_group(rewards=[1.0, 0.0, 0.0, 0.0])) + assert result.advantages == pytest.approx([3.0, -1.0, -1.0, -1.0]) + + # no-success groups carry no signal (the paper's K=0 convention) ... + assert max_rl_advantage_fn(_make_group(rewards=[0.0, 0.0])).advantages == [0.0, 0.0] + # ... and all-success groups center to zero like GRPO + assert max_rl_advantage_fn(_make_group(rewards=[1.0, 1.0])).advantages == pytest.approx([0.0, 0.0]) + + def test_efficiency_mixed_group(): """Mixed group: reward shaping preserves zero-mean, shorter correct gets higher advantage.""" inputs = _make_group(rewards=[1.0, 1.0, 0.0, 1.0], completion_lengths=[10, 30, 20, 20])