Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 54 additions & 17 deletions packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import math
import warnings
from pathlib import Path
from typing import Annotated, Any, Literal, TypeAlias
from typing import Annotated, Any, ClassVar, Literal, TypeAlias

from pydantic import AliasChoices, Field, model_serializer, model_validator
from pydantic_core.core_schema import SerializerFunctionWrapHandler
Expand Down Expand Up @@ -422,13 +422,53 @@ class CheckpointConfig(BaseConfig):
"""Skip loading the progress from checkpoint."""


FilterAction: TypeAlias = Literal["monitor", "drop", "penalize"]


class BaseFilterConfig(BaseConfig):
"""Shared action fields for rollout filters.

Exactly one of ``action`` / ``enforce`` should be set (``enforce`` is legacy
compatibility). If neither is set, the filter falls back to its per-type
default action.
"""

action: FilterAction | None = None
"""What to do when the filter detects a rollout. ``monitor``: only track detection metrics. ``drop``: skip the rollout entirely so it is not sent to the trainer. ``penalize``: cap the rollout's reward at ``penalty_reward`` before advantage computation while keeping it trainable. If None, resolves from the legacy ``enforce`` flag, falling back to the filter's default action."""

enforce: bool | None = None
"""Legacy flag kept for backwards compatibility: ``true`` resolves to ``action="drop"``, ``false`` to ``action="monitor"``. Prefer setting ``action`` instead."""

penalty_reward: float = -1.0
"""Reward cap applied when ``action="penalize"``: final reward = ``min(raw_reward, penalty_reward)``. Ignored by other actions."""

_default_action: ClassVar[FilterAction] = "monitor"

@model_validator(mode="after")
def _validate_action_and_enforce(self):
if self.action is not None and self.enforce is not None:
implied = "drop" if self.enforce else "monitor"
if self.action != implied:
raise ValueError(
f"Conflicting filter config: action={self.action!r} but enforce={self.enforce} "
f"implies action={implied!r}. Set only `action` (preferred) or only `enforce`."
)
return self

@property
def resolved_action(self) -> FilterAction:
"""The effective action: explicit ``action`` wins, then legacy ``enforce``, then the per-type default."""
if self.action is not None:
return self.action
if self.enforce is not None:
return "drop" if self.enforce else "monitor"
return self._default_action


# Flags rare tokens generated at high entropy (Section 5.2, https://arxiv.org/abs/2510.02387).
class GibberishFilterConfig(BaseConfig):
class GibberishFilterConfig(BaseFilterConfig):
type: Literal["gibberish"] = "gibberish"

enforce: bool = False
"""When True, skip detected rollouts entirely so they are not sent to the trainer. When False, only track detection metrics."""

token_id_threshold: int = 100_000
"""Token IDs above this are candidates for gibberish. BPE tokens are sorted by merge order."""

Expand All @@ -439,12 +479,9 @@ class GibberishFilterConfig(BaseConfig):
# Flags rollouts stuck in a repetition loop: emits high-confidence tokens for an extended stretch.
# Flagged when `window` consecutive tokens are each sampled with probability above `prob_threshold`.
# (Section 3.2, https://arxiv.org/abs/2506.13585)
class RepetitionFilterConfig(BaseConfig):
class RepetitionFilterConfig(BaseFilterConfig):
type: Literal["repetition"] = "repetition"

enforce: bool = False
"""When True, skip detected rollouts entirely so they are not sent to the trainer. When False, only track detection metrics."""

window: int = Field(3_000, ge=1)
"""Consecutive high-probability steps required to flag the rollout."""

Expand All @@ -453,11 +490,10 @@ class RepetitionFilterConfig(BaseConfig):


# Flags rollouts with zero advantage.
class ZeroAdvantageFilterConfig(BaseConfig):
class ZeroAdvantageFilterConfig(BaseFilterConfig):
type: Literal["zero_advantage"] = "zero_advantage"

enforce: bool = True
"""When True, skip detected rollouts entirely so they are not sent to the trainer. When False, only track detection metrics."""
_default_action: ClassVar[FilterAction] = "drop"


