Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions src/prime_rl/orchestrator/sampling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
"""Sampling strategies for training rollouts.

``EnvMixStrategy`` (global) decides *which* env to draw from next — a swappable
seam between the train envs and the dispatcher. The default
(``WeightedRoundRobin``) reproduces the previous ``TrainSource`` behavior: a
weighted random choice over env names, weighted by configured ``ratio`` (when
every env sets one) or per-env dataset size.
"""

from __future__ import annotations

import random
from abc import ABC, abstractmethod


class EnvMixStrategy(ABC):
"""Global: which env to draw from next. ``pick`` returns an env name."""

@abstractmethod
def pick(self) -> str:
"""Return the env name to sample from next."""
...


class WeightedRoundRobin(EnvMixStrategy):
"""Default env mix: weighted random choice over env names. Weights are the
configured per-env ratios (when all set) or per-env dataset sizes.

Draws from the caller's RNG so env selection stays in the same stream as
``TrainSource``'s dataset shuffles — the example sequence is unchanged.
"""

def __init__(self, env_names: list[str], weights: list[float], *, rng: random.Random) -> None:
if not env_names:
raise ValueError("WeightedRoundRobin needs at least one env")
self._rng = rng
self._env_names = list(env_names)
self._weights = list(weights)

def pick(self) -> str:
return self._rng.choices(self._env_names, weights=self._weights, k=1)[0]
28 changes: 18 additions & 10 deletions src/prime_rl/orchestrator/train_source.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,25 @@
"""TrainSource: weighted round-robin across train envs, infinite pull.

Weights default to configured ``ratio`` (when every env sets one) or to
per-env dataset size. ``next_example`` reshuffles on cursor exhaustion."""
Env selection is delegated to a swappable ``EnvMixStrategy`` (default:
weighted round-robin by configured ``ratio`` when every env sets one, else by
per-env dataset size); example selection stays here (a reshuffling cursor per
env). ``next_example`` reshuffles on cursor exhaustion. Returned dicts carry
``env_name`` + ``example_id``.
"""

from __future__ import annotations

import random

from prime_rl.orchestrator.envs import TrainEnvs
from prime_rl.orchestrator.sampling import WeightedRoundRobin


class TrainSource:
"""``next_example(available_permits)`` picks a weighted-RR env and
returns its next example (or ``None`` when the env's per-call permit
cost doesn't fit — the dispatch loop retries when permits free up).
Returned dicts carry ``env_name`` + ``example_id``."""
"""``next_example(available_permits)`` picks an env via the mix strategy and
returns its next example (or ``None`` when the env's per-call permit cost
doesn't fit — the dispatch loop retries when permits free up). Returned
dicts carry ``env_name`` + ``example_id``."""

def __init__(self, train_envs: TrainEnvs, *, seed: int | None) -> None:
self.rng = random.Random(seed)
Expand All @@ -38,15 +43,18 @@ def __init__(self, train_envs: TrainEnvs, *, seed: int | None) -> None:
self.cursors[env.name] = 0
self.env_costs[env.name] = env.config.group_size if env.requires_group_scoring else 1

self.env_names = [e.name for e in self.envs]
env_names = [e.name for e in self.envs]
configured_ratios = [e.config.ratio for e in self.envs]
if all(r is not None for r in configured_ratios):
self.weights: list[float] = [float(r) for r in configured_ratios] # type: ignore[arg-type]
weights: list[float] = [float(r) for r in configured_ratios] # type: ignore[arg-type]
else:
self.weights = [float(len(self.examples[name])) for name in self.env_names]
weights = [float(len(self.examples[name])) for name in env_names]
# Shares ``self.rng`` so env selection draws from the same stream as the
# dataset shuffles above — the example sequence matches the pre-seam path.
self.env_mix = WeightedRoundRobin(env_names, weights, rng=self.rng)

def next_example(self, available_permits: int) -> dict | None:
env_name = self.rng.choices(self.env_names, weights=self.weights, k=1)[0]
env_name = self.env_mix.pick()
if self.env_costs[env_name] > available_permits:
return None
rows = self.examples[env_name]
Expand Down
27 changes: 27 additions & 0 deletions tests/unit/orchestrator/test_sampling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import random

import pytest

from prime_rl.orchestrator.sampling import WeightedRoundRobin


def test_weighted_round_robin_is_deterministic_per_rng():
"""Same seed → same pick sequence (env selection is reproducible)."""
names = ["a", "b", "c"]
weights = [1.0, 2.0, 3.0]
a = WeightedRoundRobin(names, weights, rng=random.Random(0))
b = WeightedRoundRobin(names, weights, rng=random.Random(0))
assert [a.pick() for _ in range(100)] == [b.pick() for _ in range(100)]


def test_weighted_round_robin_respects_weights():
"""A heavily-weighted env dominates; a zero-weight env is never picked."""
wrr = WeightedRoundRobin(["rare", "common", "never"], [1.0, 99.0, 0.0], rng=random.Random(0))
picks = [wrr.pick() for _ in range(10_000)]
assert "never" not in picks
assert picks.count("common") > picks.count("rare")


def test_weighted_round_robin_rejects_empty():
with pytest.raises(ValueError):
WeightedRoundRobin([], [], rng=random.Random(0))
Loading