Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 102 additions & 2 deletions packages/prime/src/prime_cli/verifiers_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,17 @@

from __future__ import annotations

import json
import os
import re
import subprocess
import sys
import tempfile
import uuid
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Optional
from typing import Any, Optional

import httpx
import toml
Expand Down Expand Up @@ -71,6 +73,97 @@ def is_help_request(primary_arg: str, passthrough_args: list[str]) -> bool:
return any(arg in ("-h", "--help") for arg in passthrough_args)


def _normalize_eval_sampling_args_dict(sampling_args: dict[str, Any]) -> bool:
if "enable_thinking" not in sampling_args:
return False

extra_body = sampling_args.get("extra_body")
if extra_body is not None and not isinstance(extra_body, dict):
return False

chat_template_kwargs = None
if isinstance(extra_body, dict):
chat_template_kwargs = extra_body.get("chat_template_kwargs")
if chat_template_kwargs is not None and not isinstance(chat_template_kwargs, dict):
return False

enable_thinking = sampling_args.pop("enable_thinking")
extra_body_dict = dict(extra_body) if isinstance(extra_body, dict) else {}
chat_template_kwargs_dict = (
dict(chat_template_kwargs) if isinstance(chat_template_kwargs, dict) else {}
)
chat_template_kwargs_dict.setdefault("enable_thinking", enable_thinking)
extra_body_dict["chat_template_kwargs"] = chat_template_kwargs_dict
sampling_args["extra_body"] = extra_body_dict
return True


def _normalize_eval_sampling_args(passthrough_args: list[str]) -> list[str]:
normalized = list(passthrough_args)

for i, arg in enumerate(normalized[:-1]):
if arg != "--sampling-args":
continue

try:
sampling_args = json.loads(normalized[i + 1])
except json.JSONDecodeError:
return normalized

if not isinstance(sampling_args, dict):
return normalized

if not _normalize_eval_sampling_args_dict(sampling_args):
return normalized

normalized[i + 1] = json.dumps(sampling_args, separators=(",", ":"))
return normalized
Comment thread
peter941221 marked this conversation as resolved.
Outdated

return normalized


def _normalize_eval_config_target(environment: str) -> tuple[str, Optional[Path]]:
if not _is_config_target(environment):
return environment, None

config_path = Path(environment)
try:
raw = toml.load(config_path)
except Exception:
return environment, None

if not isinstance(raw, dict):
return environment, None

eval_entries = raw.get("eval")
if not isinstance(eval_entries, list):
return environment, None

mutated = False
for entry in eval_entries:
if not isinstance(entry, dict):
continue
sampling_args = entry.get("sampling_args")
if not isinstance(sampling_args, dict):
continue
if _normalize_eval_sampling_args_dict(sampling_args):
mutated = True

if not mutated:
return environment, None

temp_file = tempfile.NamedTemporaryFile(
mode="w",
encoding="utf-8",
suffix=config_path.suffix or ".toml",
delete=False,
)
with temp_file:
toml.dump(raw, temp_file)

return temp_file.name, Path(temp_file.name)


def _sanitize_help_text(help_text: str, module_name: str, prime_command: str) -> str:
lines = help_text.splitlines()
for idx, line in enumerate(lines):
Expand Down Expand Up @@ -950,6 +1043,7 @@ def run_eval_passthrough(
) -> None:
plugin = load_verifiers_prime_plugin(console=console)
config = Config()
temp_config_path: Optional[Path] = None

if not config.api_key:
console.print(
Expand All @@ -958,6 +1052,7 @@ def run_eval_passthrough(
)
raise typer.Exit(1)

passthrough_args = _normalize_eval_sampling_args(passthrough_args)
args, env, model, base_url = _add_default_inference_and_key_args(passthrough_args, config)
configured_base_url = (config.inference_url or "").strip().rstrip("/")
_validate_model(model, base_url, configured_base_url)
Expand All @@ -969,6 +1064,7 @@ def run_eval_passthrough(
env_name_for_upload: Optional[str] = None
resolved_env: Optional[ResolvedEnvironment] = None
config_envs: list[tuple[str, str]] = []
run_target, temp_config_path = _normalize_eval_config_target(environment)
Comment thread
peter941221 marked this conversation as resolved.

if _is_config_target(environment):
config_envs = _collect_eval_config_envs(Path(environment), env_dir_path)
Expand Down Expand Up @@ -1000,7 +1096,11 @@ def run_eval_passthrough(

console.print(f"[dim]Eval job_id: {job_id}[/dim]")
command = plugin.build_module_command(plugin.eval_module, [run_target, *args])
_run_command(command, env=env)
try:
_run_command(command, env=env)
finally:
if temp_config_path is not None:
temp_config_path.unlink(missing_ok=True)
Comment thread
peter941221 marked this conversation as resolved.

if skip_upload:
_print_environment_source_footer(resolved_env)
Expand Down
97 changes: 97 additions & 0 deletions packages/prime/tests/test_eval_billing.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from pathlib import Path

import httpx
import pytest
import typer
Expand Down Expand Up @@ -560,6 +562,101 @@ def fake_prepare(_plugin, env_reference, env_dir_path):
assert prepared == [("wiki-search", "./environments")]


def test_eval_run_rewrites_enable_thinking_sampling_arg(monkeypatch):
monkeypatch.setattr(
"prime_cli.verifiers_bridge.load_verifiers_prime_plugin", lambda console: DummyPlugin()
)
monkeypatch.setattr("prime_cli.verifiers_bridge.Config", lambda: DummyConfig())
monkeypatch.setattr("prime_cli.verifiers_bridge._validate_model", lambda *args: None)
monkeypatch.setattr(
"prime_cli.verifiers_bridge._preflight_inference_billing",
lambda *args: None,
)
monkeypatch.setattr(
"prime_cli.verifiers_bridge._prepare_single_environment",
lambda *args, **kwargs: ResolvedEnvironment(
original="primeintellect/gsm8k",
env_name="gsm8k",
install_mode="remote",
),
)

commands = []

def fake_run_command(command, env=None):
commands.append(command)

monkeypatch.setattr("prime_cli.verifiers_bridge._run_command", fake_run_command)

run_eval_passthrough(
environment="primeintellect/gsm8k",
passthrough_args=[
"-m",
"Qwen/Qwen3.5-122B-A10B",
"--sampling-args",
'{"enable_thinking":false,"temperature":0.2}',
],
skip_upload=True,
env_path=None,
)

assert commands
sampling_args_index = commands[0].index("--sampling-args")
assert commands[0][sampling_args_index + 1] == (
'{"temperature":0.2,"extra_body":{"chat_template_kwargs":{"enable_thinking":false}}}'
)


def test_eval_run_rewrites_enable_thinking_in_config_sampling_args(monkeypatch, tmp_path):
monkeypatch.setattr(
"prime_cli.verifiers_bridge.load_verifiers_prime_plugin", lambda console: DummyPlugin()
)
monkeypatch.setattr("prime_cli.verifiers_bridge.Config", lambda: DummyConfig())
monkeypatch.setattr("prime_cli.verifiers_bridge._validate_model", lambda *args: None)
monkeypatch.setattr(
"prime_cli.verifiers_bridge._preflight_inference_billing",
lambda *args: None,
)

config_path = tmp_path / "eval.toml"
config_path.write_text(
"""
model = "Qwen/Qwen3.5-122B-A10B"

[[eval]]
env_id = "primeintellect/gsm8k"
sampling_args = { enable_thinking = false, temperature = 0.2 }
""".strip(),
encoding="utf-8",
)

prepared = []
commands = []

def fake_prepare(_plugin, env_reference, env_dir_path):
prepared.append((env_reference, env_dir_path))

def fake_run_command(command, env=None):
commands.append(command)

monkeypatch.setattr("prime_cli.verifiers_bridge._prepare_single_environment", fake_prepare)
monkeypatch.setattr("prime_cli.verifiers_bridge._run_command", fake_run_command)

run_eval_passthrough(
environment=str(config_path),
passthrough_args=[],
skip_upload=True,
env_path=None,
)

assert prepared == [("primeintellect/gsm8k", "./environments")]
assert commands
assert commands[0][1] != str(config_path)

rewritten_config = Path(commands[0][1])
assert not rewritten_config.exists()


def test_inference_client_uses_custom_timeout(monkeypatch):
monkeypatch.setattr("prime_cli.api.inference.Config", lambda: DummyConfig())

Expand Down