diff --git a/.github/workflows/cpu_tests.yaml b/.github/workflows/cpu_tests.yaml index 4ed424bcb6..12fdd881ed 100644 --- a/.github/workflows/cpu_tests.yaml +++ b/.github/workflows/cpu_tests.yaml @@ -77,8 +77,11 @@ jobs: from prime_rl.utils.config import BaseConfig, find_package_resource, rgetattr, rsetattr from prime_rl.utils.validation import validate_shared_ckpt_config + # `verifiers` (+ its `datasets` dep) is a declared slim dep: the v1 config types + # (EnvConfig, Task, ...) extend verifiers.v1, which is pure-pydantic and pulls no + # GPU/ML deps. We still forbid the actual heavy training deps below. forbidden = ["torch", "transformers", "vllm", "wandb", "ring_flash_attn", - "verifiers", "prime", "datasets", "liger_kernel", "loguru"] + "prime", "liger_kernel", "loguru"] leaked = [m for m in forbidden if m in sys.modules] if leaked: raise SystemExit(f"slim install leaked heavy deps into sys.modules: {leaked}") diff --git a/.gitmodules b/.gitmodules index 2041f460ee..50058c2d22 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,6 +1,3 @@ -[submodule "verifiers"] - path = deps/verifiers - url = git@github.com:PrimeIntellect-ai/verifiers.git [submodule "renderers"] path = deps/renderers url = git@github.com:PrimeIntellect-ai/renderers.git @@ -13,3 +10,7 @@ [submodule "pydantic-config"] path = deps/pydantic-config url = https://github.com/PrimeIntellect-ai/pydantic-config +[submodule "deps/verifiers"] + path = deps/verifiers + url = git@github.com:PrimeIntellect-ai/verifiers.git + branch = feat/nano-as-v1 diff --git a/configs/ci/integration/reverse_text_rl_sft/start.toml b/configs/ci/integration/reverse_text_rl_sft/start.toml index 6b26bb3335..e658fab5d4 100644 --- a/configs/ci/integration/reverse_text_rl_sft/start.toml +++ b/configs/ci/integration/reverse_text_rl_sft/start.toml @@ -19,6 +19,10 @@ training_mode = "sft" batch_size = 128 group_size = 16 +# Teacher rolls out over plain chat-completions (no tokens); the renderer backfills them. +[orchestrator.renderer] +name = "qwen3" + [orchestrator.train.sampling] max_completion_tokens = 128 diff --git a/configs/debug/reverse_text_v1.toml b/configs/debug/reverse_text_v1.toml deleted file mode 100644 index 1d23b965ac..0000000000 --- a/configs/debug/reverse_text_v1.toml +++ /dev/null @@ -1,31 +0,0 @@ -max_steps = 20 -seq_len = 2048 - -[model] -name = "PrimeIntellect/Qwen3-0.6B-Reverse-Text-SFT" - -[wandb] -project = "reverse-text" -name = "reverse-text-v1" - -[orchestrator] -batch_size = 128 -group_size = 16 - -[orchestrator.train.sampling] -max_completion_tokens = 128 - -[[orchestrator.train.env]] -id = "reverse-text" -args = { v1 = true } - -[trainer.optim] -lr = 3e-6 - -[ckpt] # Checkpoint at the end of training - -[inference] - -# Model not in MODEL_RENDERER_MAP — opt into DefaultRenderer (apply_chat_template). -[orchestrator.renderer] -name = "default" diff --git a/configs/debug/training_modes/README.md b/configs/debug/training_modes/README.md index 96ccebb009..c789f43b18 100644 --- a/configs/debug/training_modes/README.md +++ b/configs/debug/training_modes/README.md @@ -9,9 +9,8 @@ Minimal end-to-end configs for the three training modes (`rl` / `opd` / `sft`) a | `opd_lora.toml` | `opd` | local vLLM (`Qwen3-0.6B-Reverse-Text-RL`) | trains a LoRA adapter (rank 8) | | `sft.toml` | `sft` | local vLLM (`Qwen3-0.6B-Reverse-Text-RL`) | | | `sft_lora.toml` | `sft` | local vLLM (`Qwen3-0.6B-Reverse-Text-RL`) | trains a LoRA adapter (rank 8) | -| `sft_external.toml` | `sft` | PI inference (`openai/gpt-5-mini`) | external OAI endpoint; no local teacher | -The student inference server is auto-launched on GPU 0 at `http://localhost:8000/v1` with `gpu_memory_utilization=0.5`. The local teacher (used by everything except `rl.toml` and `sft_external.toml`) is **not** auto-launched — start it manually on GPU 1. +The student inference server is auto-launched on GPU 0 at `http://localhost:8000/v1` with `gpu_memory_utilization=0.5`. The local teacher (used by everything except `rl.toml`) is **not** auto-launched — start it manually on GPU 1. ## Start the local teacher @@ -38,10 +37,6 @@ uv run rl @ configs/debug/training_modes/opd_lora.toml # SFT hard distill (needs teacher on port 8001) uv run rl @ configs/debug/training_modes/sft.toml uv run rl @ configs/debug/training_modes/sft_lora.toml - -# SFT hard distill from openai/gpt-5-mini via PI inference -# (requires PRIME_API_KEY + PRIME_TEAM_ID in env; no local teacher needed) -uv run rl @ configs/debug/training_modes/sft_external.toml ``` See [docs/training.md](../../docs/training.md#training-modes-rl--opd--sft-via-orchestrator) for what each mode does. diff --git a/configs/debug/training_modes/sft.toml b/configs/debug/training_modes/sft.toml index aed5b30cb3..46c2afd1e8 100644 --- a/configs/debug/training_modes/sft.toml +++ b/configs/debug/training_modes/sft.toml @@ -20,6 +20,10 @@ training_mode = "sft" batch_size = 128 group_size = 4 +# Teacher rolls out over plain chat-completions (no tokens); the renderer backfills them. +[orchestrator.renderer] +name = "qwen3" + [orchestrator.train.sampling] max_completion_tokens = 128 diff --git a/configs/debug/training_modes/sft_external.toml b/configs/debug/training_modes/sft_external.toml deleted file mode 100644 index cb9ea8d09e..0000000000 --- a/configs/debug/training_modes/sft_external.toml +++ /dev/null @@ -1,55 +0,0 @@ -# SFT from openai/gpt-5-mini via PI inference. -# Requires PRIME_API_KEY + PRIME_TEAM_ID in the environment. -# -# Run with: -# uv run rl @ configs/debug/training_modes/sft_external.toml - -max_steps = 20 -seq_len = 2048 - -[model] -name = "PrimeIntellect/Qwen3-0.6B-Reverse-Text-SFT" - -[wandb] -project = "reverse-text-debug" -name = "debug-sft-external" - -[orchestrator] -training_mode = "sft" -batch_size = 128 -group_size = 4 - -[orchestrator.train.sampling] -max_completion_tokens = 2048 -extra_body = { reasoning_effort = "minimal" } - -[[orchestrator.train.env]] -id = "reverse-text" - -[orchestrator.eval] -interval = 1 -num_examples = 128 - -[orchestrator.eval.sampling] -max_completion_tokens = 128 - -[[orchestrator.eval.env]] -id = "reverse-text" - -[orchestrator.teacher.model] -name = "openai/gpt-5-mini" - -[orchestrator.teacher.client] -base_url = ["https://api.pinference.ai/api/v1"] -api_key_var = "PRIME_API_KEY" - -[orchestrator.teacher.client.headers_from_env] -X-Prime-Team-ID = "PRIME_TEAM_ID" - -[trainer.optim] -lr = 3e-6 - -[ckpt] - -[inference] -gpu_memory_utilization = 0.5 diff --git a/configs/debug/training_modes/sft_lora.toml b/configs/debug/training_modes/sft_lora.toml index 687b45bbe3..b6ad2a8df6 100644 --- a/configs/debug/training_modes/sft_lora.toml +++ b/configs/debug/training_modes/sft_lora.toml @@ -20,6 +20,10 @@ training_mode = "sft" batch_size = 128 group_size = 4 +# Teacher rolls out over plain chat-completions (no tokens); the renderer backfills them. +[orchestrator.renderer] +name = "qwen3" + [orchestrator.train.sampling] max_completion_tokens = 128 diff --git a/configs/debug/v1/alphabet_sort.toml b/configs/debug/v1/alphabet_sort.toml new file mode 100644 index 0000000000..39c40db77f --- /dev/null +++ b/configs/debug/v1/alphabet_sort.toml @@ -0,0 +1,42 @@ +# v1 port of examples/alphabet_sort/rl.toml — identical except the env block, which loads +# the v1 `alphabet-sort-v1` taskset (multi-turn via a colocated vf.User) instead of the v0 +# `primeintellect/alphabet-sort` env. Harness runs on the subprocess runtime. + +max_steps = 200 +seq_len = 2048 + +[ckpt] # Checkpoint at the end of training + +[model] +name = "Qwen/Qwen3-4B-Instruct-2507" + +[wandb] +project = "alphabet-sort" +name = "alphabet-sort" + +[trainer.model] +impl = "auto" + +[trainer.model.ac] +freq = 1 + +[trainer.model.lora] +rank = 32 +alpha = 64 + +[trainer.optim] +lr = 1e-5 + +[orchestrator] +batch_size = 512 +group_size = 8 + +[orchestrator.train.sampling] +max_completion_tokens = 768 + +[[orchestrator.train.env]] +name = "alphabet-sort" +taskset = { id = "alphabet-sort-v1", min_turns = 3, max_turns = 5, power_per_turn = false } +harness = { id = "default", enable_bash = false, runtime = { type = "subprocess" } } + +[inference] # Default inference config diff --git a/configs/debug/v1/hendrycks_sanity.toml b/configs/debug/v1/hendrycks_sanity.toml new file mode 100644 index 0000000000..aae81b33f9 --- /dev/null +++ b/configs/debug/v1/hendrycks_sanity.toml @@ -0,0 +1,52 @@ +# v1 analog of examples/hendrycks_sanity/rl.toml — identical config with only +# the env sections swapped to v1 taskset/harness syntax (math-env taskset, +# default harness, subprocess runtime). Submits to slurm by default; pass --no-slurm +# to run locally. + +output_dir = "/beegfs/mika/hendrycks-sanity-v1-subprocess" +max_steps = 5000 + +[wandb] +project = "hendrycks-sanity" +name = "v1" + +[deployment] +num_train_gpus = 4 +num_infer_gpus = 4 + +[slurm] +job_name = "v1-subprocess" +partition = "cluster" + +[model] +name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B" + +[orchestrator] +batch_size = 512 +group_size = 8 +seq_len = 8192 + +[[orchestrator.train.env]] +name = "hendrycks-math" +taskset = { id = "math-env-v1", dataset_name = "mikasenghaas/Sanity-Test-R1D-1.5B", dataset_subset = "default" } +harness = { id = "default", enable_bash = false, runtime = { type = "subprocess" } } + +[orchestrator.eval] +interval = 50 + +[[orchestrator.eval.env]] +name = "aime2024" +taskset = { id = "aime24-v1" } +harness = { id = "default", enable_bash = false, runtime = { type = "subprocess" } } +group_size = 32 + +[trainer.model] +seq_len = 16384 + +[inference.model] +max_model_len = 8192 + +# Model not in MODEL_RENDERER_MAP — opt into DefaultRenderer (apply_chat_template). +[orchestrator.renderer] +name = "default" +reasoning_parser = "think" diff --git a/configs/debug/v1/r2e_gym.toml b/configs/debug/v1/r2e_gym.toml new file mode 100644 index 0000000000..2f45d8561f --- /dev/null +++ b/configs/debug/v1/r2e_gym.toml @@ -0,0 +1,86 @@ +# v1 port of configs/rlm_swe/qwen35_4b.toml — identical except the env blocks, which load +# the v1 `r2e-gym-v1` taskset through the rlm harness on the prime runtime instead of the +# v0 `rlm_swe` composable env. Same model / trainer / inference / orchestrator knobs. + +output_dir = "/beegfs/mika/rlm-swe-qwen35-4b" +max_steps = 400 +seq_len = 65536 + +[slurm] +job_name = "rlm-swe-qwen35-4b" +project_dir = "." +pre_run_command = "prime sandbox delete --label rlm-swe-qwen35-4b -y --plain || true" + +[deployment] +type = "multi_node" +num_train_nodes = 1 +num_infer_nodes = 1 +num_infer_replicas = 2 + +[wandb] +project = "rlm-swe-debug" +name = "qwen35-4b" + +[weight_broadcast] +type = "nccl" + +[ckpt] +interval = 50 +keep_last = 1 +resume_step = -1 + +[model] +name = "Qwen/Qwen3.5-4B" + +# --- Trainer --- + +[trainer] + +[trainer.model] +cp = 4 +cp_style = "ulysses" + +[trainer.model.ac] +freq = 1 + +[trainer.model.compile] + +# --- Orchestrator --- + +[orchestrator] +batch_size = 256 +group_size = 8 +max_inflight_rollouts = 512 +max_off_policy_steps = 16 + +# Thinking enabled for the Qwen3.5 renderer. +[orchestrator.renderer] +name = "qwen3.5" +enable_thinking = true + +[orchestrator.train.sampling] +temperature = 1.0 + +[[orchestrator.train.env]] +name = "rlm-swe-r2e" +taskset = { id = "r2e-gym-v1" } +harness = { id = "rlm", runtime = { type = "prime", labels = ["rlm-swe-qwen35-4b"] } } + +[orchestrator.prime_monitor] + +# --- Inference --- + +[inference] +gpu_memory_utilization = 0.85 +enable_prefix_caching = true + +[inference.model] +max_model_len = 65536 + +[inference.parallel] +dp = 8 + +# Qwen3.5-4B is a VL model; skip the vision tower for text-only SWE. +# `language_model_only` is a vLLM MultiModalConfig arg (no prime-rl field) → pass via vllm_extra. +[inference.vllm_extra] +language_model_only = true diff --git a/configs/debug/v1/reverse_text.toml b/configs/debug/v1/reverse_text.toml new file mode 100644 index 0000000000..220c427d53 --- /dev/null +++ b/configs/debug/v1/reverse_text.toml @@ -0,0 +1,42 @@ +# Debug RL run on the v1 env server (reverse-text starter). +# The orchestrator spawns a v1 EnvServer per env (it never loads the env +# itself), dispatches rollouts by task index, and trains on the returned Traces +# (renderer-tokenized). Light settings for a quick end-to-end smoke. + +max_steps = 20 +seq_len = 2048 + +[wandb] +project = "reverse-text" +name = "v1" + +[model] +name = "PrimeIntellect/Qwen3-0.6B-Reverse-Text-SFT" + +[orchestrator] +training_mode = "rl" +batch_size = 128 +group_size = 16 + +[orchestrator.renderer] +name = "qwen3" + +[orchestrator.train.sampling] +max_completion_tokens = 128 + +[[orchestrator.train.env]] +taskset = { id = "reverse-text-v1" } +# reverse-text is a pure-text single-turn env: disable the bash tool (the model answers +# directly) and use the subprocess runtime (no docker). +harness = { id = "default", enable_bash = false, runtime = { type = "subprocess" } } + +# No eval block, mirroring examples/reverse_text/rl.toml — this is a train-only +# smoke. Add an [orchestrator.eval] block (with an interval) to exercise eval. + +[trainer.optim] +lr = 3e-6 + +[ckpt] + +[inference] +gpu_memory_utilization = 0.5 diff --git a/configs/rlm_swe/qwen35_4b.toml b/configs/rlm_swe/qwen35_4b.toml index 626b62b8ef..0ae908c7de 100644 --- a/configs/rlm_swe/qwen35_4b.toml +++ b/configs/rlm_swe/qwen35_4b.toml @@ -60,7 +60,6 @@ temperature = 1.0 [[orchestrator.train.env]] id = "rlm_swe" name = "rlm-swe-r2e" -num_workers = 4 [orchestrator.train.env.args] labels = ["rlm-swe-qwen35-4b"] @@ -73,8 +72,7 @@ interval = 20 [[orchestrator.eval.env]] id = "rlm_swe" name = "rlm-swe-swebench-verified-quick" -num_workers = 4 -timeout = 3600 +timeout = { rollout = 3600 } [orchestrator.eval.env.args] task_type = "swebench" diff --git a/configs/v1/multimodal_color_codeword.toml b/configs/v1/multimodal_color_codeword.toml new file mode 100644 index 0000000000..e6bc07456c --- /dev/null +++ b/configs/v1/multimodal_color_codeword.toml @@ -0,0 +1,88 @@ +# 2-GPU v1 RL run for the multimodal (renderer) path: Qwen3-VL-4B on the color-codeword-v1 +# taskset. The v1 port of configs/debug/multimodal.toml — same model / sizing, but the env is +# the native v1 taskset (turn-0 images seeded in the Messages instruction; later turns injected +# by the colocated user simulator) rather than the v0 `color-codeword` env id. +# uv run rl @ configs/v1/multimodal_color_codeword.toml + +max_steps = 15 +seq_len = 4096 + +[model] +name = "Qwen/Qwen3-VL-4B-Instruct" + +[model.vlm] +vision_encoder_attr = "model.visual" +language_model_attr = "model.language_model" + +[deployment] +num_train_gpus = 1 +num_infer_gpus = 1 +gpus_per_node = 2 + +[wandb] +project = "multimodal-color-codeword-debug" +name = "v1" + +[orchestrator] +training_mode = "rl" +batch_size = 256 +group_size = 16 +# Image processor is CPU-bound and dominates for VLMs; returns diminish past 4. +pool_size = 4 + +# Step 0 on Qwen3-VL-4B vs color-codeword can be uniform (all-correct or all-wrong), so don't +# enforce zero-advantage dropping or training would crash before any progress. +[[orchestrator.post_batch_filters]] +type = "gibberish" + +[[orchestrator.post_batch_filters]] +type = "repetition" + +[[orchestrator.post_batch_filters]] +type = "zero_advantage" +enforce = false + +# Renderer left as the default AutoRendererConfig: it resolves Qwen3-VL-4B-Instruct to the +# Qwen3VLRenderer via MODEL_RENDERER_MAP, so no explicit name is needed. + +[orchestrator.train.sampling] +max_completion_tokens = 64 + +[[orchestrator.train.env]] +taskset = { id = "color-codeword-v1", images_per_turn = 1, num_examples = 1000 } +# Multi-turn VLM env: turn-0 squares ride in the task's Messages instruction, later turns come +# from the colocated user simulator. Subprocess runtime (no docker). +harness = { id = "default", enable_bash = false, runtime = { type = "subprocess" } } + +[orchestrator.eval] +interval = 5 +num_examples = 32 + +[orchestrator.eval.sampling] +max_completion_tokens = 64 + +[[orchestrator.eval.env]] +taskset = { id = "color-codeword-v1", images_per_turn = 1, num_examples = 1000 } +harness = { id = "default", enable_bash = false, runtime = { type = "subprocess" } } + +[trainer] + +[trainer.model] +optimization_dtype = "bfloat16" +reduce_dtype = "bfloat16" + +[trainer.optim] +lr = 3e-6 + +[ckpt] + +[inference] + +[inference.model] +# vLLM 0.20.1 Qwen3-VL deepstack buffer bug under cudagraph padding: eager mode keeps +# num_input_tokens == num_scheduled_tokens. +enforce_eager = true + +[inference.parallel] +dp = 1 +tp = 1 diff --git a/configs/v1/training_mode/opd.toml b/configs/v1/training_mode/opd.toml new file mode 100644 index 0000000000..d8223d1ceb --- /dev/null +++ b/configs/v1/training_mode/opd.toml @@ -0,0 +1,58 @@ +# Debug OPD (on-policy distillation) run on the v1 env server (reverse-text-v1). +# Student generates rollouts (renderer client); teacher computes logprobs. +# Start the teacher inference server first (on a separate GPU): +# CUDA_VISIBLE_DEVICES=1 uv run inference \ +# --model.name PrimeIntellect/Qwen3-0.6B-Reverse-Text-RL \ +# --server.port 8001 --gpu-memory-utilization 0.5 --model.enforce-eager +# Then: +# uv run rl @ configs/v1/training_mode/opd.toml + +max_steps = 20 +seq_len = 2048 + +[model] +name = "PrimeIntellect/Qwen3-0.6B-Reverse-Text-SFT" + +[wandb] +project = "reverse-text-v1-debug" +name = "debug-opd" + +[orchestrator] +training_mode = "opd" +batch_size = 128 +group_size = 16 + +[orchestrator.renderer] +name = "qwen3" + +[orchestrator.train.sampling] +max_completion_tokens = 128 + +[[orchestrator.train.env]] +taskset = { id = "reverse-text-v1" } +harness = { id = "default", enable_bash = false, runtime = { type = "subprocess" } } + +[orchestrator.eval] +interval = 1 +num_examples = 128 + +[orchestrator.eval.sampling] +max_completion_tokens = 128 + +[[orchestrator.eval.env]] +taskset = { id = "reverse-text-v1" } +harness = { id = "default", enable_bash = false, runtime = { type = "subprocess" } } + +[orchestrator.teacher.model] +name = "PrimeIntellect/Qwen3-0.6B-Reverse-Text-RL" + +[orchestrator.teacher.client] +base_url = ["http://localhost:8001/v1"] + +[trainer.optim] +lr = 3e-6 + +[ckpt] + +[inference] +gpu_memory_utilization = 0.5 diff --git a/configs/v1/training_mode/opd_lora.toml b/configs/v1/training_mode/opd_lora.toml new file mode 100644 index 0000000000..4c2ef8d493 --- /dev/null +++ b/configs/v1/training_mode/opd_lora.toml @@ -0,0 +1,63 @@ +# Debug OPD run on the v1 env server (reverse-text-v1), training a LoRA adapter. +# Start the teacher inference server first (on a separate GPU): +# CUDA_VISIBLE_DEVICES=1 uv run inference \ +# --model.name PrimeIntellect/Qwen3-0.6B-Reverse-Text-RL \ +# --server.port 8001 --gpu-memory-utilization 0.5 --model.enforce-eager +# Then: +# uv run rl @ configs/v1/training_mode/opd_lora.toml + +max_steps = 20 +seq_len = 2048 + +[model] +name = "PrimeIntellect/Qwen3-0.6B-Reverse-Text-SFT" + +[wandb] +project = "reverse-text-v1-debug" +name = "debug-opd-lora" + +[orchestrator] +training_mode = "opd" +batch_size = 128 +group_size = 16 + +[orchestrator.renderer] +name = "qwen3" + +[orchestrator.train.sampling] +max_completion_tokens = 128 + +[[orchestrator.train.env]] +taskset = { id = "reverse-text-v1" } +harness = { id = "default", enable_bash = false, runtime = { type = "subprocess" } } + +[orchestrator.eval] +interval = 1 +num_examples = 128 + +[orchestrator.eval.sampling] +max_completion_tokens = 128 + +[[orchestrator.eval.env]] +taskset = { id = "reverse-text-v1" } +harness = { id = "default", enable_bash = false, runtime = { type = "subprocess" } } + +[orchestrator.teacher.model] +name = "PrimeIntellect/Qwen3-0.6B-Reverse-Text-RL" + +[orchestrator.teacher.client] +base_url = ["http://localhost:8001/v1"] + +[trainer.optim] +lr = 1e-4 + +[trainer.model.lora] +rank = 8 + +[trainer.ckpt.weights] +save_adapter_separately = true + +[ckpt] + +[inference] +gpu_memory_utilization = 0.5 diff --git a/configs/v1/training_mode/rl.toml b/configs/v1/training_mode/rl.toml new file mode 100644 index 0000000000..7dfd736b7a --- /dev/null +++ b/configs/v1/training_mode/rl.toml @@ -0,0 +1,48 @@ +# Debug RL run on the v1 env server (reverse-text-v1 taskset). +# Single GPU: the rl entrypoint auto-launches the student inference server. +# uv run rl @ configs/v1/training_mode/rl.toml + +max_steps = 20 +seq_len = 2048 + +[model] +name = "PrimeIntellect/Qwen3-0.6B-Reverse-Text-SFT" + +[wandb] +project = "reverse-text-v1-debug" +name = "debug-rl" + +[orchestrator] +training_mode = "rl" +batch_size = 128 +group_size = 16 + +[orchestrator.renderer] +name = "qwen3" + +[orchestrator.train.sampling] +max_completion_tokens = 128 + +[[orchestrator.train.env]] +taskset = { id = "reverse-text-v1" } +# Pure-text single-turn env: the model answers directly (no bash tool), subprocess runtime (no docker). +harness = { id = "default", enable_bash = false, runtime = { type = "subprocess" } } + +[orchestrator.eval] +interval = 1 +num_examples = 128 + +[orchestrator.eval.sampling] +max_completion_tokens = 128 + +[[orchestrator.eval.env]] +taskset = { id = "reverse-text-v1" } +harness = { id = "default", enable_bash = false, runtime = { type = "subprocess" } } + +[trainer.optim] +lr = 3e-6 + +[ckpt] + +[inference] +gpu_memory_utilization = 0.5 diff --git a/configs/v1/training_mode/sft.toml b/configs/v1/training_mode/sft.toml new file mode 100644 index 0000000000..8100414a06 --- /dev/null +++ b/configs/v1/training_mode/sft.toml @@ -0,0 +1,59 @@ +# Debug SFT (on-policy hard distillation) run on the v1 env server (reverse-text-v1). +# The teacher rolls out through the renderer client (token-in/out), so it must share the +# student's tokenizer; the student trains on the teacher's sampled tokens directly. +# Start the teacher inference server first (on a separate GPU): +# CUDA_VISIBLE_DEVICES=1 uv run inference \ +# --model.name PrimeIntellect/Qwen3-0.6B-Reverse-Text-RL \ +# --server.port 8001 --gpu-memory-utilization 0.5 --model.enforce-eager +# Then: +# uv run rl @ configs/v1/training_mode/sft.toml + +max_steps = 20 +seq_len = 2048 + +[model] +name = "PrimeIntellect/Qwen3-0.6B-Reverse-Text-SFT" + +[wandb] +project = "reverse-text-v1-debug" +name = "debug-sft" + +[orchestrator] +training_mode = "sft" +batch_size = 128 +group_size = 4 + +[orchestrator.renderer] +name = "qwen3" + +[orchestrator.train.sampling] +max_completion_tokens = 128 + +[[orchestrator.train.env]] +taskset = { id = "reverse-text-v1" } +harness = { id = "default", enable_bash = false, runtime = { type = "subprocess" } } + +[orchestrator.eval] +interval = 1 +num_examples = 128 + +[orchestrator.eval.sampling] +max_completion_tokens = 128 + +[[orchestrator.eval.env]] +taskset = { id = "reverse-text-v1" } +harness = { id = "default", enable_bash = false, runtime = { type = "subprocess" } } + +[orchestrator.teacher.model] +name = "PrimeIntellect/Qwen3-0.6B-Reverse-Text-RL" + +[orchestrator.teacher.client] +base_url = ["http://localhost:8001/v1"] + +[trainer.optim] +lr = 3e-6 + +[ckpt] + +[inference] +gpu_memory_utilization = 0.5 diff --git a/configs/v1/training_mode/sft_lora.toml b/configs/v1/training_mode/sft_lora.toml new file mode 100644 index 0000000000..db1f94bc32 --- /dev/null +++ b/configs/v1/training_mode/sft_lora.toml @@ -0,0 +1,65 @@ +# Debug SFT run on the v1 env server (reverse-text-v1), training a LoRA adapter. +# The teacher generates rollouts over plain chat-completions; the renderer backfills tokens. +# Start the teacher inference server first (on a separate GPU): +# CUDA_VISIBLE_DEVICES=1 uv run inference \ +# --model.name PrimeIntellect/Qwen3-0.6B-Reverse-Text-RL \ +# --server.port 8001 --gpu-memory-utilization 0.5 --model.enforce-eager +# Then: +# uv run rl @ configs/v1/training_mode/sft_lora.toml + +max_steps = 20 +seq_len = 2048 + +[model] +name = "PrimeIntellect/Qwen3-0.6B-Reverse-Text-SFT" + +[wandb] +project = "reverse-text-v1-debug" +name = "debug-sft-lora" + +[orchestrator] +training_mode = "sft" +batch_size = 128 +group_size = 4 + +# Teacher rolls out over plain chat-completions (no tokens); the renderer backfills them. +[orchestrator.renderer] +name = "qwen3" + +[orchestrator.train.sampling] +max_completion_tokens = 128 + +[[orchestrator.train.env]] +taskset = { id = "reverse-text-v1" } +harness = { id = "default", enable_bash = false, runtime = { type = "subprocess" } } + +[orchestrator.eval] +interval = 1 +num_examples = 128 + +[orchestrator.eval.sampling] +max_completion_tokens = 128 + +[[orchestrator.eval.env]] +taskset = { id = "reverse-text-v1" } +harness = { id = "default", enable_bash = false, runtime = { type = "subprocess" } } + +[orchestrator.teacher.model] +name = "PrimeIntellect/Qwen3-0.6B-Reverse-Text-RL" + +[orchestrator.teacher.client] +base_url = ["http://localhost:8001/v1"] + +[trainer.optim] +lr = 1e-4 + +[trainer.model.lora] +rank = 8 + +[trainer.ckpt.weights] +save_adapter_separately = true + +[ckpt] + +[inference] +gpu_memory_utilization = 0.5 diff --git a/configs/wordle/rl.toml b/configs/wordle/rl.toml new file mode 100644 index 0000000000..826f3cd519 --- /dev/null +++ b/configs/wordle/rl.toml @@ -0,0 +1,41 @@ +# Wordle (classic v0 TextArena env) on 2 GPUs (1 trainer + 1 inference), run via the +# legacy bridge. The Wordle-SFT model plays a full ~6-turn game per rollout; the bridge +# serves each multi-turn rollout to the orchestrator as a vf.Trace. + +max_steps = 200 +seq_len = 8192 + +[deployment] +num_train_gpus = 1 +num_infer_gpus = 1 + +[wandb] +project = "wordle" +name = "wordle" + +[ckpt] # Checkpoint at the end of training + +[model] +name = "PrimeIntellect/Qwen3-1.7B-Wordle-SFT" + +[orchestrator] +batch_size = 128 +group_size = 8 + +[[orchestrator.train.env]] +id = "primeintellect/wordle" +name = "wordle" + +[orchestrator.train.sampling] +max_completion_tokens = 1024 + +[trainer] # Default trainer config + +[inference] # single inference GPU (no data parallelism) + +# Qwen3 finetune with the standard PI template patch (byte-identical to +# PrimeIntellect/Qwen3-0.6B base); always re-emits prior blocks. +# Match that with the qwen3 renderer's preserve_all_thinking. +[orchestrator.renderer] +name = "qwen3" +preserve_all_thinking = true diff --git a/deps/verifiers b/deps/verifiers index 05c66c2358..f8425c9117 160000 --- a/deps/verifiers +++ b/deps/verifiers @@ -1 +1 @@ -Subproject commit 05c66c235875d785754f2b7078db0e7deeddbeae +Subproject commit f8425c911714318da7ba93f6e6541a5b4a6b1bad diff --git a/packages/prime-rl-configs/pyproject.toml b/packages/prime-rl-configs/pyproject.toml index 55ae0a5b47..bd361045a5 100644 --- a/packages/prime-rl-configs/pyproject.toml +++ b/packages/prime-rl-configs/pyproject.toml @@ -10,6 +10,7 @@ dependencies = [ "renderers>=0.1.8.dev28", "tomli>=2.2.1", "tomli-w>=1.2.0", + "verifiers", ] [build-system] diff --git a/packages/prime-rl-configs/src/prime_rl/configs/env_server.py b/packages/prime-rl-configs/src/prime_rl/configs/env_server.py index 50c99adcff..9afec7adff 100644 --- a/packages/prime-rl-configs/src/prime_rl/configs/env_server.py +++ b/packages/prime-rl-configs/src/prime_rl/configs/env_server.py @@ -1,14 +1,12 @@ from pathlib import Path -from pydantic import model_validator - from prime_rl.configs.orchestrator import EnvConfig from prime_rl.configs.shared import LogConfig from prime_rl.utils.config import BaseConfig class EnvServerConfig(BaseConfig): - env: EnvConfig = EnvConfig() + env: EnvConfig log: LogConfig = LogConfig() @@ -17,9 +15,3 @@ class EnvServerConfig(BaseConfig): output_dir: Path = Path("outputs") """Directory to write outputs to — logs and any generated artifacts are written as subdirectories.""" - - @model_validator(mode="after") - def validate_num_workers(self): - if self.env.num_workers == "auto": - self.env.num_workers = 1 - return self diff --git a/packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py b/packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py index be5fe249f3..f69b19f6ad 100644 --- a/packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py +++ b/packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py @@ -1,10 +1,9 @@ -import math import warnings from pathlib import Path from typing import Annotated, Any, Literal, TypeAlias -from pydantic import AliasChoices, Field, model_serializer, model_validator -from pydantic_core.core_schema import SerializerFunctionWrapHandler +import verifiers.v1 as vf +from pydantic import AliasChoices, Field, model_validator from renderers import AutoRendererConfig, RendererConfig from prime_rl.configs.shared import ( @@ -143,24 +142,18 @@ def _deprecate_max_tokens(cls, data: Any) -> Any: return data -class EnvConfig(BaseConfig): - id: str = "reverse-text" - """Registered verifiers environment ID (e.g. ``math-env``, ``primeintellect/math-env``). May include an ``@version`` suffix for installation.""" +class EnvConfig(vf.EnvServerConfig): + """A v1 environment — its ``taskset`` + ``harness`` (reused from ``vf.EnvConfig``, + resolved to their specific config types by ``id`` via vf's shared validator) plus the + worker ``pool`` (from ``vf.EnvServerConfig``: ``static`` or ``elastic``, default + elastic) and prime-rl's orchestration knobs. Timeouts come from ``vf.TimeoutConfig`` + (``timeout.rollout`` / ``timeout.scoring``).""" name: str | None = None - """Display name for this environment in logs, metrics, and buffer keys. Defaults to the ``id`` without ``@version``. Must be unique across all envs in the same group.""" - - args: dict = {} - """Keyword arguments forwarded to ``vf.load_environment``. See the environment's docstring for accepted args.""" - - extra_env_kwargs: dict[str, Any] = {} - """Extra kwargs passed to the env (e.g. ``seq_len``, ``max_total_completion_tokens``). Auto-populated by the orchestrator; user overrides are generally discouraged. The main use case is matching ``extra_env_kwargs`` when running an env in an isolated environment server.""" + """Display name for this environment in logs, metrics, and buffer keys. Defaults to the taskset id. Must be unique across all envs in the same group.""" address: str | None = None - """ZMQ address of an external env server (e.g. ``tcp://host:5000``). When set, the orchestrator connects to this server instead of spawning one; when None, a subprocess env server is spawned automatically.""" - - num_workers: int | Literal["auto"] = "auto" - """Worker processes for the spawned env server. ``auto`` scales to 1 worker per 256 concurrent rollouts. Ignored when ``address`` is set.""" + """ZMQ address of an external env server (e.g. ``tcp://host:5000``). When set, the orchestrator connects to this server instead of spawning one; when None, a subprocess env server is spawned automatically. The ``pool`` sizes the spawned server.""" ratio: float | None = Field(None, gt=0) """Sampling weight for this environment in the buffer. When None for all envs, samples uniformly across all available problems. When set, must be set on all envs — values are relative weights normalized to probabilities (e.g. [1, 1] and [0.5, 0.5] are equivalent).""" @@ -168,26 +161,42 @@ class EnvConfig(BaseConfig): max_retries: int = Field(3, ge=0) """Times the env server retries a failed rollout before returning an error.""" - max_total_completion_tokens: int = -1 - """Maximum total completion tokens across all turns in a multi-turn rollout. ``-1`` disables. Auto-populated into ``extra_env_kwargs``.""" + id: str | None = None + """Classic (v0) env id, loaded via verifiers ``load_environment(id, **args)`` and served + through the legacy bridge. Set this instead of ``taskset`` to run a legacy v0 environment.""" + args: dict = Field(default_factory=dict) + """Kwargs passed to the v0 env's ``load_environment`` (only used when ``id`` is set).""" - timeout: float | None = Field(None, validation_alias=AliasChoices("timeout", "timeout_seconds")) - """Per-rollout wall-clock timeout in seconds. None disables.""" + @model_validator(mode="before") + @classmethod + def _migrate_num_workers(cls, data): + """Back-compat: the removed ``num_workers`` maps onto ``pool`` — an int becomes a + fixed ``static`` pool, ``"auto"`` falls through to the default ``elastic`` pool. An + explicit ``pool`` always wins.""" + if isinstance(data, dict) and "num_workers" in data: + num_workers = data.pop("num_workers") + if "pool" not in data and num_workers != "auto": + data["pool"] = {"type": "static", "num_workers": num_workers} + return data - state_columns: list[str] = [] - """Extra ``State`` fields to persist into the saved rollout records (in addition to the always-saved ``trajectory`` and ``sampling_args``). Values must be JSON-serializable.""" + @property + def is_legacy(self) -> bool: + """A v0/legacy env (run via the bridge): an ``id`` is set and no v1 ``taskset`` is.""" + return not self.taskset.id @property - def stripped_id(self) -> str: - """Environment ID without the @version suffix.""" - return self.id.split("@")[0] + def env_id(self) -> str: + """The env identifier — the v1 taskset id (v1) or the legacy env id (v0).""" + return self.taskset.id or self.id or "" @property def resolved_name(self) -> str: - return self.name or self.stripped_id + return self.name or self.env_id @model_validator(mode="after") - def validate_env_name(self): + def validate_env(self): + if not self.taskset.id and not self.id: + raise ValueError('no env configured — set taskset = { id = "" } (v1) or id = "" (v0/legacy)') if self.resolved_name == "all": raise ValueError( 'Environment name "all" is reserved for global metric aggregation. Use a different name or id.' @@ -195,14 +204,16 @@ def validate_env_name(self): return self @model_validator(mode="after") - def resolve_max_total_completion_tokens(self): - self.extra_env_kwargs["max_total_completion_tokens"] = self.max_total_completion_tokens - return self - - @model_validator(mode="after") - def resolve_timeout(self): - if self.timeout is not None: - self.extra_env_kwargs["timeout_seconds"] = self.timeout + def resolve_legacy_env_kwargs(self): + """For a v0/legacy env, surface the v1 knobs the legacy bridge applies via + ``extra_env_kwargs`` (``env.set_kwargs(...)``): the per-rollout wall-clock timeout and + the multi-turn completion-token budget. (``max_seq_len`` is added per train run in + ``OrchestratorConfig.resolve_env_config``, which knows ``seq_len``.)""" + if self.is_legacy: + if self.timeout.rollout is not None: + self.extra_env_kwargs["timeout_seconds"] = self.timeout.rollout + if self.max_output_tokens is not None: + self.extra_env_kwargs["max_total_completion_tokens"] = self.max_output_tokens return self @@ -230,21 +241,19 @@ class EvalEnvConfig(EnvConfig): class TrainConfig(BaseConfig): - env: list[TrainEnvConfig] = [TrainEnvConfig()] + env: list[TrainEnvConfig] = Field(default_factory=list) """Training environments.""" sampling: TrainSamplingConfig = TrainSamplingConfig() """Shared training sampling configuration.""" - num_workers: int | Literal["auto"] = "auto" - """Default worker processes for env servers. Can be overridden per env.""" - max_retries: int = Field(3, ge=0) """Default retries for failed rollouts. Can be overridden per env.""" @model_validator(mode="after") def resolve_env_defaults(self): - """Resolve per-env overrides: inherit group-level sampling, num_workers, and max_retries.""" + """Resolve per-env overrides: inherit group-level sampling and max_retries (the + worker ``pool`` is configured per env, defaulting to elastic).""" group_sampling = self.sampling.model_dump() for env in self.env: if "sampling" not in env.model_fields_set: @@ -252,8 +261,6 @@ def resolve_env_defaults(self): else: merged = group_sampling | env.sampling.model_dump(exclude_unset=True) env.sampling = TrainSamplingConfig(**merged) - if "num_workers" not in env.model_fields_set: - env.num_workers = self.num_workers if "max_retries" not in env.model_fields_set: env.max_retries = self.max_retries return self @@ -279,7 +286,7 @@ def validate_env_ratios(self): class EvalConfig(BaseConfig): - env: list[EvalEnvConfig] = [EvalEnvConfig()] + env: list[EvalEnvConfig] = Field(default_factory=list) """Evaluation environments.""" sampling: EvalSamplingConfig = Field(default_factory=EvalSamplingConfig) @@ -291,9 +298,6 @@ class EvalConfig(BaseConfig): group_size: int = Field(1, ge=1, validation_alias=AliasChoices("group_size", "rollouts_per_example")) """Default rollouts per example. Can be overridden per env.""" - num_workers: int | Literal["auto"] = "auto" - """Default worker processes for env servers. Can be overridden per env.""" - max_retries: int = Field(3, ge=0) """Default retries for failed rollouts. Can be overridden per env.""" @@ -306,7 +310,8 @@ class EvalConfig(BaseConfig): @model_validator(mode="after") def resolve_env_defaults(self): - """Resolve per-env overrides: inherit group-level sampling, num_workers, max_retries, num_examples, group_size, and interval. Then resolve auto num_workers.""" + """Resolve per-env overrides: inherit group-level sampling, max_retries, num_examples, + group_size, and interval (the worker ``pool`` is configured per env, default elastic).""" group_sampling = self.sampling.model_dump() for env in self.env: if "sampling" not in env.model_fields_set: @@ -320,17 +325,8 @@ def resolve_env_defaults(self): env.group_size = self.group_size if "interval" not in env.model_fields_set: env.interval = self.interval - if "num_workers" not in env.model_fields_set: - env.num_workers = self.num_workers if "max_retries" not in env.model_fields_set: env.max_retries = self.max_retries - # Resolve auto num_workers now that num_examples and group_size are set - if env.num_workers == "auto": - if env.num_examples == -1: - env.num_workers = 4 - else: - max_concurrent = env.num_examples * env.group_size - env.num_workers = max(1, math.ceil(max_concurrent / 256)) return self @model_validator(mode="after") @@ -513,30 +509,16 @@ class OrchestratorConfig(BaseConfig): tokenizer: TokenizerConfig = TokenizerConfig() - renderer: RendererConfig | None = AutoRendererConfig() - """Typed renderer config (``renderers.RendererConfig`` discriminated - union). Defaults to ``"auto"``, which resolves from - ``tokenizer.name_or_path`` via ``MODEL_RENDERER_MAP``. ``None`` - opts into MITO (``openai_chat_completions``); SFT mode forces this.""" + renderer: RendererConfig = AutoRendererConfig() + """Typed renderer config (``renderers.RendererConfig`` discriminated union), required — + training is renderer-only. Defaults to ``"auto"``, which resolves from + ``tokenizer.name_or_path`` via ``MODEL_RENDERER_MAP``. RL/OPD roll out through the renderer + client; SFT uses it to backfill tokens for its chat-completions teacher.""" pool_size: int | None = Field(None, ge=1) """Number of renderer slots shared across concurrent rollouts. Bump for long multi-turn prompts where client-side jinja tokenization - serializes. Only meaningful when ``renderer`` is not ``None``.""" - - @model_serializer(mode="wrap") - def _preserve_mito_renderer(self, handler: SerializerFunctionWrapHandler) -> dict[str, Any]: - """Emit ``renderer = "None"`` (string) when MITO so - ``model_dump(exclude_none=True)`` round-trips: dumped TOML has - ``renderer = "None"``, and on reload - ``BaseConfig._none_str_to_none`` coerces it back to ``None``. - Without this, a MITO orchestrator config saved to - ``control/orch.toml`` would lose the renderer key entirely and - reload as the default ``AutoRendererConfig()`` (TITO).""" - result = handler(self) - if self.renderer is None: - result["renderer"] = "None" - return result + serializes.""" optim: OptimizerConfig = OptimizerConfig() """Per-run optimizer configuration for multi-run training.""" @@ -752,16 +734,6 @@ def validate_unique_filter_types(self): ) return self - @model_validator(mode="after") - def _force_no_renderer_for_sft(self): - """SFT rolls out via the teacher's plain chat-completions endpoint; the - renderer client doesn't apply. Force ``renderer=None`` so the user - doesn't have to remember to set it. Declared before the renderer - validators below so they see the corrected value.""" - if self.training_mode == "sft": - self.renderer = None - return self - @model_validator(mode="after") def validate_training_mode(self): """Enforce training mode invariants that involve only orchestrator fields.""" @@ -772,34 +744,6 @@ def validate_training_mode(self): raise ValueError(f"orchestrator.teacher must be configured when training_mode = '{self.training_mode}'.") return self - @model_validator(mode="after") - def validate_pool_size(self): - """``pool_size`` is only meaningful when the renderer is enabled - (``renderer is not None``). Reject otherwise so callers don't - silently pass it and wonder why it's ignored.""" - if self.renderer is None and self.pool_size is not None: - raise ValueError( - f"orchestrator.pool_size={self.pool_size!r} is set but " - "orchestrator.renderer is None (MITO mode). Either configure a renderer " - "or remove pool_size." - ) - return self - - @model_validator(mode="after") - def vlm_requires_renderer(self): - """VLMs (``[model.vlm]`` block set) must go through the renderer. - - The renderer owns the processor per-slot, produces byte-identical - tokens, and ships generic ``mm_kwargs`` keyed by whatever the - model's forward signature expects. - """ - if self.student.model.vlm is not None and self.renderer is None: - raise ValueError( - "orchestrator.renderer must be set when model.vlm is set. " - "VLMs must go through a renderer (e.g. Qwen3VLRenderer) that owns the processor." - ) - return self - @model_validator(mode="after") def validate_renderer_auto_resolves(self): """Reject the silent DefaultRenderer fallback at config time. @@ -813,7 +757,7 @@ def validate_renderer_auto_resolves(self): ``DefaultRendererConfig.tool_parser`` is configured. Surface at config time so ``--dry-run`` reports the error. """ - if self.renderer is None or self.renderer.name != "auto": + if self.renderer.name != "auto": return self from renderers.base import MODEL_RENDERER_MAP @@ -831,9 +775,7 @@ def validate_renderer_auto_resolves(self): f"(b) [orchestrator.renderer] name= — " f"if {model_id!r} is template-identical to a mapped family " f"(and ideally also add it upstream to " - f"renderers.base.MODEL_RENDERER_MAP). " - f"(c) orchestrator.renderer='none' — opt out of the renderer " - f"client entirely (MITO)." + f"renderers.base.MODEL_RENDERER_MAP)." ) @model_validator(mode="after") @@ -876,12 +818,6 @@ def resolve_batching(self): if "group_size" not in env_cfg.model_fields_set: env_cfg.group_size = self.group_size - # Resolve train env num_workers from max_inflight_rollouts - for env_cfg in self.train.env: - if env_cfg.num_workers == "auto": - assert self.max_inflight_rollouts is not None - env_cfg.num_workers = max(1, math.ceil(self.max_inflight_rollouts / 256)) - return self @model_validator(mode="after") @@ -900,12 +836,15 @@ def auto_setup_bench(self): @model_validator(mode="after") def resolve_env_config(self): - """Populate extra_env_kwargs and vLLM sampling defaults from top-level fields.""" - is_vllm = self.training_mode != "sft" + """Set vLLM sampling defaults on each train env from top-level fields.""" + if self.training_mode == "sft": + return self for env in self.train.env: - env.extra_env_kwargs.update(max_seq_len=self.seq_len) - if is_vllm: - env.sampling.extra_body.setdefault("top_k", -1) - env.sampling.extra_body.setdefault("min_p", 0.0) - env.sampling.extra_body.setdefault("return_token_ids", True) + env.sampling.extra_body.setdefault("top_k", -1) + env.sampling.extra_body.setdefault("min_p", 0.0) + env.sampling.extra_body.setdefault("return_token_ids", True) + if env.is_legacy: + # v0 env: cap per-turn response tokens to the training budget (the legacy + # bridge applies extra_env_kwargs via env.set_kwargs). + env.extra_env_kwargs["max_seq_len"] = self.seq_len return self diff --git a/packages/prime-rl-configs/src/prime_rl/configs/rl.py b/packages/prime-rl-configs/src/prime_rl/configs/rl.py index dab46a9ce1..a07a934240 100644 --- a/packages/prime-rl-configs/src/prime_rl/configs/rl.py +++ b/packages/prime-rl-configs/src/prime_rl/configs/rl.py @@ -191,6 +191,11 @@ class RLConfig(BaseConfig): clean_output_dir: bool = False """Delete the output directory before starting training. Required to overwrite an output directory that contains checkpoints from a previous run when not resuming.""" + env_server_base_port: int = 5000 + """Base TCP port for launcher-spawned env servers; the i-th launcher-managed env + (train then eval) binds ``127.0.0.1:base_port + i``. Envs that set their own + ``address`` are skipped. Pick a range clear of the inference/weight-broadcast ports.""" + ### Shared configurations log: SharedLogConfig = SharedLogConfig() diff --git a/pyproject.toml b/pyproject.toml index 625a879d5a..c4e77e156f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,6 +38,10 @@ dependencies = [ "flash-linear-attention", "nvidia-ml-py>=12.575.51", "pybase64>=1.4.2", + # Previously pulled in transitively by v1 verifiers (now dropped on this branch). + "orjson>=3.10", + "pandas>=2.0", + "msgspec>=0.18", ] [project.scripts] @@ -60,6 +64,9 @@ flash-attn-3 = ["flash_attn_3"] flash-attn-cute = [ "flash-attn-4", ] +# Environments prime-rl can run, resolved by id: the classic v0 envs (driven through the +# legacy bridge) plus the v1 tasksets + harnesses (`-v1` suffix avoids clashing with the +# same-named v0 envs). envs = [ "aime2024", "aime2025", @@ -91,6 +98,16 @@ envs = [ "tau2-bench", "wiki-search", "wordle", + "reverse-text-v1", + "gsm8k-v1", + "math-env-v1", + "aime24-v1", + "alphabet-sort-v1", + "scaleswe-v1", + "color-codeword-v1", + "r2e-gym-v1", + "tasksets", + "harnesses", ] disagg = [ "deep-ep ; platform_machine == 'x86_64'", @@ -196,9 +213,10 @@ nixl-cu12 = false [tool.uv.sources] prime-rl-configs = { path = "packages/prime-rl-configs", editable = true } +# prime-rl depends on the one `verifiers` package (v0 + v1 + the legacy bridge). v0 envs +# resolve from deps/{verifiers,research-environments}/environments; v1 tasksets (`-v1`) and +# harnesses from deps/verifiers/{examples/tasksets,packages}. verifiers = { path = "deps/verifiers", editable = true } -renderers = { path = "deps/renderers", editable = true } -prime-pydantic-config = { path = "deps/pydantic-config", editable = true } aime2024 = { path = "deps/research-environments/environments/aime2024", editable = true } aime2025 = { path = "deps/research-environments/environments/aime2025", editable = true } alphabet-sort = { path = "deps/verifiers/environments/alphabet_sort", editable = true } @@ -229,6 +247,18 @@ simpleqa-verified = { path = "deps/research-environments/environments/simpleqa_v tau2-bench = { path = "deps/research-environments/environments/tau2_bench", editable = true } wiki-search = { path = "deps/verifiers/environments/wiki_search", editable = true } wordle = { path = "deps/verifiers/environments/wordle", editable = true } +reverse-text-v1 = { path = "deps/verifiers/examples/tasksets/reverse_text_v1", editable = true } +gsm8k-v1 = { path = "deps/verifiers/examples/tasksets/gsm8k_v1", editable = true } +math-env-v1 = { path = "deps/verifiers/examples/tasksets/math_env_v1", editable = true } +aime24-v1 = { path = "deps/verifiers/examples/tasksets/aime24_v1", editable = true } +alphabet-sort-v1 = { path = "deps/verifiers/examples/tasksets/alphabet_sort_v1", editable = true } +scaleswe-v1 = { path = "deps/verifiers/examples/tasksets/scaleswe_v1", editable = true } +color-codeword-v1 = { path = "deps/verifiers/examples/tasksets/color_codeword_v1", editable = true } +r2e-gym-v1 = { path = "deps/verifiers/examples/tasksets/r2e_gym_v1", editable = true } +tasksets = { path = "deps/verifiers/packages/tasksets", editable = true } +harnesses = { path = "deps/verifiers/packages/harnesses", editable = true } +renderers = { path = "deps/renderers", editable = true } +prime-pydantic-config = { path = "deps/pydantic-config", editable = true } torch = { index = "pytorch-cu128" } torchvision = { index = "pytorch-cu128" } torchaudio = { index = "pytorch-cu128" } diff --git a/src/prime_rl/entrypoints/rl.py b/src/prime_rl/entrypoints/rl.py index 0ccdd7c53b..f4cf958086 100644 --- a/src/prime_rl/entrypoints/rl.py +++ b/src/prime_rl/entrypoints/rl.py @@ -12,6 +12,7 @@ import pynvml import tomli_w +from prime_rl.configs.orchestrator import EnvConfig from prime_rl.configs.rl import RLConfig from prime_rl.utils.config import cli from prime_rl.utils.logger import get_logger, setup_logger @@ -32,6 +33,10 @@ ORCHESTRATOR_TOML = "orchestrator.toml" INFERENCE_TOML = "inference.toml" +# Gap between the train and eval env-server port blocks, so each kind has headroom +# for many envs without the blocks overlapping (train: base+i; eval: base+stride+i). +ENV_SERVER_KIND_STRIDE = 1000 + def get_physical_gpu_ids() -> list[int]: """Return physical GPU IDs visible to the launcher.""" @@ -67,6 +72,44 @@ def write_subconfigs(config: RLConfig, output_dir: Path) -> None: tomli_w.dump(config.inference.model_dump(exclude=exclude_inference, exclude_none=True, mode="json"), f) +def setup_env_servers(config: RLConfig, config_dir: Path) -> list[dict]: + """Give each env its own launcher-spawned ``env-server`` process on a fixed port, + point the orchestrator at it (set ``env.address`` so it attaches instead of + sidecar-spawning), and write a per-env ``EnvServerConfig`` TOML. Train envs bind + ``base_port + i``; eval envs bind ``base_port + ENV_SERVER_KIND_STRIDE + i``, so the + two kinds sit in separate blocks with headroom for many envs each. Envs that already + set ``address`` (a user-managed external server) are left alone. Must run before + ``write_subconfigs`` so the addresses land in the orchestrator config. + + Returns one spawn spec per server: ``{label, kind, name, toml}``. + """ + config_dir.mkdir(parents=True, exist_ok=True) + env_lists = [("train", config.orchestrator.train.env)] + if config.orchestrator.eval is not None: + env_lists.append(("eval", config.orchestrator.eval.env)) + + specs: list[dict] = [] + for kind_index, (kind, env_list) in enumerate(env_lists): + for i, env in enumerate(env_list): + if env.address is not None: + continue # user-managed external server — don't spawn one + port = config.env_server_base_port + kind_index * ENV_SERVER_KIND_STRIDE + i + env.address = f"tcp://127.0.0.1:{port}" + env_dict = { + k: v for k, v in env.model_dump(mode="json", exclude_none=True).items() if k in EnvConfig.model_fields + } + server_dict: dict = {"env": env_dict, "output_dir": config.output_dir.as_posix()} + if config.log.level is not None: + server_dict["log"] = {"level": config.log.level} + toml_path = config_dir / f"env_server_{kind}_{env.resolved_name}.toml" + with open(toml_path, "wb") as f: + tomli_w.dump(server_dict, f) + specs.append( + {"label": f"env-{kind}-{env.resolved_name}", "kind": kind, "name": env.resolved_name, "toml": toml_path} + ) + return specs + + def rl_local(config: RLConfig): assert config.deployment.type == "single_node" @@ -76,6 +119,9 @@ def rl_local(config: RLConfig): ) config_dir = config.output_dir / "configs" + # Assign each env its own env-server (sets env.address) *before* writing subconfigs, + # so the orchestrator config points at the launcher-spawned servers. + env_server_specs = setup_env_servers(config, config_dir) write_subconfigs(config, config_dir) logger.info(f"Wrote subconfigs to {config_dir}") @@ -194,6 +240,32 @@ def sigterm_handler(signum, frame): "orchestrator starts, otherwise rollouts will hang." ) + # Start one env server per env (before the orchestrator, which attaches to + # them by address). CPU-only — keep them off the GPUs. + for spec in env_server_specs: + env_log = log_dir / "envs" / spec["kind"] / f"{spec['name']}.log" + env_log.parent.mkdir(parents=True, exist_ok=True) + env_cmd = ["env-server", "@", spec["toml"].as_posix()] + logger.info(f"Starting env server {spec['label']}") + logger.debug(f"Env server start command: {' '.join(env_cmd)}") + with open(env_log, "w") as log_file: + env_process = Popen( + env_cmd, + env={**os.environ, "CUDA_VISIBLE_DEVICES": ""}, + stdout=log_file, + stderr=log_file, + ) + processes.append(env_process) + stop_event = Event() + stop_events[spec["label"]] = stop_event + monitor_thread = Thread( + target=monitor_process, + args=(env_process, stop_event, error_queue, spec["label"]), + daemon=True, + ) + monitor_thread.start() + monitor_threads.append(monitor_thread) + orchestrator_cmd = ["orchestrator", "@", (config_dir / ORCHESTRATOR_TOML).as_posix()] logger.info("Starting orchestrator process") logger.debug(f"Orchestrator start command: {' '.join(orchestrator_cmd)}") diff --git a/src/prime_rl/orchestrator/advantage.py b/src/prime_rl/orchestrator/advantage.py index b58a410326..7a1df12eca 100644 --- a/src/prime_rl/orchestrator/advantage.py +++ b/src/prime_rl/orchestrator/advantage.py @@ -4,12 +4,13 @@ from typing import TYPE_CHECKING, Callable import torch -import verifiers as vf from jaxtyping import Float from torch import Tensor if TYPE_CHECKING: - from prime_rl.orchestrator.types import TrainRollout + import verifiers.v1 as vf + + from prime_rl.orchestrator.types import Rollout from prime_rl.configs.orchestrator import ( AdvantageConfig, @@ -18,7 +19,7 @@ TokensLengthPenaltyConfig, TurnsLengthPenaltyConfig, ) -from prime_rl.orchestrator.utils import get_model_completion_len, get_tool_response_len +from prime_rl.orchestrator.utils import get_tool_response_len from prime_rl.utils.utils import import_object @@ -26,7 +27,7 @@ class AdvantageInputs: """Inputs for advantage computation of a single group (one example × N rollouts).""" - rollouts: list[vf.RolloutOutput] + rollouts: list[vf.Trace] @dataclass @@ -57,18 +58,18 @@ def default_advantage_fn( `length_penalty` enables correctness-gated efficiency shaping over a per-rollout cost: tokens (weighted completion + tool-response) or trajectory turn count. """ - rewards = torch.tensor([r["reward"] for r in inputs.rollouts], dtype=torch.float32) + rewards = torch.tensor([r.reward for r in inputs.rollouts], dtype=torch.float32) if isinstance(length_penalty, TokensLengthPenaltyConfig): w_c = length_penalty.completion_weight w_t = length_penalty.tool_response_weight costs = torch.tensor( - [w_c * get_model_completion_len(r) + w_t * get_tool_response_len(r) for r in inputs.rollouts], + [w_c * r.completion_len + w_t * get_tool_response_len(r) for r in inputs.rollouts], dtype=rewards.dtype, ) return AdvantageOutputs(advantages=_efficiency_shaping(rewards, costs).tolist()) if isinstance(length_penalty, TurnsLengthPenaltyConfig): - costs = torch.tensor([len(r["trajectory"]) for r in inputs.rollouts], dtype=rewards.dtype) + costs = torch.tensor([r.num_turns for r in inputs.rollouts], dtype=rewards.dtype) return AdvantageOutputs(advantages=_efficiency_shaping(rewards, costs).tolist()) return AdvantageOutputs(advantages=(rewards - rewards.mean()).tolist()) @@ -129,19 +130,19 @@ def advantage_fn(inputs: AdvantageInputs) -> AdvantageOutputs: def assign_advantages( - rollouts: list["TrainRollout"], # noqa: F821 (forward ref) + rollouts: list[Rollout], advantage_fn: AdvantageFn | None, ) -> None: """Compute and assign advantages for one finished group of rollouts (``TrainSink.process_group`` hands in a single group's surviving rollouts). ``advantage_fn=None`` is the trivial case (advantage = reward); a custom - ``advantage_fn`` receives the raw ``vf.RolloutOutput``\\ s via + ``advantage_fn`` receives the ``vf.Trace``\\ s via ``AdvantageInputs.rollouts``. """ if advantage_fn is None: for rollout in rollouts: rollout.advantage = rollout.reward return - result = advantage_fn(AdvantageInputs(rollouts=[r.raw for r in rollouts])) + result = advantage_fn(AdvantageInputs(rollouts=[r for r in rollouts])) for rollout, advantage in zip(rollouts, result.advantages): rollout.advantage = advantage diff --git a/src/prime_rl/orchestrator/dispatcher.py b/src/prime_rl/orchestrator/dispatcher.py index 133bc08da0..251825fcbc 100644 --- a/src/prime_rl/orchestrator/dispatcher.py +++ b/src/prime_rl/orchestrator/dispatcher.py @@ -6,7 +6,7 @@ - Emit-everything invariant: every dispatched rollout eventually reaches ``out_q`` exactly once as a ``TrainRollout`` / ``EvalRollout``. Failures (env error, empty trajectory, task exception, off-policy cancel) carry - ``raw["error"]`` set; sinks decide drop / partial-train policy. + ``trace.error`` set; sinks decide drop / partial-train policy. - ``DispatcherMode.PREFER_TRAIN`` / ``PREFER_EVAL`` controls which kind to schedule next. Transitions are level-triggered (driven by the eval source's emptiness), so in-flight rollouts of the opposite kind drain @@ -20,26 +20,25 @@ from __future__ import annotations import asyncio +import traceback import uuid from collections import Counter, defaultdict from dataclasses import dataclass, field from enum import Enum, auto from typing import Literal -import verifiers as vf +import verifiers.v1 as vf from aiolimiter import AsyncLimiter from prime_rl.orchestrator.envs import EvalEnvs, TrainEnvs from prime_rl.orchestrator.eval_source import EvalSource from prime_rl.orchestrator.train_source import TrainSource from prime_rl.orchestrator.types import ( - EvalRollout, - FinishedRollout, GroupState, InflightRollout, Policy, + Rollout, RolloutKind, - TrainRollout, ) from prime_rl.utils.async_utils import safe_cancel, safe_cancel_all from prime_rl.utils.client import InferencePool, client_identity @@ -154,7 +153,7 @@ def __init__( self.groups: dict[uuid.UUID, GroupState] = {} # Bounded so the dispatcher backpressures on a slow sink - self.out_q: asyncio.Queue[FinishedRollout] = asyncio.Queue(maxsize=max(8, self.max_inflight)) + self.out_q: asyncio.Queue[Rollout] = asyncio.Queue(maxsize=max(8, self.max_inflight)) self.mode: DispatcherMode = DispatcherMode.PREFER_TRAIN # Set by the orchestrator after the final train step; pipeline then @@ -366,7 +365,7 @@ def next_fresh_group(self, kind: RolloutKind, envs) -> GroupState | None: return GroupState( kind=kind, env_name=env_name, - example=example, + task_idx=example["task_idx"], rollouts_to_schedule=group_size, target_rollouts=group_size, eval_step=eval_step, @@ -423,7 +422,7 @@ async def schedule_group_rollout(self, group_id: uuid.UUID, group: GroupState) - task: asyncio.Task = asyncio.create_task( env.run_group( client=client, - example=group.example, + task_idx=group.task_idx, model_name=model_name, group_size=permits, cache_salt=cache_salt, @@ -436,7 +435,7 @@ async def schedule_group_rollout(self, group_id: uuid.UUID, group: GroupState) - task = asyncio.create_task( env.run_rollout( client=client, - example=group.example, + task_idx=group.task_idx, model_name=model_name, cache_salt=cache_salt, ) @@ -480,97 +479,66 @@ async def handle_completed_rollout(self, task: asyncio.Task) -> None: is_synth_exception = False try: result = task.result() - rollouts: list[vf.RolloutOutput] = result if isinstance(result, list) else [result] + rollouts: list[Rollout] = result if isinstance(result, list) else [result] except asyncio.CancelledError: return except Exception as exc: get_logger().warning(f"Rollout task failed in group {meta.group_id} ({meta.env_name}): {exc!r}") + task_idx = group.task_idx if group is not None else -1 + tb = traceback.format_exc() rollouts = [ - self.error_rollout_output(error_type=type(exc).__name__, error_repr=repr(exc)) + Rollout( + task=vf.Task(idx=task_idx, instruction=""), + errors=[vf.Error(type=type(exc).__name__, message=str(exc), traceback=tb)], + stop_condition="error", + ) for _ in range(meta.rollout_count) ] is_synth_exception = True for r in rollouts: - if r.get("error") is None and len(r.get("trajectory") or []) == 0: + if not r.has_error and r.num_turns == 0: # Empty trajectory: promote to an explicit error so the sink # treats it like any other failure - r["error"] = { - "error": "EmptyTrajectory", - "error_chain_repr": "Rollout returned with no trajectory steps", - "error_chain_str": "", - } + r.errors.append(vf.Error(type="EmptyTrajectory", message="Rollout returned with no trajectory steps")) get_logger().warning(f"Empty trajectory in group {meta.group_id} ({meta.env_name})") - if r.get("error") is not None: - err_type = r["error"].get("error", "Unknown") + if r.has_error: self.metrics.record_error(kind=meta.kind, env_name=meta.env_name) if not is_synth_exception: get_logger().warning( - f"Rollout failed in group {meta.group_id} ({meta.env_name}) — " - f"{r['error'].get('error_chain_repr', err_type)}" + f"Rollout failed in group {meta.group_id} ({meta.env_name}) — {r.error.type}: {r.error.message}" ) await self.emit_rollout(meta, group, r) - async def emit_rollout(self, meta: InflightRollout, group: GroupState | None, raw: vf.RolloutOutput) -> None: - """Build a ``TrainRollout`` / ``EvalRollout`` and put it on ``out_q``. + async def emit_rollout(self, meta: InflightRollout, group: GroupState | None, rollout: Rollout) -> None: + """Stamp prime-rl metadata onto the completed rollout and put it on ``out_q``. Pops the group from ``self.groups`` once every member has been emitted.""" eval_step = meta.eval_step policy_version = meta.policy_version - example_id = raw.get("example_id") if group is not None: eval_step = group.eval_step policy_version = group.policy_version_at_start - example_id = group.example["example_id"] group.emitted += 1 if group.emitted >= group.target_rollouts: self.groups.pop(meta.group_id, None) - common = dict( - raw=raw, - env_name=meta.env_name, - example_id=example_id if example_id is not None else -1, - group_id=meta.group_id, - policy_version=policy_version, - off_policy_steps=meta.off_policy_steps, - ) - rollout: FinishedRollout - if meta.kind == "train": - rollout = TrainRollout(**common) - else: + rollout.kind = meta.kind + rollout.env_name = meta.env_name + rollout.group_id = meta.group_id + rollout.policy_version = policy_version + rollout.off_policy_steps = meta.off_policy_steps + if meta.kind == "eval": assert eval_step is not None, "eval rollout missing eval_step" - rollout = EvalRollout(**common, eval_step=eval_step) + rollout.eval_step = eval_step await self.out_q.put(rollout) - @staticmethod - def error_rollout_output(*, error_type: str, error_repr: str) -> vf.RolloutOutput: - """Minimal ``vf.RolloutOutput`` for rollouts that never produced - real output (task exception, off-policy cancel).""" - out: vf.RolloutOutput = vf.RolloutOutput() - out["error"] = { - "error": error_type, - "error_chain_repr": error_repr, - "error_chain_str": error_repr, - } - out["trajectory"] = [] - out["completion"] = None - out["reward"] = 0.0 - out["is_truncated"] = False - out["metrics"] = {} - out["stop_condition"] = None - out["token_usage"] = { - "input_tokens": 0.0, - "output_tokens": 0.0, - "final_input_tokens": 0.0, - "final_output_tokens": 0.0, - } - return out - async def drop_group(self, group_id: uuid.UUID) -> int: """Cancel remaining in-flight tasks for this group and emit a ``Cancelled`` marker for every rollout it still owes the sink (both in-flight and not-yet-scheduled). Returns the count for off-policy metrics.""" group = self.groups.pop(group_id, None) + task_idx = group.task_idx if group is not None else -1 # Sync claim phase: pop matching tasks from ``self.inflight`` and # release their permits in one non-yielding sweep. After this loop @@ -590,8 +558,12 @@ async def drop_group(self, group_id: uuid.UUID) -> int: last_meta: InflightRollout | None = claimed[-1][1] if claimed else None for _, meta in claimed: for _ in range(meta.rollout_count): - raw = self.error_rollout_output(error_type="Cancelled", error_repr="Off-policy cancel") - await self.emit_rollout(meta, group, raw) + trace = Rollout( + task=vf.Task(idx=task_idx, instruction=""), + errors=[vf.Error(type="Cancelled", message="Off-policy cancel")], + stop_condition="error", + ) + await self.emit_rollout(meta, group, trace) # For non-group-scoring envs, the group may have rollouts that # were never dispatched (``rollouts_to_schedule > 0``). Emit @@ -612,8 +584,12 @@ async def drop_group(self, group_id: uuid.UUID) -> int: ) unscheduled_cancelled = group.rollouts_to_schedule for _ in range(unscheduled_cancelled): - raw = self.error_rollout_output(error_type="Cancelled", error_repr="Off-policy cancel") - await self.emit_rollout(fallback_meta, group, raw) + trace = Rollout( + task=vf.Task(idx=task_idx, instruction=""), + errors=[vf.Error(type="Cancelled", message="Off-policy cancel")], + stop_condition="error", + ) + await self.emit_rollout(fallback_meta, group, trace) cancelled = inflight_cancelled + unscheduled_cancelled if cancelled > 0: diff --git a/src/prime_rl/orchestrator/env_server/env_server.py b/src/prime_rl/orchestrator/env_server/env_server.py index 8ff057b9d9..2e0fd7ebc7 100644 --- a/src/prime_rl/orchestrator/env_server/env_server.py +++ b/src/prime_rl/orchestrator/env_server/env_server.py @@ -1,38 +1,31 @@ -import asyncio +from functools import partial -from verifiers.serve import ZMQEnvServer +from verifiers.v1 import pool_serve_kwargs +from verifiers.v1.serve import serve_env from prime_rl.configs.env_server import EnvServerConfig +from prime_rl.orchestrator.utils import setup_env_server_logging from prime_rl.utils.config import cli -from prime_rl.utils.logger import setup_logger -from prime_rl.utils.pathing import get_log_dir from prime_rl.utils.process import set_proc_title -from prime_rl.utils.utils import clean_exit, get_env_ids_to_install, install_env +from prime_rl.utils.utils import clean_exit @clean_exit def run_server(config: EnvServerConfig): - setup_logger(config.log.level, json_logging=config.log.json_logging) - - # install environment if not already installed - env_ids_to_install = set() - env_ids_to_install.update(get_env_ids_to_install([config.env])) - for env_id in env_ids_to_install: - install_env(env_id, prerelease=config.env_install_prerelease) - - log_dir = (get_log_dir(config.output_dir) / config.env.resolved_name).as_posix() - - server = ZMQEnvServer( - env_id=config.env.stripped_id, - env_args=config.env.args, - extra_env_kwargs=config.env.extra_env_kwargs, - log_level=config.log.level, - log_dir=log_dir, - json_logging=config.log.json_logging, - num_workers=config.env.num_workers, - **{"address": config.env.address} if config.env.address is not None else {}, + env = config.env + address = env.address or "tcp://127.0.0.1:5000" + # The env's ``pool`` (static or elastic) sizes the server; a v0/legacy env runs through + # the bridge, a v1 env is a native taskset — both serve vf.Trace over the same protocol, + # so the orchestrator is agnostic. serve_env applies the logging setup in this process + # and in every spawned worker. + server_kwargs = {"env_id": env.env_id, "env_args": env.args} if env.is_legacy else {"config": env} + serve_env( + **pool_serve_kwargs(env.pool), + legacy=env.is_legacy, + address=address, + log_setup=partial(setup_env_server_logging, config.log.level, config.log.json_logging), + **server_kwargs, ) - asyncio.run(server.run()) def main(): diff --git a/src/prime_rl/orchestrator/envs.py b/src/prime_rl/orchestrator/envs.py index fe02d2e61a..82ba344f3f 100644 --- a/src/prime_rl/orchestrator/envs.py +++ b/src/prime_rl/orchestrator/envs.py @@ -1,38 +1,88 @@ +"""Env wrappers over a v1 env server. + +Each ``Env`` owns a v1 ``EnvServer`` (spawned as a child process, or an +external one given by ``config.address``) and an ``EnvClient`` to drive it. The +orchestrator never *runs* an environment: it asks the server for ``info`` +(``num_tasks`` + whether group scoring is needed), then runs rollouts purely by +**task index**. The server returns a ``Trace`` (minus its computed fields) which we validate into a +``Trace[WireTask]`` — a real ``vf.Trace`` (never a loose dict) whose task keeps the env's +task-specific fields as extras (``WireTask`` allows them). The orchestrator never imports the +env package: the env's *type* and *runtime* both live only in the server, and the orchestrator +drives it purely by task index. (Nothing here reads typed env task fields — only ``task.idx`` +and a full ``task.model_dump``, both of which ``WireTask`` preserves.) +""" + from __future__ import annotations import asyncio import atexit import multiprocessing as mp -import time -from collections.abc import Awaitable, Callable, Iterator, Sequence +import os +import queue +import sys +from collections.abc import Iterator, Sequence from multiprocessing.process import BaseProcess from pathlib import Path from typing import Generic, TypeVar -import pandas as pd -import verifiers as vf -from verifiers.serve import ZMQEnvClient, ZMQEnvServer -from verifiers.utils.serve_utils import get_free_port +import verifiers.v1 as vf +from verifiers.v1.serve import EnvClient from prime_rl.configs.orchestrator import EnvConfig, EvalEnvConfig, TrainEnvConfig -from prime_rl.orchestrator.eval_utils import compute_pass_at_k -from prime_rl.utils.logger import ProgressTracker, get_logger -from prime_rl.utils.monitor import get_monitor -from prime_rl.utils.utils import capitalize - -REQUIRED_STATE_COLUMNS = ["trajectory"] +from prime_rl.orchestrator.types import Rollout +from prime_rl.utils.logger import get_logger + +# Every wire trace validates into this type. WireTask (extra="allow") keeps the env's task +# fields without importing the env package — the orchestrator never reads them typed (only +# task.idx + task.model_dump). +ROLLOUT_TYPE = Rollout[vf.WireTask] + +# Max wait for a spawned env server to bind and report its address. The child +# loads the taskset (possibly downloading a dataset) before reporting, so this +# is generous. +ENV_SERVER_SPAWN_TIMEOUT = 600.0 + + +def _run_env_server( + *, + log_file: str, + log_level: str, + json_logging: bool, + legacy: bool = False, + **kwargs, +) -> None: + """Spawned-process entry point: redirect this process's output to ``log_file`` (the + server's logging + any subprocess-runtime output), then serve via ``serve_env``. The + worker-pool sizing arrives in ``kwargs`` (``max_workers`` / ``multiplex`` / ``elastic`` + from the env's ``pool``). ``serve_env`` applies ``log_setup`` here and in every spawned + worker; a worker inherits this process's redirected stdout/stderr, so its per-rollout + logs reach ``log_file`` too. Top-level so it stays picklable for the ``spawn`` start + method. ``legacy`` picks the v0 bridge.""" + from functools import partial + + from verifiers.v1.serve import serve_env + + from prime_rl.orchestrator.utils import setup_env_server_logging + + fh = open(log_file, "w", buffering=1) + os.dup2(fh.fileno(), sys.stdout.fileno()) + os.dup2(fh.fileno(), sys.stderr.fileno()) + serve_env( + legacy=legacy, + log_setup=partial(setup_env_server_logging, log_level, json_logging), + **kwargs, + ) class Env: - """Wraps a vf.Environment - only exposes features used in PRIME-RL.""" + """Wraps a v1 env server + client. The orchestrator never loads the env.""" def __init__(self, config: EnvConfig): self.config = config self.sampling_args: dict = {} - - get_logger().debug(f"Initializing {config.resolved_name} ({config})") - self._env: vf.Environment = vf.load_environment(config.stripped_id, **config.args) - self._env_client: ZMQEnvClient | None = None + self.num_tasks: int = 0 + self.requires_group_scoring: bool = False + self._env_client: EnvClient | None = None self._env_server_process: BaseProcess | None = None @property @@ -40,122 +90,105 @@ def name(self) -> str: return self.config.resolved_name @property - def env(self) -> vf.Environment: - return self._env - - @property - def env_client(self) -> ZMQEnvClient: - if not self._env_client: - raise RuntimeError( - f"Env {self.name} has no env client connected. Call connect() first to connect to an env server." - ) + def env_client(self) -> EnvClient: + if self._env_client is None: + raise RuntimeError(f"Env {self.name} not started — call start() first.") return self._env_client - @property - def requires_group_scoring(self) -> bool: - return any(self.env.rubric._is_group_func(func) for func in self.env.rubric._get_reward_funcs()) - - async def start( - self, - log_dir: Path, - log_level: str | None = None, - json_logging: bool = False, - ) -> None: - """Spawn an env server (if needed) and connect to it.""" - if self.config.address is None: - address = self._spawn(log_dir=log_dir, log_level=log_level, json_logging=json_logging) - else: - address = self.config.address + async def start(self, log_dir: Path, log_level: str | None = None, json_logging: bool = False) -> None: + """Spawn the env server (if needed), connect, and cache its ``info``.""" + external = self.config.address is not None + address = self.config.address or await self._spawn(log_dir, log_level or "INFO", json_logging) get_logger().debug(f"Connecting {self.name} to env server {address}") - self._env_client = ZMQEnvClient(address=address, name=self.name) - await self.env_client.wait_for_server_startup() - - def _spawn( - self, - log_dir: Path, - log_level: str | None = None, - json_logging: bool = False, - ) -> str: - assert isinstance(self.config.num_workers, int), ( - f"num_workers must be resolved before spawn, got {self.config.num_workers!r}" + self._env_client = EnvClient(address=address) + # A spawned server already reported its address *after* binding + loading, + # so it's up — the untimed ``info`` below is enough. An external server has + # no such handshake, so poll until it answers before we block on ``info``. + if external: + await self.env_client.wait_for_server_startup() + info = await self.env_client.info() + self.num_tasks = info.num_tasks + self.requires_group_scoring = info.requires_group_scoring + get_logger().info( + f"Env {self.name} ready: num_tasks={self.num_tasks} group_scoring={self.requires_group_scoring}" ) - num_workers = self.config.num_workers - address = f"tcp://127.0.0.1:{get_free_port()}" - get_logger().debug(f"Spawning env server {self.name} ({address=}, {num_workers=})") - process = mp.get_context("spawn").Process( - target=ZMQEnvServer.run_server, - args=( - self.config.stripped_id, - self.config.args, - self.config.extra_env_kwargs, - log_level, - (log_dir / self.name).as_posix(), - ), + + async def _spawn(self, log_dir: Path, log_level: str, json_logging: bool) -> str: + """Spawn a v1 EnvServer child process (it loads the env; we never do). + The server binds an OS-assigned port (``:0``) and reports the concrete + address back over a queue — no free-port guess, no TOCTOU race. Its output + goes to ``/.log`` (``log_dir`` is already the train/eval-split + ``.../logs/envs/{train,eval}`` the orchestrator passes in).""" + ctx = mp.get_context("spawn") + address_queue: mp.Queue = ctx.Queue() + log_file = log_dir / f"{self.name}.log" + log_file.parent.mkdir(parents=True, exist_ok=True) + get_logger().debug(f"Spawning env server {self.name} (id={self.config.env_id}, log={log_file})") + server_kwargs = ( + dict( + legacy=True, + env_id=self.config.env_id, + env_args=self.config.args, + extra_env_kwargs=self.config.extra_env_kwargs, + ) + if self.config.is_legacy + else dict(legacy=False, config=self.config) + ) + process = ctx.Process( + target=_run_env_server, kwargs=dict( - address=address, + log_file=str(log_file), + log_level=log_level, json_logging=json_logging, - console_logging=False, - num_workers=num_workers, + **vf.pool_serve_kwargs(self.config.pool), + address="tcp://127.0.0.1:0", + address_queue=address_queue, + **server_kwargs, ), daemon=False, ) process.start() self._env_server_process = process + try: + address = await asyncio.to_thread(address_queue.get, timeout=ENV_SERVER_SPAWN_TIMEOUT) + except queue.Empty: + raise RuntimeError(f"Env server {self.name} did not report its address within {ENV_SERVER_SPAWN_TIMEOUT}s") + finally: + address_queue.close() + address_queue.join_thread() + get_logger().debug(f"Env server {self.name} bound at {address}") return address - def _sampling_args_with_salt(self, cache_salt: str | None) -> dict: - sampling_args = {**self.sampling_args} - if cache_salt is None: - return sampling_args - extra_body = {**sampling_args.get("extra_body", {}), "cache_salt": cache_salt} - sampling_args["extra_body"] = extra_body - return sampling_args - - @property - def state_columns(self) -> list[str]: - """Required columns plus any extras configured on the env, deduped (required first).""" - merged: list[str] = [] - for col in (*REQUIRED_STATE_COLUMNS, *self.config.state_columns): - if col not in merged: - merged.append(col) - return merged + def _sampling(self, cache_salt: str | None) -> vf.SamplingConfig: + sampling = {**self.sampling_args} + if cache_salt is not None: + sampling["extra_body"] = {**sampling.get("extra_body", {}), "cache_salt": cache_salt} + return vf.SamplingConfig(**sampling) async def run_rollout( - self, - client: vf.ClientConfig, - example: dict, - model_name: str, - cache_salt: str | None, - ) -> vf.RolloutOutput: - """Run a single rollout for an example.""" - return await self.env.run_rollout( - vf.RolloutInput(**example), + self, client: vf.ClientConfig, task_idx: int, model_name: str, cache_salt: str | None + ) -> Rollout: + """Run a single rollout for ``task_idx``; return a typed Trace.""" + wire = await self.env_client.run_rollout( + task_idx=task_idx, client=client, model=model_name, - sampling_args=self._sampling_args_with_salt(cache_salt), - max_retries=self.config.max_retries, - state_columns=self.state_columns, - env_client=self.env_client, + sampling=self._sampling(cache_salt), ) + return ROLLOUT_TYPE.model_validate(wire.to_wire()) async def run_group( - self, - client: vf.ClientConfig, - example: dict, - model_name: str, - group_size: int, - cache_salt: str | None, - ) -> list[vf.RolloutOutput]: - """Run a group of rollouts for an example. Required for group-scoring envs.""" - return await self.env.run_group( - [vf.RolloutInput(**example) for _ in range(group_size)], + self, client: vf.ClientConfig, task_idx: int, model_name: str, group_size: int, cache_salt: str | None + ) -> list[Rollout]: + """Run a group of rollouts for ``task_idx`` (group-scoring envs); return typed Traces.""" + wires = await self.env_client.run_group( + task_idx=task_idx, + n=group_size, client=client, model=model_name, - sampling_args=self._sampling_args_with_salt(cache_salt), - max_retries=self.config.max_retries, - state_columns=self.state_columns, - env_client=self.env_client, + sampling=self._sampling(cache_salt), ) + return [ROLLOUT_TYPE.model_validate(wire.to_wire()) for wire in wires] def shutdown(self) -> None: if self._env_server_process is None: @@ -171,9 +204,6 @@ def __init__(self, config: TrainEnvConfig): super().__init__(config) self.sampling_args = config.sampling.to_sampling_args() - def get_dataset(self, seed: int | None = None): - return self.env.get_dataset(seed=seed) - class EvalEnv(Env): config: EvalEnvConfig @@ -181,148 +211,12 @@ class EvalEnv(Env): def __init__(self, config: EvalEnvConfig): super().__init__(config) self.sampling_args = config.sampling.to_sampling_args() - self.examples = self.env.get_eval_dataset(n=config.num_examples).to_list() - - async def evaluate( - self, - model_name: str, - get_client: Callable[[], Awaitable[vf.ClientConfig]], - step: int, - cache_salt: str, - ) -> list[vf.RolloutOutput]: - num_examples = len(self.examples) - group_size = self.config.group_size - get_logger().info(f"Evaluating {self.name} ({num_examples=}, {group_size=})") - total_rollouts = num_examples * group_size - pbar = ProgressTracker(total=total_rollouts, desc=f"Evaluating {self.name}") - eval_start = time.perf_counter() - - if self.requires_group_scoring: - - async def run_with_progress(example: dict) -> list[vf.RolloutOutput] | None: - """Run group_size rollouts as a scored group for one example.""" - try: - client = await get_client() - outputs = await self.run_group( - client=client, - example=example, - model_name=model_name, - group_size=group_size, - cache_salt=cache_salt, - ) - pbar.update(group_size) - return outputs - except Exception as e: - get_logger().warning(f"Group failed: {e}") - pbar.update(group_size) - return None - - coros = [run_with_progress(example) for example in self.examples] - - else: - - async def run_with_progress(example: dict) -> list[vf.RolloutOutput] | None: - """Run a single rollout for one example.""" - try: - client = await get_client() - output = await self.run_rollout( - client=client, example=example, model_name=model_name, cache_salt=cache_salt - ) - pbar.update(1) - return [output] - except Exception as e: - get_logger().warning(f"Rollout failed: {e}") - pbar.update(1) - return None - - coros = [run_with_progress(example) for example in self.examples for _ in range(group_size)] - - try: - results = await asyncio.gather(*coros) - finally: - pbar.close() + self.examples: list[dict] = [] - successful_outputs = [o for group in results if group is not None for o in group] - failed_count = total_rollouts - len(successful_outputs) - eval_time = time.perf_counter() - eval_start - - if failed_count: - get_logger().warning( - f"{failed_count}/{total_rollouts} ({failed_count / total_rollouts * 100:.1f}%) rollouts failed" - ) - - if not successful_outputs: - get_logger().warning(f"All rollouts failed for {self.name}, skipping logging metrics") - get_monitor().log( - { - f"eval/{self.name}/failed_rollouts": failed_count / total_rollouts, - "step": step, - }, - step=step, - ) - return [] - - # Log metrics - monitor = get_monitor() - - rows = [ - { - "example_id": o["example_id"], - "reward": o["reward"], - "completion_len": o["token_usage"]["final_output_tokens"], - "is_truncated": o["is_truncated"], - "has_error": o.get("error") is not None, - "no_response": not o.get("completion"), - } - for o in successful_outputs - ] - results_df = pd.DataFrame(rows) - - unique_rewards = results_df.reward.dropna().unique() - could_be_binary = set(unique_rewards).issubset({0.0, 1.0}) - if could_be_binary: - pass_at_k = ( - results_df.groupby("example_id") - .apply(lambda x: compute_pass_at_k(x.reward.dropna()), include_groups=False) - .apply(pd.Series) - ) - else: - pass_at_k = None - get_logger().warning("Skipping computing pass@k rates because the task rewards appear to be non-binary") - - message = f"Evaluated {self.name} in {eval_time:.2f}s (Avg@{group_size}={results_df.reward.mean():.4f}" - if could_be_binary: - assert pass_at_k is not None - for pass_rate, pass_rate_score in pd.Series(pass_at_k.mean()).items(): - message += f", {capitalize(str(pass_rate))}: {pass_rate_score:.4f}" - - message += ( - f", No-response: {results_df.no_response.mean() * 100:.1f}%" - f", Completion Length: {results_df.completion_len.mean():.2f} (±{results_df.completion_len.std():.2f}, ∈[{results_df.completion_len.min():.2f}, {results_df.completion_len.max():.2f}])" - f", Truncated: {results_df.is_truncated.mean() * 100:.1f}%)" - ) - get_logger().success(message) - - eval_metrics = { - f"avg@{group_size}": float(results_df.reward.mean()), - "no_response/mean": float(results_df.no_response.mean()), - "no_response/count": int(results_df.no_response.sum()), - "completion_len/mean": results_df.completion_len.mean().item(), - "completion_len/max": results_df.completion_len.max().item(), - "completion_len/min": results_df.completion_len.min().item(), - "is_truncated/mean": results_df.is_truncated.mean().item(), - "failed_rollouts": failed_count / total_rollouts, - "time": eval_time, - } - if could_be_binary: - assert pass_at_k is not None - eval_metrics.update(pd.Series(pass_at_k.mean()).to_dict()) - eval_metrics = {f"eval/{self.name}/{key}": v for key, v in eval_metrics.items()} - eval_metrics["step"] = step - monitor.log(eval_metrics, step=step) - monitor.log_eval_samples(successful_outputs, env_name=self.name, step=step) - - return successful_outputs + async def start(self, log_dir: Path, log_level: str | None = None, json_logging: bool = False) -> None: + await super().start(log_dir=log_dir, log_level=log_level, json_logging=json_logging) + n = self.num_tasks if self.config.num_examples < 0 else min(self.config.num_examples, self.num_tasks) + self.examples = [{"task_idx": i} for i in range(n)] EnvT = TypeVar("EnvT", bound=Env) @@ -350,23 +244,15 @@ def __iter__(self) -> Iterator[EnvT]: def __len__(self) -> int: return len(self._envs) - async def start( - self, - log_dir: Path, - log_level: str | None = None, - json_logging: bool = False, - ) -> None: - """Spawn env servers (where needed) and connect env clients one at a time. - - Serialized to avoid a TOCTOU port race: get_free_port() only holds the port - until it returns, so parallel spawns can hand the same port to two children. - """ + async def start(self, log_dir: Path, log_level: str | None = None, json_logging: bool = False) -> None: + """Spawn env servers (where needed) and connect, one at a time. Each server + binds an OS-assigned port and reports it back, so there's no port race.""" for env in self: await env.start(log_dir=log_dir, log_level=log_level, json_logging=json_logging) atexit.register(self.shutdown) def shutdown(self) -> None: - """Terminate all spawned env server processes in parallel.""" + """Terminate all spawned env server processes.""" processes = [env._env_server_process for env in self if env._env_server_process is not None] if not processes: return diff --git a/src/prime_rl/orchestrator/eval_sink.py b/src/prime_rl/orchestrator/eval_sink.py index 2c21841678..f6ceda5f58 100644 --- a/src/prime_rl/orchestrator/eval_sink.py +++ b/src/prime_rl/orchestrator/eval_sink.py @@ -18,7 +18,7 @@ from prime_rl.orchestrator.envs import EvalEnvs from prime_rl.orchestrator.eval_utils import compute_pass_at_k -from prime_rl.orchestrator.types import EvalBatch, EvalBatchMetrics, EvalRollout +from prime_rl.orchestrator.types import EvalBatch, EvalBatchMetrics, Rollout from prime_rl.utils.logger import get_logger @@ -27,12 +27,12 @@ class EvalSink: def __init__(self, *, eval_envs: EvalEnvs) -> None: self.eval_envs = eval_envs - self.pending_groups: dict[uuid.UUID, list[EvalRollout]] = defaultdict(list) + self.pending_groups: dict[uuid.UUID, list[Rollout]] = defaultdict(list) # Bucket size IS the arrival count — ``process_group`` flushes # everything in without filtering - self.pending_batches: dict[tuple[str, int], list[EvalRollout]] = defaultdict(list) + self.pending_batches: dict[tuple[str, int], list[Rollout]] = defaultdict(list) - def add(self, rollout: EvalRollout) -> EvalBatch | None: + def add(self, rollout: Rollout) -> EvalBatch | None: """Process one arrival; finalize the group on the ``group_size``-th arrival and the per-env epoch on the ``num_examples × group_size``-th.""" env_name = rollout.env_name @@ -82,7 +82,7 @@ def batch_progress(self) -> list[tuple[str, int, int, int, int]]: # ── level 1: per-rollout (no-op for eval) ───────────────────────────── - def process_rollout(self, rollout: EvalRollout) -> None: + def process_rollout(self, rollout: Rollout) -> None: """No-op. Eval rollouts don't need trainer-bound tokenization; the method exists to keep the three-level structure uniform with ``TrainSink``. @@ -96,17 +96,17 @@ def process_group(self, group_id: uuid.UUID) -> None: if not group: return env_name = group[0].env_name - example_id = group[0].example_id + task_idx = group[0].task.idx eval_step = group[0].eval_step bucket = self.pending_batches[(env_name, eval_step)] bucket.extend(group) - survivors = [r for r in group if r.error is None] + survivors = [r for r in group if not r.has_error] num_errored = len(group) - len(survivors) rewards = [r.reward for r in survivors] avg_reward = sum(rewards) / len(rewards) if rewards else 0.0 get_logger().debug( - f"Finished group | env={env_name} example_id={example_id} eval_step={eval_step} | " + f"Finished group | env={env_name} task_idx={task_idx} eval_step={eval_step} | " f"rollouts={len(group)} (errored={num_errored}) | reward={avg_reward:.4f}" ) @@ -120,9 +120,9 @@ def process_batch(self, key: tuple[str, int]) -> EvalBatch: rollouts = self.pending_batches.pop(key, []) n_total = len(rollouts) - n_cancelled = sum(1 for r in rollouts if (r.error or {}).get("error") == "Cancelled") - n_errored = sum(1 for r in rollouts if r.error is not None) - n_cancelled - valid = [r for r in rollouts if r.error is None] + n_cancelled = sum(1 for r in rollouts if r.has_error and r.error.type == "Cancelled") + n_errored = sum(1 for r in rollouts if r.has_error) - n_cancelled + valid = [r for r in rollouts if not r.has_error] metrics = EvalBatchMetrics( n_rollouts=n_total, n_cancelled=n_cancelled, @@ -131,23 +131,23 @@ def process_batch(self, key: tuple[str, int]) -> EvalBatch: if valid: rewards = [r.reward for r in valid] - lens = [r.raw["token_usage"]["final_output_tokens"] for r in valid] + lens = [r.completion_len for r in valid] metrics.group_size = self.group_size_for(env_name) metrics.reward_mean = float(sum(rewards) / len(rewards)) metrics.completion_len_mean = float(sum(lens) / len(lens)) metrics.completion_len_max = float(max(lens)) metrics.completion_len_min = float(min(lens)) metrics.truncation_rate = float(sum(1 for r in valid if r.is_truncated) / len(valid)) - metrics.no_response_rate = float(sum(1 for r in valid if not r.raw.get("completion")) / len(valid)) - num_turns = [len(r.raw.get("trajectory") or []) for r in valid] + metrics.no_response_rate = float(sum(1 for r in valid if not r.has_response) / len(valid)) + num_turns = [r.num_turns for r in valid] metrics.num_turns_mean = float(sum(num_turns) / len(num_turns)) metrics.num_turns_min = float(min(num_turns)) metrics.num_turns_max = float(max(num_turns)) # pass@k: errored attempts don't count toward k tries - by_example: dict[int | str, list[float]] = {} + by_example: dict[int, list[float]] = {} for r in valid: - by_example.setdefault(r.example_id, []).append(r.reward) + by_example.setdefault(r.task.idx, []).append(r.reward) metrics.n_examples = len(by_example) unique_rewards = {float(r) for r in rewards} if unique_rewards.issubset({0.0, 1.0}) and by_example: diff --git a/src/prime_rl/orchestrator/filters.py b/src/prime_rl/orchestrator/filters.py index f8deda1230..bd849d8200 100644 --- a/src/prime_rl/orchestrator/filters.py +++ b/src/prime_rl/orchestrator/filters.py @@ -16,7 +16,7 @@ from prime_rl.utils.logger import get_logger if TYPE_CHECKING: - from prime_rl.orchestrator.types import TrainRollout + from prime_rl.orchestrator.types import Rollout @dataclass @@ -29,7 +29,7 @@ class RolloutFilter(Protocol): name: str enforce: bool - def check(self, rollout: "TrainRollout") -> FilterResult: ... + def check(self, rollout: "Rollout") -> FilterResult: ... @dataclass @@ -49,13 +49,11 @@ class GibberishFilter: logprob_threshold: float enforce: bool = False - def check(self, rollout: "TrainRollout") -> FilterResult: + def check(self, rollout: "Rollout") -> FilterResult: global_idx = 0 - for step in rollout.raw["trajectory"]: - tokens = step["tokens"] - if tokens is None: - continue - for token_id, logprob in zip(tokens["completion_ids"], tokens["completion_logprobs"]): + for node in rollout.nodes: + completion = [t for t, m in zip(node.token_ids, node.mask) if m] + for token_id, logprob in zip(completion, node.logprobs): if token_id > self.token_id_threshold and logprob < self.logprob_threshold: return FilterResult(detected=True, detection_index=global_idx) global_idx += 1 @@ -79,14 +77,11 @@ class RepetitionFilter: logprob_threshold: float enforce: bool = False - def check(self, rollout: "TrainRollout") -> FilterResult: + def check(self, rollout: "Rollout") -> FilterResult: consecutive = 0 global_idx = 0 - for step in rollout.raw["trajectory"]: - tokens = step["tokens"] - if tokens is None: - continue - for logprob in tokens["completion_logprobs"]: + for node in rollout.nodes: + for logprob in node.logprobs: if logprob > self.logprob_threshold: consecutive += 1 else: @@ -105,7 +100,7 @@ class ZeroAdvantageFilter: name: str enforce: bool = True - def check(self, rollout: "TrainRollout") -> FilterResult: + def check(self, rollout: "Rollout") -> FilterResult: if rollout.advantage is not None and rollout.advantage == 0.0: return FilterResult(detected=True) return FilterResult(detected=False) @@ -147,8 +142,8 @@ def setup_filters(configs: list[FilterConfig], vocab_size: int, *, kind: str) -> return filters -def apply_filters(filters: list[RolloutFilter], rollouts: list["TrainRollout"]) -> None: # noqa: F821 (forward ref) - """Flag ``TrainRollout``\\ s in place with per-filter detection + drop decision. +def apply_filters(filters: list[RolloutFilter], rollouts: list["Rollout"]) -> None: # noqa: F821 (forward ref) + """Flag ``Rollout``\\ s in place with per-filter detection + drop decision. Each rollout's ``filter_results`` dict records per-filter detection bools; ``is_filtered`` is True iff an enforcing filter detected it. First matching diff --git a/src/prime_rl/orchestrator/metrics.py b/src/prime_rl/orchestrator/metrics.py index 87ec99c424..459a674268 100644 --- a/src/prime_rl/orchestrator/metrics.py +++ b/src/prime_rl/orchestrator/metrics.py @@ -7,7 +7,7 @@ import pandas as pd from prime_rl.configs.orchestrator import OrchestratorConfig -from prime_rl.orchestrator.types import Progress, TrainBatchMetrics, TrainRollout +from prime_rl.orchestrator.types import Progress, Rollout, TrainBatchMetrics class MetricsBuilder: @@ -18,7 +18,7 @@ def build( self, *, step: int, - rollouts: list[TrainRollout], + rollouts: list[Rollout], metrics: TrainBatchMetrics, progress: Progress, step_time: float, @@ -32,30 +32,25 @@ def build( existing dashboards / alerts keep working.""" num_rollouts = len(rollouts) num_unique_examples = len({r.group_id for r in rollouts}) - num_tokens = sum( - r.raw["token_usage"]["final_input_tokens"] + r.raw["token_usage"]["final_output_tokens"] for r in rollouts - ) + num_tokens = sum(r.total_tokens for r in rollouts) results_df = pd.DataFrame( { "group_id": [r.group_id for r in rollouts], - "example_id": [r.example_id for r in rollouts], + "task_idx": [r.task.idx for r in rollouts], "env_name": [r.env_name for r in rollouts], "reward": [r.reward for r in rollouts], "is_truncated": [r.is_truncated for r in rollouts], "is_filtered": [r.is_filtered for r in rollouts], - "stop_condition": [r.raw.get("stop_condition") for r in rollouts], - "seq_len": [ - r.raw["token_usage"]["final_input_tokens"] + r.raw["token_usage"]["final_output_tokens"] - for r in rollouts - ], + "stop_condition": [r.stop_condition for r in rollouts], + "seq_len": [r.total_tokens for r in rollouts], "prefill_len": metrics.rollout_prefill_lens, "decode_len": metrics.rollout_decode_lens, "samples_per_rollout": metrics.samples_per_rollout, - "num_turns": [len(r.raw["trajectory"]) for r in rollouts], + "num_turns": [r.num_turns for r in rollouts], } ) - metrics_df = pd.DataFrame([(r.raw.get("metrics") or {}) for r in rollouts]) + metrics_df = pd.DataFrame([r.metrics for r in rollouts]) filter_df = pd.DataFrame([r.filter_results for r in rollouts]) timing_df = self.timing_df(rollouts) @@ -182,18 +177,11 @@ def compute_solve_rates(df): return to_log @staticmethod - def timing_df(rollouts: list[TrainRollout]) -> pd.DataFrame: - return pd.DataFrame( - [ - { - "total": r.raw["timing"]["total"], - "setup": r.raw["timing"]["setup"]["duration"], - "generation": r.raw["timing"]["generation"]["duration"], - "model": r.raw["timing"]["model"]["duration"], - "env": r.raw["timing"]["env"]["duration"], - "scoring": r.raw["timing"]["scoring"]["duration"], - "overhead": r.raw["timing"]["overhead"], - } - for r in rollouts - ] - ) + def timing_df(rollouts: list[Rollout]) -> pd.DataFrame: + """Per-rollout timing from the v1 Trace (`generation`/`scoring` spans).""" + rows = [] + for r in rollouts: + timing = r.timing + generation, scoring = timing.generation.duration, timing.scoring.duration + rows.append({"total": generation + scoring, "generation": generation, "scoring": scoring}) + return pd.DataFrame(rows) diff --git a/src/prime_rl/orchestrator/orchestrator.py b/src/prime_rl/orchestrator/orchestrator.py index 902c8b963b..575ad576b6 100644 --- a/src/prime_rl/orchestrator/orchestrator.py +++ b/src/prime_rl/orchestrator/orchestrator.py @@ -22,7 +22,6 @@ import asyncio import ctypes -import logging import os import time from typing import TYPE_CHECKING @@ -37,8 +36,6 @@ from prime_rl.transport.base import TrainingBatchSender from prime_rl.utils.client import InferencePool from prime_rl.utils.monitor.base import Monitor -from verifiers.utils.async_utils import EventLoopLagMonitor, EventLoopLagStats - import prime_rl._compat # noqa: F401 — patch ring_flash_attn compat before transitive imports from prime_rl.configs.orchestrator import OrchestratorConfig from prime_rl.orchestrator.ckpt import setup_ckpt_manager @@ -58,12 +55,10 @@ from prime_rl.orchestrator.train_source import TrainSource from prime_rl.orchestrator.types import ( EvalBatch, - EvalRollout, - FinishedRollout, Policy, Progress, + Rollout, TrainBatch, - TrainRollout, ) from prime_rl.orchestrator.utils import ( compute_teacher_logprobs, @@ -76,7 +71,7 @@ from prime_rl.orchestrator.watcher import WeightWatcher from prime_rl.trainer.model import setup_tokenizer from prime_rl.transport import TrainingBatch, setup_training_batch_sender -from prime_rl.utils.async_utils import safe_cancel +from prime_rl.utils.async_utils import EventLoopLagMonitor, EventLoopLagStats, safe_cancel from prime_rl.utils.client import init_nccl_broadcast, setup_inference_pool from prime_rl.utils.heartbeat import Heartbeat from prime_rl.utils.logger import format_time, get_logger, setup_logger @@ -85,8 +80,6 @@ from prime_rl.utils.usage_reporter import UsageReporter from prime_rl.utils.utils import ( clean_exit, - get_env_ids_to_install, - install_env, resolve_latest_ckpt_step, ) @@ -153,10 +146,9 @@ class Orchestrator: def __init__(self, config: OrchestratorConfig) -> None: self.config = config setup_logger(config.log.level, json_logging=config.log.json_logging) - # Silence in-process ``verifiers.*`` library noise but keep - # ``verifiers.serve`` (env-server lifecycle) through our handler - logging.getLogger("verifiers").setLevel(logging.CRITICAL + 1) - intercept_vf_logging(logger="verifiers.serve", level="WARN") + # Route the in-process v1 library logging through our handler. The + # env server runs in a child process, so its logging is separate. + intercept_vf_logging(logger="verifiers.v1", level="WARN") get_logger().info(f"Starting orchestrator ({config.training_mode})") if config.bench: @@ -206,11 +198,9 @@ async def setup(self) -> None: with open(config_dir / "orch.toml", "wb") as f: tomli_w.dump(config.model_dump(exclude_none=True, mode="json"), f) - env_ids_to_install = set(get_env_ids_to_install(config.train.env)) - if config.eval is not None: - env_ids_to_install.update(get_env_ids_to_install(config.eval.env)) - for env_id in env_ids_to_install: - install_env(env_id, prerelease=config.env_install_prerelease) + # TODO(v1, experimental): temporary. v1 envs are local packages + # installed in this venv (no prime-env hub install); the env server imports + # them in its own child process. get_logger().info(f"Initializing tokenizer ({config.tokenizer})") self.tokenizer = setup_tokenizer(config.tokenizer) @@ -237,7 +227,11 @@ async def setup(self) -> None: self.teacher_inference = await setup_inference_pool( config.teacher.client, model_name=config.teacher.model.name, - train_client_type="openai_chat_completions", + # SFT rolls the teacher out through the renderer client (token-in/out) so its + # rollouts carry tokens directly — training is renderer-only. (OPD reads teacher + # logprobs via prefill, so its pool client type is moot.) + train_client_type="renderer", + renderer_config=config.renderer, ) get_logger().info(f"Initializing monitor (wandb={config.wandb}, prime_monitor={config.prime_monitor})") @@ -384,7 +378,6 @@ async def setup(self) -> None: self.train_sink = TrainSink( config, tokenizer=self.tokenizer, - renderer=self.renderer, train_envs=self.train_envs, mm_token_type_ids_mapping=self.mm_token_type_ids_mapping, batch_size=config.batch_size, @@ -491,18 +484,17 @@ async def main_loop(self) -> None: break try: - rollout: FinishedRollout = await asyncio.wait_for(self.dispatcher.out_q.get(), timeout=0.5) + rollout: Rollout = await asyncio.wait_for(self.dispatcher.out_q.get(), timeout=0.5) except asyncio.TimeoutError: continue - if isinstance(rollout, EvalRollout): + if rollout.kind == "eval": assert self.eval_sink is not None # eval rollouts only emitted when eval is configured eval_batch = self.eval_sink.add(rollout) if eval_batch is not None: self.finalize_eval_batch(eval_batch) continue - assert isinstance(rollout, TrainRollout) train_batch = await self.train_sink.add(rollout) # In drain mode any late-arriving train batch is dropped — we # don't want to ship past ``max_steps`` @@ -555,13 +547,10 @@ async def finalize_train_batch(self, batch: TrainBatch) -> None: f"({batch.metrics.n_trainable / len(batch.rollouts):.1%}) — consider reviewing task difficulty / filter config" ) - # Materialize at the I/O boundary so prime-rl metadata travels with - # the raw vf payload on disk + in wandb sample tables - rollout_dicts = [r.to_dict() for r in batch.rollouts] + # Serialize the typed Trace at the I/O boundary (disk + wandb sample tables). + rollout_dicts = [r.model_dump(mode="json") for r in batch.rollouts] step_path = get_step_path(get_rollout_dir(config.output_dir), step) - await asyncio.to_thread( - save_rollouts, rollout_dicts, step_path / "train_rollouts.jsonl", exclude_keys={"trajectory"} - ) + await asyncio.to_thread(save_rollouts, rollout_dicts, step_path / "train_rollouts.jsonl") teacher_logprobs_time = 0.0 # opd only if config.training_mode == "opd" and self.teacher_inference is not None: @@ -592,7 +581,7 @@ async def finalize_train_batch(self, batch: TrainBatch) -> None: pre_filter_dropped_by_name=dict(self.train_sink.pre_filter_dropped_by_name), ) self.monitor.log(metrics, step=step) - self.monitor.log_samples(rollout_dicts, step=step) + self.monitor.log_samples(batch.rollouts, step=step) self.monitor.log_distributions( distributions={ "rewards": [r.reward for r in batch.rollouts], @@ -614,10 +603,7 @@ async def finalize_train_batch(self, batch: TrainBatch) -> None: num_rollouts = len(batch.rollouts) num_unique_examples = len({r.group_id for r in batch.rollouts}) - num_tokens = sum( - r.raw["token_usage"]["final_input_tokens"] + r.raw["token_usage"]["final_output_tokens"] - for r in batch.rollouts - ) + num_tokens = sum(r.total_tokens for r in batch.rollouts) self.progress.total_tokens += num_tokens self.progress.total_samples += num_rollouts self.progress.total_problems += num_unique_examples @@ -724,7 +710,7 @@ def log_train_batch(self, batch: TrainBatch, *, step: int, step_time: float) -> trainable_rate = (n_trainable / n_survivors) if n_survivors else 0.0 reward_mean = sum(r.reward for r in batch.rollouts) / max(n_survivors, 1) max_off_policy = max((r.off_policy_steps for r in batch.rollouts), default=0) - turns_mean = sum(len(r.raw.get("trajectory") or []) for r in batch.rollouts) / max(n_survivors, 1) + turns_mean = sum(r.num_turns for r in batch.rollouts) / max(n_survivors, 1) truncation_rate = sum(1 for r in batch.rollouts if r.is_truncated) / max(n_survivors, 1) head = ( @@ -748,11 +734,7 @@ def log_train_batch(self, batch: TrainBatch, *, step: int, step_time: float) -> env_error_rate = (n_env_errors / n_env_arrivals) if n_env_arrivals else 0.0 env_reward = (sum(r.reward for r in env_rollouts) / len(env_rollouts)) if env_rollouts else 0.0 env_max_off_policy = max((r.off_policy_steps for r in env_rollouts), default=0) - env_turns = ( - sum(len(r.raw.get("trajectory") or []) for r in env_rollouts) / len(env_rollouts) - if env_rollouts - else 0.0 - ) + env_turns = sum(r.num_turns for r in env_rollouts) / len(env_rollouts) if env_rollouts else 0.0 env_truncation = sum(1 for r in env_rollouts if r.is_truncated) / len(env_rollouts) if env_rollouts else 0.0 lines.append( f"╰─ {env_name:<{name_width}} | Ratio {ratio:.1%} | Reward {env_reward:.4f} | " @@ -768,14 +750,13 @@ def finalize_eval_batch(self, batch: EvalBatch) -> None: get_logger().warning(f"Eval @ step={batch.step} env={batch.env_name}: no surviving rollouts, skipping log") return - rollout_dicts = [r.to_dict() for r in batch.rollouts] + rollout_dicts = [r.model_dump(mode="json") for r in batch.rollouts] step_path = get_step_path(get_rollout_dir(self.config.output_dir), batch.step) save_rollouts( rollout_dicts, step_path / f"eval_rollouts_{batch.env_name}.jsonl", - exclude_keys={"trajectory"}, ) - self.monitor.log_eval_samples(rollout_dicts, env_name=batch.env_name, step=batch.step) + self.monitor.log_eval_samples(batch.rollouts, env_name=batch.env_name, step=batch.step) self.monitor.log(batch.metrics.to_wandb_dict(env_name=batch.env_name, step=batch.step), step=batch.step) n_total = batch.metrics.n_rollouts diff --git a/src/prime_rl/orchestrator/train_sink.py b/src/prime_rl/orchestrator/train_sink.py index 26e7b915b0..d28cfcb525 100644 --- a/src/prime_rl/orchestrator/train_sink.py +++ b/src/prime_rl/orchestrator/train_sink.py @@ -21,12 +21,8 @@ from prime_rl.orchestrator.advantage import assign_advantages, setup_advantage_fn from prime_rl.orchestrator.envs import TrainEnvs from prime_rl.orchestrator.filters import RolloutFilter, apply_filters -from prime_rl.orchestrator.trajectories import ( - backfill_rollout_tokens, - interleave_rollout, - offload_images_to_disk, -) -from prime_rl.orchestrator.types import TrainBatch, TrainBatchMetrics, TrainRollout +from prime_rl.orchestrator.trajectories import trace_to_samples +from prime_rl.orchestrator.types import Rollout, TrainBatch, TrainBatchMetrics from prime_rl.transport import TrainingSample from prime_rl.utils.logger import get_logger @@ -39,7 +35,6 @@ def __init__( config: OrchestratorConfig, *, tokenizer, - renderer, train_envs: TrainEnvs, mm_token_type_ids_mapping: dict[int, int] | None, batch_size: int | None, @@ -53,7 +48,6 @@ def __init__( ) self.config = config self.tokenizer = tokenizer - self.renderer = renderer self.train_envs = train_envs self.mm_token_type_ids_mapping = mm_token_type_ids_mapping self.batch_size = batch_size @@ -64,11 +58,11 @@ def __init__( self.pre_filters = pre_filters self.post_filters = post_filters - # Keyed by the dispatcher's group UUID. ``(env_name, example_id)`` - # isn't unique — the same example can be re-sampled while an + # Keyed by the dispatcher's group UUID. ``(env_name, task_idx)`` + # isn't unique — the same task can be re-sampled while an # earlier group is still in flight - self.pending_groups: dict[uuid.UUID, list[TrainRollout]] = defaultdict(list) - self.pending_batch: list[TrainRollout] = [] + self.pending_groups: dict[uuid.UUID, list[Rollout]] = defaultdict(list) + self.pending_batch: list[Rollout] = [] # Reset by the orchestrator after each ship via ``reset_pre_filter_stats`` self.pre_filter_seen = 0 @@ -83,13 +77,13 @@ def __init__( def group_size_for(self, env_name: str) -> int: return self.train_envs.get(env_name).config.group_size - def in_progress_groups(self) -> list[list[TrainRollout]]: + def in_progress_groups(self) -> list[list[Rollout]]: """Per-rollout groups currently accumulating in ``pending_groups`` — i.e. groups that haven't hit ``group_size`` yet, so the pipeline log can reflect partial-group progress. Skips group-scoring envs (whose rollouts only make sense as a unit — the user expects per-group fill, not per-rollout, for those).""" - out: list[list[TrainRollout]] = [] + out: list[list[Rollout]] = [] for rollouts in self.pending_groups.values(): if not rollouts: continue @@ -107,10 +101,7 @@ def batch_progress(self) -> tuple[int, int, str]: if self.batch_size is not None: return len(self.pending_batch), self.batch_size, "rollouts" assert self.token_batch_size is not None - tokens = sum( - r.raw["token_usage"]["final_input_tokens"] + r.raw["token_usage"]["final_output_tokens"] - for r in self.pending_batch - ) + tokens = sum(r.total_tokens for r in self.pending_batch) return tokens, self.token_batch_size, "tokens" def buffered_count(self) -> int: @@ -126,13 +117,13 @@ def pending_batch_by_env(self) -> dict[str, int]: counts[r.env_name] += 1 return dict(counts) - async def add(self, rollout: TrainRollout) -> TrainBatch | None: + async def add(self, rollout: Rollout) -> TrainBatch | None: """Process one arrival; finalize the group on the ``group_size``-th arrival; return a ``TrainBatch`` if the batch threshold is met.""" await self.process_rollout(rollout) env_name = rollout.env_name self.arrivals_by_env[env_name] += 1 - if rollout.error is not None: + if rollout.has_error: self.errors_by_env[env_name] += 1 self.pending_groups[rollout.group_id].append(rollout) if len(self.pending_groups[rollout.group_id]) >= self.group_size_for(env_name): @@ -140,37 +131,26 @@ async def add(self, rollout: TrainRollout) -> TrainBatch | None: ready = ( len(self.pending_batch) >= self.batch_size if self.batch_size is not None - else sum( - r.raw["token_usage"]["final_input_tokens"] + r.raw["token_usage"]["final_output_tokens"] - for r in self.pending_batch - ) - >= (self.token_batch_size or 0) + else sum(r.total_tokens for r in self.pending_batch) >= (self.token_batch_size or 0) ) if ready: return self.process_batch() return None - async def process_rollout(self, rollout: TrainRollout) -> None: - """Tokenize the rollout eagerly. Backfills tokens if the env didn't - return them (SFT against external teacher APIs); errored rollouts - skip tokenization and get dropped at the group level.""" - if rollout.error is not None: + async def process_rollout(self, rollout: Rollout) -> None: + """Build training samples from the rollout's Trace (one per branch), walking the + message graph. Training is renderer-only across all modes (RL/OPD student, SFT teacher), + so every node already carries its tokens. Errored rollouts are dropped at the group + level, so skip them here.""" + if rollout.has_error: return - raw = rollout.raw - needs_backfill = any(s["tokens"] is None for s in raw.get("trajectory") or []) - if needs_backfill: - await asyncio.to_thread(backfill_rollout_tokens, raw, self.tokenizer, renderer=self.renderer) samples = await asyncio.to_thread( - interleave_rollout, - raw, - mm_token_type_ids_mapping=self.mm_token_type_ids_mapping, + trace_to_samples, + rollout, env_name=rollout.env_name, + mm_token_type_ids_mapping=self.mm_token_type_ids_mapping, ) rollout.samples = samples or [] - # Offload base64 image bytes to disk as soon as the rollout is - # tokenized, so memory stays flat instead of holding every buffered - # rollout's images until the batch ships (no-op for text-only). - await asyncio.to_thread(offload_images_to_disk, [raw], self.config.output_dir) def process_group(self, group_id: uuid.UUID) -> None: """Finalize one GRPO group: drop errored rollouts (the whole group @@ -180,8 +160,8 @@ def process_group(self, group_id: uuid.UUID) -> None: if not group: return env_name = group[0].env_name - example_id = group[0].example_id - survivors = [r for r in group if r.error is None] + task_idx = group[0].task.idx + survivors = [r for r in group if not r.has_error] num_errored = len(group) - len(survivors) # Group-scoring envs: any failure makes survivors' rewards unsafe @@ -189,13 +169,13 @@ def process_group(self, group_id: uuid.UUID) -> None: env = self.train_envs.get(env_name) if num_errored > 0 and env.requires_group_scoring: get_logger().debug( - f"Finished group | env={env_name} example_id={example_id} | " + f"Finished group | env={env_name} task_idx={task_idx} | " f"rollouts={len(group)} (errored={num_errored}) | dropped: group-scored partial" ) return if not survivors: get_logger().debug( - f"Finished group | env={env_name} example_id={example_id} | " + f"Finished group | env={env_name} task_idx={task_idx} | " f"rollouts={len(group)} (errored={num_errored}) | dropped: all failed" ) return @@ -204,8 +184,8 @@ def process_group(self, group_id: uuid.UUID) -> None: # Propagate to the pre-tokenized samples so the orchestrator can # collect samples at ship time without re-walking rollouts. The env - # has a single sampling temperature; fan it out across each sample's - # completion tokens here (interleave leaves it empty). + # has a single sampling temperature; fan it out per token (context + # tokens are masked out, so their temperature is don't-care). temperature = env.sampling_args["temperature"] for r in survivors: for sample in r.samples: @@ -213,7 +193,7 @@ def process_group(self, group_id: uuid.UUID) -> None: sample.reward = r.reward sample.env_name = r.env_name sample.training_mode = self.config.training_mode - sample.completion_temperatures = [temperature] * len(sample.completion_ids) + sample.temperatures = [temperature] * len(sample.token_ids) if self.pre_filters: apply_filters(self.pre_filters, survivors) @@ -240,7 +220,7 @@ def process_group(self, group_id: uuid.UUID) -> None: avg_reward = sum(rewards) / len(rewards) if rewards else 0.0 filter_str = ", ".join(f"{n}={c}" for n, c in filtered_by_name.items()) if filtered_by_name else "—" get_logger().debug( - f"Finished group | env={env_name} example_id={example_id} | " + f"Finished group | env={env_name} task_idx={task_idx} | " f"rollouts={len(group)} (errored={num_errored}, filtered={num_filtered}) | " f"reward={avg_reward:.4f} | filters: {filter_str}" ) @@ -259,7 +239,7 @@ def process_batch(self) -> TrainBatch: cut = 0 running = 0 for i, r in enumerate(self.pending_batch): - running += r.raw["token_usage"]["final_input_tokens"] + r.raw["token_usage"]["final_output_tokens"] + running += r.total_tokens cut = i + 1 if running >= self.token_batch_size: break @@ -282,8 +262,8 @@ def process_batch(self) -> TrainBatch: prefill = 0 decode = 0 for sample in r.samples: - sample_decode = sum(sample.completion_mask) - sample_prefill = len(sample.prompt_ids) + len(sample.completion_mask) - sample_decode + sample_decode = sum(sample.mask) + sample_prefill = len(sample.token_ids) - sample_decode decode += sample_decode prefill += sample_prefill if not r.is_filtered: diff --git a/src/prime_rl/orchestrator/train_source.py b/src/prime_rl/orchestrator/train_source.py index db439f7539..35c8165a89 100644 --- a/src/prime_rl/orchestrator/train_source.py +++ b/src/prime_rl/orchestrator/train_source.py @@ -14,7 +14,7 @@ class TrainSource: """``next_example(available_permits)`` picks a weighted-RR env and returns its next example (or ``None`` when the env's per-call permit cost doesn't fit — the dispatch loop retries when permits free up). - Returned dicts carry ``env_name`` + ``example_id``.""" + Returned dicts carry ``env_name`` + ``task_idx``.""" def __init__(self, train_envs: TrainEnvs, *, seed: int | None) -> None: self.rng = random.Random(seed) @@ -28,11 +28,9 @@ def __init__(self, train_envs: TrainEnvs, *, seed: int | None) -> None: # per-rollout envs need 1 self.env_costs: dict[str, int] = {} for env in self.envs: - rows: list[dict] = [] - for row in env.get_dataset(seed=seed): - ex = dict(row) - ex["env_name"] = env.name - rows.append(ex) + # The orchestrator never loads the env: sample over the task-index + # range the server reported via info() (num_tasks). + rows: list[dict] = [{"task_idx": i, "env_name": env.name} for i in range(env.num_tasks)] self.rng.shuffle(rows) self.examples[env.name] = rows self.cursors[env.name] = 0 diff --git a/src/prime_rl/orchestrator/trajectories.py b/src/prime_rl/orchestrator/trajectories.py index 3e8431c12a..39522c45d8 100644 --- a/src/prime_rl/orchestrator/trajectories.py +++ b/src/prime_rl/orchestrator/trajectories.py @@ -1,652 +1,97 @@ -import base64 -import hashlib -from pathlib import Path -from typing import Any +"""Turn a v1 `Trace` (the env server's native, typed output) into training data. -import numpy as np -import pybase64 -import torch -import verifiers as vf -from transformers.tokenization_utils import PreTrainedTokenizer - -from prime_rl.transport import RoutedExperts, TrainingSample -from prime_rl.utils.chat_template import ( - common_prefix_len, - deserialize_tool_calls, - normalize_messages, - render_messages, - strip_message_content, -) -from prime_rl.utils.logger import get_logger - -# We use list() instead of deepcopy() for flat lists (token IDs, logprobs) - safe because -# primitives are immutable. mm_kwargs payloads are not mutated after creation. - - -def align_routed_experts( - routed_experts: np.ndarray | None, - expected_len: int, -) -> np.ndarray | None: - """Align routed_experts length with the expected token count. - - VLLM's capturer uses `num_tokens - 1` slot mappings because the final - generated token was never fed as input to a forward pass and has no - routing decision. Append zero-filled entries for the missing positions. - """ - if routed_experts is None: - return routed_experts - assert routed_experts.ndim == 3 - if routed_experts.shape[0] > expected_len: - return np.ascontiguousarray(routed_experts[:expected_len]) - deficit = expected_len - routed_experts.shape[0] - if deficit <= 0: - return routed_experts - padding = np.zeros((deficit, routed_experts.shape[1], routed_experts.shape[2]), dtype=routed_experts.dtype) - return np.concatenate((routed_experts, padding), axis=0) - - -def _common_prefix_len(a: list[int], b: list[int]) -> int: - return common_prefix_len(a, b) - - -def _normalize_messages(messages: Any, default_role: str) -> list[dict[str, Any]]: - return normalize_messages(messages, default_role) - - -def _deserialize_tool_calls(messages: list[dict[str, Any]]) -> list[dict[str, Any]]: - return deserialize_tool_calls(messages) - - -def _strip_message_content(messages: list[dict[str, Any]]) -> list[dict[str, Any]]: - return strip_message_content(messages) - - -def _render_messages( - tokenizer: PreTrainedTokenizer, - messages: list[dict[str, Any]], - add_generation_prompt: bool = False, - tools: list[dict[str, Any]] | None = None, -) -> list[int]: - return render_messages( - tokenizer, - messages, - add_generation_prompt=add_generation_prompt, - tools=tools, - ) - - -def _tokenize_step_from_messages( - step: vf.TrajectoryStep, - tokenizer: PreTrainedTokenizer, - tools: list[dict[str, Any]] | None = None, -) -> dict[str, Any]: - prompt = _normalize_messages(step.get("prompt"), default_role="user") - completion = _normalize_messages(step.get("completion"), default_role="assistant") - - prompt = _strip_message_content(_deserialize_tool_calls(prompt)) - completion = _strip_message_content(_deserialize_tool_calls(completion)) - - assert all(m.get("role") == "assistant" for m in completion), ( - "Expected all completion messages to be assistant role for SFT distillation, " - f"got roles: {[m.get('role') for m in completion]}" - ) - - all_messages = prompt + completion - prompt_has_assistant_completion = len(completion) > 0 and completion[0].get("role") == "assistant" - prompt_ids = _render_messages( - tokenizer, - prompt, - add_generation_prompt=prompt_has_assistant_completion, - tools=tools, - ) - full_ids = _render_messages( - tokenizer, - all_messages, - tools=tools, - ) - - split_idx = _common_prefix_len(prompt_ids, full_ids) - original_prompt_len = len(prompt_ids) - - prompt_ids = full_ids[:split_idx] - completion_ids = full_ids[split_idx:] - completion_mask = [True] * len(completion_ids) - completion_logprobs = [0.0] * len(completion_ids) - - return { - "prompt_ids": prompt_ids, - "prompt_mask": [False] * len(prompt_ids), - "completion_ids": completion_ids, - "completion_mask": completion_mask, - "completion_logprobs": completion_logprobs, - "routed_experts": None, - "prompt_prefix_len": split_idx, - "original_prompt_len": original_prompt_len, - } +The orchestrator holds a real `vf.Trace` (validated in `envs.py`), so everything here is +attribute access — no dicts. The trace is a message graph (`trace.nodes`); each `trace.branches` +entry (a root→leaf path) is first-class and carries its own flat token sequence +(`branch.token_ids` / `branch.sampled_mask` / `branch.logprobs`), so a branch yields one +training sample directly. Token-length readers (`completion_len`, `total_tokens`, `num_turns`) +live on `vf.Trace` itself. +Training is renderer-only across every mode (RL/OPD student, SFT teacher), so every node +always carries its tokens — no backfill needed. For multimodal rollouts the branch also carries +the images it introduced (`branch.multi_modal_data`), rebuilt here into the flat `mm_kwargs` / +`mm_token_type_ids` the trainer forwards. +""" -def _convert_tools_to_oai_format(tool_defs: list) -> list[dict[str, Any]] | None: - """Convert verifiers Tool objects or dicts to OAI function-calling format.""" - if not tool_defs: - return None +from __future__ import annotations - def _get(tool: Any, key: str) -> Any: - if isinstance(tool, dict): - return tool.get(key) - return getattr(tool, key, None) - - return [ - { - "type": "function", - "function": { - "name": _get(tool, "name"), - "description": _get(tool, "description"), - "parameters": _get(tool, "parameters"), - **({} if _get(tool, "strict") is None else {"strict": _get(tool, "strict")}), - }, - } - for tool in tool_defs - ] - - -def _tokenize_step_with_renderer( - step: vf.TrajectoryStep, - renderer, - tools: list[dict[str, Any]] | None = None, -) -> dict[str, Any]: - """Tokenize a trajectory step using a Renderer.""" - from renderers.base import build_trajectory_step - - prompt = _normalize_messages(step.get("prompt"), default_role="user") - completion = _normalize_messages(step.get("completion"), default_role="assistant") - prompt = _strip_message_content(_deserialize_tool_calls(prompt)) - completion = _strip_message_content(_deserialize_tool_calls(completion)) - return build_trajectory_step(renderer, prompt, completion, tools=tools) - - -def backfill_rollout_tokens( - output: vf.RolloutOutput, - tokenizer: PreTrainedTokenizer, - renderer=None, -) -> bool: - """Populate missing step tokens from prompt/completion messages. - - When a renderer is provided, uses it for tokenization (faster, deterministic). - Otherwise falls back to the tokenizer + apply_chat_template path. - """ - if all(step["tokens"] is not None for step in output["trajectory"]): - return True - - logger = get_logger() - tools = _convert_tools_to_oai_format(output.get("tool_defs", [])) - - for step_idx, step in enumerate(output["trajectory"]): - if step["tokens"] is not None: - continue - - if renderer is not None: - step["tokens"] = _tokenize_step_with_renderer(step, renderer, tools=tools) - else: - reconstructed = _tokenize_step_from_messages(step, tokenizer, tools=tools) - if reconstructed["prompt_prefix_len"] < reconstructed["original_prompt_len"]: - logger.debug( - f"Prompt tokenization was non-prefix for example {output['example_id']} step {step_idx}. " - f"Using longest common prefix length {reconstructed['prompt_prefix_len']} " - f"(original prompt had {reconstructed['original_prompt_len']} tokens)." - ) - reconstructed.pop("prompt_prefix_len") - reconstructed.pop("original_prompt_len") - step["tokens"] = reconstructed +import numpy as np +import verifiers.v1 as vf - return True +from prime_rl.transport import TrainingSample +from prime_rl.transport.types import EncodedTensor +from prime_rl.utils.logger import get_logger -def interleave_rollout( - output: vf.RolloutOutput, - mm_token_type_ids_mapping: dict[int, int] | None = None, +def _to_numpy(val) -> np.ndarray: + """A renderer mm item value (torch tensor or numpy array) -> a contiguous numpy array.""" + if hasattr(val, "detach"): # torch tensor + val = val.detach().cpu().numpy() + return np.ascontiguousarray(val) + + +def _encode_mm_kwargs(mm_items: dict[str, list[dict]]) -> dict[str, EncodedTensor] | None: + """Concatenate the branch's per-image renderer items into the flat `mm_kwargs` the trainer + forwards — one `EncodedTensor` per kwarg key (e.g. `pixel_values`, `image_grid_thw`), images + cat'd along dim 0 in branch token order. Model-agnostic: the keys are whatever the processor + emits. Returns None when there are no items.""" + bins: dict[str, list[np.ndarray]] = {} + for items in mm_items.values(): # per modality + for item in items: # per image + for key, val in item.items(): + bins.setdefault(key, []).append(_to_numpy(val)) + encoded: dict[str, EncodedTensor] = {} + for key, arrs in bins.items(): + arr = np.concatenate(arrs, axis=0) + encoded[key] = EncodedTensor(dtype=str(arr.dtype), shape=list(arr.shape), data=arr.tobytes()) + return encoded or None + + +def trace_to_samples( + trace: vf.Trace, *, env_name: str = "", -) -> list[TrainingSample] | None: - """ - Convert vf.RolloutOutput to trainable rollouts by interleaving trajectory steps - where the extension property holds. - - When consecutive steps share token prefixes (extension property), they are - merged into a single sample. When extension breaks (e.g., due to context - compaction or a change in control-flow), a new sample is started. - - Supports multi-prefix matching to handle interleaved agents. For example, - [agent1-step1, agent1-step2, agent2-step1, agent1-step3] produces two samples: - agent1 steps merged together, agent2 step separate. - - Returns a list of samples - could be 1 (extension always held) or up to T - (extension never held). - - For VLM models, each renderer-produced trajectory step carries its - per-image processed tensors inline on ``multi_modal_data``; the last - merged step's sidecar covers every image in the sample. + mm_token_type_ids_mapping: dict[int, int] | None = None, +) -> list[TrainingSample]: + """Convert a v1 `Trace` into `TrainingSample`s — one per branch. + + Each `trace.branches` entry is already a flat token sequence (`branch.token_ids` / + `branch.sampled_mask` / `branch.logprobs`), so a sample carries it directly: `mask` marks + the trainable (model-sampled) tokens, the context tokens between completions stay masked + out. On a rollout error the whole completion is masked out. A branch carrying images also + gets `mm_kwargs` (the concatenated pixel tensors) and `mm_token_type_ids` (the renderer's + `mm_token_type_id_map` applied to the branch tokens). Branches with no sampled tokens + (e.g. an openai client carrying none) yield nothing. """ - logger = get_logger() - - trajectory = output["trajectory"] - if len(trajectory) == 0: - error = output.get("error") - stop = output.get("stop_condition") - logger.warning( - f"No trajectory steps for example {output['example_id']} (error={error}, stop={stop}). Skipping rollout." - ) - return None - - has_error = output["error"] is not None - - def prepare_step_tokens(step: vf.TrajectoryStep, step_idx: int) -> dict[str, Any] | None: - tokens = step["tokens"] - if tokens is not None: - routed_experts_payload = tokens.get("routed_experts") - routed_experts = None - routed_experts_start = None - if routed_experts_payload is not None: - decoded_routed_experts = pybase64.b64decode_as_bytearray(routed_experts_payload["data"]) - routed_experts = np.frombuffer(decoded_routed_experts, dtype=np.uint8).reshape( - routed_experts_payload["shape"] - ) - routed_experts_start = routed_experts_payload["start"] - - return { - "prompt_ids": list(tokens["prompt_ids"]), - "prompt_mask": list(map(bool, tokens["prompt_mask"])), - "completion_ids": list(tokens["completion_ids"]), - "completion_mask": list(map(bool, tokens["completion_mask"])), - "completion_logprobs": list(tokens["completion_logprobs"]), - "routed_experts": routed_experts, - "routed_experts_start": routed_experts_start, - # Renderer-emitted multimodal sidecar (placeholders + per-item - # processed tensors). Populated when the rollout went through - # a multimodal-aware renderer (e.g. Qwen3VLRenderer); absent - # for text-only rollouts. - "multi_modal_data": tokens.get("multi_modal_data"), - } - - logger.warning(f"Missing rollout tokens for example {output['example_id']} step {step_idx}.") - return None - - prepared_steps: list[dict[str, Any]] = [] - for step_idx, step in enumerate(trajectory): - prepared = prepare_step_tokens(step, step_idx) - if prepared is None: - return None - prepared_steps.append(prepared) - - # Deferred routed_experts state per sample: O(N) chunk list concatenated - # once at finalize, replacing the prior O(N²) per-extension unpack/repack. - sample_routed_state: dict[int, dict[str, Any]] = {} - routed_prefix_states: dict[int, list[tuple[list[int], list[int], dict[str, Any]]]] = {} - - # Track (prefix_tokens, sample, step_indices) per active sample. step_indices - # is the explicit list of prepared_steps positions merged into this sample — - # non-contiguous when other agents' steps interleave. - active_samples: list[tuple[list[int], TrainingSample, list[int]]] = [] - - def make_sample(tokens: dict[str, Any], step_idx: int) -> TrainingSample: - """Create a new TrainingSample from a trajectory step.""" - if has_error: - completion_mask = [False] * len(tokens["completion_mask"]) - else: - completion_mask = list(tokens["completion_mask"]) - completion_ids = list(tokens["completion_ids"]) - - prompt_ids = list(tokens["prompt_ids"]) - sample = TrainingSample( - prompt_ids=prompt_ids, - prompt_mask=list(tokens["prompt_mask"]), - completion_ids=completion_ids, - completion_mask=completion_mask, - completion_logprobs=list(tokens["completion_logprobs"]), - completion_temperatures=[], - teacher_logprobs=None, - advantage=None, - env_name=env_name, - mm_token_type_ids=None, - routed_experts=None, # deferred — finalized at end of interleave_rollout - ) - # Initialize routed-experts state for this sample. First chunk is the - # raw step routed_experts (no pad, no copy). running_len is the - # cumulative count across chunks; tracked so the boundary fix-up at - # each extension is a no-op append rather than a destructive write. - step_routed = tokens.get("routed_experts") - if step_routed is not None: - routed_start = tokens["routed_experts_start"] - assert routed_start is not None, f"Missing routed_experts_start for step {step_idx}" - chunks: list[np.ndarray] = [] - running_len = 0 - if routed_start > 0: - source_len = routed_start + 1 - assert source_len in routed_prefix_states, ( - f"Missing routed prefix state for step {step_idx}: " - f"routed_start={routed_start}, prompt_len={len(tokens['prompt_ids'])}" - ) - source_state = None - for prompt_ids, completion_ids, candidate_state in routed_prefix_states[source_len]: - prompt_len = len(prompt_ids) - if ( - tokens["prompt_ids"][:prompt_len] == prompt_ids - and tokens["prompt_ids"][prompt_len:source_len] == completion_ids - ): - source_state = candidate_state - break - assert source_state is not None, ( - f"No matching routed prefix for step {step_idx}: " - f"routed_start={routed_start}, prompt_len={len(tokens['prompt_ids'])}" - ) - assert source_state["running_len"] >= routed_start, ( - f"Routed prefix too short for step {step_idx}: " - f"running_len={source_state['running_len']}, routed_start={routed_start}" - ) - remaining = routed_start - for chunk in source_state["chunks"]: - if remaining == 0: - break - take = min(remaining, int(chunk.shape[0])) - chunks.append(chunk[:take]) - remaining -= take - assert remaining == 0, ( - f"Could not reconstruct routed prefix for step {step_idx}: " - f"remaining={remaining}, routed_start={routed_start}" - ) - running_len = routed_start - chunks.append(step_routed) - running_len += int(step_routed.shape[0]) - sample_routed_state[id(sample)] = { - "chunks": chunks, - "running_len": running_len, - } - return sample - - def extend_sample( - sample: TrainingSample, - prefix_len: int, - step_idx: int, - ) -> None: - """Extend an existing sample with a new trajectory step (extension property holds).""" - tokens = prepared_steps[step_idx] - - # Extend with new prompt tokens (mask=False, no gradient) - new_prompt_ids = tokens["prompt_ids"][prefix_len:] - sample.completion_ids.extend(new_prompt_ids) - sample.completion_mask.extend([False] * len(new_prompt_ids)) - sample.completion_logprobs.extend([0.0] * len(new_prompt_ids)) - - # Extend with new completion tokens - completion_ids = tokens["completion_ids"] - sample.completion_ids.extend(completion_ids) - if has_error: - sample.completion_mask.extend([False] * len(tokens["completion_mask"])) - else: - sample.completion_mask.extend(tokens["completion_mask"]) - sample.completion_logprobs.extend(tokens["completion_logprobs"]) - - step_routed = tokens.get("routed_experts") - state = sample_routed_state.get(id(sample)) - if state is not None: - assert step_routed is not None, f"Missing routed experts for routed sample extension at step {step_idx}" - if step_routed is not None: - assert state is not None, f"Unexpected routed experts for unrouted sample at step {step_idx}" - assert tokens["routed_experts_start"] == prefix_len - 1, ( - f"Routed experts delta start mismatch at step {step_idx}: " - f"start={tokens['routed_experts_start']}, expected={prefix_len - 1}, prefix_len={prefix_len}" - ) - # Delta payloads start at prefix_len - 1. Row 0 fills the boundary - # token missing from the previous request; the rest is the new suffix. - if prefix_len > 0: - boundary_chunk = step_routed[:1] - state["chunks"].append(boundary_chunk) - state["running_len"] += 1 - step_routed = step_routed[1:] - new_chunk = step_routed - state["chunks"].append(new_chunk) - state["running_len"] += int(new_chunk.shape[0]) - - first_tokens = prepared_steps[0] - first_prefix = first_tokens["prompt_ids"] + first_tokens["completion_ids"] - first_sample = make_sample(first_tokens, step_idx=0) - active_samples.append((first_prefix, first_sample, [0])) - first_routed_state = sample_routed_state.get(id(first_sample)) - if first_routed_state is not None: - routed_prefix_states.setdefault(len(first_prefix), []).append( - (first_tokens["prompt_ids"], first_tokens["completion_ids"], first_routed_state) - ) - - for step_idx, _step in enumerate(trajectory[1:], start=1): - tokens = prepared_steps[step_idx] - step_prompt_ids = tokens["prompt_ids"] - - # Pick the *longest* matching active prefix. With compaction/rollback, - # one active sample's prefix can be a strict prefix of another (e.g. a - # later sample re-generated tokens that overlap an earlier sample's - # prefix). Both would satisfy the slice check; the shorter would - # silently absorb the longer sample's generated tokens as user input. - matched_idx = None - matched_len = -1 - matching_prefix_lens: list[int] = [] - for idx, (prefix_tokens, _, _) in enumerate(active_samples): - pl = len(prefix_tokens) - if step_prompt_ids[:pl] == prefix_tokens: - matching_prefix_lens.append(pl) - if pl > matched_len: - matched_idx = idx - matched_len = pl - - if len(matching_prefix_lens) > 1: - # Ambiguous extension: rare, but reachable via compaction/rollback - # where a new sample's prefix happens to start with an older - # sample's prefix. Longest-match is the correct choice; surface - # the ambiguity so we can audit if it shows up in real rollouts. - logger.warning( - f"Ambiguous prefix match at step {step_idx} for example {output['example_id']}: " - f"{len(matching_prefix_lens)} of {len(active_samples)} active prefixes match " - f"(lens={sorted(matching_prefix_lens)}, step_prompt_len={len(step_prompt_ids)}). " - f"Extending the longest (len={matched_len})." - ) - - if matched_idx is not None: - # Extension holds - merge into matched sample - prefix_tokens, sample, step_indices = active_samples[matched_idx] - extend_sample(sample, len(prefix_tokens), step_idx=step_idx) - new_prefix = tokens["prompt_ids"] + tokens["completion_ids"] - active_samples[matched_idx] = ( - new_prefix, - sample, - step_indices + [step_idx], - ) - routed_state = sample_routed_state.get(id(sample)) - if routed_state is not None: - routed_prefix_states.setdefault(len(new_prefix), []).append( - (tokens["prompt_ids"], tokens["completion_ids"], routed_state) - ) - else: - # No prefix matches - start a new sample - logger.debug( - f"Extension property broke at step {step_idx + 1} for example {output['example_id']}. " - f"Starting new sample (active_prefixes={len(active_samples)}, step_prompt_len={len(step_prompt_ids)})." - ) - new_prefix = tokens["prompt_ids"] + tokens["completion_ids"] - sample = make_sample(tokens, step_idx=step_idx) - active_samples.append((new_prefix, sample, [step_idx])) - routed_state = sample_routed_state.get(id(sample)) - if routed_state is not None: - routed_prefix_states.setdefault(len(new_prefix), []).append( - (tokens["prompt_ids"], tokens["completion_ids"], routed_state) - ) - - # Finalize routed_experts for each sample. One concat per sample (O(N) byte - # work) replaces the previous per-step unpack/concat/repack (O(N²)). The - # boundary entries between steps were already inserted as one-entry chunks - # during extend_sample, so a straight concat is correct. - for _, sample, _ in active_samples: - state = sample_routed_state.get(id(sample)) - if state is None: - continue - chunks = state["chunks"] - if not chunks: + has_error = trace.has_error + samples: list[TrainingSample] = [] + for branch in trace.branches: + mask = branch.sampled_mask + if not any(mask): continue - combined = np.concatenate(chunks, axis=0) if len(chunks) > 1 else np.ascontiguousarray(chunks[0]) - expected_len = len(sample.prompt_ids) + len(sample.completion_ids) - combined = align_routed_experts(combined, expected_len) - combined = np.ascontiguousarray(combined) - sample.routed_experts = RoutedExperts( - data=combined.tobytes(), - shape=list(combined.shape), - dtype=str(combined.dtype), + token_ids = branch.token_ids + mm_kwargs: dict[str, EncodedTensor] | None = None + mm_token_type_ids: list[int] | None = None + mmd = branch.multi_modal_data + if mmd is not None: + mm_kwargs = _encode_mm_kwargs(mmd.mm_items) + mapping = mm_token_type_ids_mapping or {} + mm_token_type_ids = [mapping.get(t, 0) for t in token_ids] + samples.append( + TrainingSample( + token_ids=token_ids, + mask=[m and not has_error for m in mask], + logprobs=branch.logprobs, + temperatures=[], # filled by TrainSink.process_group + teacher_logprobs=None, + advantage=None, + env_name=env_name, + mm_kwargs=mm_kwargs, + mm_token_type_ids=mm_token_type_ids, + ) ) - - # Attach images by concatenating mm_items across every step the - # sample covers. verifiers' ``state_to_output`` ships per-step - # *delta* mm_data (each step contains only items not present in the - # prior step's cumulative set, with multiset-aware dedup), so - # reading the last step alone would miss every earlier-turn image. - # Concat in step order recovers the per-sample cumulative set; - # deduping again here would drop legitimate duplicate placeholders. - for _, sample, step_indices in active_samples: - renderer_mm = _union_step_mm_data(prepared_steps, step_indices) - if renderer_mm is not None: - mm_kwargs = _pack_mm_kwargs_from_renderer(renderer_mm) - if mm_kwargs is not None: - sample.mm_kwargs = mm_kwargs - # ``mm_token_type_ids``: 1 for image-placeholder tokens, 2 - # for video, 0 otherwise. Renderer-supplied via - # ``mm_token_type_id_map`` (single source of truth). - if mm_token_type_ids_mapping is not None: - sample.mm_token_type_ids = [ - mm_token_type_ids_mapping.get(token_id, 0) - for token_id in sample.prompt_ids + sample.completion_ids - ] - - return [sample for _, sample, _ in active_samples] - - -def _union_step_mm_data( - prepared_steps: list[dict[str, Any]], - step_indices: list[int], -) -> "dict[str, Any] | None": - """Concatenate renderer-emitted mm_items across this sample's owned steps. - - ``step_indices`` lists exactly the ``prepared_steps`` positions merged into - the sample — explicit, not a range, so interleaved-agent trajectories skip - steps owned by other agents. - - Verifiers ≥ c7731bbb ships per-step *delta* mm_data instead of - cumulative — see ``verifiers/utils/save_utils.py::_delta_intermediate_mm_data``. - The cross-step dedup is already done there with multiset semantics - (preserving multiplicity for an image that appears in multiple - placeholder runs in the token stream). We just concatenate in step - order to recover the per-sample cumulative; deduping again here - would drop legitimate duplicate placeholders. - """ - union_items: dict[str, list] = {} - union_hashes: dict[str, list] = {} - has_any = False - for i in step_indices: - mm = prepared_steps[i].get("multi_modal_data") - if mm is None: - continue - items = mm.mm_items if hasattr(mm, "mm_items") else (mm or {}).get("mm_items") or {} - hashes = mm.mm_hashes if hasattr(mm, "mm_hashes") else (mm or {}).get("mm_hashes") or {} - for modality, item_lst in items.items(): - hash_lst = hashes.get(modality, []) or [] - for j, item in enumerate(item_lst or []): - h = hash_lst[j] if j < len(hash_lst) else None - union_items.setdefault(modality, []).append(item) - union_hashes.setdefault(modality, []).append(h) - has_any = True - if not has_any: - return None - return {"mm_items": union_items, "mm_hashes": union_hashes} - - -def _pack_mm_kwargs_from_renderer(mm_data: Any) -> "dict[str, Any] | None": - """Batch the renderer's per-image ``mm_items`` into model-agnostic - forward kwargs. - - ``mm_data`` may arrive as a ``MultiModalData`` instance (in-process - for tests) or as a plain dict (after msgpack round-trip from the - env-worker). Each item is a dict keyed by the names the model's - ``forward`` expects (``pixel_values`` + ``image_grid_thw`` for - Qwen3-VL, just ``pixel_values`` for Gemma3-VL, etc.). We batch by - ``torch.cat(..., dim=0)`` per key — generic because every HF VLM - processor emits a leading batch/patch dimension, and the renderer - always processes one image per call. - - Returns a dict of ``EncodedTensor`` payloads keyed by kwarg name, - or ``None`` when no multimodal data is present. - """ - from verifiers.utils.serve_utils import decode_tensor_payload - - from prime_rl.transport.types import EncodedTensor - - mm_items = mm_data.mm_items if hasattr(mm_data, "mm_items") else (mm_data or {}).get("mm_items") or {} - # Flatten across modalities into one kwarg dict — the model's - # forward signature is the schema. ``mm_items`` is typically - # ``{"image": [...], "video": [...]}`` but each modality's keys - # don't collide for any HF VLM we ship today. - per_kwarg: dict[str, list] = {} - for _modality, items in mm_items.items(): - for item in items or []: - for key, payload in item.items(): - per_kwarg.setdefault(key, []).append(decode_tensor_payload(payload)) - if not per_kwarg: - return None - out: dict[str, EncodedTensor] = {} - for key, tensors in per_kwarg.items(): - cat = torch.cat(tensors, dim=0).contiguous() - arr = cat.detach().cpu().numpy() - out[key] = EncodedTensor( - dtype=str(arr.dtype), - shape=list(arr.shape), - data=arr.tobytes(), + if not samples: + get_logger().warning( + f"No trainable samples (error={has_error}, stop={trace.stop_condition}, num_turns={trace.num_turns})." ) - return out - - -_FILE_URL_PREFIX = "file://" - - -def offload_images_to_disk(rollouts: list[vf.RolloutOutput], output_dir: Path) -> int: - """Replace base64 image data in rollout trajectories with file paths on disk. - - Scans all trajectory step prompts for data:image URLs, writes the decoded - image bytes to ``{output_dir}/assets/images/{hash}.png``, and replaces the - URL in-place with ``file://{path}``. Deduplicates by content hash so each - unique image is written only once. - - Returns the number of unique images written to disk. - """ - images_dir = output_dir / "assets" / "images" - images_dir.mkdir(parents=True, exist_ok=True) - - written: set[str] = set() - - for output in rollouts: - for step in output.get("trajectory", []): - prompt = step.get("prompt") - if not prompt or not isinstance(prompt, list): - continue - for msg in prompt: - content = msg.get("content", []) - if not isinstance(content, list): - continue - for item in content: - if item.get("type") != "image_url": - continue - url = item.get("image_url", {}).get("url", "") - if not url.startswith("data:image"): - continue - b64_data = url.split(",", 1)[1] - content_hash = hashlib.sha256(b64_data.encode()).hexdigest()[:16] - path = images_dir / f"{content_hash}.png" - if content_hash not in written: - if not path.exists(): - path.write_bytes(base64.b64decode(b64_data)) - written.add(content_hash) - item["image_url"]["url"] = f"{_FILE_URL_PREFIX}{path}" - - return len(written) + return samples diff --git a/src/prime_rl/orchestrator/types.py b/src/prime_rl/orchestrator/types.py index c2a3f5de79..e59b4bdb43 100644 --- a/src/prime_rl/orchestrator/types.py +++ b/src/prime_rl/orchestrator/types.py @@ -3,10 +3,12 @@ from __future__ import annotations import uuid -from dataclasses import dataclass, field, fields -from typing import Literal, Protocol +from dataclasses import dataclass, field +from typing import Generic, Literal, Protocol -import verifiers as vf +import verifiers.v1 as vf +from pydantic import ConfigDict, Field +from verifiers.v1.task import TaskT from prime_rl.transport import TrainingSample @@ -55,7 +57,7 @@ class GroupState: kind: RolloutKind env_name: str - example: dict + task_idx: int rollouts_to_schedule: int target_rollouts: int emitted: int = 0 @@ -64,60 +66,25 @@ class GroupState: policy_version_at_start: int = 0 -@dataclass -class FinishedRollout: - """A completed rollout the sink receives. ``raw`` is the env's untouched - ``vf.RolloutOutput``; prime-rl metadata lives on typed fields. Train vs - eval is discriminated via ``isinstance``. ``rollout_id`` is the only - safe key for tracing one rollout — ``(env_name, example_id)`` collides - on re-sampling and ``group_id`` covers a whole group.""" - - raw: vf.RolloutOutput - env_name: str - example_id: int | str - group_id: uuid.UUID - policy_version: int - off_policy_steps: int - rollout_id: uuid.UUID = field(default_factory=uuid.uuid4) - - @property - def error(self) -> dict | None: - return self.raw.get("error") - - @property - def reward(self) -> float: - return float(self.raw.get("reward", 0.0)) - - @property - def is_truncated(self) -> bool: - return bool(self.raw.get("is_truncated", False)) - - def to_dict(self) -> vf.RolloutOutput: - """``raw`` + metadata merged for I/O (``save_rollouts``, - ``monitor.log_samples``). Shallow copy; never mutates ``self.raw``.""" - out: vf.RolloutOutput = dict(self.raw) # type: ignore[assignment] - for f in fields(self): - if f.name in ("raw", "samples"): - continue - val = getattr(self, f.name) - if f.name == "filter_results": - out["filters"] = dict(val) - continue - out[f.name] = str(val) if isinstance(val, uuid.UUID) else val - return out +class Rollout(vf.Trace[TaskT], Generic[TaskT]): + """A completed rollout: the env's typed ``vf.Trace`` *is* the rollout — prime-rl's + orchestration metadata lives on it directly (set by the dispatcher once the rollout + returns), so there's no wrapper. Train vs eval is the ``kind`` discriminator. All metadata + fields are ``exclude=True``, so dumping a Rollout yields a plain trace — the on-disk + ``results.jsonl`` is unchanged.""" + model_config = ConfigDict(arbitrary_types_allowed=True) # ``samples`` holds msgspec structs -@dataclass -class TrainRollout(FinishedRollout): - samples: list[TrainingSample] = field(default_factory=list) - advantage: float | None = None - is_filtered: bool = False - filter_results: dict[str, bool] = field(default_factory=dict) - - -@dataclass -class EvalRollout(FinishedRollout): - eval_step: int = 0 + kind: RolloutKind = Field(default="train", exclude=True) + env_name: str = Field(default="", exclude=True) + group_id: uuid.UUID = Field(default_factory=uuid.uuid4, exclude=True) + policy_version: int = Field(default=0, exclude=True) + off_policy_steps: int = Field(default=0, exclude=True) + samples: list[TrainingSample] = Field(default_factory=list, exclude=True) + advantage: float | None = Field(default=None, exclude=True) + is_filtered: bool = Field(default=False, exclude=True) + filter_results: dict[str, bool] = Field(default_factory=dict, exclude=True) + eval_step: int | None = Field(default=None, exclude=True) @dataclass @@ -142,7 +109,7 @@ class TrainBatch: """``samples`` is the trainer-bound payload (post-filter survivors); ``rollouts`` is the full cohort kept for orchestrator-side I/O.""" - rollouts: list[TrainRollout] + rollouts: list[Rollout] samples: list[TrainingSample] metrics: TrainBatchMetrics @@ -197,7 +164,7 @@ class EvalBatch: env_name: str step: int - rollouts: list[EvalRollout] + rollouts: list[Rollout] metrics: EvalBatchMetrics diff --git a/src/prime_rl/orchestrator/utils.py b/src/prime_rl/orchestrator/utils.py index 5675ba3f34..500e0697f3 100644 --- a/src/prime_rl/orchestrator/utils.py +++ b/src/prime_rl/orchestrator/utils.py @@ -1,19 +1,18 @@ import asyncio import logging +import os import time from concurrent.futures import ThreadPoolExecutor from itertools import cycle from pathlib import Path import orjson -import verifiers as vf -from verifiers.utils.client_utils import setup_openai_client -from verifiers.utils.save_utils import make_serializable +import verifiers.v1 as vf from prime_rl.configs.orchestrator import OrchestratorConfig from prime_rl.transport import TrainingSample from prime_rl.utils.client import setup_inference_pool -from prime_rl.utils.logger import InterceptHandler, get_logger +from prime_rl.utils.logger import InterceptHandler, get_logger, setup_logger from prime_rl.utils.utils import ( get_broadcast_dir, get_ckpt_dir, @@ -22,63 +21,48 @@ async def setup_student_inference_pool(*, config: OrchestratorConfig, tokenizer): - """Build the student inference pool + matching renderer. Returns - ``(renderer | None, inference_pool)``; ``renderer`` is ``None`` on the - MITO path (``config.renderer is None``).""" + """Build the student renderer and inference pool, returning ``(renderer, inference_pool)``. + + Training is renderer-only: RL/OPD roll out through the env server's renderer client + (token-in/out), and SFT — which rolls out against a chat-completions teacher that returns + no tokens — re-renders the conversation with this renderer to backfill them. The renderer + is built here from the (always-set) ``config.renderer`` and also supplies the multimodal + token-type-id map. The eval client is plain chat-completions (eval traces aren't trained).""" from renderers.base import create_renderer client_config = config.student.client model_name = config.student.model.name - - if config.renderer is not None: - renderer = create_renderer(tokenizer, config.renderer) - get_logger().info(f"Initialized {type(renderer).__name__} for {model_name}") - inference_pool = await setup_inference_pool( - client_config, - model_name=model_name, - train_client_type="renderer", - eval_client_type="openai_chat_completions", - renderer_config=config.renderer, - pool_size=config.pool_size, - ) - get_logger().info("Using direct renderer rollout client") - return renderer, inference_pool - - get_logger().info("Using MITO (openai_chat_completions) for rollouts") + renderer = create_renderer(tokenizer, config.renderer) + get_logger().info("Using renderer rollout client") inference_pool = await setup_inference_pool( client_config, model_name=model_name, - train_client_type="openai_chat_completions", + train_client_type="renderer", eval_client_type="openai_chat_completions", + renderer_config=config.renderer, + pool_size=config.pool_size, ) - return None, inference_pool - + return renderer, inference_pool -def get_model_completion_len(output: vf.RolloutOutput) -> int: - """Sum of model-generated completion tokens across all turns (excludes - environment-injected tokens between turns).""" - return sum(len(step["tokens"]["completion_ids"]) for step in output["trajectory"] if step.get("tokens")) - -def get_tool_response_len(output: vf.RolloutOutput) -> int: +def get_tool_response_len(output: vf.Trace) -> int: """Total tool-response tokens consumed across the whole rollout, read from a harness-emitted metric (e.g. RLM's `rlm_total_tool_response_tokens`, deduped across turns/branches/sub-RLMs). Returns 0 when no such metric is present.""" - metrics = output.get("metrics") or {} - for key, value in metrics.items(): + for key, value in output.metrics.items(): if key.endswith("total_tool_response_tokens") and isinstance(value, (int, float)): return int(value) return 0 -def save_rollouts(rollouts: list[vf.RolloutOutput], path: Path, exclude_keys: set[str] | None = None) -> None: - """Save rollouts to a JSONL file using verifiers serialization.""" +def save_rollouts(rollouts: list[dict], path: Path, exclude_keys: set[str] | None = None) -> None: + """Save rollouts (Trace dicts, already JSON-serializable) to a JSONL file.""" path.parent.mkdir(parents=True, exist_ok=True) opts = orjson.OPT_APPEND_NEWLINE | orjson.OPT_SERIALIZE_NUMPY with open(path, "wb") as f: for rollout in rollouts: row = {k: v for k, v in rollout.items() if k not in exclude_keys} if exclude_keys else rollout - f.write(orjson.dumps(row, default=make_serializable, option=opts)) + f.write(orjson.dumps(row, default=str, option=opts)) def intercept_vf_logging(logger: str = "verifiers", level: str = "DEBUG", prefix: str | None = None): @@ -90,6 +74,15 @@ def intercept_vf_logging(logger: str = "verifiers", level: str = "DEBUG", prefix vf_logger.propagate = False +def setup_env_server_logging(log_level: str, json_logging: bool = False) -> None: + """Configure logging for an env-server process: prime-rl's logger + routing v1's stdlib + logs through it. Passed to verifiers' ``serve_env`` so it runs in the broker and in every + spawned worker — fresh ``spawn`` processes that otherwise have no handlers and would drop + their per-rollout logs.""" + setup_logger(log_level, json_logging=json_logging) + intercept_vf_logging(logger="verifiers.v1", level=log_level) + + def set_default_executor(max_workers: int = 64) -> None: """Scale the default asyncio thread pool so asyncio.to_thread has enough capacity.""" get_logger().info(f"Setting default executor to ThreadPoolExecutor(max_workers={max_workers})") @@ -103,10 +96,15 @@ async def compute_teacher_logprobs( ) -> list[list[float]]: """Compute teacher model logprobs for a batch of training samples via prefill.""" import httpx + from openai import AsyncOpenAI from vllm.entrypoints.serve.disagg.protocol import GenerateResponse async def _compute_single(client_config: vf.ClientConfig, sample: TrainingSample) -> list[float]: - client = setup_openai_client(client_config) + client = AsyncOpenAI( + base_url=client_config.base_url, + api_key=os.environ.get(client_config.api_key_var, "EMPTY"), + default_headers=client_config.headers or None, + ) # Two escape hatches from ``AsyncOpenAI.post``: # 1. URL — ``/inference/v1/generate`` is mounted at server root, not @@ -125,7 +123,7 @@ async def _compute_single(client_config: vf.ClientConfig, sample: TrainingSample cast_to=httpx.Response, body={ "model": model_name, - "token_ids": list(sample.prompt_ids) + list(sample.completion_ids), + "token_ids": list(sample.token_ids), "sampling_params": { "max_tokens": 1, "temperature": 1.0, diff --git a/src/prime_rl/trainer/batch.py b/src/prime_rl/trainer/batch.py index ea99859a35..3999a554a3 100644 --- a/src/prime_rl/trainer/batch.py +++ b/src/prime_rl/trainer/batch.py @@ -1,6 +1,8 @@ import copy -from prime_rl.transport.types import MicroBatch, RoutedExperts, TrainingSample +import numpy as np + +from prime_rl.transport.types import EncodedTensor, MicroBatch, RoutedExperts, TrainingSample ROUTED_EXPERTS_DTYPE_ITEMSIZE = { "uint8": 1, @@ -49,26 +51,77 @@ def _pad_routed_experts(micro_batch: MicroBatch, padding_size: int) -> None: routed_experts.shape[0] += padding_size +def _slice_encoded(tensor: EncodedTensor, n_rows: int) -> EncodedTensor: + """First `n_rows` rows of a dim-0-stacked encoded tensor (e.g. pixel_values, image_grid_thw).""" + row = int(np.prod(tensor.shape[1:])) if len(tensor.shape) > 1 else 1 + itemsize = np.dtype(tensor.dtype).itemsize + return EncodedTensor( + dtype=tensor.dtype, + shape=[n_rows, *tensor.shape[1:]], + data=tensor.data[: n_rows * row * itemsize], + ) + + +def _truncate_mm( + mm_token_type_ids: list[int], mm_kwargs: dict[str, EncodedTensor], seq_len: int +) -> tuple[int, dict[str, EncodedTensor] | None]: + """Truncating a sample must not split an image's placeholder block, else the surviving image + token count no longer matches the image embeddings in `mm_kwargs`. Returns the cut point + (<= seq_len, never inside an image block) and `mm_kwargs` sliced to the images whose + placeholders fully survive (None if no image survives).""" + grid = np.frombuffer(bytearray(mm_kwargs["image_grid_thw"].data), dtype=mm_kwargs["image_grid_thw"].dtype).reshape( + mm_kwargs["image_grid_thw"].shape + ) + patches_per_image = [int(g.prod()) for g in grid] + total_patches = mm_kwargs["pixel_values"].shape[0] + total_tokens = sum(1 for t in mm_token_type_ids if t) + ppt = total_patches // total_tokens if total_tokens else 1 # patches per token (merge^2) + tokens_per_image = [p // ppt for p in patches_per_image] + + surviving = sum(1 for t in mm_token_type_ids[:seq_len] if t) + kept = acc = 0 + for n in tokens_per_image: + if acc + n > surviving: + break + acc += n + kept += 1 + if acc == surviving: + cut = seq_len # surviving image tokens are exactly `kept` whole images + else: + # `surviving` lands inside image `kept`; cut to its first placeholder, dropping it. + seen, cut = 0, seq_len + for i, t in enumerate(mm_token_type_ids): + if t: + seen += 1 + if seen == acc + 1: + cut = i + break + if not kept: + return cut, None + kept_patches = sum(patches_per_image[:kept]) + sliced = {k: _slice_encoded(v, kept if k == "image_grid_thw" else kept_patches) for k, v in mm_kwargs.items()} + return cut, sliced + + def prepare_sample(training_example: TrainingSample, seq_len: int) -> MicroBatch: """ Prepare a problem for sequence packing training. Tokenize and prepare tensors. """ - input_ids = training_example.prompt_ids + training_example.completion_ids - loss_mask = training_example.prompt_mask + training_example.completion_mask - inference_logprobs = [0.0] * len(training_example.prompt_ids) + training_example.completion_logprobs + input_ids = training_example.token_ids + loss_mask = training_example.mask + inference_logprobs = training_example.logprobs advantages = [training_example.advantage] * len(input_ids) reward = training_example.reward if training_example.reward is not None else float("nan") rewards = [reward] * len(input_ids) position_ids = list(range(len(input_ids))) mm_token_type_ids = training_example.mm_token_type_ids + mm_kwargs = training_example.mm_kwargs assert training_example.env_name != "all", "env_name='all' is reserved for aggregate metric keys" env_names = [training_example.env_name] * len(input_ids) - # Per-token temperatures: prompt tokens use first completion temp (masked out anyway) - # Default to 1.0 if completion is empty (e.g., model generated only tool calls with no text) - prompt_temp = training_example.completion_temperatures[0] if training_example.completion_temperatures else 1.0 - temperatures = [prompt_temp] * len(training_example.prompt_ids) + training_example.completion_temperatures + # Per-token sampling temperatures (context tokens are masked out, so theirs are don't-care). + temperatures = training_example.temperatures # Teacher logprobs already cover the full sequence (prompt + completion), # computed via prefill in the orchestrator when a teacher model is configured @@ -78,20 +131,25 @@ def prepare_sample(training_example: TrainingSample, seq_len: int) -> MicroBatch ) if len(input_ids) > seq_len: - input_ids = input_ids[:seq_len] - loss_mask = loss_mask[:seq_len] - inference_logprobs = inference_logprobs[:seq_len] - position_ids = position_ids[:seq_len] - advantages = advantages[:seq_len] - rewards = rewards[:seq_len] - temperatures = temperatures[:seq_len] + # Multimodal: never split an image's placeholder block — cut to a whole-image boundary + # and slice mm_kwargs to match, so image-token count == image-embedding count. + cut = seq_len + if mm_token_type_ids is not None and mm_kwargs is not None: + cut, mm_kwargs = _truncate_mm(mm_token_type_ids, mm_kwargs, seq_len) + input_ids = input_ids[:cut] + loss_mask = loss_mask[:cut] + inference_logprobs = inference_logprobs[:cut] + position_ids = position_ids[:cut] + advantages = advantages[:cut] + rewards = rewards[:cut] + temperatures = temperatures[:cut] if teacher_logprobs is not None: - teacher_logprobs = teacher_logprobs[:seq_len] + teacher_logprobs = teacher_logprobs[:cut] if routed_experts is not None: - routed_experts = _slice_routed_experts(routed_experts, seq_len) + routed_experts = _slice_routed_experts(routed_experts, cut) if mm_token_type_ids is not None: - mm_token_type_ids = mm_token_type_ids[:seq_len] - env_names = env_names[:seq_len] + mm_token_type_ids = mm_token_type_ids[:cut] + env_names = env_names[:cut] assert ( len(input_ids) @@ -131,7 +189,7 @@ def prepare_sample(training_example: TrainingSample, seq_len: int) -> MicroBatch routed_experts=routed_experts, mm_token_type_ids=mm_token_type_ids, env_names=env_names, - mm_kwargs=training_example.mm_kwargs, + mm_kwargs=mm_kwargs, training_mode=training_example.training_mode, ) diff --git a/src/prime_rl/trainer/rl/packer.py b/src/prime_rl/trainer/rl/packer.py index cf9dcfa02e..b69d59cbeb 100644 --- a/src/prime_rl/trainer/rl/packer.py +++ b/src/prime_rl/trainer/rl/packer.py @@ -148,27 +148,17 @@ def _on_run_data_deleted(self, idx: int, run_id: str) -> None: def _validate_sample(self, sample: TrainingSample) -> tuple[bool, str | None]: """Validate a sample to ensure it won't crash the trainer.""" - sample_length = len(sample.prompt_ids) + len(sample.completion_ids) - if len(sample.prompt_mask) != len(sample.prompt_ids): - return ( - False, - f"Run wrote a sample with prompt mask length != prompt ids length ({len(sample.prompt_mask)} != {len(sample.prompt_ids)})", - ) - if len(sample.completion_mask) != len(sample.completion_ids): - return ( - False, - f"Run wrote a sample with completion mask length != completion ids length ({len(sample.completion_mask)} != {len(sample.completion_ids)})", - ) - if len(sample.completion_logprobs) != len(sample.completion_ids): - return ( - False, - f"Run wrote a sample with completion logprobs length != completion ids length ({len(sample.completion_logprobs)} != {len(sample.completion_ids)})", - ) - if len(sample.completion_temperatures) != len(sample.completion_ids): - return ( - False, - f"Run wrote a sample with completion temperatures length != completion ids length ({len(sample.completion_temperatures)} != {len(sample.completion_ids)})", - ) + sample_length = len(sample.token_ids) + for name, arr in ( + ("mask", sample.mask), + ("logprobs", sample.logprobs), + ("temperatures", sample.temperatures), + ): + if len(arr) != sample_length: + return ( + False, + f"Run wrote a sample with {name} length != token_ids length ({len(arr)} != {sample_length})", + ) if sample_length == 0: return False, "Run wrote a sample with no tokens" if sample_length > self.seq_len: @@ -216,7 +206,7 @@ def _count_tokens(self, threshold: int | None = None) -> int: for sample, step in buffer: if step > current_step: break - tokens += len(sample.prompt_ids) + len(sample.completion_ids) + tokens += len(sample.token_ids) if threshold is not None and tokens >= threshold: return tokens return tokens @@ -253,10 +243,10 @@ def _select_samples_round_robin(self, token_budget: int) -> list[tuple[int, Trai if step > current_step: # Samples from different steps should be consumed later break - tokens_collected += len(sample.prompt_ids) + len(sample.completion_ids) + tokens_collected += len(sample.token_ids) if tokens_collected > token_budget: - if tokens_collected == (len(sample.prompt_ids) + len(sample.completion_ids)): - tokens_collected -= len(sample.prompt_ids) + len(sample.completion_ids) + if tokens_collected == (len(sample.token_ids)): + tokens_collected -= len(sample.token_ids) # This means we have a sample that has more tokens than max seqlen self.buffers[run_idx].popleft() continue @@ -306,7 +296,7 @@ def pack(self): samples_by_run[run_idx] = [] samples_by_run[run_idx].append(sample) - num_tokens = len(sample.prompt_ids) + len(sample.completion_ids) + num_tokens = len(sample.token_ids) if run_idx in per_run_stats: cur_samples, cur_tokens = per_run_stats[run_idx] per_run_stats[run_idx] = (cur_samples + 1, cur_tokens + num_tokens) diff --git a/src/prime_rl/transport/types.py b/src/prime_rl/transport/types.py index 1bb31c9325..bdf31a8e64 100644 --- a/src/prime_rl/transport/types.py +++ b/src/prime_rl/transport/types.py @@ -24,14 +24,17 @@ class RoutedExperts(msgspec.Struct, array_like=True, gc=False, omit_defaults=Tru # Orchestrator -> Packer class TrainingSample(msgspec.Struct, array_like=True, gc=False, omit_defaults=True): - """A single training example.""" - - prompt_ids: list[int] - prompt_mask: list[bool] - completion_ids: list[int] - completion_mask: list[bool] - completion_logprobs: list[float] - completion_temperatures: list[float] # Per-token temperatures used during generation + """A single training example — one branch of a rollout as a flat token sequence. + + There is no prompt/completion split: an agentic, multi-turn branch interleaves context and + model-sampled spans, so ``mask`` marks which tokens are trainable (model-sampled) and + ``logprobs`` / ``temperatures`` are aligned per token. All four arrays share the length of + ``token_ids``.""" + + token_ids: list[int] + mask: list[bool] # per-token: True = model-sampled (trainable), False = context/scaffold + logprobs: list[float] # per-token sampling logprobs (0.0 on non-sampled tokens) + temperatures: list[float] # per-token temperature used during generation env_name: str teacher_logprobs: list[float] | None = None advantage: float | None = None diff --git a/src/prime_rl/utils/async_utils.py b/src/prime_rl/utils/async_utils.py index 1de097592f..27374bac75 100644 --- a/src/prime_rl/utils/async_utils.py +++ b/src/prime_rl/utils/async_utils.py @@ -1,4 +1,9 @@ import asyncio +from collections import deque +from time import perf_counter + +import numpy as np +from pydantic import BaseModel async def safe_cancel(task: asyncio.Task) -> None: @@ -13,3 +18,54 @@ async def safe_cancel(task: asyncio.Task) -> None: async def safe_cancel_all(tasks: list[asyncio.Task]) -> None: """Safely cancels and awaits all asyncio.Tasks.""" await asyncio.gather(*[safe_cancel(task) for task in tasks]) + + +class EventLoopLagMonitor: + """Monitors how busy the main event loop is by timing short sleeps. + + Vendored from verifiers.utils.async_utils (the orchestrator now runs on + v1 and no longer depends on v1 verifiers).""" + + def __init__(self, measure_interval: float = 0.1, max_measurements: int = 1000): + assert measure_interval > 0 and max_measurements > 0 + self.measure_interval = measure_interval + self.max_measurements = max_measurements + self.lags: deque[float] = deque(maxlen=max_measurements) + + async def measure_lag(self) -> float: + next_time = perf_counter() + self.measure_interval + await asyncio.sleep(self.measure_interval) + return perf_counter() - next_time + + async def run(self) -> None: + """Loop measuring event-loop lag; run as a background task.""" + while True: + self.lags.append(await self.measure_lag()) + + +class EventLoopLagStats(BaseModel): + """Snapshot of event-loop lag statistics.""" + + min: float = 0.0 + mean: float = 0.0 + median: float = 0.0 + p90: float = 0.0 + p99: float = 0.0 + max: float = 0.0 + n: int = 0 + + @classmethod + def from_monitor(cls, monitor: EventLoopLagMonitor) -> "EventLoopLagStats": + n = len(monitor.lags) + if n == 0: + return cls(n=0) + arr = np.array(monitor.lags) + return cls( + min=float(arr.min()), + mean=float(arr.mean()), + median=float(np.median(arr)), + p90=float(np.percentile(arr, 90)), + p99=float(np.percentile(arr, 99)), + max=float(arr.max()), + n=n, + ) diff --git a/src/prime_rl/utils/client.py b/src/prime_rl/utils/client.py index b9ee8f4b9d..e8f4059196 100644 --- a/src/prime_rl/utils/client.py +++ b/src/prime_rl/utils/client.py @@ -8,16 +8,17 @@ from typing import Protocol, runtime_checkable import httpx -import verifiers as vf +import verifiers.v1 as vf from httpx import AsyncClient from openai import NotFoundError from renderers import RendererConfig from tenacity import retry, retry_if_exception, stop_after_attempt, stop_after_delay, wait_exponential +from verifiers.v1.clients.config import OpenAIClientConfig, RendererClientConfig from prime_rl.configs.shared import ClientConfig from prime_rl.utils.logger import get_logger -# Identity tuple used by ``select_train_client`` to key load counts. ``api_base_url`` +# Identity tuple used by ``select_train_client`` to key load counts. ``base_url`` # distinguishes servers; ``X-data-parallel-rank`` distinguishes DP shards within a # server, since the router uses that header to route to specific GPU ranks. ClientIdentity = tuple[str, str | None] @@ -25,7 +26,7 @@ def client_identity(client: vf.ClientConfig) -> ClientIdentity: """Stable identity for load balancing across inference clients.""" - return (client.api_base_url, client.extra_headers.get("X-data-parallel-rank")) + return (client.base_url, client.headers.get("X-data-parallel-rank")) @runtime_checkable @@ -185,42 +186,34 @@ def setup_clients( renderer_model_name: str | None = None, pool_size: int | None = None, ) -> list[vf.ClientConfig]: - clients = [] - client_idx = 0 - # Only forward the renderer config when the client actually uses a - # renderer — MITO/TITO clients ignore it. + """Build v1 client configs (one per base_url × DP rank). ``client_type`` + ``renderer`` → token-in/out (``RendererClientConfig``, with the renderer the env + server should use forwarded as a serialized config so it doesn't fall back to the + default renderer); otherwise plain chat-completions (``OpenAIClientConfig``).""" + is_renderer = client_type == "renderer" + config_cls = RendererClientConfig if is_renderer else OpenAIClientConfig renderer_extra: dict = {} - if client_type == "renderer": + if is_renderer: + # Pass the shared renderers.RendererConfig straight through (v1's + # RendererClientConfig.renderer is the same type; pydantic round-trips it + # over the wire). prime-rl and v1 share one renderer config. renderer_extra = { - "renderer_config": renderer_config, + "renderer": renderer_config, + "pool_size": pool_size or 1, "renderer_model_name": renderer_model_name, - "renderer_pool_size": pool_size, } env_headers = { k: v for k, v in ((k, os.getenv(v)) for k, v in client_config.headers_from_env.items()) if v is not None } + clients: list[vf.ClientConfig] = [] for base_url in client_config.base_url: for dp_rank in range(client_config.dp_rank_count): headers = {**client_config.headers, **env_headers} if client_config.dp_rank_count > 1: headers["X-data-parallel-rank"] = str(dp_rank) clients.append( - vf.ClientConfig( - client_idx=client_idx, - client_type=client_type, - api_base_url=base_url, - api_key_var=client_config.api_key_var, - timeout=client_config.timeout, - connect_timeout=client_config.connect_timeout, - max_connections=8192, - max_keepalive_connections=8192, - max_retries=10, - extra_headers=headers, - extra_headers_from_state=client_config.extra_headers_from_state, - **renderer_extra, - ) + config_cls(base_url=base_url, api_key_var=client_config.api_key_var, headers=headers, **renderer_extra) ) - client_idx += 1 return clients diff --git a/src/prime_rl/utils/elastic.py b/src/prime_rl/utils/elastic.py index 951b3673c1..f483b0c7b9 100644 --- a/src/prime_rl/utils/elastic.py +++ b/src/prime_rl/utils/elastic.py @@ -17,7 +17,7 @@ from typing import Literal import httpx -import verifiers as vf +import verifiers.v1 as vf from httpx import AsyncClient from renderers import RendererConfig diff --git a/src/prime_rl/utils/monitor/base.py b/src/prime_rl/utils/monitor/base.py index bc88080451..d79cbe3140 100644 --- a/src/prime_rl/utils/monitor/base.py +++ b/src/prime_rl/utils/monitor/base.py @@ -1,8 +1,9 @@ import random from abc import ABC, abstractmethod -from typing import Any +from typing import TYPE_CHECKING, Any -import verifiers as vf +if TYPE_CHECKING: + from prime_rl.orchestrator.types import Rollout def sample_items_for_logging(items: list[Any], sample_ratio: float | None) -> list[Any]: @@ -39,11 +40,11 @@ def log(self, metrics: dict[str, Any], step: int) -> None: pass @abstractmethod - def log_samples(self, rollouts: list[vf.RolloutOutput], step: int) -> None: + def log_samples(self, rollouts: list["Rollout"], step: int) -> None: pass @abstractmethod - def log_eval_samples(self, rollouts: list[vf.RolloutOutput], env_name: str, step: int) -> None: + def log_eval_samples(self, rollouts: list["Rollout"], env_name: str, step: int) -> None: pass @abstractmethod @@ -72,10 +73,10 @@ def log(self, metrics: dict[str, Any], step: int) -> None: else: self.history = [metrics] - def log_samples(self, rollouts: list[vf.RolloutOutput], step: int) -> None: + def log_samples(self, rollouts: list["Rollout"], step: int) -> None: pass - def log_eval_samples(self, rollouts: list[vf.RolloutOutput], env_name: str, step: int) -> None: + def log_eval_samples(self, rollouts: list["Rollout"], env_name: str, step: int) -> None: pass def save_final_summary(self, filename: str = "final_summary.json") -> None: diff --git a/src/prime_rl/utils/monitor/multi.py b/src/prime_rl/utils/monitor/multi.py index 0f32ea5906..620aee88c4 100644 --- a/src/prime_rl/utils/monitor/multi.py +++ b/src/prime_rl/utils/monitor/multi.py @@ -1,10 +1,11 @@ -from typing import Any - -import verifiers as vf +from typing import TYPE_CHECKING, Any from prime_rl.utils.logger import get_logger from prime_rl.utils.monitor.base import Monitor +if TYPE_CHECKING: + from prime_rl.orchestrator.types import Rollout + class MultiMonitor(Monitor): """Monitor that wraps multiple monitors and delegates calls to all of them.""" @@ -26,14 +27,14 @@ def log(self, metrics: dict[str, Any], step: int) -> None: except Exception as e: self.logger.warning(f"Failed to log metrics to {monitor.__class__.__name__}: {e}") - def log_samples(self, rollouts: list[vf.RolloutOutput], step: int) -> None: + def log_samples(self, rollouts: list["Rollout"], step: int) -> None: for monitor in self.monitors: try: monitor.log_samples(rollouts=rollouts, step=step) except Exception as e: self.logger.warning(f"Failed to log samples to {monitor.__class__.__name__}: {e}") - def log_eval_samples(self, rollouts: list[vf.RolloutOutput], env_name: str, step: int) -> None: + def log_eval_samples(self, rollouts: list["Rollout"], env_name: str, step: int) -> None: for monitor in self.monitors: try: monitor.log_eval_samples(rollouts=rollouts, env_name=env_name, step=step) diff --git a/src/prime_rl/utils/monitor/prime.py b/src/prime_rl/utils/monitor/prime.py index 657037c6f7..cce45278de 100644 --- a/src/prime_rl/utils/monitor/prime.py +++ b/src/prime_rl/utils/monitor/prime.py @@ -7,12 +7,11 @@ from datetime import datetime, timezone from pathlib import Path from threading import Thread -from typing import Any +from typing import TYPE_CHECKING, Any import httpx import pyarrow as pa import pyarrow.parquet as pq -import verifiers as vf from prime_cli.core.config import Config as PrimeConfig from transformers.tokenization_utils import PreTrainedTokenizer @@ -21,14 +20,8 @@ from prime_rl.utils.logger import get_logger from prime_rl.utils.monitor.base import Monitor, sample_items_for_logging - -def _json(val: Any) -> str: - """JSON-serialize dicts/lists, pass strings through, default to empty string for None.""" - if isinstance(val, str): - return val - if val is None: - return "" - return json.dumps(val) +if TYPE_CHECKING: + from prime_rl.orchestrator.types import Rollout _SAMPLE_SCHEMA = pa.schema( @@ -285,7 +278,7 @@ def log(self, metrics: dict[str, Any], step: int) -> None: }, ) - def log_samples(self, rollouts: list[vf.RolloutOutput], step: int) -> None: + def log_samples(self, rollouts: list["Rollout"], step: int) -> None: """Logs rollouts to Prime Intellect API using presigned URLs for direct R2 upload.""" if not self.is_master: return @@ -328,35 +321,35 @@ def log_samples(self, rollouts: list[vf.RolloutOutput], step: int) -> None: f"Initiated samples upload at step {step} to Prime Intellect API in {time.perf_counter() - start_time:.2f}s" ) - def _rollouts_to_parquet_bytes(self, rollouts: list[vf.RolloutOutput], step: int) -> bytes | None: - """Convert rollouts directly to Parquet bytes for upload.""" + def _rollouts_to_parquet_bytes(self, rollouts: list["Rollout"], step: int) -> bytes | None: + """Convert rollouts to Parquet bytes for upload. One row per rollout, built from the + message graph: the conversation is the unit (no prompt/completion split — meaningless in + a multi-turn branch), so `completion` carries the main (last) branch's full message list + and `trajectory` carries one message list per branch (`trace.branches`).""" now = datetime.now(timezone.utc) rows = [] for sample_id, rollout in enumerate(rollouts): - prompt = rollout.get("prompt") - completion = rollout.get("completion") - trajectory = rollout.get("trajectory") or [] - if prompt is None or completion is None or not trajectory: + branches = rollout.branches + if not branches: continue + main_messages = [m.model_dump(mode="json") for m in branches[-1].messages] - example_id = rollout.get("example_id") + task_idx = rollout.task.idx try: - problem_id = int(example_id) if example_id is not None else sample_id + problem_id = int(task_idx) if task_idx is not None else sample_id except (TypeError, ValueError): problem_id = sample_id trajectory_data = [ { - "prompt": ts["prompt"], - "completion": ts["completion"], - "reward": ts.get("reward"), - "advantage": ts.get("advantage"), - "extras": ts.get("extras", {}), - "num_input_tokens": len(ts["tokens"]["prompt_ids"]) if ts.get("tokens") else None, - "num_output_tokens": len(ts["tokens"]["completion_ids"]) if ts.get("tokens") else None, + "messages": [m.model_dump(mode="json") for m in branch.messages], + "reward": rollout.reward, + "advantage": rollout.advantage, + "num_input_tokens": branch.prompt_len, + "num_output_tokens": branch.completion_len, } - for ts in trajectory + for branch in branches ] rows.append( @@ -366,19 +359,19 @@ def _rollouts_to_parquet_bytes(self, rollouts: list[vf.RolloutOutput], step: int "tag": "", "problem_id": problem_id, "sample_id": sample_id, - "prompt": json.dumps(prompt), - "completion": json.dumps(completion), + "prompt": "", + "completion": json.dumps(main_messages), "trajectory": json.dumps(trajectory_data), - "answer": rollout.get("answer") or "", - "env_name": rollout.get("env_name") or "", - "task": rollout.get("task") or "", - "info": _json(rollout.get("info")), - "reward": rollout.get("reward"), - "advantage": rollout.get("advantage"), - "metrics": _json(rollout.get("metrics")), - "timing": _json(rollout.get("timing")), - "num_input_tokens": 0, - "num_output_tokens": 0, + "answer": "", + "env_name": rollout.env_name, + "task": json.dumps(rollout.task.model_dump(mode="json")), + "info": "", + "reward": rollout.reward, + "advantage": rollout.advantage, + "metrics": json.dumps(rollout.metrics), + "timing": json.dumps(rollout.timing.model_dump(mode="json")), + "num_input_tokens": branches[-1].prompt_len, + "num_output_tokens": branches[-1].completion_len, "created_at": now, } ) @@ -491,7 +484,7 @@ async def _confirm_samples_upload(self, step: int, s3_key: str, max_retries: int await asyncio.sleep(delay) return False - def log_eval_samples(self, rollouts: list[vf.RolloutOutput], env_name: str, step: int) -> None: + def log_eval_samples(self, rollouts: list["Rollout"], env_name: str, step: int) -> None: pass def log_distributions(self, distributions: dict[str, list[float]], step: int) -> None: diff --git a/src/prime_rl/utils/monitor/wandb.py b/src/prime_rl/utils/monitor/wandb.py index 7935d13a6a..c94584f1f3 100644 --- a/src/prime_rl/utils/monitor/wandb.py +++ b/src/prime_rl/utils/monitor/wandb.py @@ -3,20 +3,38 @@ import sys import time from pathlib import Path -from typing import Any +from typing import TYPE_CHECKING, Any -import verifiers as vf import wandb from transformers.tokenization_utils import PreTrainedTokenizer from wandb.errors import CommError from wandb.sdk.mailbox.mailbox_handle import ServerResponseError from prime_rl.configs.shared import WandbConfig, WandbWithExtrasConfig -from prime_rl.utils.chat_template import deserialize_tool_calls from prime_rl.utils.config import BaseConfig from prime_rl.utils.logger import get_logger from prime_rl.utils.monitor.base import Monitor, sample_items_for_logging +if TYPE_CHECKING: + from prime_rl.orchestrator.types import Rollout + + +def _loggable_task(task) -> str: + """A Table-safe JSON string of the task for sample logging. Image content parts are elided to + a short placeholder — their base64 data bloats the table and breaks wandb Table's nested-type + inference on the variable-length content list (a plain dict would otherwise crash on it).""" + + def elide(obj): + if isinstance(obj, dict): + if obj.get("type") == "image_url": + return {"type": "image_url", "image_url": ""} + return {k: elide(v) for k, v in obj.items()} + if isinstance(obj, list): + return [elide(v) for v in obj] + return obj + + return json.dumps(elide(task.model_dump(mode="json"))) + class WandbMonitor(Monitor): """Logs to Weights and Biases.""" @@ -113,13 +131,13 @@ def init_wandb(max_retries: int): if config is not None and isinstance(config, WandbWithExtrasConfig) and config.log_extras: if config.log_extras.samples: self.last_log_samples_step = -1 - self.samples_cols = ["step", "env_name", "task", "example_id", "messages", "input_ids", "reward"] + self.samples_cols = ["step", "env_name", "task", "task_idx", "messages", "input_ids", "reward"] self.samples_table = wandb.Table( columns=self.samples_cols, log_mode="INCREMENTAL", ) self.tokenizer = tokenizer - self.eval_samples_cols = ["step", "env", "task", "example_id", "completion", "reward"] + self.eval_samples_cols = ["step", "env", "task", "task_idx", "completion", "reward"] self.eval_samples_table = wandb.Table( columns=self.eval_samples_cols, log_mode="INCREMENTAL", @@ -143,7 +161,7 @@ def log(self, metrics: dict[str, Any], step: int) -> None: return wandb.log({**metrics, "step": step}) - def log_samples(self, rollouts: list[vf.RolloutOutput], step: int) -> None: + def log_samples(self, rollouts: list["Rollout"], step: int) -> None: """Logs rollouts to W&B table.""" if not self.is_master: return @@ -172,32 +190,30 @@ def log_samples(self, rollouts: list[vf.RolloutOutput], step: int) -> None: start_time = time.perf_counter() for rollout in rollouts: - trajectory = rollout["trajectory"] - if not trajectory: - continue - last_step = trajectory[-1] - tokens = last_step["tokens"] - full_ids = tokens["prompt_ids"] + tokens["completion_ids"] - messages_text = self.tokenizer.decode(full_ids) - sample = { - "step": step, - "env_name": rollout.get("env_name"), - "task": rollout.get("task"), - "example_id": rollout["example_id"], - "messages": messages_text, - "input_ids": str(full_ids), - "reward": rollout["reward"], - } - assert list(sample.keys()) == self.samples_cols, ( - "Order of columns in the table must be the same as order of the keys here" - ) - self.samples_table.add_data(*sample.values()) + trace = rollout + for branch in trace.branches: + token_ids = branch.token_ids + if not token_ids: + continue + sample = { + "step": step, + "env_name": rollout.env_name, + "task": _loggable_task(trace.task), + "task_idx": trace.task.idx, + "messages": self.tokenizer.decode(token_ids), + "input_ids": str(token_ids), + "reward": trace.reward, + } + assert list(sample.keys()) == self.samples_cols, ( + "Order of columns in the table must be the same as order of the keys here" + ) + self.samples_table.add_data(*sample.values()) wandb.log({"samples": self.samples_table, "step": step}) self.last_log_samples_step = step self.logger.debug(f"Logged samples at step {step} to W&B table in {time.perf_counter() - start_time:.2f}s") - def log_eval_samples(self, rollouts: list[vf.RolloutOutput], env_name: str, step: int) -> None: + def log_eval_samples(self, rollouts: list["Rollout"], env_name: str, step: int) -> None: """Logs eval rollouts to a separate W&B table.""" if not self.is_master: return @@ -210,23 +226,22 @@ def log_eval_samples(self, rollouts: list[vf.RolloutOutput], env_name: str, step return for rollout in rollouts: - completion = rollout.get("completion") - if not completion: - continue - if isinstance(completion, list): - try: - completion = self.tokenizer.apply_chat_template(deserialize_tool_calls(completion), tokenize=False) - except Exception: - completion = str(completion) - sample = { - "step": step, - "env": env_name, - "task": rollout.get("task"), - "example_id": rollout["example_id"], - "completion": completion, - "reward": rollout["reward"], - } - self.eval_samples_table.add_data(*sample.values()) + trace = rollout + for branch in trace.branches: + # Eval runs the openai client (no token ids), so show the assistant message + # content rather than decoded tokens. + completion = "".join(m.content or "" for m in branch.messages if m.role == "assistant") + if not completion: + continue + sample = { + "step": step, + "env": env_name, + "task": _loggable_task(trace.task), + "task_idx": trace.task.idx, + "completion": completion, + "reward": trace.reward, + } + self.eval_samples_table.add_data(*sample.values()) wandb.log({"eval/samples": self.eval_samples_table, "step": step}) diff --git a/tests/unit/orchestrator/test_advantage.py b/tests/unit/orchestrator/test_advantage.py index 89022c8428..a401bd3960 100644 --- a/tests/unit/orchestrator/test_advantage.py +++ b/tests/unit/orchestrator/test_advantage.py @@ -2,6 +2,7 @@ import uuid import pytest +import verifiers.v1 as vf from prime_rl.configs.orchestrator import ( CustomAdvantageConfig, @@ -16,7 +17,7 @@ default_advantage_fn, setup_advantage_fn, ) -from prime_rl.orchestrator.types import TrainRollout +from prime_rl.orchestrator.types import Rollout def _make_rollout( @@ -24,23 +25,27 @@ def _make_rollout( completion_len: int = 0, num_turns: int = 1, env_name: str = "test", - example_id: int = 0, -) -> dict: - """Create a minimal rollout dict for advantage testing. - - `completion_len` tokens are split across `num_turns` trajectory steps. - """ + tool_response_len: int = 0, +) -> Rollout: + """Build a ``Rollout`` (message-graph trace) for advantage testing: ``reward`` via the + reward dict, ``completion_len`` sampled tokens split across ``num_turns`` assistant nodes, + and an optional tool-response token count surfaced as a metric.""" per_turn, rem = divmod(completion_len, max(num_turns, 1)) - trajectory = [ - {"tokens": {"prompt_ids": [0], "completion_ids": list(range(per_turn + (rem if i == 0 else 0)))}} + nodes = [ + vf.MessageNode( + message=vf.AssistantMessage(content="x"), + token_ids=list(range(n := per_turn + (rem if i == 0 else 0))), + mask=[True] * n, + logprobs=[0.0] * n, + ) for i in range(num_turns) ] - return { - "reward": reward, - "trajectory": trajectory, - "env_name": env_name, - "example_id": example_id, - } + metrics = {"rlm_total_tool_response_tokens": tool_response_len} if tool_response_len else {} + rollout = Rollout[vf.Task]( + task=vf.Task(idx=0, instruction=""), nodes=nodes, rewards={"reward": reward}, metrics=metrics + ) + rollout.env_name = env_name + return rollout def _make_group(rewards, completion_lengths=None, num_turns=None) -> AdvantageInputs: @@ -83,7 +88,7 @@ def test_efficiency_mixed_group(): # All correct rollouts have positive advantage for rollout, adv in zip(inputs.rollouts, result.advantages): - if rollout["reward"] >= 1.0: + if rollout.reward >= 1.0: assert adv > 0 @@ -164,21 +169,9 @@ def test_efficiency_amplification_bounded(): def test_efficiency_tokens_with_tool_response_weight(): """`tool_response_weight` shifts shaping onto tool-response tokens read from rollout metrics.""" rollouts = [ - { - "reward": 1.0, - "trajectory": [{"tokens": {"prompt_ids": [0], "completion_ids": list(range(10))}}], - "metrics": {"rlm_total_tool_response_tokens": 200}, - }, - { - "reward": 1.0, - "trajectory": [{"tokens": {"prompt_ids": [0], "completion_ids": list(range(10))}}], - "metrics": {"rlm_total_tool_response_tokens": 0}, - }, - { - "reward": 1.0, - "trajectory": [{"tokens": {"prompt_ids": [0], "completion_ids": list(range(10))}}], - "metrics": {"rlm_total_tool_response_tokens": 100}, - }, + _make_rollout(1.0, completion_len=10, tool_response_len=200), + _make_rollout(1.0, completion_len=10, tool_response_len=0), + _make_rollout(1.0, completion_len=10, tool_response_len=100), ] inputs = AdvantageInputs(rollouts=rollouts) @@ -198,12 +191,10 @@ def test_efficiency_tokens_with_tool_response_weight(): def test_efficiency_fractional_weight_with_int_rewards(): """Fractional weights must not truncate when rollout rewards are emitted as ints.""" - rollouts_int = [ - {"reward": 1, "trajectory": [{"tokens": {"prompt_ids": [0], "completion_ids": list(range(7))}}]}, - {"reward": 1, "trajectory": [{"tokens": {"prompt_ids": [0], "completion_ids": list(range(11))}}]}, - {"reward": 0, "trajectory": [{"tokens": {"prompt_ids": [0], "completion_ids": list(range(13))}}]}, - ] - rollouts_float = [{**r, "reward": float(r["reward"])} for r in rollouts_int] + lens = [7, 11, 13] + int_rewards = [1, 1, 0] + rollouts_int = [_make_rollout(r, completion_len=n) for r, n in zip(int_rewards, lens)] + rollouts_float = [_make_rollout(float(r), completion_len=n) for r, n in zip(int_rewards, lens)] fractional = TokensLengthPenaltyConfig(completion_weight=0.3, tool_response_weight=0.0) int_result = default_advantage_fn(AdvantageInputs(rollouts=rollouts_int), length_penalty=fractional) @@ -214,12 +205,7 @@ def test_efficiency_fractional_weight_with_int_rewards(): def test_efficiency_zero_costs_falls_back_to_plain_grpo(): """When all effective costs are zero, shaping is a no-op (no NaNs from div-by-zero).""" # tool-only weights but no harness metric → all costs == 0 - rollouts = [ - {"reward": 1.0, "trajectory": [{"tokens": {"prompt_ids": [0], "completion_ids": list(range(10))}}]}, - {"reward": 1.0, "trajectory": [{"tokens": {"prompt_ids": [0], "completion_ids": list(range(10))}}]}, - {"reward": 0.0, "trajectory": [{"tokens": {"prompt_ids": [0], "completion_ids": list(range(10))}}]}, - ] - inputs = AdvantageInputs(rollouts=rollouts) + inputs = _make_group(rewards=[1.0, 1.0, 0.0], completion_lengths=[10, 10, 10]) result = default_advantage_fn(inputs, length_penalty=_TOKENS_TOOL_ONLY) expected = default_advantage_fn(inputs) # plain GRPO assert not any(math.isnan(a) for a in result.advantages) @@ -249,22 +235,17 @@ def test_efficiency_turns_penalty(): assert result.advantages == pytest.approx([0.625, 0.125, -0.875, 0.125], abs=1e-6) -def _train_rollouts(rewards: list[float]) -> list[TrainRollout]: - """Wrap a list of rewards into ``TrainRollout``\\ s sharing a single - ``group_id`` — ``assign_advantages`` works on one group at a time - (the sink groups by ``group_id`` upstream).""" +def _train_rollouts(rewards: list[float]) -> list[Rollout]: + """Build ``Rollout``\\ s sharing a single ``group_id`` — ``assign_advantages`` works on one + group at a time (the sink groups by ``group_id`` upstream).""" gid = uuid.uuid4() - return [ - TrainRollout( - raw={"reward": r, "trajectory": []}, - env_name="test", - example_id=0, - group_id=gid, - policy_version=0, - off_policy_steps=0, - ) - for r in rewards - ] + rollouts = [] + for r in rewards: + rollout = Rollout[vf.Task](task=vf.Task(idx=0, instruction=""), rewards={"reward": r}) + rollout.env_name = "test" + rollout.group_id = gid + rollouts.append(rollout) + return rollouts def test_assign_advantages_writes_field(): @@ -306,4 +287,4 @@ def test_setup_advantage_fn_with_custom_config(): def _dummy_custom_advantage(inputs: AdvantageInputs, scale: float = 1.0) -> AdvantageOutputs: """A simple custom advantage for testing.""" - return AdvantageOutputs(advantages=[r["reward"] * scale for r in inputs.rollouts]) + return AdvantageOutputs(advantages=[r.reward * scale for r in inputs.rollouts]) diff --git a/tests/unit/orchestrator/test_batch.py b/tests/unit/orchestrator/test_batch.py index 7531423c72..6f8d6adcc5 100644 --- a/tests/unit/orchestrator/test_batch.py +++ b/tests/unit/orchestrator/test_batch.py @@ -2,7 +2,7 @@ import pytest from prime_rl.trainer.batch import prepare_batch, prepare_sample -from prime_rl.transport.types import RoutedExperts, TrainingSample +from prime_rl.transport.types import EncodedTensor, RoutedExperts, TrainingSample def _routed_experts(data, dtype=np.uint8): @@ -22,12 +22,10 @@ def _make_training_example( env_name: str = "test-env", ) -> TrainingSample: return TrainingSample( - prompt_ids=[1, 2], - prompt_mask=[False, False], - completion_ids=[3, 4], - completion_mask=[True, True], - completion_logprobs=[-0.1, -0.2], - completion_temperatures=[temperature, temperature], # Per-token temperatures + token_ids=[1, 2, 3, 4], + mask=[False, False, True, True], + logprobs=[0.0, 0.0, -0.1, -0.2], + temperatures=[temperature, temperature, temperature, temperature], teacher_logprobs=[0.0, 0.0, 0.0, 0.0], advantage=1.0, env_name=env_name, @@ -40,12 +38,10 @@ def _make_training_example( def test_training_sample_requires_env_name(): with pytest.raises(TypeError, match="env_name"): TrainingSample( - prompt_ids=[1, 2], - prompt_mask=[False, False], - completion_ids=[3, 4], - completion_mask=[True, True], - completion_logprobs=[-0.1, -0.2], - completion_temperatures=[1.0, 1.0], + token_ids=[1, 2, 3, 4], + mask=[False, False, True, True], + logprobs=[0.0, 0.0, -0.1, -0.2], + temperatures=[1.0, 1.0, 1.0, 1.0], advantage=1.0, ) @@ -140,12 +136,10 @@ def test_prepare_sample_with_routed_experts(): routed_experts = [[[0, 1], [2, 3]], [[4, 5], [6, 7]], [[0, 2], [1, 3]], [[1, 0], [3, 2]]] routed_payload = _routed_experts(routed_experts) sample = TrainingSample( - prompt_ids=[1, 2], - prompt_mask=[False, False], - completion_ids=[3, 4], - completion_mask=[True, True], - completion_logprobs=[-0.1, -0.2], - completion_temperatures=[1.0, 1.0], + token_ids=[1, 2, 3, 4], + mask=[False, False, True, True], + logprobs=[0.0, 0.0, -0.1, -0.2], + temperatures=[1.0, 1.0, 1.0, 1.0], advantage=1.0, env_name="test-env", routed_experts=routed_payload, @@ -162,12 +156,10 @@ def test_prepare_sample_truncates_routed_experts(): routed_payload = _routed_experts(routed_experts) expected_payload = _routed_experts(routed_experts[:3]) sample = TrainingSample( - prompt_ids=[1, 2], - prompt_mask=[False, False], - completion_ids=[3, 4], - completion_mask=[True, True], - completion_logprobs=[-0.1, -0.2], - completion_temperatures=[1.0, 1.0], + token_ids=[1, 2, 3, 4], + mask=[False, False, True, True], + logprobs=[0.0, 0.0, -0.1, -0.2], + temperatures=[1.0, 1.0, 1.0, 1.0], advantage=1.0, env_name="test-env", routed_experts=routed_payload, @@ -179,15 +171,50 @@ def test_prepare_sample_truncates_routed_experts(): assert micro_batch.env_names == ["test-env"] * 3 +def _encoded(arr) -> EncodedTensor: + a = np.asarray(arr) + return EncodedTensor(data=a.tobytes(), shape=list(a.shape), dtype=str(a.dtype)) + + +def test_prepare_sample_truncates_mm_at_image_boundary(): + """Truncation never splits an image's placeholder block: it cuts to a whole-image boundary + and slices mm_kwargs to match, so image-token count stays == image-embedding count.""" + # Two 2-token images (patches-per-token = 1): image-pad at indices 1,2 (img0) and 4,5 (img1). + mm_token_type_ids = [0, 1, 1, 0, 1, 1, 0] + pixel_values = np.array([[1.0], [1.0], [2.0], [2.0]], dtype=np.float32) # img0=1.0, img1=2.0 + grid = np.array([[1, 2, 1], [1, 2, 1]], dtype=np.int64) + sample = TrainingSample( + token_ids=[10, 11, 12, 13, 14, 15, 16], + mask=[False, False, False, False, False, True, True], + logprobs=[0.0] * 7, + temperatures=[1.0] * 7, + advantage=1.0, + env_name="test-env", + mm_token_type_ids=mm_token_type_ids, + mm_kwargs={"pixel_values": _encoded(pixel_values), "image_grid_thw": _encoded(grid)}, + ) + + # seq_len=5 falls inside img1 (one of its two placeholders survives) -> drop img1 entirely. + mb = prepare_sample(sample, seq_len=5) + assert len(mb.input_ids) == 4 # cut back to img1's first placeholder (index 4) + assert len(mb.mm_token_type_ids) == len(mb.input_ids) + n_placeholders = sum(1 for t in mb.mm_token_type_ids if t) + assert n_placeholders == 2 # only img0's two placeholders remain + # No mismatch: placeholders == image embeddings, and only img0's pixels are kept. + assert mb.mm_kwargs["pixel_values"].shape == [2, 1] + assert mb.mm_kwargs["image_grid_thw"].shape == [1, 3] + kept = np.frombuffer(bytearray(mb.mm_kwargs["pixel_values"].data), dtype=np.float32) + assert kept.tolist() == [1.0, 1.0] + assert n_placeholders == mb.mm_kwargs["pixel_values"].shape[0] # ppt == 1 here + + def test_prepare_sample_none_routed_experts(): """When routed_experts is None, micro_batch.routed_experts is None.""" sample = TrainingSample( - prompt_ids=[1, 2], - prompt_mask=[False, False], - completion_ids=[3, 4], - completion_mask=[True, True], - completion_logprobs=[-0.1, -0.2], - completion_temperatures=[1.0, 1.0], + token_ids=[1, 2, 3, 4], + mask=[False, False, True, True], + logprobs=[0.0, 0.0, -0.1, -0.2], + temperatures=[1.0, 1.0, 1.0, 1.0], advantage=1.0, env_name="test-env", ) diff --git a/tests/unit/orchestrator/test_filters.py b/tests/unit/orchestrator/test_filters.py index 2643bf71bb..baa3921ad1 100644 --- a/tests/unit/orchestrator/test_filters.py +++ b/tests/unit/orchestrator/test_filters.py @@ -1,6 +1,8 @@ import math import uuid +import verifiers.v1 as vf + from prime_rl.configs.orchestrator import GibberishFilterConfig, RepetitionFilterConfig from prime_rl.orchestrator.filters import ( GibberishFilter, @@ -9,7 +11,18 @@ setup_filter, setup_filters, ) -from prime_rl.orchestrator.types import TrainRollout +from prime_rl.orchestrator.types import Rollout + + +def _assistant_node(token_ids: list[int], logprobs: list[float]) -> vf.MessageNode: + """An assistant node whose tokens are all model-sampled (the filters read each node's + masked-True tokens + logprobs).""" + return vf.MessageNode( + message=vf.AssistantMessage(content="x"), + token_ids=token_ids, + mask=[True] * len(token_ids), + logprobs=logprobs, + ) def _make_rollout( @@ -18,52 +31,21 @@ def _make_rollout( *, reward: float = 1.0, multi_step: bool = False, -) -> TrainRollout: - """Build a ``TrainRollout`` with a minimal ``vf.RolloutOutput``-shaped - raw payload — enough for the filters to inspect ``trajectory`` / - ``stop_condition`` / etc.""" +) -> Rollout: + """Build a ``Rollout`` (a message-graph trace) carrying the completion tokens — enough for + the filters to inspect each node's sampled tokens / logprobs.""" if multi_step: mid = len(completion_ids) // 2 - trajectory = [ - { - "tokens": { - "completion_ids": completion_ids[:mid], - "completion_logprobs": completion_logprobs[:mid], - "completion_mask": [1] * mid, - } - }, - { - "tokens": { - "completion_ids": completion_ids[mid:], - "completion_logprobs": completion_logprobs[mid:], - "completion_mask": [1] * (len(completion_ids) - mid), - } - }, + nodes = [ + _assistant_node(completion_ids[:mid], completion_logprobs[:mid]), + _assistant_node(completion_ids[mid:], completion_logprobs[mid:]), ] else: - trajectory = [ - { - "tokens": { - "completion_ids": completion_ids, - "completion_logprobs": completion_logprobs, - "completion_mask": [1] * len(completion_ids), - } - } - ] - raw = { - "trajectory": trajectory, - "reward": reward, - "stop_condition": None, - "metrics": {}, - } - return TrainRollout( - raw=raw, - env_name="test", - example_id=0, - group_id=uuid.uuid4(), - policy_version=0, - off_policy_steps=0, - ) + nodes = [_assistant_node(completion_ids, completion_logprobs)] + rollout = Rollout[vf.Task](task=vf.Task(idx=0, instruction=""), nodes=nodes, rewards={"reward": reward}) + rollout.env_name = "test" + rollout.group_id = uuid.uuid4() + return rollout def _make_gibberish_filter(vocab_size=128_000, token_id_threshold=100_000, logprob_offset=2.0, enforce=False): @@ -248,9 +230,9 @@ def test_apply_filters_enforced_flags_rollout(): apply_filters([gibberish_filter], [rollout]) assert rollout.reward == 1.0 - assert rollout.raw["trajectory"][0]["tokens"]["completion_ids"] == [120_000] - assert rollout.raw["trajectory"][0]["tokens"]["completion_mask"] == [1] - assert rollout.raw["stop_condition"] is None + assert rollout.nodes[0].token_ids == [120_000] + assert rollout.nodes[0].mask == [True] + assert rollout.stop_condition is None assert rollout.filter_results == {"gibberish": True} assert rollout.is_filtered is True @@ -267,9 +249,9 @@ def test_apply_filters_preserves_clean_rollouts(): apply_filters([gibberish_filter], [rollout]) assert rollout.reward == 1.0 - assert rollout.raw["trajectory"][0]["tokens"]["completion_ids"] == [50, 60, 70] - assert all(m == 1 for m in rollout.raw["trajectory"][0]["tokens"]["completion_mask"]) - assert rollout.raw["stop_condition"] is None + assert rollout.nodes[0].token_ids == [50, 60, 70] + assert all(rollout.nodes[0].mask) + assert rollout.stop_condition is None assert rollout.filter_results == {"gibberish": False} assert rollout.is_filtered is False @@ -286,7 +268,7 @@ def test_apply_filters_first_filter_wins(): apply_filters([gibberish_filter, repetition_filter], [rollout]) - assert rollout.raw["stop_condition"] is None + assert rollout.stop_condition is None assert rollout.filter_results == {"gibberish": True, "repetition": False} assert rollout.is_filtered is True @@ -329,13 +311,13 @@ def test_apply_filters_enforced_preserves_rollout_tokens(): apply_filters([gibberish_filter], [rollout]) - assert rollout.raw["trajectory"][0]["tokens"]["completion_ids"] == [10, 120_000, 30] - assert rollout.raw["trajectory"][0]["tokens"]["completion_logprobs"] == [ + assert rollout.nodes[0].token_ids == [10, 120_000, 30] + assert rollout.nodes[0].logprobs == [ -1.0, gibberish_filter.logprob_threshold - 1.0, -0.5, ] - assert rollout.raw["trajectory"][0]["tokens"]["completion_mask"] == [1, 1, 1] + assert rollout.nodes[0].mask == [True, True, True] assert rollout.is_filtered is True @@ -347,11 +329,11 @@ def test_apply_filters_preserves_existing_stop_condition(): completion_logprobs=[gibberish_filter.logprob_threshold - 1.0], reward=1.0, ) - rollout.raw["stop_condition"] = "generation_truncated" + rollout.stop_condition = "generation_truncated" apply_filters([gibberish_filter], [rollout]) - assert rollout.raw["stop_condition"] == "generation_truncated" + assert rollout.stop_condition == "generation_truncated" assert rollout.is_filtered is True @@ -370,8 +352,8 @@ def test_apply_filters_monitor_only_tracks_detection(): apply_filters([gibberish_filter], [rollout]) assert rollout.reward == 1.0 - assert all(m == 1 for m in rollout.raw["trajectory"][0]["tokens"]["completion_mask"]) - assert rollout.raw["stop_condition"] is None + assert all(rollout.nodes[0].mask) + assert rollout.stop_condition is None assert rollout.filter_results == {"gibberish": True} assert rollout.is_filtered is False diff --git a/tests/unit/orchestrator/test_orchestrator_setup.py b/tests/unit/orchestrator/test_orchestrator_setup.py index 2372b004fd..07ac09500c 100644 --- a/tests/unit/orchestrator/test_orchestrator_setup.py +++ b/tests/unit/orchestrator/test_orchestrator_setup.py @@ -48,44 +48,3 @@ async def run() -> None: ) asyncio.run(run()) - - -def test_setup_student_inference_pool_defaults_to_mito(): - """No renderer -> plain MITO chat completions.""" - - async def run() -> None: - tokenizer = object() - config = SimpleNamespace( - training_mode="rl", - renderer=None, - pool_size=None, - student=SimpleNamespace( - client=SimpleNamespace(base_url=["http://localhost:8000/v1"]), - model=SimpleNamespace(name="student-model"), - ), - ) - inference_pool = object() - - with ( - patch("renderers.base.create_renderer") as create_renderer_mock, - patch( - "prime_rl.orchestrator.utils.setup_inference_pool", - new=AsyncMock(return_value=inference_pool), - ) as setup_pool_mock, - ): - renderer, returned_pool = await setup_student_inference_pool( - config=config, - tokenizer=tokenizer, - ) - - assert renderer is None - assert returned_pool is inference_pool - create_renderer_mock.assert_not_called() - setup_pool_mock.assert_awaited_once_with( - config.student.client, - model_name="student-model", - train_client_type="openai_chat_completions", - eval_client_type="openai_chat_completions", - ) - - asyncio.run(run()) diff --git a/tests/unit/orchestrator/test_sft_trajectories.py b/tests/unit/orchestrator/test_sft_trajectories.py deleted file mode 100644 index 8252c6cff5..0000000000 --- a/tests/unit/orchestrator/test_sft_trajectories.py +++ /dev/null @@ -1,141 +0,0 @@ -from unittest.mock import MagicMock - -import verifiers as vf - -from prime_rl.orchestrator.trajectories import backfill_rollout_tokens, interleave_rollout - - -class SimpleChatTokenizer: - def __init__(self): - self._tok2id: dict[str, int] = {} - self._next_id = 1 - - def _id(self, token: str) -> int: - if token not in self._tok2id: - self._tok2id[token] = self._next_id - self._next_id += 1 - return self._tok2id[token] - - def apply_chat_template(self, messages, add_generation_prompt=False, return_dict=False): - del return_dict - ids = [] - for message in messages: - role = message.get("role", "unknown") - ids.append(self._id(f"<|{role}|>")) - content = message.get("content", "") - if isinstance(content, str): - if content: - ids.append(self._id(content)) - else: - ids.append(self._id(str(content))) - if add_generation_prompt: - ids.append(self._id("<|assistant|>")) - return ids - - -def test_interleave_rollout_missing_tokens_returns_none(): - output = vf.RolloutOutput( - example_id=1, - trajectory=[ - vf.TrajectoryStep( - prompt=[{"role": "user", "content": "U1"}], - completion=[{"role": "assistant", "content": "A1"}], - response=MagicMock(), - tokens=None, - reward=None, - advantage=None, - is_truncated=False, - trajectory_id="1", - extras={}, - ) - ], - sampling_args={"temperature": 1.0}, - error=None, - ) - - assert interleave_rollout(output) is None - - -def test_backfill_rollout_tokens_for_sft(): - tokenizer = SimpleChatTokenizer() - output = vf.RolloutOutput( - example_id=42, - env_name="test-env", - trajectory=[ - vf.TrajectoryStep( - prompt=[{"role": "user", "content": "U1"}], - completion=[{"role": "assistant", "content": "A1"}], - response=MagicMock(), - tokens=None, - reward=None, - advantage=None, - is_truncated=False, - trajectory_id="1", - extras={}, - ), - vf.TrajectoryStep( - prompt=[ - {"role": "user", "content": "U1"}, - {"role": "assistant", "content": "A1"}, - {"role": "user", "content": "U2"}, - ], - completion=[{"role": "assistant", "content": "A2"}], - response=MagicMock(), - tokens=None, - reward=None, - advantage=None, - is_truncated=False, - trajectory_id="1", - extras={}, - ), - ], - sampling_args={"temperature": 1.0}, - error=None, - ) - - backfill_rollout_tokens(output, tokenizer) - - rollouts = interleave_rollout(output) - assert rollouts is not None - assert len(rollouts) == 1 - - rollout = rollouts[0] - step1_prompt_ids = tokenizer.apply_chat_template( - [{"role": "user", "content": "U1"}], - add_generation_prompt=True, - ) - step1_full_ids = tokenizer.apply_chat_template( - [{"role": "user", "content": "U1"}, {"role": "assistant", "content": "A1"}], - add_generation_prompt=False, - ) - step2_prompt_ids = tokenizer.apply_chat_template( - [ - {"role": "user", "content": "U1"}, - {"role": "assistant", "content": "A1"}, - {"role": "user", "content": "U2"}, - ], - add_generation_prompt=True, - ) - step2_full_ids = tokenizer.apply_chat_template( - [ - {"role": "user", "content": "U1"}, - {"role": "assistant", "content": "A1"}, - {"role": "user", "content": "U2"}, - {"role": "assistant", "content": "A2"}, - ], - add_generation_prompt=False, - ) - - prefix_len_1 = len(step1_prompt_ids) - prefix_len_2 = len(step2_prompt_ids) - step1_completion_ids = step1_full_ids[prefix_len_1:] - step2_completion_ids = step2_full_ids[prefix_len_2:] - step1_prefix = step1_prompt_ids + step1_completion_ids - step2_new_prompt_ids = step2_prompt_ids[len(step1_prefix) :] - - assert rollout.prompt_ids == step1_prompt_ids - assert rollout.completion_ids == step1_completion_ids + step2_new_prompt_ids + step2_completion_ids - assert rollout.completion_mask == ( - [True] * len(step1_completion_ids) + [False] * len(step2_new_prompt_ids) + [True] * len(step2_completion_ids) - ) - assert rollout.completion_logprobs == [0.0] * len(rollout.completion_ids) diff --git a/tests/unit/orchestrator/test_teacher_logprobs.py b/tests/unit/orchestrator/test_teacher_logprobs.py index d63fdce792..3d2fa95f30 100644 --- a/tests/unit/orchestrator/test_teacher_logprobs.py +++ b/tests/unit/orchestrator/test_teacher_logprobs.py @@ -2,7 +2,8 @@ import json import httpx -import verifiers as vf +import openai +from verifiers.v1.clients.config import OpenAIClientConfig from prime_rl.orchestrator import utils as orchestrator_utils from prime_rl.transport import TrainingSample @@ -41,20 +42,19 @@ async def _run(): "kv_transfer_params": None, } ) - monkeypatch.setattr(orchestrator_utils, "setup_openai_client", lambda _: fake_client) + # compute_teacher_logprobs constructs AsyncOpenAI directly; hand back the fake. + monkeypatch.setattr(openai, "AsyncOpenAI", lambda **kwargs: fake_client) sample = TrainingSample( - prompt_ids=[1], - prompt_mask=[True], - completion_ids=[2, 3], - completion_mask=[True, True], - completion_logprobs=[-0.1, -0.2], - completion_temperatures=[1.0, 1.0], + token_ids=[1, 2, 3], + mask=[False, True, True], + logprobs=[0.0, -0.1, -0.2], + temperatures=[1.0, 1.0, 1.0], env_name="test-env", ) result = await orchestrator_utils.compute_teacher_logprobs( - clients=[vf.ClientConfig()], + clients=[OpenAIClientConfig(base_url="http://fake-host:8000/v1")], model_name="teacher-model", samples=[sample], ) diff --git a/tests/unit/orchestrator/test_trajectories.py b/tests/unit/orchestrator/test_trajectories.py deleted file mode 100644 index bda129cd43..0000000000 --- a/tests/unit/orchestrator/test_trajectories.py +++ /dev/null @@ -1,1380 +0,0 @@ -from unittest.mock import MagicMock - -import numpy as np -import pybase64 -import pytest -import verifiers as vf - -from prime_rl.orchestrator.trajectories import ( - _deserialize_tool_calls, - align_routed_experts, - interleave_rollout, -) - -_interleave_rollout = interleave_rollout - - -def interleave_rollout(output, *args, **kwargs): - kwargs.setdefault("env_name", output.get("env_name", "test-env")) - return _interleave_rollout(output, *args, **kwargs) - - -def _decode_mm_pixels(sample) -> list: - """Decode ``sample.mm_kwargs['pixel_values']`` to a nested list.""" - p = sample.mm_kwargs["pixel_values"] - return np.frombuffer(p.data, dtype=np.dtype(p.dtype)).reshape(p.shape).tolist() - - -def _decode_mm_thw(sample) -> list: - """Decode ``sample.mm_kwargs['image_grid_thw']`` to a nested list.""" - g = sample.mm_kwargs["image_grid_thw"] - return np.frombuffer(g.data, dtype=np.dtype(g.dtype)).reshape(g.shape).tolist() - - -def _routed_experts_payload(data, start: int = 0) -> dict: - arr = np.asarray(data, dtype=np.uint8) - return { - "data": pybase64.b64encode(memoryview(np.ascontiguousarray(arr))).decode("ascii"), - "shape": list(arr.shape), - "start": start, - } - - -def _sample_routed_experts(sample) -> np.ndarray: - assert sample.routed_experts is not None - return np.frombuffer(sample.routed_experts.data, dtype=np.dtype(sample.routed_experts.dtype)).reshape( - sample.routed_experts.shape - ) - - -def test_deserialize_tool_calls_does_not_inject_missing_key(): - messages = [{"role": "assistant", "content": "hello"}] - - deserialized = _deserialize_tool_calls(messages) - - assert "tool_calls" not in deserialized[0] - - -def test_deserialize_tool_calls_parses_arguments_when_present(): - messages = [ - { - "role": "assistant", - "tool_calls": [ - { - "id": "1", - "type": "function", - "function": {"name": "lookup", "arguments": '{"x": 1}'}, - } - ], - } - ] - - deserialized = _deserialize_tool_calls(messages) - - assert deserialized[0]["tool_calls"][0]["function"]["arguments"] == {"x": 1} - - -@pytest.fixture -def single_step_trajectory_output(): - output = vf.RolloutOutput( - example_id=0, - trajectory=[ - vf.TrajectoryStep( - prompt=[{"role": "user", "content": "U1"}], - completion=[{"role": "assistant", "content": "A1"}], - response=MagicMock(), - tokens=vf.TrajectoryStepTokens( - prompt_ids=[1, 2], - prompt_mask=[0, 0], - completion_ids=[3, 4], - completion_mask=[1, 1], - completion_logprobs=[-0.1, -0.2], - overlong_prompt=False, - is_truncated=False, - ), - reward=None, - advantage=None, - is_truncated=False, - trajectory_id="1", - extras={}, - ) - ], - sampling_args={"temperature": 1.0}, - error=None, - ) - return output - - -@pytest.fixture -def multi_step_trajectory_output(): - output = vf.RolloutOutput( - example_id=0, - trajectory=[ - vf.TrajectoryStep( - prompt=[{"role": "user", "content": "U1"}], - completion=[{"role": "assistant", "content": "A1"}], - response=MagicMock(), - tokens=vf.TrajectoryStepTokens( - prompt_ids=[1, 2], - prompt_mask=[0, 0], - completion_ids=[3, 4], - completion_mask=[1, 1], - completion_logprobs=[-0.1, -0.2], - overlong_prompt=False, - is_truncated=False, - ), - reward=None, - advantage=None, - is_truncated=False, - trajectory_id="1", - extras={}, - ), - vf.TrajectoryStep( - prompt=[ - {"role": "user", "content": "U1"}, - {"role": "assistant", "content": "A1"}, - {"role": "user", "content": "U2"}, - ], - completion=[{"role": "assistant", "content": "A2"}], - response=MagicMock(), - tokens=vf.TrajectoryStepTokens( - prompt_ids=[1, 2, 3, 4, 5, 6], - prompt_mask=[0, 0, 0, 0, 0, 0], - completion_ids=[7, 8], - completion_mask=[1, 1], - completion_logprobs=[-0.3, -0.4], - overlong_prompt=False, - is_truncated=False, - ), - reward=None, - advantage=None, - is_truncated=False, - trajectory_id="1", - extras={}, - ), - ], - sampling_args={"temperature": 1.0}, - error=None, - ) - return output - - -@pytest.fixture -def multi_step_trajectory_with_tool_calls_output(): - output = vf.RolloutOutput( - example_id=0, - trajectory=[ - vf.TrajectoryStep( - prompt=[{"role": "user", "content": "U1"}], - completion=[{"role": "assistant", "content": "A1 + TC1"}], - response=MagicMock(), - tokens=vf.TrajectoryStepTokens( - prompt_ids=[1, 2], - prompt_mask=[0, 0], - completion_ids=[3, 4], - completion_mask=[1, 1], - completion_logprobs=[-0.1, -0.2], - overlong_prompt=False, - is_truncated=False, - ), - reward=None, - advantage=None, - is_truncated=False, - trajectory_id="1", - extras={}, - ), - vf.TrajectoryStep( - prompt=[ - {"role": "user", "content": "U1"}, - {"role": "assistant", "content": "A1 + TC1"}, - {"role": "tool", "tool_call_id": "TR1", "content": "TR1"}, - ], - completion=[{"role": "assistant", "content": "A2 + TC2"}], - response=MagicMock(), - tokens=vf.TrajectoryStepTokens( - prompt_ids=[1, 2, 3, 4, 5, 6], - prompt_mask=[0, 0, 0, 0, 0, 0], - completion_ids=[7, 8], - completion_mask=[1, 1], - completion_logprobs=[-0.3, -0.4], - overlong_prompt=False, - is_truncated=False, - ), - reward=None, - advantage=None, - is_truncated=False, - trajectory_id="1", - extras={}, - ), - ], - reward=1.0, - advantage=None, - stop_condition=None, - metrics={"has_error": 0.0, "tool_calls": 1.0}, - sampling_args={"temperature": 1.0}, - error=None, - ) - return output - - -@pytest.fixture -def multi_step_trajectory_extension_never_holds(): - """ - 2-step trajectory where extension NEVER holds (step 2 has completely different tokens). - This simulates e.g. a chat template that re-renders the entire conversation differently. - """ - output = vf.RolloutOutput( - example_id=0, - trajectory=[ - vf.TrajectoryStep( - prompt=[{"role": "user", "content": "U1"}], - completion=[{"role": "assistant", "content": "A1"}], - response=MagicMock(), - tokens=vf.TrajectoryStepTokens( - prompt_ids=[1, 2], - prompt_mask=[0, 0], - completion_ids=[3, 4], - completion_mask=[1, 1], - completion_logprobs=[-0.1, -0.2], - overlong_prompt=False, - is_truncated=False, - ), - reward=None, - advantage=None, - is_truncated=False, - trajectory_id="1", - extras={}, - ), - vf.TrajectoryStep( - prompt=[ - {"role": "user", "content": "U1"}, - {"role": "assistant", "content": "A1"}, - {"role": "user", "content": "U2"}, - ], - completion=[{"role": "assistant", "content": "A2"}], - response=MagicMock(), - tokens=vf.TrajectoryStepTokens( - # Different tokens - extension breaks (e.g. thinking was stripped) - prompt_ids=[10, 20, 30, 40, 50, 60], - prompt_mask=[0, 0, 0, 0, 0, 0], - completion_ids=[7, 8], - completion_mask=[1, 1], - completion_logprobs=[-0.3, -0.4], - overlong_prompt=False, - is_truncated=False, - ), - reward=None, - advantage=None, - is_truncated=False, - trajectory_id="1", - extras={}, - ), - ], - sampling_args={"temperature": 1.0}, - error=None, - ) - return output - - -@pytest.fixture -def multi_step_trajectory_with_tool_calls_extension_never_holds(): - """2-step trajectory with tool calls where extension NEVER holds.""" - output = vf.RolloutOutput( - example_id=0, - trajectory=[ - vf.TrajectoryStep( - prompt=[{"role": "user", "content": "U1"}], - completion=[{"role": "assistant", "content": "A1 + TC1"}], - response=MagicMock(), - tokens=vf.TrajectoryStepTokens( - prompt_ids=[1, 2], - prompt_mask=[0, 0], - completion_ids=[3, 4], - completion_mask=[1, 1], - completion_logprobs=[-0.1, -0.2], - overlong_prompt=False, - is_truncated=False, - ), - reward=None, - advantage=None, - extras={}, - is_truncated=False, - trajectory_id="1", - ), - vf.TrajectoryStep( - prompt=[ - {"role": "user", "content": "U1"}, - {"role": "assistant", "content": "A1 + TC1"}, - {"role": "tool", "tool_call_id": "TR1", "content": "TR1"}, - ], - completion=[{"role": "assistant", "content": "A2 + TC2"}], - response=MagicMock(), - tokens=vf.TrajectoryStepTokens( - # Different tokens - extension breaks - prompt_ids=[10, 20, 30, 40, 50, 60], - prompt_mask=[0, 0, 0, 0, 0, 0], - completion_ids=[7, 8], - completion_mask=[1, 1], - completion_logprobs=[-0.3, -0.4], - overlong_prompt=False, - is_truncated=False, - ), - reward=None, - advantage=None, - extras={}, - is_truncated=False, - trajectory_id="1", - ), - ], - reward=1.0, - advantage=None, - stop_condition=None, - sampling_args={"temperature": 1.0}, - metrics={"has_error": 0.0, "tool_calls": 1.0}, - error=None, - ) - return output - - -def test_branching_equivalent_multi_step_trajectory(multi_step_trajectory_extension_never_holds): - """When extension never holds, each step becomes its own sample (same as old branching).""" - rollouts = interleave_rollout(multi_step_trajectory_extension_never_holds) - assert rollouts is not None - assert len(rollouts) == 2 - - # first step - rollout = rollouts[0] - assert rollout.prompt_ids == [1, 2] - assert rollout.prompt_mask == [False, False] - assert rollout.completion_ids == [3, 4] - assert rollout.completion_mask == [True, True] - assert rollout.completion_logprobs == [-0.1, -0.2] - assert rollout.completion_temperatures == [] - - # second step - rollout = rollouts[1] - assert rollout.prompt_ids == [10, 20, 30, 40, 50, 60] - assert rollout.prompt_mask == [False, False, False, False, False, False] - assert rollout.completion_ids == [7, 8] - assert rollout.completion_mask == [True, True] - assert rollout.completion_logprobs == [-0.3, -0.4] - assert rollout.completion_temperatures == [] - - -def test_branching_equivalent_multi_step_trajectory_with_tool_calls( - multi_step_trajectory_with_tool_calls_extension_never_holds, -): - """When extension never holds (with tool calls), same as old branching.""" - rollouts = interleave_rollout(multi_step_trajectory_with_tool_calls_extension_never_holds) - assert rollouts is not None - assert len(rollouts) == 2 - - # first step - rollout = rollouts[0] - assert rollout.prompt_ids == [1, 2] - assert rollout.prompt_mask == [False, False] - assert rollout.completion_ids == [3, 4] - assert rollout.completion_mask == [True, True] - assert rollout.completion_logprobs == [-0.1, -0.2] - assert rollout.completion_temperatures == [] - - # second step - rollout = rollouts[1] - assert rollout.prompt_ids == [10, 20, 30, 40, 50, 60] - assert rollout.prompt_mask == [False, False, False, False, False, False] - assert rollout.completion_ids == [7, 8] - assert rollout.completion_mask == [True, True] - assert rollout.completion_logprobs == [-0.3, -0.4] - assert rollout.completion_temperatures == [] - - -def test_interleave_rollout_single_step_trajectory(single_step_trajectory_output): - single_step_trajectory_output["env_name"] = "test-env" - rollouts = interleave_rollout(single_step_trajectory_output) - assert rollouts is not None - assert len(rollouts) == 1 - rollout = rollouts[0] - - assert rollout.prompt_ids == [1, 2] - assert rollout.prompt_mask == [False, False] - assert rollout.completion_ids == [3, 4] - assert rollout.completion_mask == [True, True] - assert rollout.completion_logprobs == [-0.1, -0.2] - assert rollout.completion_temperatures == [] - assert rollout.env_name == "test-env" - - -def test_interleave_rollout_multi_step_trajectory(multi_step_trajectory_output): - rollouts = interleave_rollout(multi_step_trajectory_output) - assert rollouts is not None - assert len(rollouts) == 1 - rollout = rollouts[0] - - assert rollout.prompt_ids == [1, 2] - assert rollout.prompt_mask == [False, False] - assert rollout.completion_ids == [3, 4, 5, 6, 7, 8] - assert rollout.completion_mask == [True, True, False, False, True, True] - assert rollout.completion_logprobs == [-0.1, -0.2, 0, 0, -0.3, -0.4] - # ``completion_temperatures`` is filled by the orchestrator post-interleave; empty here. - assert rollout.completion_temperatures == [] - - -def test_interleave_rollout_multi_step_trajectory_with_tool_calls(multi_step_trajectory_with_tool_calls_output): - rollouts = interleave_rollout(multi_step_trajectory_with_tool_calls_output) - assert rollouts is not None - assert len(rollouts) == 1 - rollout = rollouts[0] - - assert rollout.prompt_ids == [1, 2] - assert rollout.prompt_mask == [False, False] - assert rollout.completion_ids == [3, 4, 5, 6, 7, 8] - assert rollout.completion_mask == [True, True, False, False, True, True] - assert rollout.completion_logprobs == [-0.1, -0.2, 0, 0, -0.3, -0.4] - # ``completion_temperatures`` is filled by the orchestrator post-interleave; empty here. - assert rollout.completion_temperatures == [] - - -@pytest.fixture -def five_step_trajectory_with_extension_break(): - """ - 5-step trajectory where extension property breaks at step 4. - - Steps 1-3: extension holds (tokens grow by appending) - Step 4: extension breaks (completely different prefix, e.g. context compaction) - Steps 4-5: extension holds again - - Expected: 2 samples (steps 1-3 merged, steps 4-5 merged) - """ - output = vf.RolloutOutput( - example_id=0, - trajectory=[ - # Step 1: initial prompt and completion - vf.TrajectoryStep( - prompt=[{"role": "user", "content": "U1"}], - completion=[{"role": "assistant", "content": "A1"}], - response=MagicMock(), - tokens=vf.TrajectoryStepTokens( - prompt_ids=[1, 2], - prompt_mask=[0, 0], - completion_ids=[3, 4], - completion_mask=[1, 1], - completion_logprobs=[-0.1, -0.2], - overlong_prompt=False, - is_truncated=False, - ), - reward=None, - advantage=None, - extras={}, - is_truncated=False, - trajectory_id="1", - ), - # Step 2: extends step 1 (prefix [1,2,3,4] matches) - vf.TrajectoryStep( - prompt=[ - {"role": "user", "content": "U1"}, - {"role": "assistant", "content": "A1"}, - {"role": "user", "content": "U2"}, - ], - completion=[{"role": "assistant", "content": "A2"}], - response=MagicMock(), - tokens=vf.TrajectoryStepTokens( - prompt_ids=[1, 2, 3, 4, 5, 6], - prompt_mask=[0, 0, 0, 0, 0, 0], - completion_ids=[7, 8], - completion_mask=[1, 1], - completion_logprobs=[-0.3, -0.4], - overlong_prompt=False, - is_truncated=False, - ), - reward=None, - advantage=None, - extras={}, - is_truncated=False, - trajectory_id="1", - ), - # Step 3: extends step 2 (prefix [1,2,3,4,5,6,7,8] matches) - vf.TrajectoryStep( - prompt=[ - {"role": "user", "content": "U1"}, - {"role": "assistant", "content": "A1"}, - {"role": "user", "content": "U2"}, - {"role": "assistant", "content": "A2"}, - {"role": "user", "content": "U3"}, - ], - completion=[{"role": "assistant", "content": "A3"}], - response=MagicMock(), - tokens=vf.TrajectoryStepTokens( - prompt_ids=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10], - prompt_mask=[0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - completion_ids=[11, 12], - completion_mask=[1, 1], - completion_logprobs=[-0.5, -0.6], - overlong_prompt=False, - is_truncated=False, - ), - reward=None, - advantage=None, - extras={}, - is_truncated=False, - trajectory_id="1", - ), - # Step 4: EXTENSION BREAKS - different prefix (e.g. thinking stripped, context compacted) - vf.TrajectoryStep( - prompt=[ - {"role": "user", "content": "U1"}, - {"role": "assistant", "content": "A1"}, # thinking stripped - {"role": "user", "content": "U2"}, - {"role": "assistant", "content": "A2"}, - {"role": "user", "content": "U3"}, - {"role": "assistant", "content": "A3"}, - {"role": "user", "content": "U4"}, - ], - completion=[{"role": "assistant", "content": "A4"}], - response=MagicMock(), - tokens=vf.TrajectoryStepTokens( - prompt_ids=[100, 101, 102, 103], # completely different tokens (re-rendered) - prompt_mask=[0, 0, 0, 0], - completion_ids=[104, 105], - completion_mask=[1, 1], - completion_logprobs=[-0.7, -0.8], - overlong_prompt=False, - is_truncated=False, - ), - reward=None, - advantage=None, - extras={}, - is_truncated=False, - trajectory_id="1", - ), - # Step 5: extends step 4 (prefix [100,101,102,103,104,105] matches) - vf.TrajectoryStep( - prompt=[ - {"role": "user", "content": "U1"}, - {"role": "assistant", "content": "A1"}, - {"role": "user", "content": "U2"}, - {"role": "assistant", "content": "A2"}, - {"role": "user", "content": "U3"}, - {"role": "assistant", "content": "A3"}, - {"role": "user", "content": "U4"}, - {"role": "assistant", "content": "A4"}, - {"role": "user", "content": "U5"}, - ], - completion=[{"role": "assistant", "content": "A5"}], - response=MagicMock(), - tokens=vf.TrajectoryStepTokens( - prompt_ids=[100, 101, 102, 103, 104, 105, 106, 107], - prompt_mask=[0, 0, 0, 0, 0, 0, 0, 0], - completion_ids=[108, 109], - completion_mask=[1, 1], - completion_logprobs=[-0.9, -1.0], - overlong_prompt=False, - is_truncated=False, - ), - reward=None, - advantage=None, - extras={}, - is_truncated=False, - trajectory_id="1", - ), - ], - sampling_args={"temperature": 1.0}, - error=None, - ) - return output - - -def test_interleave_rollout_extension_break_creates_multiple_samples(five_step_trajectory_with_extension_break): - """ - When extension property breaks mid-trajectory, interleave_rollout should: - - Merge steps 1-3 into first sample (extension held) - - Start new sample at step 4 (extension broke) - - Merge steps 4-5 into second sample (extension held again) - """ - rollouts = interleave_rollout(five_step_trajectory_with_extension_break) - - assert rollouts is not None - assert len(rollouts) == 2, "Should produce 2 samples when extension breaks at step 4" - - # First sample: steps 1-3 merged - sample1 = rollouts[0] - assert sample1.prompt_ids == [1, 2] - assert sample1.prompt_mask == [False, False] - # completion_ids: step1 completion [3,4] + step2 new prompt [5,6] + step2 completion [7,8] - # + step3 new prompt [9,10] + step3 completion [11,12] - assert sample1.completion_ids == [3, 4, 5, 6, 7, 8, 9, 10, 11, 12] - # completion_mask: step1 [T,T] + step2 prompt [F,F] + step2 completion [T,T] - # + step3 prompt [F,F] + step3 completion [T,T] - assert sample1.completion_mask == [True, True, False, False, True, True, False, False, True, True] - assert sample1.completion_logprobs == [-0.1, -0.2, 0, 0, -0.3, -0.4, 0, 0, -0.5, -0.6] - - # Second sample: steps 4-5 merged (fresh start after extension break) - sample2 = rollouts[1] - assert sample2.prompt_ids == [100, 101, 102, 103] - assert sample2.prompt_mask == [False, False, False, False] - # completion_ids: step4 completion [104,105] + step5 new prompt [106,107] + step5 completion [108,109] - assert sample2.completion_ids == [104, 105, 106, 107, 108, 109] - # completion_mask: step4 [T,T] + step5 prompt [F,F] + step5 completion [T,T] - assert sample2.completion_mask == [True, True, False, False, True, True] - assert sample2.completion_logprobs == [-0.7, -0.8, 0, 0, -0.9, -1.0] - - -@pytest.fixture -def interleaved_agents_trajectory(): - """ - Trajectory with interleaved agents: agent1 steps, then agent2 step, then agent1 continues. - This tests multi-prefix tracking where agent1-step3 should merge back with agent1 sample. - - agent1-step1: prompt=[1,2], completion=[3,4] - agent1-step2: prompt=[1,2,3,4,5,6], completion=[7,8] (extends agent1-step1) - agent2-step1: prompt=[100,101], completion=[102,103] (different prefix, new sample) - agent1-step3: prompt=[1,2,3,4,5,6,7,8,9,10], completion=[11,12] (extends agent1-step2!) - """ - output = vf.RolloutOutput( - example_id=1, - task="test", - trajectory=[ - # agent1-step1 - vf.TrajectoryStep( - prompt="agent1 turn 1", - completion="response 1", - response=None, - tokens=vf.TrajectoryStepTokens( - prompt_ids=[1, 2], - prompt_mask=[0, 0], - completion_ids=[3, 4], - completion_mask=[1, 1], - completion_logprobs=[-0.1, -0.2], - overlong_prompt=False, - is_truncated=False, - ), - reward=None, - advantage=None, - is_truncated=False, - trajectory_id="traj1", - extras={}, - ), - # agent1-step2 (extends agent1-step1) - vf.TrajectoryStep( - prompt="agent1 turn 2", - completion="response 2", - response=None, - tokens=vf.TrajectoryStepTokens( - prompt_ids=[1, 2, 3, 4, 5, 6], - prompt_mask=[0, 0, 0, 0, 0, 0], - completion_ids=[7, 8], - completion_mask=[1, 1], - completion_logprobs=[-0.3, -0.4], - overlong_prompt=False, - is_truncated=False, - ), - reward=None, - advantage=None, - is_truncated=False, - trajectory_id="traj1", - extras={}, - ), - # agent2-step1 (different prefix, starts new sample) - vf.TrajectoryStep( - prompt="agent2 turn 1", - completion="agent2 response", - response=None, - tokens=vf.TrajectoryStepTokens( - prompt_ids=[100, 101], - prompt_mask=[0, 0], - completion_ids=[102, 103], - completion_mask=[1, 1], - completion_logprobs=[-0.5, -0.6], - overlong_prompt=False, - is_truncated=False, - ), - reward=None, - advantage=None, - is_truncated=False, - trajectory_id="traj2", - extras={}, - ), - # agent1-step3 (extends agent1-step2, should merge back!) - vf.TrajectoryStep( - prompt="agent1 turn 3", - completion="response 3", - response=None, - tokens=vf.TrajectoryStepTokens( - prompt_ids=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10], - prompt_mask=[0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - completion_ids=[11, 12], - completion_mask=[1, 1], - completion_logprobs=[-0.7, -0.8], - overlong_prompt=False, - is_truncated=False, - ), - reward=None, - advantage=None, - is_truncated=False, - trajectory_id="traj1", - extras={}, - ), - ], - sampling_args={"temperature": 1.0}, - error=None, - ) - return output - - -def test_interleave_rollout_interleaved_agents(interleaved_agents_trajectory): - """ - When agents are interleaved (agent1, agent1, agent2, agent1), the multi-prefix - tracking should merge agent1-step3 back into the agent1 sample, not start a new one. - """ - rollouts = interleave_rollout(interleaved_agents_trajectory) - - assert rollouts is not None - assert len(rollouts) == 2, "Should produce 2 samples (agent1 merged, agent2 separate)" - - # First sample: agent1 steps 1, 2, 3 merged - agent1_sample = rollouts[0] - assert agent1_sample.prompt_ids == [1, 2] - assert agent1_sample.prompt_mask == [False, False] - # completion_ids: step1 [3,4] + step2 new prompt [5,6] + step2 completion [7,8] - # + step3 new prompt [9,10] + step3 completion [11,12] - assert agent1_sample.completion_ids == [3, 4, 5, 6, 7, 8, 9, 10, 11, 12] - assert agent1_sample.completion_mask == [True, True, False, False, True, True, False, False, True, True] - assert agent1_sample.completion_logprobs == [-0.1, -0.2, 0, 0, -0.3, -0.4, 0, 0, -0.7, -0.8] - - # Second sample: agent2 step 1 only - agent2_sample = rollouts[1] - assert agent2_sample.prompt_ids == [100, 101] - assert agent2_sample.prompt_mask == [False, False] - assert agent2_sample.completion_ids == [102, 103] - assert agent2_sample.completion_mask == [True, True] - assert agent2_sample.completion_logprobs == [-0.5, -0.6] - - -@pytest.fixture -def prefix_of_prefix_trajectory(): - """ - Trajectory where one active sample's prefix is a strict prefix of another's. - - Construction: - - step 0: prompt=[1,2], completion=[3,4] -> sample A, P_A=[1,2,3,4] - - step 1: extends A. prompt=[1,2,3,4,5], completion=[6] -> P_A=[1,2,3,4,5,6] - - step 2: rollback/regenerate. prompt=[1,2] (shorter than P_A so no match), - completion=[3,4,5,6,7] -> sample B, P_B=[1,2,3,4,5,6,7] - P_B starts with P_A. - - step 3: extends B. prompt=[1,2,3,4,5,6,7,8], completion=[9] - Both P_A and P_B are token-prefixes of the step's prompt. - - The correct match is the longer P_B. First-match-wins picks P_A and silently - folds B's generated tokens into A as user-input tokens (mask=False). - """ - output = vf.RolloutOutput( - example_id=2, - task="test", - trajectory=[ - vf.TrajectoryStep( - prompt="step 0", - completion="completion 0", - response=None, - tokens=vf.TrajectoryStepTokens( - prompt_ids=[1, 2], - prompt_mask=[0, 0], - completion_ids=[3, 4], - completion_mask=[1, 1], - completion_logprobs=[-0.1, -0.2], - overlong_prompt=False, - is_truncated=False, - ), - reward=None, - advantage=None, - is_truncated=False, - trajectory_id="traj_A", - extras={}, - ), - vf.TrajectoryStep( - prompt="step 1", - completion="completion 1", - response=None, - tokens=vf.TrajectoryStepTokens( - prompt_ids=[1, 2, 3, 4, 5], - prompt_mask=[0, 0, 0, 0, 0], - completion_ids=[6], - completion_mask=[1], - completion_logprobs=[-0.3], - overlong_prompt=False, - is_truncated=False, - ), - reward=None, - advantage=None, - is_truncated=False, - trajectory_id="traj_A", - extras={}, - ), - vf.TrajectoryStep( - prompt="step 2 (rollback)", - completion="completion 2", - response=None, - tokens=vf.TrajectoryStepTokens( - prompt_ids=[1, 2], - prompt_mask=[0, 0], - completion_ids=[3, 4, 5, 6, 7], - completion_mask=[1, 1, 1, 1, 1], - completion_logprobs=[-0.4, -0.5, -0.6, -0.7, -0.8], - overlong_prompt=False, - is_truncated=False, - ), - reward=None, - advantage=None, - is_truncated=False, - trajectory_id="traj_B", - extras={}, - ), - vf.TrajectoryStep( - prompt="step 3 (extends B)", - completion="completion 3", - response=None, - tokens=vf.TrajectoryStepTokens( - prompt_ids=[1, 2, 3, 4, 5, 6, 7, 8], - prompt_mask=[0, 0, 0, 0, 0, 0, 0, 0], - completion_ids=[9], - completion_mask=[1], - completion_logprobs=[-0.9], - overlong_prompt=False, - is_truncated=False, - ), - reward=None, - advantage=None, - is_truncated=False, - trajectory_id="traj_B", - extras={}, - ), - ], - sampling_args={"temperature": 1.0}, - error=None, - ) - return output - - -def test_interleave_rollout_picks_longest_matching_prefix(prefix_of_prefix_trajectory): - """ - When two active samples both match (one's prefix is a strict prefix of the - other's), the longer prefix is the correct extension. Previously the first- - match-wins loop folded the longer sample's generated tokens into the shorter - sample as user input (mask=False) and left the longer sample stale. - """ - rollouts = interleave_rollout(prefix_of_prefix_trajectory) - - assert rollouts is not None - assert len(rollouts) == 2 - - # Sample A: steps 0 and 1 only. Step 3 must NOT have been folded in here. - sample_a = rollouts[0] - assert sample_a.prompt_ids == [1, 2] - # step 0 completion [3,4] + step 1 new prompt [5] + step 1 completion [6] - assert sample_a.completion_ids == [3, 4, 5, 6] - assert sample_a.completion_mask == [True, True, False, True] - assert sample_a.completion_logprobs == [-0.1, -0.2, 0.0, -0.3] - - # Sample B: steps 2 and 3 merged. The token 7 (from step 2's completion) - # must remain masked as a generated token, not silently re-classified. - sample_b = rollouts[1] - assert sample_b.prompt_ids == [1, 2] - # step 2 completion [3,4,5,6,7] + step 3 new prompt [8] + step 3 completion [9] - assert sample_b.completion_ids == [3, 4, 5, 6, 7, 8, 9] - assert sample_b.completion_mask == [True, True, True, True, True, False, True] - assert sample_b.completion_logprobs == [-0.4, -0.5, -0.6, -0.7, -0.8, 0.0, -0.9] - - -def test_interleave_rollout_empty_trajectory(): - """Empty trajectory returns None.""" - output = vf.RolloutOutput( - example_id=1, - trajectory=[], - error=None, - ) - assert interleave_rollout(output) is None - - -def test_interleave_rollout_error_masks_all_false(): - """ - When rollout output has an error, all completion_mask values should be False - across both make_sample (step 0) and extend_sample (step 1). - """ - output = vf.RolloutOutput( - example_id=1, - trajectory=[ - vf.TrajectoryStep( - prompt=[{"role": "user", "content": "U1"}], - completion=[{"role": "assistant", "content": "A1"}], - response=MagicMock(), - tokens=vf.TrajectoryStepTokens( - prompt_ids=[1, 2], - prompt_mask=[0, 0], - completion_ids=[3, 4], - completion_mask=[1, 1], - completion_logprobs=[-0.1, -0.2], - overlong_prompt=False, - is_truncated=False, - ), - reward=None, - advantage=None, - is_truncated=False, - trajectory_id="1", - extras={}, - ), - vf.TrajectoryStep( - prompt=[{"role": "user", "content": "U2"}], - completion=[{"role": "assistant", "content": "A2"}], - response=MagicMock(), - tokens=vf.TrajectoryStepTokens( - prompt_ids=[1, 2, 3, 4, 5, 6], - prompt_mask=[0, 0, 0, 0, 0, 0], - completion_ids=[7, 8], - completion_mask=[1, 1], - completion_logprobs=[-0.3, -0.4], - overlong_prompt=False, - is_truncated=False, - ), - reward=None, - advantage=None, - is_truncated=False, - trajectory_id="1", - extras={}, - ), - ], - error="timeout: environment exceeded time limit", - sampling_args={"temperature": 0.8}, - ) - - rollouts = interleave_rollout(output) - - assert rollouts is not None - assert len(rollouts) == 1 - rollout = rollouts[0] - # Extension holds so tokens merge, but ALL completion_mask should be False - assert rollout.completion_ids == [3, 4, 5, 6, 7, 8] - assert rollout.completion_mask == [False, False, False, False, False, False] - # Logprobs preserved; ``completion_temperatures`` is filled by the orchestrator post-interleave. - assert rollout.completion_logprobs == [-0.1, -0.2, 0.0, 0.0, -0.3, -0.4] - assert rollout.completion_temperatures == [] - - -def test_align_routed_experts_none(): - assert align_routed_experts(None, 10) is None - - -def test_align_routed_experts_empty(): - experts = np.empty((0, 2, 2), dtype=np.uint8) - result = align_routed_experts(experts, 10) - assert result is not None - assert result.shape == (10, 2, 2) - assert np.all(result == 0) - - -def test_align_routed_experts_no_deficit(): - # 3 tokens, 2 layers, topk=2 - experts = np.asarray([[[0, 1], [2, 3]], [[4, 5], [6, 7]], [[0, 2], [1, 3]]], dtype=np.uint8) - result = align_routed_experts(experts, expected_len=3) - np.testing.assert_array_equal(result, experts) - - -def test_align_routed_experts_with_deficit(): - # 2 tokens but expected 4 (deficit of 2) - experts = np.asarray([[[1, 2], [3, 4]], [[5, 6], [7, 0]]], dtype=np.uint8) - result = align_routed_experts(experts, expected_len=4) - assert result is not None - assert result.shape == (4, 2, 2) - np.testing.assert_array_equal(result[:2], experts) - # Padded entries should be zero-filled with same shape [layers=2, topk=2] - np.testing.assert_array_equal(result[2], [[0, 0], [0, 0]]) - np.testing.assert_array_equal(result[3], [[0, 0], [0, 0]]) - - -def test_align_routed_experts_excess_length(): - experts = np.asarray([[[1, 2]], [[3, 4]], [[5, 6]]], dtype=np.uint8) - result = align_routed_experts(experts, expected_len=2) - np.testing.assert_array_equal(result, experts[:2]) - - -def test_interleave_rollout_single_step_with_routed_experts(): - """Routed experts are aligned and passed through for a single-step trajectory.""" - # prompt_ids=[1,2], completion_ids=[3,4] -> total 4 tokens - # vLLM returns num_tokens-1 = 3 routed expert entries - routed_experts_from_vllm = np.asarray([[[0, 1]], [[2, 3]], [[4, 5]]], dtype=np.uint8) - output = vf.RolloutOutput( - example_id=0, - trajectory=[ - vf.TrajectoryStep( - prompt=[{"role": "user", "content": "U1"}], - completion=[{"role": "assistant", "content": "A1"}], - response=MagicMock(), - tokens=vf.TrajectoryStepTokens( - prompt_ids=[1, 2], - prompt_mask=[0, 0], - completion_ids=[3, 4], - completion_mask=[1, 1], - completion_logprobs=[-0.1, -0.2], - overlong_prompt=False, - is_truncated=False, - routed_experts=_routed_experts_payload(routed_experts_from_vllm), - ), - reward=None, - advantage=None, - is_truncated=False, - trajectory_id="1", - extras={}, - ) - ], - sampling_args={"temperature": 1.0}, - error=None, - ) - - rollouts = interleave_rollout(output) - assert rollouts is not None - assert len(rollouts) == 1 - sample = rollouts[0] - - # Should be aligned to 4 tokens (2 prompt + 2 completion) - assert sample.routed_experts is not None - routed_experts = _sample_routed_experts(sample) - assert routed_experts.shape == (4, 1, 2) - # First 3 are original, last one is zero-padded - np.testing.assert_array_equal(routed_experts[:3], routed_experts_from_vllm) - np.testing.assert_array_equal(routed_experts[3], [[0, 0]]) - - -def test_interleave_rollout_multi_step_with_routed_experts(): - """Routed experts are extended and aligned across multi-step trajectories.""" - # Step 1: prompt=[1,2], completion=[3,4] -> 4 tokens, vLLM returns 3 - step1_experts = np.asarray([[[1, 2]], [[3, 4]], [[5, 6]]], dtype=np.uint8) - # Step 2: prompt=[1,2,3,4,5,6], completion=[7,8], bridged from prefix len 4. - # vLLM returns routed experts starting at row 3: boundary token 4, then 5, 6, 7. - step2_experts = np.asarray([[[40, 41]], [[50, 51]], [[60, 61]], [[70, 71]]], dtype=np.uint8) - - output = vf.RolloutOutput( - example_id=0, - trajectory=[ - vf.TrajectoryStep( - prompt=[{"role": "user", "content": "U1"}], - completion=[{"role": "assistant", "content": "A1"}], - response=MagicMock(), - tokens=vf.TrajectoryStepTokens( - prompt_ids=[1, 2], - prompt_mask=[0, 0], - completion_ids=[3, 4], - completion_mask=[1, 1], - completion_logprobs=[-0.1, -0.2], - overlong_prompt=False, - is_truncated=False, - routed_experts=_routed_experts_payload(step1_experts), - ), - reward=None, - advantage=None, - is_truncated=False, - trajectory_id="1", - extras={}, - ), - vf.TrajectoryStep( - prompt=[ - {"role": "user", "content": "U1"}, - {"role": "assistant", "content": "A1"}, - {"role": "user", "content": "U2"}, - ], - completion=[{"role": "assistant", "content": "A2"}], - response=MagicMock(), - tokens=vf.TrajectoryStepTokens( - prompt_ids=[1, 2, 3, 4, 5, 6], - prompt_mask=[0, 0, 0, 0, 0, 0], - completion_ids=[7, 8], - completion_mask=[1, 1], - completion_logprobs=[-0.3, -0.4], - overlong_prompt=False, - is_truncated=False, - routed_experts=_routed_experts_payload(step2_experts, start=3), - ), - reward=None, - advantage=None, - is_truncated=False, - trajectory_id="1", - extras={}, - ), - ], - sampling_args={"temperature": 1.0}, - error=None, - ) - - rollouts = interleave_rollout(output) - assert rollouts is not None - assert len(rollouts) == 1 - sample = rollouts[0] - - # Merged sample: prompt=[1,2], completion=[3,4,5,6,7,8] -> 8 tokens total - assert len(sample.prompt_ids) + len(sample.completion_ids) == 8 - assert sample.routed_experts is not None - routed_experts = _sample_routed_experts(sample) - assert routed_experts.shape == (8, 1, 2) - np.testing.assert_array_equal( - routed_experts, - np.asarray( - [ - [[1, 2]], - [[3, 4]], - [[5, 6]], - [[40, 41]], - [[50, 51]], - [[60, 61]], - [[70, 71]], - [[0, 0]], - ], - dtype=np.uint8, - ), - ) - - -def test_interleave_rollout_branch_delta_uses_prior_routed_prefix(): - step1_experts = np.asarray([[[1, 2]], [[3, 4]], [[5, 6]]], dtype=np.uint8) - step2_experts = np.asarray([[[40, 41]], [[50, 51]], [[60, 61]], [[70, 71]]], dtype=np.uint8) - step3_experts = np.asarray([[[80, 81]], [[90, 91]], [[100, 101]], [[110, 111]]], dtype=np.uint8) - - output = vf.RolloutOutput( - example_id=0, - trajectory=[ - vf.TrajectoryStep( - prompt=[{"role": "user", "content": "U1"}], - completion=[{"role": "assistant", "content": "A1"}], - response=MagicMock(), - tokens=vf.TrajectoryStepTokens( - prompt_ids=[1, 2], - prompt_mask=[0, 0], - completion_ids=[3, 4], - completion_mask=[1, 1], - completion_logprobs=[-0.1, -0.2], - overlong_prompt=False, - is_truncated=False, - routed_experts=_routed_experts_payload(step1_experts), - ), - reward=None, - advantage=None, - is_truncated=False, - trajectory_id="1", - extras={}, - ), - vf.TrajectoryStep( - prompt=[ - {"role": "user", "content": "U1"}, - {"role": "assistant", "content": "A1"}, - {"role": "user", "content": "U2"}, - ], - completion=[{"role": "assistant", "content": "A2"}], - response=MagicMock(), - tokens=vf.TrajectoryStepTokens( - prompt_ids=[1, 2, 3, 4, 5, 6], - prompt_mask=[0, 0, 0, 0, 0, 0], - completion_ids=[7, 8], - completion_mask=[1, 1], - completion_logprobs=[-0.3, -0.4], - overlong_prompt=False, - is_truncated=False, - routed_experts=_routed_experts_payload(step2_experts, start=3), - ), - reward=None, - advantage=None, - is_truncated=False, - trajectory_id="1", - extras={}, - ), - vf.TrajectoryStep( - prompt=[ - {"role": "user", "content": "U1"}, - {"role": "assistant", "content": "A1"}, - {"role": "user", "content": "branch"}, - ], - completion=[{"role": "assistant", "content": "A3"}], - response=MagicMock(), - tokens=vf.TrajectoryStepTokens( - prompt_ids=[1, 2, 3, 4, 9, 10], - prompt_mask=[0, 0, 0, 0, 0, 0], - completion_ids=[11, 12], - completion_mask=[1, 1], - completion_logprobs=[-0.5, -0.6], - overlong_prompt=False, - is_truncated=False, - routed_experts=_routed_experts_payload(step3_experts, start=3), - ), - reward=None, - advantage=None, - is_truncated=False, - trajectory_id="1", - extras={}, - ), - ], - sampling_args={"temperature": 1.0}, - error=None, - ) - - rollouts = interleave_rollout(output) - - assert rollouts is not None - assert len(rollouts) == 2 - branched = rollouts[1] - assert branched.prompt_ids == [1, 2, 3, 4, 9, 10] - assert branched.completion_ids == [11, 12] - routed_experts = _sample_routed_experts(branched) - assert routed_experts.shape == (8, 1, 2) - np.testing.assert_array_equal( - routed_experts, - np.asarray( - [ - [[1, 2]], - [[3, 4]], - [[5, 6]], - [[80, 81]], - [[90, 91]], - [[100, 101]], - [[110, 111]], - [[0, 0]], - ], - dtype=np.uint8, - ), - ) - - -def test_interleave_rollout_none_routed_experts_stays_none(): - """When routed_experts is None, sample.routed_experts remains None.""" - output = vf.RolloutOutput( - example_id=0, - trajectory=[ - vf.TrajectoryStep( - prompt=[{"role": "user", "content": "U1"}], - completion=[{"role": "assistant", "content": "A1"}], - response=MagicMock(), - tokens=vf.TrajectoryStepTokens( - prompt_ids=[1, 2], - prompt_mask=[0, 0], - completion_ids=[3, 4], - completion_mask=[1, 1], - completion_logprobs=[-0.1, -0.2], - overlong_prompt=False, - is_truncated=False, - routed_experts=None, - ), - reward=None, - advantage=None, - is_truncated=False, - trajectory_id="1", - extras={}, - ) - ], - sampling_args={"temperature": 1.0}, - error=None, - ) - - rollouts = interleave_rollout(output) - assert rollouts is not None - assert rollouts[0].routed_experts is None - - -# ============================================================================= -# Renderer-emitted multimodal data -# ============================================================================= - - -def test_interleave_rollout_packs_pixels_from_renderer_mm_data(): - """``interleave_rollout`` packs renderer-emitted ``multi_modal_data`` - (pixel_values / image_grid_thw / mm_token_type_ids) onto the - TrainingSample. - - verifiers' ``_delta_intermediate_mm_data`` ships per-step *delta* - mm_data (each step contains only items not present in the prior - step's cumulative set). Prime-rl unions across the sample's step - range to recover the cumulative set in image-placeholder order. - """ - import torch as _torch - from renderers.base import MultiModalData, PlaceholderRange - - # Two synthetic single-image items — values are arbitrary, what - # matters is that the packer concatenates them correctly. - item1_pv = _torch.tensor([[1.0, 2.0]], dtype=_torch.float32) - item2_pv = _torch.tensor([[3.0, 4.0]], dtype=_torch.float32) - item1_thw = _torch.tensor([[1, 2, 3]], dtype=_torch.int64) - item2_thw = _torch.tensor([[1, 4, 4]], dtype=_torch.int64) - - # Step 0: image h1 (first time it's seen, included in delta). - mm_step_0 = MultiModalData( - mm_hashes={"image": ["h1"]}, - mm_placeholders={"image": [PlaceholderRange(offset=1, length=1)]}, - mm_items={"image": [{"pixel_values": item1_pv, "image_grid_thw": item1_thw}]}, - ) - # Step 1: post-delta — only h2 (h1 was dropped because it was in - # the prior step's cumulative set). Renderer's bridge would have - # produced cumulative [h1, h2] before verifiers' delta rewrite. - mm_step_1 = MultiModalData( - mm_hashes={"image": ["h2"]}, - mm_placeholders={"image": [PlaceholderRange(offset=4, length=1)]}, - mm_items={"image": [{"pixel_values": item2_pv, "image_grid_thw": item2_thw}]}, - ) - - output = vf.RolloutOutput( - example_id=1, - trajectory=[ - vf.TrajectoryStep( - prompt=[{"role": "user", "content": "Turn 1"}], - completion=[{"role": "assistant", "content": "Response 1"}], - response=MagicMock(), - tokens=vf.TrajectoryStepTokens( - prompt_ids=[1, 2], - prompt_mask=[0, 0], - completion_ids=[3, 4], - completion_mask=[1, 1], - completion_logprobs=[-0.1, -0.2], - overlong_prompt=False, - is_truncated=False, - multi_modal_data=mm_step_0, - ), - reward=None, - advantage=None, - is_truncated=False, - trajectory_id="1", - extras={}, - ), - vf.TrajectoryStep( - prompt=[{"role": "user", "content": "Turn 2"}], - completion=[{"role": "assistant", "content": "Response 2"}], - response=MagicMock(), - tokens=vf.TrajectoryStepTokens( - prompt_ids=[1, 2, 3, 4, 5], - prompt_mask=[0, 0, 0, 0, 0], - completion_ids=[6, 7], - completion_mask=[1, 1], - completion_logprobs=[-0.3, -0.4], - overlong_prompt=False, - is_truncated=False, - multi_modal_data=mm_step_1, - ), - reward=None, - advantage=None, - is_truncated=False, - trajectory_id="1", - extras={}, - ), - ], - sampling_args={"temperature": 1.0}, - error=None, - ) - - # Token 2 is the image placeholder, token 5 is the video placeholder. - mm_mapping = {2: 1, 5: 2} - rollouts = interleave_rollout(output, mm_token_type_ids_mapping=mm_mapping) - - assert rollouts is not None and len(rollouts) == 1 - sample = rollouts[0] - # Extension holds; both steps merge into one sample. mm_data is - # the union of step 0's delta ([h1]) and step 1's delta ([h2]). - assert sample.prompt_ids == [1, 2] - assert sample.completion_ids == [3, 4, 5, 6, 7] - # Pixel values packed by concatenating step 0's item then step 1's. - assert _decode_mm_pixels(sample) == [ - [1.0, 2.0], - [3.0, 4.0], - ] - assert _decode_mm_thw(sample) == [[1, 2, 3], [1, 4, 4]] - # mm_token_type_ids: image at token 2, video at token 5, rest 0. - assert sample.mm_token_type_ids == [0, 1, 0, 0, 2, 0, 0] diff --git a/tests/unit/test_configs.py b/tests/unit/test_configs.py index fcdee7a843..ed62ad54a1 100644 --- a/tests/unit/test_configs.py +++ b/tests/unit/test_configs.py @@ -169,7 +169,7 @@ def test_removed_fused_lm_head_chunk_size_field_is_rejected(): def test_orchestrator_vlm_requires_renderer(): - with pytest.raises(ValidationError, match="orchestrator.renderer must be set when model.vlm is set"): + with pytest.raises(ValidationError, match="renderer"): OrchestratorConfig.model_validate( { "student": { @@ -213,7 +213,7 @@ def test_shared_model_name_propagates_to_subconfigs(): { "model": {"name": model_name}, "trainer": {}, - "orchestrator": {"renderer": None}, + "orchestrator": {"renderer": {"name": "default"}}, "inference": {}, } ) @@ -230,7 +230,7 @@ def test_shared_tokenizer_propagates_when_subconfigs_unset(): "model": {"name": "my-model"}, "tokenizer": {"name": "my-tokenizer"}, "trainer": {}, - "orchestrator": {"renderer": None}, + "orchestrator": {"renderer": {"name": "default"}}, } ) assert config.trainer.tokenizer.name == "my-tokenizer" @@ -247,7 +247,7 @@ def test_shared_and_sub_tokenizer_name_conflict_raises(): "model": {"name": "my-model"}, "tokenizer": {"name": "shared-tok"}, "trainer": {"tokenizer": {"name": "trainer-tok"}}, - "orchestrator": {"renderer": None}, + "orchestrator": {"renderer": {"name": "default"}}, } ) @@ -258,7 +258,7 @@ def test_tokenizer_name_falls_back_to_model_name_when_unset(): "model": {"name": "my-model"}, "tokenizer": {"trust_remote_code": True}, "trainer": {}, - "orchestrator": {"renderer": None}, + "orchestrator": {"renderer": {"name": "default"}}, } ) assert config.trainer.tokenizer.name == "my-model" @@ -286,7 +286,7 @@ def test_explicit_subconfig_tokenizer_name_survives_shared_model_propagation(): "model": {"name": "M"}, "trainer": {}, "orchestrator": { - "renderer": None, + "renderer": {"name": "default"}, "tokenizer": {"name": "explicit-orch-tok"}, }, } @@ -305,7 +305,7 @@ def test_tokenizer_chat_template_mismatch_raises(): RLConfig.model_validate( { "trainer": {"tokenizer": {"chat_template": "A"}}, - "orchestrator": {"renderer": None, "tokenizer": {"chat_template": "B"}}, + "orchestrator": {"renderer": {"name": "default"}, "tokenizer": {"chat_template": "B"}}, } ) @@ -315,7 +315,7 @@ def test_shared_seq_len_propagates_to_subconfigs(): { "seq_len": 4096, "trainer": {}, - "orchestrator": {"renderer": None}, + "orchestrator": {"renderer": {"name": "default"}}, } ) assert config.trainer.model.seq_len == 4096 @@ -331,7 +331,7 @@ def test_shared_and_sub_seq_len_conflict_raises(): { "seq_len": 4096, "trainer": {"model": {"seq_len": 8192}}, - "orchestrator": {"renderer": None}, + "orchestrator": {"renderer": {"name": "default"}}, } ) @@ -343,7 +343,7 @@ def test_shared_and_sub_model_name_conflict_raises(): { "model": {"name": "X"}, "trainer": {"model": {"name": "Y"}}, - "orchestrator": {"renderer": None}, + "orchestrator": {"renderer": {"name": "default"}}, } ) @@ -355,7 +355,7 @@ def test_shared_and_sub_max_steps_conflict_raises(): { "max_steps": 100, "trainer": {}, - "orchestrator": {"renderer": None, "max_steps": 200}, + "orchestrator": {"renderer": {"name": "default"}, "max_steps": 200}, } ) @@ -370,7 +370,7 @@ def test_trainer_chat_template_cascades_to_inference(): { "model": {"name": "Qwen/Qwen3-0.6B"}, "trainer": {"tokenizer": {"chat_template": "TPL"}}, - "orchestrator": {"renderer": None, "tokenizer": {"chat_template": "TPL"}}, + "orchestrator": {"renderer": {"name": "default"}, "tokenizer": {"chat_template": "TPL"}}, "inference": {}, } ) @@ -396,7 +396,7 @@ def test_shared_wandb_fields_propagate_to_subconfigs(): "offline": False, }, "trainer": {}, - "orchestrator": {"renderer": None}, + "orchestrator": {"renderer": {"name": "default"}}, } ) for component in (config.trainer.wandb, config.orchestrator.wandb): @@ -416,7 +416,7 @@ def test_empty_shared_ckpt_block_does_not_conflict_with_subconfig_ckpt(): { "ckpt": {}, # empty block, no field set "trainer": {"ckpt": {"interval": 50}}, - "orchestrator": {"renderer": None, "ckpt": {"interval": 50}}, + "orchestrator": {"renderer": {"name": "default"}, "ckpt": {"interval": 50}}, } ) assert config.trainer.ckpt is not None @@ -430,7 +430,7 @@ def test_shared_and_subconfig_disjoint_fields_coexist(): { "model": {"name": "Qwen/Qwen3-0.6B"}, "trainer": {"model": {"impl": "custom"}}, - "orchestrator": {"renderer": None}, + "orchestrator": {"renderer": {"name": "default"}}, } ) assert config.trainer.model.name == "Qwen/Qwen3-0.6B" @@ -482,15 +482,15 @@ def test_orchestrator_explicit_renderer_skips_unmapped_check(): assert config.renderer.name == "qwen3" -def test_orchestrator_renderer_none_skips_unmapped_check(): - """renderer=None (MITO mode) means the renderer client isn't used, so MODEL_RENDERER_MAP doesn't apply.""" - config = OrchestratorConfig.model_validate( - { - "model": {"name": "not-a-real-org/not-a-real-model"}, - "renderer": None, - } - ) - assert config.renderer is None +def test_orchestrator_renderer_none_rejected(): + """A renderer is required (training is renderer-only): the non-optional type rejects None.""" + with pytest.raises(ValidationError, match="renderer"): + OrchestratorConfig.model_validate( + { + "model": {"name": "not-a-real-org/not-a-real-model"}, + "renderer": None, + } + ) def test_orchestrator_explicit_default_renderer_with_unmapped_model(): @@ -515,7 +515,7 @@ def test_shared_model_name_resolves_inference_parsers(): { "model": {"name": "Qwen/Qwen3-Coder-30B-A3B-Instruct"}, "trainer": {}, - "orchestrator": {"renderer": None}, + "orchestrator": {"renderer": {"name": "default"}}, "inference": {}, } ) @@ -531,7 +531,7 @@ def test_explicit_inference_parser_wins_over_auto(): { "model": {"name": "Qwen/Qwen3-Coder-30B-A3B-Instruct"}, "trainer": {}, - "orchestrator": {"renderer": None}, + "orchestrator": {"renderer": {"name": "default"}}, "inference": {"model": {"tool_call_parser": "hermes"}}, } ) diff --git a/tests/unit/train/rl/test_packer.py b/tests/unit/train/rl/test_packer.py index d1f1c3634f..8d9bf2ce22 100644 --- a/tests/unit/train/rl/test_packer.py +++ b/tests/unit/train/rl/test_packer.py @@ -32,8 +32,8 @@ def create_run_with_config(output_dir: Path, run_name: str) -> Path: "group_size": 1, "env": [{"id": "test-env"}], "sampling": {"temperature": 1.0}, - # test-model isn't in MODEL_RENDERER_MAP; bypass the renderer-resolution validator. - "renderer": "None", + # test-model isn't in MODEL_RENDERER_MAP; use the explicit default renderer. + "renderer": {"name": "default"}, } with open(control_dir / "orch.toml", "wb") as f: tomli_w.dump(config, f) @@ -42,12 +42,10 @@ def create_run_with_config(output_dir: Path, run_name: str) -> Path: def make_training_sample() -> TrainingSample: return TrainingSample( - prompt_ids=[1], - prompt_mask=[False], - completion_ids=[2], - completion_mask=[True], - completion_logprobs=[-0.1], - completion_temperatures=[1.0], + token_ids=[1, 2], + mask=[False, True], + logprobs=[0.0, -0.1], + temperatures=[1.0, 1.0], env_name="test-env", ) diff --git a/tests/unit/train/test_runs.py b/tests/unit/train/test_runs.py index 6bb6623c06..cb8d6e5f9e 100644 --- a/tests/unit/train/test_runs.py +++ b/tests/unit/train/test_runs.py @@ -42,8 +42,8 @@ def create_run_with_config( "batch_size": 32, "group_size": 4, "env": [{"id": "test-env"}], - # test-model isn't in MODEL_RENDERER_MAP; bypass the renderer-resolution validator. - "renderer": "None", + # test-model isn't in MODEL_RENDERER_MAP; use the explicit default renderer. + "renderer": {"name": "default"}, } with open(config_dir / "orch.toml", "wb") as f: @@ -203,7 +203,7 @@ def test_config_loading(tmp_path: Path) -> None: "max_steps": 1000, "group_size": 4, "env": [{"id": "test-env"}], - "renderer": "None", + "renderer": {"name": "default"}, } create_run_with_config(tmp_path, "run_test123", config=test_config) @@ -248,7 +248,7 @@ def test_config_cleanup_on_deletion(tmp_path: Path) -> None: "batch_size": 16, "group_size": 4, "env": [{"id": "test-env"}], - "renderer": "None", + "renderer": {"name": "default"}, } run_dir = create_run_with_config(tmp_path, "run_delete_me", config=test_config) diff --git a/uv.lock b/uv.lock index 5c15a1668a..8bcb0d88b2 100644 --- a/uv.lock +++ b/uv.lock @@ -96,6 +96,17 @@ requires-dist = [ { name = "verifiers", specifier = ">=0.1.12.dev1" }, ] +[[package]] +name = "aime24-v1" +version = "0.1.0" +source = { editable = "deps/verifiers/examples/tasksets/aime24_v1" } +dependencies = [ + { name = "datasets", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, +] + +[package.metadata] +requires-dist = [{ name = "datasets" }] + [[package]] name = "aiofiles" version = "25.1.0" @@ -168,6 +179,17 @@ dependencies = [ [package.metadata] requires-dist = [{ name = "verifiers", specifier = ">=0.1.9" }] +[[package]] +name = "alphabet-sort-v1" +version = "0.1.0" +source = { editable = "deps/verifiers/examples/tasksets/alphabet_sort_v1" } +dependencies = [ + { name = "datasets", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, +] + +[package.metadata] +requires-dist = [{ name = "datasets" }] + [[package]] name = "annotated-doc" version = "0.0.4" @@ -557,6 +579,17 @@ requires-dist = [ { name = "verifiers", specifier = ">=0.1.10" }, ] +[[package]] +name = "color-codeword-v1" +version = "0.1.0" +source = { editable = "deps/verifiers/examples/tasksets/color_codeword_v1" } +dependencies = [ + { name = "pillow", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, +] + +[package.metadata] +requires-dist = [{ name = "pillow" }] + [[package]] name = "comm" version = "0.2.3" @@ -1097,6 +1130,7 @@ dependencies = [ ] sdist = { url = "https://files.pythonhosted.org/packages/c8/33/c97b2bcbe06e0f011eedee0f41d4060f6344901a53c2703acc3dd7429713/fastsafetensors-0.3.2.tar.gz", hash = "sha256:9e358fce238684613a5c3ebb7800c52c5b3270c0bb5e4ed2191ee8f3d0431de1", size = 70409, upload-time = "2026-05-22T05:39:34.787Z" } wheels = [ + { url = "https://files.pythonhosted.org/packages/c9/bb/9f821eac9bddd41ea1c5cd9b6a597c002741f022ecf6f3ba5cfcc3e9c950/fastsafetensors-0.3.2-cp312-cp312-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:69f4d8cbd3b542e5ddf7fee8136cf35e1524f9c30e118f64a0e846dab7e8de6b", size = 1877989, upload-time = "2026-06-04T09:02:56.11Z" }, { url = "https://files.pythonhosted.org/packages/e9/68/a31c1661adf4d1b5ec29470ff991bde9094e4f347b0e6d1af8ba6b560d32/fastsafetensors-0.3.2-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:6a932d7166c9e17e48aca3e5503d326bc6fc73fce6dc985ae6bd2ccc0f308b14", size = 1907188, upload-time = "2026-05-22T05:39:30.242Z" }, ] @@ -1426,6 +1460,17 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/5c/90/b0cbbd9efcc82816c58f31a34963071aa19fb792a212a5d9caf8e0fc3097/grpclib-0.4.9-py3-none-any.whl", hash = "sha256:7762ec1c8ed94dfad597475152dd35cbd11aecaaca2f243e29702435ca24cf0e", size = 77063, upload-time = "2025-12-14T22:23:13.224Z" }, ] +[[package]] +name = "gsm8k-v1" +version = "0.1.0" +source = { editable = "deps/verifiers/examples/tasksets/gsm8k_v1" } +dependencies = [ + { name = "datasets", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, +] + +[package.metadata] +requires-dist = [{ name = "datasets" }] + [[package]] name = "gymnasium" version = "1.3.0" @@ -1463,6 +1508,11 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/69/b2/119f6e6dcbd96f9069ce9a2665e0146588dc9f88f29549711853645e736a/h2-4.3.0-py3-none-any.whl", hash = "sha256:c438f029a25f7945c69e0ccf0fb951dc3f73a5f6412981daee861431b70e2bdd", size = 61779, upload-time = "2025-08-23T18:12:17.779Z" }, ] +[[package]] +name = "harnesses" +version = "0.1.0" +source = { editable = "deps/verifiers/packages/harnesses" } + [[package]] name = "hf-xet" version = "1.5.0" @@ -1951,7 +2001,7 @@ wheels = [ [[package]] name = "kubernetes" -version = "36.0.0" +version = "36.0.2" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "aiohttp", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, @@ -1965,9 +2015,9 @@ dependencies = [ { name = "urllib3", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "websocket-client", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/bf/59/dc635e4e9afb3884bc5c57f14fe23783e4c04601aa20b835ac75c41d1625/kubernetes-36.0.0.tar.gz", hash = "sha256:027b606bb8032e6c6464a53236bdd9bd9a94c237e1063bc45a303c25b304ced9", size = 2346728, upload-time = "2026-05-20T20:44:24.28Z" } +sdist = { url = "https://files.pythonhosted.org/packages/2f/57/8b538af5076bc3372949d76f70ba3449bdfe52f9e6488170fa5d4f7cbe70/kubernetes-36.0.2.tar.gz", hash = "sha256:03551fcb49cae1f708f63624041e37403545b7aaed10cbf54e2b01a37a5438e3", size = 2336738, upload-time = "2026-06-01T18:20:30.785Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/cd/d2/6f99ca9c7eb961dfdd45b9643101399a8ee20922c662c362c91e9cc7e832/kubernetes-36.0.0-py2.py3-none-any.whl", hash = "sha256:a766433357ec9f90db7565cccf52e28e7fca40b0ef366c80a6022adbc0ac0425", size = 4660469, upload-time = "2026-05-20T20:44:20.893Z" }, + { url = "https://files.pythonhosted.org/packages/46/2c/5c160dbdef7123f8cc97fd8ece7e0198627a426a2a49614845e9086feb8d/kubernetes-36.0.2-py2.py3-none-any.whl", hash = "sha256:faf9b5241b58de0c4a5069f2a0ffc8ac06fece7215156cd3d3ba081a78a858b6", size = 4617568, upload-time = "2026-06-01T18:20:28.737Z" }, ] [[package]] @@ -1981,7 +2031,7 @@ sdist = { url = "https://files.pythonhosted.org/packages/0e/72/a3add0e4eec4eb9e2 [[package]] name = "langfuse" -version = "4.6.1" +version = "4.7.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "backoff", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, @@ -1993,9 +2043,9 @@ dependencies = [ { name = "pydantic", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "wrapt", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/a6/31/4b7157be23e7c8c3581ac5f6547c5c003e232e7044c92398c468ef78a809/langfuse-4.6.1.tar.gz", hash = "sha256:7f256c669e610909c2e93ca3e9e4168dbef344b753b6874f14b0edd673863f17", size = 281379, upload-time = "2026-05-08T14:08:15.909Z" } +sdist = { url = "https://files.pythonhosted.org/packages/2c/74/a6f1a99893ee6d1a69439ae7eb92f8fe8806103492dc26531d5942dbd3bf/langfuse-4.7.1.tar.gz", hash = "sha256:f9e262eceedb353b191c1da1f8452d1e8ebf52297ca20e160cda0206608e3a40", size = 320620, upload-time = "2026-05-29T18:06:22.435Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/4b/bf/3a6082f7809bdcc1269e9920c07d7c7f92a53cc265a4a879e59c92b23b36/langfuse-4.6.1-py3-none-any.whl", hash = "sha256:a696ac3089a0c8431bf7f1b47b7f6417da311f418dd04ce9ef62d63608fd8797", size = 481237, upload-time = "2026-05-08T14:08:17.141Z" }, + { url = "https://files.pythonhosted.org/packages/9f/9a/bd3368f46b6c72ee2068b80536826b02ae86df53eff1c79941344503098f/langfuse-4.7.1-py3-none-any.whl", hash = "sha256:a4e59c81ad5e5b16a65d3849f4923ebc3ad6e67ec803ada83d50c0cb66149490", size = 562571, upload-time = "2026-05-29T18:06:20.517Z" }, ] [[package]] @@ -2047,7 +2097,7 @@ wheels = [ [[package]] name = "litellm" -version = "1.86.0" +version = "1.87.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "aiohttp", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, @@ -2063,9 +2113,9 @@ dependencies = [ { name = "tiktoken", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "tokenizers", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/2f/a7/26b8b04e4fcff26b60200ffe7458a255552ae51014468188f5db45674eb2/litellm-1.86.0.tar.gz", hash = "sha256:eccab86e0820b60b3f9484b233fb8d818b97afb19d5b4fa08d0d045621350ba4", size = 15379195, upload-time = "2026-05-24T02:42:10.865Z" } +sdist = { url = "https://files.pythonhosted.org/packages/77/0d/ccdf682ccfd7f18bf0e179c39d85616b8f8ef05a798588285310412db13d/litellm-1.87.0.tar.gz", hash = "sha256:cafc1882cb0cbab8374c41180af86e4a067796e4524e15f59e99f6e689cd1bd8", size = 15453755, upload-time = "2026-06-02T03:53:29.076Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/ff/0b/9a044c061a69e801de042e962c34f5bc2e094810e28b49ce0b3bedee9327/litellm-1.86.0-py3-none-any.whl", hash = "sha256:9d8171ca1a17705b7c7a6fdce8cfc07bbf641284b46c1b6047f83a779159990c", size = 17011225, upload-time = "2026-05-24T02:42:00.629Z" }, + { url = "https://files.pythonhosted.org/packages/98/20/88a372fa7e50fc2c33458c6eef94a79afcf7bdfa43610079531b82b484a3/litellm-1.87.0-py3-none-any.whl", hash = "sha256:fbbba7e47ae29b55f878fe1acc80effb92761bc168f6236bd81a0cb6e147d855", size = 17103948, upload-time = "2026-06-02T03:53:25.677Z" }, ] [[package]] @@ -2228,6 +2278,17 @@ requires-dist = [ { name = "verifiers", specifier = ">=0.1.12.dev1" }, ] +[[package]] +name = "math-env-v1" +version = "0.1.0" +source = { editable = "deps/verifiers/examples/tasksets/math_env_v1" } +dependencies = [ + { name = "datasets", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, +] + +[package.metadata] +requires-dist = [{ name = "datasets" }] + [[package]] name = "math-python" version = "0.1.10" @@ -2618,11 +2679,11 @@ wheels = [ [[package]] name = "narwhals" -version = "2.21.2" +version = "2.22.0" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/cf/a0/6198c56d42ef2f3c6ed0c42ba30dbcefdc86a91262d7d449010770ae085b/narwhals-2.21.2.tar.gz", hash = "sha256:5c5b2d0b47aef7c73ea412cfcbcd467f2f2d5be73e3c2ab19d78f4a97718790a", size = 632176, upload-time = "2026-05-16T08:49:08.314Z" } +sdist = { url = "https://files.pythonhosted.org/packages/9c/1c/c80cb7719721a44846c6301ef118434bae30a423924bfad3a47f16bdc064/narwhals-2.22.0.tar.gz", hash = "sha256:6486282bb7e4b4ab55963efbd8be1451b764cc4874b74d1fd625eba9dc60b86f", size = 417565, upload-time = "2026-06-01T13:34:36.249Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/1d/77/928ea2e70641ca177a11140062cc5840d421795f2e82749d408d0cce900a/narwhals-2.21.2-py3-none-any.whl", hash = "sha256:7bb57c3700486039215455b9bf2d64261915cc0fd845cc30272d631df696b251", size = 451201, upload-time = "2026-05-16T08:49:05.536Z" }, + { url = "https://files.pythonhosted.org/packages/1f/b6/e7cdde7b8e90d5dff25b622f95833ef26567ad184c977278b93a1cbd5717/narwhals-2.22.0-py3-none-any.whl", hash = "sha256:1421797ede01789cc1537619dbc3f36f840737240f748fdb24a60a0225fc80be", size = 453815, upload-time = "2026-06-01T13:34:34.127Z" }, ] [[package]] @@ -3505,9 +3566,12 @@ dependencies = [ { name = "liger-kernel", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "loguru", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "mooncake-transfer-engine", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "msgspec", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "numpy", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "nvidia-ml-py", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "openai", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "orjson", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "pandas", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "prime", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "prime-rl-configs", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "pyarrow", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, @@ -3554,17 +3618,23 @@ disagg = [ envs = [ { name = "aime2024", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "aime2025", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "aime24-v1", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "alphabet-sort", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "alphabet-sort-v1", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "code-env", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "color-codeword", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "color-codeword-v1", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "deepdive", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "general-agent", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "gpqa", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "gsm8k-v1", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "harnesses", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "hle", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "ifeval", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "livecodebench", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "logic-env", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "math-env", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "math-env-v1", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "math-python", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "math500", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "mini-swe-agent-plus", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, @@ -3575,10 +3645,14 @@ envs = [ { name = "opencode-math", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "opencode-science", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "opencode-swe", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "r2e-gym-v1", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "reverse-text", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "reverse-text-v1", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "rlm-swe", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "scaleswe-v1", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "science-env", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "simpleqa-verified", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "tasksets", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "tau2-bench", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "wiki-search", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "wordle", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, @@ -3619,11 +3693,14 @@ mamba-ssm = [ requires-dist = [ { name = "aime2024", marker = "extra == 'envs'", editable = "deps/research-environments/environments/aime2024" }, { name = "aime2025", marker = "extra == 'envs'", editable = "deps/research-environments/environments/aime2025" }, + { name = "aime24-v1", marker = "extra == 'envs'", editable = "deps/verifiers/examples/tasksets/aime24_v1" }, { name = "aiolimiter", specifier = ">=1.2.1" }, { name = "alphabet-sort", marker = "extra == 'envs'", editable = "deps/verifiers/environments/alphabet_sort" }, + { name = "alphabet-sort-v1", marker = "extra == 'envs'", editable = "deps/verifiers/examples/tasksets/alphabet_sort_v1" }, { name = "beartype", specifier = ">=0.21.0" }, { name = "code-env", marker = "extra == 'envs'", editable = "deps/research-environments/environments/code_env" }, { name = "color-codeword", marker = "extra == 'envs'", editable = "deps/research-environments/environments/color_codeword" }, + { name = "color-codeword-v1", marker = "extra == 'envs'", editable = "deps/verifiers/examples/tasksets/color_codeword_v1" }, { name = "datasets", specifier = ">=4.0.0" }, { name = "deep-ep", marker = "platform_machine == 'x86_64' and extra == 'disagg'", url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/deep_ep-1.2.1+29d31c0-cp312-cp312-linux_x86_64.whl" }, { name = "deep-gemm", marker = "platform_machine == 'x86_64' and extra == 'disagg'", url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/deep_gemm-2.5.0+891d57b-cp312-cp312-linux_x86_64.whl" }, @@ -3635,6 +3712,8 @@ requires-dist = [ { name = "flash-linear-attention", git = "https://github.com/fla-org/flash-linear-attention" }, { name = "general-agent", marker = "extra == 'envs'", editable = "deps/research-environments/environments/general_agent" }, { name = "gpqa", marker = "extra == 'envs'", editable = "deps/research-environments/environments/gpqa" }, + { name = "gsm8k-v1", marker = "extra == 'envs'", editable = "deps/verifiers/examples/tasksets/gsm8k_v1" }, + { name = "harnesses", marker = "extra == 'envs'", editable = "deps/verifiers/packages/harnesses" }, { name = "hle", marker = "extra == 'envs'", editable = "deps/research-environments/environments/hle" }, { name = "ifeval", marker = "extra == 'envs'", editable = "deps/research-environments/environments/ifeval" }, { name = "jaxtyping", specifier = ">=0.3.2" }, @@ -3644,6 +3723,7 @@ requires-dist = [ { name = "logic-env", marker = "extra == 'envs'", editable = "deps/research-environments/environments/logic_env" }, { name = "loguru", specifier = ">=0.7.3" }, { name = "math-env", marker = "extra == 'envs'", editable = "deps/research-environments/environments/math_env" }, + { name = "math-env-v1", marker = "extra == 'envs'", editable = "deps/verifiers/examples/tasksets/math_env_v1" }, { name = "math-python", marker = "extra == 'envs'", editable = "deps/verifiers/environments/math_python" }, { name = "math500", marker = "extra == 'envs'", editable = "deps/research-environments/environments/math500" }, { name = "mini-swe-agent-plus", marker = "extra == 'envs'", editable = "deps/research-environments/environments/mini_swe_agent_plus" }, @@ -3651,6 +3731,7 @@ requires-dist = [ { name = "mmlu-pro", marker = "extra == 'envs'", editable = "deps/research-environments/environments/mmlu_pro" }, { name = "modelexpress", marker = "extra == 'modelexpress'", specifier = "==0.3.0" }, { name = "mooncake-transfer-engine", specifier = ">=0.3.10.post2" }, + { name = "msgspec", specifier = ">=0.18" }, { name = "nixl", marker = "extra == 'disagg'" }, { name = "nixl-cu12", marker = "platform_machine == 'x86_64' and extra == 'disagg'", url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/nixl_cu12-0.10.1-cp312-cp312-linux_x86_64.whl" }, { name = "numpy", specifier = ">=2.2.6" }, @@ -3661,6 +3742,8 @@ requires-dist = [ { name = "opencode-math", marker = "extra == 'envs'", editable = "deps/research-environments/environments/opencode_math" }, { name = "opencode-science", marker = "extra == 'envs'", editable = "deps/research-environments/environments/opencode_science" }, { name = "opencode-swe", marker = "extra == 'envs'", editable = "deps/research-environments/environments/opencode_swe" }, + { name = "orjson", specifier = ">=3.10" }, + { name = "pandas", specifier = ">=2.0" }, { name = "prime", specifier = ">=0.6.4" }, { name = "prime-rl", extras = ["disagg"], marker = "extra == 'all'" }, { name = "prime-rl", extras = ["flash-attn"], marker = "extra == 'all'" }, @@ -3672,14 +3755,18 @@ requires-dist = [ { name = "pybase64", specifier = ">=1.4.2" }, { name = "pyzmq", specifier = ">=27.1.0" }, { name = "quack-kernels", marker = "extra == 'quack'", specifier = ">=0.4.1" }, + { name = "r2e-gym-v1", marker = "extra == 'envs'", editable = "deps/verifiers/examples/tasksets/r2e_gym_v1" }, { name = "renderers", editable = "deps/renderers" }, { name = "reverse-text", marker = "extra == 'envs'", editable = "deps/verifiers/environments/reverse_text" }, + { name = "reverse-text-v1", marker = "extra == 'envs'", editable = "deps/verifiers/examples/tasksets/reverse_text_v1" }, { name = "rich", specifier = ">=14.0.0" }, { name = "ring-flash-attn", specifier = ">=0.1.8" }, { name = "rlm-swe", marker = "extra == 'envs'", editable = "deps/research-environments/environments/rlm_swe" }, + { name = "scaleswe-v1", marker = "extra == 'envs'", editable = "deps/verifiers/examples/tasksets/scaleswe_v1" }, { name = "science-env", marker = "extra == 'envs'", editable = "deps/research-environments/environments/science_env" }, { name = "setproctitle", specifier = ">=1.3.0" }, { name = "simpleqa-verified", marker = "extra == 'envs'", editable = "deps/research-environments/environments/simpleqa_verified" }, + { name = "tasksets", marker = "extra == 'envs'", editable = "deps/verifiers/packages/tasksets" }, { name = "tau2-bench", marker = "extra == 'envs'", editable = "deps/research-environments/environments/tau2_bench" }, { name = "tenacity", specifier = ">=8.2.0" }, { name = "tilelang", specifier = ">=0.1.8" }, @@ -3722,6 +3809,7 @@ dependencies = [ { name = "renderers", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "tomli", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "tomli-w", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "verifiers", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, ] [package.metadata] @@ -3731,11 +3819,12 @@ requires-dist = [ { name = "renderers", specifier = ">=0.1.8.dev28" }, { name = "tomli", specifier = ">=2.2.1" }, { name = "tomli-w", specifier = ">=1.2.0" }, + { name = "verifiers" }, ] [[package]] name = "prime-sandboxes" -version = "0.2.26" +version = "0.2.27" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "aiofiles", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, @@ -3745,23 +3834,23 @@ dependencies = [ { name = "pydantic", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "tenacity", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/9b/98/7ecfdd389664c5112eda8d83e287b266f627fdbcefe5510b1b7346effa4b/prime_sandboxes-0.2.26.tar.gz", hash = "sha256:87b0aa2d5ee44f201cb4d0fc1bdf52efa617a7f3cbf9e5a5ee7bc2de0f45ac98", size = 68496, upload-time = "2026-05-12T22:17:47.801Z" } +sdist = { url = "https://files.pythonhosted.org/packages/9a/1f/b257b21f54e9b961bcac05335f6637c494ed0986fb4b5d2054f97f96faa0/prime_sandboxes-0.2.27.tar.gz", hash = "sha256:db4071387f4b2dc8bcd0c8916af4031f864dfd5dfd9818f56119347af20b1469", size = 68661, upload-time = "2026-06-05T21:55:35.175Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/e8/51/b5c282dffd4eba73953cdf86c34df36dd31264464489fd14a1a565776e03/prime_sandboxes-0.2.26-py3-none-any.whl", hash = "sha256:03226e4fa81a8008bbdb3046f295a56a43e74bb89b8e894b001547643deca918", size = 34234, upload-time = "2026-05-12T22:17:46.494Z" }, + { url = "https://files.pythonhosted.org/packages/94/bc/a9142d0ef67d92672469cff15cadc04eb46ca9957b00156445d14c14b04f/prime_sandboxes-0.2.27-py3-none-any.whl", hash = "sha256:3fb227cc909c15475fb2874974d166d857021ee6ed9a6e0d4482b1b5ebfc50e2", size = 34402, upload-time = "2026-06-05T21:55:33.451Z" }, ] [[package]] name = "prime-tunnel" -version = "0.1.7" +version = "0.1.8" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "httpx", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "pydantic", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "tenacity", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/91/63/900dbb168a0f49dd9b6c510520a8c9b2b7ecb2c5863caac7757f12ea16e5/prime_tunnel-0.1.7.tar.gz", hash = "sha256:0fd296a78645c95c4474cf418321bac2ad6ba6816701e8823baf6ab718a68a5a", size = 13024, upload-time = "2026-05-20T06:38:40.444Z" } +sdist = { url = "https://files.pythonhosted.org/packages/9c/e7/557c7623bd8c9a7a9127d880dce2a3d8612d3a7df02bdde6f01c958a20b8/prime_tunnel-0.1.8.tar.gz", hash = "sha256:07803d5d5c6ec83c260bbef0ecd340f4f6acab0f6a254d5d34a4c6ddfc777c0b", size = 14576, upload-time = "2026-06-02T06:39:36.56Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/60/32/3270dba4c7c402372d25faf8f633f571fe2be8692d720e0eb717ed6ab44e/prime_tunnel-0.1.7-py3-none-any.whl", hash = "sha256:226a08ef9cc4c13503f4e1a687d4fa733ab72626725538b822edbaf8064c40fd", size = 14912, upload-time = "2026-05-20T06:38:38.986Z" }, + { url = "https://files.pythonhosted.org/packages/f2/5f/181fe0ca1ca8507c0a373e55363b411bdeab9ad904c68edec7d7b1bdea74/prime_tunnel-0.1.8-py3-none-any.whl", hash = "sha256:fab081f566e0887ce50c7b5872169e9f356fc3263d5f0493cfabb3150eb73b0a", size = 15843, upload-time = "2026-06-02T06:39:34.043Z" }, ] [[package]] @@ -4034,6 +4123,17 @@ crypto = [ { name = "cryptography", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, ] +[[package]] +name = "pymupdf" +version = "1.27.2.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/22/32/708bedc9dde7b328d45abbc076091769d44f2f24ad151ad92d56a6ec142b/pymupdf-1.27.2.3.tar.gz", hash = "sha256:7a92faa25129e8bbec5e50eeb9214f187665428c31b05c4ef6e36c58c0b1c6d2", size = 85759618, upload-time = "2026-04-24T14:13:14.42Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c3/26/b7e5a70eb83bd189f8b5df87ec442746b992f2f632662839b288170d357d/pymupdf-1.27.2.3-cp310-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:1dd460a3ae4597a755f00a3bd9771f5ebf1531dc111f6a36bf05dd00a6b84425", size = 24333923, upload-time = "2026-04-24T14:09:47.341Z" }, + { url = "https://files.pythonhosted.org/packages/e4/a0/aa1ee2240f29481a04a827c313333b4ecd8a14d6ac3e15d3f41a30574781/pymupdf-1.27.2.3-cp310-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:857842b4888827bd6155a1131341b2822a7ebe9a8c15a975fd7d490d7a64a30c", size = 24963198, upload-time = "2026-04-24T14:10:07.408Z" }, + { url = "https://files.pythonhosted.org/packages/69/49/4f742451f980840829fc00ba158bebb25d389c846d8f4f8c65936ee55de8/pymupdf-1.27.2.3-cp310-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:580983849c64a08d08344ca3d1580e87c01f046a8392421797bc850efd72a5b6", size = 25184609, upload-time = "2026-04-24T14:10:22.911Z" }, +] + [[package]] name = "pynacl" version = "1.6.2" @@ -4222,13 +4322,24 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/38/e4/a6c3bbbe3d4242fa412454b8e8069a079e500be331aecf8f2aa666164e9c/quack_kernels-0.4.1-py3-none-any.whl", hash = "sha256:c1c8df2935bf5156ec47d2c5384ac08b411fd0ee702d80ae916dbf6d6f5ae813", size = 260827, upload-time = "2026-04-30T14:37:54.584Z" }, ] +[[package]] +name = "r2e-gym-v1" +version = "0.1.0" +source = { editable = "deps/verifiers/examples/tasksets/r2e_gym_v1" } +dependencies = [ + { name = "datasets", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, +] + +[package.metadata] +requires-dist = [{ name = "datasets" }] + [[package]] name = "redis" -version = "7.4.0" +version = "8.0.0" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/7b/7f/3759b1d0d72b7c92f0d70ffd9dc962b7b7b5ee74e135f9d7d8ab06b8a318/redis-7.4.0.tar.gz", hash = "sha256:64a6ea7bf567ad43c964d2c30d82853f8df927c5c9017766c55a1d1ed95d18ad", size = 4943913, upload-time = "2026-03-24T09:14:37.53Z" } +sdist = { url = "https://files.pythonhosted.org/packages/53/ae/ed461cca5780b5fc8b9fe8ca0ed98d89508645fb9d880c24cc42c087678f/redis-8.0.0.tar.gz", hash = "sha256:a00c5355432051ac14e593b8b197fc76c887ee12d55a0984f69328a1115fdc49", size = 5101591, upload-time = "2026-05-28T12:45:13.5Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/74/3a/95deec7db1eb53979973ebd156f3369a72732208d1391cd2e5d127062a32/redis-7.4.0-py3-none-any.whl", hash = "sha256:a9c74a5c893a5ef8455a5adb793a31bb70feb821c86eccb62eebef5a19c429ec", size = 409772, upload-time = "2026-03-24T09:14:35.968Z" }, + { url = "https://files.pythonhosted.org/packages/27/e3/b519734372d305bd547534a9f32e4ce9f98552af753dce72cf3483a0ff0b/redis-8.0.0-py3-none-any.whl", hash = "sha256:c938c18338585009f0bc310f4c7e4e4b4d37639356c4ac072cedf3af570c8dc7", size = 499870, upload-time = "2026-05-28T12:45:11.697Z" }, ] [[package]] @@ -4338,6 +4449,17 @@ requires-dist = [ { name = "verifiers", specifier = ">=0.1.5.post0" }, ] +[[package]] +name = "reverse-text-v1" +version = "0.1.0" +source = { editable = "deps/verifiers/examples/tasksets/reverse_text_v1" } +dependencies = [ + { name = "datasets", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, +] + +[package.metadata] +requires-dist = [{ name = "datasets" }] + [[package]] name = "rich" version = "15.0.0" @@ -4441,6 +4563,17 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/4a/d8/0c8a7dc9b41dcac53c4cbf9df2b9c83e0e0097203de8b37a712b345c0be5/safetensors-0.7.0-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:b0f6d66c1c538d5a94a73aa9ddca8ccc4227e6c9ff555322ea40bdd142391dd4", size = 677368, upload-time = "2025-11-19T15:18:41.627Z" }, ] +[[package]] +name = "scaleswe-v1" +version = "0.1.0" +source = { editable = "deps/verifiers/examples/tasksets/scaleswe_v1" } +dependencies = [ + { name = "datasets", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, +] + +[package.metadata] +requires-dist = [{ name = "datasets" }] + [[package]] name = "science-env" version = "0.1.4" @@ -4458,18 +4591,19 @@ requires-dist = [ [[package]] name = "scikit-learn" -version = "1.8.0" +version = "1.9.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "joblib", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "narwhals", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "numpy", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "scipy", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "threadpoolctl", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/0e/d4/40988bf3b8e34feec1d0e6a051446b1f66225f8529b9309becaeef62b6c4/scikit_learn-1.8.0.tar.gz", hash = "sha256:9bccbb3b40e3de10351f8f5068e105d0f4083b1a65fa07b6634fbc401a6287fd", size = 7335585, upload-time = "2025-12-10T07:08:53.618Z" } +sdist = { url = "https://files.pythonhosted.org/packages/fa/6f/37092bdb25f712817231799fc5674d8e704066a8a70c1d2d40517e18b4ab/scikit_learn-1.9.0.tar.gz", hash = "sha256:8833266989d3a5110178a9fae30783675460724d0e1efb13b14901d2c660c557", size = 7750767, upload-time = "2026-06-02T11:54:32.706Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/dd/47/f187b4636ff80cc63f21cd40b7b2d177134acaa10f6bb73746130ee8c2e5/scikit_learn-1.8.0-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:4496bb2cf7a43ce1a2d7524a79e40bc5da45cf598dbf9545b7e8316ccba47bb4", size = 8660492, upload-time = "2025-12-10T07:07:55.574Z" }, - { url = "https://files.pythonhosted.org/packages/97/74/b7a304feb2b49df9fafa9382d4d09061a96ee9a9449a7cbea7988dda0828/scikit_learn-1.8.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a0bcfe4d0d14aec44921545fd2af2338c7471de9cb701f1da4c9d85906ab847a", size = 8931904, upload-time = "2025-12-10T07:07:57.666Z" }, + { url = "https://files.pythonhosted.org/packages/a0/ee/5adbc77656b71f9456a2f5a7a9fdb4bcf9207a6b962889f1c2f9323afa4e/scikit_learn-1.9.0-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:5e50ed4da51974e86e940690e9a3d82e729b62b5a49f7c9bac534d515d39d86f", size = 8837603, upload-time = "2026-06-02T11:53:30.328Z" }, + { url = "https://files.pythonhosted.org/packages/6c/c2/63fdda36c56437eeb44aaf9493c8bcd62ce230ab1598924fc626ffbfa943/scikit_learn-1.9.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:056c92bb67ad4c28463c2f2653d9701449201e7e7a9e94e321be0f71c4fef2b8", size = 9132097, upload-time = "2026-06-02T11:53:33.456Z" }, ] [[package]] @@ -4708,14 +4842,14 @@ wheels = [ [[package]] name = "synchronicity" -version = "0.12.2" +version = "0.12.3" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "typing-extensions", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/f9/5e/50ea27817003665c7cc4f5bdad309f13d6329037f657848ee87fe06c3740/synchronicity-0.12.2.tar.gz", hash = "sha256:6fd605a5035d1ec74ce48fffaca80ea00345c84ca34223914e2436fb4f162ff9", size = 60018, upload-time = "2026-04-06T15:06:15.447Z" } +sdist = { url = "https://files.pythonhosted.org/packages/ec/d5/e96e6082790c92480380f28aa53e111844cdac7b0f75846f4772cb535a43/synchronicity-0.12.3.tar.gz", hash = "sha256:0d4228b85eaf2805f23b4615b2039a9d24ea811646e2d9f8d0c033094eb85841", size = 60261, upload-time = "2026-05-28T12:33:50.206Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/41/44/4f6ba4e2c171847e6f9a460213b196bbf26edea43d0e66889c7ccc55d368/synchronicity-0.12.2-py3-none-any.whl", hash = "sha256:9dbaca81fb7f2b57c6dea326e514e1c80e9ccfd9c9618515e84fa6091026273b", size = 41312, upload-time = "2026-04-06T15:06:14.459Z" }, + { url = "https://files.pythonhosted.org/packages/57/ea/531a6ea751cbd989da386144810b1b8f529b0aae8c1a9beda8b40966c9c2/synchronicity-0.12.3-py3-none-any.whl", hash = "sha256:e476818cd14102136f41622c619de548f0000c024485fc18521c8fe908ea7574", size = 40982, upload-time = "2026-05-28T12:33:49.125Z" }, ] [[package]] @@ -4727,6 +4861,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/99/55/db07de81b5c630da5cbf5c7df646580ca26dfaefa593667fc6f2fe016d2e/tabulate-0.10.0-py3-none-any.whl", hash = "sha256:f0b0622e567335c8fabaaa659f1b33bcb6ddfe2e496071b743aa113f8774f2d3", size = 39814, upload-time = "2026-03-04T18:55:31.284Z" }, ] +[[package]] +name = "tasksets" +version = "0.1.0" +source = { editable = "deps/verifiers/packages/tasksets" } + +[package.metadata] +requires-dist = [ + { name = "nltk", marker = "extra == 'textarena'", specifier = ">=3.9.2" }, + { name = "textarena", marker = "extra == 'textarena'", specifier = "==0.7.4" }, +] +provides-extras = ["textarena"] + [[package]] name = "tau2" version = "0.2.1.dev0" @@ -5319,51 +5465,60 @@ wheels = [ name = "verifiers" source = { editable = "deps/verifiers" } dependencies = [ + { name = "aiohttp", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "aiolimiter", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "anthropic", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "certifi", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "datasets", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "gepa", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "httpx", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "jinja2", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "loguru", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "math-verify", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "mcp", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "modal", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "msgpack", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "nest-asyncio", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "numpy", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "openai", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "openai-agents", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "pillow", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "prime-pydantic-config", extra = ["toml"], marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "prime-sandboxes", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "prime-tunnel", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "pydantic", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "pymupdf", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "pyzmq", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "regex", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "renderers", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "requests", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "rich", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "setproctitle", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "tenacity", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "textual", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "tomli-w", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "uvloop", marker = "(platform_machine == 'aarch64' and platform_python_implementation != 'PyPy' and sys_platform == 'linux') or (platform_machine == 'x86_64' and platform_python_implementation != 'PyPy' and sys_platform == 'linux')" }, ] [package.metadata] requires-dist = [ { name = "accelerate", marker = "extra == 'rl'", specifier = ">=1.4.0" }, + { name = "aiohttp", specifier = ">=3.9.0" }, { name = "aiohttp", marker = "extra == 'browser'", specifier = ">=3.9.0" }, { name = "aiolimiter", specifier = ">=1.2.1" }, { name = "anthropic", specifier = ">=0.78.0" }, + { name = "certifi" }, { name = "datasets", specifier = ">=3.0.0,<4.7.0" }, { name = "deepspeed", marker = "extra == 'rl'", specifier = ">=0.17.6" }, { name = "flash-attn", marker = "extra == 'rl'", specifier = ">=2.8.3" }, { name = "gepa" }, - { name = "harnesses", marker = "extra == 'harnesses'", editable = "deps/verifiers/packages/harnesses" }, - { name = "harnesses", marker = "extra == 'packages'", editable = "deps/verifiers/packages/harnesses" }, - { name = "harnesses", extras = ["nemogym"], marker = "extra == 'nemogym'", editable = "deps/verifiers/packages/harnesses" }, { name = "httpx", specifier = ">=0.27.0" }, { name = "jinja2", specifier = ">=3.1.6" }, { name = "liger-kernel", marker = "extra == 'rl'", specifier = ">=0.5.10" }, + { name = "loguru", specifier = ">=0.7.0" }, { name = "math-verify", specifier = ">=0.8.0" }, { name = "mcp", specifier = ">=1.14.1" }, + { name = "modal", specifier = ">=1.4.0" }, { name = "msgpack", specifier = ">=1.1.2" }, { name = "nest-asyncio", specifier = ">=1.6.0" }, { name = "nltk", marker = "extra == 'ta'" }, @@ -5371,29 +5526,28 @@ requires-dist = [ { name = "openai", specifier = ">=1.108.1" }, { name = "openai-agents", specifier = ">=0.0.7" }, { name = "peft", marker = "extra == 'rl'" }, - { name = "prime-pydantic-config", extras = ["toml"] }, - { name = "prime-sandboxes", specifier = ">=0.2.25" }, - { name = "prime-tunnel", specifier = ">=0.1.6" }, + { name = "pillow" }, + { name = "prime-pydantic-config", extras = ["toml"], specifier = ">=0.3.0.dev86" }, + { name = "prime-sandboxes", specifier = ">=0.2.27" }, + { name = "prime-tunnel", specifier = ">=0.1.8" }, { name = "pydantic", specifier = ">=2.11.9" }, + { name = "pymupdf" }, { name = "python-dotenv", marker = "extra == 'browser'", specifier = ">=1.0.0" }, { name = "pyzmq", specifier = ">=27.1.0" }, { name = "reasoning-gym", marker = "extra == 'rg'" }, { name = "regex", specifier = "<2026.4.4" }, - { name = "renderers", marker = "extra == 'renderers'", specifier = ">=0.1.8.dev28" }, + { name = "renderers", specifier = ">=0.1.8.dev40" }, + { name = "renderers", marker = "extra == 'renderers'", specifier = ">=0.1.8.dev40" }, { name = "requests" }, { name = "requests", marker = "extra == 'rl'" }, { name = "rich" }, { name = "setproctitle", specifier = ">=1.3.0" }, { name = "stagehand", marker = "extra == 'browser'", specifier = ">=3.0.0" }, - { name = "tasksets", extras = ["nemogym"], marker = "extra == 'nemogym'", editable = "deps/verifiers/packages/tasksets" }, - { name = "tasksets", extras = ["openenv"], marker = "extra == 'openenv'", editable = "deps/verifiers/packages/tasksets" }, - { name = "tasksets", extras = ["openenv", "openreward", "ta"], marker = "extra == 'packages'", editable = "deps/verifiers/packages/tasksets" }, - { name = "tasksets", extras = ["openenv", "openreward", "ta"], marker = "extra == 'tasksets'", editable = "deps/verifiers/packages/tasksets" }, - { name = "tasksets", extras = ["openreward"], marker = "extra == 'openreward'", editable = "deps/verifiers/packages/tasksets" }, { name = "tenacity", specifier = ">=8.5.0" }, { name = "textarena", marker = "extra == 'ta'" }, { name = "textual" }, { name = "tomli", marker = "python_full_version < '3.11'" }, + { name = "tomli-w", specifier = ">=1.0.0" }, { name = "torch", marker = "extra == 'rl'", specifier = ">=2.8.0,<2.9.0" }, { name = "transformers", marker = "extra == 'rl'", specifier = ">=4.56.2" }, { name = "typing-extensions", marker = "python_full_version < '3.12'" }, @@ -5401,14 +5555,14 @@ requires-dist = [ { name = "vllm", marker = "extra == 'rl'", specifier = ">=0.10.0,<0.11.0" }, { name = "wandb", marker = "extra == 'rl'" }, ] -provides-extras = ["browser", "harnesses", "nemogym", "openenv", "openreward", "packages", "renderers", "rg", "rl", "ta", "tasksets"] +provides-extras = ["browser", "renderers", "rg", "rl", "ta"] [package.metadata.requires-dev] dev = [ { name = "aiohttp", specifier = ">=3.9.0" }, - { name = "harnesses", editable = "deps/verifiers/packages/harnesses" }, { name = "ipykernel" }, { name = "ipywidgets" }, + { name = "nltk" }, { name = "pre-commit" }, { name = "pytest", specifier = ">=7.0.0" }, { name = "pytest-asyncio", specifier = ">=0.21.0" }, @@ -5416,13 +5570,28 @@ dev = [ { name = "pytest-xdist", specifier = ">=3.8.0" }, { name = "python-dotenv", specifier = ">=1.0.0" }, { name = "reasoning-gym" }, - { name = "renderers", specifier = ">=0.1.8.dev28" }, { name = "ruff" }, { name = "stagehand", specifier = ">=3.0.0" }, - { name = "tasksets", extras = ["openenv", "openreward", "ta"], editable = "deps/verifiers/packages/tasksets" }, + { name = "textarena" }, { name = "ty", specifier = ">=0.0.1a29,<0.0.22" }, ] -policy = [{ name = "semgrep", specifier = ">=1.150.0" }] +examples = [ + { name = "aime24-v1", editable = "deps/verifiers/examples/tasksets/aime24_v1" }, + { name = "alphabet-sort-v1", editable = "deps/verifiers/examples/tasksets/alphabet_sort_v1" }, + { name = "code-golf-v1", editable = "deps/verifiers/examples/tasksets/code_golf_v1" }, + { name = "compact", editable = "deps/verifiers/examples/harnesses/compact" }, + { name = "deepwiki-v1", editable = "deps/verifiers/examples/tasksets/deepwiki_v1" }, + { name = "glossary-v1", editable = "deps/verifiers/examples/tasksets/glossary_v1" }, + { name = "gsm8k-v1", editable = "deps/verifiers/examples/tasksets/gsm8k_v1" }, + { name = "math-env-v1", editable = "deps/verifiers/examples/tasksets/math_env_v1" }, + { name = "reverse-text-v1", editable = "deps/verifiers/examples/tasksets/reverse_text_v1" }, + { name = "terminal-bench-2-v1", editable = "deps/verifiers/examples/tasksets/terminal_bench_2_v1" }, + { name = "wiki-search-v1", editable = "deps/verifiers/examples/tasksets/wiki_search_v1" }, + { name = "wikispeedia-v1", editable = "deps/verifiers/examples/tasksets/wikispeedia_v1" }, + { name = "wordle-v1", editable = "deps/verifiers/examples/tasksets/wordle_v1" }, +] +harnesses = [{ name = "harnesses", editable = "deps/verifiers/packages/harnesses" }] +tasksets = [{ name = "tasksets", editable = "deps/verifiers/packages/tasksets" }] [[package]] name = "virtualenv"