diff --git a/packages/prime/src/prime_cli/api/rl.py b/packages/prime/src/prime_cli/api/rl.py index dd1cd68f5..a0c51e461 100644 --- a/packages/prime/src/prime_cli/api/rl.py +++ b/packages/prime/src/prime_cli/api/rl.py @@ -192,7 +192,6 @@ def create_run( secrets: Optional[Dict[str, str]] = None, team_id: Optional[str] = None, eval_config: Optional[Dict[str, Any]] = None, - val_config: Optional[Dict[str, Any]] = None, buffer_config: Optional[Dict[str, Any]] = None, learning_rate: Optional[float] = None, lora_alpha: Optional[int] = None, @@ -266,9 +265,6 @@ def create_run( if eval_config: payload["eval"] = eval_config - if val_config: - payload["val"] = val_config - if buffer_config: payload["buffer"] = buffer_config diff --git a/packages/prime/src/prime_cli/commands/rl.py b/packages/prime/src/prime_cli/commands/rl.py index 70143e1cf..48638524b 100644 --- a/packages/prime/src/prime_cli/commands/rl.py +++ b/packages/prime/src/prime_cli/commands/rl.py @@ -258,12 +258,6 @@ def generate_rl_config_template(environment: str | None = None) -> str: # num_examples = 30 # rollouts_per_example = 4 -# Optional: validation during training -# [val] -# num_examples = 64 -# rollouts_per_example = 1 -# interval = 5 - # Optional: buffer configuration for difficulty filtering # [buffer] # easy_threshold = 1.0 @@ -409,24 +403,6 @@ def to_api_dict(self) -> Dict[str, Any] | None: return result -class ValConfig(BaseModel): - model_config = ConfigDict(extra="forbid") - - num_examples: int | None = None - rollouts_per_example: int | None = None - interval: int | None = None - - def to_api_dict(self) -> Dict[str, Any] | None: - result: Dict[str, Any] = {} - if self.num_examples is not None: - result["num_examples"] = self.num_examples - if self.rollouts_per_example is not None: - result["rollouts_per_example"] = self.rollouts_per_example - if self.interval is not None: - result["interval"] = self.interval - return result if result else None - - class BufferConfig(BaseModel): model_config = ConfigDict(extra="forbid") @@ -602,7 +578,6 @@ class RLConfig(BaseModel): env: List[EnvConfig] = Field(default_factory=list) sampling: SamplingConfig = Field(default_factory=SamplingConfig) eval: EvalConfig = Field(default_factory=EvalConfig) - val: ValConfig = Field(default_factory=ValConfig) buffer: BufferConfig = Field(default_factory=BufferConfig) wandb: WandbConfig = Field(default_factory=WandbConfig) checkpoints: CheckpointsConfig = Field(default_factory=CheckpointsConfig) @@ -964,13 +939,6 @@ def _fetch_pricing() -> None: if cfg.eval.interval: console.print(f" Interval: {cfg.eval.interval}") - # Validation - if cfg.val.num_examples is not None: - console.print("\n[cyan]Validation[/cyan]") - console.print(f" Num Examples: {cfg.val.num_examples}") - if cfg.val.interval: - console.print(f" Interval: {cfg.val.interval}") - # Infrastructure if cfg.infrastructure.compute_size: console.print("\n[cyan]Infrastructure[/cyan]") @@ -1096,7 +1064,6 @@ def _format(list_p: Any, eff_p: Any) -> str: secrets=secrets if secrets else None, team_id=app_config.team_id, eval_config=cfg.eval.to_api_dict(), - val_config=cfg.val.to_api_dict(), buffer_config=cfg.buffer.to_api_dict(), learning_rate=cfg.learning_rate, lora_alpha=cfg.lora_alpha, diff --git a/packages/prime/src/prime_lab_app/training_config.py b/packages/prime/src/prime_lab_app/training_config.py index 406a6d80c..b594353e3 100644 --- a/packages/prime/src/prime_lab_app/training_config.py +++ b/packages/prime/src/prime_lab_app/training_config.py @@ -46,8 +46,6 @@ def training_config_toml(raw: dict[str, Any]) -> str: if isinstance(raw.get("eval_config"), dict): config["eval"] = _rl_eval_config(raw["eval_config"]) - if isinstance(raw.get("val_config"), dict): - config["val"] = raw["val_config"] if isinstance(raw.get("buffer_config"), dict): config["buffer"] = raw["buffer_config"] @@ -76,12 +74,7 @@ def normalize_rl_config(config: dict[str, Any]) -> dict[str, Any]: updated["eval"] = _rl_eval_config(eval_config) else: updated.pop("eval_config", None) - if "val_config" in updated and "val" not in updated: - val_config = updated.pop("val_config") - if isinstance(val_config, dict): - updated["val"] = val_config - else: - updated.pop("val_config", None) + updated.pop("val_config", None) if "buffer_config" in updated and "buffer" not in updated: buffer_config = updated.pop("buffer_config") if isinstance(buffer_config, dict): diff --git a/packages/prime/tests/test_lab_view.py b/packages/prime/tests/test_lab_view.py index 395776877..722fbc624 100644 --- a/packages/prime/tests/test_lab_view.py +++ b/packages/prime/tests/test_lab_view.py @@ -1760,6 +1760,37 @@ def test_lab_view_training_config_uses_sampling_for_max_tokens() -> None: HostedRLConfig.model_validate(parsed) +def test_lab_view_training_config_drops_deprecated_val_config() -> None: + config = _training_config_toml( + { + "base_model": "Qwen/Qwen3.5-2B", + "environments": ["primeintellect/alphabet-sort"], + "val_config": {"interval": 5}, + } + ) + + parsed = toml.loads(config) + assert "val" not in parsed + HostedRLConfig.model_validate(parsed) + + +def test_config_builder_preserves_deprecated_val_for_schema_error() -> None: + config = { + "model": "Qwen/Qwen3.5-2B", + "env": [{"id": "primeintellect/alphabet-sort"}], + "val": {"interval": 5}, + } + values = { + "config-model": "Qwen/Qwen3.5-2B", + "config-envs": "primeintellect/alphabet-sort", + } + + build = build_config_from_fields(config, "rl", lambda field_id: values.get(field_id, "")) + + assert "[val]" in build.toml_text + assert build.errors == ("val: Extra inputs are not permitted",) + + def test_lab_config_factory_renders_shared_eval_and_rl_templates() -> None: eval_toml = format_lab_config( evaluation_config(env_id="primeintellect/wordle", num_examples=-1, max_tokens=None) diff --git a/packages/prime/tests/test_rl_config.py b/packages/prime/tests/test_rl_config.py index e55108b4d..9a2a13f3c 100644 --- a/packages/prime/tests/test_rl_config.py +++ b/packages/prime/tests/test_rl_config.py @@ -55,6 +55,8 @@ def test_generate_rl_config_template_keeps_default_surface_minimal() -> None: assert 'model = "Qwen/Qwen3.5-0.8B"' in template assert "# learning_rate = 3e-5 # optional; default is 1e-4" in template + assert "# [val]" not in template + assert "validation during training" not in template.lower() hidden_fields = [ "oversampling_factor", @@ -91,6 +93,18 @@ def test_flatten_config_schema_preserves_optional_array_item_types() -> None: assert rows["buffer.env_ratios"] == "list[number]" +def test_load_config_rejects_deprecated_val_section(tmp_path: Path) -> None: + config_path = tmp_path / "rl.toml" + config_path.write_text( + 'model = "dummy"\n' + "[val]\n" + "interval = 5\n" + ) + + with pytest.raises(typer.Exit): + load_config(str(config_path)) + + def test_load_config_accepts_sampling_reasoning_effort(tmp_path: Path) -> None: config_path = tmp_path / "rl.toml" config_path.write_text('model = "openai/gpt-oss-20b"\n[sampling]\nreasoning_effort = "high"\n') diff --git a/packages/prime/tests/test_train_cli.py b/packages/prime/tests/test_train_cli.py index 3259213d1..b466c7519 100644 --- a/packages/prime/tests/test_train_cli.py +++ b/packages/prime/tests/test_train_cli.py @@ -65,6 +65,15 @@ def test_train_init_defaults_to_rl_toml() -> None: assert Path("rl.toml").exists() +def test_train_configs_omits_deprecated_val_section() -> None: + result = runner.invoke(app, ["train", "configs", "--output", "json"], env=TEST_ENV) + + assert result.exit_code == 0, result.output + data = json.loads(result.stdout) + sections = {item["section"] for item in data["configs"]} + assert "val" not in sections + + def test_train_request_submits_model_request(monkeypatch) -> None: captured: dict[str, Any] = {}