FilterConfig: TypeAlias = Annotated[
Expand Down Expand Up @@ -552,14 +588,15 @@ def _preserve_mito_renderer(self, handler: SerializerFunctionWrapHandler) -> dic
advantage: AdvantageConfig | None = DefaultAdvantageConfig()

pre_batch_filters: list[FilterConfig] = [
GibberishFilterConfig(enforce=False),
RepetitionFilterConfig(enforce=False),
ZeroAdvantageFilterConfig(enforce=False),
GibberishFilterConfig(),
RepetitionFilterConfig(),
ZeroAdvantageFilterConfig(action="monitor"),
]
"""Filters applied *before* a rollout enters the training batch buffer.
All three filter types are registered in monitor mode by default; flip ``enforce=true`` per type
All three filter types are registered in monitor mode by default; set ``action="drop"`` per type
to drop matching rollouts before they consume a slot in the batch (e.g. a zero-advantage group
never makes it into a training batch)."""
never makes it into a training batch), or ``action="penalize"`` (gibberish/repetition) to cap the
rollout's reward at ``penalty_reward`` before advantage computation while keeping it trainable."""

post_batch_filters: list[FilterConfig] = [
GibberishFilterConfig(),
Expand Down
107 changes: 86 additions & 21 deletions src/prime_rl/orchestrator/filters.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,41 @@
"""Orchestrator-side rollout filters for detecting degenerate generations.

Filters run after rollouts complete, inspecting token IDs and logprobs to
detect gibberish or repetition. Detection metrics are always tracked.
When enforce=True, detected rollouts are skipped entirely during training and
are not sent to the trainer. Reward is kept as-is for baseline calculation.
detect gibberish or repetition. Detection metrics are always tracked. Each
filter resolves to one of three actions:

- ``monitor``: only record detection metrics;
- ``drop``: detected rollouts are skipped entirely during training and are
not sent to the trainer. Reward is kept as-is for baseline calculation;
- ``penalize``: detected rollouts stay trainable, but their reward is capped
at ``penalty_reward``. Penalties must be applied before advantage
computation to create negative policy-gradient signal — token/logprob
based filters are ``pre_advantage`` phase so ``TrainSink.process_group``
runs them before ``assign_advantages``.
"""

from __future__ import annotations

import math
from dataclasses import dataclass
from typing import TYPE_CHECKING, Protocol
from typing import TYPE_CHECKING, Literal, Protocol, TypeAlias

from prime_rl.configs.orchestrator import FilterConfig
from prime_rl.configs.orchestrator import FilterAction, FilterConfig
from prime_rl.utils.logger import get_logger

if TYPE_CHECKING:
from prime_rl.orchestrator.types import TrainRollout


FilterPhase: TypeAlias = Literal["pre_advantage", "post_advantage"]

_ACTION_LOG_NAMES: dict[FilterAction, str] = {
"monitor": "Monitoring",
"drop": "Dropping",
"penalize": "Penalizing",
}


@dataclass
class FilterResult:
detected: bool
Expand All @@ -27,7 +44,9 @@ class FilterResult:

class RolloutFilter(Protocol):
name: str
enforce: bool
action: FilterAction
phase: FilterPhase
penalty_reward: float

def check(self, rollout: "TrainRollout") -> FilterResult: ...

Expand All @@ -47,7 +66,9 @@ class GibberishFilter:
name: str
token_id_threshold: int
logprob_threshold: float
enforce: bool = False
action: FilterAction = "monitor"
penalty_reward: float = -1.0
phase: FilterPhase = "pre_advantage"

def check(self, rollout: "TrainRollout") -> FilterResult:
global_idx = 0
Expand Down Expand Up @@ -77,7 +98,9 @@ class RepetitionFilter:
name: str
window: int
logprob_threshold: float
enforce: bool = False
action: FilterAction = "monitor"
penalty_reward: float = -1.0
phase: FilterPhase = "pre_advantage"

def check(self, rollout: "TrainRollout") -> FilterResult:
consecutive = 0
Expand All @@ -103,7 +126,9 @@ class ZeroAdvantageFilter:
GRPO group earned the same reward, so the centered advantage collapses)."""

name: str
enforce: bool = True
action: FilterAction = "drop"
penalty_reward: float = -1.0
phase: FilterPhase = "post_advantage"

def check(self, rollout: "TrainRollout") -> FilterResult:
if rollout.advantage is not None and rollout.advantage == 0.0:
Expand All @@ -118,19 +143,22 @@ def setup_filter(config: FilterConfig, vocab_size: int) -> RolloutFilter:
name="gibberish",
token_id_threshold=config.token_id_threshold,
logprob_threshold=-math.log(vocab_size) - config.logprob_offset,
enforce=config.enforce,
action=config.resolved_action,
penalty_reward=config.penalty_reward,
)
elif config.type == "repetition":
return RepetitionFilter(
name="repetition",
window=config.window,
logprob_threshold=math.log(config.prob_threshold),
enforce=config.enforce,
action=config.resolved_action,
penalty_reward=config.penalty_reward,
)
elif config.type == "zero_advantage":
return ZeroAdvantageFilter(
name="zero_advantage",
enforce=config.enforce,
action=config.resolved_action,
penalty_reward=config.penalty_reward,
)
raise ValueError(f"Unknown filter type: {config.type}")

Expand All @@ -141,24 +169,59 @@ def setup_filters(configs: list[FilterConfig], vocab_size: int, *, kind: str) ->
if filters:
get_logger().info(f"Configured {len(filters)} {kind} rollout filter(s):")
for config, filt in zip(configs, filters):
mode = "Enforcing" if filt.enforce else "Monitoring"
mode = _ACTION_LOG_NAMES[filt.action]
params = ", ".join(f"{k}={v}" for k, v in config.model_dump().items())
get_logger().info(f" {mode} {filt.name} filter ({params})")
return filters


def split_filters(filters: list[RolloutFilter]) -> tuple[list[RolloutFilter], list[RolloutFilter]]:
"""Split filters into ``(pre_advantage, post_advantage)`` phase lists."""
return (
[f for f in filters if f.phase == "pre_advantage"],
[f for f in filters if f.phase == "post_advantage"],
)


def penalize_reward(
rollout: "TrainRollout", filter_name: str, penalty_reward: float, detection_index: int | None
) -> None:
"""Cap the rollout's reward at ``penalty_reward`` and record penalty metadata.

