diff --git a/packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py b/packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py index 55a2210abe..1bda57633f 100644 --- a/packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py +++ b/packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py @@ -1,7 +1,7 @@ import math import warnings from pathlib import Path -from typing import Annotated, Any, Literal, TypeAlias +from typing import Annotated, Any, ClassVar, Literal, TypeAlias from pydantic import AliasChoices, Field, model_serializer, model_validator from pydantic_core.core_schema import SerializerFunctionWrapHandler @@ -422,13 +422,53 @@ class CheckpointConfig(BaseConfig): """Skip loading the progress from checkpoint.""" +FilterAction: TypeAlias = Literal["monitor", "drop", "penalize"] + + +class BaseFilterConfig(BaseConfig): + """Shared action fields for rollout filters. + + Exactly one of ``action`` / ``enforce`` should be set (``enforce`` is legacy + compatibility). If neither is set, the filter falls back to its per-type + default action. + """ + + action: FilterAction | None = None + """What to do when the filter detects a rollout. ``monitor``: only track detection metrics. ``drop``: skip the rollout entirely so it is not sent to the trainer. ``penalize``: cap the rollout's reward at ``penalty_reward`` before advantage computation while keeping it trainable. If None, resolves from the legacy ``enforce`` flag, falling back to the filter's default action.""" + + enforce: bool | None = None + """Legacy flag kept for backwards compatibility: ``true`` resolves to ``action="drop"``, ``false`` to ``action="monitor"``. Prefer setting ``action`` instead.""" + + penalty_reward: float = -1.0 + """Reward cap applied when ``action="penalize"``: final reward = ``min(raw_reward, penalty_reward)``. Ignored by other actions.""" + + _default_action: ClassVar[FilterAction] = "monitor" + + @model_validator(mode="after") + def _validate_action_and_enforce(self): + if self.action is not None and self.enforce is not None: + implied = "drop" if self.enforce else "monitor" + if self.action != implied: + raise ValueError( + f"Conflicting filter config: action={self.action!r} but enforce={self.enforce} " + f"implies action={implied!r}. Set only `action` (preferred) or only `enforce`." + ) + return self + + @property + def resolved_action(self) -> FilterAction: + """The effective action: explicit ``action`` wins, then legacy ``enforce``, then the per-type default.""" + if self.action is not None: + return self.action + if self.enforce is not None: + return "drop" if self.enforce else "monitor" + return self._default_action + + # Flags rare tokens generated at high entropy (Section 5.2, https://arxiv.org/abs/2510.02387). -class GibberishFilterConfig(BaseConfig): +class GibberishFilterConfig(BaseFilterConfig): type: Literal["gibberish"] = "gibberish" - enforce: bool = False - """When True, skip detected rollouts entirely so they are not sent to the trainer. When False, only track detection metrics.""" - token_id_threshold: int = 100_000 """Token IDs above this are candidates for gibberish. BPE tokens are sorted by merge order.""" @@ -439,12 +479,9 @@ class GibberishFilterConfig(BaseConfig): # Flags rollouts stuck in a repetition loop: emits high-confidence tokens for an extended stretch. # Flagged when `window` consecutive tokens are each sampled with probability above `prob_threshold`. # (Section 3.2, https://arxiv.org/abs/2506.13585) -class RepetitionFilterConfig(BaseConfig): +class RepetitionFilterConfig(BaseFilterConfig): type: Literal["repetition"] = "repetition" - enforce: bool = False - """When True, skip detected rollouts entirely so they are not sent to the trainer. When False, only track detection metrics.""" - window: int = Field(3_000, ge=1) """Consecutive high-probability steps required to flag the rollout.""" @@ -453,11 +490,10 @@ class RepetitionFilterConfig(BaseConfig): # Flags rollouts with zero advantage. -class ZeroAdvantageFilterConfig(BaseConfig): +class ZeroAdvantageFilterConfig(BaseFilterConfig): type: Literal["zero_advantage"] = "zero_advantage" - enforce: bool = True - """When True, skip detected rollouts entirely so they are not sent to the trainer. When False, only track detection metrics.""" + _default_action: ClassVar[FilterAction] = "drop" FilterConfig: TypeAlias = Annotated[ @@ -552,14 +588,15 @@ def _preserve_mito_renderer(self, handler: SerializerFunctionWrapHandler) -> dic advantage: AdvantageConfig | None = DefaultAdvantageConfig() pre_batch_filters: list[FilterConfig] = [ - GibberishFilterConfig(enforce=False), - RepetitionFilterConfig(enforce=False), - ZeroAdvantageFilterConfig(enforce=False), + GibberishFilterConfig(), + RepetitionFilterConfig(), + ZeroAdvantageFilterConfig(action="monitor"), ] """Filters applied *before* a rollout enters the training batch buffer. - All three filter types are registered in monitor mode by default; flip ``enforce=true`` per type + All three filter types are registered in monitor mode by default; set ``action="drop"`` per type to drop matching rollouts before they consume a slot in the batch (e.g. a zero-advantage group - never makes it into a training batch).""" + never makes it into a training batch), or ``action="penalize"`` (gibberish/repetition) to cap the + rollout's reward at ``penalty_reward`` before advantage computation while keeping it trainable.""" post_batch_filters: list[FilterConfig] = [ GibberishFilterConfig(), diff --git a/src/prime_rl/orchestrator/filters.py b/src/prime_rl/orchestrator/filters.py index f8deda1230..4ff66148f2 100644 --- a/src/prime_rl/orchestrator/filters.py +++ b/src/prime_rl/orchestrator/filters.py @@ -1,24 +1,41 @@ """Orchestrator-side rollout filters for detecting degenerate generations. Filters run after rollouts complete, inspecting token IDs and logprobs to -detect gibberish or repetition. Detection metrics are always tracked. -When enforce=True, detected rollouts are skipped entirely during training and -are not sent to the trainer. Reward is kept as-is for baseline calculation. +detect gibberish or repetition. Detection metrics are always tracked. Each +filter resolves to one of three actions: + +- ``monitor``: only record detection metrics; +- ``drop``: detected rollouts are skipped entirely during training and are + not sent to the trainer. Reward is kept as-is for baseline calculation; +- ``penalize``: detected rollouts stay trainable, but their reward is capped + at ``penalty_reward``. Penalties must be applied before advantage + computation to create negative policy-gradient signal — token/logprob + based filters are ``pre_advantage`` phase so ``TrainSink.process_group`` + runs them before ``assign_advantages``. """ from __future__ import annotations import math from dataclasses import dataclass -from typing import TYPE_CHECKING, Protocol +from typing import TYPE_CHECKING, Literal, Protocol, TypeAlias -from prime_rl.configs.orchestrator import FilterConfig +from prime_rl.configs.orchestrator import FilterAction, FilterConfig from prime_rl.utils.logger import get_logger if TYPE_CHECKING: from prime_rl.orchestrator.types import TrainRollout +FilterPhase: TypeAlias = Literal["pre_advantage", "post_advantage"] + +_ACTION_LOG_NAMES: dict[FilterAction, str] = { + "monitor": "Monitoring", + "drop": "Dropping", + "penalize": "Penalizing", +} + + @dataclass class FilterResult: detected: bool @@ -27,7 +44,9 @@ class FilterResult: class RolloutFilter(Protocol): name: str - enforce: bool + action: FilterAction + phase: FilterPhase + penalty_reward: float def check(self, rollout: "TrainRollout") -> FilterResult: ... @@ -47,7 +66,9 @@ class GibberishFilter: name: str token_id_threshold: int logprob_threshold: float - enforce: bool = False + action: FilterAction = "monitor" + penalty_reward: float = -1.0 + phase: FilterPhase = "pre_advantage" def check(self, rollout: "TrainRollout") -> FilterResult: global_idx = 0 @@ -77,7 +98,9 @@ class RepetitionFilter: name: str window: int logprob_threshold: float - enforce: bool = False + action: FilterAction = "monitor" + penalty_reward: float = -1.0 + phase: FilterPhase = "pre_advantage" def check(self, rollout: "TrainRollout") -> FilterResult: consecutive = 0 @@ -103,7 +126,9 @@ class ZeroAdvantageFilter: GRPO group earned the same reward, so the centered advantage collapses).""" name: str - enforce: bool = True + action: FilterAction = "drop" + penalty_reward: float = -1.0 + phase: FilterPhase = "post_advantage" def check(self, rollout: "TrainRollout") -> FilterResult: if rollout.advantage is not None and rollout.advantage == 0.0: @@ -118,19 +143,22 @@ def setup_filter(config: FilterConfig, vocab_size: int) -> RolloutFilter: name="gibberish", token_id_threshold=config.token_id_threshold, logprob_threshold=-math.log(vocab_size) - config.logprob_offset, - enforce=config.enforce, + action=config.resolved_action, + penalty_reward=config.penalty_reward, ) elif config.type == "repetition": return RepetitionFilter( name="repetition", window=config.window, logprob_threshold=math.log(config.prob_threshold), - enforce=config.enforce, + action=config.resolved_action, + penalty_reward=config.penalty_reward, ) elif config.type == "zero_advantage": return ZeroAdvantageFilter( name="zero_advantage", - enforce=config.enforce, + action=config.resolved_action, + penalty_reward=config.penalty_reward, ) raise ValueError(f"Unknown filter type: {config.type}") @@ -141,24 +169,59 @@ def setup_filters(configs: list[FilterConfig], vocab_size: int, *, kind: str) -> if filters: get_logger().info(f"Configured {len(filters)} {kind} rollout filter(s):") for config, filt in zip(configs, filters): - mode = "Enforcing" if filt.enforce else "Monitoring" + mode = _ACTION_LOG_NAMES[filt.action] params = ", ".join(f"{k}={v}" for k, v in config.model_dump().items()) get_logger().info(f" {mode} {filt.name} filter ({params})") return filters +def split_filters(filters: list[RolloutFilter]) -> tuple[list[RolloutFilter], list[RolloutFilter]]: + """Split filters into ``(pre_advantage, post_advantage)`` phase lists.""" + return ( + [f for f in filters if f.phase == "pre_advantage"], + [f for f in filters if f.phase == "post_advantage"], + ) + + +def penalize_reward( + rollout: "TrainRollout", filter_name: str, penalty_reward: float, detection_index: int | None +) -> None: + """Cap the rollout's reward at ``penalty_reward`` and record penalty metadata. + + Uses ``min(...)`` so the penalty is a cap: rewards already below + ``penalty_reward`` are never improved. The original env reward is + preserved in ``rollout.raw_reward`` (first penalty wins) and per-filter + details in ``rollout.reward_penalties``. + """ + raw_reward = rollout.reward + penalized_reward = min(raw_reward, penalty_reward) + if rollout.raw_reward is None: + rollout.raw_reward = raw_reward + rollout.raw["reward"] = penalized_reward + rollout.reward_penalties[filter_name] = { + "raw_reward": raw_reward, + "penalized_reward": penalized_reward, + "detection_index": detection_index, + } + + def apply_filters(filters: list[RolloutFilter], rollouts: list["TrainRollout"]) -> None: # noqa: F821 (forward ref) - """Flag ``TrainRollout``\\ s in place with per-filter detection + drop decision. + """Flag ``TrainRollout``\\ s in place with per-filter detection + action. Each rollout's ``filter_results`` dict records per-filter detection bools; - ``is_filtered`` is True iff an enforcing filter detected it. First matching - filter wins per rollout (no double-counting). Reward and trajectory tokens - are left untouched so the rollout can still contribute to baseline - calculations and metric aggregation. + ``is_filtered`` is True iff a ``drop`` filter detected it. A ``penalize`` + filter caps the rollout's reward at its ``penalty_reward`` but leaves the + rollout trainable. First matching filter wins per rollout within a call + (no double-counting). Trajectory tokens are left untouched so the rollout + can still contribute to baseline calculations and metric aggregation. + + Safe to call more than once on the same rollouts (e.g. once per filter + phase): missing ``filter_results`` keys are initialized without wiping + results, drops, or penalties recorded by an earlier call. """ for rollout in rollouts: - rollout.filter_results = {f.name: False for f in filters} - rollout.is_filtered = False + for filt in filters: + rollout.filter_results.setdefault(filt.name, False) if not filters: return @@ -168,6 +231,8 @@ def apply_filters(filters: list[RolloutFilter], rollouts: list["TrainRollout"]) result = filt.check(rollout) if result.detected: rollout.filter_results[filt.name] = True - if filt.enforce: + if filt.action == "drop": rollout.is_filtered = True + elif filt.action == "penalize": + penalize_reward(rollout, filt.name, filt.penalty_reward, result.detection_index) break diff --git a/src/prime_rl/orchestrator/metrics.py b/src/prime_rl/orchestrator/metrics.py index 87ec99c424..ecb71dd133 100644 --- a/src/prime_rl/orchestrator/metrics.py +++ b/src/prime_rl/orchestrator/metrics.py @@ -130,6 +130,19 @@ def compute_solve_rates(df): "step": step, } + # Reward-penalty metrics: only emitted when a `penalize` filter fired + # somewhere in the batch (keeps dashboards clean otherwise) + penalized_names = sorted({name for r in rollouts for name in r.reward_penalties}) + for name in penalized_names: + to_log[f"filters/all/{name}_penalized"] = sum(1.0 for r in rollouts if name in r.reward_penalties) / max( + num_rollouts, 1 + ) + if any(r.raw_reward is not None for r in rollouts): + raw_rewards = pd.Series([r.reward if r.raw_reward is None else r.raw_reward for r in rollouts]) + to_log["raw_reward/all/mean"] = raw_rewards.mean() + to_log["raw_reward/all/max"] = raw_rewards.max() + to_log["raw_reward/all/min"] = raw_rewards.min() + # Per-env metrics per_env_columns = [ "seq_len", diff --git a/src/prime_rl/orchestrator/train_sink.py b/src/prime_rl/orchestrator/train_sink.py index f79a0d5eff..56b0674bc6 100644 --- a/src/prime_rl/orchestrator/train_sink.py +++ b/src/prime_rl/orchestrator/train_sink.py @@ -2,8 +2,9 @@ 1. ``process_rollout`` — eager per-rollout tokenization (overlaps with dispatcher producing more rollouts). Errored rollouts skip this. -2. ``process_group`` — filters errored rollouts, computes advantages over - survivors, runs the pre-batch filter pass. +2. ``process_group`` — filters errored rollouts, runs the pre-advantage + pre-batch filter pass (``penalize`` reward caps land before the baseline), + computes advantages over survivors, runs the post-advantage pass. 3. ``process_batch`` — applies post-batch filter annotations and assembles the trainer-bound ``TrainingSample`` list. Returns a ``TrainBatch``. @@ -20,7 +21,7 @@ from prime_rl.configs.orchestrator import OrchestratorConfig from prime_rl.orchestrator.advantage import assign_advantages from prime_rl.orchestrator.envs import TrainEnvs -from prime_rl.orchestrator.filters import RolloutFilter, apply_filters +from prime_rl.orchestrator.filters import RolloutFilter, apply_filters, split_filters from prime_rl.orchestrator.trajectories import ( backfill_rollout_tokens, interleave_rollout, @@ -170,8 +171,10 @@ async def process_rollout(self, rollout: TrainRollout) -> None: def process_group(self, group_id: uuid.UUID) -> None: """Finalize one GRPO group: drop errored rollouts (the whole group - when ``requires_group_scoring`` and any failed), assign advantages, - run pre-batch filters, append survivors to ``pending_batch``.""" + when ``requires_group_scoring`` and any failed), run pre-advantage + pre-batch filters (so ``penalize`` reward caps are visible to the + baseline), assign advantages, run post-advantage pre-batch filters, + append survivors to ``pending_batch``.""" group = self.pending_groups.pop(group_id, []) if not group: return @@ -196,8 +199,20 @@ def process_group(self, group_id: uuid.UUID) -> None: ) return + # Pre-advantage filters run before advantage assignment so a + # `penalize` action's reward cap is visible to the group baseline — + # the penalized rollout ends up with a lower advantage than its + # peers. Dropped rollouts still participate in the baseline (reward + # untouched), matching prior behavior. + pre_advantage_filters, post_advantage_filters = split_filters(self.pre_filters) + if pre_advantage_filters: + apply_filters(pre_advantage_filters, survivors) + assign_advantages(survivors, self.train_envs.get(env_name).advantage_fn) + if post_advantage_filters: + apply_filters(post_advantage_filters, survivors) + # Propagate to the pre-tokenized samples so the orchestrator can # collect samples at ship time without re-walking rollouts. The env # has a single sampling temperature; fan it out across each sample's @@ -211,8 +226,7 @@ def process_group(self, group_id: uuid.UUID) -> None: sample.training_mode = self.config.training_mode sample.completion_temperatures = [temperature] * len(sample.completion_ids) - if self.pre_filters: - apply_filters(self.pre_filters, survivors) + drop_filter_names = {f.name for f in self.pre_filters if f.action == "drop"} filtered_by_name: dict[str, int] = {} num_filtered = 0 for r in survivors: @@ -221,7 +235,7 @@ def process_group(self, group_id: uuid.UUID) -> None: self.pre_filter_dropped += 1 num_filtered += 1 for name, hit in r.filter_results.items(): - if hit: + if hit and name in drop_filter_names: self.pre_filter_dropped_by_name[name] = self.pre_filter_dropped_by_name.get(name, 0) + 1 filtered_by_name[name] = filtered_by_name.get(name, 0) + 1 continue @@ -264,6 +278,16 @@ def process_batch(self) -> TrainBatch: if self.post_filters: apply_filters(self.post_filters, cohort) + # A post-batch ``penalize`` filter caps the rollout reward after + # ``process_group`` already stamped it onto the samples — re-sync + # so trainer-bound samples agree with the rollout reward used in + # metrics. Advantage is intentionally untouched: post-batch runs + # after advantage computation, so a penalty here is metadata-only. + for r in cohort: + if not r.reward_penalties: + continue + for sample in r.samples: + sample.reward = r.reward # Samples are pre-built by ``process_rollout``; ``process_group`` # already set advantage/reward on each sample diff --git a/src/prime_rl/orchestrator/types.py b/src/prime_rl/orchestrator/types.py index c2a3f5de79..fc0c985a32 100644 --- a/src/prime_rl/orchestrator/types.py +++ b/src/prime_rl/orchestrator/types.py @@ -103,6 +103,11 @@ def to_dict(self) -> vf.RolloutOutput: if f.name == "filter_results": out["filters"] = dict(val) continue + # Only surface penalty metadata when a penalty actually fired + if f.name == "raw_reward" and val is None: + continue + if f.name == "reward_penalties" and not val: + continue out[f.name] = str(val) if isinstance(val, uuid.UUID) else val return out @@ -113,6 +118,10 @@ class TrainRollout(FinishedRollout): advantage: float | None = None is_filtered: bool = False filter_results: dict[str, bool] = field(default_factory=dict) + raw_reward: float | None = None + """Original env reward, recorded only when a ``penalize`` filter capped ``reward``.""" + reward_penalties: dict[str, dict] = field(default_factory=dict) + """Per-filter penalty metadata (``raw_reward`` / ``penalized_reward`` / ``detection_index``), keyed by filter name.""" @dataclass diff --git a/tests/unit/orchestrator/test_filters.py b/tests/unit/orchestrator/test_filters.py index 2643bf71bb..967a60631d 100644 --- a/tests/unit/orchestrator/test_filters.py +++ b/tests/unit/orchestrator/test_filters.py @@ -1,13 +1,24 @@ import math import uuid -from prime_rl.configs.orchestrator import GibberishFilterConfig, RepetitionFilterConfig +import pytest +from pydantic import ValidationError + +from prime_rl.configs.orchestrator import ( + DefaultAdvantageConfig, + GibberishFilterConfig, + RepetitionFilterConfig, + ZeroAdvantageFilterConfig, +) +from prime_rl.orchestrator.advantage import assign_advantages, setup_advantage_fn from prime_rl.orchestrator.filters import ( GibberishFilter, RepetitionFilter, + ZeroAdvantageFilter, apply_filters, setup_filter, setup_filters, + split_filters, ) from prime_rl.orchestrator.types import TrainRollout @@ -66,16 +77,35 @@ def _make_rollout( ) -def _make_gibberish_filter(vocab_size=128_000, token_id_threshold=100_000, logprob_offset=2.0, enforce=False): +def _make_gibberish_filter( + vocab_size=128_000, token_id_threshold=100_000, logprob_offset=2.0, action="monitor", penalty_reward=-1.0 +): logprob_threshold = -math.log(vocab_size) - logprob_offset return GibberishFilter( - name="gibberish", token_id_threshold=token_id_threshold, logprob_threshold=logprob_threshold, enforce=enforce + name="gibberish", + token_id_threshold=token_id_threshold, + logprob_threshold=logprob_threshold, + action=action, + penalty_reward=penalty_reward, ) -def _make_repetition_filter(window=5, prob_threshold=0.99, enforce=False): +def _make_repetition_filter(window=5, prob_threshold=0.99, action="monitor", penalty_reward=-1.0): return RepetitionFilter( - name="repetition", window=window, logprob_threshold=math.log(prob_threshold), enforce=enforce + name="repetition", + window=window, + logprob_threshold=math.log(prob_threshold), + action=action, + penalty_reward=penalty_reward, + ) + + +def _make_dirty_rollout(gibberish_filter, *, reward: float = 1.0) -> TrainRollout: + """A rollout that triggers the given gibberish filter.""" + return _make_rollout( + completion_ids=[120_000], + completion_logprobs=[gibberish_filter.logprob_threshold - 1.0], + reward=reward, ) @@ -187,6 +217,47 @@ def test_repetition_varied_probs_no_trigger(): assert result.detected is False +# --- config resolution / validation tests --- + + +def test_config_default_actions(): + assert GibberishFilterConfig().resolved_action == "monitor" + assert RepetitionFilterConfig().resolved_action == "monitor" + assert ZeroAdvantageFilterConfig().resolved_action == "drop" + + +def test_config_legacy_enforce_true_resolves_to_drop(): + assert GibberishFilterConfig(enforce=True).resolved_action == "drop" + assert RepetitionFilterConfig(enforce=True).resolved_action == "drop" + assert ZeroAdvantageFilterConfig(enforce=True).resolved_action == "drop" + + +def test_config_legacy_enforce_false_resolves_to_monitor(): + assert GibberishFilterConfig(enforce=False).resolved_action == "monitor" + assert RepetitionFilterConfig(enforce=False).resolved_action == "monitor" + assert ZeroAdvantageFilterConfig(enforce=False).resolved_action == "monitor" + + +def test_config_penalize_parses_with_penalty_reward(): + config = GibberishFilterConfig(action="penalize", penalty_reward=-0.5) + assert config.resolved_action == "penalize" + assert config.penalty_reward == -0.5 + + +def test_config_conflicting_action_and_enforce_raises(): + with pytest.raises(ValidationError): + GibberishFilterConfig(action="penalize", enforce=True) + with pytest.raises(ValidationError): + RepetitionFilterConfig(action="monitor", enforce=True) + with pytest.raises(ValidationError): + ZeroAdvantageFilterConfig(action="drop", enforce=False) + + +def test_config_consistent_action_and_enforce_ok(): + assert GibberishFilterConfig(action="drop", enforce=True).resolved_action == "drop" + assert RepetitionFilterConfig(action="monitor", enforce=False).resolved_action == "monitor" + + # --- setup_filter / setup_filters tests --- @@ -197,13 +268,21 @@ def test_setup_filter_gibberish(): assert gibberish_filter.name == "gibberish" assert gibberish_filter.token_id_threshold == 100_000 assert abs(gibberish_filter.logprob_threshold - (-math.log(128_000) - 2.0)) < 1e-10 - assert gibberish_filter.enforce is False + assert gibberish_filter.action == "monitor" + assert gibberish_filter.phase == "pre_advantage" -def test_setup_filter_gibberish_enforce(): +def test_setup_filter_gibberish_legacy_enforce(): config = GibberishFilterConfig(enforce=True) gibberish_filter = setup_filter(config, vocab_size=128_000) - assert gibberish_filter.enforce is True + assert gibberish_filter.action == "drop" + + +def test_setup_filter_gibberish_penalize(): + config = GibberishFilterConfig(action="penalize", penalty_reward=-0.5) + gibberish_filter = setup_filter(config, vocab_size=128_000) + assert gibberish_filter.action == "penalize" + assert gibberish_filter.penalty_reward == -0.5 def test_setup_filter_repetition(): @@ -213,13 +292,22 @@ def test_setup_filter_repetition(): assert repetition_filter.name == "repetition" assert repetition_filter.window == 3_000 assert abs(repetition_filter.logprob_threshold - math.log(0.99)) < 1e-10 - assert repetition_filter.enforce is False + assert repetition_filter.action == "monitor" + assert repetition_filter.phase == "pre_advantage" -def test_setup_filter_repetition_enforce(): +def test_setup_filter_repetition_legacy_enforce(): config = RepetitionFilterConfig(enforce=True) repetition_filter = setup_filter(config, vocab_size=128_000) - assert repetition_filter.enforce is True + assert repetition_filter.action == "drop" + + +def test_setup_filter_zero_advantage_defaults_to_drop(): + config = ZeroAdvantageFilterConfig() + zero_advantage_filter = setup_filter(config, vocab_size=128_000) + assert isinstance(zero_advantage_filter, ZeroAdvantageFilter) + assert zero_advantage_filter.action == "drop" + assert zero_advantage_filter.phase == "post_advantage" def test_setup_filters_multiple(): @@ -233,11 +321,22 @@ def test_setup_filters_multiple(): assert filters[1].name == "repetition" -# --- apply_filters tests (enforce=True) --- +def test_split_filters_by_phase(): + filters = [ + _make_gibberish_filter(), + _make_repetition_filter(), + ZeroAdvantageFilter(name="zero_advantage"), + ] + pre, post = split_filters(filters) + assert [f.name for f in pre] == ["gibberish", "repetition"] + assert [f.name for f in post] == ["zero_advantage"] + +# --- apply_filters tests (action="drop") --- -def test_apply_filters_enforced_flags_rollout(): - gibberish_filter = _make_gibberish_filter(enforce=True) + +def test_apply_filters_drop_flags_rollout(): + gibberish_filter = _make_gibberish_filter(action="drop") rollout = _make_rollout( completion_ids=[120_000], @@ -256,7 +355,7 @@ def test_apply_filters_enforced_flags_rollout(): def test_apply_filters_preserves_clean_rollouts(): - gibberish_filter = _make_gibberish_filter(enforce=True) + gibberish_filter = _make_gibberish_filter(action="drop") rollout = _make_rollout( completion_ids=[50, 60, 70], @@ -275,8 +374,8 @@ def test_apply_filters_preserves_clean_rollouts(): def test_apply_filters_first_filter_wins(): - gibberish_filter = _make_gibberish_filter(enforce=True) - repetition_filter = _make_repetition_filter(window=2, enforce=True) + gibberish_filter = _make_gibberish_filter(action="drop") + repetition_filter = _make_repetition_filter(window=2, action="drop") rollout = _make_rollout( completion_ids=[120_000, 1, 2], @@ -303,12 +402,10 @@ def test_apply_filters_empty_list(): def test_apply_filters_mixed_batch(): - gibberish_filter = _make_gibberish_filter(enforce=True) + gibberish_filter = _make_gibberish_filter(action="drop") clean = _make_rollout(completion_ids=[50], completion_logprobs=[-1.0], reward=1.0) - dirty = _make_rollout( - completion_ids=[120_000], completion_logprobs=[gibberish_filter.logprob_threshold - 1.0], reward=1.0 - ) + dirty = _make_dirty_rollout(gibberish_filter) apply_filters([gibberish_filter], [clean, dirty]) @@ -318,8 +415,8 @@ def test_apply_filters_mixed_batch(): assert dirty.is_filtered is True -def test_apply_filters_enforced_preserves_rollout_tokens(): - gibberish_filter = _make_gibberish_filter(enforce=True) +def test_apply_filters_drop_preserves_rollout_tokens(): + gibberish_filter = _make_gibberish_filter(action="drop") rollout = _make_rollout( completion_ids=[10, 120_000, 30], @@ -340,13 +437,9 @@ def test_apply_filters_enforced_preserves_rollout_tokens(): def test_apply_filters_preserves_existing_stop_condition(): - gibberish_filter = _make_gibberish_filter(enforce=True) + gibberish_filter = _make_gibberish_filter(action="drop") - rollout = _make_rollout( - completion_ids=[120_000], - completion_logprobs=[gibberish_filter.logprob_threshold - 1.0], - reward=1.0, - ) + rollout = _make_dirty_rollout(gibberish_filter) rollout.raw["stop_condition"] = "generation_truncated" apply_filters([gibberish_filter], [rollout]) @@ -355,17 +448,13 @@ def test_apply_filters_preserves_existing_stop_condition(): assert rollout.is_filtered is True -# --- apply_filters tests (monitor-only, enforce=False) --- +# --- apply_filters tests (action="monitor") --- def test_apply_filters_monitor_only_tracks_detection(): - gibberish_filter = _make_gibberish_filter(enforce=False) + gibberish_filter = _make_gibberish_filter(action="monitor") - rollout = _make_rollout( - completion_ids=[120_000], - completion_logprobs=[gibberish_filter.logprob_threshold - 1.0], - reward=1.0, - ) + rollout = _make_dirty_rollout(gibberish_filter) apply_filters([gibberish_filter], [rollout]) @@ -374,15 +463,15 @@ def test_apply_filters_monitor_only_tracks_detection(): assert rollout.raw["stop_condition"] is None assert rollout.filter_results == {"gibberish": True} assert rollout.is_filtered is False + assert rollout.raw_reward is None + assert rollout.reward_penalties == {} def test_apply_filters_monitor_only_mixed_batch(): - gibberish_filter = _make_gibberish_filter(enforce=False) + gibberish_filter = _make_gibberish_filter(action="monitor") clean = _make_rollout(completion_ids=[50], completion_logprobs=[-1.0], reward=1.0) - dirty = _make_rollout( - completion_ids=[120_000], completion_logprobs=[gibberish_filter.logprob_threshold - 1.0], reward=1.0 - ) + dirty = _make_dirty_rollout(gibberish_filter) apply_filters([gibberish_filter], [clean, dirty]) @@ -390,3 +479,156 @@ def test_apply_filters_monitor_only_mixed_batch(): assert dirty.reward == 1.0 assert clean.is_filtered is False assert dirty.is_filtered is False + + +# --- apply_filters tests (action="penalize") --- + + +def test_penalize_gibberish_caps_reward(): + gibberish_filter = _make_gibberish_filter(action="penalize") + + rollout = _make_dirty_rollout(gibberish_filter, reward=1.0) + + apply_filters([gibberish_filter], [rollout]) + + assert rollout.reward == -1.0 + assert rollout.raw_reward == 1.0 + assert rollout.filter_results == {"gibberish": True} + assert rollout.reward_penalties == { + "gibberish": {"raw_reward": 1.0, "penalized_reward": -1.0, "detection_index": 0} + } + + +def test_penalize_repetition_caps_reward(): + repetition_filter = _make_repetition_filter(window=3, action="penalize") + + rollout = _make_rollout( + completion_ids=list(range(3)), + completion_logprobs=[-0.001] * 3, + reward=0.0, + ) + + apply_filters([repetition_filter], [rollout]) + + assert rollout.reward == -1.0 + assert rollout.raw_reward == 0.0 + assert rollout.reward_penalties["repetition"]["detection_index"] == 2 + + +def test_penalize_does_not_filter_rollout(): + gibberish_filter = _make_gibberish_filter(action="penalize") + + rollout = _make_dirty_rollout(gibberish_filter) + + apply_filters([gibberish_filter], [rollout]) + + assert rollout.is_filtered is False + # Trajectory tokens stay untouched — rollout remains trainable + assert rollout.raw["trajectory"][0]["tokens"]["completion_mask"] == [1] + + +def test_penalize_respects_custom_penalty_reward(): + gibberish_filter = _make_gibberish_filter(action="penalize", penalty_reward=-0.5) + + rollout = _make_dirty_rollout(gibberish_filter, reward=1.0) + + apply_filters([gibberish_filter], [rollout]) + + assert rollout.reward == -0.5 + + +def test_penalize_does_not_improve_already_negative_reward(): + gibberish_filter = _make_gibberish_filter(action="penalize", penalty_reward=-1.0) + + rollout = _make_dirty_rollout(gibberish_filter, reward=-2.0) + + apply_filters([gibberish_filter], [rollout]) + + assert rollout.reward == -2.0 + assert rollout.raw_reward == -2.0 + + +def test_penalize_skips_clean_rollouts(): + gibberish_filter = _make_gibberish_filter(action="penalize") + + rollout = _make_rollout(completion_ids=[50], completion_logprobs=[-1.0], reward=1.0) + + apply_filters([gibberish_filter], [rollout]) + + assert rollout.reward == 1.0 + assert rollout.raw_reward is None + assert rollout.reward_penalties == {} + + +def test_penalize_first_match_wins_single_penalty(): + gibberish_filter = _make_gibberish_filter(action="penalize") + repetition_filter = _make_repetition_filter(window=2, action="penalize") + + # Triggers both gibberish (token 0) and repetition (tokens 1-2) + rollout = _make_rollout( + completion_ids=[120_000, 1, 2], + completion_logprobs=[gibberish_filter.logprob_threshold - 1.0, -0.001, -0.001], + reward=1.0, + ) + + apply_filters([gibberish_filter, repetition_filter], [rollout]) + + assert rollout.reward == -1.0 + assert rollout.filter_results == {"gibberish": True, "repetition": False} + assert list(rollout.reward_penalties) == ["gibberish"] + + +def test_repeated_apply_filters_preserves_prior_metadata(): + gibberish_filter = _make_gibberish_filter(action="penalize") + zero_advantage_filter = ZeroAdvantageFilter(name="zero_advantage", action="drop") + + rollout = _make_dirty_rollout(gibberish_filter, reward=1.0) + + # Phase 1: pre-advantage penalty + apply_filters([gibberish_filter], [rollout]) + rollout.advantage = 0.0 + # Phase 2: post-advantage drop must not wipe phase-1 results + apply_filters([zero_advantage_filter], [rollout]) + + assert rollout.filter_results == {"gibberish": True, "zero_advantage": True} + assert rollout.is_filtered is True + assert rollout.reward == -1.0 + assert rollout.raw_reward == 1.0 + assert "gibberish" in rollout.reward_penalties + + +# --- ordering tests: penalty visible to advantage computation --- + + +def test_penalized_reward_changes_advantage(): + gibberish_filter = _make_gibberish_filter(action="penalize") + advantage_fn = setup_advantage_fn(DefaultAdvantageConfig()) + + clean = _make_rollout(completion_ids=[50], completion_logprobs=[-1.0], reward=1.0) + dirty = _make_dirty_rollout(gibberish_filter, reward=1.0) + + apply_filters([gibberish_filter], [clean, dirty]) + assign_advantages([clean, dirty], advantage_fn) + + # Group rewards are (1.0, -1.0): mean 0 → clean +1, dirty -1 + assert clean.advantage == pytest.approx(1.0) + assert dirty.advantage == pytest.approx(-1.0) + assert dirty.advantage < clean.advantage + + +def test_equally_penalized_group_collapses_to_zero_advantage(): + gibberish_filter = _make_gibberish_filter(action="penalize") + zero_advantage_filter = ZeroAdvantageFilter(name="zero_advantage", action="drop") + advantage_fn = setup_advantage_fn(DefaultAdvantageConfig()) + + rollouts = [_make_dirty_rollout(gibberish_filter, reward=1.0) for _ in range(2)] + + apply_filters([gibberish_filter], rollouts) + assign_advantages(rollouts, advantage_fn) + apply_filters([zero_advantage_filter], rollouts) + + for rollout in rollouts: + assert rollout.reward == -1.0 + assert rollout.advantage == pytest.approx(0.0) + assert rollout.filter_results == {"gibberish": True, "zero_advantage": True} + assert rollout.is_filtered is True diff --git a/tests/unit/orchestrator/test_train_sink.py b/tests/unit/orchestrator/test_train_sink.py new file mode 100644 index 0000000000..3855a37af9 --- /dev/null +++ b/tests/unit/orchestrator/test_train_sink.py @@ -0,0 +1,212 @@ +"""Sink-level tests for filter actions: ``process_group`` ordering (penalty +visible to the advantage baseline + sample propagation) and ``process_batch`` +post-filter sample re-sync. + +``TrainSink``'s heavy constructor args (tokenizer / renderer / real envs) +are only used by ``add()`` / ``process_rollout``; these tests bypass them by +pre-building ``rollout.samples`` and driving ``process_group`` / +``process_batch`` directly. +""" + +import math +import uuid +from types import SimpleNamespace + +from prime_rl.configs.orchestrator import DefaultAdvantageConfig +from prime_rl.orchestrator.advantage import setup_advantage_fn +from prime_rl.orchestrator.filters import GibberishFilter, ZeroAdvantageFilter +from prime_rl.orchestrator.train_sink import TrainSink +from prime_rl.orchestrator.types import TrainRollout +from prime_rl.transport import TrainingSample + +VOCAB_SIZE = 128_000 + + +def _make_gibberish_filter(action="penalize", penalty_reward=-1.0): + return GibberishFilter( + name="gibberish", + token_id_threshold=100_000, + logprob_threshold=-math.log(VOCAB_SIZE) - 2.0, + action=action, + penalty_reward=penalty_reward, + ) + + +def _make_sample(reward=None): + return TrainingSample( + prompt_ids=[1, 2], + prompt_mask=[True, True], + completion_ids=[3, 4], + completion_mask=[True, True], + completion_logprobs=[-0.1, -0.2], + completion_temperatures=[1.0, 1.0], + env_name="test", + reward=reward, + ) + + +def _make_rollout(*, dirty: bool, reward: float = 1.0, group_id=None, gibberish_filter=None) -> TrainRollout: + """Rollout with one pre-built sample; ``dirty=True`` triggers the + gibberish filter (rare token at high entropy).""" + if dirty: + assert gibberish_filter is not None + completion_ids = [120_000] + completion_logprobs = [gibberish_filter.logprob_threshold - 1.0] + else: + completion_ids = [50] + completion_logprobs = [-1.0] + raw = { + "trajectory": [ + { + "tokens": { + "completion_ids": completion_ids, + "completion_logprobs": completion_logprobs, + "completion_mask": [1] * len(completion_ids), + } + } + ], + "reward": reward, + "stop_condition": None, + "metrics": {}, + } + rollout = TrainRollout( + raw=raw, + env_name="test", + example_id=0, + group_id=group_id or uuid.uuid4(), + policy_version=0, + off_policy_steps=0, + ) + rollout.samples = [_make_sample(reward=None)] + return rollout + + +class _FakeEnv: + def __init__(self, group_size=2): + self.config = SimpleNamespace(group_size=group_size) + self.requires_group_scoring = False + self.advantage_fn = setup_advantage_fn(DefaultAdvantageConfig()) + self.sampling_args = {"temperature": 1.0} + + +class _FakeTrainEnvs: + def __init__(self, env: _FakeEnv): + self._env = env + + def get(self, name: str) -> _FakeEnv: + return self._env + + +def _make_sink(*, pre_filters=(), post_filters=(), batch_size=1) -> TrainSink: + return TrainSink( + config=SimpleNamespace(training_mode="rl"), # process_group reads only training_mode + tokenizer=None, + renderer=None, + train_envs=_FakeTrainEnvs(_FakeEnv()), + mm_token_type_ids_mapping=None, + batch_size=batch_size, + token_batch_size=None, + pre_filters=list(pre_filters), + post_filters=list(post_filters), + ) + + +# --- process_group: pre-advantage penalty ordering --- + + +def test_process_group_penalize_lands_before_advantage_and_samples(): + gibberish_filter = _make_gibberish_filter(action="penalize") + sink = _make_sink(pre_filters=[gibberish_filter, ZeroAdvantageFilter(name="zero_advantage", action="drop")]) + + group_id = uuid.uuid4() + clean = _make_rollout(dirty=False, reward=1.0, group_id=group_id) + dirty = _make_rollout(dirty=True, reward=1.0, group_id=group_id, gibberish_filter=gibberish_filter) + sink.pending_groups[group_id] = [clean, dirty] + + sink.process_group(group_id) + + # Penalty applied before the group baseline: rewards (1.0, -1.0) → advantages (+1, -1) + assert dirty.reward == -1.0 + assert dirty.raw_reward == 1.0 + assert clean.advantage == 1.0 + assert dirty.advantage == -1.0 + + # Samples stamped with post-penalty reward and advantage + assert clean.samples[0].reward == 1.0 + assert clean.samples[0].advantage == 1.0 + assert dirty.samples[0].reward == -1.0 + assert dirty.samples[0].advantage == -1.0 + + # Penalized rollout stays trainable; nonzero advantages → no drops + assert sink.pending_batch == [clean, dirty] + assert sink.pre_filter_seen == 2 + assert sink.pre_filter_dropped == 0 + + +def test_process_group_equally_penalized_group_dropped_by_zero_advantage(): + gibberish_filter = _make_gibberish_filter(action="penalize") + sink = _make_sink(pre_filters=[gibberish_filter, ZeroAdvantageFilter(name="zero_advantage", action="drop")]) + + group_id = uuid.uuid4() + rollouts = [ + _make_rollout(dirty=True, reward=1.0, group_id=group_id, gibberish_filter=gibberish_filter) for _ in range(2) + ] + sink.pending_groups[group_id] = rollouts + + sink.process_group(group_id) + + # Both capped to the same reward → zero advantages → post-advantage drop + for rollout in rollouts: + assert rollout.reward == -1.0 + assert rollout.advantage == 0.0 + assert rollout.is_filtered is True + assert sink.pending_batch == [] + assert sink.pre_filter_dropped == 2 + # Drop attribution counts only drop-action filters, not the penalty + assert sink.pre_filter_dropped_by_name == {"zero_advantage": 2} + + +# --- process_batch: post-batch penalize re-syncs samples --- + + +def test_process_batch_penalize_resyncs_sample_reward(): + gibberish_filter = _make_gibberish_filter(action="penalize") + sink = _make_sink(post_filters=[gibberish_filter], batch_size=2) + + clean = _make_rollout(dirty=False, reward=1.0) + dirty = _make_rollout(dirty=True, reward=1.0, gibberish_filter=gibberish_filter) + # Simulate process_group's propagation: samples stamped with pre-penalty reward + for rollout in (clean, dirty): + rollout.samples[0].reward = rollout.reward + rollout.samples[0].advantage = 0.5 + sink.pending_batch = [clean, dirty] + + batch = sink.process_batch() + + # Shipped samples agree with the (penalized) rollout reward used in metrics + assert dirty.reward == -1.0 + assert dirty.samples[0].reward == -1.0 + assert clean.samples[0].reward == 1.0 + # Penalize keeps the rollout trainable; advantage is metadata-only post-batch + assert dirty.is_filtered is False + assert dirty.samples[0].advantage == 0.5 + assert batch.metrics.n_trainable == 2 + assert len(batch.samples) == 2 + + +def test_process_batch_drop_still_excludes_samples(): + gibberish_filter = _make_gibberish_filter(action="drop") + sink = _make_sink(post_filters=[gibberish_filter], batch_size=2) + + clean = _make_rollout(dirty=False, reward=1.0) + dirty = _make_rollout(dirty=True, reward=1.0, gibberish_filter=gibberish_filter) + for rollout in (clean, dirty): + rollout.samples[0].reward = rollout.reward + sink.pending_batch = [clean, dirty] + + batch = sink.process_batch() + + assert dirty.is_filtered is True + assert dirty.reward == 1.0 # drop leaves reward untouched + assert batch.metrics.n_trainable == 1 + assert len(batch.samples) == 1