Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion deps/pydantic-config
2 changes: 2 additions & 0 deletions packages/prime-rl-configs/src/prime_rl/configs/env_server.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from pathlib import Path
from typing import ClassVar

from pydantic import model_validator

Expand All @@ -8,6 +9,7 @@


class EnvServerConfig(BaseConfig):
env_prefix: ClassVar[str] = "PRIME_RL_ENV_SERVER_"
env: EnvConfig = EnvConfig()

log: LogConfig = LogConfig()
Expand Down
6 changes: 3 additions & 3 deletions packages/prime-rl-configs/src/prime_rl/configs/inference.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
from argparse import Namespace
from pathlib import Path
from typing import Annotated, Any, Literal, TypeAlias
from typing import Annotated, Any, ClassVar, Literal, TypeAlias

from pydantic import Field, model_validator
from pydantic_config import BaseConfig

from prime_rl.configs.shared import BaseModelConfig, LogConfig, SlurmConfig
from prime_rl.utils.config import find_package_resource, rgetattr, rsetattr
from prime_rl.utils.config import BaseConfig, find_package_resource, rgetattr, rsetattr
from prime_rl.utils.parsers import resolve_reasoning_parser, resolve_tool_call_parser

# TODO: Set thinking/ solution budget
Expand Down Expand Up @@ -270,6 +269,7 @@ class InferenceExperimentalConfig(BaseConfig):


class InferenceConfig(BaseConfig):
env_prefix: ClassVar[str] = "PRIME_RL_INFER_"
server: ServerConfig = ServerConfig()

model: ModelConfig = Field(default_factory=ModelConfig)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import math
import warnings
from pathlib import Path
from typing import Annotated, Any, Literal, TypeAlias
from typing import Annotated, Any, ClassVar, Literal, TypeAlias

from pydantic import AliasChoices, Field, model_serializer, model_validator
from pydantic_core.core_schema import SerializerFunctionWrapHandler
Expand Down Expand Up @@ -500,6 +500,7 @@ class RolloutModelConfig(BaseConfig):


class OrchestratorConfig(BaseConfig):
env_prefix: ClassVar[str] = "PRIME_RL_ORCH_"
training_mode: Literal["rl", "opd", "sft"] = "rl"
"""Training mode. ``rl``: student generates rollouts, no teacher. ``opd``: student generates rollouts, teacher computes logprobs (teacher_tau > 0). ``sft``: teacher generates rollouts, student inference pool used for evals and weight sync."""

Expand Down
3 changes: 2 additions & 1 deletion packages/prime-rl-configs/src/prime_rl/configs/rl.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import warnings
from pathlib import Path
from typing import Annotated, Any, Literal, TypeAlias
from typing import Annotated, Any, ClassVar, Literal, TypeAlias

from pydantic import Field, model_validator

Expand Down Expand Up @@ -178,6 +178,7 @@ def total_infer_nodes(self) -> int:


class RLConfig(BaseConfig):
env_prefix: ClassVar[str] = "PRIME_RL_"
trainer: TrainerConfig

orchestrator: OrchestratorConfig
Expand Down
3 changes: 2 additions & 1 deletion packages/prime-rl-configs/src/prime_rl/configs/sft.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import warnings
from pathlib import Path
from typing import Annotated, Literal, TypeAlias
from typing import Annotated, ClassVar, Literal, TypeAlias

from pydantic import Field, model_validator
from renderers import RendererConfig
Expand Down Expand Up @@ -171,6 +171,7 @@ class SFTExperimentalConfig(BaseConfig):


class SFTConfig(BaseConfig):
env_prefix: ClassVar[str] = "PRIME_RL_SFT_"
model: ModelConfig = ModelConfig()

tokenizer: TokenizerConfig = TokenizerConfig()
Expand Down
3 changes: 2 additions & 1 deletion packages/prime-rl-configs/src/prime_rl/configs/trainer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import warnings
from pathlib import Path
from typing import Annotated, Any, Literal, TypeAlias
from typing import Annotated, Any, ClassVar, Literal, TypeAlias

