-
Notifications
You must be signed in to change notification settings - Fork 308
feat(orchestrator): add penalize action for gibberish/repetition filters #2775
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
anravich13-cloud
wants to merge
11
commits into
main
Choose a base branch
from
feat/repetition-gibberish-penalty
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 4 commits
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
5208ce8
feat(orchestrator): add penalize action for gibberish/repetition filters
anravich13-cloud 695cdf1
style: ruff format metrics.py
anravich13-cloud 4e1613d
fix(orchestrator): re-sync sample rewards after post-batch penalize
anravich13-cloud 85b696e
test(orchestrator): add TrainSink-level tests for filter actions
anravich13-cloud b644be5
docs: launch RL trainer validation step via torchrun
anravich13-cloud b9f2b01
fix(orchestrator): make debug validation run standalone and self-term…
anravich13-cloud 530a285
docs: fix eval validation step to use vf-eval
anravich13-cloud 732578f
refactor: drop legacy enforce compat, trim tests per review
anravich13-cloud 62894cf
Merge pull request #2787 into repetition filter branch
8c8a808
Merge origin/main into repetition filter branch
10190e1
Clean up post_batch_filters in orch.toml
anravich13-cloud File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,7 +1,7 @@ | ||
| import math | ||
| import warnings | ||
| from pathlib import Path | ||
| from typing import Annotated, Any, Literal, TypeAlias | ||
| from typing import Annotated, Any, ClassVar, Literal, TypeAlias | ||
|
|
||
| from pydantic import AliasChoices, Field, model_serializer, model_validator | ||
| from pydantic_core.core_schema import SerializerFunctionWrapHandler | ||
|
|
@@ -422,13 +422,53 @@ class CheckpointConfig(BaseConfig): | |
| """Skip loading the progress from checkpoint.""" | ||
|
|
||
|
|
||
| FilterAction: TypeAlias = Literal["monitor", "drop", "penalize"] | ||
|
|
||
|
|
||
| class BaseFilterConfig(BaseConfig): | ||
| """Shared action fields for rollout filters. | ||
|
|
||
| Exactly one of ``action`` / ``enforce`` should be set (``enforce`` is legacy | ||
| compatibility). If neither is set, the filter falls back to its per-type | ||
| default action. | ||
| """ | ||
|
|
||
| action: FilterAction | None = None | ||
| """What to do when the filter detects a rollout. ``monitor``: only track detection metrics. ``drop``: skip the rollout entirely so it is not sent to the trainer. ``penalize``: cap the rollout's reward at ``penalty_reward`` before advantage computation while keeping it trainable. If None, resolves from the legacy ``enforce`` flag, falling back to the filter's default action.""" | ||
|
|
||
| enforce: bool | None = None | ||
| """Legacy flag kept for backwards compatibility: ``true`` resolves to ``action="drop"``, ``false`` to ``action="monitor"``. Prefer setting ``action`` instead.""" | ||
|
|
||
| penalty_reward: float = -1.0 | ||
| """Reward cap applied when ``action="penalize"``: final reward = ``min(raw_reward, penalty_reward)``. Ignored by other actions.""" | ||
|
|
||
| _default_action: ClassVar[FilterAction] = "monitor" | ||
|
|
||
| @model_validator(mode="after") | ||
| def _validate_action_and_enforce(self): | ||
| if self.action is not None and self.enforce is not None: | ||
| implied = "drop" if self.enforce else "monitor" | ||
| if self.action != implied: | ||
| raise ValueError( | ||
| f"Conflicting filter config: action={self.action!r} but enforce={self.enforce} " | ||
| f"implies action={implied!r}. Set only `action` (preferred) or only `enforce`." | ||
| ) | ||
| return self | ||
|
|
||
| @property | ||
| def resolved_action(self) -> FilterAction: | ||
| """The effective action: explicit ``action`` wins, then legacy ``enforce``, then the per-type default.""" | ||
| if self.action is not None: | ||
| return self.action | ||
| if self.enforce is not None: | ||
| return "drop" if self.enforce else "monitor" | ||
| return self._default_action | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. lets rather use proper default value that having | None and doing this logic here |
||
|
|
||
|
|
||
| # Flags rare tokens generated at high entropy (Section 5.2, https://arxiv.org/abs/2510.02387). | ||
| class GibberishFilterConfig(BaseConfig): | ||
| class GibberishFilterConfig(BaseFilterConfig): | ||
| type: Literal["gibberish"] = "gibberish" | ||
|
|
||
| enforce: bool = False | ||
| """When True, skip detected rollouts entirely so they are not sent to the trainer. When False, only track detection metrics.""" | ||
|
|
||
| token_id_threshold: int = 100_000 | ||
| """Token IDs above this are candidates for gibberish. BPE tokens are sorted by merge order.""" | ||
|
|
||
|
|
@@ -439,12 +479,9 @@ class GibberishFilterConfig(BaseConfig): | |
| # Flags rollouts stuck in a repetition loop: emits high-confidence tokens for an extended stretch. | ||
| # Flagged when `window` consecutive tokens are each sampled with probability above `prob_threshold`. | ||
| # (Section 3.2, https://arxiv.org/abs/2506.13585) | ||
| class RepetitionFilterConfig(BaseConfig): | ||
| class RepetitionFilterConfig(BaseFilterConfig): | ||
| type: Literal["repetition"] = "repetition" | ||
|
|
||
| enforce: bool = False | ||
| """When True, skip detected rollouts entirely so they are not sent to the trainer. When False, only track detection metrics.""" | ||
|
|
||
| window: int = Field(3_000, ge=1) | ||
| """Consecutive high-probability steps required to flag the rollout.""" | ||
|
|
||
|
|
@@ -453,11 +490,10 @@ class RepetitionFilterConfig(BaseConfig): | |
|
|
||
|
|
||
| # Flags rollouts with zero advantage. | ||
| class ZeroAdvantageFilterConfig(BaseConfig): | ||
| class ZeroAdvantageFilterConfig(BaseFilterConfig): | ||
| type: Literal["zero_advantage"] = "zero_advantage" | ||
|
|
||
| enforce: bool = True | ||
| """When True, skip detected rollouts entirely so they are not sent to the trainer. When False, only track detection metrics.""" | ||
| _default_action: ClassVar[FilterAction] = "drop" | ||
|
|
||
|
|
||
| FilterConfig: TypeAlias = Annotated[ | ||
|
|
@@ -552,14 +588,15 @@ def _preserve_mito_renderer(self, handler: SerializerFunctionWrapHandler) -> dic | |
| advantage: AdvantageConfig | None = DefaultAdvantageConfig() | ||
|
|
||
| pre_batch_filters: list[FilterConfig] = [ | ||
| GibberishFilterConfig(enforce=False), | ||
| RepetitionFilterConfig(enforce=False), | ||
| ZeroAdvantageFilterConfig(enforce=False), | ||
| GibberishFilterConfig(), | ||
| RepetitionFilterConfig(), | ||
| ZeroAdvantageFilterConfig(action="monitor"), | ||
| ] | ||
| """Filters applied *before* a rollout enters the training batch buffer. | ||
| All three filter types are registered in monitor mode by default; flip ``enforce=true`` per type | ||
| All three filter types are registered in monitor mode by default; set ``action="drop"`` per type | ||
| to drop matching rollouts before they consume a slot in the batch (e.g. a zero-advantage group | ||
| never makes it into a training batch).""" | ||
| never makes it into a training batch), or ``action="penalize"`` (gibberish/repetition) to cap the | ||
| rollout's reward at ``penalty_reward`` before advantage computation while keeping it trainable.""" | ||
|
|
||
| post_batch_filters: list[FilterConfig] = [ | ||
| GibberishFilterConfig(), | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can just remove the backwards compat here