Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 76 additions & 0 deletions configs/debug/v1/multimodal.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# v1 port of configs/debug/multimodal.toml — identical except the env block, which loads
# the native v1 `color-codeword-v1` taskset (images delivered by a colocated vf.User) on the
# subprocess runtime, instead of the v0 `color-codeword` env. Exercises native v1 image
# input → renderer → multi_modal_data → Qwen3-VL training end-to-end.

max_steps = 15
seq_len = 4096

[model]
name = "Qwen/Qwen3-VL-4B-Instruct"

[model.vlm]
vision_encoder_attr = "model.visual"
language_model_attr = "model.language_model"

[deployment]
num_train_gpus = 1
num_infer_gpus = 1
gpus_per_node = 2

[orchestrator]
batch_size = 256
group_size = 16
# Image processor is CPU-bound and dominates for VLMs; returns diminish past 4.
pool_size = 4

# Step 0 on Qwen3-VL-4B vs color-codeword can be uniform (all-correct or
# all-wrong), so don't enforce zero-advantage dropping or training would crash
# before any progress.
[[orchestrator.post_batch_filters]]
type = "gibberish"

[[orchestrator.post_batch_filters]]
type = "repetition"

[[orchestrator.post_batch_filters]]
type = "zero_advantage"
enforce = false

[orchestrator.train.sampling]
max_completion_tokens = 64

[[orchestrator.train.env]]
taskset = { id = "color-codeword-v1", images_per_turn = 1, max_turns = 3, num_examples = 1000, seed = 42 }
harness = { id = "default", enable_bash = false, runtime = { type = "subprocess" } }

# Default renderer (AutoRendererConfig) resolves Qwen3-VL-4B-Instruct from
# MODEL_RENDERER_MAP to Qwen3VLRenderer at runtime; no explicit name needed.

[trainer]

[trainer.model]
optimization_dtype = "bfloat16"
reduce_dtype = "bfloat16"

[trainer.optim]
lr = 3e-6

[inference]

[inference.model]
# Workaround for vLLM 0.20.1 Qwen3-VL deepstack buffer bug: when num_scheduled_tokens
# (188) gets padded up to the next cudagraph_capture_size (192), the model's
# _set_deepstack_input_embeds sizes the buffer to 188 but forward() runs with 192,
# triggering "Requested more deepstack tokens than available in buffer". Eager mode
# skips the padding so num_input_tokens == num_scheduled_tokens.
enforce_eager = true

[inference.parallel]
dp = 1
tp = 1

