Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions packages/prime/src/prime_cli/api/rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,6 @@ def create_run(
secrets: Optional[Dict[str, str]] = None,
team_id: Optional[str] = None,
eval_config: Optional[Dict[str, Any]] = None,
val_config: Optional[Dict[str, Any]] = None,
buffer_config: Optional[Dict[str, Any]] = None,
learning_rate: Optional[float] = None,
lora_alpha: Optional[int] = None,
Expand Down Expand Up @@ -266,9 +265,6 @@ def create_run(
if eval_config:
payload["eval"] = eval_config

if val_config:
payload["val"] = val_config

if buffer_config:
payload["buffer"] = buffer_config

Expand Down
33 changes: 0 additions & 33 deletions packages/prime/src/prime_cli/commands/rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,12 +258,6 @@ def generate_rl_config_template(environment: str | None = None) -> str:
# num_examples = 30
# rollouts_per_example = 4

# Optional: validation during training
# [val]
# num_examples = 64
# rollouts_per_example = 1
# interval = 5

# Optional: buffer configuration for difficulty filtering
# [buffer]
# easy_threshold = 1.0
Expand Down Expand Up @@ -409,24 +403,6 @@ def to_api_dict(self) -> Dict[str, Any] | None:
return result


class ValConfig(BaseModel):
model_config = ConfigDict(extra="forbid")

num_examples: int | None = None
rollouts_per_example: int | None = None
interval: int | None = None

def to_api_dict(self) -> Dict[str, Any] | None:
result: Dict[str, Any] = {}
if self.num_examples is not None:
result["num_examples"] = self.num_examples
if self.rollouts_per_example is not None:
result["rollouts_per_example"] = self.rollouts_per_example
if self.interval is not None:
result["interval"] = self.interval
return result if result else None


class BufferConfig(BaseModel):
model_config = ConfigDict(extra="forbid")

Expand Down Expand Up @@ -602,7 +578,6 @@ class RLConfig(BaseModel):
env: List[EnvConfig] = Field(default_factory=list)
sampling: SamplingConfig = Field(default_factory=SamplingConfig)
eval: EvalConfig = Field(default_factory=EvalConfig)
val: ValConfig = Field(default_factory=ValConfig)
buffer: BufferConfig = Field(default_factory=BufferConfig)
wandb: WandbConfig = Field(default_factory=WandbConfig)
checkpoints: CheckpointsConfig = Field(default_factory=CheckpointsConfig)
Expand Down Expand Up @@ -964,13 +939,6 @@ def _fetch_pricing() -> None:
if cfg.eval.interval:
console.print(f" Interval: {cfg.eval.interval}")

# Validation
if cfg.val.num_examples is not None:
console.print("\n[cyan]Validation[/cyan]")
console.print(f" Num Examples: {cfg.val.num_examples}")
if cfg.val.interval:
console.print(f" Interval: {cfg.val.interval}")

# Infrastructure
if cfg.infrastructure.compute_size:
console.print("\n[cyan]Infrastructure[/cyan]")
Expand Down Expand Up @@ -1096,7 +1064,6 @@ def _format(list_p: Any, eff_p: Any) -> str:
secrets=secrets if secrets else None,
team_id=app_config.team_id,
eval_config=cfg.eval.to_api_dict(),
val_config=cfg.val.to_api_dict(),
buffer_config=cfg.buffer.to_api_dict(),
learning_rate=cfg.learning_rate,
lora_alpha=cfg.lora_alpha,
Expand Down
9 changes: 1 addition & 8 deletions packages/prime/src/prime_lab_app/training_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,6 @@ def training_config_toml(raw: dict[str, Any]) -> str:

if isinstance(raw.get("eval_config"), dict):
config["eval"] = _rl_eval_config(raw["eval_config"])
if isinstance(raw.get("val_config"), dict):
config["val"] = raw["val_config"]
if isinstance(raw.get("buffer_config"), dict):
config["buffer"] = raw["buffer_config"]

Expand Down Expand Up @@ -76,12 +74,7 @@ def normalize_rl_config(config: dict[str, Any]) -> dict[str, Any]:
updated["eval"] = _rl_eval_config(eval_config)
else:
updated.pop("eval_config", None)
if "val_config" in updated and "val" not in updated:
val_config = updated.pop("val_config")
if isinstance(val_config, dict):
updated["val"] = val_config
else:
updated.pop("val_config", None)
updated.pop("val_config", None)
if "buffer_config" in updated and "buffer" not in updated:
buffer_config = updated.pop("buffer_config")
if isinstance(buffer_config, dict):
Expand Down
31 changes: 31 additions & 0 deletions packages/prime/tests/test_lab_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -1760,6 +1760,37 @@ def test_lab_view_training_config_uses_sampling_for_max_tokens() -> None:
HostedRLConfig.model_validate(parsed)


def test_lab_view_training_config_drops_deprecated_val_config() -> None:
config = _training_config_toml(
{
"base_model": "Qwen/Qwen3.5-2B",
"environments": ["primeintellect/alphabet-sort"],
"val_config": {"interval": 5},
}
)

parsed = toml.loads(config)
assert "val" not in parsed
HostedRLConfig.model_validate(parsed)


def test_config_builder_preserves_deprecated_val_for_schema_error() -> None:
config = {
"model": "Qwen/Qwen3.5-2B",
"env": [{"id": "primeintellect/alphabet-sort"}],
"val": {"interval": 5},
}
values = {
"config-model": "Qwen/Qwen3.5-2B",
"config-envs": "primeintellect/alphabet-sort",
}

build = build_config_from_fields(config, "rl", lambda field_id: values.get(field_id, ""))

assert "[val]" in build.toml_text
assert build.errors == ("val: Extra inputs are not permitted",)


def test_lab_config_factory_renders_shared_eval_and_rl_templates() -> None:
eval_toml = format_lab_config(
evaluation_config(env_id="primeintellect/wordle", num_examples=-1, max_tokens=None)
Expand Down
14 changes: 14 additions & 0 deletions packages/prime/tests/test_rl_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ def test_generate_rl_config_template_keeps_default_surface_minimal() -> None:

assert 'model = "Qwen/Qwen3.5-0.8B"' in template
assert "# learning_rate = 3e-5 # optional; default is 1e-4" in template
assert "# [val]" not in template
assert "validation during training" not in template.lower()

hidden_fields = [
"oversampling_factor",
Expand Down Expand Up @@ -91,6 +93,18 @@ def test_flatten_config_schema_preserves_optional_array_item_types() -> None:
assert rows["buffer.env_ratios"] == "list[number]"


def test_load_config_rejects_deprecated_val_section(tmp_path: Path) -> None:
config_path = tmp_path / "rl.toml"
config_path.write_text(
'model = "dummy"\n'
"[val]\n"
"interval = 5\n"
)

with pytest.raises(typer.Exit):
load_config(str(config_path))


def test_load_config_accepts_sampling_reasoning_effort(tmp_path: Path) -> None:
config_path = tmp_path / "rl.toml"
config_path.write_text('model = "openai/gpt-oss-20b"\n[sampling]\nreasoning_effort = "high"\n')
Expand Down
9 changes: 9 additions & 0 deletions packages/prime/tests/test_train_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,15 @@ def test_train_init_defaults_to_rl_toml() -> None:
assert Path("rl.toml").exists()


def test_train_configs_omits_deprecated_val_section() -> None:
result = runner.invoke(app, ["train", "configs", "--output", "json"], env=TEST_ENV)

assert result.exit_code == 0, result.output
data = json.loads(result.stdout)
sections = {item["section"] for item in data["configs"]}
assert "val" not in sections


def test_train_request_submits_model_request(monkeypatch) -> None:
captured: dict[str, Any] = {}

Expand Down