diff --git a/configs/debug/algorithms/README.md b/configs/debug/algorithms/README.md index a670af2247..72b7c3a8bd 100644 --- a/configs/debug/algorithms/README.md +++ b/configs/debug/algorithms/README.md @@ -13,6 +13,7 @@ Minimal end-to-end configs for the algorithms against bundled verifiers envs, us | `sft_distill_external.toml` | `sft` | PI inference (`openai/gpt-5-mini`) | external OAI endpoint; no local server | | `self_distill.toml` | `opsd` | none (`model = "policy"`) | SDFT against the live policy; demo from reverse-text's `answer` field | | `echo.toml` | `echo` | none | multi-turn `alphabet-sort`; CE on observation tokens | +| `rlcsd.toml` | `rlcsd` | none (`model = "policy"`) | contrastive self-distillation modulating GRPO; hints from sibling rollouts | | `mixed_grpo_opd.toml` | `grpo` + `opd` (per env) | local vLLM (`Qwen3-0.6B-Reverse-Text-RL`) | two envs, one run; heterogeneous batches (with/without `ref_logprobs`) | The policy inference server is auto-launched on GPU 0 at `http://localhost:8000/v1` with `gpu_memory_utilization=0.5`. The local frozen model (used by `opd*.toml`, `sft_distill.toml` / `sft_distill_lora.toml`, and `mixed_grpo_opd.toml`) is **not** auto-launched — start it manually on GPU 1. @@ -58,6 +59,9 @@ uv run rl @ configs/debug/algorithms/self_distill.toml # ECHO (no frozen model; multi-turn env) uv run rl @ configs/debug/algorithms/echo.toml +# RLCSD (no frozen model; teacher = live policy on sibling hints) +uv run rl @ configs/debug/algorithms/rlcsd.toml + # Mixed per-env algorithms: GRPO + OPD in one run (needs the frozen model on port 8001) uv run rl @ configs/debug/algorithms/mixed_grpo_opd.toml ``` diff --git a/configs/debug/algorithms/rlcsd.toml b/configs/debug/algorithms/rlcsd.toml new file mode 100644 index 0000000000..2fe255bf45 --- /dev/null +++ b/configs/debug/algorithms/rlcsd.toml @@ -0,0 +1,65 @@ +# RLCSD (arXiv:2606.11709) on reverse-text: GRPO anchored by the verifier, +# with a contrastive self-distillation signal modulating the advantage at +# high-signal tokens. The teacher is the live policy conditioned on correct / +# incorrect sibling rollouts — no extra server needed. +# uv run rl @ configs/debug/algorithms/rlcsd.toml + +max_steps = 20 +seq_len = 2048 + +[model] +name = "PrimeIntellect/Qwen3-0.6B-Reverse-Text-SFT" + +[wandb] +project = "algorithms-debug" +name = "debug-rlcsd" + +[orchestrator] +batch_size = 128 +group_size = 16 + +# The default template's continuation instruction is math-flavored +# (\boxed{}); reverse-text just wants the answer re-attempted. Reverse-text's +# reward is continuous (LCS) and exact reversals are rare at 0.6B, so the +# binary-verifier default threshold of 1.0 would leave every correct-hint +# pool empty — mostly-right reversals (>= 0.5) serve as positive hints. +[orchestrator.algo.advantage] +type = "rlcsd" +correct_threshold = 0.5 +# Negative hints must be clearly failed reversals (< 0.2), not borderline: +# the band in between serves as neither hint, so noise contrasts stop firing +# as groups tighten around the threshold. +min_contrast_gap = 0.3 +template = """{question} + +Here is a reference solution to this problem: +=== Reference Solution Begin === +{hint} +=== Reference Solution End === + +After reading the reference solution above, answer the original problem yourself.""" + +[orchestrator.renderer] +name = "qwen3" + +[orchestrator.train.sampling] +max_completion_tokens = 128 + +[[orchestrator.train.env]] +id = "reverse-text" + +[orchestrator.eval] +interval = 5 +num_examples = 128 + +[orchestrator.eval.sampling] +max_completion_tokens = 128 + +[[orchestrator.eval.env]] +id = "reverse-text" + +[trainer.optim] +lr = 1e-6 + +[inference] +gpu_memory_utilization = 0.5 diff --git a/docs/algorithms.md b/docs/algorithms.md index 2580b644c8..2e815f2f43 100644 --- a/docs/algorithms.md +++ b/docs/algorithms.md @@ -71,6 +71,7 @@ type = "grpo" # the default | `sft` | *(the teacher)* | `ce` on actions | Hard distillation: a frozen model generates rollouts, the policy trains with CE on its tokens. Needs a `teacher` (folds into `sampling.source`). | | `opsd` | policy | `ref_kl` on actions | SDFT ([arXiv:2601.19897](https://arxiv.org/abs/2601.19897)): the model is its own reference, conditioned on an expert demonstration. Defaults to the live policy (the paper's setting, no extra deployment); set an inline `model` to score under a frozen copy instead. | | `echo` | policy | `rl` on actions + weighted `ce` on observations | ECHO: standard GRPO plus a cross-entropy loss on env-provided tokens already present in the rollout, selected by message role (needs the renderer's role attribution). Defaults to tool-response bodies at `alpha = 0.1` (ECHO's λ); set `roles` to train other roles, each at its own weight. | +| `rlcsd` | policy | `rl` on actions | RLCSD ([arXiv:2606.11709](https://arxiv.org/abs/2606.11709)): GRPO anchored by the verifier, with a contrastive self-distillation signal — the teacher's logprobs under a correct sibling-rollout hint vs. under K incorrect sibling hints — modulating the advantage magnitude at high-signal tokens (sign-preserving). The identical hint template makes the privilege-induced style shift cancel in the subtraction, concentrating the signal on task-bearing tokens. Teacher defaults to the live policy. | | `reward` | policy | `rl` on actions | REINFORCE-style: advantage = raw reward, no group baseline. | | `custom` | policy | `rl` on actions | Your own advantage function (`import_path`), per-token advantages per rollout — see [Custom Advantage](#custom-advantage). | @@ -133,6 +134,19 @@ advantage = { type = "echo" } # shorthand: the env assembles its own algorithm At runtime, each env's resolved config builds two objects: a `Sampler` (`prime_rl.orchestrator.sampler`) from the `sampling` component — the pool rollouts are generated from, and the home of future sampling strategies like replay buffers or branching — and one of the named algorithm classes in `prime_rl.orchestrator.algo` (one module per algorithm: `algo/grpo.py`, `algo/opd.py`, …) from the `advantage` component. Algorithm dispatch is keyed on `advantage.type` — it names the algorithm, and each config class's defaults are its vetted parameterization: +<<<<<<< HEAD +| `advantage.type` | Class | `assign_advantages` (group time) | `query_references` (ship time) | +|---|---|---|---| +| `grpo` | `GRPOAlgorithm` | group-norm credit (optional length penalty) | — | +| `echo` | `EchoAlgorithm` | group-norm credit, plus weighted ce on observation tokens | — | +| `max_rl` | `MaxRLAlgorithm` | mean-normalized group credit | — | +| `opd` | `OPDAlgorithm` | — | own-context prefill under the teacher | +| `opsd` | `OPSDAlgorithm` | — | demo-conditioned prefill under the teacher | +| `rlcsd` | `RLCSDAlgorithm` | std-normalized group credit | contrastive hinted prefills → per-token advantage modulation | +| `sft` | `SFTDistillAlgorithm` | group-norm credit (feeds filters) | — | +| `reward` | `RewardAlgorithm` | raw reward | — | +| `custom` | `CustomAlgorithm` | your function | — | +======= | `advantage.type` | Class | hook(s) — stage | |---|---|---| | `grpo` | `GRPOAlgorithm` | `score_group`: group-norm credit (optional length penalty) | @@ -143,6 +157,7 @@ At runtime, each env's resolved config builds two objects: a `Sampler` (`prime_r | `sft` | `SFTDistillAlgorithm` | `score_group`: group-norm credit (feeds filters) | | `reward` | `RewardAlgorithm` | `score_rollout`: raw reward | | `custom` | `CustomAlgorithm` | `score_group`: your function | +>>>>>>> feat/algorithm-abstraction Each class owns its hooks outright — reading one top to bottom reads the algorithm, and everything on the class is an override point. The three hooks are one scope-and-timing ladder — each wider scope is unlocked by a later barrier, so the two axes coincide. Each is handed a `RolloutView` (a writable handle exposing only what is valid at its stage: `raw`, `samples`, `reward`, and `assign_advantages` — never not-yet-assigned credit or pipeline-internal lifecycle fields): @@ -150,7 +165,11 @@ Each class owns its hooks outright — reading one top to bottom reads the algor - `score_group(group)` — the cohort, **before filtering** (filters read the streams), synchronous: group-relative credit (GRPO/MaxRL baselines). `group` is a list of `RolloutView`. - `async score_batch(batch)` — the batch's survivors, **after filtering** (dropped rollouts never cost reference compute), async: the only stage with model access — query the algorithm's reference pool (e.g. `self.teacher_pool`, connected in its `setup()` override via `self.connect(...)` — the live policy pool when the reference is `"policy"`, a freshly connected client pool when frozen) and attach per-token results, or modulate advantages. +<<<<<<< HEAD +The two hooks are pinned by the filter barrier: everything the orchestrator computes locally runs before it, everything that queries a model runs after it. The pipeline drives them through two module-level phase functions it never looks inside: `finalize_group(algorithm, rollouts)` per group (credit + wire stamping; after this the records are frozen — groups die at stamping) and `finalize_batch(train_envs, rollouts)` per batch. Sample construction (interleaving) is pure pipeline — it records observation-token provenance as `obs_spans` for any algorithm that trains on env-provided tokens. +======= The pipeline drives the hooks through three module-level phase functions it never looks inside: `finalize_rollout(algorithm, rollout)` per arrival, `finalize_group(algorithm, rollouts)` per group (scoring + wire stamping; after this the records are frozen — groups die at stamping), and `finalize_batch(train_envs, rollouts)` per batch. Sample construction (interleaving) is pure pipeline — it records the `obs_spans` provenance for any algorithm that trains on env-provided tokens. +>>>>>>> feat/algorithm-abstraction Class-level declarations state what the algorithm needs: which loss component its action tokens feed (`action_loss_type`) and what it calls its reference model (`model_role`, e.g. `"teacher"`). Every class is constructed with its advantage config — the component it interprets; the bundle dissolves at construction — plus the two host-owned resources: the policy pool and the policy's renderer. Text → token ids always goes through the renderer, the same path the policy's own prompts take (`opsd` requires one, validated at config time). The pipeline only ever calls the phase functions — writing your own algorithm is subclassing `Algorithm` and overriding the hooks its signal needs. For pure credit assignment, no subclass is needed: `advantage.type = "custom"` imports a plain advantage function (see [Custom Advantage](#custom-advantage)); custom reference scoring means forking one of the named classes. Shared math (group normalization, prefill alignment) lives as plain functions in `prime_rl.orchestrator.algo.advantage`. @@ -283,6 +302,7 @@ The advantage strategy is the `advantage` component of the [algorithm](#the-algo | `reward` | `rl` | Advantage = raw reward, no baseline. | | `opd` | `ref_kl` | On-policy distillation: per-token reverse KL to a reference model (`model`, an inline frozen hosted model), evaluated in the trainer from shipped reference logprobs. No credit — rollouts keep `advantages = None` (advantage-based filters never fire) and ship no advantage stream; `group_size` only fans out sampling. | | `opsd` | `ref_kl` | SDFT: per-token reverse KL to a demo-conditioned reference. No credit — rollouts keep `advantages = None` (advantage-based filters never fire) and ship no advantage stream. | +| `rlcsd` | `rl` | Std-normalized group credit, modulated per token at ship time by the contrastive hinted-teacher signal (`λ·tanh(e_ctr/τ)`, masked at `δ`, sign-preserving clamp, two-path normalization via `η`). | | `sft` | `ce` | Cross-entropy on the sampled tokens. The loss ignores advantages, but group-relative credit is still assigned so reward-based filtering keeps working. | | `custom` | `rl` | Your function (below); per-token advantages per rollout. | @@ -305,17 +325,31 @@ type = "tokens" ### Custom Advantage +<<<<<<< HEAD +Advantages are computed **per group**. You write a function that takes one group's `TrainRollout`s — the same objects the algorithm hooks see — and returns per-token advantages: one list per rollout, aligned to its training samples' completion tokens (for multi-turn envs the merged completion, including interleaved observation tokens). There is no scalar advantage anywhere in the pipeline — uniform group credit goes through `broadcast(rollouts, values)`, which spreads one value per rollout over its completion tokens. The orchestrator handles groups of varying size automatically — partial-group training kicks in when some rollouts in a group errored. +======= Advantages are computed **per group**. You write a function that takes one group's `RolloutView`s — the same handles the `score_group` hook sees — and returns one value per rollout: a scalar (broadcast over that rollout's completion tokens) or a per-token list aligned to them (for multi-turn envs the merged completion, including interleaved observation tokens). There is no scalar advantage stored anywhere in the pipeline — the scalar is just a convenience the view broadcasts at write time. The orchestrator handles groups of varying size automatically — partial-group training kicks in when some rollouts in a group errored. +>>>>>>> feat/algorithm-abstraction ```python # my_module.py import statistics +<<<<<<< HEAD +from prime_rl.orchestrator.algo import broadcast + +def normalized_advantage(rollouts, eps: float = 1e-8) -> list[list[float]]: + rewards = [r.reward for r in rollouts] + mean = statistics.fmean(rewards) + std = statistics.pstdev(rewards) if len(rewards) > 1 else 0.0 + return broadcast(rollouts, [(r - mean) / (std + eps) for r in rewards]) +======= def normalized_advantage(group, eps: float = 1e-8) -> list[float]: rewards = [v.reward for v in group] mean = statistics.fmean(rewards) std = statistics.pstdev(rewards) if len(rewards) > 1 else 0.0 return [(r - mean) / (std + eps) for r in rewards] # one scalar per rollout +>>>>>>> feat/algorithm-abstraction ``` ```toml @@ -325,11 +359,27 @@ import_path = "my_module.normalized_advantage" kwargs = { eps = 1e-8 } ``` +<<<<<<< HEAD +Each `TrainRollout` carries `raw` (the env's untouched `verifiers.RolloutOutput`: turns, tool calls, custom metadata), `samples` (the merged token sequences), and the interleaving provenance — so you have the full interleaved rollout, not just the reward. Use this for anything reward-shaping-like that needs trajectory context. +======= Each `RolloutView` exposes `raw` (the env's untouched `verifiers.RolloutOutput`: turns, tool calls, custom metadata), `samples` (the merged token sequences), and `reward` — so you have the full interleaved rollout, not just the reward. Use this for anything reward-shaping-like that needs trajectory context. +>>>>>>> feat/algorithm-abstraction Genuinely per-token credit (process rewards, step-level credit assignment) returns shaped lists instead of scalars: ```python +<<<<<<< HEAD +def step_weighted_advantage(rollouts) -> list[list[float]]: + rewards = [r.reward for r in rollouts] + baseline = statistics.fmean(rewards) + return [ + [(reward - baseline) * w for w in my_token_weights(rollout.raw)] # one float per completion token + for reward, rollout in zip(rewards, rollouts) + ] +``` + +Each list must match the rollout's completion-token count exactly — validated loudly at group finalization. Advantage-based filters and metrics derive from the streams (the zero-advantage filter checks for all-zero streams; logged distributions use per-rollout means). Signals that depend on the live policy's weights (like OPD's reverse KL) cannot be precomputed here; those are reference-scoring algorithms, evaluated in the trainer. +======= def step_weighted_advantage(group) -> list[list[float]]: rewards = [v.reward for v in group] baseline = statistics.fmean(rewards) @@ -340,6 +390,7 @@ def step_weighted_advantage(group) -> list[list[float]]: ``` Each per-token list must match the rollout's completion-token count exactly — validated loudly when the view writes it. Advantage-based filters and metrics derive from the streams (the zero-advantage filter checks for all-zero streams; logged distributions use per-rollout means). Signals that depend on the live policy's weights (like OPD's reverse KL) cannot be precomputed here; those are reference-scoring algorithms, evaluated in the trainer. +>>>>>>> feat/algorithm-abstraction ### Reference Scoring @@ -356,6 +407,8 @@ demo_key = "demonstration" max_concurrent = 64 ``` +`rlcsd` also scores at ship time — `1 + num_negative_hints` hinted prefills per rollout, hints drawn from the rollout's own group siblings — but ships modulated per-token advantages instead of reference logprobs. + Only batch survivors get scored — rollouts that are filtered or cancelled never cost reference compute. The time shows up as `time/scoring` in the step timing. ## Filters diff --git a/docs/training.md b/docs/training.md index 9c332b7eca..399b90cf90 100644 --- a/docs/training.md +++ b/docs/training.md @@ -93,6 +93,7 @@ The RL entrypoint supports several training algorithms, switched via `[orchestra | `sft` | Required, any OpenAI-compatible endpoint | Hard-distill: a frozen model generates rollouts, the policy trains on its tokens | | `opsd` | `"policy"` (the default, no deployment) or a vLLM endpoint serving a frozen copy | [SDFT](https://arxiv.org/abs/2601.19897): the model is its own reference conditioned on expert demonstrations | | `echo` | None | GRPO plus cross-entropy on env-observation tokens | +| `rlcsd` | `"policy"` (the default) or a vLLM endpoint | [RLCSD](https://arxiv.org/abs/2606.11709): GRPO with a contrastive self-distillation signal from sibling-rollout hints modulating per-token advantages | `reward` (raw-reward credit, no baseline) and `custom` (your own advantage function) complete the set — see [Algorithms § The Algorithms](algorithms.md#the-algorithms). diff --git a/packages/prime-rl-configs/src/prime_rl/configs/algorithm.py b/packages/prime-rl-configs/src/prime_rl/configs/algorithm.py index d7ff988d61..1b72a72005 100644 --- a/packages/prime-rl-configs/src/prime_rl/configs/algorithm.py +++ b/packages/prime-rl-configs/src/prime_rl/configs/algorithm.py @@ -261,6 +261,80 @@ class OPSDAdvantageConfig(BaseConfig): """Maximum concurrent prefill requests per batch.""" +class RLCSDAdvantageConfig(BaseConfig): + type: Literal["rlcsd"] = "rlcsd" + """RLCSD (arXiv:2606.11709): GRPO with a contrastive self-distillation + modulation. The teacher scores each rollout's tokens under a correct + sibling rollout as a hint and under ``num_negative_hints`` incorrect + sibling hints (identical template, so the privilege-induced style shift + cancels in the subtraction); the squashed contrast ``λ·tanh(e/τ)`` + modulates the group-relative advantage at tokens where it exceeds + ``delta``, with a sign-preserving clamp so the verifier keeps the update + direction. Ships per-token advantages on the ``rl`` loss component. + Groups without both correct and incorrect rollouts get no modulation + (uniform groups already die in the zero-advantage filter, matching the + paper's group-discard rule).""" + + action_loss_type: ClassVar[ActionLossType] = "rl" + group_relative: ClassVar[bool] = True + model_role: ClassVar[str] = "teacher" + + model: ModelReference = "policy" + """The teacher the hinted distributions are computed under. ``"policy"`` + (the default) approximates the paper's setting — there the teacher is a + snapshot of the student refreshed every 10 steps; the live policy + refreshes every weight update. Set an inline frozen hosted model to + contrast under a fixed teacher instead.""" + + num_negative_hints: int = Field(4, ge=1) + """K: incorrect sibling hints whose probabilities average into the + negative branch — marginalizing over error types stabilizes the + contrast.""" + + tau: float = Field(0.02, gt=0) + """Soft-threshold slope of the tanh squash on the raw contrast.""" + + lam: float = Field(0.5, gt=0) + """Scale of the modulation: ``r_t = lam · tanh(e_ctr / tau)`` ∈ (-lam, lam).""" + + delta: float = Field(0.02, ge=0) + """Modulation mask threshold: only tokens with ``|r_t| > delta`` get + their advantage modulated (~20-30% of tokens at the defaults).""" + + eta: float = Field(1.0, ge=0) + """Weight of the modulated path relative to the unmodulated path; both + paths are normalized independently per rollout so the modulated tokens + never dilute.""" + + correct_threshold: float = 1.0 + """Rollouts with ``reward >= correct_threshold`` form the correct hint + pool — the binary verifier generalized to continuous rewards.""" + + min_contrast_gap: float = Field(0.0, ge=0) + """Exclusion band below ``correct_threshold``: negative hints need + ``reward < correct_threshold - min_contrast_gap``, so borderline rollouts + never serve as wrong hints and near-threshold noise stops producing + contrast as the group tightens. ``0.0`` (the default) disables the band; + on binary rewards any value in (0, 1] is equivalent to it.""" + + template: str = ( + "{question}\n\n" + "Here is a reference solution to this problem:\n" + "=== Reference Solution Begin ===\n{hint}\n=== Reference Solution End ===\n\n" + "After reading the reference solution above, make sure you understand the " + "reasoning behind each step. Please reason step by step, and put your final " + "answer within \\boxed{{}}." + ) + """Template for the hinted teacher context. Receives ``{question}`` (the + original user message text) and ``{hint}`` (a sibling rollout's full + completion text). Byte-for-byte identical for correct and incorrect + hints — that symmetry is what cancels the style component.""" + + max_concurrent: int = Field(32, ge=1) + """Maximum concurrent prefill requests per batch (each rollout costs + ``1 + num_negative_hints`` prefills).""" + + class SFTAdvantageConfig(BaseConfig): type: Literal["sft"] = "sft" """SFT distillation: cross-entropy on the sampled tokens. The ``ce`` @@ -299,6 +373,7 @@ class CustomAdvantageConfig(BaseConfig): | RewardAdvantageConfig | OPDAdvantageConfig | OPSDAdvantageConfig + | RLCSDAdvantageConfig | SFTAdvantageConfig | CustomAdvantageConfig, Field(discriminator="type"), diff --git a/packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py b/packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py index fbed7d8cd4..c15399d0a3 100644 --- a/packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py +++ b/packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py @@ -744,10 +744,10 @@ def validate_renderer_for_demo_scoring(self): if self.renderer is not None: return self for env in self.train.env: - if env.algo is not None and env.algo.advantage.type == "opsd": + if env.algo is not None and env.algo.advantage.type in ("opsd", "rlcsd"): raise ValueError( - f"env '{env.resolved_name}' uses opsd, which renders its demo-conditioned " - "scoring prefix client-side and requires orchestrator.renderer — remove " + f"env '{env.resolved_name}' uses {env.algo.advantage.type}, which renders its " + "hinted scoring prefixes client-side and requires orchestrator.renderer — remove " "'renderer = \"None\"'." ) if env.algo is not None and env.algo.advantage.type == "echo": diff --git a/skills/configs/SKILL.md b/skills/configs/SKILL.md index 0321033da9..f7e97b4c13 100644 --- a/skills/configs/SKILL.md +++ b/skills/configs/SKILL.md @@ -51,7 +51,7 @@ CLI: `--env.0.id reverse-text --env.1.id math-env`. **Discriminated unions** — set the `type` field to pick the variant (`[orchestrator.advantage] type = "max_rl"`). Omit `type` to keep the default variant. -**Algorithms** — `[orchestrator.algo.advantage] type = "grpo" | "max_rl" | "opd" | "opsd" | "sft" | "echo" | "reward" | "custom"` — the advantage type names the algorithm (credit assignment + loss routing, fused), and each type's class defaults are its vetted setting; any other key you set is your own assembly (e.g. `[orchestrator.algo.advantage.roles.user] alpha = 0.1` for echo — setting any echo role replaces the whole role table). There is no preset layer. Per-env override: `[[orchestrator.train.env]]` `advantage = { type = "echo" }` (the env assembles its own algorithm). prime-rl only hosts the trainable policy; frozen models are inline external endpoints on the algorithm — `[orchestrator.algo.teacher]` (alias for `model`) with `name` + `base_url` folds into the slot the type declares (`advantage.model` for opd/opsd, `sampling.source` for sft). `model = "policy"` points a component at the live policy (opsd's default). See `docs/algorithms.md`. +**Algorithms** — `[orchestrator.algo.advantage] type = "grpo" | "max_rl" | "opd" | "opsd" | "rlcsd" | "sft" | "echo" | "reward" | "custom"` — the advantage type names the algorithm (credit assignment + loss routing, fused), and each type's class defaults are its vetted setting; any other key you set is your own assembly (e.g. `[orchestrator.algo.advantage.roles.user] alpha = 0.1` for echo — setting any echo role replaces the whole role table). There is no preset layer. Per-env override: `[[orchestrator.train.env]]` `advantage = { type = "echo" }` (the env assembles its own algorithm). prime-rl only hosts the trainable policy; frozen models are inline external endpoints on the algorithm — `[orchestrator.algo.teacher]` (alias for `model`) with `name` + `base_url` folds into the slot the type declares (`advantage.model` for opd/opsd, `sampling.source` for sft). `model = "policy"` points a component at the live policy (opsd's default). See `docs/algorithms.md`. **`BaseModel | None` fields** — bare flag enables defaults; nested override enables and sets: diff --git a/src/prime_rl/orchestrator/algo/__init__.py b/src/prime_rl/orchestrator/algo/__init__.py index 3506444bb5..2dd038ec62 100644 --- a/src/prime_rl/orchestrator/algo/__init__.py +++ b/src/prime_rl/orchestrator/algo/__init__.py @@ -6,7 +6,7 @@ :class:`~prime_rl.orchestrator.sampler.Sampler`): - one module per algorithm (``grpo``, ``echo``, ``max_rl``, ``opd``, - ``opsd``, ``sft``, ``reward``, ``custom``) — each named class owns its + ``opsd``, ``rlcsd``, ``sft``, ``reward``, ``custom``) — each named class owns its scoring hooks (``score_rollout`` / ``score_group`` / ``score_batch``) and declares what it needs (loss component, a "teacher", ...). One instance per env, built by :func:`build_algorithm`. Custom credit assignment plugs in @@ -49,6 +49,7 @@ from prime_rl.orchestrator.algo.opd import OPDAlgorithm from prime_rl.orchestrator.algo.opsd import OPSDAlgorithm from prime_rl.orchestrator.algo.reward import RewardAlgorithm +from prime_rl.orchestrator.algo.rlcsd import RLCSDAlgorithm from prime_rl.orchestrator.algo.routing import stamp_advantages, stamp_loss_routing from prime_rl.orchestrator.algo.sft import SFTDistillAlgorithm from prime_rl.orchestrator.types import RolloutView @@ -67,6 +68,7 @@ "max_rl": MaxRLAlgorithm, "opd": OPDAlgorithm, "opsd": OPSDAlgorithm, + "rlcsd": RLCSDAlgorithm, "sft": SFTDistillAlgorithm, "reward": RewardAlgorithm, "custom": CustomAlgorithm, @@ -90,6 +92,7 @@ def build_algorithm(config: AlgorithmConfig, policy_pool: InferencePool, rendere "MaxRLAlgorithm", "OPDAlgorithm", "OPSDAlgorithm", + "RLCSDAlgorithm", "RewardAlgorithm", "RolloutView", "SFTDistillAlgorithm", diff --git a/src/prime_rl/orchestrator/algo/rlcsd.py b/src/prime_rl/orchestrator/algo/rlcsd.py new file mode 100644 index 0000000000..c37da59a96 --- /dev/null +++ b/src/prime_rl/orchestrator/algo/rlcsd.py @@ -0,0 +1,243 @@ +from __future__ import annotations + +import asyncio +import math +import random +import statistics +from collections import defaultdict +from itertools import cycle +from typing import TYPE_CHECKING + +from prime_rl.configs.algorithm import AdvantageConfig, RLCSDAdvantageConfig +from prime_rl.orchestrator.algo.advantage import apply_advantage_fn +from prime_rl.orchestrator.algo.base import Algorithm +from prime_rl.orchestrator.utils import compute_prefill_logprobs +from prime_rl.utils.logger import get_logger + +if TYPE_CHECKING: + from renderers.base import Renderer + + from prime_rl.orchestrator.types import RolloutView + from prime_rl.utils.client import InferencePool + +_ADV_EPS = 1e-6 + + +def _std_norm_advantage_fn(group: list[RolloutView]) -> list[float]: + """Std-normalized group-relative advantage (the paper's Eq. 8): + ``(r - mean) / (std + eps)`` — one scalar per rollout (the view broadcasts + it over the completion tokens).""" + rewards = [v.reward for v in group] + mean = statistics.fmean(rewards) + std = statistics.pstdev(rewards) if len(rewards) > 1 else 0.0 + return [(r - mean) / (std + _ADV_EPS) for r in rewards] + + +def _hint_pools( + group: list[RolloutView], correct_threshold: float, min_contrast_gap: float +) -> tuple[list[RolloutView], list[RolloutView]]: + """Partition one group into hint pools. Positives are verified correct + (``reward >= correct_threshold``); negatives must be clearly wrong + (``reward < correct_threshold - min_contrast_gap``). Rollouts in the band + between are neither — they never serve as hints, so near-threshold noise + stops producing contrast as the group tightens.""" + correct = [r for r in group if (r.reward or 0.0) >= correct_threshold] + wrong = [r for r in group if (r.reward or 0.0) < correct_threshold - min_contrast_gap] + return correct, wrong + + +def _contrastive_signal(pos_logprobs: list[float], neg_logprobs: list[list[float]]) -> list[float]: + """Per-token contrast ``e_ctr`` (Eq. 7): the teacher's logprob under the + correct hint minus the log of the *mean probability* over the K incorrect + hints (log-mean-exp, not mean logprob).""" + k = len(neg_logprobs) + signal = [] + for t, pos in enumerate(pos_logprobs): + neg_t = [neg[t] for neg in neg_logprobs] + peak = max(neg_t) + log_mean_neg = peak + math.log(sum(math.exp(v - peak) for v in neg_t) / k) + signal.append(pos - log_mean_neg) + return signal + + +def _modulated_token_advantages( + signal: list[float], + advantages: list[float], + completion_mask: list[bool], + *, + lam: float, + tau: float, + delta: float, + eta: float, +) -> list[float] | None: + """Two-path token advantages (Eqs. 9-15): squash the contrast through + ``lam·tanh(·/tau)``, modulate the per-token base advantages (uniform + under group-norm assignment) at tokens above the ``delta`` mask with a + sign-preserving clamp, and fold the paper's independent path + normalization into the magnitudes — the clipped surrogate is positively + homogeneous in the advantage, so per-rollout weights ``L/|U|`` and + ``eta·L/|M|`` reproduce the two-path objective without touching the + trainer. Returns ``None`` when no token is trainable.""" + modulation = [lam * math.tanh(e / tau) for e in signal] + trainable = [t for t, trains in enumerate(completion_mask) if trains] + if not trainable: + return None + modulated = {t for t in trainable if abs(modulation[t]) > delta} + num_total = len(trainable) + num_modulated = len(modulated) + num_plain = num_total - num_modulated + + token_advantages = [0.0] * len(signal) + for t in trainable: + base = advantages[t] + if t in modulated: + shifted = base + modulation[t] + clamped = max(0.0, shifted) if base >= 0 else min(0.0, shifted) + token_advantages[t] = eta * clamped * (num_total / num_modulated) + else: + token_advantages[t] = base * (num_total / num_plain) + return token_advantages + + +class RLCSDAlgorithm(Algorithm): + """RLCSD (arXiv:2606.11709): GRPO anchored by the verifier, with a + contrastive self-distillation signal modulating the advantage magnitude + at high-signal tokens. + + At group time, std-normalized group-relative credit (broadcast per + token). At ship time, each surviving rollout's tokens are prefill-scored + under the teacher conditioned on a correct sibling rollout and on K + incorrect siblings (byte-identical hint template, so the + privilege-induced style shift cancels in the subtraction); the squashed + contrast modulates the base advantages with a sign-preserving clamp and + overwrites the sample's advantage stream on the ``rl`` component. + Rollouts whose group offers no contrast (no correct or no incorrect + sibling) keep their plain group-norm stream.""" + + action_loss_type = "rl" + model_role = "teacher" + + def __init__(self, advantage: AdvantageConfig, policy_pool: InferencePool, renderer: Renderer | None): + super().__init__(advantage, policy_pool, renderer) + assert isinstance(advantage, RLCSDAdvantageConfig) + assert renderer is not None, "rlcsd requires the renderer (validated at config time)" + self.num_negative_hints = advantage.num_negative_hints + self.tau = advantage.tau + self.lam = advantage.lam + self.delta = advantage.delta + self.eta = advantage.eta + self.correct_threshold = advantage.correct_threshold + self.min_contrast_gap = advantage.min_contrast_gap + self.template = advantage.template + self.max_concurrent = advantage.max_concurrent + self.teacher = advantage.model + self.teacher_pool: InferencePool | None = None # connected in setup() + + async def setup(self) -> None: + self.teacher_pool = await self.connect(self.teacher) + + def score_group(self, group: list[RolloutView]) -> None: + apply_advantage_fn(group, _std_norm_advantage_fn) + + async def score_batch(self, batch: list[RolloutView]) -> None: + pool = self.teacher_pool + assert pool is not None, "teacher pool not connected — Algorithm.setup() must run first" + semaphore = asyncio.Semaphore(self.max_concurrent) + clients = cycle(pool.train_clients) + + groups: dict[object, list[RolloutView]] = defaultdict(list) + for view in batch: + groups[view.group_key].append(view) + + tasks = [] + for group in groups.values(): + correct, wrong = _hint_pools(group, self.correct_threshold, self.min_contrast_gap) + for view in group: + # Hints come from siblings only — conditioning the teacher on + # the rollout itself shifts it toward degenerate over-confidence. + pos_pool = [s for s in correct if s is not view] + neg_pool = [s for s in wrong if s is not view] + if not pos_pool or not neg_pool: + continue # no contrast available — the rollout keeps its plain group-norm stream + tasks.append(self._score_one(view, pos_pool, neg_pool, semaphore, pool, next(clients))) + # Contrast needs a correct AND an incorrect sibling; early in training + # (or with a miscalibrated correct_threshold) most groups offer none + # and the batch silently trains as plain GRPO — make that visible. + get_logger().debug(f"rlcsd: contrast available for {len(tasks)}/{len(batch)} rollouts") + if tasks: + await asyncio.gather(*tasks) + + async def _score_one( + self, + view: RolloutView, + pos_pool: list[RolloutView], + neg_pool: list[RolloutView], + semaphore: asyncio.Semaphore, + pool: InferencePool, + client, + ) -> None: + assert len(view.samples) == 1 # single-step trajectory → one sample + sample = view.samples[0] + completion_ids = list(sample.completion_ids) + prompt_len = len(sample.prompt_ids) + + pos_hint = random.choice(pos_pool) + neg_hints = random.sample(neg_pool, min(self.num_negative_hints, len(neg_pool))) + + async def hinted_logprobs(hint: RolloutView) -> list[float]: + prefix_ids = self._hinted_prefix_ids(view, hint) + async with semaphore: + full = await compute_prefill_logprobs(client, pool.model_name, prefix_ids + completion_ids) + return full[-len(completion_ids) :] + + results = await asyncio.gather(hinted_logprobs(pos_hint), *(hinted_logprobs(h) for h in neg_hints)) + signal = _contrastive_signal(results[0], list(results[1:])) + + # Base credit = the group-norm stream already stamped onto the sample + # at group time (prompt-padded); modulate its completion portion. + base = sample.advantages[prompt_len:] if sample.advantages is not None else [0.0] * len(completion_ids) + token_advantages = _modulated_token_advantages( + signal, + base, + list(sample.completion_mask), + lam=self.lam, + tau=self.tau, + delta=self.delta, + eta=self.eta, + ) + if token_advantages is not None: + sample.advantages = [0.0] * prompt_len + token_advantages + + def _hint_text(self, rollout: RolloutView) -> str: + """A sibling rollout's full completion text — the reference solution + the teacher is conditioned on.""" + trajectory = rollout.raw.get("trajectory") or [] + if len(trajectory) != 1: + raise ValueError( + f"rlcsd supports single-step trajectories only; " + f"env '{rollout.env_name}' produced {len(trajectory)} steps." + ) + parts = [m.get("content") for m in trajectory[0]["completion"] if isinstance(m.get("content"), str)] + return "\n".join(parts) + + def _hinted_prefix_ids(self, rollout: RolloutView, hint: RolloutView) -> list[int]: + """Rebuild the rollout's first-turn prompt with the hint woven into + the last user message, rendered through the policy's renderer — the + same messages → token ids path the policy's own prompts take.""" + trajectory = rollout.raw.get("trajectory") or [] + if len(trajectory) != 1: + raise ValueError( + f"rlcsd supports single-step trajectories only; " + f"env '{rollout.env_name}' produced {len(trajectory)} steps." + ) + messages = [dict(m) for m in trajectory[0]["prompt"]] + user_indices = [i for i, m in enumerate(messages) if m.get("role") == "user"] + if not user_indices: + raise ValueError(f"rlcsd found no user message to condition (env '{rollout.env_name}').") + last_user = messages[user_indices[-1]] + question = last_user.get("content") + if not isinstance(question, str): + raise ValueError("rlcsd supports text-only prompts (user content must be a string).") + last_user["content"] = self.template.format(question=question, hint=self._hint_text(hint)) + assert self.renderer is not None + return self.renderer.render_ids(messages, add_generation_prompt=True) diff --git a/src/prime_rl/orchestrator/types.py b/src/prime_rl/orchestrator/types.py index ad9b694bda..1ee2abdb20 100644 --- a/src/prime_rl/orchestrator/types.py +++ b/src/prime_rl/orchestrator/types.py @@ -132,11 +132,12 @@ def to_dict(self) -> vf.RolloutOutput: class RolloutView: """A finalized rollout as a writable handle — the single currency the scoring hooks operate on. Exposes what the env produced (``raw``), the - samples interleaving built (``samples``, carrying ``obs_spans``), and the - rollout's identity/reward; credit is written through - :meth:`assign_advantages`, which spreads over the samples' completion - tokens. Deliberately does *not* expose pipeline-internal lifecycle fields - (``is_filtered``, ``filter_results``, ``group_id``) or not-yet-assigned + samples interleaving built (``samples``, carrying ``obs_spans``), the + rollout's identity/reward, and its ``group_key`` (the safe cohort key for + partitioning a batch's survivors at the batch stage); credit is written + through :meth:`assign_advantages`, which spreads over the samples' + completion tokens. Deliberately does *not* expose pipeline-internal + lifecycle fields (``is_filtered``, ``filter_results``) or not-yet-assigned credit (``advantages``) — a hook can only touch what is valid at its stage.""" @@ -162,6 +163,14 @@ def env_name(self) -> str: def example_id(self) -> int | str: return self._rollout.example_id + @property + def group_key(self) -> uuid.UUID: + """The rollout's group identity — the safe key for partitioning a + batch's survivors back into their cohorts at the batch stage (the only + stage that sees more than one group). Use over ``example_id``, which + collides when an example is re-sampled.""" + return self._rollout.group_id + def assign_advantages(self, values: float | list[float]) -> None: """Write the rl advantage stream: a scalar broadcast over the rollout's completion tokens, or a per-token list aligned to them diff --git a/tests/unit/orchestrator/test_algorithms.py b/tests/unit/orchestrator/test_algorithms.py index afec1fa52a..c105427ba9 100644 --- a/tests/unit/orchestrator/test_algorithms.py +++ b/tests/unit/orchestrator/test_algorithms.py @@ -27,6 +27,7 @@ def _ref_kind(ref): ("sft", FROZEN, "frozen", None, "ce"), ("opsd", None, "policy", "policy", "ref_kl"), ("echo", None, "policy", None, "rl"), + ("rlcsd", None, "policy", "policy", "rl"), ], ) def test_type_defaults_are_the_vetted_algorithms(advantage_type, model, source, advantage_model, action_loss_type): @@ -303,6 +304,60 @@ def keep_last_only(output): _echo_algorithm(filter_fn=lambda output: [[True] * 4, [True] * 6]).score_rollout(RolloutView(rollout)) +def test_rlcsd_contrastive_signal_is_log_mean_exp(): + from prime_rl.orchestrator.algo.rlcsd import _contrastive_signal + + # One negative hint: plain logprob difference + assert _contrastive_signal([-1.0], [[-2.0]])[0] == pytest.approx(1.0) + # K negatives: the negative branch is the log of the MEAN probability, + # not the mean logprob + import math + + expected = -1.0 - math.log((math.exp(-1.0) + math.exp(-3.0)) / 2) + assert _contrastive_signal([-1.0], [[-1.0], [-3.0]])[0] == pytest.approx(expected) + + +def test_rlcsd_hint_pools_gap_band(): + from types import SimpleNamespace + + from prime_rl.orchestrator.algo.rlcsd import _hint_pools + + group = [SimpleNamespace(reward=r) for r in (1.0, 0.6, 0.45, 0.1)] + # gap 0 reproduces the plain threshold split + correct, wrong = _hint_pools(group, 0.5, 0.0) + assert [r.reward for r in correct] == [1.0, 0.6] + assert [r.reward for r in wrong] == [0.45, 0.1] + # the band [threshold - gap, threshold) serves as neither hint + correct, wrong = _hint_pools(group, 0.5, 0.3) + assert [r.reward for r in correct] == [1.0, 0.6] + assert [r.reward for r in wrong] == [0.1] + # binary rewards: any gap in (0, 1] reduces to the paper's partition + binary = [SimpleNamespace(reward=r) for r in (1.0, 0.0, 1.0)] + correct, wrong = _hint_pools(binary, 1.0, 0.5) + assert len(correct) == 2 and len(wrong) == 1 + + +def test_rlcsd_modulation_two_path_weights_and_clamp(): + from prime_rl.orchestrator.algo.rlcsd import _modulated_token_advantages + + knobs = dict(lam=0.5, tau=0.02, delta=0.02, eta=1.0) + # Token 0 carries a saturating contrast (tanh -> 1, r = lam), token 1 none. + # Paths normalize independently: each path's weight is L / |path|. + out = _modulated_token_advantages([10.0, 0.0], [1.0, 1.0], [True, True], **knobs) + assert out[0] == pytest.approx((1.0 + 0.5) * 2.0) # modulated path, |M| = 1 + assert out[1] == pytest.approx(1.0 * 2.0) # plain path, |U| = 1 + + # Sign-preserving clamp: modulation never flips the verifier's direction + out = _modulated_token_advantages([10.0], [-0.2], [True], **knobs) + assert out == [0.0] + + # Below the mask threshold everything stays plain GRPO at unit weight + assert _modulated_token_advantages([0.0], [1.0], [True], **knobs) == [1.0] + + # No trainable tokens -> no per-token advantages + assert _modulated_token_advantages([10.0], [1.0], [False], **knobs) is None + + def test_interleave_records_obs_spans(): samples = interleave_rollout(_two_step_rollout(), env_name="test-env") assert samples is not None