From 84370ab405728291f83d376e7cf859f7ede241d5 Mon Sep 17 00:00:00 2001 From: hallerite Date: Tue, 9 Jun 2026 17:53:31 +0000 Subject: [PATCH] feat(orchestrator): EnvMixStrategy seam for env selection MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Extract TrainSource's weighted round-robin env pick into a swappable EnvMixStrategy (default WeightedRoundRobin). Example selection (the reshuffling cursor) stays in TrainSource. The strategy draws from TrainSource's RNG, so the example sequence is unchanged — pure extraction, no behavior delta. Separates 'which env' from 'which example' as the seam slice (c) builds on. Co-Authored-By: Claude Opus 4.8 --- src/prime_rl/orchestrator/sampling.py | 41 +++++++++++++++++++++++ src/prime_rl/orchestrator/train_source.py | 28 ++++++++++------ tests/unit/orchestrator/test_sampling.py | 27 +++++++++++++++ 3 files changed, 86 insertions(+), 10 deletions(-) create mode 100644 src/prime_rl/orchestrator/sampling.py create mode 100644 tests/unit/orchestrator/test_sampling.py diff --git a/src/prime_rl/orchestrator/sampling.py b/src/prime_rl/orchestrator/sampling.py new file mode 100644 index 0000000000..2eb25be6f8 --- /dev/null +++ b/src/prime_rl/orchestrator/sampling.py @@ -0,0 +1,41 @@ +"""Sampling strategies for training rollouts. + +``EnvMixStrategy`` (global) decides *which* env to draw from next — a swappable +seam between the train envs and the dispatcher. The default +(``WeightedRoundRobin``) reproduces the previous ``TrainSource`` behavior: a +weighted random choice over env names, weighted by configured ``ratio`` (when +every env sets one) or per-env dataset size. +""" + +from __future__ import annotations + +import random +from abc import ABC, abstractmethod + + +class EnvMixStrategy(ABC): + """Global: which env to draw from next. ``pick`` returns an env name.""" + + @abstractmethod + def pick(self) -> str: + """Return the env name to sample from next.""" + ... + + +class WeightedRoundRobin(EnvMixStrategy): + """Default env mix: weighted random choice over env names. Weights are the + configured per-env ratios (when all set) or per-env dataset sizes. + + Draws from the caller's RNG so env selection stays in the same stream as + ``TrainSource``'s dataset shuffles — the example sequence is unchanged. + """ + + def __init__(self, env_names: list[str], weights: list[float], *, rng: random.Random) -> None: + if not env_names: + raise ValueError("WeightedRoundRobin needs at least one env") + self._rng = rng + self._env_names = list(env_names) + self._weights = list(weights) + + def pick(self) -> str: + return self._rng.choices(self._env_names, weights=self._weights, k=1)[0] diff --git a/src/prime_rl/orchestrator/train_source.py b/src/prime_rl/orchestrator/train_source.py index db439f7539..722cfb6e11 100644 --- a/src/prime_rl/orchestrator/train_source.py +++ b/src/prime_rl/orchestrator/train_source.py @@ -1,20 +1,25 @@ """TrainSource: weighted round-robin across train envs, infinite pull. -Weights default to configured ``ratio`` (when every env sets one) or to -per-env dataset size. ``next_example`` reshuffles on cursor exhaustion.""" +Env selection is delegated to a swappable ``EnvMixStrategy`` (default: +weighted round-robin by configured ``ratio`` when every env sets one, else by +per-env dataset size); example selection stays here (a reshuffling cursor per +env). ``next_example`` reshuffles on cursor exhaustion. Returned dicts carry +``env_name`` + ``example_id``. +""" from __future__ import annotations import random from prime_rl.orchestrator.envs import TrainEnvs +from prime_rl.orchestrator.sampling import WeightedRoundRobin class TrainSource: - """``next_example(available_permits)`` picks a weighted-RR env and - returns its next example (or ``None`` when the env's per-call permit - cost doesn't fit — the dispatch loop retries when permits free up). - Returned dicts carry ``env_name`` + ``example_id``.""" + """``next_example(available_permits)`` picks an env via the mix strategy and + returns its next example (or ``None`` when the env's per-call permit cost + doesn't fit — the dispatch loop retries when permits free up). Returned + dicts carry ``env_name`` + ``example_id``.""" def __init__(self, train_envs: TrainEnvs, *, seed: int | None) -> None: self.rng = random.Random(seed) @@ -38,15 +43,18 @@ def __init__(self, train_envs: TrainEnvs, *, seed: int | None) -> None: self.cursors[env.name] = 0 self.env_costs[env.name] = env.config.group_size if env.requires_group_scoring else 1 - self.env_names = [e.name for e in self.envs] + env_names = [e.name for e in self.envs] configured_ratios = [e.config.ratio for e in self.envs] if all(r is not None for r in configured_ratios): - self.weights: list[float] = [float(r) for r in configured_ratios] # type: ignore[arg-type] + weights: list[float] = [float(r) for r in configured_ratios] # type: ignore[arg-type] else: - self.weights = [float(len(self.examples[name])) for name in self.env_names] + weights = [float(len(self.examples[name])) for name in env_names] + # Shares ``self.rng`` so env selection draws from the same stream as the + # dataset shuffles above — the example sequence matches the pre-seam path. + self.env_mix = WeightedRoundRobin(env_names, weights, rng=self.rng) def next_example(self, available_permits: int) -> dict | None: - env_name = self.rng.choices(self.env_names, weights=self.weights, k=1)[0] + env_name = self.env_mix.pick() if self.env_costs[env_name] > available_permits: return None rows = self.examples[env_name] diff --git a/tests/unit/orchestrator/test_sampling.py b/tests/unit/orchestrator/test_sampling.py new file mode 100644 index 0000000000..4a419af132 --- /dev/null +++ b/tests/unit/orchestrator/test_sampling.py @@ -0,0 +1,27 @@ +import random + +import pytest + +from prime_rl.orchestrator.sampling import WeightedRoundRobin + + +def test_weighted_round_robin_is_deterministic_per_rng(): + """Same seed → same pick sequence (env selection is reproducible).""" + names = ["a", "b", "c"] + weights = [1.0, 2.0, 3.0] + a = WeightedRoundRobin(names, weights, rng=random.Random(0)) + b = WeightedRoundRobin(names, weights, rng=random.Random(0)) + assert [a.pick() for _ in range(100)] == [b.pick() for _ in range(100)] + + +def test_weighted_round_robin_respects_weights(): + """A heavily-weighted env dominates; a zero-weight env is never picked.""" + wrr = WeightedRoundRobin(["rare", "common", "never"], [1.0, 99.0, 0.0], rng=random.Random(0)) + picks = [wrr.pick() for _ in range(10_000)] + assert "never" not in picks + assert picks.count("common") > picks.count("rare") + + +def test_weighted_round_robin_rejects_empty(): + with pytest.raises(ValueError): + WeightedRoundRobin([], [], rng=random.Random(0))