Uses ``min(...)`` so the penalty is a cap: rewards already below
``penalty_reward`` are never improved. The original env reward is
preserved in ``rollout.raw_reward`` (first penalty wins) and per-filter
details in ``rollout.reward_penalties``.
"""
raw_reward = rollout.reward
penalized_reward = min(raw_reward, penalty_reward)
if rollout.raw_reward is None:
rollout.raw_reward = raw_reward
rollout.raw["reward"] = penalized_reward
rollout.reward_penalties[filter_name] = {
"raw_reward": raw_reward,
"penalized_reward": penalized_reward,
"detection_index": detection_index,
}


def apply_filters(filters: list[RolloutFilter], rollouts: list["TrainRollout"]) -> None: # noqa: F821 (forward ref)
"""Flag ``TrainRollout``\\ s in place with per-filter detection + drop decision.
"""Flag ``TrainRollout``\\ s in place with per-filter detection + action.

Each rollout's ``filter_results`` dict records per-filter detection bools;
``is_filtered`` is True iff an enforcing filter detected it. First matching
filter wins per rollout (no double-counting). Reward and trajectory tokens
are left untouched so the rollout can still contribute to baseline
calculations and metric aggregation.
``is_filtered`` is True iff a ``drop`` filter detected it. A ``penalize``
filter caps the rollout's reward at its ``penalty_reward`` but leaves the
rollout trainable. First matching filter wins per rollout within a call
(no double-counting). Trajectory tokens are left untouched so the rollout
can still contribute to baseline calculations and metric aggregation.

Safe to call more than once on the same rollouts (e.g. once per filter
phase): missing ``filter_results`` keys are initialized without wiping
results, drops, or penalties recorded by an earlier call.
"""
for rollout in rollouts:
rollout.filter_results = {f.name: False for f in filters}
rollout.is_filtered = False
for filt in filters:
rollout.filter_results.setdefault(filt.name, False)

if not filters:
return
Expand All @@ -168,6 +231,8 @@ def apply_filters(filters: list[RolloutFilter], rollouts: list["TrainRollout"])
result = filt.check(rollout)
if result.detected:
rollout.filter_results[filt.name] = True
if filt.enforce:
if filt.action == "drop":
rollout.is_filtered = True
elif filt.action == "penalize":
penalize_reward(rollout, filt.name, filt.penalty_reward, result.detection_index)
break
13 changes: 13 additions & 0 deletions src/prime_rl/orchestrator/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,19 @@ def compute_solve_rates(df):
"step": step,
}

# Reward-penalty metrics: only emitted when a `penalize` filter fired
# somewhere in the batch (keeps dashboards clean otherwise)
penalized_names = sorted({name for r in rollouts for name in r.reward_penalties})
for name in penalized_names:
to_log[f"filters/all/{name}_penalized"] = sum(
1.0 for r in rollouts if name in r.reward_penalties
) / max(num_rollouts, 1)
if any(r.raw_reward is not None for r in rollouts):
raw_rewards = pd.Series([r.reward if r.raw_reward is None else r.raw_reward for r in rollouts])
to_log["raw_reward/all/mean"] = raw_rewards.mean()
to_log["raw_reward/all/max"] = raw_rewards.max()
to_log["raw_reward/all/min"] = raw_rewards.min()

# Per-env metrics
per_env_columns = [
"seq_len",
Expand Down
Loading
Loading