[wandb]
project = "debug"
name = "multimodal-v1"
tags = ["qwen3vl-4b", "color-codeword-v1", "renderer", "v1"]
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ envs = [
"gsm8k-v1",
"math-env-v1",
"aime24-v1",
"color-codeword-v1",
"harnesses",
]
disagg = [
Expand Down Expand Up @@ -246,6 +247,7 @@ reverse-text-v1 = { path = "deps/verifiers/examples/tasksets/reverse_text_v1", e
gsm8k-v1 = { path = "deps/verifiers/examples/tasksets/gsm8k_v1", editable = true }
math-env-v1 = { path = "deps/verifiers/examples/tasksets/math_env_v1", editable = true }
aime24-v1 = { path = "deps/verifiers/examples/tasksets/aime24_v1", editable = true }
color-codeword-v1 = { path = "deps/verifiers/examples/tasksets/color_codeword_v1", editable = true }
harnesses = { path = "deps/verifiers/packages/harnesses", editable = true }
renderers = { path = "deps/renderers", editable = true }
prime-pydantic-config = { path = "deps/pydantic-config", editable = true }
Expand Down
9 changes: 4 additions & 5 deletions src/prime_rl/orchestrator/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
compute_teacher_logprobs,
get_weight_dir,
intercept_vf_logging,
mm_token_type_id_map,
save_rollouts,
set_default_executor,
setup_student_inference_pool,
Expand Down Expand Up @@ -215,11 +216,9 @@ async def setup(self) -> None:
self.renderer, self.student_inference = await setup_student_inference_pool(
config=config, tokenizer=self.tokenizer
)
self.mm_token_type_ids_mapping = (
getattr(self.renderer, "mm_token_type_id_map", None) if self.renderer is not None else None
)
if self.mm_token_type_ids_mapping == {}:
self.mm_token_type_ids_mapping = None
# The renderer lives in the env server, so derive the VLM type-id map from the
# renderer config directly (None for text models).
self.mm_token_type_ids_mapping = mm_token_type_id_map(config)

if config.teacher is not None:
get_logger().info(
Expand Down
7 changes: 6 additions & 1 deletion src/prime_rl/orchestrator/train_sink.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,12 @@ async def process_rollout(self, rollout: TrainRollout) -> None:
return
if any(turn.tokens is None for turn in rollout.trace.trajectory):
await asyncio.to_thread(backfill_rollout_tokens, rollout.trace, self.tokenizer)
samples = await asyncio.to_thread(trace_to_samples, rollout.trace, env_name=rollout.env_name)
samples = await asyncio.to_thread(
trace_to_samples,
rollout.trace,
env_name=rollout.env_name,
mm_token_type_ids_mapping=self.mm_token_type_ids_mapping,
)
rollout.samples = samples or []

def process_group(self, group_id: uuid.UUID) -> None:
Expand Down
62 changes: 61 additions & 1 deletion src/prime_rl/orchestrator/trajectories.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from verifiers.v1.clients.openai import message_to_wire

from prime_rl.transport import TrainingSample
from prime_rl.transport.types import EncodedTensor
from prime_rl.utils.chat_template import (
common_prefix_len,
deserialize_tool_calls,
Expand Down Expand Up @@ -50,7 +51,57 @@ def backfill_rollout_tokens(trace: vf.Trace, tokenizer) -> None:
)


def trace_to_samples(trace: vf.Trace, *, env_name: str = "") -> list[TrainingSample]:
def _decode_wire_tensor(wt: vf.WireTensor):
import base64

import numpy as np
import torch

arr = np.frombuffer(base64.b64decode(wt.data), dtype=np.dtype(wt.dtype)).reshape(wt.shape)
return torch.from_numpy(arr.copy())


def _pack_mm_kwargs(mm_list: list[vf.MMData]) -> dict[str, EncodedTensor] | None:
"""Union each turn's *new* images into model-agnostic `mm_kwargs`: concat each
HF-processor kwarg (e.g. `pixel_values`, `image_grid_thw`) in turn order. The model's
`forward` signature is the schema, so image/video keys don't clash.

The stitched ids carry each image's placeholder tokens once (in the turn it's
introduced), so we contribute each image once too. A turn's `multi_modal_data` may be
*cumulative* (the renderer re-rendered the whole prompt → every image so far, native v1)
or *delta* (only the turn's new images, the v0 bridge); a turn is cumulative iff its
hashes restate everything taken so far, so we take only the appended tail. Identical
images in different turns (e.g. two squares of the same color) keep distinct slots —
matched by position, not deduped by hash."""
import torch

per_kwarg: dict[str, list] = {}
taken: dict[str, list] = {} # modality -> hashes contributed so far, in order
for mm in mm_list:
for modality, items in mm.mm_items.items():
hashes = list(mm.mm_hashes.get(modality) or [None] * len(items))
acc = taken.setdefault(modality, [])
if None not in hashes and hashes[: len(acc)] == acc:
start = len(acc) # cumulative turn: skip the restated prefix
acc[:] = hashes
else:
start = 0 # delta turn: all images are new
acc.extend(hashes)
for item in items[start:]:
for key, wt in item.items():
per_kwarg.setdefault(key, []).append(_decode_wire_tensor(wt))
if not per_kwarg:
return None
out: dict[str, EncodedTensor] = {}
for key, tensors in per_kwarg.items():
arr = torch.cat(tensors, dim=0).contiguous().numpy()
out[key] = EncodedTensor(dtype=str(arr.dtype), shape=list(arr.shape), data=arr.tobytes())
return out


def trace_to_samples(
trace: vf.Trace, *, env_name: str = "", mm_token_type_ids_mapping: dict[int, int] | None = None
) -> list[TrainingSample]:
"""Convert a v1 `Trace` into `TrainingSample`s — one per branch.

Stitch each branch's turns into one token sequence: the branch's first-turn
Expand Down Expand Up @@ -91,6 +142,13 @@ def trace_to_samples(trace: vf.Trace, *, env_name: str = "") -> list[TrainingSam

if not completion_ids:
continue
# Multimodal: union each turn's image/video tensors into mm_kwargs, and tag every
# token with its modality (image-placeholder vs text) via the renderer's map.
mm = [t.tokens.multi_modal_data for t in turns if t.tokens and t.tokens.multi_modal_data]
mm_kwargs = _pack_mm_kwargs(mm) if mm else None
mm_token_type_ids = None
if mm_kwargs is not None and mm_token_type_ids_mapping:
mm_token_type_ids = [mm_token_type_ids_mapping.get(tid, 0) for tid in prompt_ids + completion_ids]
samples.append(
TrainingSample(
prompt_ids=prompt_ids,
Expand All @@ -102,6 +160,8 @@ def trace_to_samples(trace: vf.Trace, *, env_name: str = "") -> list[TrainingSam
teacher_logprobs=None,
advantage=None,
env_name=env_name,
mm_kwargs=mm_kwargs,
mm_token_type_ids=mm_token_type_ids,
)
)
if not samples:
Expand Down
13 changes: 13 additions & 0 deletions src/prime_rl/orchestrator/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,19 @@ async def setup_student_inference_pool(*, config: OrchestratorConfig, tokenizer)
return None, inference_pool


def mm_token_type_id_map(config: OrchestratorConfig) -> dict[int, int] | None:
"""The image/video placeholder-token → type-id map the VLM trainer needs to build
`mm_token_type_ids` (M-RoPE). Built transiently from the renderer config — the
orchestrator keeps no renderer of its own, and a VLM always has one configured
(`vlm_requires_renderer`). None for a text model."""
if not config.student.model.is_vlm or config.renderer is None:
return None
from renderers import create_renderer_pool

pool = create_renderer_pool(config.student.model.name, config.renderer, size=1)
return getattr(pool, "mm_token_type_id_map", None) or None


def get_model_completion_len(output: vf.Trace) -> int:
"""Sum of model-generated completion tokens across all turns (excludes
environment-injected tokens between turns)."""
Expand Down
13 changes: 13 additions & 0 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading