Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions configs/debug/algorithms/echo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,13 @@ batch_size = 32
group_size = 4

# Assembled (no preset name): presets are atomic. alphabet-sort's feedback
# arrives as user messages, so we train all observation tokens, not just
# tool responses (the echo preset's default).
# arrives as user messages, so we train the user role instead of the echo
# preset's tool default.
[orchestrator.algo.advantage]
type = "echo"
observation_weight = 0.1
observations = "all"

[orchestrator.algo.advantage.roles.user]
alpha = 0.1

[[orchestrator.train.env]]
id = "alphabet-sort"
Expand Down
29 changes: 25 additions & 4 deletions docs/algorithms.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,24 +69,45 @@ name = "grpo" # the default
| `opd` | policy | `ref_kl` | `ref_kl` on actions | On-policy distillation ([Thinking Machines](https://thinkingmachines.ai/blog/on-policy-distillation/)): the policy samples, per-token reverse KL against a reference model as the gradient signal. Needs an inline `model`. |
| `sft_distill` | *(set via `model`)* | `supervised` | `ce` on actions | Hard distillation: a frozen model generates rollouts, the policy trains with CE on its tokens. Needs an inline `model`. |
| `self_distill` | policy | `demo_ref_kl` | `ref_kl` on actions | SDFT ([arXiv:2601.19897](https://arxiv.org/abs/2601.19897)): the model is its own reference, conditioned on an expert demonstration. Defaults to the live policy (the paper's setting, no extra deployment); set an inline `model` to score under a frozen copy instead. |
| `echo` | policy | `echo` | `rl` on actions + weighted `ce` on observations | ECHO: standard GRPO plus a cross-entropy loss on tool-response tokens already present in the rollout (`observation_weight` is ECHO's λ, default 0.1; needs the renderer's role attribution). Assemble with `observations = "all"` to train every env-provided token instead. |
| `echo` | policy | `echo` | `rl` on actions + weighted `ce` on observations | ECHO: standard GRPO plus a cross-entropy loss on env-provided tokens already present in the rollout, selected by message role (needs the renderer's role attribution). The preset trains tool-response bodies at `alpha = 0.1` (ECHO's λ); assemble `roles` to train other roles, each at its own weight. |

### Customizing Components

Presets are **atomic**: a preset name fixes both components, and the only keys that may accompany it are the `model` / `teacher` shorthand (the distillation presets are incomplete without an endpoint by design). To customize anything else, drop `name` and assemble the components directly — presets are thin deltas, so assembly costs one `type` key:

```toml
# echo with a custom lambda — assembled, no preset name:
# echo on tool AND user feedback tokens, each at its own weight — assembled,
# no preset name. Setting any role replaces the whole table.
[orchestrator.algo.advantage]
type = "echo"
observation_weight = 0.25

[orchestrator.algo.advantage.roles.tool]
alpha = 0.25

[orchestrator.algo.advantage.roles.user]
alpha = 0.05

# or a custom advantage strategy:
# [orchestrator.algo.advantage]
# type = "custom"
# import_path = "my_module.normalized_advantage"
```

Echo also takes an optional user-supplied token filter that narrows the role selection per rollout — e.g. dropping warning lines from tool output, or tokens the sampler found unlikely:

```toml
[orchestrator.algo.advantage.filter]
import_path = "my_module.drop_warnings"
kwargs = { patterns = ["WARNING"] }
```

```python
# my_module.py — sees the raw rollout (message text, sampling logprobs);
# returns one keep-mask per trajectory step, spanning that step's
# prompt_ids + completion_ids. False = never echo-trained.
def drop_warnings(rollout, *, patterns: list[str]) -> list[list[bool]]: ...
```

A preset name with explicit `advantage` / `sampling` keys is a parse-time error: a modified preset is not the preset, so the config must state what it actually runs.

Component compatibility is validated at config time: frozen-model sampling cannot feed an advantage with the `rl` loss component (no policy sampling logprobs for importance ratios), `ref_kl` pointed at `"policy"` is rejected as degenerate (zero KL), and group-relative advantage with `group_size = 1` warns that every advantage collapses to zero.
Expand Down Expand Up @@ -252,7 +273,7 @@ The advantage strategy is the `advantage` component of the [algorithm](#the-algo
| Type | Component | Effect |
|---|---|---|
| `group_norm` | `rl` | Group-norm (GRPO): reward minus per-group baseline, optional length penalty. |
| `echo` | `rl` + `ce` | Group-norm on action tokens, plus weighted CE on env-observation tokens (`observation_weight`, ECHO's λ). |
| `echo` | `rl` + `ce` | Group-norm on action tokens, plus weighted CE on env-provided tokens selected by message role (each role's `alpha` is its ECHO λ), optionally narrowed by a user filter. |
| `reward` | `rl` | Advantage = raw reward, no baseline. |
| `ref_kl` | `ref_kl` | On-policy distillation: per-token reverse KL to a reference model (`model`, an inline frozen hosted model), evaluated in the trainer from shipped reference logprobs. No scalars — rollouts keep `advantage = None` (advantage-based filters never fire) and ship a neutral 0.0; `group_size` only fans out sampling. |
| `demo_ref_kl` | `ref_kl` | SDFT: per-token reverse KL to a demo-conditioned reference. No scalars — rollouts keep `advantage = None` (advantage-based filters never fire) and ship a neutral 0.0. |
Expand Down
66 changes: 53 additions & 13 deletions packages/prime-rl-configs/src/prime_rl/configs/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,23 +119,63 @@ class GroupNormAdvantageConfig(BaseConfig):
"""Correctness-gated length penalty. ``tokens`` shapes by weighted token cost; ``turns`` shapes by trajectory turn count; None disables shaping. In mixed groups, lower-cost correct rollouts get amplified advantage (up to 2x), higher-cost correct rollouts are unchanged, incorrect untouched. In all-correct groups, below-average-cost rollouts get advantage in [0, 1], others get 0."""


class EchoRoleConfig(BaseConfig):
"""Echo CE supervision for one message role."""

alpha: float = Field(0.1, gt=0)
"""Per-token ce weight for this role's env-provided tokens (ECHO's lambda)."""


class EchoRolesConfig(BaseConfig):
"""Which env-provided message roles train, each at its own weight.
Setting any role replaces the whole table — unset roles stay disabled."""

system: EchoRoleConfig | None = None
user: EchoRoleConfig | None = None
assistant: EchoRoleConfig | None = None
tool: EchoRoleConfig | None = None

@model_validator(mode="after")
def require_a_role(self):
if self.system is None and self.user is None and self.assistant is None and self.tool is None:
raise ValueError("echo needs at least one role enabled (system, user, assistant, or tool)")
return self


class EchoFilterConfig(BaseConfig):
"""User-supplied per-token filter narrowing the role-selected echo tokens.

The callable is imported at startup and invoked once per rollout as
``filter_fn(rollout, **kwargs) -> list[list[bool]]`` — one keep-mask per
trajectory step, each spanning that step's ``prompt_ids`` +
``completion_ids``. Tokens with ``False`` never receive echo weight; the
filter can only narrow the role selection, not widen it. The raw rollout
exposes message text and sampling logprobs, so content filters (e.g.
dropping tool-output warnings) and sampling-probability filters need no
extra framework surface."""

import_path: str
"""Import path to the filter callable (e.g. ``my_module.drop_warnings``)."""

kwargs: dict[str, Any] = Field(default_factory=dict)
"""Kwargs forwarded to the filter."""


class EchoAdvantageConfig(GroupNormAdvantageConfig):
type: Literal["echo"] = "echo" # type: ignore[assignment]
"""ECHO: group-relative advantage on action tokens (GRPO), plus weighted
CE on env-provided observation tokens of later turns (tool output,
terminal responses). The observation tokens feed the ``ce`` loss component
at ``observation_weight`` and stay outside the rl mask and its
denominator."""
CE on env-provided tokens of later turns (tool output, user feedback),
selected by message role via the renderer's per-token attribution
(requires ``orchestrator.renderer``; MITO rollouts carry no attribution).
Selected tokens feed the ``ce`` loss component at their role's ``alpha``
and stay outside the rl mask and its denominator."""

observation_weight: float = Field(0.1, gt=0)
"""Per-token ce weight for observation tokens (ECHO's lambda)."""
roles: EchoRolesConfig = EchoRolesConfig(tool=EchoRoleConfig())
"""The role table. The default — tool-response bodies at ``alpha = 0.1``
— is the vetted ECHO setting."""

observations: Literal["tool", "all"] = "tool"
"""Which env-provided tokens train. ``tool`` (the vetted default — the
ECHO setting) trains tool/terminal response bodies only, using the
renderer's per-token role attribution (requires ``orchestrator.renderer``;
MITO rollouts carry no attribution). ``all`` trains every env-provided
token — tool and user feedback alike."""
filter: EchoFilterConfig | None = None
"""Optional user-supplied filter narrowing the role-selected tokens."""


class RewardAdvantageConfig(BaseConfig):
Expand Down Expand Up @@ -287,7 +327,7 @@ class AlgorithmConfig(BaseConfig):
- ``opd`` — on-policy distillation: policy samples, ``ref_kl`` advantage against a reference model. Needs ``model``.
- ``sft_distill`` — a frozen model samples, the policy trains with CE on its tokens (``supervised``). Needs ``model``.
- ``self_distill`` — SDFT: policy samples, ``demo_ref_kl`` advantage against the live policy by default.
- ``echo`` — GRPO on action tokens + weighted CE on env-observation tokens.
- ``echo`` — GRPO on action tokens + weighted CE on tool-response observation tokens.
"""

model: ModelReference | None = Field(None, exclude=True, validation_alias=AliasChoices("model", "teacher"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -767,11 +767,10 @@ def validate_renderer_for_demo_scoring(self):
"'renderer = \"None\"' (and note the renderer is forced off when no train env "
"samples from the policy)."
)
if env.algo is not None and getattr(env.algo.advantage, "observations", None) == "tool":
if env.algo is not None and env.algo.advantage.type == "echo":
raise ValueError(
f"env '{env.resolved_name}' trains on tool observation tokens, which needs the "
"renderer's per-token role attribution — set orchestrator.renderer, or assemble "
"echo with observations = 'all'."
f"env '{env.resolved_name}' trains env-provided tokens by message role (echo), "
"which needs the renderer's per-token attribution — set orchestrator.renderer."
)
return self

Expand Down
2 changes: 1 addition & 1 deletion skills/configs/SKILL.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ CLI: `--env.0.id reverse-text --env.1.id math-env`.

**Discriminated unions** — set the `type` field to pick the variant (`[orchestrator.advantage] type = "custom"`). Omit `type` to keep the default variant.

**Algorithm presets** — `[orchestrator.algo] name = "grpo" | "opd" | "sft_distill" | "self_distill" | "echo"` bundles sampling and the advantage (the per-token training signal: credit assignment + loss routing, fused — echo's `observation_weight` lives on `advantage`). The preset merges under your keys, so fields you don't set are kept; a different `advantage.type` replaces the strategy wholesale. Per-env override: `[[orchestrator.train.env]]` `algo = { name = "echo" }`. prime-rl only hosts the trainable policy; frozen models are inline external endpoints on the algorithm — `[orchestrator.algo.model]` (alias: `[orchestrator.algo.teacher]`) with `name` + `base_url` folds into the unresolved component reference (`advantage.model` for opd, `sampling.source` for sft_distill). `model = "policy"` points a component at the live policy (`self_distill`'s default). See `docs/algorithms.md`.
**Algorithm presets** — `[orchestrator.algo] name = "grpo" | "opd" | "sft_distill" | "self_distill" | "echo"` bundles sampling and the advantage (the per-token training signal: credit assignment + loss routing, fused). Presets are **atomic**: a name fixes both components, and only the `model` / `teacher` shorthand may accompany it — to customize anything else, drop `name` and assemble the components (`[orchestrator.algo.advantage] type = "echo"` + `[orchestrator.algo.advantage.roles.user] alpha = 0.1`; setting any echo role replaces the whole role table). Per-env override: `[[orchestrator.train.env]]` `algo = { name = "echo" }`. prime-rl only hosts the trainable policy; frozen models are inline external endpoints on the algorithm — `[orchestrator.algo.model]` (alias: `[orchestrator.algo.teacher]`) with `name` + `base_url` folds into the unresolved component reference (`advantage.model` for opd, `sampling.source` for sft_distill). `model = "policy"` points a component at the live policy (`self_distill`'s default). See `docs/algorithms.md`.

**`BaseModel | None` fields** — bare flag enables defaults; nested override enables and sets:

Expand Down
34 changes: 24 additions & 10 deletions src/prime_rl/orchestrator/algo/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@

import asyncio
from collections import defaultdict
from collections.abc import Callable
from functools import partial
from itertools import cycle
from typing import TYPE_CHECKING, ClassVar

Expand Down Expand Up @@ -108,8 +110,11 @@ def __init__(self, config: AlgorithmConfig, policy_pool: InferencePool, renderer
self.renderer = renderer
self.reference_pool: InferencePool | None = None # resolved in setup() when the algorithm declares a model
self.connected_pools: list[InferencePool] = [] # client pools connected in setup(); closed at shutdown
self.observation_weight: float | None = None # ce weight for env-provided tokens; None masks them out
self.observation_tokens: str | None = None # which env tokens interleave tags: "tool", "all", or None
# Echo selection: message role -> per-token ce weight for env-provided
# tokens (None masks them all out), plus an optional user filter
# narrowing the selection per rollout. Consumed by interleave_rollout.
self.echo_roles: dict[str, float] | None = None
self.echo_filter_fn: Callable[..., list[list[bool]]] | None = None

async def setup(self) -> None:
"""Connect a client pool to the algorithm's frozen reference model and
Expand Down Expand Up @@ -142,7 +147,7 @@ def finalize_group(self, rollouts: list[TrainRollout]) -> None:
sample.advantage = rollout.advantage if rollout.advantage is not None else 0.0
sample.reward = rollout.reward
sample.env_name = rollout.env_name
stamp_loss_routing(sample, self.action_loss_type, self.observation_weight)
stamp_loss_routing(sample, self.action_loss_type)

def _reference_pool(self) -> InferencePool:
pool = self.reference_pool
Expand All @@ -165,16 +170,25 @@ def assign(self, rollouts: list[TrainRollout]) -> None:


class EchoAlgorithm(GRPOAlgorithm):
"""GRPO on action tokens, plus weighted CE on env-provided observation
tokens (tool output, terminal responses). The observation tokens feed the
``ce`` loss component at ``observation_weight`` and stay outside the rl
mask and its denominator."""
"""GRPO on action tokens, plus weighted CE on env-provided tokens of
later turns (tool output, user feedback), selected by message role —
tool-response bodies at the vetted default. Selected tokens feed the
``ce`` loss component at their role's ``alpha`` and stay outside the rl
mask and its denominator. An optional user filter narrows the selection
per rollout (e.g. dropping tool-output warnings)."""

def __init__(self, config: AlgorithmConfig, policy_pool: InferencePool, renderer: Renderer | None):
super().__init__(config, policy_pool, renderer)
assert isinstance(config.advantage, EchoAdvantageConfig)
self.observation_weight = config.advantage.observation_weight
self.observation_tokens = config.advantage.observations
advantage = config.advantage
assert isinstance(advantage, EchoAdvantageConfig)
self.echo_roles = {
role: role_config.alpha
for role in ("system", "user", "assistant", "tool")
if (role_config := getattr(advantage.roles, role)) is not None
}
if advantage.filter is not None:
filter_fn = import_object(advantage.filter.import_path)
self.echo_filter_fn = partial(filter_fn, **advantage.filter.kwargs)


class OPDAlgorithm(Algorithm):
Expand Down
30 changes: 14 additions & 16 deletions src/prime_rl/orchestrator/algo/routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,25 +19,23 @@
from prime_rl.orchestrator.types import TrainRollout


def stamp_loss_routing(
sample: TrainingSample, action_loss_type: ActionLossType, observation_weight: float | None
) -> None:
def stamp_loss_routing(sample: TrainingSample, action_loss_type: ActionLossType) -> None:
"""Stamp the algorithm's loss routing onto one sample's component weight
streams.

Action tokens (the trainable completion tokens) feed the algorithm's
component: ``rl`` is the default (absent streams ship nothing), while
``ce``/``ref_kl`` stamp that component's weights over the action tokens
and zero the rl stream. When the algorithm trains on observations
(``observation_weight`` is set), env-provided tokens (tagged by
``interleave_rollout`` in ``completion_obs_mask``) get that ce weight —
they stay out of ``completion_mask``, so the ce component is the only one
that trains them. ``completion_obs_mask`` is orchestrator-internal and
cleared here so it never ships.
and zero the rl stream. When the algorithm trains on observations,
env-provided tokens carry their per-token ce weights (tagged by
``interleave_rollout`` in ``completion_obs_weights``) — they stay out of
``completion_mask``, so the ce component is the only one that trains
them. ``completion_obs_weights`` is orchestrator-internal and cleared
here so it never ships.
"""
obs_mask = sample.completion_obs_mask
sample.completion_obs_mask = None
train_obs = observation_weight is not None and obs_mask is not None and any(obs_mask)
obs_weights = sample.completion_obs_weights
sample.completion_obs_weights = None
train_obs = obs_weights is not None and any(obs_weights)
if action_loss_type == "rl" and not train_obs:
return

Expand All @@ -57,11 +55,11 @@ def stamp_loss_routing(
sample.ref_kl_weights = action_weights

if train_obs:
assert obs_mask is not None and observation_weight is not None
assert obs_weights is not None
ce_weights = sample.ce_weights if sample.ce_weights is not None else [0.0] * seq_len
for i, is_obs in enumerate(obs_mask):
if is_obs:
ce_weights[prompt_len + i] = observation_weight
for i, weight in enumerate(obs_weights):
if weight:
ce_weights[prompt_len + i] = weight
sample.ce_weights = ce_weights


Expand Down
4 changes: 3 additions & 1 deletion src/prime_rl/orchestrator/train_sink.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,12 +156,14 @@ async def process_rollout(self, rollout: TrainRollout) -> None:
needs_backfill = any(s["tokens"] is None for s in raw.get("trajectory") or [])
if needs_backfill:
await asyncio.to_thread(backfill_rollout_tokens, raw, self.tokenizer, renderer=self.renderer)
algorithm = self.train_envs.get(rollout.env_name).algorithm
samples = await asyncio.to_thread(
interleave_rollout,
raw,
mm_token_type_ids_mapping=self.mm_token_type_ids_mapping,
env_name=rollout.env_name,
observation_tokens=self.train_envs.get(rollout.env_name).algorithm.observation_tokens,
echo_roles=algorithm.echo_roles,
echo_filter_fn=algorithm.echo_filter_fn,
)
rollout.samples = samples or []
# Offload base64 image bytes to disk as soon as the rollout is
Expand Down
Loading
Loading