diff --git a/configs/debug/algorithms/echo.toml b/configs/debug/algorithms/echo.toml index 832a5a44d9..1d93e53dfa 100644 --- a/configs/debug/algorithms/echo.toml +++ b/configs/debug/algorithms/echo.toml @@ -17,12 +17,13 @@ batch_size = 32 group_size = 4 # Assembled (no preset name): presets are atomic. alphabet-sort's feedback -# arrives as user messages, so we train all observation tokens, not just -# tool responses (the echo preset's default). +# arrives as user messages, so we train the user role instead of the echo +# preset's tool default. [orchestrator.algo.advantage] type = "echo" -observation_weight = 0.1 -observations = "all" + +[orchestrator.algo.advantage.roles.user] +alpha = 0.1 [[orchestrator.train.env]] id = "alphabet-sort" diff --git a/docs/algorithms.md b/docs/algorithms.md index 9526e6afe0..8f2b8fdca5 100644 --- a/docs/algorithms.md +++ b/docs/algorithms.md @@ -69,17 +69,23 @@ name = "grpo" # the default | `opd` | policy | `ref_kl` | `ref_kl` on actions | On-policy distillation ([Thinking Machines](https://thinkingmachines.ai/blog/on-policy-distillation/)): the policy samples, per-token reverse KL against a reference model as the gradient signal. Needs an inline `model`. | | `sft_distill` | *(set via `model`)* | `supervised` | `ce` on actions | Hard distillation: a frozen model generates rollouts, the policy trains with CE on its tokens. Needs an inline `model`. | | `self_distill` | policy | `demo_ref_kl` | `ref_kl` on actions | SDFT ([arXiv:2601.19897](https://arxiv.org/abs/2601.19897)): the model is its own reference, conditioned on an expert demonstration. Defaults to the live policy (the paper's setting, no extra deployment); set an inline `model` to score under a frozen copy instead. | -| `echo` | policy | `echo` | `rl` on actions + weighted `ce` on observations | ECHO: standard GRPO plus a cross-entropy loss on tool-response tokens already present in the rollout (`observation_weight` is ECHO's λ, default 0.1; needs the renderer's role attribution). Assemble with `observations = "all"` to train every env-provided token instead. | +| `echo` | policy | `echo` | `rl` on actions + weighted `ce` on observations | ECHO: standard GRPO plus a cross-entropy loss on env-provided tokens already present in the rollout, selected by message role (needs the renderer's role attribution). The preset trains tool-response bodies at `alpha = 0.1` (ECHO's λ); assemble `roles` to train other roles, each at its own weight. | ### Customizing Components Presets are **atomic**: a preset name fixes both components, and the only keys that may accompany it are the `model` / `teacher` shorthand (the distillation presets are incomplete without an endpoint by design). To customize anything else, drop `name` and assemble the components directly — presets are thin deltas, so assembly costs one `type` key: ```toml -# echo with a custom lambda — assembled, no preset name: +# echo on tool AND user feedback tokens, each at its own weight — assembled, +# no preset name. Setting any role replaces the whole table. [orchestrator.algo.advantage] type = "echo" -observation_weight = 0.25 + +[orchestrator.algo.advantage.roles.tool] +alpha = 0.25 + +[orchestrator.algo.advantage.roles.user] +alpha = 0.05 # or a custom advantage strategy: # [orchestrator.algo.advantage] @@ -87,6 +93,21 @@ observation_weight = 0.25 # import_path = "my_module.normalized_advantage" ``` +Echo also takes an optional user-supplied token filter that narrows the role selection per rollout — e.g. dropping warning lines from tool output, or tokens the sampler found unlikely: + +```toml +[orchestrator.algo.advantage.filter] +import_path = "my_module.drop_warnings" +kwargs = { patterns = ["WARNING"] } +``` + +```python +# my_module.py — sees the raw rollout (message text, sampling logprobs); +# returns one keep-mask per trajectory step, spanning that step's +# prompt_ids + completion_ids. False = never echo-trained. +def drop_warnings(rollout, *, patterns: list[str]) -> list[list[bool]]: ... +``` + A preset name with explicit `advantage` / `sampling` keys is a parse-time error: a modified preset is not the preset, so the config must state what it actually runs. Component compatibility is validated at config time: frozen-model sampling cannot feed an advantage with the `rl` loss component (no policy sampling logprobs for importance ratios), `ref_kl` pointed at `"policy"` is rejected as degenerate (zero KL), and group-relative advantage with `group_size = 1` warns that every advantage collapses to zero. @@ -252,7 +273,7 @@ The advantage strategy is the `advantage` component of the [algorithm](#the-algo | Type | Component | Effect | |---|---|---| | `group_norm` | `rl` | Group-norm (GRPO): reward minus per-group baseline, optional length penalty. | -| `echo` | `rl` + `ce` | Group-norm on action tokens, plus weighted CE on env-observation tokens (`observation_weight`, ECHO's λ). | +| `echo` | `rl` + `ce` | Group-norm on action tokens, plus weighted CE on env-provided tokens selected by message role (each role's `alpha` is its ECHO λ), optionally narrowed by a user filter. | | `reward` | `rl` | Advantage = raw reward, no baseline. | | `ref_kl` | `ref_kl` | On-policy distillation: per-token reverse KL to a reference model (`model`, an inline frozen hosted model), evaluated in the trainer from shipped reference logprobs. No scalars — rollouts keep `advantage = None` (advantage-based filters never fire) and ship a neutral 0.0; `group_size` only fans out sampling. | | `demo_ref_kl` | `ref_kl` | SDFT: per-token reverse KL to a demo-conditioned reference. No scalars — rollouts keep `advantage = None` (advantage-based filters never fire) and ship a neutral 0.0. | diff --git a/packages/prime-rl-configs/src/prime_rl/configs/algorithm.py b/packages/prime-rl-configs/src/prime_rl/configs/algorithm.py index e475a92682..bd0cee727a 100644 --- a/packages/prime-rl-configs/src/prime_rl/configs/algorithm.py +++ b/packages/prime-rl-configs/src/prime_rl/configs/algorithm.py @@ -119,23 +119,63 @@ class GroupNormAdvantageConfig(BaseConfig): """Correctness-gated length penalty. ``tokens`` shapes by weighted token cost; ``turns`` shapes by trajectory turn count; None disables shaping. In mixed groups, lower-cost correct rollouts get amplified advantage (up to 2x), higher-cost correct rollouts are unchanged, incorrect untouched. In all-correct groups, below-average-cost rollouts get advantage in [0, 1], others get 0.""" +class EchoRoleConfig(BaseConfig): + """Echo CE supervision for one message role.""" + + alpha: float = Field(0.1, gt=0) + """Per-token ce weight for this role's env-provided tokens (ECHO's lambda).""" + + +class EchoRolesConfig(BaseConfig): + """Which env-provided message roles train, each at its own weight. + Setting any role replaces the whole table — unset roles stay disabled.""" + + system: EchoRoleConfig | None = None + user: EchoRoleConfig | None = None + assistant: EchoRoleConfig | None = None + tool: EchoRoleConfig | None = None + + @model_validator(mode="after") + def require_a_role(self): + if self.system is None and self.user is None and self.assistant is None and self.tool is None: + raise ValueError("echo needs at least one role enabled (system, user, assistant, or tool)") + return self + + +class EchoFilterConfig(BaseConfig): + """User-supplied per-token filter narrowing the role-selected echo tokens. + + The callable is imported at startup and invoked once per rollout as + ``filter_fn(rollout, **kwargs) -> list[list[bool]]`` — one keep-mask per + trajectory step, each spanning that step's ``prompt_ids`` + + ``completion_ids``. Tokens with ``False`` never receive echo weight; the + filter can only narrow the role selection, not widen it. The raw rollout + exposes message text and sampling logprobs, so content filters (e.g. + dropping tool-output warnings) and sampling-probability filters need no + extra framework surface.""" + + import_path: str + """Import path to the filter callable (e.g. ``my_module.drop_warnings``).""" + + kwargs: dict[str, Any] = Field(default_factory=dict) + """Kwargs forwarded to the filter.""" + + class EchoAdvantageConfig(GroupNormAdvantageConfig): type: Literal["echo"] = "echo" # type: ignore[assignment] """ECHO: group-relative advantage on action tokens (GRPO), plus weighted - CE on env-provided observation tokens of later turns (tool output, - terminal responses). The observation tokens feed the ``ce`` loss component - at ``observation_weight`` and stay outside the rl mask and its - denominator.""" + CE on env-provided tokens of later turns (tool output, user feedback), + selected by message role via the renderer's per-token attribution + (requires ``orchestrator.renderer``; MITO rollouts carry no attribution). + Selected tokens feed the ``ce`` loss component at their role's ``alpha`` + and stay outside the rl mask and its denominator.""" - observation_weight: float = Field(0.1, gt=0) - """Per-token ce weight for observation tokens (ECHO's lambda).""" + roles: EchoRolesConfig = EchoRolesConfig(tool=EchoRoleConfig()) + """The role table. The default — tool-response bodies at ``alpha = 0.1`` + — is the vetted ECHO setting.""" - observations: Literal["tool", "all"] = "tool" - """Which env-provided tokens train. ``tool`` (the vetted default — the - ECHO setting) trains tool/terminal response bodies only, using the - renderer's per-token role attribution (requires ``orchestrator.renderer``; - MITO rollouts carry no attribution). ``all`` trains every env-provided - token — tool and user feedback alike.""" + filter: EchoFilterConfig | None = None + """Optional user-supplied filter narrowing the role-selected tokens.""" class RewardAdvantageConfig(BaseConfig): @@ -287,7 +327,7 @@ class AlgorithmConfig(BaseConfig): - ``opd`` — on-policy distillation: policy samples, ``ref_kl`` advantage against a reference model. Needs ``model``. - ``sft_distill`` — a frozen model samples, the policy trains with CE on its tokens (``supervised``). Needs ``model``. - ``self_distill`` — SDFT: policy samples, ``demo_ref_kl`` advantage against the live policy by default. - - ``echo`` — GRPO on action tokens + weighted CE on env-observation tokens. + - ``echo`` — GRPO on action tokens + weighted CE on tool-response observation tokens. """ model: ModelReference | None = Field(None, exclude=True, validation_alias=AliasChoices("model", "teacher")) diff --git a/packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py b/packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py index f9bf8275b5..cc9e4eff30 100644 --- a/packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py +++ b/packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py @@ -767,11 +767,10 @@ def validate_renderer_for_demo_scoring(self): "'renderer = \"None\"' (and note the renderer is forced off when no train env " "samples from the policy)." ) - if env.algo is not None and getattr(env.algo.advantage, "observations", None) == "tool": + if env.algo is not None and env.algo.advantage.type == "echo": raise ValueError( - f"env '{env.resolved_name}' trains on tool observation tokens, which needs the " - "renderer's per-token role attribution — set orchestrator.renderer, or assemble " - "echo with observations = 'all'." + f"env '{env.resolved_name}' trains env-provided tokens by message role (echo), " + "which needs the renderer's per-token attribution — set orchestrator.renderer." ) return self diff --git a/skills/configs/SKILL.md b/skills/configs/SKILL.md index bd63f7a15a..a3665ebcfb 100644 --- a/skills/configs/SKILL.md +++ b/skills/configs/SKILL.md @@ -51,7 +51,7 @@ CLI: `--env.0.id reverse-text --env.1.id math-env`. **Discriminated unions** — set the `type` field to pick the variant (`[orchestrator.advantage] type = "custom"`). Omit `type` to keep the default variant. -**Algorithm presets** — `[orchestrator.algo] name = "grpo" | "opd" | "sft_distill" | "self_distill" | "echo"` bundles sampling and the advantage (the per-token training signal: credit assignment + loss routing, fused — echo's `observation_weight` lives on `advantage`). The preset merges under your keys, so fields you don't set are kept; a different `advantage.type` replaces the strategy wholesale. Per-env override: `[[orchestrator.train.env]]` `algo = { name = "echo" }`. prime-rl only hosts the trainable policy; frozen models are inline external endpoints on the algorithm — `[orchestrator.algo.model]` (alias: `[orchestrator.algo.teacher]`) with `name` + `base_url` folds into the unresolved component reference (`advantage.model` for opd, `sampling.source` for sft_distill). `model = "policy"` points a component at the live policy (`self_distill`'s default). See `docs/algorithms.md`. +**Algorithm presets** — `[orchestrator.algo] name = "grpo" | "opd" | "sft_distill" | "self_distill" | "echo"` bundles sampling and the advantage (the per-token training signal: credit assignment + loss routing, fused). Presets are **atomic**: a name fixes both components, and only the `model` / `teacher` shorthand may accompany it — to customize anything else, drop `name` and assemble the components (`[orchestrator.algo.advantage] type = "echo"` + `[orchestrator.algo.advantage.roles.user] alpha = 0.1`; setting any echo role replaces the whole role table). Per-env override: `[[orchestrator.train.env]]` `algo = { name = "echo" }`. prime-rl only hosts the trainable policy; frozen models are inline external endpoints on the algorithm — `[orchestrator.algo.model]` (alias: `[orchestrator.algo.teacher]`) with `name` + `base_url` folds into the unresolved component reference (`advantage.model` for opd, `sampling.source` for sft_distill). `model = "policy"` points a component at the live policy (`self_distill`'s default). See `docs/algorithms.md`. **`BaseModel | None` fields** — bare flag enables defaults; nested override enables and sets: diff --git a/src/prime_rl/orchestrator/algo/algorithm.py b/src/prime_rl/orchestrator/algo/algorithm.py index 7236782f8a..f88779ea46 100644 --- a/src/prime_rl/orchestrator/algo/algorithm.py +++ b/src/prime_rl/orchestrator/algo/algorithm.py @@ -28,6 +28,8 @@ import asyncio from collections import defaultdict +from collections.abc import Callable +from functools import partial from itertools import cycle from typing import TYPE_CHECKING, ClassVar @@ -108,8 +110,11 @@ def __init__(self, config: AlgorithmConfig, policy_pool: InferencePool, renderer self.renderer = renderer self.reference_pool: InferencePool | None = None # resolved in setup() when the algorithm declares a model self.connected_pools: list[InferencePool] = [] # client pools connected in setup(); closed at shutdown - self.observation_weight: float | None = None # ce weight for env-provided tokens; None masks them out - self.observation_tokens: str | None = None # which env tokens interleave tags: "tool", "all", or None + # Echo selection: message role -> per-token ce weight for env-provided + # tokens (None masks them all out), plus an optional user filter + # narrowing the selection per rollout. Consumed by interleave_rollout. + self.echo_roles: dict[str, float] | None = None + self.echo_filter_fn: Callable[..., list[list[bool]]] | None = None async def setup(self) -> None: """Connect a client pool to the algorithm's frozen reference model and @@ -142,7 +147,7 @@ def finalize_group(self, rollouts: list[TrainRollout]) -> None: sample.advantage = rollout.advantage if rollout.advantage is not None else 0.0 sample.reward = rollout.reward sample.env_name = rollout.env_name - stamp_loss_routing(sample, self.action_loss_type, self.observation_weight) + stamp_loss_routing(sample, self.action_loss_type) def _reference_pool(self) -> InferencePool: pool = self.reference_pool @@ -165,16 +170,25 @@ def assign(self, rollouts: list[TrainRollout]) -> None: class EchoAlgorithm(GRPOAlgorithm): - """GRPO on action tokens, plus weighted CE on env-provided observation - tokens (tool output, terminal responses). The observation tokens feed the - ``ce`` loss component at ``observation_weight`` and stay outside the rl - mask and its denominator.""" + """GRPO on action tokens, plus weighted CE on env-provided tokens of + later turns (tool output, user feedback), selected by message role — + tool-response bodies at the vetted default. Selected tokens feed the + ``ce`` loss component at their role's ``alpha`` and stay outside the rl + mask and its denominator. An optional user filter narrows the selection + per rollout (e.g. dropping tool-output warnings).""" def __init__(self, config: AlgorithmConfig, policy_pool: InferencePool, renderer: Renderer | None): super().__init__(config, policy_pool, renderer) - assert isinstance(config.advantage, EchoAdvantageConfig) - self.observation_weight = config.advantage.observation_weight - self.observation_tokens = config.advantage.observations + advantage = config.advantage + assert isinstance(advantage, EchoAdvantageConfig) + self.echo_roles = { + role: role_config.alpha + for role in ("system", "user", "assistant", "tool") + if (role_config := getattr(advantage.roles, role)) is not None + } + if advantage.filter is not None: + filter_fn = import_object(advantage.filter.import_path) + self.echo_filter_fn = partial(filter_fn, **advantage.filter.kwargs) class OPDAlgorithm(Algorithm): diff --git a/src/prime_rl/orchestrator/algo/routing.py b/src/prime_rl/orchestrator/algo/routing.py index cdeea94de5..d1f4357cb4 100644 --- a/src/prime_rl/orchestrator/algo/routing.py +++ b/src/prime_rl/orchestrator/algo/routing.py @@ -19,25 +19,23 @@ from prime_rl.orchestrator.types import TrainRollout -def stamp_loss_routing( - sample: TrainingSample, action_loss_type: ActionLossType, observation_weight: float | None -) -> None: +def stamp_loss_routing(sample: TrainingSample, action_loss_type: ActionLossType) -> None: """Stamp the algorithm's loss routing onto one sample's component weight streams. Action tokens (the trainable completion tokens) feed the algorithm's component: ``rl`` is the default (absent streams ship nothing), while ``ce``/``ref_kl`` stamp that component's weights over the action tokens - and zero the rl stream. When the algorithm trains on observations - (``observation_weight`` is set), env-provided tokens (tagged by - ``interleave_rollout`` in ``completion_obs_mask``) get that ce weight — - they stay out of ``completion_mask``, so the ce component is the only one - that trains them. ``completion_obs_mask`` is orchestrator-internal and - cleared here so it never ships. + and zero the rl stream. When the algorithm trains on observations, + env-provided tokens carry their per-token ce weights (tagged by + ``interleave_rollout`` in ``completion_obs_weights``) — they stay out of + ``completion_mask``, so the ce component is the only one that trains + them. ``completion_obs_weights`` is orchestrator-internal and cleared + here so it never ships. """ - obs_mask = sample.completion_obs_mask - sample.completion_obs_mask = None - train_obs = observation_weight is not None and obs_mask is not None and any(obs_mask) + obs_weights = sample.completion_obs_weights + sample.completion_obs_weights = None + train_obs = obs_weights is not None and any(obs_weights) if action_loss_type == "rl" and not train_obs: return @@ -57,11 +55,11 @@ def stamp_loss_routing( sample.ref_kl_weights = action_weights if train_obs: - assert obs_mask is not None and observation_weight is not None + assert obs_weights is not None ce_weights = sample.ce_weights if sample.ce_weights is not None else [0.0] * seq_len - for i, is_obs in enumerate(obs_mask): - if is_obs: - ce_weights[prompt_len + i] = observation_weight + for i, weight in enumerate(obs_weights): + if weight: + ce_weights[prompt_len + i] = weight sample.ce_weights = ce_weights diff --git a/src/prime_rl/orchestrator/train_sink.py b/src/prime_rl/orchestrator/train_sink.py index 7746400c75..62e94c41f1 100644 --- a/src/prime_rl/orchestrator/train_sink.py +++ b/src/prime_rl/orchestrator/train_sink.py @@ -156,12 +156,14 @@ async def process_rollout(self, rollout: TrainRollout) -> None: needs_backfill = any(s["tokens"] is None for s in raw.get("trajectory") or []) if needs_backfill: await asyncio.to_thread(backfill_rollout_tokens, raw, self.tokenizer, renderer=self.renderer) + algorithm = self.train_envs.get(rollout.env_name).algorithm samples = await asyncio.to_thread( interleave_rollout, raw, mm_token_type_ids_mapping=self.mm_token_type_ids_mapping, env_name=rollout.env_name, - observation_tokens=self.train_envs.get(rollout.env_name).algorithm.observation_tokens, + echo_roles=algorithm.echo_roles, + echo_filter_fn=algorithm.echo_filter_fn, ) rollout.samples = samples or [] # Offload base64 image bytes to disk as soon as the rollout is diff --git a/src/prime_rl/orchestrator/trajectories.py b/src/prime_rl/orchestrator/trajectories.py index e7228a9aba..8cc92959e2 100644 --- a/src/prime_rl/orchestrator/trajectories.py +++ b/src/prime_rl/orchestrator/trajectories.py @@ -1,5 +1,6 @@ import base64 import hashlib +from collections.abc import Callable from pathlib import Path from typing import Any @@ -201,22 +202,18 @@ def backfill_rollout_tokens( return True -def _observation_span_mask(tokens: dict[str, Any], prefix_len: int, mode: str | None) -> list[bool]: - """Observation tags for one later-turn prompt extension +def _observation_span_weights(tokens: dict[str, Any], prefix_len: int, role_weights: dict[str, float]) -> list[float]: + """Per-token echo weights for one later-turn prompt extension (``prompt_ids[prefix_len:]``). - ``"all"`` tags the whole span. ``"tool"`` tags tool-message tokens only, - via the renderer's per-token attribution — response bodies when the - renderer provides ``is_content``, whole tool messages otherwise.""" - span = range(prefix_len, len(tokens["prompt_ids"])) - if mode == "all": - return [True] * len(span) - assert mode == "tool", f"unknown observation_tokens mode: {mode!r}" + Each token gets its message role's weight (0.0 for unselected roles), via + the renderer's per-token attribution — message content bodies when the + renderer provides ``is_content``, whole messages otherwise.""" attribution = tokens.get("prompt_attribution") if attribution is None: raise ValueError( - "observation_tokens='tool' needs the renderer's per-token role attribution, " - "which MITO rollouts don't carry — use the renderer, or observations='all'." + "echo selects env-provided tokens by message role, which needs the renderer's " + "per-token attribution — MITO rollouts don't carry it; set orchestrator.renderer." ) # Serialized steps carry the attribution as a dict of RenderedTokens @@ -227,14 +224,37 @@ def field(key: str) -> Any: indices = field("message_indices") roles = field("message_roles") is_content = field("is_content") or [] - mask = [] - for k in span: + weights = [] + for k in range(prefix_len, len(tokens["prompt_ids"])): idx = indices[k] - tool = idx >= 0 and roles[idx] == "tool" - if tool and is_content: - tool = bool(is_content[k]) - mask.append(tool) - return mask + selected = idx >= 0 and roles[idx] in role_weights + if selected and is_content: + selected = bool(is_content[k]) + weights.append(role_weights[roles[idx]] if selected else 0.0) + return weights + + +def _echo_filter_masks(output: vf.RolloutOutput, filter_fn: Callable[..., list[list[bool]]]) -> list[list[bool]]: + """Invoke the user echo filter and validate its shape: one keep-mask per + trajectory step, each spanning that step's ``prompt_ids`` + + ``completion_ids``.""" + trajectory = output["trajectory"] + masks = filter_fn(output) + if not isinstance(masks, list) or len(masks) != len(trajectory): + got = len(masks) if isinstance(masks, list) else type(masks).__name__ + raise ValueError( + f"echo filter must return one keep-mask per trajectory step: got {got}, expected {len(trajectory)}" + ) + for step_idx, (step, mask) in enumerate(zip(trajectory, masks)): + tokens = step["tokens"] + expected = len(tokens["prompt_ids"]) + len(tokens["completion_ids"]) + if not isinstance(mask, list) or len(mask) != expected: + got = len(mask) if isinstance(mask, list) else type(mask).__name__ + raise ValueError( + f"echo filter mask for step {step_idx} must span the step's prompt+completion " + f"tokens: got {got}, expected {expected}" + ) + return masks def interleave_rollout( @@ -242,7 +262,8 @@ def interleave_rollout( mm_token_type_ids_mapping: dict[int, int] | None = None, *, env_name: str = "", - observation_tokens: str | None = None, + echo_roles: dict[str, float] | None = None, + echo_filter_fn: Callable[..., list[list[bool]]] | None = None, ) -> list[TrainingSample] | None: """ Convert vf.RolloutOutput to trainable rollouts by interleaving trajectory steps @@ -259,13 +280,14 @@ def interleave_rollout( Returns a list of samples - could be 1 (extension always held) or up to T (extension never held). - With ``observation_tokens``, each sample additionally carries - ``completion_obs_mask`` marking env-provided tokens within - ``completion_ids`` (the later-turn prompt extensions). ``"all"`` marks - every env-provided token; ``"tool"`` marks tool-response bodies only, - via the renderer's per-token role attribution. Algorithms that train on - observations (ECHO) route these tokens to the CE loss type instead of - dropping them. + With ``echo_roles`` (message role → per-token ce weight), each sample + additionally carries ``completion_obs_weights`` over its + ``completion_ids``: env-provided tokens within the later-turn prompt + extensions get their message role's weight (via the renderer's per-token + attribution), everything else 0.0. ``echo_filter_fn`` optionally narrows + the selection with per-step keep-masks (see ``_echo_filter_masks``). + Algorithms that train on observations (ECHO) fold these weights into the + CE component instead of dropping the tokens. For VLM models, each renderer-produced trajectory step carries its per-image processed tensors inline on ``multi_modal_data``; the last @@ -325,6 +347,8 @@ def prepare_step_tokens(step: vf.TrajectoryStep, step_idx: int) -> dict[str, Any return None prepared_steps.append(prepared) + echo_filter_masks = _echo_filter_masks(output, echo_filter_fn) if echo_roles and echo_filter_fn else None + # Deferred routed_experts state per sample: O(N) chunk list concatenated # once at finalize, replacing the prior O(N²) per-extension unpack/repack. sample_routed_state: dict[int, dict[str, Any]] = {} @@ -357,7 +381,7 @@ def make_sample(tokens: dict[str, Any], step_idx: int) -> TrainingSample: mm_token_type_ids=None, routed_experts=None, # deferred — finalized at end of interleave_rollout # A step's own completion tokens are actions, not observations - completion_obs_mask=[False] * len(completion_ids) if observation_tokens else None, + completion_obs_weights=[0.0] * len(completion_ids) if echo_roles else None, ) # Initialize routed-experts state for this sample. First chunk is the # raw step routed_experts (no pad, no copy). running_len is the @@ -426,8 +450,12 @@ def extend_sample( sample.completion_ids.extend(new_prompt_ids) sample.completion_mask.extend([False] * len(new_prompt_ids)) sample.completion_logprobs.extend([0.0] * len(new_prompt_ids)) - if sample.completion_obs_mask is not None: - sample.completion_obs_mask.extend(_observation_span_mask(tokens, prefix_len, observation_tokens)) + if sample.completion_obs_weights is not None: + weights = _observation_span_weights(tokens, prefix_len, echo_roles) + if echo_filter_masks is not None: + step_mask = echo_filter_masks[step_idx] + weights = [w if step_mask[prefix_len + j] else 0.0 for j, w in enumerate(weights)] + sample.completion_obs_weights.extend(weights) # Extend with new completion tokens completion_ids = tokens["completion_ids"] @@ -437,8 +465,8 @@ def extend_sample( else: sample.completion_mask.extend(tokens["completion_mask"]) sample.completion_logprobs.extend(tokens["completion_logprobs"]) - if sample.completion_obs_mask is not None: - sample.completion_obs_mask.extend([False] * len(completion_ids)) + if sample.completion_obs_weights is not None: + sample.completion_obs_weights.extend([0.0] * len(completion_ids)) step_routed = tokens.get("routed_experts") state = sample_routed_state.get(id(sample)) diff --git a/src/prime_rl/transport/types.py b/src/prime_rl/transport/types.py index 95825070de..d5373d55c9 100644 --- a/src/prime_rl/transport/types.py +++ b/src/prime_rl/transport/types.py @@ -65,11 +65,11 @@ class TrainingSample(msgspec.Struct, array_like=True, gc=False, omit_defaults=Tr # rollout-level ``advantage`` scalar over the sequence. token_advantages: list[float] | None = None - # Orchestrator-internal: marks env-provided observation tokens within - # ``completion_ids`` (set by ``interleave_rollout`` when the env's - # algorithm trains on observations). Consumed by the env algorithm when - # stamping loss routing and cleared before transport. - completion_obs_mask: list[bool] | None = None + # Orchestrator-internal: per-token echo ce weights for env-provided + # tokens within ``completion_ids`` (set by ``interleave_rollout`` when the + # env's algorithm trains on observations; 0.0 = not selected). Folded into + # ``ce_weights`` when stamping loss routing and cleared before transport. + completion_obs_weights: list[float] | None = None class TrainingBatch(msgspec.Struct, array_like=True, gc=False, omit_defaults=True): diff --git a/tests/unit/orchestrator/test_algorithms.py b/tests/unit/orchestrator/test_algorithms.py index 6cb4056f3e..5ec21352d3 100644 --- a/tests/unit/orchestrator/test_algorithms.py +++ b/tests/unit/orchestrator/test_algorithms.py @@ -38,7 +38,7 @@ def test_preset_expansion(name, model, source, advantage_type, advantage_model, def test_preset_with_component_override_is_rejected(): with pytest.raises(ValueError, match="presets are atomic"): - AlgorithmConfig(name="echo", advantage={"observation_weight": 0.5}) + AlgorithmConfig(name="echo", advantage={"roles": {"user": {"alpha": 0.5}}}) with pytest.raises(ValueError, match="presets are atomic"): AlgorithmConfig(name="opd", model=FROZEN, advantage={"max_concurrent": 64}) with pytest.raises(ValueError, match="presets are atomic"): @@ -46,9 +46,24 @@ def test_preset_with_component_override_is_rejected(): def test_assembled_components_without_preset_name(): - algo = AlgorithmConfig(advantage={"type": "echo", "observation_weight": 0.5}) + algo = AlgorithmConfig(advantage={"type": "echo", "roles": {"user": {"alpha": 0.5}}}) assert algo.advantage.type == "echo" - assert algo.advantage.observation_weight == 0.5 + assert algo.advantage.roles.user.alpha == 0.5 + # Setting any role replaces the whole table: the tool default is gone + assert algo.advantage.roles.tool is None + + +def test_echo_preset_defaults_to_tool_bodies(): + algo = AlgorithmConfig(name="echo") + assert algo.advantage.roles.tool.alpha == 0.1 + assert algo.advantage.roles.system is None + assert algo.advantage.roles.user is None + assert algo.advantage.roles.assistant is None + + +def test_echo_roles_require_at_least_one(): + with pytest.raises(ValueError, match="at least one role"): + AlgorithmConfig(advantage={"type": "echo", "roles": {}}) def test_ref_kl_requires_model_reference(): @@ -87,7 +102,7 @@ def test_rl_loss_type_incompatible_with_frozen_sampling(): AlgorithmConfig(sampling={"source": FROZEN}, advantage={"type": "group_norm"}) -def _make_sample(obs_mask: list[bool] | None) -> TrainingSample: +def _make_sample(obs_weights: list[float] | None) -> TrainingSample: return TrainingSample( prompt_ids=[1, 2], prompt_mask=[False, False], @@ -96,13 +111,13 @@ def _make_sample(obs_mask: list[bool] | None) -> TrainingSample: completion_logprobs=[-0.1, -0.2, 0.0, -0.3], completion_temperatures=[], env_name="test-env", - completion_obs_mask=obs_mask, + completion_obs_weights=obs_weights, ) def test_stamp_loss_routing_uniform_rl(): - sample = _make_sample(obs_mask=None) - stamp_loss_routing(sample, "rl", None) + sample = _make_sample(obs_weights=None) + stamp_loss_routing(sample, "rl") # Hot path: absent streams mean rl weight 1.0 on the loss mask assert sample.rl_weights is None assert sample.ce_weights is None @@ -110,8 +125,8 @@ def test_stamp_loss_routing_uniform_rl(): def test_stamp_loss_routing_ref_kl_action(): - sample = _make_sample(obs_mask=None) - stamp_loss_routing(sample, "ref_kl", None) + sample = _make_sample(obs_weights=None) + stamp_loss_routing(sample, "ref_kl") # Action tokens (completion_mask True) feed the ref_kl component; rl is off assert sample.rl_weights == [0.0] * 6 assert sample.ref_kl_weights == [0.0, 0.0] + [1.0, 1.0, 0.0, 1.0] @@ -119,8 +134,8 @@ def test_stamp_loss_routing_ref_kl_action(): def test_stamp_loss_routing_ce_action(): - sample = _make_sample(obs_mask=None) - stamp_loss_routing(sample, "ce", None) + sample = _make_sample(obs_weights=None) + stamp_loss_routing(sample, "ce") assert sample.rl_weights == [0.0] * 6 assert sample.ce_weights == [0.0, 0.0] + [1.0, 1.0, 0.0, 1.0] assert sample.ref_kl_weights is None @@ -128,11 +143,11 @@ def test_stamp_loss_routing_ce_action(): def test_stamp_loss_routing_echo_observations(): # Token at completion index 2 is an env observation (masked out today) - sample = _make_sample(obs_mask=[False, False, True, False]) - stamp_loss_routing(sample, "rl", 0.1) + sample = _make_sample(obs_weights=[0.0, 0.0, 0.1, 0.0]) + stamp_loss_routing(sample, "rl") - assert sample.completion_obs_mask is None # cleared, never ships - # The observation token trains on the ce component with the configured + assert sample.completion_obs_weights is None # cleared, never ships + # The observation token trains on the ce component with its role's # weight; it stays out of completion_mask (the rl mask), so the rl # component and its denominator never see it. assert sample.completion_mask == [True, True, False, True] @@ -141,10 +156,10 @@ def test_stamp_loss_routing_echo_observations(): assert sample.ref_kl_weights is None -def test_stamp_loss_routing_clears_obs_mask_when_unused(): - sample = _make_sample(obs_mask=[False, False, True, False]) - stamp_loss_routing(sample, "rl", None) - assert sample.completion_obs_mask is None +def test_stamp_loss_routing_clears_obs_weights_when_all_zero(): + sample = _make_sample(obs_weights=[0.0, 0.0, 0.0, 0.0]) + stamp_loss_routing(sample, "rl") + assert sample.completion_obs_weights is None assert sample.ce_weights is None assert sample.completion_mask == [True, True, False, True] @@ -163,20 +178,20 @@ def _make_rollout(samples: list[TrainingSample], token_advantages: list[float] | def test_spread_token_advantages_pads_prompt(): - rollout = _make_rollout([_make_sample(obs_mask=None)], token_advantages=[0.5, -0.5, 0.0, 1.0]) + rollout = _make_rollout([_make_sample(obs_weights=None)], token_advantages=[0.5, -0.5, 0.0, 1.0]) spread_token_advantages(rollout) # 2 prompt positions padded with 0.0 + 4 completion-aligned advantages assert rollout.samples[0].token_advantages == [0.0, 0.0, 0.5, -0.5, 0.0, 1.0] def test_spread_token_advantages_rejects_misaligned(): - rollout = _make_rollout([_make_sample(obs_mask=None)], token_advantages=[0.5]) + rollout = _make_rollout([_make_sample(obs_weights=None)], token_advantages=[0.5]) with pytest.raises(ValueError, match="align"): spread_token_advantages(rollout) def test_spread_token_advantages_rejects_multi_sample_rollouts(): - samples = [_make_sample(obs_mask=None), _make_sample(obs_mask=None)] + samples = [_make_sample(obs_weights=None), _make_sample(obs_weights=None)] rollout = _make_rollout(samples, token_advantages=[0.5, -0.5, 0.0, 1.0]) with pytest.raises(ValueError, match="exactly one training sample"): spread_token_advantages(rollout) @@ -219,40 +234,66 @@ def step(prompt_ids, completion_ids, logprobs, prompt_attribution=None): ) -def test_interleave_tags_observation_tokens(): - samples = interleave_rollout(_two_step_rollout(), env_name="test-env", observation_tokens="all") - assert samples is not None and len(samples) == 1 - sample = samples[0] - assert sample.completion_ids == [3, 4, 5, 6, 7, 8] - # [3,4] step-1 action, [5,6] observation, [7,8] step-2 action - assert sample.completion_obs_mask == [False, False, True, True, False, False] - assert sample.completion_mask == [True, True, False, False, True, True] - - -def test_interleave_tags_tool_observation_tokens(): +def test_interleave_tags_observation_weights_by_role(): # Span tokens [5,6] (positions 4,5) belong to a tool message; is_content - # excludes the wrap token, so only the body token is tagged. + # excludes the wrap token, so only the body token gets the tool weight. attribution = { "message_indices": [0, 0, 1, 1, 2, 2], "message_roles": ["user", "assistant", "tool"], "is_content": [True, True, True, True, False, True], } - samples = interleave_rollout(_two_step_rollout(attribution), env_name="test-env", observation_tokens="tool") - assert samples is not None - assert samples[0].completion_obs_mask == [False, False, False, True, False, False] + samples = interleave_rollout(_two_step_rollout(attribution), env_name="test-env", echo_roles={"tool": 0.1}) + assert samples is not None and len(samples) == 1 + sample = samples[0] + assert sample.completion_ids == [3, 4, 5, 6, 7, 8] + # [3,4] step-1 action, [5,6] observation, [7,8] step-2 action + assert sample.completion_obs_weights == [0.0, 0.0, 0.0, 0.1, 0.0, 0.0] + assert sample.completion_mask == [True, True, False, False, True, True] - # Without is_content, whole tool messages are tagged. - attribution = {"message_indices": [0, 0, 1, 1, 2, 2], "message_roles": ["user", "assistant", "tool"]} - samples = interleave_rollout(_two_step_rollout(attribution), env_name="test-env", observation_tokens="tool") + # Without is_content, whole messages count; each role carries its own weight. + attribution = {"message_indices": [0, 0, 1, 1, 2, 3], "message_roles": ["user", "assistant", "tool", "user"]} + samples = interleave_rollout( + _two_step_rollout(attribution), env_name="test-env", echo_roles={"tool": 0.1, "user": 0.05} + ) assert samples is not None - assert samples[0].completion_obs_mask == [False, False, True, True, False, False] + assert samples[0].completion_obs_weights == [0.0, 0.0, 0.1, 0.05, 0.0, 0.0] # MITO rollouts carry no attribution: loud error, not a silent no-op. - with pytest.raises(ValueError, match="role attribution"): - interleave_rollout(_two_step_rollout(), env_name="test-env", observation_tokens="tool") + with pytest.raises(ValueError, match="attribution"): + interleave_rollout(_two_step_rollout(), env_name="test-env", echo_roles={"tool": 0.1}) + + +def test_interleave_echo_filter_narrows_selection(): + attribution = {"message_indices": [0, 0, 1, 1, 2, 2], "message_roles": ["user", "assistant", "tool"]} + + def keep_last_only(rollout): + # One keep-mask per step over prompt+completion; drops span position 4. + return [[True] * 4, [True, True, True, True, False, True, True, True]] + + samples = interleave_rollout( + _two_step_rollout(attribution), env_name="test-env", echo_roles={"tool": 0.1}, echo_filter_fn=keep_last_only + ) + assert samples is not None + assert samples[0].completion_obs_weights == [0.0, 0.0, 0.0, 0.1, 0.0, 0.0] + + # Shape violations fail loudly: wrong step count, wrong per-step length. + with pytest.raises(ValueError, match="per trajectory step"): + interleave_rollout( + _two_step_rollout(attribution), + env_name="test-env", + echo_roles={"tool": 0.1}, + echo_filter_fn=lambda r: [[True] * 4], + ) + with pytest.raises(ValueError, match="prompt\\+completion"): + interleave_rollout( + _two_step_rollout(attribution), + env_name="test-env", + echo_roles={"tool": 0.1}, + echo_filter_fn=lambda r: [[True] * 4, [True] * 6], + ) -def test_interleave_obs_mask_off_by_default(): +def test_interleave_obs_weights_off_by_default(): samples = interleave_rollout(_two_step_rollout(), env_name="test-env") assert samples is not None - assert samples[0].completion_obs_mask is None + assert samples[0].completion_obs_weights is None diff --git a/tests/unit/test_configs.py b/tests/unit/test_configs.py index fc5c980f38..f4c7d4e4a3 100644 --- a/tests/unit/test_configs.py +++ b/tests/unit/test_configs.py @@ -170,7 +170,7 @@ def test_removed_fused_lm_head_chunk_size_field_is_rejected(): def test_env_advantage_shorthand_assembles_own_algorithm(): config = OrchestratorConfig.model_validate( { - "renderer": {"name": "qwen3"}, # echo's tool-mode default needs role attribution + "renderer": {"name": "qwen3"}, # echo needs the renderer's role attribution "algo": {"name": "echo"}, "train": {"env": [{"id": "a", "advantage": {"type": "reward"}}, {"id": "b"}]}, }