from pydantic import Field, model_validator

Expand Down Expand Up @@ -500,6 +500,7 @@ class TrainerExperimentalConfig(BaseConfig):


class TrainerConfig(BaseConfig):
env_prefix: ClassVar[str] = "PRIME_RL_TRAINER_"
model: ModelConfig = ModelConfig()

tokenizer: TokenizerConfig = TokenizerConfig()
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ readme = "README.md"
requires-python = "~=3.12.0"
dependencies = [
"prime-rl-configs",
"prime-pydantic-config",
"beartype>=0.21.0",
"datasets>=4.0.0",
"jaxtyping>=0.3.2",
Expand Down
79 changes: 79 additions & 0 deletions tests/unit/utils/test_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
"""Tests for environment variable injection in prime-rl config classes."""

import os

import pytest

from prime_rl.configs.inference import InferenceConfig
from prime_rl.configs.orchestrator import OrchestratorConfig
from prime_rl.configs.rl import RLConfig
from prime_rl.configs.trainer import TrainerConfig
from prime_rl.utils.config import cli


@pytest.fixture(autouse=True)
def _clean_env(monkeypatch):
"""Remove any PRIME_RL_* env vars before each test."""
for key in list(os.environ):
if key.startswith("PRIME_RL_"):
monkeypatch.delenv(key, raising=False)


# ── OrchestratorConfig ──────────────────────────────────────────────


def test_orch_env_var_batch_size(monkeypatch):
monkeypatch.setenv("PRIME_RL_ORCH_BATCH_SIZE", "512")
config = OrchestratorConfig.model_validate({})
assert config.batch_size == 512


def test_orch_env_var_nested(monkeypatch):
monkeypatch.setenv("PRIME_RL_ORCH_STUDENT__MODEL__NAME", "Qwen/Qwen3-0.6B")
config = OrchestratorConfig.model_validate({})
assert config.student.model.name == "Qwen/Qwen3-0.6B"


def test_orch_env_var_cli_wins(monkeypatch):
monkeypatch.setenv("PRIME_RL_ORCH_BATCH_SIZE", "512")
config = cli(OrchestratorConfig, args=["--batch-size", "256"])
assert config.batch_size == 256


# ── TrainerConfig ───────────────────────────────────────────────────


def test_trainer_env_var_max_steps(monkeypatch):
monkeypatch.setenv("PRIME_RL_TRAINER_MAX_STEPS", "1000")
config = TrainerConfig.model_validate({})
assert config.max_steps == 1000


# ── InferenceConfig ────────────────────────────────────────────────


def test_infer_env_var(monkeypatch):
monkeypatch.setenv("PRIME_RL_INFER_MODEL__MAX_MODEL_LEN", "4096")
config = InferenceConfig.model_validate({})
assert config.model.max_model_len == 4096


# ── RLConfig ───────────────────────────────────────────────────────


def test_rl_env_var_propagates_to_orchestrator(monkeypatch):
"""Env vars on RLConfig (prefix PRIME_RL_) propagate to sub-configs
before the auto_setup_shared_configs validator runs."""
monkeypatch.setenv("PRIME_RL_ORCHESTRATOR__BATCH_SIZE", "512")
config = RLConfig.model_validate({"trainer": {}, "orchestrator": {}, "inference": {}})
assert config.orchestrator.batch_size == 512


# ── No prefix leakage ──────────────────────────────────────────────


def test_orch_prefix_does_not_leak_to_trainer(monkeypatch):
monkeypatch.setenv("PRIME_RL_ORCH_BATCH_SIZE", "999")
config = TrainerConfig.model_validate({})
# TrainerConfig has prefix PRIME_RL_TRAINER_; ORCH_ vars shouldn't affect it
assert config.max_steps == TrainerConfig().max_steps
28 changes: 22 additions & 6 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading