From f32df985584ecacb320f13566b979729192cc622 Mon Sep 17 00:00:00 2001 From: Eli Gottlieb Date: Wed, 27 May 2026 01:59:40 +0000 Subject: [PATCH 01/31] feat(orchestrator): reconstruct mm pixels at training-sample build The env worker now ships descriptor-only mm_data, so re-derive pixel_values in the orchestrator when building training samples: reprocess the offloaded (file://) window images via the renderer, matched by content hash with a grid_thw assert, before packing mm_kwargs. Thread the rollout renderer into interleave_rollout; trainer unchanged. _pack_mm_kwargs_from_renderer now normalizes decoded payloads via torch.as_tensor so reconstructed numpy pixels batch alongside the existing encoded-wire path. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/prime_rl/orchestrator/orchestrator.py | 7 ++- src/prime_rl/orchestrator/trajectories.py | 75 ++++++++++++++++++++++- 2 files changed, 80 insertions(+), 2 deletions(-) diff --git a/src/prime_rl/orchestrator/orchestrator.py b/src/prime_rl/orchestrator/orchestrator.py index 65f12ccf77..b46cc6e0c4 100644 --- a/src/prime_rl/orchestrator/orchestrator.py +++ b/src/prime_rl/orchestrator/orchestrator.py @@ -489,7 +489,12 @@ async def orchestrate(config: OrchestratorConfig): # Process rollouts in parallel results = await asyncio.gather( *( - asyncio.to_thread(interleave_rollout, r, mm_token_type_ids_mapping=mm_token_type_ids_mapping) + asyncio.to_thread( + interleave_rollout, + r, + mm_token_type_ids_mapping=mm_token_type_ids_mapping, + renderer=renderer, + ) for r in train_rollouts ) ) diff --git a/src/prime_rl/orchestrator/trajectories.py b/src/prime_rl/orchestrator/trajectories.py index 42af5ac042..e891b1ae67 100644 --- a/src/prime_rl/orchestrator/trajectories.py +++ b/src/prime_rl/orchestrator/trajectories.py @@ -204,6 +204,7 @@ def backfill_rollout_tokens( def interleave_rollout( output: vf.RolloutOutput, mm_token_type_ids_mapping: dict[int, int] | None = None, + renderer: Any = None, ) -> list[TrainingSample] | None: """ Convert vf.RolloutOutput to trainable rollouts by interleaving trajectory steps @@ -425,6 +426,15 @@ def extend_sample( for _, sample, step_indices in active_samples: renderer_mm = _union_step_mm_data(prepared_steps, step_indices) if renderer_mm is not None: + # The env worker ships descriptor-only mm_data (no pixel_values) to + # keep its memory flat. Re-derive the pixels here from the offloaded + # images referenced in this sample's messages, matched by content + # hash with a grid_thw assert. ``renderer`` is the multimodal pool + # used for rollouts; absent (or already-pixel-bearing in-process + # tests) → pass through unchanged. + if renderer is not None and _mm_needs_pixels(renderer_mm): + window_messages = _window_image_messages(trajectory, step_indices) + renderer_mm = _reconstruct_mm_pixels(renderer, renderer_mm, window_messages) mm_kwargs = _pack_mm_kwargs_from_renderer(renderer_mm) if mm_kwargs is not None: sample.mm_kwargs = mm_kwargs @@ -479,6 +489,65 @@ def _union_step_mm_data( return {"mm_items": union_items, "mm_hashes": union_hashes} +def _mm_needs_pixels(union_mm: dict[str, Any]) -> bool: + """True if any mm item lacks ``pixel_values`` (descriptor-only shape).""" + for items in (union_mm.get("mm_items") or {}).values(): + for item in items or []: + if isinstance(item, dict) and item.get("pixel_values") is None: + return True + return False + + +def _window_image_messages(trajectory: list[Any], step_indices: list[int]) -> list[Any]: + """Collect the messages from the steps this sample covers. + + The offloaded images (``file://`` after ``offload_images_to_disk``, or + inline base64 in-process) live in the step prompts, in conversation order. + Concatenating the window's prompts gives ``materialize_pixels`` every image + it needs to re-derive pixels by hash; duplicates across cumulative prompts + are harmless (the cache absorbs them and matching stops once resolved). + """ + messages: list[Any] = [] + for si in step_indices: + if not (0 <= si < len(trajectory)): + continue + prompt = trajectory[si].get("prompt") + if isinstance(prompt, list): + messages.extend(prompt) + return messages + + +def _reconstruct_mm_pixels(renderer: Any, union_mm: dict[str, Any], messages: list[Any]) -> Any: + """Re-attach ``pixel_values`` to a descriptor-only union mm_data. + + Delegates to the renderer's ``materialize_pixels`` (hash-matched reprocess + of the window images, with a ``grid_thw`` assert). The descriptor's + ``image_grid_thw`` is decoded from its msgpack wire shape back to numpy + first, so the renderer's numpy-vs-numpy grid assert holds after transport. + """ + from renderers.base import MultiModalData + from verifiers.utils.serve_utils import decode_tensor_payload + + items = union_mm.get("mm_items") or {} + decoded_items: dict[str, list] = {} + for modality, lst in items.items(): + new_lst: list[dict[str, Any]] = [] + for item in lst or []: + item = dict(item) + grid = item.get("image_grid_thw") + if item.get("pixel_values") is None and grid is not None: + item["image_grid_thw"] = decode_tensor_payload(grid, to_torch=False) + new_lst.append(item) + decoded_items[modality] = new_lst + + md = MultiModalData( + mm_hashes=union_mm.get("mm_hashes") or {}, + mm_placeholders={}, + mm_items=decoded_items, + ) + return renderer.materialize_pixels(md, messages) + + def _pack_mm_kwargs_from_renderer(mm_data: Any) -> "dict[str, Any] | None": """Batch the renderer's per-image ``mm_items`` into model-agnostic forward kwargs. @@ -508,7 +577,11 @@ def _pack_mm_kwargs_from_renderer(mm_data: Any) -> "dict[str, Any] | None": for _modality, items in mm_items.items(): for item in items or []: for key, payload in item.items(): - per_kwarg.setdefault(key, []).append(decode_tensor_payload(payload)) + # ``decode_tensor_payload`` rehydrates the encoded wire shape to + # torch but passes already-rehydrated numpy through unchanged + # (e.g. pixels reconstructed in-process from disk). ``as_tensor`` + # normalizes both to torch so the ``torch.cat`` below is uniform. + per_kwarg.setdefault(key, []).append(torch.as_tensor(decode_tensor_payload(payload))) if not per_kwarg: return None out: dict[str, EncodedTensor] = {} From 1e89b7ac79cfffac4e52d8ab977487974087e439 Mon Sep 17 00:00:00 2001 From: Eli Gottlieb Date: Wed, 27 May 2026 02:00:42 +0000 Subject: [PATCH 02/31] chore: bump renderers + verifiers pins to ephemeral-mm-pixels Point the submodules at the descriptor-only mm_data work: - renderers: pixels ephemeral on the rollout path + bridge no-mutate fix - verifiers: descriptor-only delta regression test Co-Authored-By: Claude Opus 4.7 (1M context) --- deps/renderers | 2 +- deps/verifiers | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/deps/renderers b/deps/renderers index 2ec28a8543..1550bf02b8 160000 --- a/deps/renderers +++ b/deps/renderers @@ -1 +1 @@ -Subproject commit 2ec28a854388c985b95fd276d95651c7820d3df8 +Subproject commit 1550bf02b8e412833c4e9bdc694f1f65062a1fe2 diff --git a/deps/verifiers b/deps/verifiers index f9c68eb28c..8703573d4a 160000 --- a/deps/verifiers +++ b/deps/verifiers @@ -1 +1 @@ -Subproject commit f9c68eb28ccf0448d4573b7ca50f1163f81c5cfd +Subproject commit 8703573d4a8338d517901f55ae90accaa0aec76d From 7e67a088596634a5b5ff6efeb66bb64d87419478 Mon Sep 17 00:00:00 2001 From: Eli Gottlieb Date: Fri, 29 May 2026 04:48:11 +0000 Subject: [PATCH 03/31] chore(memory): cap native threads on orchestrator + env-worker spawn Apply verifiers' native-thread limits (OMP/MKL/BLAS=1, MALLOC_ARENA_MAX=2, tokenizers parallelism off) to the orchestrator subprocess env and around the env-worker `process.start()`, so high-core hosts don't explode native thread teams + glibc arenas during multimodal image processing. Co-Authored-By: Claude Opus 4.8 (1M context) --- src/prime_rl/entrypoints/rl.py | 21 +++++++++++++-------- src/prime_rl/orchestrator/envs.py | 5 ++++- 2 files changed, 17 insertions(+), 9 deletions(-) diff --git a/src/prime_rl/entrypoints/rl.py b/src/prime_rl/entrypoints/rl.py index 5b342c58a2..f1304b620b 100644 --- a/src/prime_rl/entrypoints/rl.py +++ b/src/prime_rl/entrypoints/rl.py @@ -200,19 +200,24 @@ def sigterm_handler(signum, frame): ] logger.info("Starting orchestrator process") logger.debug(f"Orchestrator start command: {' '.join(orchestrator_cmd)}") + from verifiers.utils.native_threads import native_thread_limited_env + + orchestrator_env = native_thread_limited_env( + { + **os.environ, + **wandb_shared_env, + "WANDB_SHARED_LABEL": "orchestrator", + "LOGURU_FORCE_COLORS": "1", + "WANDB_PROGRAM": "uv run rl", + "WANDB_ARGS": json.dumps(start_command), + } + ) with open(log_dir / "orchestrator.log", "w") as log_file: orchestrator_process = Popen( orchestrator_cmd, stdout=log_file, stderr=log_file, - env={ - **os.environ, - **wandb_shared_env, - "WANDB_SHARED_LABEL": "orchestrator", - "LOGURU_FORCE_COLORS": "1", - "WANDB_PROGRAM": "uv run rl", - "WANDB_ARGS": json.dumps(start_command), - }, + env=orchestrator_env, ) processes.append(orchestrator_process) diff --git a/src/prime_rl/orchestrator/envs.py b/src/prime_rl/orchestrator/envs.py index 03b1f59e90..d8931b1f49 100644 --- a/src/prime_rl/orchestrator/envs.py +++ b/src/prime_rl/orchestrator/envs.py @@ -100,7 +100,10 @@ def _spawn( ), daemon=False, ) - process.start() + from verifiers.utils.native_threads import scoped_native_thread_limits + + with scoped_native_thread_limits(): + process.start() self._env_server_process = process return address From c019044c29324948ef219ca6ee353ec31f4b921c Mon Sep 17 00:00:00 2001 From: Eli Gottlieb Date: Fri, 29 May 2026 04:48:38 +0000 Subject: [PATCH 04/31] feat(orchestrator): bound + trim multimodal pixel materialization; bump deps MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Reduce orchestrator memory on VLM RL batches when converting rollouts to training samples: - Bound the interleave fan-out with a semaphore (new `OrchestratorConfig.mm_materialize_concurrency`, default 4) so only N rollouts reconstruct pixels at once — caps the transient build spike that otherwise stacks across the whole batch. - Skip pixel materialization for filtered rollouts (their samples are dropped before the trainer), and auto-wire `VF_RENDERER_IMAGE_OFFLOAD_DIR` to the run's assets dir so the env-worker live offload needs no override. - `offload_images_to_disk` now also normalizes `file://` images (leave if already under the run dir, else copy in) with hashing aligned to the env-worker writer (sha256 of decoded bytes + media-type ext). - Drop the masked_scatter repro instrumentation. Bump deps/renderers -> 73345a2 (descriptor-aware hash-only/full serialization) and deps/verifiers -> bff4cb78 (lean ephemeral-MM policy + live image offload), so prime-rl imports the consolidated multimodal path with no PYTHONPATH override. Co-Authored-By: Claude Opus 4.8 (1M context) --- deps/renderers | 2 +- deps/verifiers | 2 +- .../src/prime_rl/configs/orchestrator.py | 3 + src/prime_rl/orchestrator/orchestrator.py | 33 +++-- src/prime_rl/orchestrator/trajectories.py | 133 +++++++++++++----- tests/unit/orchestrator/test_trajectories.py | 124 ++++++++++++++++ 6 files changed, 249 insertions(+), 48 deletions(-) diff --git a/deps/renderers b/deps/renderers index 1550bf02b8..73345a25a9 160000 --- a/deps/renderers +++ b/deps/renderers @@ -1 +1 @@ -Subproject commit 1550bf02b8e412833c4e9bdc694f1f65062a1fe2 +Subproject commit 73345a25a91c1e9d731fdccabe096492a2db6ea2 diff --git a/deps/verifiers b/deps/verifiers index 8703573d4a..bff4cb786f 160000 --- a/deps/verifiers +++ b/deps/verifiers @@ -1 +1 @@ -Subproject commit 8703573d4a8338d517901f55ae90accaa0aec76d +Subproject commit bff4cb786fe12b681067e57890699489786af111 diff --git a/packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py b/packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py index e20afd456b..cfa844b835 100644 --- a/packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py +++ b/packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py @@ -664,6 +664,9 @@ def _preserve_mito_renderer(self, handler: SerializerFunctionWrapHandler) -> dic max_off_policy_steps: int = Field(8, ge=0) """Maximum policies allowed to generate a single rollout. Rollouts generated more than ``max_off_policy_steps`` ahead of training are discarded. Higher values yield better throughput at the cost of off-policy noise.""" + mm_materialize_concurrency: int = Field(4, ge=1) + """Max rollouts whose multimodal pixels are reconstructed concurrently when converting rollouts to training samples. Bounds the transient build-time memory spike for VLM batches (each in-flight rollout holds live pixel tensors + packing copies). No effect on text-only runs.""" + bench: bool = False """Benchmark mode. Sets ``max_steps`` to 5 and disables W&B.""" diff --git a/src/prime_rl/orchestrator/orchestrator.py b/src/prime_rl/orchestrator/orchestrator.py index b46cc6e0c4..92d5bbcb39 100644 --- a/src/prime_rl/orchestrator/orchestrator.py +++ b/src/prime_rl/orchestrator/orchestrator.py @@ -203,6 +203,18 @@ async def orchestrate(config: OrchestratorConfig): env.sampling_args.pop("logprobs", None) logger.info(f"Loaded {len(train_envs)} training environment(s) ({', '.join(train_envs.names)})") + # Env workers (spawned by ``*.start`` → ``Env._spawn``, localhost subprocesses + # that inherit this process's ``os.environ``) offload screenshot base64 to + # disk during the live rollout to keep their RSS flat. Point them at this + # run's assets dir — the same one ``offload_images_to_disk`` writes to — so + # images land directly where the trainer reads them (no post-rollout copy) + # and are cleaned up with the run. ``setdefault`` lets an explicit + # ``VF_RENDERER_IMAGE_OFFLOAD_DIR`` override win. + os.environ.setdefault( + "VF_RENDERER_IMAGE_OFFLOAD_DIR", + str(config.output_dir / "assets" / "images"), + ) + await train_envs.start( log_dir=get_log_dir(config.output_dir.parent) / "envs" / "train", log_level=config.log.vf_level, @@ -486,18 +498,23 @@ async def orchestrate(config: OrchestratorConfig): ) ) - # Process rollouts in parallel - results = await asyncio.gather( - *( - asyncio.to_thread( + # Process rollouts in parallel, but bound how many materialize multimodal + # pixels at once: each in-flight VLM rollout transiently holds live pixel + # tensors + packing copies, so an unbounded fan-out stacks that overhead + # across the whole batch. The semaphore caps the transient peak; it does + # not change text-only runs (no pixels to materialize). + interleave_sem = asyncio.Semaphore(config.mm_materialize_concurrency) + + async def _interleave_bounded(rollout: vf.RolloutOutput): + async with interleave_sem: + return await asyncio.to_thread( interleave_rollout, - r, + rollout, mm_token_type_ids_mapping=mm_token_type_ids_mapping, renderer=renderer, ) - for r in train_rollouts - ) - ) + + results = await asyncio.gather(*(_interleave_bounded(r) for r in train_rollouts)) # Collect results and assign advantages. Metrics are computed over all # rollouts; only non-filtered samples are sent to the trainer. diff --git a/src/prime_rl/orchestrator/trajectories.py b/src/prime_rl/orchestrator/trajectories.py index e891b1ae67..d0d0de7992 100644 --- a/src/prime_rl/orchestrator/trajectories.py +++ b/src/prime_rl/orchestrator/trajectories.py @@ -423,29 +423,37 @@ def extend_sample( # reading the last step alone would miss every earlier-turn image. # Concat in step order recovers the per-sample cumulative set; # deduping again here would drop legitimate duplicate placeholders. - for _, sample, step_indices in active_samples: - renderer_mm = _union_step_mm_data(prepared_steps, step_indices) - if renderer_mm is not None: - # The env worker ships descriptor-only mm_data (no pixel_values) to - # keep its memory flat. Re-derive the pixels here from the offloaded - # images referenced in this sample's messages, matched by content - # hash with a grid_thw assert. ``renderer`` is the multimodal pool - # used for rollouts; absent (or already-pixel-bearing in-process - # tests) → pass through unchanged. - if renderer is not None and _mm_needs_pixels(renderer_mm): - window_messages = _window_image_messages(trajectory, step_indices) - renderer_mm = _reconstruct_mm_pixels(renderer, renderer_mm, window_messages) - mm_kwargs = _pack_mm_kwargs_from_renderer(renderer_mm) - if mm_kwargs is not None: - sample.mm_kwargs = mm_kwargs - # ``mm_token_type_ids``: 1 for image-placeholder tokens, 2 - # for video, 0 otherwise. Renderer-supplied via - # ``mm_token_type_id_map`` (single source of truth). - if mm_token_type_ids_mapping is not None: - sample.mm_token_type_ids = [ - mm_token_type_ids_mapping.get(token_id, 0) - for token_id in sample.prompt_ids + sample.completion_ids - ] + # + # Skip pixel materialization for filtered rollouts: their samples are + # dropped before the trainer (the orchestrator collect loop only appends + # ``not is_filtered`` samples), so re-deriving their pixels is pure waste. + # The token bookkeeping above is kept — metrics are computed over all + # rollouts. ``is_filtered`` is absent for standalone callers (tests), which + # default to materializing. + if not output.get("is_filtered", False): + for _, sample, step_indices in active_samples: + renderer_mm = _union_step_mm_data(prepared_steps, step_indices) + if renderer_mm is not None: + # The env worker ships descriptor-only mm_data (no pixel_values) + # to keep its memory flat. Re-derive the pixels here from the + # offloaded images referenced in this sample's messages, matched + # by content hash with a grid_thw assert. ``renderer`` is the + # multimodal pool used for rollouts; absent (or already + # pixel-bearing in-process tests) → pass through unchanged. + if renderer is not None and _mm_needs_pixels(renderer_mm): + window_messages = _window_image_messages(trajectory, step_indices) + renderer_mm = _reconstruct_mm_pixels(renderer, renderer_mm, window_messages) + mm_kwargs = _pack_mm_kwargs_from_renderer(renderer_mm) + if mm_kwargs is not None: + sample.mm_kwargs = mm_kwargs + # ``mm_token_type_ids``: 1 for image-placeholder tokens, 2 + # for video, 0 otherwise. Renderer-supplied via + # ``mm_token_type_id_map`` (single source of truth). + if mm_token_type_ids_mapping is not None: + sample.mm_token_type_ids = [ + mm_token_type_ids_mapping.get(token_id, 0) + for token_id in sample.prompt_ids + sample.completion_ids + ] return [sample for _, sample, _ in active_samples] @@ -599,15 +607,40 @@ def _pack_mm_kwargs_from_renderer(mm_data: Any) -> "dict[str, Any] | None": _FILE_URL_PREFIX = "file://" -def offload_images_to_disk(rollouts: list[vf.RolloutOutput], output_dir: Path) -> int: - """Replace base64 image data in rollout trajectories with file paths on disk. +_MEDIA_TYPE_EXT = {"jpeg": ".jpg", "jpg": ".jpg", "png": ".png", "webp": ".webp", "gif": ".gif"} + + +def _media_type_ext(media_type: str) -> str: + subtype = media_type.split("/", 1)[-1].split(";", 1)[0].strip().lower() + return _MEDIA_TYPE_EXT.get(subtype, ".img") + + +def _is_under(path: Path, parent: Path) -> bool: + """True if ``path`` lives directly in (or below) ``parent``.""" + try: + path, parent = path.resolve(), parent.resolve() + except OSError: + pass + return parent == path.parent or parent in path.parents - Scans all trajectory step prompts for data:image URLs, writes the decoded - image bytes to ``{output_dir}/assets/images/{hash}.png``, and replaces the - URL in-place with ``file://{path}``. Deduplicates by content hash so each - unique image is written only once. - Returns the number of unique images written to disk. +def offload_images_to_disk(rollouts: list[vf.RolloutOutput], output_dir: Path) -> int: + """Normalize trajectory prompt images to ``file://`` paths under the run dir. + + Scans all trajectory step prompts for image URLs and rewrites them in place: + + - ``data:image/...;base64,...`` → decode, write to + ``{output_dir}/assets/images/{sha256(decoded)[:16]}{ext}``, rewrite to + ``file://``. + - ``file://...`` already under ``{output_dir}/assets/images`` → left as-is + (the env worker offloaded it there during the live rollout). + - ``file://...`` elsewhere (e.g. a shared image cache the env worker wrote + to) → copied into the run's assets dir by content hash and rewritten, so + the rollout is self-contained for the trainer. + + Content-addressed by ``sha256(decoded_bytes)``, matching the env-worker live + offload, so an image offloaded during the rollout is recognized and never + re-decoded or duplicated. Returns the number of unique images written here. """ images_dir = output_dir / "assets" / "images" images_dir.mkdir(parents=True, exist_ok=True) @@ -624,18 +657,42 @@ def offload_images_to_disk(rollouts: list[vf.RolloutOutput], output_dir: Path) - if not isinstance(content, list): continue for item in content: - if item.get("type") != "image_url": + if not isinstance(item, dict) or item.get("type") != "image_url": + continue + image_url = item.get("image_url") + if not isinstance(image_url, dict): continue - url = item.get("image_url", {}).get("url", "") - if not url.startswith("data:image"): + url = image_url.get("url", "") + if not isinstance(url, str): continue - b64_data = url.split(",", 1)[1] - content_hash = hashlib.sha256(b64_data.encode()).hexdigest()[:16] - path = images_dir / f"{content_hash}.png" + + if url.startswith("data:image"): + header, _, b64_data = url.partition(",") + if not b64_data: + continue + try: + raw = base64.b64decode(b64_data) + except Exception: + continue + ext = _media_type_ext(header[len("data:") :]) + elif url.startswith(_FILE_URL_PREFIX): + src = Path(url[len(_FILE_URL_PREFIX) :]) + if _is_under(src, images_dir): + continue # already in the run's assets dir + try: + raw = src.read_bytes() + except OSError: + continue + ext = src.suffix or ".img" + else: + continue + + content_hash = hashlib.sha256(raw).hexdigest()[:16] + path = images_dir / f"{content_hash}{ext}" if content_hash not in written: if not path.exists(): - path.write_bytes(base64.b64decode(b64_data)) + path.write_bytes(raw) written.add(content_hash) - item["image_url"]["url"] = f"{_FILE_URL_PREFIX}{path}" + image_url["url"] = f"{_FILE_URL_PREFIX}{path}" return len(written) diff --git a/tests/unit/orchestrator/test_trajectories.py b/tests/unit/orchestrator/test_trajectories.py index 36c9ef1008..0ed7b16c41 100644 --- a/tests/unit/orchestrator/test_trajectories.py +++ b/tests/unit/orchestrator/test_trajectories.py @@ -1,3 +1,5 @@ +import base64 +from pathlib import Path from unittest.mock import MagicMock import numpy as np @@ -9,6 +11,7 @@ _deserialize_tool_calls, align_routed_experts, interleave_rollout, + offload_images_to_disk, ) _interleave_rollout = interleave_rollout @@ -1117,3 +1120,124 @@ def test_interleave_rollout_packs_pixels_from_renderer_mm_data(): assert _decode_mm_thw(sample) == [[1, 2, 3], [1, 4, 4]] # mm_token_type_ids: image at token 2, video at token 5, rest 0. assert sample.mm_token_type_ids == [0, 1, 0, 0, 2, 0, 0] + + +def test_interleave_rollout_skips_pixel_materialization_for_filtered_rollout(): + """A filtered rollout's samples are dropped before the trainer, so + ``interleave_rollout`` must skip the expensive pixel reconstruction — the + sample is still produced (for metrics) but carries no mm_kwargs.""" + import torch as _torch + from renderers.base import MultiModalData, PlaceholderRange + + mm = MultiModalData( + mm_hashes={"image": ["h1"]}, + mm_placeholders={"image": [PlaceholderRange(offset=1, length=1)]}, + mm_items={ + "image": [ + { + "pixel_values": _torch.tensor([[1.0, 2.0]], dtype=_torch.float32), + "image_grid_thw": _torch.tensor([[1, 2, 3]], dtype=_torch.int64), + } + ] + }, + ) + output = vf.RolloutOutput( + example_id=1, + is_filtered=True, + trajectory=[ + vf.TrajectoryStep( + prompt=[{"role": "user", "content": "Turn 1"}], + completion=[{"role": "assistant", "content": "Response 1"}], + response=MagicMock(), + tokens=vf.TrajectoryStepTokens( + prompt_ids=[1, 2], + prompt_mask=[0, 0], + completion_ids=[3, 4], + completion_mask=[1, 1], + completion_logprobs=[-0.1, -0.2], + overlong_prompt=False, + is_truncated=False, + multi_modal_data=mm, + ), + reward=None, + advantage=None, + is_truncated=False, + trajectory_id="1", + extras={}, + ), + ], + sampling_args={"temperature": 1.0}, + error=None, + ) + + rollouts = interleave_rollout(output, mm_token_type_ids_mapping={2: 1}) + + # Sample still produced (token bookkeeping for metrics) ... + assert rollouts is not None and len(rollouts) == 1 + # ... but no pixels materialized/packed, since it won't reach the trainer. + assert rollouts[0].mm_kwargs is None + assert rollouts[0].mm_token_type_ids is None + + +# ── offload_images_to_disk: data URIs + file:// normalization ───────── + + +def _image_rollout(url: str) -> dict: + """Minimal rollout whose single step prompt carries one image part.""" + return { + "trajectory": [{"prompt": [{"role": "user", "content": [{"type": "image_url", "image_url": {"url": url}}]}]}] + } + + +def _step_image_url(rollout: dict) -> str: + return rollout["trajectory"][0]["prompt"][0]["content"][0]["image_url"]["url"] + + +def test_offload_data_uri_writes_decoded_bytes(tmp_path): + raw = b"jpeg-bytes-abc" + uri = "data:image/jpeg;base64," + base64.b64encode(raw).decode("ascii") + rollout = _image_rollout(uri) + + n = offload_images_to_disk([rollout], tmp_path) + + assert n == 1 + url = _step_image_url(rollout) + assert url.startswith("file://") and url.endswith(".jpg") + path = Path(url[len("file://") :]) + assert path.parent == (tmp_path / "assets" / "images") + # The file holds the *decoded* bytes (content-addressed), not the base64. + assert path.read_bytes() == raw + + +def test_offload_leaves_file_url_already_in_assets(tmp_path): + images_dir = tmp_path / "assets" / "images" + images_dir.mkdir(parents=True) + existing = images_dir / "deadbeefdeadbeef.jpg" + existing.write_bytes(b"already-here") + url = f"file://{existing}" + rollout = _image_rollout(url) + + n = offload_images_to_disk([rollout], tmp_path) + + assert n == 0 # nothing copied + assert _step_image_url(rollout) == url # url untouched + + +def test_offload_copies_foreign_file_url_into_assets(tmp_path): + # The env worker offloaded to a shared cache outside this run's dir. + cache = tmp_path / "renderer-image-cache" + cache.mkdir() + raw = b"png-bytes-from-cache" + src = cache / "abc123.png" + src.write_bytes(raw) + out_dir = tmp_path / "run" + rollout = _image_rollout(f"file://{src}") + + n = offload_images_to_disk([rollout], out_dir) + + assert n == 1 + url = _step_image_url(rollout) + new_path = Path(url[len("file://") :]) + assert new_path.parent == (out_dir / "assets" / "images") + assert new_path.suffix == ".png" + assert new_path.read_bytes() == raw From 2bc548f9c43e3422a5e323c0c93c6d1acf70ef47 Mon Sep 17 00:00:00 2001 From: Codex Date: Fri, 29 May 2026 06:04:43 +0000 Subject: [PATCH 05/31] fix(orchestrator): resolve image-offload dir to absolute; bump deps/verifiers The offloaded image path becomes a file:// URL; a relative --output-dir yielded a malformed URI (file://rel/...) that the renderer could not load, failing turn 0 of every multimodal rollout. Resolve to absolute in both offload-dir wiring spots. Bump deps/verifiers bff4cb78 -> 4112bc0a (same fix in _image_offload_dir + regression test). Validated end-to-end: orchestrator offloads 11 unique images and converts rollouts to training examples cleanly (smoke stage1c). Co-Authored-By: Claude Opus 4.8 (1M context) --- deps/verifiers | 2 +- src/prime_rl/orchestrator/orchestrator.py | 4 +++- src/prime_rl/orchestrator/trajectories.py | 4 +++- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/deps/verifiers b/deps/verifiers index bff4cb786f..4112bc0a85 160000 --- a/deps/verifiers +++ b/deps/verifiers @@ -1 +1 @@ -Subproject commit bff4cb786fe12b681067e57890699489786af111 +Subproject commit 4112bc0a854518a42193bda83e65d950a2c8d010 diff --git a/src/prime_rl/orchestrator/orchestrator.py b/src/prime_rl/orchestrator/orchestrator.py index 92d5bbcb39..5d483426f6 100644 --- a/src/prime_rl/orchestrator/orchestrator.py +++ b/src/prime_rl/orchestrator/orchestrator.py @@ -210,9 +210,11 @@ async def orchestrate(config: OrchestratorConfig): # images land directly where the trainer reads them (no post-rollout copy) # and are cleaned up with the run. ``setdefault`` lets an explicit # ``VF_RENDERER_IMAGE_OFFLOAD_DIR`` override win. + # Absolute path: the env worker turns this into ``file://`` image URLs, and a + # relative path makes a malformed URI the renderer can't load. os.environ.setdefault( "VF_RENDERER_IMAGE_OFFLOAD_DIR", - str(config.output_dir / "assets" / "images"), + str((config.output_dir / "assets" / "images").resolve()), ) await train_envs.start( diff --git a/src/prime_rl/orchestrator/trajectories.py b/src/prime_rl/orchestrator/trajectories.py index d0d0de7992..bff9fd7566 100644 --- a/src/prime_rl/orchestrator/trajectories.py +++ b/src/prime_rl/orchestrator/trajectories.py @@ -642,7 +642,9 @@ def offload_images_to_disk(rollouts: list[vf.RolloutOutput], output_dir: Path) - offload, so an image offloaded during the rollout is recognized and never re-decoded or duplicated. Returns the number of unique images written here. """ - images_dir = output_dir / "assets" / "images" + # Absolute: paths become ``file://`` URLs; a relative path yields a malformed + # URI (``file://rel/...``) that the renderer can't load. + images_dir = (output_dir / "assets" / "images").resolve() images_dir.mkdir(parents=True, exist_ok=True) written: set[str] = set() From 9a57eccd0ce24ad5350c90ff413d68761347a498 Mon Sep 17 00:00:00 2001 From: Codex Date: Fri, 29 May 2026 20:20:16 +0000 Subject: [PATCH 06/31] fix(orchestrator): free per-step pixel mm_kwargs; drop dead offload-dir setdefault; bump deps/verifiers - Memory leak: the end-of-step cleanup deleted train_examples/training_batch but not 'results', which holds the same samples' mm_kwargs (full pixel byte-copies). That left the batch's pixels pinned past gc.collect()+malloc_trim, defeating the trim until 'results' was rebound a step later (multimodal-specific RSS ratchet). Add results + samples (and filter_df/timing_df) to the del. - Remove the VF_RENDERER_IMAGE_OFFLOAD_DIR setdefault: dead now that verifiers derives the offload dir from RUN_ID, and useless in prod anyway (env workers are separate pods that don't inherit os.environ). - Bump deps/verifiers 4112bc0a -> 91555323 (RUN_ID-derived offload dir). Co-Authored-By: Claude Opus 4.8 (1M context) --- deps/verifiers | 2 +- src/prime_rl/orchestrator/orchestrator.py | 25 ++++++++--------------- 2 files changed, 9 insertions(+), 18 deletions(-) diff --git a/deps/verifiers b/deps/verifiers index 4112bc0a85..9155532361 160000 --- a/deps/verifiers +++ b/deps/verifiers @@ -1 +1 @@ -Subproject commit 4112bc0a854518a42193bda83e65d950a2c8d010 +Subproject commit 91555323618fbd3435caf3d64f54b5449aee9aeb diff --git a/src/prime_rl/orchestrator/orchestrator.py b/src/prime_rl/orchestrator/orchestrator.py index 5d483426f6..9488e07ef2 100644 --- a/src/prime_rl/orchestrator/orchestrator.py +++ b/src/prime_rl/orchestrator/orchestrator.py @@ -203,20 +203,6 @@ async def orchestrate(config: OrchestratorConfig): env.sampling_args.pop("logprobs", None) logger.info(f"Loaded {len(train_envs)} training environment(s) ({', '.join(train_envs.names)})") - # Env workers (spawned by ``*.start`` → ``Env._spawn``, localhost subprocesses - # that inherit this process's ``os.environ``) offload screenshot base64 to - # disk during the live rollout to keep their RSS flat. Point them at this - # run's assets dir — the same one ``offload_images_to_disk`` writes to — so - # images land directly where the trainer reads them (no post-rollout copy) - # and are cleaned up with the run. ``setdefault`` lets an explicit - # ``VF_RENDERER_IMAGE_OFFLOAD_DIR`` override win. - # Absolute path: the env worker turns this into ``file://`` image URLs, and a - # relative path makes a malformed URI the renderer can't load. - os.environ.setdefault( - "VF_RENDERER_IMAGE_OFFLOAD_DIR", - str((config.output_dir / "assets" / "images").resolve()), - ) - await train_envs.start( log_dir=get_log_dir(config.output_dir.parent) / "envs" / "train", log_level=config.log.vf_level, @@ -776,9 +762,14 @@ def compute_solve_rates(df): progress.step += 1 is_first_step = False - # Free large per-step objects to prevent memory accumulation - del train_rollouts, train_examples, training_batch - del results_df, metrics_df + # Free large per-step objects to prevent memory accumulation. ``results`` + # is load-bearing for multimodal runs: it holds every sample's + # ``mm_kwargs`` (full pixel byte-copies), and it references the same + # sample objects as ``train_examples``/``training_batch`` — so deleting + # only those leaves the pixels pinned and the malloc_trim below cannot + # reclaim them until ``results`` is rebound a step later. + del train_rollouts, train_examples, training_batch, results, samples + del results_df, metrics_df, filter_df, timing_df gc.collect() # Return free glibc heap pages to the OS. numpy/pandas allocate array data # via malloc (outside Python's allocator), so gc.collect() alone doesn't From fa757d97bb60e4918810e7d04b6c73a1ddfcbf0b Mon Sep 17 00:00:00 2001 From: Eli Gottlieb <78387377+eligotts@users.noreply.github.com> Date: Sat, 30 May 2026 22:38:12 +0000 Subject: [PATCH 07/31] chore: bump deps/renderers + deps/verifiers to ephemeral-mm heads - deps/renderers -> d6ed224 (shared mm_store module; feature offload default-on + placeholder_length self-repair). - deps/verifiers -> 7ec7169c (offload dir from mm_store, prompt-only offload; token-prefix delta-baseline for descriptor-only mm_data). Co-Authored-By: Codex Co-Authored-By: Claude Opus 4.8 (1M context) --- deps/renderers | 2 +- deps/verifiers | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/deps/renderers b/deps/renderers index 73345a25a9..d6ed224839 160000 --- a/deps/renderers +++ b/deps/renderers @@ -1 +1 @@ -Subproject commit 73345a25a91c1e9d731fdccabe096492a2db6ea2 +Subproject commit d6ed2248394cc41b5df9e74ad4c0f2b384601596 diff --git a/deps/verifiers b/deps/verifiers index 9155532361..7ec7169cd8 160000 --- a/deps/verifiers +++ b/deps/verifiers @@ -1 +1 @@ -Subproject commit 91555323618fbd3435caf3d64f54b5449aee9aeb +Subproject commit 7ec7169cd887cfe7fac9c5cae97b9547448d40d5 From 9558f954247f9a9368a135abee7bd80821bcc442 Mon Sep 17 00:00:00 2001 From: Eli Gottlieb <78387377+eligotts@users.noreply.github.com> Date: Sat, 30 May 2026 22:38:34 +0000 Subject: [PATCH 08/31] feat(mm): MMRefs transport + trainer-side deferred materialization + metric When defer_mm_materialization is on, the orchestrator ships lightweight image refs (MMRefs: descriptor + file:// uris) instead of heavy pixel mm_kwargs, and the trainer materializes pixels in its data loader: - transport: MMRefs struct (array_like, appended LAST), carried on TrainingSample + MicroBatch alongside mm_kwargs. - utils/mm.py: materialize_mm_refs (hash-deduped decode-once-populate-all-slots via the renderer), reconstruct/pack/encode helpers, defer-validation hook. - trainer: build the renderer once (VLM runs only), materialize in _micro_batch_to_tensor, carry mm_refs through packing/prepare_batch; always-on pre-forward all-reduce so a materialization error fails every rank before the forward collective; log time/mm_materialize + mm/images_materialized. Co-Authored-By: Codex Co-Authored-By: Claude Opus 4.8 (1M context) --- src/prime_rl/trainer/batch.py | 8 +- src/prime_rl/trainer/rl/data.py | 49 +++++ src/prime_rl/trainer/rl/train.py | 35 +++- src/prime_rl/transport/__init__.py | 3 +- src/prime_rl/transport/types.py | 20 ++ src/prime_rl/utils/mm.py | 166 ++++++++++++++++ tests/unit/train/rl/test_packer.py | 113 +++++++++++ tests/unit/trainer/test_mm_refs.py | 304 +++++++++++++++++++++++++++++ 8 files changed, 694 insertions(+), 4 deletions(-) create mode 100644 src/prime_rl/utils/mm.py create mode 100644 tests/unit/trainer/test_mm_refs.py diff --git a/src/prime_rl/trainer/batch.py b/src/prime_rl/trainer/batch.py index ea99859a35..cb1f3a100a 100644 --- a/src/prime_rl/trainer/batch.py +++ b/src/prime_rl/trainer/batch.py @@ -132,13 +132,17 @@ def prepare_sample(training_example: TrainingSample, seq_len: int) -> MicroBatch mm_token_type_ids=mm_token_type_ids, env_names=env_names, mm_kwargs=training_example.mm_kwargs, + mm_refs=training_example.mm_refs, training_mode=training_example.training_mode, ) def _is_multimodal_sample(sample: MicroBatch) -> bool: - """Check if a sample contains multimodal data (images).""" - return sample.mm_kwargs is not None + """Check if a sample contains multimodal data (images). A deferred sample + carries ``mm_refs`` and no ``mm_kwargs``; both count as multimodal so it is + not mis-packed as text (which would break the FSDP per-step modality + invariant).""" + return sample.mm_kwargs is not None or sample.mm_refs is not None def packed_samples_into_micro_bs( diff --git a/src/prime_rl/trainer/rl/data.py b/src/prime_rl/trainer/rl/data.py index 45acdcb5c0..79958b3acf 100644 --- a/src/prime_rl/trainer/rl/data.py +++ b/src/prime_rl/trainer/rl/data.py @@ -1,8 +1,10 @@ +import time from pathlib import Path from typing import TypedDict import torch from jaxtyping import Bool, Float, Int +from renderers import RendererConfig from torch import Tensor from transformers.tokenization_utils import PreTrainedTokenizer @@ -11,6 +13,7 @@ from prime_rl.trainer.runs import get_multi_run_manager from prime_rl.trainer.world import get_world from prime_rl.transport import MicroBatch, MicroBatchReceiver, TransportConfig, setup_micro_batch_receiver +from prime_rl.utils.logger import get_logger class TensorMicroBatch(TypedDict): @@ -163,6 +166,8 @@ def __init__( pad_to_multiple_of: int, tokenizer: PreTrainedTokenizer, config: TransportConfig, + defer_mm_materialization: bool = False, + renderer_config: RendererConfig | None = None, ): self.world = get_world() @@ -182,6 +187,21 @@ def __init__( self.receiver: MicroBatchReceiver = setup_micro_batch_receiver(output_dir, dp_rank, start_step, config) + # Deferred materialization: each rank builds its own renderer once and + # materializes pixels from the shipped image references in get_batch. + self.defer_mm_materialization = defer_mm_materialization + self._renderer = None + # Build the renderer only when one is configured. With default-on defer, + # text-only runs leave renderer_config None and never receive mm_refs, so + # the materialize path below is simply never hit. + if defer_mm_materialization and renderer_config is not None: + from renderers.base import create_renderer + + self._renderer = create_renderer(tokenizer, renderer_config) + # Per-step materialization cost, surfaced as wandb time/mm_materialize. + self.last_mm_materialize_time = 0.0 + self.last_mm_images_materialized = 0 + def wait_for_batch(self) -> None: if self.world.is_master: self.packer._arm_watchdog() @@ -194,6 +214,8 @@ def wait_for_batch(self) -> None: def get_batch(self) -> list[TensorMicroBatch]: micro_batches = self.receiver.receive() + self.last_mm_materialize_time = 0.0 + self.last_mm_images_materialized = 0 return [self._micro_batch_to_tensor(mb) for mb in micro_batches] def _micro_batch_to_tensor(self, micro_batch: MicroBatch) -> TensorMicroBatch: @@ -210,6 +232,33 @@ def _micro_batch_to_tensor(self, micro_batch: MicroBatch) -> TensorMicroBatch: key: torch.frombuffer(bytearray(payload.data), dtype=_torch_dtype(payload.dtype)).reshape(payload.shape) for key, payload in micro_batch.mm_kwargs.items() } + elif micro_batch.mm_refs is not None: + # Deferred path: materialize pixels here from the shipped image + # references. Returns torch tensors directly (no decode needed). + # SCOPE (16a): this runs in every rank that holds the shard, so with + # TP/CP/EP the same images are read+processed non_dp_world_size times. + # Fine for DP-only; a per-DP-group materializer + broadcast is a 16b + # perf item for large model-parallel runs. + if self._renderer is None: + raise ValueError( + "Received mm_refs but the trainer has no renderer: orchestrator/trainer " + "defer_mm_materialization config mismatch (trainer flag is off)." + ) + from prime_rl.utils.mm import materialize_mm_refs + + try: + materialize_start = time.perf_counter() + mm_kwargs = materialize_mm_refs(self._renderer, micro_batch.mm_refs) + self.last_mm_materialize_time += time.perf_counter() - materialize_start + self.last_mm_images_materialized += len(micro_batch.mm_refs.uris) + except Exception as exc: + # The pre-forward all-reduce will fail-fast every rank, so make the + # culprit obvious: which run (from lora_num_tokens) and which images. + run_idx = next((i for i, n in enumerate(micro_batch.lora_num_tokens or []) if n > 0), None) + get_logger().error( + f"mm materialization failed (run_idx={run_idx}, uris={micro_batch.mm_refs.uris}): {exc!r}" + ) + raise routed_experts = None packed_routed_experts = micro_batch.routed_experts if packed_routed_experts is not None: diff --git a/src/prime_rl/trainer/rl/train.py b/src/prime_rl/trainer/rl/train.py index c0b451586d..2e6d1dcfbb 100644 --- a/src/prime_rl/trainer/rl/train.py +++ b/src/prime_rl/trainer/rl/train.py @@ -121,6 +121,15 @@ def train(config: TrainerConfig): config.output_dir, config.max_concurrent_runs, torch.device("cuda", world.local_rank), config.model.lora ) + # Reject (at discovery) runs whose deferred-MM config is incompatible with the + # trainer — a misconfigured run shouldn't crash all ranks later in get_batch. + if world.is_master: + from prime_rl.utils.mm import make_defer_mm_validation_hook + + multi_run_manager.register_config_validation_hook( + make_defer_mm_validation_hook(config.defer_mm_materialization, config.renderer) + ) + # Initialize parallel dimensions parallel_dims = get_parallel_dims(config.model) @@ -238,6 +247,10 @@ def load_run_checkpoint(_optimizer, idx: int) -> None: config.model.cp, tokenizer, config.rollout_transport, + defer_mm_materialization=config.defer_mm_materialization, + # Only VLM runs materialize pixels; text-only runs leave this None so + # default-on defer never builds an unused renderer for them. + renderer_config=config.renderer if config.model.vlm is not None else None, ) token_exporter = setup_token_exporter(config, parallel_dims, world, logger) @@ -336,7 +349,23 @@ def load_run_checkpoint(_optimizer, idx: int) -> None: # Load the training batch logger.debug("Loading batch") load_data_start_time = time.perf_counter() - micro_batches = dataloader.get_batch() + # A get_batch failure on one rank (deferred-MM materialization error, a + # missing file://, or an orchestrator/trainer flag mismatch) must fail all + # ranks before the forward collective, or survivors hang in NCCL. The + # per-step int all-reduce is negligible and protects every run, not just MM. + load_error: Exception | None = None + try: + micro_batches = dataloader.get_batch() + except Exception as exc: + load_error = exc + failed_flag = torch.tensor(1 if load_error else 0, dtype=torch.int64, device="cuda") + dist.all_reduce(failed_flag, op=dist.ReduceOp.MAX) + if failed_flag.item() != 0: + # Preserve the culprit rank's traceback; bystander ranks still raise + # so none proceeds into the forward collective alone. + if load_error is not None: + raise RuntimeError("Training-batch load failed on this rank; failing all ranks.") from load_error + raise RuntimeError("Training-batch load failed on another rank; failing all ranks.") load_data_time = time.perf_counter() - load_data_start_time logger.debug(f"Loaded batch in {load_data_time:.2f} seconds") @@ -627,6 +656,10 @@ def load_run_checkpoint(_optimizer, idx: int) -> None: "time/step": step_time, "time/wait_for_batch": wait_for_batch_time, "time/load_data": load_data_time, + # Synchronous trainer-side MM materialization (decode + vision + # preprocessing from shipped refs), a subset of time/load_data. + "time/mm_materialize": getattr(dataloader, "last_mm_materialize_time", 0.0), + "mm/images_materialized": getattr(dataloader, "last_mm_images_materialized", 0), "time/broadcast_weights": broadcast_weights_time, "time/save_ckpt": save_ckpt_time, "time/forward_backward": forward_backward_time, diff --git a/src/prime_rl/transport/__init__.py b/src/prime_rl/transport/__init__.py index bad9d6c806..8773be02ab 100644 --- a/src/prime_rl/transport/__init__.py +++ b/src/prime_rl/transport/__init__.py @@ -8,7 +8,7 @@ FileSystemTrainingBatchReceiver, FileSystemTrainingBatchSender, ) -from prime_rl.transport.types import MicroBatch, RoutedExperts, TrainingBatch, TrainingSample +from prime_rl.transport.types import MicroBatch, MMRefs, RoutedExperts, TrainingBatch, TrainingSample from prime_rl.transport.zmq import ( ZMQMicroBatchReceiver, ZMQMicroBatchSender, @@ -67,6 +67,7 @@ def setup_micro_batch_receiver( "TrainingSample", "TrainingBatch", "MicroBatch", + "MMRefs", "RoutedExperts", "setup_training_batch_sender", "setup_training_batch_receiver", diff --git a/src/prime_rl/transport/types.py b/src/prime_rl/transport/types.py index 1bb31c9325..8b6ecb7953 100644 --- a/src/prime_rl/transport/types.py +++ b/src/prime_rl/transport/types.py @@ -14,6 +14,17 @@ class EncodedTensor(msgspec.Struct, array_like=True, gc=False): data: bytes +# Lightweight image references shipped instead of materialized pixels when +# defer_mm_materialization is on: the orchestrator emits these and the trainer +# materializes pixels in its data loader. +class MMRefs(msgspec.Struct, array_like=True, gc=False): + # Descriptor-only mm_data ({"mm_items": {...grid/placeholder...}, "mm_hashes": {...}}), + # transport-safe (grids arrive as msgpack wire payloads, hashes as str). + the + # candidate file:// image URIs for this sample. Trainer materializes from these. + descriptor: dict + uris: list[str] + + # Routed experts are large per-token arrays. tolist() is too expensive, so we # send raw bytes through msgpack and carry the shape/dtype needed to rebuild. class RoutedExperts(msgspec.Struct, array_like=True, gc=False, omit_defaults=True): @@ -56,6 +67,12 @@ class TrainingSample(msgspec.Struct, array_like=True, gc=False, omit_defaults=Tr # taus), sft uses sft_loss_fn. Stamped by the orchestrator from training_mode. training_mode: TrainingMode = "rl" + # Lightweight image references (deferred materialization). Exactly one of + # {mm_kwargs, mm_refs} is populated per multimodal sample. APPENDED LAST: + # array_like=True structs encode positionally, so new fields must go at the + # end to preserve the wire positions of existing fields. + mm_refs: MMRefs | None = None + class TrainingBatch(msgspec.Struct, array_like=True, gc=False, omit_defaults=True): """A batch of training examples with metadata for transport.""" @@ -89,3 +106,6 @@ class MicroBatch(msgspec.Struct, array_like=True, gc=False, omit_defaults=True): # sft → sft loss). All samples packed into a micro batch share the same mode. training_mode: TrainingMode = "rl" rewards: list[float] | None = None + # See TrainingSample.mm_refs. APPENDED LAST — array_like=True is positional, so + # new fields go at the end to preserve existing field wire positions. + mm_refs: MMRefs | None = None diff --git a/src/prime_rl/utils/mm.py b/src/prime_rl/utils/mm.py new file mode 100644 index 0000000000..6c5d9e2207 --- /dev/null +++ b/src/prime_rl/utils/mm.py @@ -0,0 +1,166 @@ +"""Shared multimodal helpers used by both the orchestrator (flag-off path, +materialize pixels then ship heavy mm_kwargs) and the trainer (flag-on path, +ship lightweight mm_refs then materialize pixels in the data loader). + +Factoring these here keeps the two paths byte-identical and avoids a +trainer→orchestrator import dependency. +""" + +from typing import Any + +import torch + +from prime_rl.transport.types import EncodedTensor, MMRefs + + +def reconstruct_mm_pixels(renderer: Any, descriptor: dict, messages: list) -> Any: + """Re-attach ``pixel_values`` to a descriptor-only union mm_data. + + Delegates to the renderer's ``materialize_pixels`` (hash-matched reprocess + of the window images, with a ``grid_thw`` assert). The descriptor's + ``image_grid_thw`` is decoded from its msgpack wire shape back to numpy + first, so the renderer's numpy-vs-numpy grid assert holds after transport. + """ + from renderers.base import MultiModalData + from verifiers.utils.serve_utils import decode_tensor_payload + + items = descriptor.get("mm_items") or {} + decoded_items: dict[str, list] = {} + for modality, lst in items.items(): + new_lst: list[dict[str, Any]] = [] + for item in lst or []: + item = dict(item) + grid = item.get("image_grid_thw") + if item.get("pixel_values") is None and grid is not None: + item["image_grid_thw"] = decode_tensor_payload(grid, to_torch=False) + new_lst.append(item) + decoded_items[modality] = new_lst + + md = MultiModalData( + mm_hashes=descriptor.get("mm_hashes") or {}, + mm_placeholders={}, + mm_items=decoded_items, + ) + return renderer.materialize_pixels(md, messages) + + +def pack_mm_kwargs_tensors(mm_data: Any) -> "dict[str, torch.Tensor] | None": + """Batch the renderer's per-image ``mm_items`` into model-agnostic forward + kwargs, returning torch tensors (not encoded bytes). + + ``mm_data`` may arrive as a ``MultiModalData`` instance (in-process for + tests) or as a plain dict (after msgpack round-trip from the env-worker). + Each item is a dict keyed by the names the model's ``forward`` expects + (``pixel_values`` + ``image_grid_thw`` for Qwen3-VL, just ``pixel_values`` + for Gemma3-VL, etc.). We batch by ``torch.cat(..., dim=0)`` per key — + generic because every HF VLM processor emits a leading batch/patch + dimension, and the renderer always processes one image per call. + + Returns a dict of torch tensors keyed by kwarg name, or ``None`` when no + multimodal data is present. + """ + from verifiers.utils.serve_utils import decode_tensor_payload + + mm_items = mm_data.mm_items if hasattr(mm_data, "mm_items") else (mm_data or {}).get("mm_items") or {} + # Flatten across modalities into one kwarg dict — the model's forward + # signature is the schema. ``mm_items`` is typically ``{"image": [...], + # "video": [...]}`` but each modality's keys don't collide for any HF VLM + # we ship today. + per_kwarg: dict[str, list] = {} + for _modality, items in mm_items.items(): + for item in items or []: + for key, payload in item.items(): + # ``decode_tensor_payload`` rehydrates the encoded wire shape to + # torch but passes already-rehydrated numpy through unchanged. + # ``as_tensor`` normalizes both to torch so ``torch.cat`` below + # is uniform. + per_kwarg.setdefault(key, []).append(torch.as_tensor(decode_tensor_payload(payload))) + if not per_kwarg: + return None + out: dict[str, torch.Tensor] = {} + for key, tensors in per_kwarg.items(): + out[key] = torch.cat(tensors, dim=0).contiguous() + return out + + +def encode_mm_kwargs(tensors: dict[str, torch.Tensor]) -> dict[str, EncodedTensor]: + """Encode packed torch tensors into ``EncodedTensor`` wire payloads.""" + out: dict[str, EncodedTensor] = {} + for key, cat in tensors.items(): + arr = cat.detach().cpu().numpy() + out[key] = EncodedTensor(dtype=str(arr.dtype), shape=list(arr.shape), data=arr.tobytes()) + return out + + +def build_image_messages(uris: list[str]) -> list[dict]: + """Minimal messages ``materialize_pixels`` hash-matches against. Order and + duplicates are harmless — matching dedups by hash.""" + return [{"role": "user", "content": [{"type": "image_url", "image_url": {"url": u}} for u in uris]}] + + +def image_uris_from_messages(messages: list) -> list[str]: + """Collect every image URI from message ``content`` lists. Keeps order; + duplicates are fine. Accepts ``file://`` (offloaded) and ``data:image`` + (in-process/non-offloaded).""" + uris: list[str] = [] + for msg in messages: + if not isinstance(msg, dict): + continue + content = msg.get("content", []) + if not isinstance(content, list): + continue + for item in content: + if not isinstance(item, dict) or item.get("type") != "image_url": + continue + image_url = item.get("image_url") + if not isinstance(image_url, dict): + continue + url = image_url.get("url", "") + if isinstance(url, str) and (url.startswith("file://") or url.startswith("data:image")): + uris.append(url) + return uris + + +def materialize_mm_refs(renderer: Any, refs: MMRefs) -> "dict[str, torch.Tensor] | None": + """Trainer entry point: reconstruct pixels from refs and pack into forward + kwargs (torch tensors).""" + return pack_mm_kwargs_tensors(reconstruct_mm_pixels(renderer, refs.descriptor, build_image_messages(refs.uris))) + + +def make_defer_mm_validation_hook(trainer_defers: bool, trainer_renderer: Any): + """Build a MultiRunManager config-validation hook that vets each discovered + run's orchestrator config against the trainer for deferred materialization. + + On failure the run is rejected at discovery (``get_orchestrator_config`` writes + ``config_validation_error.txt`` and returns None, so the run is never registered + or packed — this is rejection, not ``evicted.txt`` eviction). That keeps one + misconfigured run from either crashing all ranks later inside ``get_batch`` or + silently materializing with the wrong image processor. + """ + + def validate(orch_config: Any) -> "tuple[bool, str]": + if not getattr(orch_config, "defer_mm_materialization", False): + return True, "" # run ships pixels (mm_kwargs) — trainer handles regardless + if not trainer_defers: + return False, ( + "run sets defer_mm_materialization=true but the trainer does not — the trainer has no " + "renderer to materialize mm_refs and the run's batches would fail in get_batch. " + "Enable defer_mm_materialization and set [renderer] on the trainer." + ) + # Both defer: the trainer materializes every run's images with its single + # renderer, so the families must match. Auto resolves against the (shared) + # base model → compatible. Within-family processor drift is backstopped by + # the grid skew-assert in materialize_pixels. + orch_r = getattr(orch_config, "renderer", None) + if ( + orch_r is not None + and type(orch_r).__name__ != "AutoRendererConfig" + and type(orch_r) is not type(trainer_renderer) + ): + return False, ( + f"run renderer {type(orch_r).__name__} is a different family than the trainer renderer " + f"{type(trainer_renderer).__name__}; deferred materialization would use the wrong image processor." + ) + return True, "" + + return validate diff --git a/tests/unit/train/rl/test_packer.py b/tests/unit/train/rl/test_packer.py index d1f1c3634f..e40e7e89ff 100644 --- a/tests/unit/train/rl/test_packer.py +++ b/tests/unit/train/rl/test_packer.py @@ -110,3 +110,116 @@ def fake_sender(_output_dir, _data_world_size, _current_step, _config): sender = sender_holder["sender"] assert len(sender.sent) == 1 assert len(sender.sent[0][0]) == 1 + + +def _mm_sample(uri: str) -> TrainingSample: + """Deferred-MM TrainingSample (carries mm_refs, no mm_kwargs).""" + from prime_rl.transport.types import MMRefs + + return TrainingSample( + prompt_ids=[1], + prompt_mask=[False], + completion_ids=[2], + completion_mask=[True], + completion_logprobs=[-0.1], + completion_temperatures=[1.0], + env_name="test-env", + mm_token_type_ids=[1, 1], + mm_refs=MMRefs( + descriptor={"mm_items": {"image": [{"image_grid_thw": [[1, 2, 3]]}]}, "mm_hashes": {"image": [uri]}}, + uris=[uri], + ), + ) + + +def _packer_with_two_runs(tmp_path, monkeypatch, dp_world_size, seq_len): + """Set up a MultiPacker over two discovered runs; capture sent grids.""" + reset_world() + runs._MULTI_RUN_MANAGER = None + manager = setup_multi_run_manager(output_dir=tmp_path, max_runs=2, device=torch.device("cpu")) + create_run_with_config(tmp_path, "run_a") + create_run_with_config(tmp_path, "run_b") + manager.discover_runs() + + sent: list = [] + + class DummyReceiver: + def receive(self): + return [] + + def reset_run(self, idx): + pass + + class DummySender: + def send(self, micro_batch_grid): + sent.append(micro_batch_grid) + + monkeypatch.setattr("prime_rl.trainer.rl.packer.setup_training_batch_receiver", lambda _c: DummyReceiver()) + monkeypatch.setattr("prime_rl.trainer.rl.packer.setup_micro_batch_sender", lambda *a, **k: DummySender()) + packer = MultiPacker( + dp_world_size=dp_world_size, + seq_len=seq_len, + pad_to_multiple_of=1, + tokenizer=None, + config=FileSystemTransportConfig(), + start_step=0, + ) + return manager, packer, sent + + +def test_multipacker_pack_preserves_mm_refs_modality_and_run_tagging(tmp_path, monkeypatch): + """The REAL MultiPacker.pack() path (per-run buffers → round-robin → prepare_batch + → merged rank grids) preserves deferred-MM correctness across runs.""" + from prime_rl.trainer.batch import _is_multimodal_sample + + manager, packer, sent = _packer_with_two_runs(tmp_path, monkeypatch, dp_world_size=2, seq_len=3) + a, b = manager.id_2_idx["run_a"], manager.id_2_idx["run_b"] + for idx, uri in ((a, "ha"), (b, "hb")): + packer.buffers[idx].append((_mm_sample(uri), 0)) + packer.buffers[idx].append((make_training_sample(), 0)) + + packer.pack() + assert sent, "pack() sent nothing" + grid = sent[-1] # list[per-rank list[MicroBatch]] + assert len(grid) == 2 + + # FSDP safety: every rank has the same modality at each micro-step index. + for step_mbs in zip(*grid): + assert len({_is_multimodal_sample(mb) for mb in step_mbs}) == 1 + + mm_mbs = [mb for rank in grid for mb in rank if _is_multimodal_sample(mb)] + assert mm_mbs, "no MM microbatches produced" + real_run_idxs = set() + for mb in mm_mbs: + # Every MM microbatch is a standalone deferred-refs sequence (never pixels, + # never packed with text). Dummies deep-copy the source so they keep mm_refs + # too — they're distinguished by an all-False loss_mask. + assert mb.mm_kwargs is None and mb.mm_refs is not None + if any(mb.loss_mask): # real (loss-bearing) MM → tagged to exactly one run + tagged = [i for i, n in enumerate(mb.lora_num_tokens) if n > 0] + assert len(tagged) == 1 and mb.lora_num_tokens[tagged[0]] == len(mb.input_ids) + real_run_idxs.add(tagged[0]) + assert real_run_idxs == {a, b}, f"both runs' MM should be tagged; got {real_run_idxs}" + + +def test_multipacker_pack_mm_padding_is_zero_loss(tmp_path, monkeypatch): + """A lone MM sample forces a dummy MM microbatch for rank padding; it must be + zero-loss (and keep MM modality so all ranks still run the vision encoder).""" + from prime_rl.trainer.batch import _is_multimodal_sample + + manager, packer, sent = _packer_with_two_runs(tmp_path, monkeypatch, dp_world_size=2, seq_len=2) + a, b = manager.id_2_idx["run_a"], manager.id_2_idx["run_b"] + packer.buffers[a].append((_mm_sample("ha"), 0)) # one MM → needs a dummy to fill 2 ranks + packer.buffers[b].append((make_training_sample(), 0)) + + packer.pack() + assert sent + grid = sent[-1] + mm_mbs = [mb for rank in grid for mb in rank if _is_multimodal_sample(mb)] + # A dummy keeps MM modality (so all ranks run the vision encoder) but is + # zero-loss; it's identified by an all-False loss_mask, not by missing mm_refs + # (the dummy deep-copies the source, mm_refs included). + dummies = [mb for mb in mm_mbs if not any(mb.loss_mask)] + assert dummies, "expected a zero-loss dummy MM padding microbatch" + for d in dummies: + assert all(a == 0.0 for a in d.advantages) diff --git a/tests/unit/trainer/test_mm_refs.py b/tests/unit/trainer/test_mm_refs.py new file mode 100644 index 0000000000..ec3614591c --- /dev/null +++ b/tests/unit/trainer/test_mm_refs.py @@ -0,0 +1,304 @@ +"""Tests for deferred multimodal materialization (Phase 16a). + +The orchestrator ships lightweight image references (``mm_refs``) and the +trainer materializes pixels from them via ``prime_rl.utils.mm``, reusing the +same materialize/pack code as the orchestrator's flag-off path so parity and +the duplicate-image guarantee hold by construction. +""" + +import hashlib +from dataclasses import replace + +import msgspec +import pytest +import torch +from renderers.base import MultiModalData + +from prime_rl.orchestrator.trajectories import _collect_mm_refs, _pack_mm_kwargs_from_renderer, _reconstruct_mm_pixels +from prime_rl.trainer.batch import _is_multimodal_sample +from prime_rl.transport.types import MicroBatch, MMRefs, TrainingSample +from prime_rl.utils.mm import ( + build_image_messages, + encode_mm_kwargs, + pack_mm_kwargs_tensors, + reconstruct_mm_pixels, +) + + +def _uri_hash(uri: str) -> str: + return hashlib.sha256(uri.encode()).hexdigest()[:16] + + +class _StubRenderer: + """Mirrors ``materialize_image_pixels``: resolves each URI to a deterministic + pixel tensor by content hash, decodes each unique hash once, and populates + every descriptor slot referencing that hash (the duplicate guarantee).""" + + def __init__(self, pixels_by_hash: dict[str, torch.Tensor]): + self._pixels_by_hash = pixels_by_hash + + def materialize_pixels(self, mm_data: MultiModalData, messages: list) -> MultiModalData: + image_items = mm_data.mm_items.get("image") or [] + hashes = mm_data.mm_hashes.get("image") or [] + # Decode each referenced URI once (dedup by hash), like the real renderer. + decoded: dict[str, torch.Tensor] = {} + for msg in messages: + for part in msg.get("content", []): + url = part.get("image_url", {}).get("url", "") + h = _uri_hash(url) + if h not in decoded: + decoded[h] = self._pixels_by_hash[h] + new_items = [] + for i, item in enumerate(image_items): + h = hashes[i] + new_items.append({"pixel_values": decoded[h], "image_grid_thw": item["image_grid_thw"]}) + return replace(mm_data, mm_items={**mm_data.mm_items, "image": new_items}) + + +def _grid_payload(g: list[int]) -> dict: + # Wire shape as produced by verifiers' msgpack_encoder (env→orch hop). + arr = torch.tensor([g], dtype=torch.int64).numpy() + return {"__torch_tensor__": True, "dtype": "int64", "shape": list(arr.shape), "data": arr.tobytes()} + + +def _descriptor(uris: list[str], grids: list[list[int]]) -> dict: + """Descriptor-only mm_data keyed by URI content hash, grids as wire payloads.""" + return { + "mm_items": {"image": [{"image_grid_thw": _grid_payload(g)} for g in grids]}, + "mm_hashes": {"image": [_uri_hash(u) for u in uris]}, + } + + +def test_golden_parity_trainer_matches_orchestrator(): + """Trainer-side encode(pack(reconstruct)) is byte-identical to the + orchestrator's existing _pack_mm_kwargs_from_renderer(_reconstruct_mm_pixels).""" + uris = ["file:///a.jpg", "file:///b.jpg"] + grids = [[1, 2, 3], [1, 4, 4]] + pixels = {_uri_hash(uris[0]): torch.tensor([[1.0, 2.0]]), _uri_hash(uris[1]): torch.tensor([[3.0, 4.0]])} + renderer = _StubRenderer(pixels) + + descriptor = _descriptor(uris, grids) + messages = build_image_messages(uris) + + # Trainer path (utils.mm). + trainer_kwargs = encode_mm_kwargs(pack_mm_kwargs_tensors(reconstruct_mm_pixels(renderer, descriptor, messages))) + # Orchestrator path (trajectories delegates to the same code). + orch_kwargs = _pack_mm_kwargs_from_renderer(_reconstruct_mm_pixels(renderer, _descriptor(uris, grids), messages)) + + assert trainer_kwargs.keys() == orch_kwargs.keys() + for key in trainer_kwargs: + assert trainer_kwargs[key].dtype == orch_kwargs[key].dtype + assert trainer_kwargs[key].shape == orch_kwargs[key].shape + assert trainer_kwargs[key].data == orch_kwargs[key].data + + +def test_duplicate_image_decoded_once_populated_per_slot(): + """A descriptor with the SAME hash in two slots + one URI → both slots get + identical pixel tensors (decoded once).""" + uri = "file:///dup.jpg" + h = _uri_hash(uri) + pixels = {h: torch.tensor([[7.0, 8.0]])} + renderer = _StubRenderer(pixels) + + # Two item slots, same hash, one URI. + descriptor = { + "mm_items": { + "image": [{"image_grid_thw": _grid_payload([1, 2, 3])}, {"image_grid_thw": _grid_payload([1, 2, 3])}] + }, + "mm_hashes": {"image": [h, h]}, + } + refs = MMRefs(descriptor=descriptor, uris=[uri]) + + from prime_rl.utils.mm import materialize_mm_refs + + kwargs = materialize_mm_refs(renderer, refs) + pv = kwargs["pixel_values"] + assert pv.shape[0] == 2 + assert torch.equal(pv[0], pv[1]) + assert torch.equal(pv[0], torch.tensor([7.0, 8.0])) + + +def test_mm_refs_msgpack_round_trip(): + """MMRefs (descriptor + uris) encodes+decodes through msgpack cleanly — + catches a stray tensor/numpy left in the descriptor.""" + descriptor = _descriptor(["file:///a.jpg"], [[1, 2, 3]]) + refs = MMRefs(descriptor=descriptor, uris=["file:///a.jpg"]) + + raw = msgspec.msgpack.encode(refs) + decoded = msgspec.msgpack.decode(raw, type=MMRefs) + + assert decoded.uris == refs.uris + assert decoded.descriptor["mm_hashes"] == descriptor["mm_hashes"] + + +def test_micro_batch_with_mm_refs_is_multimodal(): + """A MicroBatch carrying mm_refs and no mm_kwargs is classified multimodal.""" + refs = MMRefs(descriptor={"mm_items": {}, "mm_hashes": {}}, uris=["file:///a.jpg"]) + mb = MicroBatch( + input_ids=[1, 2], + loss_mask=[True, True], + advantages=[0.0, 0.0], + inference_logprobs=[0.0, 0.0], + position_ids=[0, 1], + temperatures=[1.0, 1.0], + env_names=["e", "e"], + mm_refs=refs, + ) + assert mb.mm_kwargs is None + assert _is_multimodal_sample(mb) + + +def test_collect_mm_refs_normalizes_raw_tensor_descriptor(): + """_collect_mm_refs must produce a transport-safe, descriptor-ONLY MMRefs even + when the union mm_data holds raw torch tensors (in-process path) — pixels + dropped, grids → list[int], hashes → str, and msgpack-encodable. The bare + msgspec encoder rejects tensors, so this is the edge the old code missed.""" + # Real Qwen grids are 2-D (1, 3); the flat (3,) shape would mask a + # shape-preservation bug since _grids_equal compares [[t,h,w]] nested. + union_mm = { + "mm_items": {"image": [{"pixel_values": torch.zeros(4, 8), "image_grid_thw": torch.tensor([[1, 2, 3]])}]}, + "mm_hashes": {"image": ["abc123"]}, + } + trajectory = [ + {"prompt": [{"role": "user", "content": [{"type": "image_url", "image_url": {"url": "file:///a.jpg"}}]}]} + ] + refs = _collect_mm_refs(union_mm, trajectory, [0]) + + item = refs.descriptor["mm_items"]["image"][0] + assert "pixel_values" not in item # descriptor-only + assert item["image_grid_thw"] == [[1, 2, 3]] # (1, 3) nesting preserved, transport-safe lists + assert refs.descriptor["mm_hashes"]["image"] == ["abc123"] + assert refs.uris == ["file:///a.jpg"] + # Must survive the bare msgspec encoder used by the batch sender. + msgspec.msgpack.decode(msgspec.msgpack.encode(refs), type=MMRefs) + + +@pytest.mark.parametrize( + "union_mm", + [ + {"mm_items": {"video": [{"image_grid_thw": [[1, 2, 3]]}]}, "mm_hashes": {"video": ["v"]}}, + {"mm_items": {}, "mm_hashes": {"video": ["v"]}}, # non-image present only in mm_hashes + ], + ids=["in_mm_items", "in_mm_hashes_only"], +) +def test_collect_mm_refs_rejects_non_image_modality(union_mm): + """Image-only this iteration: a non-image modality in EITHER mm_items or + mm_hashes must fail loudly, not silently drop to empty uris.""" + with pytest.raises(ValueError, match="image modality only"): + _collect_mm_refs(union_mm, [], []) + + +def test_collect_mm_refs_rejects_non_qwen_image_descriptor(): + """Renderer-family guard: an image item without image_grid_thw (e.g. a + non-Qwen renderer keyed on grid_thws) must fail loudly, not ship a None grid.""" + union_mm = {"mm_items": {"image": [{"grid_thws": [[1, 2, 3]]}]}, "mm_hashes": {"image": ["h"]}} + with pytest.raises(ValueError, match="Qwen-style image descriptors"): + _collect_mm_refs(union_mm, [], []) + + +def _mm_sample(uri: str, env: str = "e", n_prompt: int = 4, n_comp: int = 4): + """A multimodal TrainingSample carrying deferred mm_refs (post-normalization: + plain-list grid, str hash keyed by URI content hash so the stub renderer matches).""" + refs = MMRefs( + descriptor={"mm_items": {"image": [{"image_grid_thw": [[1, 2, 3]]}]}, "mm_hashes": {"image": [_uri_hash(uri)]}}, + uris=[uri], + ) + return TrainingSample( + prompt_ids=list(range(n_prompt)), + prompt_mask=[False] * n_prompt, + completion_ids=list(range(n_comp)), + completion_mask=[True] * n_comp, + completion_logprobs=[0.0] * n_comp, + completion_temperatures=[1.0] * n_comp, + env_name=env, + advantage=0.0, + reward=0.0, + mm_token_type_ids=[1] * (n_prompt + n_comp), + mm_refs=refs, + ) + + +def _text_sample(env: str = "e", n_prompt: int = 4, n_comp: int = 4): + return TrainingSample( + prompt_ids=list(range(n_prompt)), + prompt_mask=[False] * n_prompt, + completion_ids=list(range(n_comp)), + completion_mask=[True] * n_comp, + completion_logprobs=[0.0] * n_comp, + completion_temperatures=[1.0] * n_comp, + env_name=env, + advantage=0.0, + reward=0.0, + ) + + +def test_multirun_packing_preserves_mm_refs_modality_and_run_tagging(): + """Multi-run: deferred mm_refs samples from 2 runs pack correctly through the + REAL prepare_batch — each MM sample is its own microbatch carrying its mm_refs, + tagged to exactly one run via lora_num_tokens, and the modality-separated + strided distribution keeps every rank on the same modality per step index + (the FSDP vision-encoder safety property).""" + from prime_rl.trainer.batch import prepare_batch + + uri0, uri1 = "file:///run0.jpg", "file:///run1.jpg" + rollouts = [_mm_sample(uri0), _text_sample(), _mm_sample(uri1), _text_sample()] + idxs = [0, 0, 1, 1] # run 0 and run 1 + grid = prepare_batch(rollouts, seq_len=64, num_train_workers=2, idxs=idxs, num_loras=2) + + assert len(grid) == 2 # 2 dp ranks + # FSDP safety: at each step index, both ranks see the SAME modality. + for step_mbs in zip(*grid): + modalities = {_is_multimodal_sample(mb) for mb in step_mbs} + assert len(modalities) == 1, "ranks diverge in modality at a step index → FSDP all-gather would hang" + + # Every MM microbatch carries its mm_refs (not pixels) and is tagged to one run. + mm_mbs = [mb for rank in grid for mb in rank if _is_multimodal_sample(mb) and mb.mm_refs is not None] + assert len(mm_mbs) == 2 # one per run (padding microbatches have no mm_refs) + for mb in mm_mbs: + assert mb.mm_kwargs is None + tagged = [i for i, n in enumerate(mb.lora_num_tokens) if n > 0] + assert len(tagged) == 1 and mb.lora_num_tokens[tagged[0]] == len(mb.input_ids) + + # Run-agnostic materialization: one shared renderer materializes either run's refs. + renderer = _StubRenderer({_uri_hash(uri0): torch.tensor([[5.0, 6.0]]), _uri_hash(uri1): torch.tensor([[7.0, 8.0]])}) + from prime_rl.utils.mm import materialize_mm_refs + + for mb in mm_mbs: + kwargs = materialize_mm_refs(renderer, mb.mm_refs) + assert kwargs is not None and kwargs["pixel_values"].shape[0] == 1 + + +@pytest.mark.parametrize( + "trainer_defers, trainer_renderer_name, run_defers, run_renderer_name, expected_ok", + [ + # Run doesn't defer → always fine (ships pixels; trainer handles either way). + (True, "Qwen3VLRendererConfig", False, "Qwen3RendererConfig", True), + (False, None, False, None, True), + # Run defers but trainer doesn't → reject (no renderer to materialize). + (False, None, True, "Qwen3VLRendererConfig", False), + # Both defer, same renderer family → ok. + (True, "Qwen3VLRendererConfig", True, "Qwen3VLRendererConfig", True), + # Both defer, run uses Auto → ok (resolves against the shared base model). + (True, "Qwen3VLRendererConfig", True, "AutoRendererConfig", True), + # Both defer, different renderer family → reject (wrong image processor). + (True, "Qwen3VLRendererConfig", True, "Qwen3RendererConfig", False), + ], +) +def test_defer_mm_validation_hook_matrix( + trainer_defers, trainer_renderer_name, run_defers, run_renderer_name, expected_ok +): + """Direct coverage of the discovery-time config rejection matrix.""" + from types import SimpleNamespace + + import renderers + + from prime_rl.utils.mm import make_defer_mm_validation_hook + + def _cfg(name): + return getattr(renderers, name)() if name else None + + hook = make_defer_mm_validation_hook(trainer_defers, _cfg(trainer_renderer_name)) + orch_config = SimpleNamespace(defer_mm_materialization=run_defers, renderer=_cfg(run_renderer_name)) + ok, msg = hook(orch_config) + assert ok is expected_ok + assert (msg == "") is expected_ok # rejection carries a non-empty reason From b0e1e4d5773473b6005795cddb9af70267568c9c Mon Sep 17 00:00:00 2001 From: Eli Gottlieb <78387377+eligotts@users.noreply.github.com> Date: Sat, 30 May 2026 22:38:58 +0000 Subject: [PATCH 09/31] feat(mm): orchestrator ships mm_refs + canonical run-scoped offload/sweep - trajectories: _collect_mm_refs normalizes the per-sample descriptor (strip pixel_values, grids -> nested lists, hashes -> str, image-only guard) and interleave_rollout(defer_materialization) ships MMRefs instead of materializing pixels; offload_images_to_disk writes under mm_store's assets/images subdir. - orchestrator: resolve a canonical asset_root = run_dir(RUN_ID) (falls back to output_dir locally) and use it for BOTH the image offload and the per-step TTL sweep, so the env worker, orchestrator, and sweeper share one /data/outputs/run_/assets tree (no re-copy, sweep covers features too). Also free per-step results/samples to drop the multimodal RSS ratchet. Co-Authored-By: Codex Co-Authored-By: Claude Opus 4.8 (1M context) --- src/prime_rl/orchestrator/orchestrator.py | 24 +++- src/prime_rl/orchestrator/trajectories.py | 161 +++++++++++----------- 2 files changed, 101 insertions(+), 84 deletions(-) diff --git a/src/prime_rl/orchestrator/orchestrator.py b/src/prime_rl/orchestrator/orchestrator.py index 9488e07ef2..3a810dbf94 100644 --- a/src/prime_rl/orchestrator/orchestrator.py +++ b/src/prime_rl/orchestrator/orchestrator.py @@ -31,6 +31,7 @@ import pandas as pd import verifiers as vf +from renderers import mm_store from renderers.base import create_renderer from prime_rl.configs.orchestrator import OrchestratorConfig @@ -451,15 +452,33 @@ async def orchestrate(config: OrchestratorConfig): save_rollouts, train_rollouts, step_path / "train_rollouts.jsonl", exclude_keys={"trajectory"} ) + # Canonical run asset root. The platform mounts RUN_ID into every pod, so + # the env worker (image + feature offload), the orchestrator offload, and + # the sweeper all resolve the SAME ``/data/outputs/run_`` dir — + # otherwise the orchestrator re-copies images the env worker already wrote + # (different subdir) and the sweeper misses the env-worker/feature assets. + # Falls back to ``output_dir`` for local runs without RUN_ID. + asset_root = mm_store.run_dir(run_id) if run_id else config.output_dir + # Offload base64 images to disk to free memory. No-op for text-only - # rollouts (no ``data:image`` URLs to find); cheap to call always. + # rollouts (no ``data:image`` URLs to find); cheap to call always. Images + # the env worker already offloaded here (same dir) are recognized and not + # re-copied. offload_start = time.perf_counter() - num_offloaded = offload_images_to_disk(train_rollouts, config.output_dir) + num_offloaded = offload_images_to_disk(train_rollouts, asset_root) if num_offloaded: logger.info( f"Offloaded {num_offloaded} unique images to disk in {time.perf_counter() - offload_start:.2f}s" ) + # Evict stale offloaded multimodal artifacts under this run's dir (images + + # mm_features). No-op for text-only runs (asset dirs absent). Content- + # addressed + re-writable, so over-eviction is safe; each run sweeps its + # own dir → multi-run safe. + num_swept = mm_store.sweep_stale_artifacts(asset_root, config.mm_artifact_ttl_seconds) + if num_swept: + logger.info(f"Swept {num_swept} stale multimodal artifacts (ttl={config.mm_artifact_ttl_seconds}s)") + # Convert rollouts to training samples parallel_preprocess_start = time.perf_counter() @@ -500,6 +519,7 @@ async def _interleave_bounded(rollout: vf.RolloutOutput): rollout, mm_token_type_ids_mapping=mm_token_type_ids_mapping, renderer=renderer, + defer_materialization=config.defer_mm_materialization, ) results = await asyncio.gather(*(_interleave_bounded(r) for r in train_rollouts)) diff --git a/src/prime_rl/orchestrator/trajectories.py b/src/prime_rl/orchestrator/trajectories.py index bff9fd7566..48e10aed8b 100644 --- a/src/prime_rl/orchestrator/trajectories.py +++ b/src/prime_rl/orchestrator/trajectories.py @@ -5,11 +5,12 @@ import numpy as np import pybase64 -import torch import verifiers as vf +from renderers import mm_store from transformers.tokenization_utils import PreTrainedTokenizer from prime_rl.transport import RoutedExperts, TrainingSample +from prime_rl.transport.types import MMRefs from prime_rl.utils.chat_template import ( common_prefix_len, deserialize_tool_calls, @@ -18,6 +19,12 @@ strip_message_content, ) from prime_rl.utils.logger import get_logger +from prime_rl.utils.mm import ( + encode_mm_kwargs, + image_uris_from_messages, + pack_mm_kwargs_tensors, + reconstruct_mm_pixels, +) # We use list() instead of deepcopy() for flat lists (token IDs, logprobs) - safe because # primitives are immutable. mm_kwargs payloads are not mutated after creation. @@ -205,6 +212,7 @@ def interleave_rollout( output: vf.RolloutOutput, mm_token_type_ids_mapping: dict[int, int] | None = None, renderer: Any = None, + defer_materialization: bool = False, ) -> list[TrainingSample] | None: """ Convert vf.RolloutOutput to trainable rollouts by interleaving trajectory steps @@ -433,7 +441,13 @@ def extend_sample( if not output.get("is_filtered", False): for _, sample, step_indices in active_samples: renderer_mm = _union_step_mm_data(prepared_steps, step_indices) - if renderer_mm is not None: + if renderer_mm is None: + continue + if defer_materialization: + # Ship lightweight refs; the trainer materializes pixels in its + # data loader from the descriptor + window image URIs. + sample.mm_refs = _collect_mm_refs(renderer_mm, trajectory, step_indices) + else: # The env worker ships descriptor-only mm_data (no pixel_values) # to keep its memory flat. Re-derive the pixels here from the # offloaded images referenced in this sample's messages, matched @@ -441,19 +455,19 @@ def extend_sample( # multimodal pool used for rollouts; absent (or already # pixel-bearing in-process tests) → pass through unchanged. if renderer is not None and _mm_needs_pixels(renderer_mm): - window_messages = _window_image_messages(trajectory, step_indices) - renderer_mm = _reconstruct_mm_pixels(renderer, renderer_mm, window_messages) + renderer_mm = _reconstruct_mm_pixels( + renderer, renderer_mm, _window_image_messages(trajectory, step_indices) + ) mm_kwargs = _pack_mm_kwargs_from_renderer(renderer_mm) if mm_kwargs is not None: sample.mm_kwargs = mm_kwargs - # ``mm_token_type_ids``: 1 for image-placeholder tokens, 2 - # for video, 0 otherwise. Renderer-supplied via - # ``mm_token_type_id_map`` (single source of truth). - if mm_token_type_ids_mapping is not None: - sample.mm_token_type_ids = [ - mm_token_type_ids_mapping.get(token_id, 0) - for token_id in sample.prompt_ids + sample.completion_ids - ] + # ``mm_token_type_ids``: 1 for image-placeholder tokens, 2 for + # video, 0 otherwise. Renderer-supplied via ``mm_token_type_id_map`` + # (single source of truth). Computed in both paths. + if (sample.mm_kwargs is not None or sample.mm_refs is not None) and mm_token_type_ids_mapping is not None: + sample.mm_token_type_ids = [ + mm_token_type_ids_mapping.get(token_id, 0) for token_id in sample.prompt_ids + sample.completion_ids + ] return [sample for _, sample, _ in active_samples] @@ -526,82 +540,65 @@ def _window_image_messages(trajectory: list[Any], step_indices: list[int]) -> li def _reconstruct_mm_pixels(renderer: Any, union_mm: dict[str, Any], messages: list[Any]) -> Any: - """Re-attach ``pixel_values`` to a descriptor-only union mm_data. - - Delegates to the renderer's ``materialize_pixels`` (hash-matched reprocess - of the window images, with a ``grid_thw`` assert). The descriptor's - ``image_grid_thw`` is decoded from its msgpack wire shape back to numpy - first, so the renderer's numpy-vs-numpy grid assert holds after transport. - """ - from renderers.base import MultiModalData - from verifiers.utils.serve_utils import decode_tensor_payload - - items = union_mm.get("mm_items") or {} - decoded_items: dict[str, list] = {} - for modality, lst in items.items(): - new_lst: list[dict[str, Any]] = [] - for item in lst or []: - item = dict(item) - grid = item.get("image_grid_thw") - if item.get("pixel_values") is None and grid is not None: - item["image_grid_thw"] = decode_tensor_payload(grid, to_torch=False) - new_lst.append(item) - decoded_items[modality] = new_lst - - md = MultiModalData( - mm_hashes=union_mm.get("mm_hashes") or {}, - mm_placeholders={}, - mm_items=decoded_items, - ) - return renderer.materialize_pixels(md, messages) + """Re-attach ``pixel_values`` to a descriptor-only union mm_data. Delegates + to ``utils.mm`` so the flag-off path matches the trainer's flag-on path.""" + return reconstruct_mm_pixels(renderer, union_mm, messages) def _pack_mm_kwargs_from_renderer(mm_data: Any) -> "dict[str, Any] | None": - """Batch the renderer's per-image ``mm_items`` into model-agnostic - forward kwargs. - - ``mm_data`` may arrive as a ``MultiModalData`` instance (in-process - for tests) or as a plain dict (after msgpack round-trip from the - env-worker). Each item is a dict keyed by the names the model's - ``forward`` expects (``pixel_values`` + ``image_grid_thw`` for - Qwen3-VL, just ``pixel_values`` for Gemma3-VL, etc.). We batch by - ``torch.cat(..., dim=0)`` per key — generic because every HF VLM - processor emits a leading batch/patch dimension, and the renderer - always processes one image per call. - - Returns a dict of ``EncodedTensor`` payloads keyed by kwarg name, - or ``None`` when no multimodal data is present. - """ - from verifiers.utils.serve_utils import decode_tensor_payload + """Batch the renderer's per-image ``mm_items`` into ``EncodedTensor`` + forward kwargs. Delegates to ``utils.mm`` (pack then encode).""" + tensors = pack_mm_kwargs_tensors(mm_data) + if tensors is None: + return None + return encode_mm_kwargs(tensors) - from prime_rl.transport.types import EncodedTensor - mm_items = mm_data.mm_items if hasattr(mm_data, "mm_items") else (mm_data or {}).get("mm_items") or {} - # Flatten across modalities into one kwarg dict — the model's - # forward signature is the schema. ``mm_items`` is typically - # ``{"image": [...], "video": [...]}`` but each modality's keys - # don't collide for any HF VLM we ship today. - per_kwarg: dict[str, list] = {} - for _modality, items in mm_items.items(): - for item in items or []: - for key, payload in item.items(): - # ``decode_tensor_payload`` rehydrates the encoded wire shape to - # torch but passes already-rehydrated numpy through unchanged - # (e.g. pixels reconstructed in-process from disk). ``as_tensor`` - # normalizes both to torch so the ``torch.cat`` below is uniform. - per_kwarg.setdefault(key, []).append(torch.as_tensor(decode_tensor_payload(payload))) - if not per_kwarg: +def _grid_to_list(grid: Any) -> "list | None": + """Normalize an ``image_grid_thw`` (torch/numpy tensor, msgpack wire payload, + or list) to a plain nested list — the only transport-safe form. Shape is + preserved (real Qwen grids are 2-D ``(n, 3)`` → ``list[list[int]]``), since + the renderer's ``_grids_equal`` compares the nested form.""" + if grid is None: return None - out: dict[str, EncodedTensor] = {} - for key, tensors in per_kwarg.items(): - cat = torch.cat(tensors, dim=0).contiguous() - arr = cat.detach().cpu().numpy() - out[key] = EncodedTensor( - dtype=str(arr.dtype), - shape=list(arr.shape), - data=arr.tobytes(), + if isinstance(grid, dict): # msgpack wire payload from the env→orch hop + from verifiers.utils.serve_utils import decode_tensor_payload + + grid = decode_tensor_payload(grid, to_torch=False) + return grid.tolist() if hasattr(grid, "tolist") else list(grid) + + +def _collect_mm_refs(union_mm: dict[str, Any], trajectory: list[Any], step_indices: list[int]) -> MMRefs | None: + """Build lightweight image references (descriptor + window URIs) for the + deferred-materialization path. + + The descriptor must be transport-safe: the batch sender uses a bare msgspec + encoder that rejects tensors/ndarrays. So we ship a normalized, descriptor-ONLY + copy — pixel_values dropped, grids → ``list[int]``, hashes → ``str`` — rather + than the raw union (which on in-process paths still holds torch tensors). + """ + items = union_mm.get("mm_items") or {} + hashes = union_mm.get("mm_hashes") or {} + non_image = sorted( + {m for m, lst in items.items() if lst and m != "image"} | {m for m, hl in hashes.items() if hl and m != "image"} + ) + if non_image: + raise ValueError(f"defer_mm_materialization supports the image modality only this iteration; got {non_image}") + # Renderer-family guard: this path is Qwen-style (descriptors keyed on + # ``image_grid_thw``). Other families (e.g. Kimi uses ``grid_thws``) would + # silently ship a None grid — fail loudly instead. + if any(item.get("image_grid_thw") is None for item in (items.get("image") or [])): + raise ValueError( + "defer_mm_materialization currently supports Qwen-style image descriptors only " + "(items must carry image_grid_thw); got an image item without it — unsupported renderer family." ) - return out + norm_items = { + modality: [{"image_grid_thw": _grid_to_list(item.get("image_grid_thw"))} for item in (lst or [])] + for modality, lst in items.items() + } + norm_hashes = {m: [str(h) if h is not None else h for h in (hl or [])] for m, hl in hashes.items()} + uris = image_uris_from_messages(_window_image_messages(trajectory, step_indices)) + return MMRefs(descriptor={"mm_items": norm_items, "mm_hashes": norm_hashes}, uris=uris) _FILE_URL_PREFIX = "file://" @@ -644,7 +641,7 @@ def offload_images_to_disk(rollouts: list[vf.RolloutOutput], output_dir: Path) - """ # Absolute: paths become ``file://`` URLs; a relative path yields a malformed # URI (``file://rel/...``) that the renderer can't load. - images_dir = (output_dir / "assets" / "images").resolve() + images_dir = (output_dir / mm_store.IMAGE_ASSET_SUBDIR).resolve() images_dir.mkdir(parents=True, exist_ok=True) written: set[str] = set() From be11a90567752fbdde13b2c21faba57c49af0cf4 Mon Sep 17 00:00:00 2001 From: Eli Gottlieb <78387377+eligotts@users.noreply.github.com> Date: Sat, 30 May 2026 22:38:58 +0000 Subject: [PATCH 10/31] feat(mm): vLLM mmfile feature reader serving_tokens parses mmfile:v1 refs (via mm_store.split_mmfile_ref), validates the artifact (safe-id regexes, slot modality/hash match, version-pinned fingerprint compat) and loads the processed MultiModalKwargsItem from /data/outputs/run_/assets/mm_features, with structured errors on a missing/invalid artifact. Shares the format + envelope helpers with the writer via renderers.mm_store. Co-Authored-By: Codex Co-Authored-By: Claude Opus 4.8 (1M context) --- src/prime_rl/inference/vllm/serving_tokens.py | 411 ++++++++++++++++-- tests/unit/inference/test_serving_tokens.py | 84 ++++ 2 files changed, 467 insertions(+), 28 deletions(-) diff --git a/src/prime_rl/inference/vllm/serving_tokens.py b/src/prime_rl/inference/vllm/serving_tokens.py index afaabef0e6..629d5debc1 100644 --- a/src/prime_rl/inference/vllm/serving_tokens.py +++ b/src/prime_rl/inference/vllm/serving_tokens.py @@ -30,11 +30,30 @@ from __future__ import annotations +import asyncio +import concurrent.futures +import json +import logging +import os +import re +import time from collections.abc import AsyncGenerator from functools import cached_property +from http import HTTPStatus +from pathlib import Path from typing import Any from fastapi import Request +from renderers.mm_store import ( + _SAFE_FINGERPRINT_RE, + _SAFE_MM_HASH_RE, + _SAFE_RUN_ID_RE, + MMFILE_PREFIX, + mm_feature_envelope_matches, + mm_feature_fingerprint, + mm_feature_path, + split_mmfile_ref, +) from vllm.entrypoints.openai.engine.protocol import ErrorResponse, RequestResponseMetadata from vllm.entrypoints.serve.disagg.protocol import ( GenerateRequest, @@ -48,6 +67,13 @@ from prime_rl.inference.vllm.routed_experts import RoutedExpertsCapture +logger = logging.getLogger(__name__) + +_MM_FEATURE_LOAD_WORKERS_ENV = "PRIME_RL_MM_FEATURE_LOAD_WORKERS" +_MM_FEATURE_LOAD_RETRIES = 3 +_MM_FEATURE_LOAD_BACKOFF_S = 0.02 +_mm_feature_executor: concurrent.futures.ThreadPoolExecutor | None = None + class PrimeRlGenerateResponseChoice(GenerateResponseChoice): routed_experts: dict[str, Any] | None = None @@ -97,6 +123,269 @@ async def _client_set_max_tokens(raw_request: Request | None) -> bool: return isinstance(sp, dict) and "max_tokens" in sp +class _MMFeatureArtifactError(Exception): + def __init__( + self, + *, + error_type: str, + message: str, + missing: list[dict[str, str]] | None = None, + status_code: HTTPStatus = HTTPStatus.INTERNAL_SERVER_ERROR, + ) -> None: + super().__init__(message) + self.error_type = error_type + self.missing = missing or [] + self.status_code = status_code + + def response_message(self) -> str: + return json.dumps( + { + "error_type": self.error_type, + "message": str(self), + "missing": self.missing, + }, + separators=(",", ":"), + ) + + +def _mm_feature_load_workers() -> int: + raw = os.getenv(_MM_FEATURE_LOAD_WORKERS_ENV, "8").strip() + try: + return max(1, int(raw)) + except ValueError: + return 8 + + +def _get_mm_feature_executor() -> concurrent.futures.ThreadPoolExecutor: + global _mm_feature_executor + if _mm_feature_executor is None: + _mm_feature_executor = concurrent.futures.ThreadPoolExecutor( + max_workers=_mm_feature_load_workers(), + thread_name_prefix="prime-mmfile", + ) + return _mm_feature_executor + + +def _mm_feature_env_run_id() -> str: + run_id = os.environ.get("RUN_ID", "").strip() + if not run_id or not _SAFE_RUN_ID_RE.fullmatch(run_id): + raise _MMFeatureArtifactError( + error_type="invalid_mm_feature_store", + message="RUN_ID must be set to a safe run id for legacy mmfile refs.", + status_code=HTTPStatus.BAD_REQUEST, + ) + return run_id + + +def _parse_mmfile_ref(ref: str, *, expected_modality: str, expected_hash: str) -> tuple[str, str, str, str]: + try: + run_id, fingerprint, modality, mm_hash = split_mmfile_ref(ref) + except ValueError as exc: + raise _MMFeatureArtifactError( + error_type="invalid_mm_feature_ref", + message=f"Invalid mmfile ref shape for {expected_modality}.", + status_code=HTTPStatus.BAD_REQUEST, + ) from exc + if run_id is None: # legacy 5-part ref: run_id comes from this process's env + run_id = _mm_feature_env_run_id() + if not _SAFE_RUN_ID_RE.fullmatch(run_id): + raise _MMFeatureArtifactError( + error_type="invalid_mm_feature_ref", + message="mmfile run_id contains unsafe characters.", + status_code=HTTPStatus.BAD_REQUEST, + ) + if modality != expected_modality: + raise _MMFeatureArtifactError( + error_type="invalid_mm_feature_ref", + message=(f"mmfile modality {modality!r} does not match slot modality {expected_modality!r}."), + status_code=HTTPStatus.BAD_REQUEST, + ) + if mm_hash != expected_hash: + raise _MMFeatureArtifactError( + error_type="invalid_mm_feature_ref", + message="mmfile hash does not match the slot mm_hash.", + status_code=HTTPStatus.BAD_REQUEST, + ) + if not _SAFE_FINGERPRINT_RE.fullmatch(fingerprint): + raise _MMFeatureArtifactError( + error_type="invalid_mm_feature_ref", + message="mmfile fingerprint contains unsafe characters.", + status_code=HTTPStatus.BAD_REQUEST, + ) + if modality != "image": + raise _MMFeatureArtifactError( + error_type="invalid_mm_feature_ref", + message=f"Unsupported mmfile modality: {modality!r}.", + status_code=HTTPStatus.BAD_REQUEST, + ) + if not _SAFE_MM_HASH_RE.fullmatch(mm_hash): + raise _MMFeatureArtifactError( + error_type="invalid_mm_feature_ref", + message="mmfile hash contains unsafe characters.", + status_code=HTTPStatus.BAD_REQUEST, + ) + expected_fingerprint = mm_feature_fingerprint(family="qwen_vl", spatial_merge_size=2) + if fingerprint != expected_fingerprint: + raise _MMFeatureArtifactError( + error_type="incompatible_mm_feature_artifact", + message=( + "mmfile fingerprint is not compatible with this vLLM process " + f"(got {fingerprint}, expected {expected_fingerprint})." + ), + status_code=HTTPStatus.BAD_REQUEST, + ) + return run_id, fingerprint, modality, mm_hash + + +def _mm_feature_path(*, run_id: str, fingerprint: str, modality: str, mm_hash: str) -> Path: + # ``_parse_mmfile_ref`` validates run_id/fingerprint/modality/mm_hash and the + # traversal guard before we reach here, so ``mm_store.mm_feature_path``'s + # ValueError paths are unreachable; surface any as the reader's domain error. + try: + return mm_feature_path(run_id=run_id, fingerprint=fingerprint, modality=modality, mm_hash=mm_hash) + except ValueError as exc: + raise _MMFeatureArtifactError( + error_type="invalid_mm_feature_ref", + message=str(exc), + status_code=HTTPStatus.BAD_REQUEST, + ) from exc + + +def _decoded_image_placeholder_length(item: Any, *, spatial_merge_size: int) -> int: + elem = item.get("image_grid_thw") + data = getattr(elem, "data", elem) + if hasattr(data, "detach"): + data = data.detach().cpu() + if hasattr(data, "tolist"): + data = data.tolist() + grid = data[0] if isinstance(data, list) and data and isinstance(data[0], list) else data + if not isinstance(grid, list) or len(grid) != 3: + raise ValueError("decoded image_grid_thw does not have shape [T,H,W]") + return int(grid[0]) * int(grid[1]) * int(grid[2]) // (spatial_merge_size**2) + + +def _load_mmfile_ref_sync( + ref: str, + *, + expected_modality: str, + expected_hash: str, + expected_placeholder_length: int, +): + import msgpack + from vllm.multimodal.inputs import MultiModalKwargsItem + from vllm.v1.serial_utils import MsgpackDecoder + + run_id, fingerprint, modality, mm_hash = _parse_mmfile_ref( + ref, expected_modality=expected_modality, expected_hash=expected_hash + ) + path = _mm_feature_path(run_id=run_id, fingerprint=fingerprint, modality=modality, mm_hash=mm_hash) + missing = [ + { + "run_id": run_id, + "modality": modality, + "mm_hash": mm_hash, + "fingerprint": fingerprint, + } + ] + + packed: bytes | None = None + for attempt in range(_MM_FEATURE_LOAD_RETRIES): + try: + packed = path.read_bytes() + break + except FileNotFoundError: + if attempt + 1 == _MM_FEATURE_LOAD_RETRIES: + raise _MMFeatureArtifactError( + error_type="missing_mm_feature_artifact", + message=f"Missing mmfile artifact: {path}", + missing=missing, + ) from None + time.sleep(_MM_FEATURE_LOAD_BACKOFF_S * (attempt + 1)) + + try: + artifact = msgpack.unpackb(packed, raw=False) + envelope = artifact.get("envelope") if isinstance(artifact, dict) else None + payload = artifact.get("payload") if isinstance(artifact, dict) else None + if not isinstance(envelope, dict) or not isinstance(payload, bytes): + raise ValueError("artifact must contain envelope and binary payload") + if not mm_feature_envelope_matches( + envelope, + run_id=run_id, + fingerprint=fingerprint, + modality=modality, + mm_hash=mm_hash, + payload=payload, + require_run_id=False, + ): + raise ValueError("artifact envelope does not match requested mmfile") + + decoder = MsgpackDecoder(t=MultiModalKwargsItem) + item = decoder.decode(payload) + placeholder_length = _decoded_image_placeholder_length(item, spatial_merge_size=2) + if int(envelope.get("placeholder_length", -1)) != expected_placeholder_length: + raise ValueError("artifact placeholder length does not match envelope") + if placeholder_length != expected_placeholder_length: + raise ValueError("decoded image_grid_thw does not match placeholder length") + return item + except _MMFeatureArtifactError: + raise + except Exception as exc: + raise _MMFeatureArtifactError( + error_type="corrupt_mm_feature_artifact", + message=f"Corrupt mmfile artifact for {modality}:{mm_hash}: {exc}", + missing=missing, + ) from exc + + +async def _load_mmfile_ref( + ref: str, + *, + expected_modality: str, + expected_hash: str, + expected_placeholder_length: int, +): + loop = asyncio.get_running_loop() + return await loop.run_in_executor( + _get_mm_feature_executor(), + lambda: _load_mmfile_ref_sync( + ref, + expected_modality=expected_modality, + expected_hash=expected_hash, + expected_placeholder_length=expected_placeholder_length, + ), + ) + + +def _missing_cache_error_from_exception(exc: Exception, features: Any) -> _MMFeatureArtifactError | None: + text = repr(exc) + if "Expected a cached item" not in text: + return None + + missing_hashes = set(re.findall(r"mm_hash=['\"]([^'\"]+)['\"]", text)) + missing: list[dict[str, str]] = [] + kwargs_data = getattr(features, "kwargs_data", None) + hashes_by_modality = getattr(features, "mm_hashes", {}) or {} + if isinstance(kwargs_data, dict): + for modality, items in kwargs_data.items(): + hashes = hashes_by_modality.get(modality) or [] + for idx, item in enumerate(items): + if item is not None or idx >= len(hashes): + continue + mm_hash = hashes[idx] + if missing_hashes and mm_hash not in missing_hashes: + continue + missing.append({"modality": modality, "mm_hash": mm_hash}) + + if not missing and missing_hashes: + missing = [{"modality": "unknown", "mm_hash": h} for h in missing_hashes] + + return _MMFeatureArtifactError( + error_type="missing_mm_cache_item", + message=f"vLLM multimodal cache miss for cache-only slot: {exc}", + missing=missing, + ) + + class PrimeRlServingTokens(ServingTokens): """ServingTokens + DP-rank routing + compact routed experts + max_tokens defaulting.""" @@ -150,31 +439,86 @@ async def serve_tokens( # Build the engine input — features-aware (MM) or text-only fallback. # Identical to upstream so we keep tracking it. if features := request.features: - from vllm.entrypoints.serve.disagg.mm_serde import decode_mm_kwargs_item - from vllm.inputs import mm_input - from vllm.multimodal.inputs import ( - MultiModalKwargsItem, - PlaceholderRange, - ) - - mm_placeholders = { - modality: [PlaceholderRange(offset=p.offset, length=p.length) for p in ranges] - for modality, ranges in features.mm_placeholders.items() - } - mm_kwargs: dict[str, list[MultiModalKwargsItem | None]] = {} - if features.kwargs_data is not None: - for modality, items in features.kwargs_data.items(): - mm_kwargs[modality] = [decode_mm_kwargs_item(item) if item is not None else None for item in items] - else: - for modality, hashes in features.mm_hashes.items(): - mm_kwargs[modality] = [None] * len(hashes) - engine_input = mm_input( - prompt_token_ids=request.token_ids, - mm_kwargs=mm_kwargs, # type: ignore[arg-type] - mm_hashes=features.mm_hashes, - mm_placeholders=mm_placeholders, - cache_salt=request.cache_salt, - ) + try: + from vllm.entrypoints.serve.disagg.mm_serde import decode_mm_kwargs_item + from vllm.inputs import mm_input + from vllm.multimodal.inputs import ( + MultiModalKwargsItem, + PlaceholderRange, + ) + + mm_placeholders = { + modality: [PlaceholderRange(offset=p.offset, length=p.length) for p in ranges] + for modality, ranges in features.mm_placeholders.items() + } + mm_kwargs: dict[str, list[MultiModalKwargsItem | None]] = {} + slot_counts = {"none": 0, "inline": 0, "mmfile": 0} + load_start = time.monotonic() + + async def decode_slot(modality: str, idx: int, item: str | None) -> MultiModalKwargsItem | None: + if item is None: + slot_counts["none"] += 1 + return None + if item.startswith(f"{MMFILE_PREFIX}:"): + slot_counts["mmfile"] += 1 + hashes = features.mm_hashes.get(modality) or [] + placeholders = features.mm_placeholders.get(modality) or [] + if idx >= len(hashes) or idx >= len(placeholders): + raise _MMFeatureArtifactError( + error_type="invalid_mm_feature_ref", + message=("mmfile slot has no matching hash or placeholder entry."), + status_code=HTTPStatus.BAD_REQUEST, + ) + return await _load_mmfile_ref( + item, + expected_modality=modality, + expected_hash=hashes[idx], + expected_placeholder_length=placeholders[idx].length, + ) + slot_counts["inline"] += 1 + return decode_mm_kwargs_item(item) + + if features.kwargs_data is not None: + for modality, items in features.kwargs_data.items(): + hashes = features.mm_hashes.get(modality) or [] + if len(items) != len(hashes): + raise _MMFeatureArtifactError( + error_type="invalid_mm_feature_ref", + message=( + f"kwargs_data[{modality!r}] has {len(items)} items but mm_hashes has {len(hashes)}." + ), + status_code=HTTPStatus.BAD_REQUEST, + ) + mm_kwargs[modality] = list( + await asyncio.gather(*(decode_slot(modality, idx, item) for idx, item in enumerate(items))) + ) + else: + for modality, hashes in features.mm_hashes.items(): + slot_counts["none"] += len(hashes) + mm_kwargs[modality] = [None] * len(hashes) + + if any(slot_counts.values()): + logger.debug( + "decoded multimodal feature slots none=%d inline=%d mmfile=%d disk_load_ms=%.2f", + slot_counts["none"], + slot_counts["inline"], + slot_counts["mmfile"], + (time.monotonic() - load_start) * 1000.0, + ) + + engine_input = mm_input( + prompt_token_ids=request.token_ids, + mm_kwargs=mm_kwargs, # type: ignore[arg-type] + mm_hashes=features.mm_hashes, + mm_placeholders=mm_placeholders, + cache_salt=request.cache_salt, + ) + except _MMFeatureArtifactError as exc: + return self.create_error_response( + exc.response_message(), + err_type=exc.error_type, + status_code=exc.status_code, + ) else: (engine_input,) = await self.openai_serving_render.preprocess_completion( request, @@ -269,9 +613,20 @@ async def serve_tokens_full_generator( # type: ignore[override] capture = _GenerateRoutedExpertsCapture(result_generator) result_generator = capture - response = await super().serve_tokens_full_generator( - request, result_generator, request_id, model_name, request_metadata - ) + try: + response = await super().serve_tokens_full_generator( + request, result_generator, request_id, model_name, request_metadata + ) + except Exception as exc: + if request.features is not None: + mm_error = _missing_cache_error_from_exception(exc, request.features) + if mm_error is not None: + return self.create_error_response( + mm_error.response_message(), + err_type=mm_error.error_type, + status_code=mm_error.status_code, + ) + raise if capture is not None and isinstance(response, GenerateResponse): response = capture.post_process(response) diff --git a/tests/unit/inference/test_serving_tokens.py b/tests/unit/inference/test_serving_tokens.py index 1882e57e55..3446a62f7e 100644 --- a/tests/unit/inference/test_serving_tokens.py +++ b/tests/unit/inference/test_serving_tokens.py @@ -14,6 +14,8 @@ import numpy as np import pybase64 +import pytest +from renderers.mm_store import mm_feature_fingerprint as _mm_feature_fingerprint from vllm.entrypoints.serve.disagg.protocol import GenerateResponse, GenerateResponseChoice from prime_rl.inference.vllm.routed_experts import serialize_routed_experts @@ -21,6 +23,9 @@ PrimeRlServingTokens, _client_set_max_tokens, _GenerateRoutedExpertsCapture, + _load_mmfile_ref_sync, + _missing_cache_error_from_exception, + _MMFeatureArtifactError, ) @@ -115,3 +120,82 @@ def test_client_set_max_tokens_assumes_set_when_body_unreadable(): # non-dict body → can't tell, don't override. assert asyncio.run(_client_set_max_tokens(_FakeRawRequest([1, 2, 3]))) is True + + +def test_missing_cache_error_is_typed_for_cache_only_slots(): + class _Features: + kwargs_data = {"image": [None, "mmfile:v1:run-a:fp:image:def"]} + mm_hashes = {"image": ["abc", "def"]} + + err = AssertionError("Expected a cached item for mm_hash='abc'") + + typed = _missing_cache_error_from_exception(err, _Features()) + + assert typed is not None + assert typed.error_type == "missing_mm_cache_item" + assert typed.missing == [{"modality": "image", "mm_hash": "abc"}] + + +def test_missing_mmfile_artifact_is_typed(tmp_path, monkeypatch): + monkeypatch.setenv("PRIME_RL_MM_FEATURE_ROOT", str(tmp_path)) + monkeypatch.delenv("RUN_ID", raising=False) + run_id = "testrun" + mm_hash = "a" * 32 + fingerprint = _mm_feature_fingerprint(family="qwen_vl", spatial_merge_size=2) + ref = f"mmfile:v1:{run_id}:{fingerprint}:image:{mm_hash}" + + with pytest.raises(_MMFeatureArtifactError) as exc_info: + _load_mmfile_ref_sync( + ref, + expected_modality="image", + expected_hash=mm_hash, + expected_placeholder_length=1, + ) + + assert exc_info.value.error_type == "missing_mm_feature_artifact" + assert exc_info.value.missing == [ + { + "run_id": run_id, + "modality": "image", + "mm_hash": mm_hash, + "fingerprint": fingerprint, + } + ] + + +def test_mmfile_artifact_round_trips_vllm_serde(tmp_path, monkeypatch): + torch = pytest.importorskip("torch") + pytest.importorskip("vllm") + + from renderers.base import MultiModalData, PlaceholderRange + from renderers.client import _build_qwen_vl_features + + monkeypatch.setenv("RENDERERS_MM_FEATURE_STORE_MODE", "on") + monkeypatch.setenv("PRIME_RL_MM_FEATURE_ROOT", str(tmp_path)) + monkeypatch.setenv("RUN_ID", "roundtrip") + mm_hash = "a" * 32 + mm_data = MultiModalData( + mm_hashes={"image": [mm_hash]}, + mm_placeholders={"image": [PlaceholderRange(offset=5, length=1)]}, + mm_items={ + "image": [ + { + "pixel_values": torch.zeros(4, 8, dtype=torch.float32), + "image_grid_thw": torch.tensor([[1, 2, 2]], dtype=torch.int64), + } + ] + }, + ) + + features = _build_qwen_vl_features(mm_data, spatial_merge_size=2) + ref = features["kwargs_data"]["image"][0] + assert ref.startswith("mmfile:v1:roundtrip:") + monkeypatch.setenv("RUN_ID", "different-reader-run") + item = _load_mmfile_ref_sync( + ref, + expected_modality="image", + expected_hash=mm_hash, + expected_placeholder_length=1, + ) + + assert "image_grid_thw" in item From c0ef88b6e9e62e3da9ec72d57666a9862307e891 Mon Sep 17 00:00:00 2001 From: Eli Gottlieb <78387377+eligotts@users.noreply.github.com> Date: Sat, 30 May 2026 22:38:58 +0000 Subject: [PATCH 11/31] feat(mm): default deferred materialization + feature offload on (VLM-scoped) - defer_mm_materialization defaults to True on both orchestrator + trainer; the renderer requirement is scoped to VLM runs (model.vlm is not None) so text-only runs are a no-op. Trainer renderer defaults to AutoRendererConfig() (mirrors the orchestrator); SFT force-sets defer off. - mm_artifact_ttl_seconds (default 3600s) drives the orchestrator's per-step sweep. Co-Authored-By: Codex Co-Authored-By: Claude Opus 4.8 (1M context) --- .../src/prime_rl/configs/orchestrator.py | 22 ++++++++++++++++ .../src/prime_rl/configs/trainer.py | 26 +++++++++++++++++++ tests/unit/test_configs.py | 23 ++++++++++++++-- 3 files changed, 69 insertions(+), 2 deletions(-) diff --git a/packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py b/packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py index cfa844b835..9be54a2472 100644 --- a/packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py +++ b/packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py @@ -633,6 +633,9 @@ def _preserve_mito_renderer(self, handler: SerializerFunctionWrapHandler) -> dic output_dir: Path = Path("outputs/run_default") """Directory to write outputs to — checkpoints, weights, rollouts, and logs are written as subdirectories. Should be a persistent directory with enough disk space and unique per experiment running on a single node.""" + mm_artifact_ttl_seconds: float = 3600.0 + """TTL (seconds) for offloaded multimodal artifacts under ``output_dir/assets/{images,mm_features}``. Once per step the orchestrator deletes artifact files older than this. Artifacts are content-addressed and re-materializable, so over-eviction is safe (triggers a re-write) while under-eviction only wastes disk — bias large. Defaults to 1 hour.""" + tasks_per_minute: int | None = Field(None, ge=1) """Rate limit per environment worker, in tasks per minute. Recommended for sandbox-backed environments to prevent sandbox-not-ready errors during autoscaling. With multiple workers, the effective total rate is ``workers × this value``. None disables rate limiting.""" @@ -667,6 +670,9 @@ def _preserve_mito_renderer(self, handler: SerializerFunctionWrapHandler) -> dic mm_materialize_concurrency: int = Field(4, ge=1) """Max rollouts whose multimodal pixels are reconstructed concurrently when converting rollouts to training samples. Bounds the transient build-time memory spike for VLM batches (each in-flight rollout holds live pixel tensors + packing copies). No effect on text-only runs.""" + defer_mm_materialization: bool = True + """Defer multimodal pixel materialization to the trainer. When True, the orchestrator ships lightweight image references (``mm_refs``) instead of materializing pixels and shipping heavy ``mm_kwargs``. Must match the trainer's setting. A no-op for text-only runs; forced off for SFT.""" + bench: bool = False """Benchmark mode. Sets ``max_steps`` to 5 and disables W&B.""" @@ -808,6 +814,9 @@ def _force_no_renderer_for_sft(self): validators below so they see the corrected value.""" if self.training_mode == "sft": self.renderer = None + # SFT has no renderer, so it can't defer materialization; keep the + # default-on flag from tripping the renderer-required validator. + self.defer_mm_materialization = False return self @model_validator(mode="after") @@ -884,6 +893,19 @@ def validate_renderer_auto_resolves(self): f"client entirely (MITO)." ) + @model_validator(mode="after") + def validate_defer_mm_materialization(self): + """Deferred materialization needs a renderer so the descriptor it ships + in ``mm_refs`` is reproducible by the trainer's identical renderer.""" + # Only VLM runs emit mm_refs; text-only runs never do, so default-on is + # a harmless no-op for them even if the renderer is opted out. + if self.defer_mm_materialization and self.renderer is None and self.student.model.vlm is not None: + raise ValueError( + "orchestrator.defer_mm_materialization requires a renderer so the trainer can " + "materialize pixels identically from the shipped image references." + ) + return self + @model_validator(mode="after") def resolve_batching(self): has_rollout_batch = self.batch_size is not None diff --git a/packages/prime-rl-configs/src/prime_rl/configs/trainer.py b/packages/prime-rl-configs/src/prime_rl/configs/trainer.py index 00f4e07deb..ce434ebaf6 100644 --- a/packages/prime-rl-configs/src/prime_rl/configs/trainer.py +++ b/packages/prime-rl-configs/src/prime_rl/configs/trainer.py @@ -3,6 +3,7 @@ from typing import Annotated, Any, Literal, TypeAlias from pydantic import Field, model_validator +from renderers import AutoRendererConfig, RendererConfig from prime_rl.configs.shared import ( BaseModelConfig, @@ -562,6 +563,12 @@ class TrainerConfig(BaseConfig): max_concurrent_runs: int = Field(1, ge=1) """Maximum number of concurrent runs to allow. If 1, only one run may run at a time.""" + defer_mm_materialization: bool = True + """Defer multimodal pixel materialization from the orchestrator to the trainer. When True, the orchestrator ships lightweight image references (``mm_refs``) and the trainer materializes pixels in its data loader. Must match the orchestrator's setting; requires ``renderer`` to be set for VLM runs. A no-op for text-only runs (no ``mm_refs`` ever arrive).""" + + renderer: RendererConfig | None = AutoRendererConfig() + """Typed renderer config (``renderers.RendererConfig`` discriminated union), mirroring the orchestrator's. Auto-resolves from the model by default so VLM defer runs work without restating it; only used by VLM runs (text-only ignores it).""" + experimental: TrainerExperimentalConfig = TrainerExperimentalConfig() @model_validator(mode="after") @@ -673,3 +680,22 @@ def router_replay_only_with_custom_impl(self): raise ValueError("Router replay is only supported with the custom implementation or auto mode") return self + + @model_validator(mode="after") + def validate_defer_mm_materialization(self): + if not self.defer_mm_materialization: + return self + # Multi-run IS supported: synchronous trainer-side materialization is + # run-agnostic (all concurrent runs are LoRA adapters on the same base + # model → same image processor; mm_refs are self-contained per sample), + # and it does NOT touch the per-run ready_to_update/progress machinery in + # the packer. (A future prefetch/late-commit path WOULD need the multi-run + # ready_to_update state split — guard that there, not on the flag.) + # Only VLM runs materialize pixels; text-only runs never receive + # ``mm_refs``, so default-on is a harmless no-op for them. + if self.renderer is None and self.model.vlm is not None: + raise ValueError( + "defer_mm_materialization requires a renderer config so the trainer can " + "materialize pixels identically to the orchestrator. Set [renderer]." + ) + return self diff --git a/tests/unit/test_configs.py b/tests/unit/test_configs.py index 0044c0da5f..a4117b927e 100644 --- a/tests/unit/test_configs.py +++ b/tests/unit/test_configs.py @@ -25,13 +25,32 @@ ] +def _git_tracked_files() -> "set[Path] | None": + """Resolved paths of git-tracked tomls under configs/ and examples/, or None + if git is unavailable. Used to skip untracked scratch/experiment configs a dev + drops in locally — those shouldn't break the shipped-config validation.""" + import subprocess + + try: + out = subprocess.run( + ["git", "ls-files", "configs", "examples"], capture_output=True, text=True, check=True + ).stdout + except Exception: + return None + return {Path(line).resolve() for line in out.splitlines() if line.strip().endswith(".toml")} + + def get_config_files() -> list[Path]: - """Any TOML file inside `configs/` or `examples/` (skips the configs/private/ submodule).""" + """Any tracked TOML file inside `configs/` or `examples/` (skips the configs/private/ submodule).""" private = Path("configs/private") config_files = [p for p in Path("configs").rglob("*.toml") if private not in p.parents] example_files = list(Path("examples").rglob("*.toml")) + candidates = config_files + example_files - return config_files + example_files + tracked = _git_tracked_files() + if tracked is not None: + candidates = [p for p in candidates if p.resolve() in tracked] + return candidates def is_eval_config(path: Path) -> bool: From 545c8c9b51335e0b7ad2b7d114eb8e6790a12040 Mon Sep 17 00:00:00 2001 From: Eli Gottlieb <78387377+eligotts@users.noreply.github.com> Date: Sun, 31 May 2026 01:24:34 +0000 Subject: [PATCH 12/31] fix(mm): build trainer renderer whenever configured, not gated on model.vlm MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The orchestrator ships mm_refs based on whether rollouts have images, not on the trainer's optional [model.vlm] block — which hosted VLM configs leave unset (model.vlm=None). The previous gate (renderer only when model.vlm is not None) discarded the auto-resolved renderer, so the trainer built none and crashed on the first mm_refs ("trainer has no renderer"). Pass config.renderer unconditionally; data.py still builds only when defer is on and a renderer is configured (default AutoRendererConfig; explicit renderer=None still opts out). Co-Authored-By: Codex Co-Authored-By: Claude Opus 4.8 (1M context) --- src/prime_rl/trainer/rl/train.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/prime_rl/trainer/rl/train.py b/src/prime_rl/trainer/rl/train.py index 2e6d1dcfbb..cd0b05c574 100644 --- a/src/prime_rl/trainer/rl/train.py +++ b/src/prime_rl/trainer/rl/train.py @@ -248,9 +248,13 @@ def load_run_checkpoint(_optimizer, idx: int) -> None: tokenizer, config.rollout_transport, defer_mm_materialization=config.defer_mm_materialization, - # Only VLM runs materialize pixels; text-only runs leave this None so - # default-on defer never builds an unused renderer for them. - renderer_config=config.renderer if config.model.vlm is not None else None, + # Pass the configured renderer (defaults to AutoRendererConfig). The + # orchestrator ships mm_refs based on whether rollouts have images, NOT + # on the trainer's model.vlm block (prod VLM configs may leave it None), + # so do NOT gate on model.vlm. data.py builds the renderer only when + # defer is on and renderer_config is not None, so an explicit + # renderer=None still opts out (and a text-only run never gets mm_refs). + renderer_config=config.renderer, ) token_exporter = setup_token_exporter(config, parallel_dims, world, logger) From 64b33b0a8d064ac2e2de836bb796850a4632ecbd Mon Sep 17 00:00:00 2001 From: Eli Gottlieb <78387377+eligotts@users.noreply.github.com> Date: Sun, 31 May 2026 06:02:25 +0000 Subject: [PATCH 13/31] fix(docker): pin runtime base to python:3.12-slim-bookworm (glibc 2.36) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The bare `python:3.12-slim` tag tracks Debian stable, which moved to trixie (glibc ~2.41). A rebuild drifted the runtime glibc, and at training time the FLA gated-delta-rule backward JIT-compiles a TileLang CUDA kernel via the hostPath-mounted node nvcc — which then fails with `bits/mathcalls.h: exception specification is incompatible ... cospi/sinpi` (new glibc's noexcept decls vs the CUDA host math headers). Pinning bookworm (glibc 2.36) matches the ubuntu22.04 builder's glibc 2.35 and the known-good CUDA 12.x combo. Runtime-only change; deps unchanged. Co-Authored-By: Codex Co-Authored-By: Claude Opus 4.8 (1M context) --- Dockerfile.cuda | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/Dockerfile.cuda b/Dockerfile.cuda index 35e8d7f283..9a824d2ab2 100644 --- a/Dockerfile.cuda +++ b/Dockerfile.cuda @@ -86,7 +86,11 @@ ARG TARGETARCH COPY scripts/docker-arm64-post-install.sh /app/scripts/docker-arm64-post-install.sh RUN if [ "$TARGETARCH" = "arm64" ]; then /app/scripts/docker-arm64-post-install.sh; fi -FROM python:3.12-slim +# Pin Debian 12 (bookworm, glibc 2.36) — the bare `python:3.12-slim` tag tracks +# Debian stable, which moved to trixie (glibc ~2.41) and broke the runtime FLA +# TileLang nvcc JIT (`bits/mathcalls.h: cospi/sinpi noexcept` conflict vs the +# mounted CUDA host headers). bookworm's glibc matches the ubuntu22.04 builder. +FROM python:3.12-slim-bookworm RUN apt-get update && apt-get install -y \ --no-install-recommends \ From dbdbeea0530bd6adff85225705a7a86657432f0b Mon Sep 17 00:00:00 2001 From: Eli Gottlieb <78387377+eligotts@users.noreply.github.com> Date: Mon, 1 Jun 2026 00:50:10 +0000 Subject: [PATCH 14/31] feat(mm): features-only artifact eviction (30m TTL) + last-use mtime Stop sweeping raw images; only evict the expensive mm_features cache. Images are the run's recoverable source of truth (terminal, non-regenerable) and the trainer rebuilds pixels from them, so they are kept; features are regenerable and disposable. Drop mm_artifact_ttl_seconds 7200 -> 1800: the horizon only needs to exceed the write->vLLM-admit window (seconds), so 30 min is a large safety margin against racing in-flight reads while keeping disk bounded. offload_images_to_disk now refreshes mtime on the skip-if-exists path so all content-addressed writers (images + features) share last-use semantics, making a future image sweep safe by default. Bumps deps/renderers (features-only sweep + feature-writer mtime) and deps/verifiers (image-writer mtime). Co-Authored-By: Codex Co-Authored-By: Claude Opus 4.8 (1M context) --- deps/renderers | 2 +- deps/verifiers | 2 +- .../src/prime_rl/configs/orchestrator.py | 4 ++-- src/prime_rl/orchestrator/orchestrator.py | 12 +++++++----- src/prime_rl/orchestrator/trajectories.py | 9 +++++++++ 5 files changed, 20 insertions(+), 9 deletions(-) diff --git a/deps/renderers b/deps/renderers index d6ed224839..10b71d6271 160000 --- a/deps/renderers +++ b/deps/renderers @@ -1 +1 @@ -Subproject commit d6ed2248394cc41b5df9e74ad4c0f2b384601596 +Subproject commit 10b71d627184d4db5448fb12e2941e42b32b07b4 diff --git a/deps/verifiers b/deps/verifiers index 7ec7169cd8..8085c6ceb6 160000 --- a/deps/verifiers +++ b/deps/verifiers @@ -1 +1 @@ -Subproject commit 7ec7169cd887cfe7fac9c5cae97b9547448d40d5 +Subproject commit 8085c6ceb65e116d5eec9b4db3f7835f34f265af diff --git a/packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py b/packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py index 9be54a2472..10a7553802 100644 --- a/packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py +++ b/packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py @@ -633,8 +633,8 @@ def _preserve_mito_renderer(self, handler: SerializerFunctionWrapHandler) -> dic output_dir: Path = Path("outputs/run_default") """Directory to write outputs to — checkpoints, weights, rollouts, and logs are written as subdirectories. Should be a persistent directory with enough disk space and unique per experiment running on a single node.""" - mm_artifact_ttl_seconds: float = 3600.0 - """TTL (seconds) for offloaded multimodal artifacts under ``output_dir/assets/{images,mm_features}``. Once per step the orchestrator deletes artifact files older than this. Artifacts are content-addressed and re-materializable, so over-eviction is safe (triggers a re-write) while under-eviction only wastes disk — bias large. Defaults to 1 hour.""" + mm_artifact_ttl_seconds: float = 1800.0 + """TTL (seconds) for offloaded multimodal ``mm_features`` artifacts under ``output_dir/assets/mm_features``. Once per step the orchestrator deletes feature files older than this. Features ONLY: source images under ``assets/images`` are never swept (they are terminal browser output with no regeneration path and are kept for the whole run as the recoverable source). Features are a regenerable cache (trainer rebuilds pixels from the image; env-worker rewrites missing features on demand), so over-eviction just forces a reprocess. The TTL only needs to exceed the write→vLLM-admit window (seconds), so minutes leave a large safety margin against racing in-flight reads. Defaults to 30 minutes.""" tasks_per_minute: int | None = Field(None, ge=1) """Rate limit per environment worker, in tasks per minute. Recommended for sandbox-backed environments to prevent sandbox-not-ready errors during autoscaling. With multiple workers, the effective total rate is ``workers × this value``. None disables rate limiting.""" diff --git a/src/prime_rl/orchestrator/orchestrator.py b/src/prime_rl/orchestrator/orchestrator.py index 3a810dbf94..9d1524bf49 100644 --- a/src/prime_rl/orchestrator/orchestrator.py +++ b/src/prime_rl/orchestrator/orchestrator.py @@ -471,13 +471,15 @@ async def orchestrate(config: OrchestratorConfig): f"Offloaded {num_offloaded} unique images to disk in {time.perf_counter() - offload_start:.2f}s" ) - # Evict stale offloaded multimodal artifacts under this run's dir (images + - # mm_features). No-op for text-only runs (asset dirs absent). Content- - # addressed + re-writable, so over-eviction is safe; each run sweeps its - # own dir → multi-run safe. + # Evict stale offloaded mm_features (the expensive processed payloads) under + # this run's dir. Features ONLY — source images are kept for the whole run + # (terminal output, no regeneration path); features are a regenerable cache + # the trainer never reads, so over-eviction just forces a reprocess. No-op + # for text-only runs (feature dir absent); each run sweeps its own dir → + # multi-run safe. num_swept = mm_store.sweep_stale_artifacts(asset_root, config.mm_artifact_ttl_seconds) if num_swept: - logger.info(f"Swept {num_swept} stale multimodal artifacts (ttl={config.mm_artifact_ttl_seconds}s)") + logger.info(f"Swept {num_swept} stale mm_feature artifacts (ttl={config.mm_artifact_ttl_seconds}s)") # Convert rollouts to training samples parallel_preprocess_start = time.perf_counter() diff --git a/src/prime_rl/orchestrator/trajectories.py b/src/prime_rl/orchestrator/trajectories.py index 48e10aed8b..ea39a352bf 100644 --- a/src/prime_rl/orchestrator/trajectories.py +++ b/src/prime_rl/orchestrator/trajectories.py @@ -691,6 +691,15 @@ def offload_images_to_disk(rollouts: list[vf.RolloutOutput], output_dir: Path) - if content_hash not in written: if not path.exists(): path.write_bytes(raw) + else: + # Recurring image already on disk: refresh its mtime so a + # future last-use sweep treats it as hot. Images aren't + # evicted today, but this keeps the practice consistent + # with the mm_feature writer. Best-effort on a sweep race. + try: + path.touch() + except OSError: + pass written.add(content_hash) image_url["url"] = f"{_FILE_URL_PREFIX}{path}" From d859ac6042a1f3b9cfcf29df960d00c200956762 Mon Sep 17 00:00:00 2001 From: Eli Gottlieb <78387377+eligotts@users.noreply.github.com> Date: Mon, 1 Jun 2026 00:50:18 +0000 Subject: [PATCH 15/31] fix(qwen-vlm): raise a clear error on image token/feature mismatch Before the image-embed masked_scatter, assert that the image-token count equals the image-feature count and raise a descriptive ValueError (token id, both counts, input_ids/pixel_values/grid shapes) when they differ. A mismatch otherwise surfaces as an opaque CUDA masked_scatter device-side assert that is near-impossible to diagnose. This is a loud tripwire, not a masking guard: it fails fast with the context needed to find the bad mm sidecar. Co-Authored-By: Codex Co-Authored-By: Claude Opus 4.8 (1M context) --- .../models/qwen3_5_moe/modeling_qwen3_5_moe.py | 11 +++++++++++ tests/unit/train/models/test_qwen3_5_moe_vlm.py | 16 ++++++++++++++++ 2 files changed, 27 insertions(+) diff --git a/src/prime_rl/trainer/models/qwen3_5_moe/modeling_qwen3_5_moe.py b/src/prime_rl/trainer/models/qwen3_5_moe/modeling_qwen3_5_moe.py index a03713a3af..49c0944f42 100644 --- a/src/prime_rl/trainer/models/qwen3_5_moe/modeling_qwen3_5_moe.py +++ b/src/prime_rl/trainer/models/qwen3_5_moe/modeling_qwen3_5_moe.py @@ -841,6 +841,17 @@ def forward( image_embeds = vision_output.pooler_output.to(inputs_embeds.device, inputs_embeds.dtype) image_mask = input_ids == self.config.image_token_id + image_token_count = int(image_mask.sum().item()) + image_feature_count = int(image_embeds.shape[0]) + if image_token_count != image_feature_count: + raise ValueError( + "Qwen VLM image token/feature mismatch before scatter: " + f"image_token_id={self.config.image_token_id}, " + f"image_tokens={image_token_count}, image_features={image_feature_count}, " + f"input_ids_shape={tuple(input_ids.shape)}, " + f"pixel_values_shape={tuple(pixel_values.shape)}, " + f"image_grid_thw_shape={tuple(image_grid_thw.shape) if image_grid_thw is not None else None}" + ) image_mask = image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) diff --git a/tests/unit/train/models/test_qwen3_5_moe_vlm.py b/tests/unit/train/models/test_qwen3_5_moe_vlm.py index c629f6be25..4210e3fdf0 100644 --- a/tests/unit/train/models/test_qwen3_5_moe_vlm.py +++ b/tests/unit/train/models/test_qwen3_5_moe_vlm.py @@ -87,6 +87,22 @@ def test_vlm_forward(): assert out_mm["logits"].shape == (1, input_ids_mm.shape[1], vocab) +def test_vlm_forward_rejects_image_token_feature_mismatch(): + """Bad MM sidecars should fail before CUDA masked_scatter asserts.""" + config = _tiny_vlm_config() + with torch.device("cuda"), default_dtype(torch.float32): + model = Qwen3_5MoeForCausalLM(config) + inject_prime_lm_head(model) + + pixel_values, image_grid_thw, n_img_tokens = _make_image_inputs(config) + text_part = torch.randint(0, 200, (1, 10), device="cuda") + img_part = torch.full((1, n_img_tokens + 1), config.image_token_id, device="cuda") + input_ids = torch.cat([text_part[:, :5], img_part, text_part[:, 5:]], dim=1) + + with pytest.raises(ValueError, match="image token/feature mismatch"): + model(input_ids=input_ids, pixel_values=pixel_values, image_grid_thw=image_grid_thw) + + def test_vlm_backward(): """Gradients flow through both vision scatter and text model.""" config = _tiny_vlm_config() From 1b24bd72c5f6d3ef4b0d7f8638cf69ef45d75bf7 Mon Sep 17 00:00:00 2001 From: Eli Gottlieb <78387377+eligotts@users.noreply.github.com> Date: Mon, 1 Jun 2026 00:50:28 +0000 Subject: [PATCH 16/31] test(orchestrator): assert step-back delta sample is self-contained Cover the cross-repo contract with the verifiers trajectory delta encoder: when a stale-prefix step-back starts a new TrainingSample, prime-rl will not merge it into the prior sample, so its mm_data delta must carry the full cumulative window. Guards against a partial delta that would drop images on step-back. Co-Authored-By: Codex Co-Authored-By: Claude Opus 4.8 (1M context) --- tests/unit/orchestrator/test_trajectories.py | 108 +++++++++++++++++++ 1 file changed, 108 insertions(+) diff --git a/tests/unit/orchestrator/test_trajectories.py b/tests/unit/orchestrator/test_trajectories.py index 0ed7b16c41..d2826ae220 100644 --- a/tests/unit/orchestrator/test_trajectories.py +++ b/tests/unit/orchestrator/test_trajectories.py @@ -1122,6 +1122,114 @@ def test_interleave_rollout_packs_pixels_from_renderer_mm_data(): assert sample.mm_token_type_ids == [0, 1, 0, 0, 2, 0, 0] +def test_interleave_rollout_step_back_delta_sample_is_self_contained(): + """A stale-prefix step-back starts a new TrainingSample, so its mm_data + delta must be full cumulative for that new window. + + This models the cross-repo contract with verifiers' trajectory delta + encoder: after step 1 advances the active sample from step 0, step 2 steps + back to step 0's historical prefix. PrimeRL will not merge it into sample + 0, so step 2 must carry both A and C. + """ + import torch as _torch + from renderers.base import MultiModalData, PlaceholderRange + from verifiers.utils.save_utils import _delta_intermediate_mm_data + + image_token = 2 + + def mm(*hashes: str) -> MultiModalData: + return MultiModalData( + mm_hashes={"image": list(hashes)}, + mm_placeholders={"image": [PlaceholderRange(offset=i * 10, length=1) for i, _ in enumerate(hashes)]}, + mm_items={ + "image": [ + { + "pixel_values": _torch.tensor([[float(i)]], dtype=_torch.float32), + "image_grid_thw": _torch.tensor([[1, 2, 2]], dtype=_torch.int64), + } + for i, _ in enumerate(hashes) + ] + }, + ) + + raw_steps = [ + { + "tokens": { + "prompt_ids": [10, image_token, 11], + "prompt_mask": [False, False, False], + "completion_ids": [12], + "completion_mask": [True], + "completion_logprobs": [-0.1], + "multi_modal_data": mm("A"), + } + }, + { + "tokens": { + "prompt_ids": [10, image_token, 11, 12, 13, image_token], + "prompt_mask": [False] * 6, + "completion_ids": [14], + "completion_mask": [True], + "completion_logprobs": [-0.2], + "multi_modal_data": mm("A", "B"), + } + }, + { + "tokens": { + # Extends step 0's historical prefix, not step 1's active prefix. + "prompt_ids": [10, image_token, 11, 12, 15, image_token], + "prompt_mask": [False] * 6, + "completion_ids": [16], + "completion_mask": [True], + "completion_logprobs": [-0.3], + "multi_modal_data": mm("A", "C"), + } + }, + ] + delta_steps = _delta_intermediate_mm_data(raw_steps) + + trajectory = [] + for idx, raw_step in enumerate(delta_steps): + t = raw_step["tokens"] + trajectory.append( + vf.TrajectoryStep( + prompt=[{"role": "user", "content": f"turn {idx}"}], + completion=[{"role": "assistant", "content": f"response {idx}"}], + response=MagicMock(), + tokens=vf.TrajectoryStepTokens( + prompt_ids=t["prompt_ids"], + prompt_mask=t["prompt_mask"], + completion_ids=t["completion_ids"], + completion_mask=t["completion_mask"], + completion_logprobs=t["completion_logprobs"], + overlong_prompt=False, + is_truncated=False, + multi_modal_data=t["multi_modal_data"], + ), + reward=None, + advantage=None, + is_truncated=False, + trajectory_id="1", + extras={}, + ) + ) + + rollouts = interleave_rollout( + vf.RolloutOutput( + example_id=1, + trajectory=trajectory, + sampling_args={"temperature": 1.0}, + error=None, + ), + mm_token_type_ids_mapping={image_token: 1}, + ) + + assert rollouts is not None and len(rollouts) == 2 + assert sum(x == image_token for x in rollouts[0].prompt_ids + rollouts[0].completion_ids) == 2 + assert _decode_mm_thw(rollouts[0]) == [[1, 2, 2], [1, 2, 2]] + assert sum(x == image_token for x in rollouts[1].prompt_ids + rollouts[1].completion_ids) == 2 + assert _decode_mm_thw(rollouts[1]) == [[1, 2, 2], [1, 2, 2]] + + def test_interleave_rollout_skips_pixel_materialization_for_filtered_rollout(): """A filtered rollout's samples are dropped before the trainer, so ``interleave_rollout`` must skip the expensive pixel reconstruction — the From a2e89e36a8047fff19671e76c588cd4158ed4e7c Mon Sep 17 00:00:00 2001 From: Eli Gottlieb <78387377+eligotts@users.noreply.github.com> Date: Wed, 3 Jun 2026 23:23:56 +0000 Subject: [PATCH 17/31] feat(trainer): pack multimodal samples into microbatches Allow multimodal samples to share a packed microbatch (gated by trainer.pack_multimodal) instead of each MM sample getting its own. Packing is enabled only where it is provably correct: the model advertises the pass_1d position strategy, flash-attn varlen is active (block-diagonal cu_seqlens from per-sample-reset positions isolates samples), context parallelism is off, and a single run per microbatch (the MoE LoRA path applies one adapter per microbatch). MM samples still never pack with text (FSDP per-step modality invariant) nor mix refs/kwargs sidecars; sidecars are concatenated in sample order. Cleanups to the packing primitive: - drop the redundant EncodedTensor payload-length validators (a malformed payload already fails loudly downstream at frombuffer(...).reshape); keep the dtype/shape concat precondition. - dedup the sidecar dispatch in _append_micro_batch (the kind match is already guaranteed by _can_pack_sample). Bumps deps/renderers and deps/verifiers to their OOM-hunt-cleanup heads. Co-Authored-By: Codex Co-Authored-By: Claude Opus 4.8 (1M context) --- deps/renderers | 2 +- deps/verifiers | 2 +- .../src/prime_rl/configs/trainer.py | 3 + src/prime_rl/trainer/batch.py | 204 +++++++++++---- src/prime_rl/trainer/model.py | 17 +- .../qwen3_5_moe/modeling_qwen3_5_moe.py | 1 + src/prime_rl/trainer/rl/data.py | 2 + src/prime_rl/trainer/rl/packer.py | 31 ++- src/prime_rl/trainer/rl/train.py | 16 ++ src/prime_rl/utils/vlm.py | 58 +++++ tests/unit/train/rl/test_packer.py | 30 ++- tests/unit/train/test_model_forward.py | 19 ++ tests/unit/trainer/test_mm_refs.py | 239 +++++++++++++++++- tests/unit/utils/test_vlm.py | 44 ++++ 14 files changed, 608 insertions(+), 60 deletions(-) create mode 100644 tests/unit/utils/test_vlm.py diff --git a/deps/renderers b/deps/renderers index 10b71d6271..a8f874c416 160000 --- a/deps/renderers +++ b/deps/renderers @@ -1 +1 @@ -Subproject commit 10b71d627184d4db5448fb12e2941e42b32b07b4 +Subproject commit a8f874c416ecb155250db4ca1b732384018289af diff --git a/deps/verifiers b/deps/verifiers index 8085c6ceb6..a7fc7431b4 160000 --- a/deps/verifiers +++ b/deps/verifiers @@ -1 +1 @@ -Subproject commit 8085c6ceb65e116d5eec9b4db3f7835f34f265af +Subproject commit a7fc7431b41fb14bb069a5e0f68be24d402a11de diff --git a/packages/prime-rl-configs/src/prime_rl/configs/trainer.py b/packages/prime-rl-configs/src/prime_rl/configs/trainer.py index ce434ebaf6..82ff3eb73f 100644 --- a/packages/prime-rl-configs/src/prime_rl/configs/trainer.py +++ b/packages/prime-rl-configs/src/prime_rl/configs/trainer.py @@ -566,6 +566,9 @@ class TrainerConfig(BaseConfig): defer_mm_materialization: bool = True """Defer multimodal pixel materialization from the orchestrator to the trainer. When True, the orchestrator ships lightweight image references (``mm_refs``) and the trainer materializes pixels in its data loader. Must match the orchestrator's setting; requires ``renderer`` to be set for VLM runs. A no-op for text-only runs (no ``mm_refs`` ever arrive).""" + pack_multimodal: bool = True + """Pack multimodal samples together when the active model path supports packed multimodal position boundaries. Default-on, but the trainer gates it off for unsupported VLM/HF MRoPE paths, non-varlen attention, or context parallelism.""" + renderer: RendererConfig | None = AutoRendererConfig() """Typed renderer config (``renderers.RendererConfig`` discriminated union), mirroring the orchestrator's. Auto-resolves from the model by default so VLM defer runs work without restating it; only used by VLM runs (text-only ignores it).""" diff --git a/src/prime_rl/trainer/batch.py b/src/prime_rl/trainer/batch.py index cb1f3a100a..9cb2e4adc5 100644 --- a/src/prime_rl/trainer/batch.py +++ b/src/prime_rl/trainer/batch.py @@ -1,6 +1,6 @@ import copy -from prime_rl.transport.types import MicroBatch, RoutedExperts, TrainingSample +from prime_rl.transport.types import EncodedTensor, MicroBatch, MMRefs, RoutedExperts, TrainingSample ROUTED_EXPERTS_DTYPE_ITEMSIZE = { "uint8": 1, @@ -49,6 +49,135 @@ def _pad_routed_experts(micro_batch: MicroBatch, padding_size: int) -> None: routed_experts.shape[0] += padding_size +def _append_encoded_tensor(dst: EncodedTensor, src: EncodedTensor, key: str) -> None: + # Concatenate along dim 0; dtype and trailing dims must match. A malformed + # payload (data length inconsistent with shape) surfaces loudly downstream + # when the trainer does frombuffer(...).reshape(shape), so we don't re-check it here. + if dst.dtype != src.dtype: + raise ValueError(f"Cannot pack mm_kwargs[{key!r}] with different dtypes: {dst.dtype} vs {src.dtype}") + if len(dst.shape) != len(src.shape) or dst.shape[1:] != src.shape[1:]: + raise ValueError(f"Cannot pack mm_kwargs[{key!r}] with incompatible shapes: {dst.shape} vs {src.shape}") + dst.data += src.data + dst.shape[0] += src.shape[0] + + +def _append_mm_kwargs(dst: dict[str, EncodedTensor], src: dict[str, EncodedTensor]) -> None: + if set(dst) != set(src): + raise ValueError(f"Cannot pack mm_kwargs with different keys: {sorted(dst)} vs {sorted(src)}") + for key in dst: + _append_encoded_tensor(dst[key], src[key], key) + + +def _append_mm_ref_descriptor_list(dst_map: dict, src_map: dict, field: str) -> None: + if set(dst_map) != set(src_map): + raise ValueError(f"Cannot pack mm_refs descriptor {field} with different modalities") + for modality, src_items in src_map.items(): + dst_items = dst_map[modality] + if not isinstance(dst_items, list) or not isinstance(src_items, list): + raise ValueError(f"mm_refs descriptor {field}[{modality!r}] must be a list to pack") + dst_items.extend(copy.deepcopy(src_items)) + + +def _append_mm_refs(dst: MMRefs, src: MMRefs) -> None: + dst_items = dst.descriptor.get("mm_items") or {} + src_items = src.descriptor.get("mm_items") or {} + dst_hashes = dst.descriptor.get("mm_hashes") or {} + src_hashes = src.descriptor.get("mm_hashes") or {} + + _append_mm_ref_descriptor_list(dst_items, src_items, "mm_items") + _append_mm_ref_descriptor_list(dst_hashes, src_hashes, "mm_hashes") + dst.descriptor["mm_items"] = dst_items + dst.descriptor["mm_hashes"] = dst_hashes + dst.uris.extend(src.uris) + + +def _mm_sidecar_kind(sample: MicroBatch) -> str | None: + if sample.mm_kwargs is not None and sample.mm_refs is not None: + raise ValueError("A multimodal sample cannot carry both mm_kwargs and mm_refs") + if sample.mm_refs is not None: + return "refs" + if sample.mm_kwargs is not None: + return "kwargs" + return None + + +def _single_lora_idx(sample: MicroBatch) -> int | None: + if sample.lora_num_tokens is None: + return None + active = [idx for idx, tokens in enumerate(sample.lora_num_tokens) if tokens > 0] + return active[0] if len(active) == 1 else None + + +def _can_pack_sample( + bin_content: MicroBatch, + sample: MicroBatch, + *, + idx: int, + max_seq_len: int, + pack_multimodal: bool, +) -> bool: + if len(bin_content.input_ids) + len(sample.input_ids) > max_seq_len: + return False + if bin_content.training_mode != sample.training_mode: + return False + + bin_mm_kind = _mm_sidecar_kind(bin_content) + sample_mm_kind = _mm_sidecar_kind(sample) + if bin_mm_kind is None and sample_mm_kind is None: + return True + if not pack_multimodal or bin_mm_kind != sample_mm_kind: + return False + # Multimodal samples only pack with the same run: a multi-run microbatch would + # break the MoE LoRA path (one adapter per microbatch). prepare_batch may be + # called with multi-run input, so this guard is load-bearing, not redundant. + return _single_lora_idx(bin_content) == idx + + +def _append_micro_batch(bin_content: MicroBatch, sample: MicroBatch, idx: int) -> None: + existing_len = len(bin_content.input_ids) + sample_len = len(sample.input_ids) + + bin_content.input_ids.extend(sample.input_ids) + bin_content.loss_mask.extend(sample.loss_mask) + bin_content.advantages.extend(sample.advantages) + if sample.rewards is not None: + if bin_content.rewards is None: + bin_content.rewards = [float("nan")] * existing_len + bin_content.rewards.extend(sample.rewards) + elif bin_content.rewards is not None: + bin_content.rewards.extend([float("nan")] * sample_len) + bin_content.inference_logprobs.extend(sample.inference_logprobs) + bin_content.temperatures.extend(sample.temperatures) + if bin_content.teacher_logprobs is not None or sample.teacher_logprobs is not None: + if bin_content.teacher_logprobs is None: + bin_content.teacher_logprobs = [0.0] * existing_len + bin_content.teacher_logprobs.extend(sample.teacher_logprobs or [0.0] * sample_len) + + assert (bin_content.routed_experts is None) == (sample.routed_experts is None) + if sample.routed_experts is not None: + if bin_content.routed_experts is None: + bin_content.routed_experts = _copy_routed_experts(sample.routed_experts) + else: + _append_routed_experts(bin_content, sample) + + if bin_content.mm_token_type_ids is not None or sample.mm_token_type_ids is not None: + if bin_content.mm_token_type_ids is None: + bin_content.mm_token_type_ids = [0] * existing_len + bin_content.mm_token_type_ids.extend(sample.mm_token_type_ids or [0] * sample_len) + + bin_content.env_names.extend(sample.env_names) + bin_content.position_ids.extend(sample.position_ids) + assert bin_content.lora_num_tokens is not None + bin_content.lora_num_tokens[idx] += sample_len + + # Concatenate the multimodal sidecar. _can_pack_sample already guaranteed the + # bin and sample share the same kind, so dispatch on whichever is present. + if bin_content.mm_refs is not None: + _append_mm_refs(bin_content.mm_refs, sample.mm_refs) + elif bin_content.mm_kwargs is not None: + _append_mm_kwargs(bin_content.mm_kwargs, sample.mm_kwargs) + + def prepare_sample(training_example: TrainingSample, seq_len: int) -> MicroBatch: """ Prepare a problem for sequence packing training. @@ -77,6 +206,12 @@ def prepare_sample(training_example: TrainingSample, seq_len: int) -> MicroBatch _copy_routed_experts(training_example.routed_experts) if training_example.routed_experts is not None else None ) + if (training_example.mm_kwargs is not None or training_example.mm_refs is not None) and len(input_ids) > seq_len: + raise ValueError( + "Cannot truncate multimodal training sample without also truncating its multimodal sidecars: " + f"sample_len={len(input_ids)}, seq_len={seq_len}" + ) + if len(input_ids) > seq_len: input_ids = input_ids[:seq_len] loss_mask = loss_mask[:seq_len] @@ -131,8 +266,8 @@ def prepare_sample(training_example: TrainingSample, seq_len: int) -> MicroBatch routed_experts=routed_experts, mm_token_type_ids=mm_token_type_ids, env_names=env_names, - mm_kwargs=training_example.mm_kwargs, - mm_refs=training_example.mm_refs, + mm_kwargs=copy.deepcopy(training_example.mm_kwargs), + mm_refs=copy.deepcopy(training_example.mm_refs), training_mode=training_example.training_mode, ) @@ -146,15 +281,18 @@ def _is_multimodal_sample(sample: MicroBatch) -> bool: def packed_samples_into_micro_bs( - samples: list[tuple[int, MicroBatch]], max_seq_len: int, num_loras: int + samples: list[tuple[int, MicroBatch]], max_seq_len: int, num_loras: int, pack_multimodal: bool = False ) -> list[MicroBatch]: """ Pack samples into micro_batch efficiently. We follow the First Fit Decreasing algorithm to pack the samples into bins and minimize potential padding while never truncating. With per-token temperatures, samples can be packed together regardless of their temperature values. - NOTE: Multimodal samples (with mm_kwargs) are NOT packed together as they have variable-sized - vision data that doesn't pack well. Each multimodal sample becomes its own micro batch. + Multimodal samples are only packed when ``pack_multimodal`` is true. They + pack with other multimodal samples of the same sidecar representation + (deferred ``mm_refs`` or eager ``mm_kwargs``), never with text-only samples. + The caller is responsible for enabling this only for model paths whose + position handling supports packed multimodal boundaries. """ # Sort by (lora_idx, -length) for packing efficiency samples.sort(key=lambda x: (x[0], -len(x[1].input_ids))) @@ -163,52 +301,25 @@ def packed_samples_into_micro_bs( micro_batches: list[MicroBatch] = [] for idx, sample in samples: - # Multimodal samples cannot be packed - each becomes its own micro batch - if _is_multimodal_sample(sample): + # Unsupported multimodal samples remain standalone. Supported multimodal + # samples use the same token-side first-fit packing as text, with strict + # sidecar concatenation in the same sample order. + if _is_multimodal_sample(sample) and not pack_multimodal: sample.lora_num_tokens = [0] * num_loras sample.lora_num_tokens[idx] = len(sample.input_ids) micro_batches.append(sample) continue - # Try to find a bin that can fit this sequence (only pack text-only samples) + # Try to find a bin that can fit this sequence. for bin_content in micro_batches: - # Don't pack into multimodal micro batches - if _is_multimodal_sample(bin_content): - continue - # Check if sequence fits in this bin - if ( - len(bin_content.input_ids) + len(sample.input_ids) <= max_seq_len - and bin_content.training_mode == sample.training_mode + if _can_pack_sample( + bin_content, + sample, + idx=idx, + max_seq_len=max_seq_len, + pack_multimodal=pack_multimodal, ): - existing_len = len(bin_content.input_ids) - bin_content.input_ids.extend(sample.input_ids) - bin_content.loss_mask.extend(sample.loss_mask) - bin_content.advantages.extend(sample.advantages) - if sample.rewards is not None: - if bin_content.rewards is None: - bin_content.rewards = [float("nan")] * existing_len - bin_content.rewards.extend(sample.rewards) - elif bin_content.rewards is not None: - bin_content.rewards.extend([float("nan")] * len(sample.input_ids)) - bin_content.inference_logprobs.extend(sample.inference_logprobs) - bin_content.temperatures.extend(sample.temperatures) - if sample.teacher_logprobs is not None: - if bin_content.teacher_logprobs is None: - bin_content.teacher_logprobs = [] - bin_content.teacher_logprobs.extend(sample.teacher_logprobs) - assert (bin_content.routed_experts is None) == (sample.routed_experts is None) - if sample.routed_experts is not None: - if bin_content.routed_experts is None: - bin_content.routed_experts = _copy_routed_experts(sample.routed_experts) - else: - _append_routed_experts(bin_content, sample) - if sample.mm_token_type_ids is not None: - if bin_content.mm_token_type_ids is None: - bin_content.mm_token_type_ids = [] - bin_content.mm_token_type_ids.extend(sample.mm_token_type_ids) - bin_content.env_names.extend(sample.env_names) - bin_content.position_ids.extend(sample.position_ids) - bin_content.lora_num_tokens[idx] += len(sample.input_ids) + _append_micro_batch(bin_content, sample, idx) break else: sample.lora_num_tokens = [0] * num_loras @@ -287,6 +398,7 @@ def prepare_batch( idxs: list[int], num_loras: int, pad_to_multiple_of: int = 1, + pack_multimodal: bool = False, ) -> list[list[MicroBatch]]: """ Prepare a batch of problems for each GPU. Each batch is a list of micro batches. @@ -299,7 +411,7 @@ def prepare_batch( """ all_samples = [(idx, prepare_sample(rollout, seq_len)) for idx, rollout in zip(idxs, rollouts)] - micro_batches = packed_samples_into_micro_bs(all_samples, seq_len, num_loras) + micro_batches = packed_samples_into_micro_bs(all_samples, seq_len, num_loras, pack_multimodal=pack_multimodal) micro_batches = [pad_micro_batch(micro_batch, pad_to_multiple_of) for micro_batch in micro_batches] # Separate by modality so each step index has uniform modality across all ranks diff --git a/src/prime_rl/trainer/model.py b/src/prime_rl/trainer/model.py index 6e82d6d8d4..2fdb232f97 100644 --- a/src/prime_rl/trainer/model.py +++ b/src/prime_rl/trainer/model.py @@ -55,7 +55,12 @@ from prime_rl.utils.logger import get_logger from prime_rl.utils.sequence import get_cu_seqlens_from_position_ids from prime_rl.utils.utils import format_time -from prime_rl.utils.vlm import get_language_model, get_vision_encoder, is_vlm_architecture +from prime_rl.utils.vlm import ( + get_language_model, + get_packed_mm_position_strategy, + get_vision_encoder, + is_vlm_architecture, +) def pre_download_model(model_name: str) -> None: @@ -1144,11 +1149,11 @@ def forward( kwargs.update(mm_kwargs) if mm_token_type_ids is not None: kwargs["mm_token_type_ids"] = mm_token_type_ids - # ``position_ids`` for MRoPE families: Qwen3-VL's HF forward - # recomputes 3D positions from ``image_grid_thw`` and breaks if - # given the trainer's pre-computed 1D ``position_ids``. Detect - # via the mm_kwargs shape so we don't enumerate model_types. - if "image_grid_thw" not in mm_kwargs: + # HF Qwen-style MRoPE models must compute 3D/4-row multimodal + # positions from ``image_grid_thw`` internally. Custom Prime VLMs with + # ``pass_1d`` consume reset 1D positions and derive packed boundaries + # from them, so they must receive trainer ``position_ids``. + if "image_grid_thw" not in mm_kwargs or get_packed_mm_position_strategy(model) == "pass_1d": kwargs["position_ids"] = position_ids else: kwargs["position_ids"] = position_ids diff --git a/src/prime_rl/trainer/models/qwen3_5_moe/modeling_qwen3_5_moe.py b/src/prime_rl/trainer/models/qwen3_5_moe/modeling_qwen3_5_moe.py index 49c0944f42..a726621006 100644 --- a/src/prime_rl/trainer/models/qwen3_5_moe/modeling_qwen3_5_moe.py +++ b/src/prime_rl/trainer/models/qwen3_5_moe/modeling_qwen3_5_moe.py @@ -903,6 +903,7 @@ class Qwen3_5MoeForCausalLM(Qwen3_5MoePreTrainedModel, GenerationMixin): def __init__(self, config, **kwargs): super().__init__(config, **kwargs) self._is_vlm = hasattr(config, "vision_config") + self.packed_mm_position_strategy = "pass_1d" if self._is_vlm else "none" if self._is_vlm: self.model = Qwen3_5MoeVLMModel(config) diff --git a/src/prime_rl/trainer/rl/data.py b/src/prime_rl/trainer/rl/data.py index 79958b3acf..7c3d763518 100644 --- a/src/prime_rl/trainer/rl/data.py +++ b/src/prime_rl/trainer/rl/data.py @@ -168,6 +168,7 @@ def __init__( config: TransportConfig, defer_mm_materialization: bool = False, renderer_config: RendererConfig | None = None, + pack_multimodal: bool = False, ): self.world = get_world() @@ -179,6 +180,7 @@ def __init__( transport_config=config, pad_to_multiple_of=pad_to_multiple_of, start_step=start_step, + pack_multimodal=pack_multimodal, ) non_dp_world_size = self.world.world_size // dp_world_size diff --git a/src/prime_rl/trainer/rl/packer.py b/src/prime_rl/trainer/rl/packer.py index cf9dcfa02e..eac150c0e1 100644 --- a/src/prime_rl/trainer/rl/packer.py +++ b/src/prime_rl/trainer/rl/packer.py @@ -33,6 +33,7 @@ def __init__( tokenizer: PreTrainedTokenizer, config: TransportConfig, start_step: int = 0, + pack_multimodal: bool = False, ): self.logger = get_logger() self.multi_run_manager = get_multi_run_manager() @@ -40,6 +41,7 @@ def __init__( self.seq_len = seq_len self.pad_to_multiple_of = pad_to_multiple_of self.tokenizer = tokenizer + self.pack_multimodal = pack_multimodal self.receiver = setup_training_batch_receiver(config) shutil.rmtree(get_rollout_dir(self.multi_run_manager.output_dir), ignore_errors=True) self.sender: MicroBatchSender = setup_micro_batch_sender( @@ -85,8 +87,9 @@ def __init__( tokenizer: PreTrainedTokenizer, config: TransportConfig, start_step: int = 0, + pack_multimodal: bool = False, ): - super().__init__(dp_world_size, seq_len, pad_to_multiple_of, tokenizer, config, start_step) + super().__init__(dp_world_size, seq_len, pad_to_multiple_of, tokenizer, config, start_step, pack_multimodal) assert self.multi_run_manager.max_runs == 1, "SinglePacker only supports one run" def pack(self): @@ -110,6 +113,7 @@ def pack(self): num_train_workers=self.dp_world_size, idxs=[0] * len(batch.examples), num_loras=self.multi_run_manager.max_runs, + pack_multimodal=self.pack_multimodal, ) self.sender.send(micro_batch_grid) @@ -124,8 +128,9 @@ def __init__( tokenizer: PreTrainedTokenizer, config: TransportConfig, start_step: int = 0, + pack_multimodal: bool = False, ): - super().__init__(dp_world_size, seq_len, pad_to_multiple_of, tokenizer, config, start_step) + super().__init__(dp_world_size, seq_len, pad_to_multiple_of, tokenizer, config, start_step, pack_multimodal) # Per-run buffer: stores (TrainingSample, step) tuples self.buffers: list[deque[tuple[TrainingSample, int]]] = [ deque() for _ in range(self.multi_run_manager.max_runs) @@ -327,6 +332,7 @@ def pack(self): num_train_workers=self.dp_world_size, idxs=[run_idx] * len(run_samples), num_loras=self.multi_run_manager.max_runs, + pack_multimodal=self.pack_multimodal, ) # Merge into combined grid for worker_idx, worker_batches in enumerate(run_micro_batch_grid): @@ -342,9 +348,26 @@ def setup_packer( tokenizer: PreTrainedTokenizer, transport_config: TransportConfig, start_step: int = 0, + pack_multimodal: bool = False, ) -> BasePacker: multi_run_manager = get_multi_run_manager() if multi_run_manager.max_runs == 1: - return SinglePacker(dp_world_size, seq_len, pad_to_multiple_of, tokenizer, transport_config, start_step) + return SinglePacker( + dp_world_size, + seq_len, + pad_to_multiple_of, + tokenizer, + transport_config, + start_step, + pack_multimodal, + ) else: - return MultiPacker(dp_world_size, seq_len, pad_to_multiple_of, tokenizer, transport_config, start_step) + return MultiPacker( + dp_world_size, + seq_len, + pad_to_multiple_of, + tokenizer, + transport_config, + start_step, + pack_multimodal, + ) diff --git a/src/prime_rl/trainer/rl/train.py b/src/prime_rl/trainer/rl/train.py index cd0b05c574..56c66937c7 100644 --- a/src/prime_rl/trainer/rl/train.py +++ b/src/prime_rl/trainer/rl/train.py @@ -26,6 +26,7 @@ shard_for_cp, ) from prime_rl.utils.logger import setup_logger +from prime_rl.utils.vlm import get_packed_mm_disabled_reasons, get_packed_mm_position_strategy from prime_rl.trainer.rl.loss import ( compute_entropy, compute_loss, @@ -158,6 +159,20 @@ def train(config: TrainerConfig): logger.info(f"Initializing tokenizer ({config.tokenizer})") tokenizer = setup_tokenizer(config.tokenizer) + mm_position_strategy = get_packed_mm_position_strategy(model) + mm_pack_reasons = get_packed_mm_disabled_reasons( + model, + enabled=config.pack_multimodal, + attn_impl=config.model.attn, + cp_enabled=parallel_dims.cp_enabled, + cp_size=config.model.cp, + ) + pack_multimodal = not mm_pack_reasons + if pack_multimodal: + logger.info("Multimodal packing enabled (position_strategy=pass_1d)") + elif config.model.vlm is not None or mm_position_strategy != "none": + logger.info(f"Multimodal packing disabled ({', '.join(mm_pack_reasons)})") + # Set up the loss function logger.info(f"Setting up loss function ({config.loss})") loss_fns = setup_loss_fns(config.loss) @@ -255,6 +270,7 @@ def load_run_checkpoint(_optimizer, idx: int) -> None: # defer is on and renderer_config is not None, so an explicit # renderer=None still opts out (and a text-only run never gets mm_refs). renderer_config=config.renderer, + pack_multimodal=pack_multimodal, ) token_exporter = setup_token_exporter(config, parallel_dims, world, logger) diff --git a/src/prime_rl/utils/vlm.py b/src/prime_rl/utils/vlm.py index a099c19925..abb8d983c3 100644 --- a/src/prime_rl/utils/vlm.py +++ b/src/prime_rl/utils/vlm.py @@ -10,6 +10,7 @@ """ from dataclasses import dataclass +from typing import Literal, TypeAlias import torch.nn as nn from transformers.configuration_utils import PretrainedConfig @@ -23,6 +24,10 @@ class VLMModelInfo: language_model_attr: str +PackedMMPositionStrategy: TypeAlias = Literal["none", "pass_1d"] +PACKED_MM_ATTN_IMPLS = ("flash_attention_2", "flash_attention_3", "fa4") + + # Central registry: model_type -> architecture info. VLM_REGISTRY: dict[str, VLMModelInfo] = { "qwen3_vl": VLMModelInfo(vision_encoder_attr="model.visual", language_model_attr="model.language_model"), @@ -86,6 +91,49 @@ def is_vlm_architecture(model_config: PretrainedConfig) -> bool: return _get_model_info_from_config(model_config) is not None +def get_packed_mm_position_strategy(model: nn.Module) -> PackedMMPositionStrategy: + """Return the model's packed multimodal position strategy. + + ``pass_1d`` is intentionally narrow: it means the VLM's language model + consumes reset 1D ``position_ids`` and derives packed attention boundaries + from them. HF Qwen-style MRoPE models need model-computed 3D/4-row positions + and therefore remain ``none`` until a dedicated builder exists. + """ + for candidate in _iter_wrapped_modules(model): + strategy = getattr(candidate, "packed_mm_position_strategy", None) + if strategy in ("none", "pass_1d"): + return strategy + + model_type = getattr(getattr(candidate, "config", None), "model_type", None) + if model_type == "qwen3_5_moe" and getattr(candidate, "_is_vlm", False): + return "pass_1d" + + return "none" + + +def get_packed_mm_disabled_reasons( + model: nn.Module, + *, + enabled: bool, + attn_impl: str, + cp_enabled: bool, + cp_size: int | None = None, +) -> list[str]: + """Return reasons multimodal packing should be disabled for this runtime.""" + strategy = get_packed_mm_position_strategy(model) + reasons = [] + if not enabled: + reasons.append("trainer.pack_multimodal=false") + if strategy != "pass_1d": + reasons.append(f"position_strategy={strategy}") + if attn_impl not in PACKED_MM_ATTN_IMPLS: + reasons.append(f"attn={attn_impl}") + if cp_enabled: + cp_label = cp_size if cp_size is not None else "enabled" + reasons.append(f"cp={cp_label}") + return reasons + + def get_layer_prefix(model_config: PretrainedConfig, override: str | None = None) -> str: """Return the weight key prefix for language model layers. @@ -122,3 +170,13 @@ def _resolve_attr(obj, dotted_path: str): if obj is None: return None return obj + + +def _iter_wrapped_modules(model: nn.Module): + """Yield a module and common wrapper inners without depending on wrapper types.""" + seen: set[int] = set() + current = model + while current is not None and id(current) not in seen: + seen.add(id(current)) + yield current + current = getattr(current, "module", None) diff --git a/tests/unit/train/rl/test_packer.py b/tests/unit/train/rl/test_packer.py index e40e7e89ff..8ba9b0f55a 100644 --- a/tests/unit/train/rl/test_packer.py +++ b/tests/unit/train/rl/test_packer.py @@ -132,7 +132,7 @@ def _mm_sample(uri: str) -> TrainingSample: ) -def _packer_with_two_runs(tmp_path, monkeypatch, dp_world_size, seq_len): +def _packer_with_two_runs(tmp_path, monkeypatch, dp_world_size, seq_len, pack_multimodal: bool = False): """Set up a MultiPacker over two discovered runs; capture sent grids.""" reset_world() runs._MULTI_RUN_MANAGER = None @@ -163,6 +163,7 @@ def send(self, micro_batch_grid): tokenizer=None, config=FileSystemTransportConfig(), start_step=0, + pack_multimodal=pack_multimodal, ) return manager, packer, sent @@ -223,3 +224,30 @@ def test_multipacker_pack_mm_padding_is_zero_loss(tmp_path, monkeypatch): assert dummies, "expected a zero-loss dummy MM padding microbatch" for d in dummies: assert all(a == 0.0 for a in d.advantages) + + +def test_multipacker_pack_packs_mm_refs_within_each_run_when_enabled(tmp_path, monkeypatch): + """MultiPacker threads the MM-packing capability into per-run prepare_batch + calls, so refs pack within a run but never across runs.""" + from prime_rl.trainer.batch import _is_multimodal_sample + + manager, packer, sent = _packer_with_two_runs( + tmp_path, monkeypatch, dp_world_size=2, seq_len=4, pack_multimodal=True + ) + a, b = manager.id_2_idx["run_a"], manager.id_2_idx["run_b"] + for idx, prefix in ((a, "a"), (b, "b")): + packer.buffers[idx].append((_mm_sample(f"{prefix}0"), 0)) + packer.buffers[idx].append((_mm_sample(f"{prefix}1"), 0)) + + packer.pack() + assert sent + grid = sent[-1] + real_mm_mbs = [mb for rank in grid for mb in rank if _is_multimodal_sample(mb) and any(mb.loss_mask)] + + assert len(real_mm_mbs) == 2 + assert sorted(mb.mm_refs.uris for mb in real_mm_mbs) == [["a0", "a1"], ["b0", "b1"]] + for mb in real_mm_mbs: + assert len(mb.input_ids) == 4 + assert mb.position_ids == [0, 1, 0, 1] + tagged = [i for i, n in enumerate(mb.lora_num_tokens) if n > 0] + assert len(tagged) == 1 diff --git a/tests/unit/train/test_model_forward.py b/tests/unit/train/test_model_forward.py index 7baf7760a8..184b7d4989 100644 --- a/tests/unit/train/test_model_forward.py +++ b/tests/unit/train/test_model_forward.py @@ -81,3 +81,22 @@ def test_forward_keeps_position_ids_for_non_mrope_vlm(): assert model.kwargs is not None torch.testing.assert_close(model.kwargs["position_ids"], position_ids) + + +def test_forward_keeps_position_ids_for_pass_1d_mrope_vlm(): + """Custom Prime VLMs with packed 1D position support keep position_ids even + when Qwen-style image_grid_thw is present.""" + model = _CaptureModel(SimpleNamespace(model_type="qwen3_5_moe")) + model.packed_mm_position_strategy = "pass_1d" + input_ids = torch.tensor([[1, 10, 10, 2, 20, 20]]) + position_ids = torch.tensor([[0, 1, 2, 0, 1, 2]]) + + forward( + model, + input_ids, + position_ids, + mm_kwargs={"pixel_values": torch.ones(4, 3), "image_grid_thw": torch.tensor([[1, 1, 2], [1, 1, 2]])}, + ) + + assert model.kwargs is not None + torch.testing.assert_close(model.kwargs["position_ids"], position_ids) diff --git a/tests/unit/trainer/test_mm_refs.py b/tests/unit/trainer/test_mm_refs.py index ec3614591c..c6db1d8e8d 100644 --- a/tests/unit/trainer/test_mm_refs.py +++ b/tests/unit/trainer/test_mm_refs.py @@ -16,7 +16,7 @@ from prime_rl.orchestrator.trajectories import _collect_mm_refs, _pack_mm_kwargs_from_renderer, _reconstruct_mm_pixels from prime_rl.trainer.batch import _is_multimodal_sample -from prime_rl.transport.types import MicroBatch, MMRefs, TrainingSample +from prime_rl.transport.types import EncodedTensor, MicroBatch, MMRefs, TrainingSample from prime_rl.utils.mm import ( build_image_messages, encode_mm_kwargs, @@ -232,6 +232,243 @@ def _text_sample(env: str = "e", n_prompt: int = 4, n_comp: int = 4): ) +def _encoded_tensor(tensor: torch.Tensor) -> EncodedTensor: + arr = tensor.detach().cpu().numpy() + return EncodedTensor(dtype=str(arr.dtype), shape=list(arr.shape), data=arr.tobytes()) + + +def _mm_kwargs_sample(pixel_values: torch.Tensor, env: str = "e", n_prompt: int = 2, n_comp: int = 2): + return TrainingSample( + prompt_ids=list(range(n_prompt)), + prompt_mask=[False] * n_prompt, + completion_ids=list(range(n_comp)), + completion_mask=[True] * n_comp, + completion_logprobs=[0.0] * n_comp, + completion_temperatures=[1.0] * n_comp, + env_name=env, + advantage=0.0, + reward=0.0, + mm_token_type_ids=[1] * (n_prompt + n_comp), + mm_kwargs={ + "pixel_values": _encoded_tensor(pixel_values), + "image_grid_thw": _encoded_tensor(torch.tensor([[1, 1, pixel_values.shape[0]]], dtype=torch.int64)), + }, + ) + + +def test_prepare_batch_packs_mm_refs_when_enabled_preserving_order_and_boundaries(): + """Deferred refs pack by token length, while descriptors/uris concatenate in + the same order and position_ids keep per-sample resets.""" + from prime_rl.trainer.batch import prepare_batch + + uri = "file:///dup.jpg" + rollouts = [ + _mm_sample(uri, n_prompt=2, n_comp=2), + _mm_sample(uri, n_prompt=1, n_comp=3), + _text_sample(n_prompt=2, n_comp=2), + ] + + grid = prepare_batch( + rollouts, + seq_len=16, + num_train_workers=1, + idxs=[0, 0, 0], + num_loras=1, + pack_multimodal=True, + ) + flat = grid[0] + mm_mbs = [mb for mb in flat if _is_multimodal_sample(mb)] + text_mbs = [mb for mb in flat if not _is_multimodal_sample(mb)] + + assert len(mm_mbs) == 1 + assert len(text_mbs) == 1 + mb = mm_mbs[0] + assert mb.mm_refs is not None and mb.mm_kwargs is None + assert mb.input_ids == [0, 1, 0, 1, 0, 0, 1, 2] + assert mb.position_ids == [0, 1, 2, 3, 0, 1, 2, 3] + assert mb.mm_token_type_ids == [1] * len(mb.input_ids) + assert mb.mm_refs.uris == [uri, uri] + assert mb.mm_refs.descriptor["mm_hashes"]["image"] == [_uri_hash(uri), _uri_hash(uri)] + assert len(mb.mm_refs.descriptor["mm_items"]["image"]) == 2 + assert mb.lora_num_tokens == [len(mb.input_ids)] + + +def test_packed_mm_refs_filesystem_transport_materializes_stitched_tensors(tmp_path): + """End-to-end trainer mechanics: prepare packed MM refs, write/read them + through the real filesystem microbatch transport, then materialize them in + DataLoader into the model kwargs consumed by forward.""" + from types import SimpleNamespace + + from prime_rl.trainer.batch import prepare_batch + from prime_rl.trainer.rl.data import DataLoader + from prime_rl.transport.filesystem import FileSystemMicroBatchReceiver, FileSystemMicroBatchSender + + uri0, uri1 = "file:///packed-a.jpg", "file:///packed-b.jpg" + rollouts = [ + _mm_sample(uri0, n_prompt=2, n_comp=2), + _mm_sample(uri1, n_prompt=2, n_comp=2), + ] + grid = prepare_batch( + rollouts, + seq_len=16, + num_train_workers=2, + idxs=[0, 0], + num_loras=1, + pack_multimodal=True, + ) + + assert len(grid) == 2 + assert len(grid[0]) == len(grid[1]) == 1 + assert any(grid[0][0].loss_mask) + assert not any(grid[1][0].loss_mask) # modality-preserving dummy for rank alignment + + sender = FileSystemMicroBatchSender(tmp_path, data_world_size=2, current_step=0) + sender.send(grid) + rank0_mb = FileSystemMicroBatchReceiver(tmp_path, data_rank=0, current_step=0).receive()[0] + rank1_mb = FileSystemMicroBatchReceiver(tmp_path, data_rank=1, current_step=0).receive()[0] + + for mb, has_loss in ((rank0_mb, True), (rank1_mb, False)): + assert mb.mm_refs is not None and mb.mm_kwargs is None + assert mb.mm_refs.uris == [uri0, uri1] + assert mb.mm_refs.descriptor["mm_hashes"]["image"] == [_uri_hash(uri0), _uri_hash(uri1)] + assert len(mb.mm_refs.descriptor["mm_items"]["image"]) == 2 + assert mb.position_ids == [0, 1, 2, 3, 0, 1, 2, 3] + assert any(mb.loss_mask) is has_loss + + renderer = _StubRenderer( + { + _uri_hash(uri0): torch.tensor([[10.0, 11.0]], dtype=torch.float32), + _uri_hash(uri1): torch.tensor([[20.0, 21.0]], dtype=torch.float32), + } + ) + + loader = DataLoader.__new__(DataLoader) + loader.multi_run_manager = SimpleNamespace(max_runs=1) + loader._renderer = renderer + loader.last_mm_materialize_time = 0.0 + loader.last_mm_images_materialized = 0 + + rank0 = DataLoader._micro_batch_to_tensor(loader, rank0_mb) + rank1 = DataLoader._micro_batch_to_tensor(loader, rank1_mb) + + for tensor_batch, has_loss in ((rank0, True), (rank1, False)): + torch.testing.assert_close(tensor_batch["position_ids"], torch.tensor([[0, 1, 2, 3, 0, 1, 2, 3]])) + torch.testing.assert_close(tensor_batch["mm_token_type_ids"], torch.ones((1, 8), dtype=torch.long)) + assert bool(tensor_batch["loss_mask"].any().item()) is has_loss + assert tensor_batch["mm_kwargs"] is not None + torch.testing.assert_close( + tensor_batch["mm_kwargs"]["pixel_values"], + torch.tensor([[10.0, 11.0], [20.0, 21.0]], dtype=torch.float32), + ) + torch.testing.assert_close( + tensor_batch["mm_kwargs"]["image_grid_thw"], + torch.tensor([[1, 2, 3], [1, 2, 3]], dtype=torch.int64), + ) + + assert loader.last_mm_images_materialized == 4 # both aligned ranks materialized two image refs + + +def test_prepare_batch_does_not_pack_mm_refs_with_text_or_other_lora(): + from prime_rl.trainer.batch import prepare_batch + + rollouts = [ + _mm_sample("file:///run0.jpg", n_prompt=2, n_comp=2), + _mm_sample("file:///run1.jpg", n_prompt=2, n_comp=2), + _text_sample(n_prompt=2, n_comp=2), + ] + grid = prepare_batch( + rollouts, + seq_len=16, + num_train_workers=1, + idxs=[0, 1, 0], + num_loras=2, + pack_multimodal=True, + ) + + mm_mbs = [mb for mb in grid[0] if _is_multimodal_sample(mb)] + text_mbs = [mb for mb in grid[0] if not _is_multimodal_sample(mb)] + assert len(mm_mbs) == 2 + assert len(text_mbs) == 1 + assert [mb.lora_num_tokens for mb in mm_mbs] == [[4, 0], [0, 4]] + + +def test_prepare_batch_does_not_pack_mm_refs_with_eager_mm_kwargs(): + from prime_rl.trainer.batch import prepare_batch + + rollouts = [ + _mm_sample("file:///refs.jpg", n_prompt=2, n_comp=2), + _mm_kwargs_sample(torch.tensor([[1.0, 2.0]], dtype=torch.float32)), + ] + grid = prepare_batch( + rollouts, + seq_len=16, + num_train_workers=1, + idxs=[0, 0], + num_loras=1, + pack_multimodal=True, + ) + + mm_mbs = [mb for mb in grid[0] if _is_multimodal_sample(mb)] + assert len(mm_mbs) == 2 + assert {("refs" if mb.mm_refs is not None else "kwargs") for mb in mm_mbs} == {"refs", "kwargs"} + + +def test_prepare_batch_packs_eager_mm_kwargs_when_enabled(): + from prime_rl.trainer.batch import prepare_batch + + rollouts = [ + _mm_kwargs_sample(torch.tensor([[1.0, 2.0]], dtype=torch.float32)), + _mm_kwargs_sample(torch.tensor([[3.0, 4.0]], dtype=torch.float32)), + ] + grid = prepare_batch( + rollouts, + seq_len=16, + num_train_workers=1, + idxs=[0, 0], + num_loras=1, + pack_multimodal=True, + ) + + mb = grid[0][0] + assert mb.mm_kwargs is not None and mb.mm_refs is None + pv = mb.mm_kwargs["pixel_values"] + assert pv.shape == [2, 2] + pixel_values = torch.frombuffer(bytearray(pv.data), dtype=torch.float32).reshape(pv.shape) + torch.testing.assert_close(pixel_values, torch.tensor([[1.0, 2.0], [3.0, 4.0]])) + assert mb.position_ids == [0, 1, 2, 3, 0, 1, 2, 3] + + +def test_prepare_batch_rejects_incompatible_eager_mm_kwargs(): + from prime_rl.trainer.batch import prepare_batch + + with pytest.raises(ValueError, match="incompatible shapes"): + prepare_batch( + [ + _mm_kwargs_sample(torch.tensor([[1.0, 2.0]], dtype=torch.float32)), + _mm_kwargs_sample(torch.tensor([[3.0, 4.0, 5.0]], dtype=torch.float32)), + ], + seq_len=16, + num_train_workers=1, + idxs=[0, 0], + num_loras=1, + pack_multimodal=True, + ) + + +def test_prepare_batch_rejects_truncated_multimodal_sample(): + from prime_rl.trainer.batch import prepare_batch + + with pytest.raises(ValueError, match="Cannot truncate multimodal"): + prepare_batch( + [_mm_sample("file:///too-long.jpg", n_prompt=2, n_comp=2)], + seq_len=3, + num_train_workers=1, + idxs=[0], + num_loras=1, + pack_multimodal=True, + ) + + def test_multirun_packing_preserves_mm_refs_modality_and_run_tagging(): """Multi-run: deferred mm_refs samples from 2 runs pack correctly through the REAL prepare_batch — each MM sample is its own microbatch carrying its mm_refs, diff --git a/tests/unit/utils/test_vlm.py b/tests/unit/utils/test_vlm.py new file mode 100644 index 0000000000..3630455fd4 --- /dev/null +++ b/tests/unit/utils/test_vlm.py @@ -0,0 +1,44 @@ +from types import SimpleNamespace + +import torch.nn as nn + +from prime_rl.utils.vlm import get_packed_mm_disabled_reasons, get_packed_mm_position_strategy + + +class _Model(nn.Module): + def __init__(self, *, strategy=None, model_type="qwen3_5_moe", is_vlm=True): + super().__init__() + self.config = SimpleNamespace(model_type=model_type) + self._is_vlm = is_vlm + if strategy is not None: + self.packed_mm_position_strategy = strategy + + +def test_packed_mm_strategy_is_pass_1d_for_custom_qwen35_vlm(): + assert get_packed_mm_position_strategy(_Model()) == "pass_1d" + + +def test_packed_mm_strategy_none_for_text_only_custom_qwen35(): + assert get_packed_mm_position_strategy(_Model(is_vlm=False)) == "none" + + +def test_packed_mm_gate_allows_only_supported_runtime(): + model = _Model(strategy="pass_1d") + + assert ( + get_packed_mm_disabled_reasons(model, enabled=True, attn_impl="flash_attention_2", cp_enabled=False, cp_size=1) + == [] + ) + assert get_packed_mm_disabled_reasons(model, enabled=True, attn_impl="sdpa", cp_enabled=False) == ["attn=sdpa"] + assert get_packed_mm_disabled_reasons(model, enabled=True, attn_impl="fa4", cp_enabled=True, cp_size=2) == ["cp=2"] + assert get_packed_mm_disabled_reasons(model, enabled=False, attn_impl="fa4", cp_enabled=False) == [ + "trainer.pack_multimodal=false" + ] + + +def test_packed_mm_gate_rejects_hf_mrope_default_strategy(): + model = _Model(strategy="none", model_type="qwen3_vl") + + assert get_packed_mm_disabled_reasons(model, enabled=True, attn_impl="flash_attention_2", cp_enabled=False) == [ + "position_strategy=none" + ] From 5bd07f7102899e4ef32b8dc018f366bda4f5dbc0 Mon Sep 17 00:00:00 2001 From: Eli Gottlieb <78387377+eligotts@users.noreply.github.com> Date: Thu, 4 Jun 2026 03:33:13 +0000 Subject: [PATCH 18/31] feat(transport): default micro-batches to ZMQ with multi-node binding + coordinated fail-fast Ship packed per-rank micro-batches (trainer master -> data ranks) over ZMQ instead of the shared filesystem. Hardened for multi-node and made fail-fast: - bind 0.0.0.0 / connect MASTER_ADDR (or configured host) so PUB/SUB + the READY barrier work across nodes; port+1 = data, port+2 = startup READY barrier. - bounded timeouts (recv/ready/publish) replace infinite blocking, with a publish_grace to reduce PUB/SUB slow-joiner races at startup. - per-message step tag + receiver mismatch assertion: a dropped message becomes a loud crash, never silent training on misaligned data. - torch.distributed Store publish-status barrier (master -> ranks) plus an all-reduce(MAX) fail-fast (any rank -> all): a master pack failure or a single rank's transport failure crashes all ranks coordinated instead of hanging on a collective. synchronize_state runs only after that barrier. micro_batch_transport defaults to ZMQ; rollouts (orchestrator -> trainer) stay filesystem. Adds unit fan-out/timeout/step-mismatch tests and a torchrun 2-rank integration smoke. Residual risk is multi-node PUB/SUB startup; validate the slow-joiner grace on the real topology. Co-Authored-By: Codex Co-Authored-By: Claude Opus 4.8 (1M context) --- .../src/prime_rl/configs/shared.py | 25 ++- .../src/prime_rl/configs/trainer.py | 4 + src/prime_rl/trainer/rl/data.py | 55 +++++- src/prime_rl/trainer/rl/packer.py | 31 +++- src/prime_rl/trainer/rl/train.py | 30 +++- src/prime_rl/transport/zmq.py | 91 ++++++++-- .../integration/test_zmq_microbatch_smoke.py | 168 ++++++++++++++++++ tests/integration/zmq_microbatch_smoke.py | 78 ++++++++ tests/unit/test_configs.py | 8 + tests/unit/train/rl/test_packer.py | 18 ++ tests/unit/transport/test_zmq.py | 101 +++++++++++ 11 files changed, 574 insertions(+), 35 deletions(-) create mode 100644 tests/integration/test_zmq_microbatch_smoke.py create mode 100644 tests/integration/zmq_microbatch_smoke.py create mode 100644 tests/unit/transport/test_zmq.py diff --git a/packages/prime-rl-configs/src/prime_rl/configs/shared.py b/packages/prime-rl-configs/src/prime_rl/configs/shared.py index 8c1b3884be..5a4c6e1e12 100644 --- a/packages/prime-rl-configs/src/prime_rl/configs/shared.py +++ b/packages/prime-rl-configs/src/prime_rl/configs/shared.py @@ -242,16 +242,35 @@ class FileSystemTransportConfig(BaseTransportConfig): class ZMQTransportConfig(BaseTransportConfig): + """ + ZMQ binds on all local interfaces and connects to ``host`` (or ``MASTER_ADDR`` when unset). + Base ``port`` is used for training batches if that hop uses ZMQ; micro-batches use + ``port + 1`` for PUB/SUB data and ``port + 2`` for the startup READY barrier. + This assumes a trusted trainer network; ZMQ messages are not authenticated. + """ + type: Literal["zmq"] = "zmq" - host: str = "localhost" - """Host address for ZMQ transport.""" + host: str | None = None + """Host address receivers/senders connect to. When unset or ``0.0.0.0``, resolves to ``MASTER_ADDR`` or ``localhost``.""" port: int = 5555 """Base port for ZMQ transport.""" - hwm: int = 10 + hwm: int = Field(64, ge=1) """High-water mark (max in-flight messages per ZMQ socket).""" + recv_timeout_seconds: int = Field(300, ge=1) + """Seconds a micro-batch receiver waits after the master has published a step before failing fast.""" + + ready_timeout_seconds: int = Field(300, ge=1) + """Seconds the micro-batch sender waits at startup for rank READY messages before failing fast.""" + + publish_timeout_seconds: int = Field(1800, ge=1) + """Seconds ranks wait for the master to publish/fail a packed micro-batch step.""" + + publish_grace_ms: int = Field(100, ge=0) + """Small startup grace after all READY messages arrive, reducing PUB/SUB slow-joiner races.""" + TransportConfig: TypeAlias = Annotated[FileSystemTransportConfig | ZMQTransportConfig, Field(discriminator="type")] diff --git a/packages/prime-rl-configs/src/prime_rl/configs/trainer.py b/packages/prime-rl-configs/src/prime_rl/configs/trainer.py index 82ff3eb73f..eb86ba9e6d 100644 --- a/packages/prime-rl-configs/src/prime_rl/configs/trainer.py +++ b/packages/prime-rl-configs/src/prime_rl/configs/trainer.py @@ -13,6 +13,7 @@ TrainerLogConfig, TransportConfig, WandbConfig, + ZMQTransportConfig, ) from prime_rl.utils.config import BaseConfig @@ -523,6 +524,9 @@ class TrainerConfig(BaseConfig): rollout_transport: TransportConfig = FileSystemTransportConfig() """Transport used to ship rollouts from orchestrator to trainer.""" + micro_batch_transport: TransportConfig = ZMQTransportConfig() + """Transport used to ship packed per-rank micro-batches from the trainer master to data ranks.""" + log: TrainerLogConfig = TrainerLogConfig() wandb: WandbConfig | None = None diff --git a/src/prime_rl/trainer/rl/data.py b/src/prime_rl/trainer/rl/data.py index 7c3d763518..73d1c1f11d 100644 --- a/src/prime_rl/trainer/rl/data.py +++ b/src/prime_rl/trainer/rl/data.py @@ -1,8 +1,11 @@ +import pickle import time +from datetime import timedelta from pathlib import Path from typing import TypedDict import torch +import torch.distributed.distributed_c10d as c10d from jaxtyping import Bool, Float, Int from renderers import RendererConfig from torch import Tensor @@ -15,6 +18,8 @@ from prime_rl.transport import MicroBatch, MicroBatchReceiver, TransportConfig, setup_micro_batch_receiver from prime_rl.utils.logger import get_logger +DEFAULT_MICRO_BATCH_PUBLISH_TIMEOUT_SECONDS = 1800 + class TensorMicroBatch(TypedDict): """A micro batch of data for training.""" @@ -67,6 +72,9 @@ def __init__(self, config: FakeDataLoaderConfig, seq_len: int, dp_world_size: in def wait_for_batch(self) -> None: return + def synchronize_state(self) -> None: + return + def get_batch(self) -> list[TensorMicroBatch]: if not self.generate_samples: get_micro_batch_fn = self._get_micro_batch @@ -169,8 +177,17 @@ def __init__( defer_mm_materialization: bool = False, renderer_config: RendererConfig | None = None, pack_multimodal: bool = False, + micro_batch_transport_config: TransportConfig | None = None, ): self.world = get_world() + self._current_step = start_step + self._micro_batch_transport_config = micro_batch_transport_config or config + self._publish_timeout_seconds = getattr( + self._micro_batch_transport_config, + "publish_timeout_seconds", + DEFAULT_MICRO_BATCH_PUBLISH_TIMEOUT_SECONDS, + ) + self._store = c10d._get_default_store() if self.world.is_master: self.packer: BasePacker = setup_packer( @@ -181,13 +198,16 @@ def __init__( pad_to_multiple_of=pad_to_multiple_of, start_step=start_step, pack_multimodal=pack_multimodal, + micro_batch_transport_config=self._micro_batch_transport_config, ) non_dp_world_size = self.world.world_size // dp_world_size dp_rank = self.world.rank // non_dp_world_size self.multi_run_manager = get_multi_run_manager() - self.receiver: MicroBatchReceiver = setup_micro_batch_receiver(output_dir, dp_rank, start_step, config) + self.receiver: MicroBatchReceiver = setup_micro_batch_receiver( + output_dir, dp_rank, start_step, self._micro_batch_transport_config + ) # Deferred materialization: each rank builds its own renderer once and # materializes pixels from the shipped image references in get_batch. @@ -204,21 +224,52 @@ def __init__( self.last_mm_materialize_time = 0.0 self.last_mm_images_materialized = 0 + def _publish_status_key(self) -> str: + return f"micro_batch_publish/{self._current_step}" + + def _publish_micro_batch_status(self, *, ok: bool, error: str = "") -> None: + self._store.set(self._publish_status_key(), pickle.dumps({"ok": ok, "error": error})) + + def _wait_for_micro_batch_status(self) -> None: + key = self._publish_status_key() + try: + self._store.wait([key], timedelta(seconds=self._publish_timeout_seconds)) + except Exception as exc: + raise TimeoutError( + f"Timed out waiting for trainer master to publish micro-batch step {self._current_step} " + f"after {self._publish_timeout_seconds}s" + ) from exc + + status = pickle.loads(self._store.get(key)) + if not status.get("ok", False): + error = status.get("error") or "unknown error" + raise RuntimeError(f"Trainer master failed to pack micro-batch step {self._current_step}: {error}") + def wait_for_batch(self) -> None: if self.world.is_master: self.packer._arm_watchdog() try: self.packer.pack() + self._publish_micro_batch_status(ok=True) + except Exception as exc: + self._publish_micro_batch_status(ok=False, error=repr(exc)) + raise finally: self.packer._disarm_watchdog() + + self._wait_for_micro_batch_status() self.receiver.wait() + + def synchronize_state(self) -> None: self.multi_run_manager.synchronize_state() def get_batch(self) -> list[TensorMicroBatch]: micro_batches = self.receiver.receive() self.last_mm_materialize_time = 0.0 self.last_mm_images_materialized = 0 - return [self._micro_batch_to_tensor(mb) for mb in micro_batches] + tensor_batches = [self._micro_batch_to_tensor(mb) for mb in micro_batches] + self._current_step += 1 + return tensor_batches def _micro_batch_to_tensor(self, micro_batch: MicroBatch) -> TensorMicroBatch: """Convert a MicroBatch (msgspec struct with lists) to a TensorMicroBatch (dict with tensors).""" diff --git a/src/prime_rl/trainer/rl/packer.py b/src/prime_rl/trainer/rl/packer.py index eac150c0e1..885b866414 100644 --- a/src/prime_rl/trainer/rl/packer.py +++ b/src/prime_rl/trainer/rl/packer.py @@ -34,6 +34,7 @@ def __init__( config: TransportConfig, start_step: int = 0, pack_multimodal: bool = False, + micro_batch_transport_config: TransportConfig | None = None, ): self.logger = get_logger() self.multi_run_manager = get_multi_run_manager() @@ -43,9 +44,10 @@ def __init__( self.tokenizer = tokenizer self.pack_multimodal = pack_multimodal self.receiver = setup_training_batch_receiver(config) + micro_batch_transport_config = micro_batch_transport_config or config shutil.rmtree(get_rollout_dir(self.multi_run_manager.output_dir), ignore_errors=True) self.sender: MicroBatchSender = setup_micro_batch_sender( - self.multi_run_manager.output_dir, dp_world_size, start_step, config + self.multi_run_manager.output_dir, dp_world_size, start_step, micro_batch_transport_config ) self._last_heartbeat = time.monotonic() self._watchdog_armed = threading.Event() @@ -88,8 +90,18 @@ def __init__( config: TransportConfig, start_step: int = 0, pack_multimodal: bool = False, + micro_batch_transport_config: TransportConfig | None = None, ): - super().__init__(dp_world_size, seq_len, pad_to_multiple_of, tokenizer, config, start_step, pack_multimodal) + super().__init__( + dp_world_size, + seq_len, + pad_to_multiple_of, + tokenizer, + config, + start_step, + pack_multimodal, + micro_batch_transport_config, + ) assert self.multi_run_manager.max_runs == 1, "SinglePacker only supports one run" def pack(self): @@ -129,8 +141,18 @@ def __init__( config: TransportConfig, start_step: int = 0, pack_multimodal: bool = False, + micro_batch_transport_config: TransportConfig | None = None, ): - super().__init__(dp_world_size, seq_len, pad_to_multiple_of, tokenizer, config, start_step, pack_multimodal) + super().__init__( + dp_world_size, + seq_len, + pad_to_multiple_of, + tokenizer, + config, + start_step, + pack_multimodal, + micro_batch_transport_config, + ) # Per-run buffer: stores (TrainingSample, step) tuples self.buffers: list[deque[tuple[TrainingSample, int]]] = [ deque() for _ in range(self.multi_run_manager.max_runs) @@ -349,6 +371,7 @@ def setup_packer( transport_config: TransportConfig, start_step: int = 0, pack_multimodal: bool = False, + micro_batch_transport_config: TransportConfig | None = None, ) -> BasePacker: multi_run_manager = get_multi_run_manager() if multi_run_manager.max_runs == 1: @@ -360,6 +383,7 @@ def setup_packer( transport_config, start_step, pack_multimodal, + micro_batch_transport_config, ) else: return MultiPacker( @@ -370,4 +394,5 @@ def setup_packer( transport_config, start_step, pack_multimodal, + micro_batch_transport_config, ) diff --git a/src/prime_rl/trainer/rl/train.py b/src/prime_rl/trainer/rl/train.py index 56c66937c7..a35959b722 100644 --- a/src/prime_rl/trainer/rl/train.py +++ b/src/prime_rl/trainer/rl/train.py @@ -71,6 +71,16 @@ from torchtitan.distributed.utils import clip_grad_norm_ +def _raise_if_any_rank_failed(local_error: Exception | None, message: str) -> None: + failed_flag = torch.tensor(1 if local_error else 0, dtype=torch.int64, device="cuda") + dist.all_reduce(failed_flag, op=dist.ReduceOp.MAX) + if failed_flag.item() == 0: + return + if local_error is not None: + raise RuntimeError(f"{message} on this rank; failing all ranks.") from local_error + raise RuntimeError(f"{message} on another rank; failing all ranks.") + + @clean_exit def train(config: TrainerConfig): # Setup world and logger @@ -271,6 +281,7 @@ def load_run_checkpoint(_optimizer, idx: int) -> None: # renderer=None still opts out (and a text-only run never gets mm_refs). renderer_config=config.renderer, pack_multimodal=pack_multimodal, + micro_batch_transport_config=config.micro_batch_transport, ) token_exporter = setup_token_exporter(config, parallel_dims, world, logger) @@ -362,7 +373,13 @@ def load_run_checkpoint(_optimizer, idx: int) -> None: # Wait for the batch to be available logger.debug("Waiting for training batch to arrive") wait_for_batch_start_time = time.perf_counter() - dataloader.wait_for_batch() + wait_error: Exception | None = None + try: + dataloader.wait_for_batch() + except Exception as exc: + wait_error = exc + _raise_if_any_rank_failed(wait_error, "Training-batch wait failed") + dataloader.synchronize_state() wait_for_batch_time = time.perf_counter() - wait_for_batch_start_time logger.debug(f"Waited for batch to arrive for {wait_for_batch_time:.2f} seconds") @@ -378,14 +395,9 @@ def load_run_checkpoint(_optimizer, idx: int) -> None: micro_batches = dataloader.get_batch() except Exception as exc: load_error = exc - failed_flag = torch.tensor(1 if load_error else 0, dtype=torch.int64, device="cuda") - dist.all_reduce(failed_flag, op=dist.ReduceOp.MAX) - if failed_flag.item() != 0: - # Preserve the culprit rank's traceback; bystander ranks still raise - # so none proceeds into the forward collective alone. - if load_error is not None: - raise RuntimeError("Training-batch load failed on this rank; failing all ranks.") from load_error - raise RuntimeError("Training-batch load failed on another rank; failing all ranks.") + # Preserve the culprit rank's traceback; bystander ranks still raise so + # none proceeds into the forward collective alone. + _raise_if_any_rank_failed(load_error, "Training-batch load failed") load_data_time = time.perf_counter() - load_data_start_time logger.debug(f"Loaded batch in {load_data_time:.2f} seconds") diff --git a/src/prime_rl/transport/zmq.py b/src/prime_rl/transport/zmq.py index 5577b11e50..7b973ccc6e 100644 --- a/src/prime_rl/transport/zmq.py +++ b/src/prime_rl/transport/zmq.py @@ -1,5 +1,6 @@ +import os from pathlib import Path -from time import time +from time import monotonic, sleep, time import zmq @@ -11,6 +12,20 @@ LOG_FREQ_SECONDS = 10 +def _connect_host(transport: ZMQTransportConfig) -> str: + if transport.host and transport.host != "0.0.0.0": + return transport.host + return os.environ.get("MASTER_ADDR", "localhost") + + +def _bind_host() -> str: + return "0.0.0.0" + + +def _timeout_ms(seconds: int) -> int: + return int(seconds * 1000) + + class ZMQTrainingBatchSender(TrainingBatchSender): """ One PUSH socket; each message is multipart: [sender_id, payload] @@ -23,13 +38,14 @@ def __init__(self, output_dir: Path, transport: ZMQTransportConfig): self.context = zmq.Context.instance() self.socket: zmq.Socket = self.context.socket(zmq.PUSH) self.socket.setsockopt(zmq.SNDHWM, transport.hwm) - self.socket.connect(f"tcp://{transport.host}:{transport.port}") + connect_host = _connect_host(transport) + self.socket.connect(f"tcp://{connect_host}:{transport.port}") self.sender_id = output_dir.stem.encode("utf-8") self.logger.info( f"ZMQ training batch sender initialized: output_dir={output_dir} " - f"endpoint=tcp://{transport.host}:{transport.port} hwm={transport.hwm}" + f"endpoint=tcp://{connect_host}:{transport.port} hwm={transport.hwm}" ) async def send(self, batch: TrainingBatch) -> None: @@ -65,7 +81,8 @@ def __init__(self, transport: ZMQTransportConfig): self.context = zmq.Context.instance() self.socket: zmq.Socket = self.context.socket(zmq.PULL) self.socket.setsockopt(zmq.RCVHWM, transport.hwm) - self.socket.bind(f"tcp://{transport.host}:{transport.port}") + bind_host = _bind_host() + self.socket.bind(f"tcp://{bind_host}:{transport.port}") self.poller = zmq.Poller() self.poller.register(self.socket, zmq.POLLIN) @@ -75,7 +92,7 @@ def __init__(self, transport: ZMQTransportConfig): self._pending: dict[bytes, dict[int, TrainingBatch]] = {} self.logger.info( - f"ZMQ training batch receiver initialized: endpoint=tcp://{transport.host}:{transport.port} hwm={transport.hwm}" + f"ZMQ training batch receiver initialized: endpoint=tcp://{bind_host}:{transport.port} hwm={transport.hwm}" ) def can_receive(self) -> bool: @@ -175,22 +192,27 @@ def __init__(self, output_dir: Path, data_world_size: int, current_step: int, tr """ZMQ micro batch sender that sends micro batches to the trainers through ZMQ transport. There is one sender for the entire data world.""" super().__init__(output_dir, data_world_size) self.context = zmq.Context.instance() + self._ready_timeout_ms = _timeout_ms(transport.ready_timeout_seconds) + self._publish_grace_seconds = transport.publish_grace_ms / 1000.0 # Data channel (PUB) self.socket: zmq.Socket = self.context.socket(zmq.PUB) self.socket.setsockopt(zmq.SNDHWM, transport.hwm) - self.socket.bind(f"tcp://{transport.host}:{transport.port + 1}") + bind_host = _bind_host() + self.socket.bind(f"tcp://{bind_host}:{transport.port + 1}") # ready barrier socket, to avoid slow joiners dropping for step 0 (and generally at startup) self.ready_socket: zmq.Socket = self.context.socket(zmq.PULL) self.ready_socket.setsockopt(zmq.RCVHWM, transport.hwm) - self.ready_socket.bind(f"tcp://{transport.host}:{transport.port + 2}") + self.ready_socket.bind(f"tcp://{bind_host}:{transport.port + 2}") + self.ready_poller = zmq.Poller() + self.ready_poller.register(self.ready_socket, zmq.POLLIN) self._ready = False self.logger.info( - f"ZMQ micro batch sender initialized: endpoint=tcp://{transport.host}:{transport.port + 1} " - f"ready_endpoint=tcp://{transport.host}:{transport.port + 2} hwm={transport.hwm}" + f"ZMQ micro batch sender initialized: endpoint=tcp://{bind_host}:{transport.port + 1} " + f"ready_endpoint=tcp://{bind_host}:{transport.port + 2} hwm={transport.hwm}" ) self._topic_prefix = b"data_rank|" @@ -204,9 +226,19 @@ def _wait_for_ready(self) -> None: self.logger.debug(f"Waiting for {self.data_world_size} READY messages") ready_ranks: set[int] = set() - # Block until all ranks have announced readiness + deadline = monotonic() + self._ready_timeout_ms / 1000.0 while len(ready_ranks) < self.data_world_size: - msg = self.ready_socket.recv() # blocks + remaining_ms = max(0, int((deadline - monotonic()) * 1000)) + if remaining_ms == 0: + missing = sorted(set(range(self.data_world_size)) - ready_ranks) + raise TimeoutError( + f"Timed out waiting for ZMQ micro-batch READY messages from ranks {missing} " + f"after {self._ready_timeout_ms / 1000.0:.0f}s" + ) + events = dict(self.ready_poller.poll(timeout=remaining_ms)) + if self.ready_socket not in events: + continue + msg = self.ready_socket.recv(flags=zmq.NOBLOCK) try: rank = int(msg.decode("utf-8")) except Exception: @@ -214,6 +246,8 @@ def _wait_for_ready(self) -> None: ready_ranks.add(rank) self.logger.debug(f"All {self.data_world_size} ranks READY, starting PUB") + if self._publish_grace_seconds > 0: + sleep(self._publish_grace_seconds) self._ready = True def send(self, micro_batch_grid: list[list[MicroBatch]]) -> None: @@ -229,7 +263,8 @@ def send(self, micro_batch_grid: list[list[MicroBatch]]) -> None: for data_rank in range(self.data_world_size): buffer = self.encoder.encode(micro_batch_grid[data_rank]) topic = self._topic_prefix + str(data_rank).encode("utf-8") + b"|" - self.socket.send_multipart([topic, buffer], copy=False) + step = str(self._current_step).encode("utf-8") + self.socket.send_multipart([topic, step, buffer], copy=False) self._current_step += 1 def close(self) -> None: @@ -245,10 +280,13 @@ def __init__(self, output_dir: Path, data_rank: int, current_step: int, transpor """ZMQ micro batch receiver that receives micro batches from the sender. There is one receiver per data rank.""" super().__init__(output_dir, data_rank) self.context = zmq.Context.instance() + self._recv_timeout_ms = _timeout_ms(transport.recv_timeout_seconds) self.socket: zmq.Socket = self.context.socket(zmq.SUB) self.socket.setsockopt(zmq.RCVHWM, transport.hwm) - self.socket.connect(f"tcp://{transport.host}:{transport.port + 1}") + self.socket.setsockopt(zmq.RCVTIMEO, self._recv_timeout_ms) + connect_host = _connect_host(transport) + self.socket.connect(f"tcp://{connect_host}:{transport.port + 1}") self._topic = b"data_rank|" + str(data_rank).encode("utf-8") + b"|" self.socket.setsockopt(zmq.SUBSCRIBE, self._topic) @@ -259,20 +297,26 @@ def __init__(self, output_dir: Path, data_rank: int, current_step: int, transpor # ready barrier socket, to avoid slow joiners dropping for step 0 (and generally at startup) self.ready_socket: zmq.Socket = self.context.socket(zmq.PUSH) self.ready_socket.setsockopt(zmq.SNDHWM, transport.hwm) - self.ready_socket.connect(f"tcp://{transport.host}:{transport.port + 2}") + self.ready_socket.connect(f"tcp://{connect_host}:{transport.port + 2}") # Announce readiness after connect+subscribe are set self.ready_socket.send(str(data_rank).encode("utf-8")) self.logger.info( - f"ZMQ micro batch receiver initialized: endpoint=tcp://{transport.host}:{transport.port + 1} " - f"ready_endpoint=tcp://{transport.host}:{transport.port + 2} hwm={transport.hwm}" + f"ZMQ micro batch receiver initialized: endpoint=tcp://{connect_host}:{transport.port + 1} " + f"ready_endpoint=tcp://{connect_host}:{transport.port + 2} hwm={transport.hwm} " + f"recv_timeout_seconds={transport.recv_timeout_seconds}" ) self._current_step = current_step def wait(self) -> None: - self.poller.poll(timeout=None) + events = dict(self.poller.poll(timeout=self._recv_timeout_ms)) + if self.socket not in events: + raise TimeoutError( + f"Timed out waiting for ZMQ micro-batch for data_rank={self.data_rank} " + f"step={self._current_step} after {self._recv_timeout_ms / 1000.0:.0f}s" + ) def can_receive(self) -> bool: events = dict(self.poller.poll(timeout=0)) @@ -280,7 +324,18 @@ def can_receive(self) -> bool: def receive(self) -> list[MicroBatch]: """Receive a micro batch from the trainer.""" - _, payload = self.socket.recv_multipart(copy=False) + try: + _, step_raw, payload = self.socket.recv_multipart(copy=False) + except zmq.Again as exc: + raise TimeoutError( + f"Timed out receiving ZMQ micro-batch payload for data_rank={self.data_rank} " + f"step={self._current_step} after {self._recv_timeout_ms / 1000.0:.0f}s" + ) from exc + step = int(bytes(step_raw).decode("utf-8")) + if step != self._current_step: + raise ValueError( + f"Received ZMQ micro-batch for step {step}, expected {self._current_step} (data_rank={self.data_rank})" + ) micro_batches: list[MicroBatch] = self.decoder.decode(payload) self.logger.debug(f"Received {len(micro_batches)} micro batches for step {self._current_step}") self._current_step += 1 diff --git a/tests/integration/test_zmq_microbatch_smoke.py b/tests/integration/test_zmq_microbatch_smoke.py new file mode 100644 index 0000000000..e0f0fc025c --- /dev/null +++ b/tests/integration/test_zmq_microbatch_smoke.py @@ -0,0 +1,168 @@ +from __future__ import annotations + +import asyncio +import json +import os +import random +import socket +import subprocess +import sys +from pathlib import Path + +import pytest +import tomli_w + +from prime_rl.transport.filesystem import FileSystemTrainingBatchSender +from prime_rl.transport.types import TrainingBatch, TrainingSample + + +def _flatten(values: list[list[float]]) -> list[float]: + return [item for row in values for item in row] + + +def _free_tcp_port() -> int: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.bind(("127.0.0.1", 0)) + return sock.getsockname()[1] + + +def _free_zmq_base_port() -> int: + for _ in range(100): + base = random.randint(30_000, 60_000) + sockets: list[socket.socket] = [] + try: + for port in (base + 1, base + 2): + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.bind(("0.0.0.0", port)) + sockets.append(sock) + except OSError: + pass + else: + return base + finally: + for sock in sockets: + sock.close() + raise RuntimeError("Could not find free ZMQ base port") + + +def _create_run(output_dir: Path) -> Path: + run_dir = output_dir / "run_zmq_smoke" + control_dir = run_dir / "control" + control_dir.mkdir(parents=True) + config = { + "model": {"name": "test-model"}, + "batch_size": 2, + "group_size": 1, + "env": [{"id": "test-env"}], + "sampling": {"temperature": 1.0}, + # test-model is synthetic; bypass model->renderer validation. + "renderer": "None", + } + with open(control_dir / "orch.toml", "wb") as f: + tomli_w.dump(config, f) + return run_dir + + +def _training_sample(prompt_id: int, completion_id: int, advantage: float) -> TrainingSample: + return TrainingSample( + prompt_ids=[prompt_id], + prompt_mask=[False], + completion_ids=[completion_id], + completion_mask=[True], + completion_logprobs=[-0.1], + completion_temperatures=[0.7], + env_name="test-env", + advantage=advantage, + reward=1.0, + ) + + +def test_dataloader_splits_filesystem_rollouts_and_zmq_micro_batches(tmp_path: Path) -> None: + run_dir = _create_run(tmp_path) + asyncio.run( + FileSystemTrainingBatchSender(run_dir).send( + TrainingBatch( + examples=[ + _training_sample(10, 11, 1.0), + _training_sample(20, 21, 2.0), + ], + step=0, + ) + ) + ) + + dist_port = _free_tcp_port() + zmq_base_port = _free_zmq_base_port() + script = Path(__file__).with_name("zmq_microbatch_smoke.py") + env = { + **os.environ, + "MASTER_ADDR": "127.0.0.1", + } + result = subprocess.run( + [ + sys.executable, + "-m", + "torch.distributed.run", + "--nproc_per_node=2", + "--master_addr=127.0.0.1", + "--master_port", + str(dist_port), + script.as_posix(), + tmp_path.as_posix(), + str(zmq_base_port), + ], + cwd=Path(__file__).parents[2], + env=env, + capture_output=True, + text=True, + timeout=60, + check=False, + ) + + assert result.returncode == 0, f"stdout:\n{result.stdout}\nstderr:\n{result.stderr}" + + outputs = [json.loads((tmp_path / "rank_outputs" / f"rank_{rank}.json").read_text()) for rank in range(2)] + expected = [ + { + "input_ids": [[10, 11]], + "position_ids": [[0, 1]], + "loss_mask": [[False, True]], + "advantages": [[1.0, 1.0]], + "rewards": [[1.0, 1.0]], + "inference_logprobs": [[0.0, -0.1]], + "temperatures": [[0.7, 0.7]], + "env_names": ["test-env", "test-env"], + "lora_num_tokens": [2], + "training_mode": "rl", + "mm_kwargs": None, + "mm_token_type_ids": None, + }, + { + "input_ids": [[20, 21]], + "position_ids": [[0, 1]], + "loss_mask": [[False, True]], + "advantages": [[2.0, 2.0]], + "rewards": [[1.0, 1.0]], + "inference_logprobs": [[0.0, -0.1]], + "temperatures": [[0.7, 0.7]], + "env_names": ["test-env", "test-env"], + "lora_num_tokens": [2], + "training_mode": "rl", + "mm_kwargs": None, + "mm_token_type_ids": None, + }, + ] + for actual, expected_rank in zip(outputs, expected, strict=True): + for key in ( + "input_ids", + "position_ids", + "loss_mask", + "env_names", + "lora_num_tokens", + "training_mode", + "mm_kwargs", + "mm_token_type_ids", + ): + assert actual[key] == expected_rank[key] + for key in ("advantages", "rewards", "inference_logprobs", "temperatures"): + assert _flatten(actual[key]) == pytest.approx(_flatten(expected_rank[key])) diff --git a/tests/integration/zmq_microbatch_smoke.py b/tests/integration/zmq_microbatch_smoke.py new file mode 100644 index 0000000000..b3f22d7c48 --- /dev/null +++ b/tests/integration/zmq_microbatch_smoke.py @@ -0,0 +1,78 @@ +from __future__ import annotations + +import json +import sys +from pathlib import Path + +import torch +import torch.distributed as dist + +import prime_rl.trainer.runs as runs +from prime_rl.configs.shared import FileSystemTransportConfig, ZMQTransportConfig +from prime_rl.trainer.rl.data import DataLoader +from prime_rl.trainer.runs import setup_multi_run_manager +from prime_rl.trainer.world import get_world, reset_world + + +def main() -> None: + output_dir = Path(sys.argv[1]) + zmq_base_port = int(sys.argv[2]) + + dist.init_process_group("gloo", init_method="env://") + reset_world() + runs._MULTI_RUN_MANAGER = None + + world = get_world() + setup_multi_run_manager(output_dir=output_dir, max_runs=1, device=torch.device("cpu")) + + loader = DataLoader( + output_dir=output_dir, + start_step=0, + dp_world_size=world.world_size, + seq_len=2, + pad_to_multiple_of=1, + tokenizer=None, + config=FileSystemTransportConfig(), + micro_batch_transport_config=ZMQTransportConfig( + host="127.0.0.1", + port=zmq_base_port, + recv_timeout_seconds=5, + ready_timeout_seconds=5, + publish_timeout_seconds=20, + publish_grace_ms=0, + ), + ) + + loader.wait_for_batch() + loader.synchronize_state() + batch = loader.get_batch() + assert len(batch) == 1 + micro_batch = batch[0] + + rank_output_dir = output_dir / "rank_outputs" + rank_output_dir.mkdir(exist_ok=True) + with open(rank_output_dir / f"rank_{world.rank}.json", "w") as f: + json.dump( + { + "input_ids": micro_batch["input_ids"].tolist(), + "position_ids": micro_batch["position_ids"].tolist(), + "loss_mask": micro_batch["loss_mask"].tolist(), + "advantages": micro_batch["advantages"].tolist(), + "rewards": micro_batch["rewards"].tolist() if micro_batch["rewards"] is not None else None, + "inference_logprobs": micro_batch["inference_logprobs"].tolist(), + "temperatures": micro_batch["temperatures"].tolist(), + "env_names": micro_batch["env_names"], + "lora_num_tokens": micro_batch["lora_num_tokens"].tolist(), + "training_mode": micro_batch["training_mode"], + "mm_kwargs": micro_batch["mm_kwargs"], + "mm_token_type_ids": micro_batch["mm_token_type_ids"], + }, + f, + ) + + dist.barrier() + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/tests/unit/test_configs.py b/tests/unit/test_configs.py index a4117b927e..1cabbe9d26 100644 --- a/tests/unit/test_configs.py +++ b/tests/unit/test_configs.py @@ -187,6 +187,14 @@ def test_removed_fused_lm_head_chunk_size_field_is_rejected(): TrainerModelConfig.model_validate({"fused_lm_head_chunk_size": "auto"}) +def test_trainer_splits_rollout_and_micro_batch_transport_defaults(): + config = TrainerConfig() + + assert config.rollout_transport.type == "filesystem" + assert config.micro_batch_transport.type == "zmq" + assert config.micro_batch_transport.hwm == 64 + + def test_orchestrator_vlm_requires_renderer(): with pytest.raises(ValidationError, match="orchestrator.renderer must be set when model.vlm is set"): OrchestratorConfig.model_validate( diff --git a/tests/unit/train/rl/test_packer.py b/tests/unit/train/rl/test_packer.py index 8ba9b0f55a..cddbb0288b 100644 --- a/tests/unit/train/rl/test_packer.py +++ b/tests/unit/train/rl/test_packer.py @@ -5,6 +5,7 @@ import tomli_w import torch import torch.distributed as dist +import torch.distributed.distributed_c10d as c10d import prime_rl.trainer.runs as runs from prime_rl.configs.shared import FileSystemTransportConfig @@ -251,3 +252,20 @@ def test_multipacker_pack_packs_mm_refs_within_each_run_when_enabled(tmp_path, m assert mb.position_ids == [0, 1, 0, 1] tagged = [i for i, n in enumerate(mb.lora_num_tokens) if n > 0] assert len(tagged) == 1 + + +def test_micro_batch_publish_status_round_trip(): + from prime_rl.trainer.rl.data import DataLoader + + loader = DataLoader.__new__(DataLoader) + loader._store = c10d._get_default_store() + loader._publish_timeout_seconds = 1 + loader._current_step = 987654 + + loader._publish_micro_batch_status(ok=True) + loader._wait_for_micro_batch_status() + + loader._current_step += 1 + loader._publish_micro_batch_status(ok=False, error="boom") + with pytest.raises(RuntimeError, match="boom"): + loader._wait_for_micro_batch_status() diff --git a/tests/unit/transport/test_zmq.py b/tests/unit/transport/test_zmq.py new file mode 100644 index 0000000000..23cb29e3f5 --- /dev/null +++ b/tests/unit/transport/test_zmq.py @@ -0,0 +1,101 @@ +import random +import socket +from pathlib import Path + +import pytest + +from prime_rl.configs.shared import ZMQTransportConfig +from prime_rl.transport.types import MicroBatch +from prime_rl.transport.zmq import ZMQMicroBatchReceiver, ZMQMicroBatchSender + + +def _free_base_port() -> int: + for _ in range(100): + base = random.randint(30_000, 60_000) + sockets = [] + try: + for port in (base + 1, base + 2): + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.bind(("0.0.0.0", port)) + sockets.append(sock) + except OSError: + pass + else: + return base + finally: + for sock in sockets: + sock.close() + raise RuntimeError("Could not find free ZMQ base port") + + +def _micro_batch(token: int = 1) -> MicroBatch: + return MicroBatch( + input_ids=[token], + loss_mask=[True], + advantages=[1.0], + inference_logprobs=[0.0], + position_ids=[0], + temperatures=[1.0], + env_names=["test"], + lora_num_tokens=[1], + ) + + +def _transport(**overrides) -> ZMQTransportConfig: + config = {"host": "127.0.0.1", "port": _free_base_port()} + config.update(overrides) + return ZMQTransportConfig(**config) + + +def test_zmq_micro_batch_routes_each_rank_topic(tmp_path: Path): + transport = _transport(ready_timeout_seconds=2, recv_timeout_seconds=2) + sender = ZMQMicroBatchSender(tmp_path, data_world_size=2, current_step=7, transport=transport) + receiver_0 = ZMQMicroBatchReceiver(tmp_path, data_rank=0, current_step=7, transport=transport) + receiver_1 = ZMQMicroBatchReceiver(tmp_path, data_rank=1, current_step=7, transport=transport) + try: + sender.send([[_micro_batch(10), _micro_batch(11)], [_micro_batch(20), _micro_batch(21)]]) + receiver_0.wait() + receiver_1.wait() + out_0 = receiver_0.receive() + out_1 = receiver_1.receive() + finally: + receiver_0.close() + receiver_1.close() + sender.close() + + assert [micro_batch.input_ids for micro_batch in out_0] == [[10], [11]] + assert [micro_batch.input_ids for micro_batch in out_1] == [[20], [21]] + + +def test_zmq_micro_batch_receive_timeout(tmp_path: Path): + transport = _transport(recv_timeout_seconds=1) + receiver = ZMQMicroBatchReceiver(tmp_path, data_rank=0, current_step=0, transport=transport) + try: + with pytest.raises(TimeoutError, match="Timed out waiting for ZMQ micro-batch"): + receiver.wait() + finally: + receiver.close() + + +def test_zmq_micro_batch_ready_timeout(tmp_path: Path): + transport = _transport(ready_timeout_seconds=1) + sender = ZMQMicroBatchSender(tmp_path, data_world_size=1, current_step=0, transport=transport) + try: + with pytest.raises(TimeoutError, match="READY messages"): + sender.send([[_micro_batch()]]) + finally: + sender.close() + + +def test_zmq_micro_batch_step_mismatch_fails_fast(tmp_path: Path): + transport = _transport(ready_timeout_seconds=2, recv_timeout_seconds=2) + sender = ZMQMicroBatchSender(tmp_path, data_world_size=1, current_step=5, transport=transport) + receiver = ZMQMicroBatchReceiver(tmp_path, data_rank=0, current_step=6, transport=transport) + try: + sender.send([[_micro_batch()]]) + receiver.wait() + with pytest.raises(ValueError, match="expected 6"): + receiver.receive() + finally: + receiver.close() + sender.close() From 765fbb51a8f1a3951f63f3a6d866cd4b484c6210 Mon Sep 17 00:00:00 2001 From: Eli Gottlieb <78387377+eligotts@users.noreply.github.com> Date: Thu, 4 Jun 2026 03:33:24 +0000 Subject: [PATCH 19/31] fix(monitor): inline offloaded images for platform sample upload After image offload, logged samples carry file:// URIs pointing at local trainer/orchestrator disk, which the Prime platform can't resolve (broken images in the dashboard). When building the sample parquet, re-inline local image files as data: URLs so the dashboard can display them. Purely the platform sample-logging path; does not touch training, offload, or the materialize path. Guarded: file://+image-mime only, 2 MB cap, OSError -> skip (handles a swept/missing file), per-call dedup cache, and the original rollout is not mutated. Pairs path.as_uri() (offload) with urlparse/unquote (monitor) so file:// URIs round-trip. Co-Authored-By: Codex Co-Authored-By: Claude Opus 4.8 (1M context) --- src/prime_rl/orchestrator/trajectories.py | 8 +-- src/prime_rl/utils/monitor/prime.py | 69 +++++++++++++++++++- tests/unit/orchestrator/test_trajectories.py | 17 +++++ tests/unit/utils/test_prime_monitor.py | 42 ++++++++++++ 4 files changed, 129 insertions(+), 7 deletions(-) diff --git a/src/prime_rl/orchestrator/trajectories.py b/src/prime_rl/orchestrator/trajectories.py index ea39a352bf..6d49a83110 100644 --- a/src/prime_rl/orchestrator/trajectories.py +++ b/src/prime_rl/orchestrator/trajectories.py @@ -644,7 +644,7 @@ def offload_images_to_disk(rollouts: list[vf.RolloutOutput], output_dir: Path) - images_dir = (output_dir / mm_store.IMAGE_ASSET_SUBDIR).resolve() images_dir.mkdir(parents=True, exist_ok=True) - written: set[str] = set() + written: set[Path] = set() for output in rollouts: for step in output.get("trajectory", []): @@ -688,7 +688,7 @@ def offload_images_to_disk(rollouts: list[vf.RolloutOutput], output_dir: Path) - content_hash = hashlib.sha256(raw).hexdigest()[:16] path = images_dir / f"{content_hash}{ext}" - if content_hash not in written: + if path not in written: if not path.exists(): path.write_bytes(raw) else: @@ -700,7 +700,7 @@ def offload_images_to_disk(rollouts: list[vf.RolloutOutput], output_dir: Path) - path.touch() except OSError: pass - written.add(content_hash) - image_url["url"] = f"{_FILE_URL_PREFIX}{path}" + written.add(path) + image_url["url"] = path.as_uri() return len(written) diff --git a/src/prime_rl/utils/monitor/prime.py b/src/prime_rl/utils/monitor/prime.py index 657037c6f7..96933925b5 100644 --- a/src/prime_rl/utils/monitor/prime.py +++ b/src/prime_rl/utils/monitor/prime.py @@ -1,13 +1,16 @@ import asyncio +import base64 import io import json import math +import mimetypes import os import time from datetime import datetime, timezone from pathlib import Path from threading import Thread from typing import Any +from urllib.parse import unquote, urlparse import httpx import pyarrow as pa @@ -57,6 +60,8 @@ def _json(val: Any) -> str: _DROPPED_JSON_VALUE = object() +_FILE_URL_SCHEME = "file" +_MAX_INLINE_SAMPLE_IMAGE_BYTES = 2 * 1024 * 1024 def _drop_non_finite_json_values(value: Any, dropped_paths: list[str], path: str = "") -> Any: @@ -89,6 +94,63 @@ def _drop_non_finite_json_values(value: Any, dropped_paths: list[str], path: str return value +def _local_image_file_to_data_url( + url: str, + cache: dict[str, str | None], + max_bytes: int = _MAX_INLINE_SAMPLE_IMAGE_BYTES, +) -> str | None: + if url in cache: + return cache[url] + + parsed = urlparse(url) + if parsed.scheme != _FILE_URL_SCHEME or parsed.netloc not in ("", "localhost"): + cache[url] = None + return None + + path = Path(unquote(parsed.path)) + media_type = mimetypes.guess_type(path.name)[0] + if media_type is None or not media_type.startswith("image/"): + cache[url] = None + return None + + try: + if path.stat().st_size > max_bytes: + cache[url] = None + return None + encoded = base64.b64encode(path.read_bytes()).decode("ascii") + except OSError: + cache[url] = None + return None + + data_url = f"data:{media_type};base64,{encoded}" + cache[url] = data_url + return data_url + + +def _inline_local_image_urls(value: Any, cache: dict[str, str | None]) -> Any: + if isinstance(value, list): + return [_inline_local_image_urls(item, cache) for item in value] + + if not isinstance(value, dict): + return value + + inlined = {key: _inline_local_image_urls(item, cache) for key, item in value.items()} + image_url = inlined.get("image_url") + if not isinstance(image_url, dict): + return inlined + + url = image_url.get("url") + if not isinstance(url, str): + return inlined + + data_url = _local_image_file_to_data_url(url, cache) + if data_url is None: + return inlined + + inlined["image_url"] = {**image_url, "url": data_url} + return inlined + + class PrimeMonitor(Monitor): """Logs to Prime Intellect API.""" @@ -332,6 +394,7 @@ def _rollouts_to_parquet_bytes(self, rollouts: list[vf.RolloutOutput], step: int """Convert rollouts directly to Parquet bytes for upload.""" now = datetime.now(timezone.utc) rows = [] + image_data_url_cache: dict[str, str | None] = {} for sample_id, rollout in enumerate(rollouts): prompt = rollout.get("prompt") @@ -366,9 +429,9 @@ def _rollouts_to_parquet_bytes(self, rollouts: list[vf.RolloutOutput], step: int "tag": "", "problem_id": problem_id, "sample_id": sample_id, - "prompt": json.dumps(prompt), - "completion": json.dumps(completion), - "trajectory": json.dumps(trajectory_data), + "prompt": json.dumps(_inline_local_image_urls(prompt, image_data_url_cache)), + "completion": json.dumps(_inline_local_image_urls(completion, image_data_url_cache)), + "trajectory": json.dumps(_inline_local_image_urls(trajectory_data, image_data_url_cache)), "answer": rollout.get("answer") or "", "env_name": rollout.get("env_name") or "", "task": rollout.get("task") or "", diff --git a/tests/unit/orchestrator/test_trajectories.py b/tests/unit/orchestrator/test_trajectories.py index d2826ae220..98b5b402cc 100644 --- a/tests/unit/orchestrator/test_trajectories.py +++ b/tests/unit/orchestrator/test_trajectories.py @@ -1317,6 +1317,23 @@ def test_offload_data_uri_writes_decoded_bytes(tmp_path): assert path.read_bytes() == raw +def test_offload_same_bytes_with_different_media_types_writes_both_files(tmp_path): + raw = b"same-image-bytes" + b64 = base64.b64encode(raw).decode("ascii") + png_rollout = _image_rollout(f"data:image/png;base64,{b64}") + jpg_rollout = _image_rollout(f"data:image/jpeg;base64,{b64}") + + n = offload_images_to_disk([png_rollout, jpg_rollout], tmp_path) + + assert n == 2 + png_path = Path(_step_image_url(png_rollout)[len("file://") :]) + jpg_path = Path(_step_image_url(jpg_rollout)[len("file://") :]) + assert png_path.suffix == ".png" + assert jpg_path.suffix == ".jpg" + assert png_path.read_bytes() == raw + assert jpg_path.read_bytes() == raw + + def test_offload_leaves_file_url_already_in_assets(tmp_path): images_dir = tmp_path / "assets" / "images" images_dir.mkdir(parents=True) diff --git a/tests/unit/utils/test_prime_monitor.py b/tests/unit/utils/test_prime_monitor.py index f44065a7c6..2f5c138ed4 100644 --- a/tests/unit/utils/test_prime_monitor.py +++ b/tests/unit/utils/test_prime_monitor.py @@ -1,5 +1,7 @@ +import base64 import io import json +from pathlib import Path from unittest.mock import Mock import pyarrow.parquet as pq @@ -94,6 +96,46 @@ def test_rollouts_to_parquet_bytes_skips_rollouts_without_trajectory(): assert rows[0]["sample_id"] == 0 +def test_rollouts_to_parquet_bytes_inlines_local_image_urls_without_mutating(tmp_path: Path): + monitor = _new_monitor() + monitor.run_id = "run-images" + image_path = tmp_path / "sample.jpg" + image_bytes = b"jpeg-bytes" + image_path.write_bytes(image_bytes) + file_url = image_path.as_uri() + rollout = _build_rollout(example_id=1, reward=1.0, task="image-task") + image_part = {"type": "image_url", "image_url": {"url": file_url}} + rollout["prompt"] = [{"role": "user", "content": [image_part]}] + rollout["completion"] = [{"role": "assistant", "content": [image_part]}] + rollout["trajectory"][0]["prompt"] = [{"role": "user", "content": [image_part]}] + + parquet_bytes = monitor._rollouts_to_parquet_bytes([rollout], step=9) + + assert parquet_bytes is not None + row = pq.read_table(io.BytesIO(parquet_bytes)).to_pylist()[0] + expected_url = "data:image/jpeg;base64," + base64.b64encode(image_bytes).decode("ascii") + assert json.loads(row["prompt"])[0]["content"][0]["image_url"]["url"] == expected_url + assert json.loads(row["completion"])[0]["content"][0]["image_url"]["url"] == expected_url + assert json.loads(row["trajectory"])[0]["prompt"][0]["content"][0]["image_url"]["url"] == expected_url + assert rollout["prompt"][0]["content"][0]["image_url"]["url"] == file_url + + +def test_rollouts_to_parquet_bytes_leaves_large_local_image_urls(tmp_path: Path): + monitor = _new_monitor() + monitor.run_id = "run-large-image" + image_path = tmp_path / "large.png" + image_path.write_bytes(b"x" * (2 * 1024 * 1024 + 1)) + file_url = image_path.as_uri() + rollout = _build_rollout(example_id=1, reward=1.0, task="image-task") + rollout["prompt"] = [{"role": "user", "content": [{"type": "image_url", "image_url": {"url": file_url}}]}] + + parquet_bytes = monitor._rollouts_to_parquet_bytes([rollout], step=9) + + assert parquet_bytes is not None + row = pq.read_table(io.BytesIO(parquet_bytes)).to_pylist()[0] + assert json.loads(row["prompt"])[0]["content"][0]["image_url"]["url"] == file_url + + def test_sanitize_json_payload_drops_non_finite_values_and_logs_paths(): monitor = _new_monitor() monitor.logger = Mock() From d9bac3bba85e9d65c64020048cc6eeab39d03c17 Mon Sep 17 00:00:00 2001 From: Eli Gottlieb <78387377+eligotts@users.noreply.github.com> Date: Thu, 4 Jun 2026 03:48:43 +0000 Subject: [PATCH 20/31] chore(transport): raise default publish_grace_ms 100 -> 1000 Conservative one-time startup grace for the first prod runs: gives PUB/SUB subscriptions ample time to propagate before the first publish, avoiding step-0 slow-joiner drops (which would crash-loop at startup). It is a one-time cost at job start, so erring large is essentially free; dial down once a topology is observed to start cleanly. Co-Authored-By: Codex Co-Authored-By: Claude Opus 4.8 (1M context) --- packages/prime-rl-configs/src/prime_rl/configs/shared.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/packages/prime-rl-configs/src/prime_rl/configs/shared.py b/packages/prime-rl-configs/src/prime_rl/configs/shared.py index 5a4c6e1e12..50a8811f32 100644 --- a/packages/prime-rl-configs/src/prime_rl/configs/shared.py +++ b/packages/prime-rl-configs/src/prime_rl/configs/shared.py @@ -269,8 +269,10 @@ class ZMQTransportConfig(BaseTransportConfig): publish_timeout_seconds: int = Field(1800, ge=1) """Seconds ranks wait for the master to publish/fail a packed micro-batch step.""" - publish_grace_ms: int = Field(100, ge=0) - """Small startup grace after all READY messages arrive, reducing PUB/SUB slow-joiner races.""" + publish_grace_ms: int = Field(1000, ge=0) + """One-time startup grace after all READY messages arrive, before the first publish, to let + PUB/SUB subscriptions propagate and avoid step-0 slow-joiner drops. Conservative by default + since it is a one-time cost; lower it once a topology is observed to start cleanly.""" TransportConfig: TypeAlias = Annotated[FileSystemTransportConfig | ZMQTransportConfig, Field(discriminator="type")] From f7a7fbc687a0cf5f9fc7ef89c349d9f00c29c361 Mon Sep 17 00:00:00 2001 From: Eli Gottlieb <78387377+eligotts@users.noreply.github.com> Date: Thu, 4 Jun 2026 18:24:13 +0000 Subject: [PATCH 21/31] fix(transport): remove generation-bound publish timeout from micro-batch wait MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The worker wait on the master's micro-batch publish status was bounded by a fixed timeout (publish_timeout_seconds=1800), but the master sets that status only after pack() — which blocks on the orchestrator for generation. Slow multimodal steps (observed 1178-3137s) legitimately exceed 1800s, so every rank timed out -> coordinated crash -> deterministic crash-loop on the slow step. The old filesystem path had no such timeout and simply idled. Wait on the publish key with no deadline instead. Liveness is already covered: a wedged master is killed by the packer watchdog (-> torchrun tears down the group) and a master pack error sets ok=False for an immediate coordinated fail. Genuine ZMQ delivery stays bounded by the receiver's recv_timeout once published. Net deletes the publish_timeout_seconds knob and its plumbing. Co-Authored-By: Codex Co-Authored-By: Claude Opus 4.8 (1M context) --- .../src/prime_rl/configs/shared.py | 3 --- src/prime_rl/trainer/rl/data.py | 25 ++++++++----------- tests/integration/zmq_microbatch_smoke.py | 1 - tests/unit/train/rl/test_packer.py | 1 - 4 files changed, 11 insertions(+), 19 deletions(-) diff --git a/packages/prime-rl-configs/src/prime_rl/configs/shared.py b/packages/prime-rl-configs/src/prime_rl/configs/shared.py index 50a8811f32..c1f096dc89 100644 --- a/packages/prime-rl-configs/src/prime_rl/configs/shared.py +++ b/packages/prime-rl-configs/src/prime_rl/configs/shared.py @@ -266,9 +266,6 @@ class ZMQTransportConfig(BaseTransportConfig): ready_timeout_seconds: int = Field(300, ge=1) """Seconds the micro-batch sender waits at startup for rank READY messages before failing fast.""" - publish_timeout_seconds: int = Field(1800, ge=1) - """Seconds ranks wait for the master to publish/fail a packed micro-batch step.""" - publish_grace_ms: int = Field(1000, ge=0) """One-time startup grace after all READY messages arrive, before the first publish, to let PUB/SUB subscriptions propagate and avoid step-0 slow-joiner drops. Conservative by default diff --git a/src/prime_rl/trainer/rl/data.py b/src/prime_rl/trainer/rl/data.py index 73d1c1f11d..f7067e9f32 100644 --- a/src/prime_rl/trainer/rl/data.py +++ b/src/prime_rl/trainer/rl/data.py @@ -1,6 +1,5 @@ import pickle import time -from datetime import timedelta from pathlib import Path from typing import TypedDict @@ -18,7 +17,8 @@ from prime_rl.transport import MicroBatch, MicroBatchReceiver, TransportConfig, setup_micro_batch_receiver from prime_rl.utils.logger import get_logger -DEFAULT_MICRO_BATCH_PUBLISH_TIMEOUT_SECONDS = 1800 +# Poll interval for the worker wait on the master's micro-batch publish status. +_PUBLISH_POLL_SECONDS = 1.0 class TensorMicroBatch(TypedDict): @@ -182,11 +182,6 @@ def __init__( self.world = get_world() self._current_step = start_step self._micro_batch_transport_config = micro_batch_transport_config or config - self._publish_timeout_seconds = getattr( - self._micro_batch_transport_config, - "publish_timeout_seconds", - DEFAULT_MICRO_BATCH_PUBLISH_TIMEOUT_SECONDS, - ) self._store = c10d._get_default_store() if self.world.is_master: @@ -231,14 +226,16 @@ def _publish_micro_batch_status(self, *, ok: bool, error: str = "") -> None: self._store.set(self._publish_status_key(), pickle.dumps({"ok": ok, "error": error})) def _wait_for_micro_batch_status(self) -> None: + # No deadline here: packing is generation-bound (the master blocks on the + # orchestrator) and can legitimately take arbitrarily long, so a fixed + # timeout would crash the run on slow generation rather than a real fault. + # Liveness is already covered: a wedged master is killed by the packer + # watchdog -> torchrun tears down the group, and a master pack error sets + # ok=False below for an immediate coordinated fail. Genuine ZMQ delivery + # is still bounded by the receiver's recv_timeout once published. key = self._publish_status_key() - try: - self._store.wait([key], timedelta(seconds=self._publish_timeout_seconds)) - except Exception as exc: - raise TimeoutError( - f"Timed out waiting for trainer master to publish micro-batch step {self._current_step} " - f"after {self._publish_timeout_seconds}s" - ) from exc + while not self._store.check([key]): + time.sleep(_PUBLISH_POLL_SECONDS) status = pickle.loads(self._store.get(key)) if not status.get("ok", False): diff --git a/tests/integration/zmq_microbatch_smoke.py b/tests/integration/zmq_microbatch_smoke.py index b3f22d7c48..d6061c5ccf 100644 --- a/tests/integration/zmq_microbatch_smoke.py +++ b/tests/integration/zmq_microbatch_smoke.py @@ -38,7 +38,6 @@ def main() -> None: port=zmq_base_port, recv_timeout_seconds=5, ready_timeout_seconds=5, - publish_timeout_seconds=20, publish_grace_ms=0, ), ) diff --git a/tests/unit/train/rl/test_packer.py b/tests/unit/train/rl/test_packer.py index cddbb0288b..35623daab1 100644 --- a/tests/unit/train/rl/test_packer.py +++ b/tests/unit/train/rl/test_packer.py @@ -259,7 +259,6 @@ def test_micro_batch_publish_status_round_trip(): loader = DataLoader.__new__(DataLoader) loader._store = c10d._get_default_store() - loader._publish_timeout_seconds = 1 loader._current_step = 987654 loader._publish_micro_batch_status(ok=True) From 52f4bd8603b0d7995b224bfdfd96dfc96fb9aef5 Mon Sep 17 00:00:00 2001 From: Eli Gottlieb <78387377+eligotts@users.noreply.github.com> Date: Fri, 5 Jun 2026 02:12:38 +0000 Subject: [PATCH 22/31] fix(monitor): stream multimodal sample uploads from disk Large multimodal sample parquets (inlined base64 images) were PUT to R2 as one in-RAM body, failing on a memory-tight orchestrator with WriteTimeout / OpenSSL [BUF] malloc failure. Gated on VLM runs (run_config model/student.model is_vlm), write the parquet straight to a disk-backed temp file (co-located with output_dir to avoid tmpfs) and stream it: an async byte iterator with an explicit Content-Length keeps the presigned PUT on Content-Length (httpx drops chunked transfer-encoding) and bounds each TLS write; a generous write timeout replaces the flat 30s; concurrent streamed uploads are serialized; the temp file is unlinked on every exit. Text runs keep the inline byte PUT unchanged. Co-Authored-By: Codex Co-Authored-By: Claude Opus 4.8 (1M context) --- src/prime_rl/utils/monitor/prime.py | 145 ++++++++++++++++++------- tests/unit/utils/test_prime_monitor.py | 42 +++++++ 2 files changed, 150 insertions(+), 37 deletions(-) diff --git a/src/prime_rl/utils/monitor/prime.py b/src/prime_rl/utils/monitor/prime.py index 96933925b5..a62133da25 100644 --- a/src/prime_rl/utils/monitor/prime.py +++ b/src/prime_rl/utils/monitor/prime.py @@ -1,10 +1,12 @@ import asyncio import base64 +import contextlib import io import json import math import mimetypes import os +import tempfile import time from datetime import datetime, timezone from pathlib import Path @@ -63,6 +65,33 @@ def _json(val: Any) -> str: _FILE_URL_SCHEME = "file" _MAX_INLINE_SAMPLE_IMAGE_BYTES = 2 * 1024 * 1024 +# Multimodal sample parquets carry inlined base64 images and can be large; stream them +# from disk with a generous write window instead of buffering the whole body in RAM. +_R2_UPLOAD_TIMEOUT = httpx.Timeout(connect=30.0, write=600.0, read=300.0, pool=30.0) + + +async def _aiter_handle(f, chunk_size: int = 65536): + """Async byte iterator over an open binary file. Async (not a sync file object) so an + AsyncClient can send it; small chunks bound each TLS write to avoid large allocations.""" + while chunk := f.read(chunk_size): + yield chunk + + +def _table_to_parquet_bytes(table: "pa.Table") -> bytes: + buf = io.BytesIO() + pq.write_table(table, buf, compression="snappy", use_dictionary=True, write_statistics=True) + return buf.getvalue() + + +def _run_config_is_multimodal(run_config: Any) -> bool: + """True when the run trains a VLM. Handles trainer (``model``) and orchestrator + (``student.model``) config shapes; defaults False for anything else.""" + candidates = ( + getattr(run_config, "model", None), + getattr(getattr(run_config, "student", None), "model", None), + ) + return any(getattr(model, "is_vlm", False) for model in candidates if model is not None) + def _drop_non_finite_json_values(value: Any, dropped_paths: list[str], path: str = "") -> Any: if isinstance(value, float) and not math.isfinite(value): @@ -182,6 +211,7 @@ def __init__( self.history: list[dict[str, Any]] = [] self._keep_full_history = keep_full_history self.output_dir = output_dir + self._is_multimodal = _run_config_is_multimodal(run_config) self._registered = False self._finalized = False self._closed = False @@ -375,23 +405,30 @@ def log_samples(self, rollouts: list[vf.RolloutOutput], step: int) -> None: self.logger.info(f"Logging {len(rollouts)} samples to Prime Intellect API at step {step}") start_time = time.perf_counter() - parquet_bytes = self._rollouts_to_parquet_bytes(rollouts, step) - - if not parquet_bytes: + table = self._rollouts_to_parquet_table(rollouts, step) + if table is None: self.logger.warning(f"No samples to log at step {step}") return - self._pending_sample_steps.add(step) + # Multimodal parquets inline base64 images and can be large: write straight to a + # disk-backed temp file and stream the upload, so the serialized body is never held + # in RAM as bytes and nothing is retained during the async upload. Text runs stay inline. + if self._is_multimodal: + body: bytes | Path = self._spill_table_to_tempfile(table, step) + else: + body = _table_to_parquet_bytes(table) + del table - # Use presigned URL flow for uploading samples - self._upload_samples_via_presigned_url(parquet_bytes, step) + # Mark pending only once the body is built and ready to schedule. + self._pending_sample_steps.add(step) + self._upload_samples_via_presigned_url(body, step) self.logger.debug( f"Initiated samples upload at step {step} to Prime Intellect API in {time.perf_counter() - start_time:.2f}s" ) - def _rollouts_to_parquet_bytes(self, rollouts: list[vf.RolloutOutput], step: int) -> bytes | None: - """Convert rollouts directly to Parquet bytes for upload.""" + def _rollouts_to_parquet_table(self, rollouts: list[vf.RolloutOutput], step: int) -> "pa.Table | None": + """Build the sample Parquet table from rollouts (None when there is nothing to log).""" now = datetime.now(timezone.utc) rows = [] image_data_url_cache: dict[str, str | None] = {} @@ -449,15 +486,26 @@ def _rollouts_to_parquet_bytes(self, rollouts: list[vf.RolloutOutput], step: int if not rows: return None - table = pa.Table.from_pylist(rows, schema=_SAMPLE_SCHEMA) - buf = io.BytesIO() - pq.write_table(table, buf, compression="snappy", use_dictionary=True, write_statistics=True) - return buf.getvalue() + return pa.Table.from_pylist(rows, schema=_SAMPLE_SCHEMA) - def _upload_samples_via_presigned_url(self, parquet_bytes: bytes, step: int) -> None: + def _rollouts_to_parquet_bytes(self, rollouts: list[vf.RolloutOutput], step: int) -> bytes | None: + """Convert rollouts to in-memory Parquet bytes (inline upload path).""" + table = self._rollouts_to_parquet_table(rollouts, step) + return None if table is None else _table_to_parquet_bytes(table) + + def _spill_table_to_tempfile(self, table: "pa.Table", step: int) -> Path: + """Write the sample Parquet straight to a disk-backed temp file, co-located with the + run's output dir to avoid a tmpfs ``/tmp``. The upload coroutine unlinks it when done.""" + directory = str(self.output_dir) if self.output_dir else None + fd, name = tempfile.mkstemp(prefix=f"prime_samples_step{step}_", suffix=".parquet", dir=directory) + with os.fdopen(fd, "wb") as f: + pq.write_table(table, f, compression="snappy", use_dictionary=True, write_statistics=True) + return Path(name) + + def _upload_samples_via_presigned_url(self, body: bytes | Path, step: int) -> None: """Upload Parquet samples using presigned URL flow (fire-and-forget).""" future = asyncio.run_coroutine_threadsafe( - self._upload_samples_via_presigned_url_async(parquet_bytes, step), + self._upload_samples_via_presigned_url_async(body, step), self._loop, ) self._pending_futures.append(future) @@ -465,36 +513,42 @@ def _upload_samples_via_presigned_url(self, parquet_bytes: bytes, step: int) -> self._pending_futures = [f for f in self._pending_futures if not f.done()] async def _upload_samples_via_presigned_url_async( - self, parquet_bytes: bytes, step: int, max_retries: int = 3 + self, body: bytes | Path, step: int, max_retries: int = 3 ) -> None: - """Upload Parquet bytes via presigned URL flow.""" + """Upload Parquet samples via presigned URL flow. A ``Path`` body (multimodal runs) + is streamed from disk under a concurrency gate; a ``bytes`` body is sent inline.""" + # Streamed disk uploads are serialized; inline byte uploads keep their prior behavior. + gate = self._upload_semaphore if isinstance(body, Path) else contextlib.nullcontext() try: - presign_data = await self._request_presigned_url(step) - if not presign_data: - self.logger.warning(f"Failed to get presigned URL for samples at step {step}") - return + async with gate: + presign_data = await self._request_presigned_url(step) + if not presign_data: + self.logger.warning(f"Failed to get presigned URL for samples at step {step}") + return - presigned_url = presign_data["presigned_url"] - s3_key = presign_data["s3_key"] + presigned_url = presign_data["presigned_url"] + s3_key = presign_data["s3_key"] - upload_success = await self._upload_to_r2( - presigned_url, parquet_bytes, content_type="application/parquet", max_retries=max_retries - ) - if not upload_success: - self.logger.warning(f"Failed to upload samples to R2 at step {step}") - return + upload_success = await self._upload_to_r2( + presigned_url, body, content_type="application/parquet", max_retries=max_retries + ) + if not upload_success: + self.logger.warning(f"Failed to upload samples to R2 at step {step}") + return - confirm_success = await self._confirm_samples_upload(step, s3_key) - if not confirm_success: - self.logger.warning(f"Failed to confirm samples upload at step {step}") - return + confirm_success = await self._confirm_samples_upload(step, s3_key) + if not confirm_success: + self.logger.warning(f"Failed to confirm samples upload at step {step}") + return - self.last_log_samples_step = step - self.logger.debug(f"Successfully completed samples upload at step {step}") + self.last_log_samples_step = step + self.logger.debug(f"Successfully completed samples upload at step {step}") except Exception as e: self.logger.warning(f"Failed to upload samples via presigned URL at step {step}: {type(e).__name__}: {e}") finally: + if isinstance(body, Path): + body.unlink(missing_ok=True) self._pending_sample_steps.discard(step) async def _request_presigned_url(self, step: int) -> dict[str, Any] | None: @@ -516,12 +570,24 @@ async def _request_presigned_url(self, step: int) -> dict[str, Any] | None: return None async def _upload_to_r2( - self, presigned_url: str, data: bytes, content_type: str = "application/json", max_retries: int = 3 + self, presigned_url: str, body: bytes | Path, content_type: str = "application/json", max_retries: int = 3 ) -> bool: - """Upload data to R2 using presigned URL.""" + """Upload data to R2 using presigned URL. A ``Path`` is streamed from disk in chunks + (an explicit Content-Length keeps the presigned PUT on Content-Length rather than + chunked transfer-encoding, and bounds each TLS write); ``bytes`` is sent inline.""" for attempt in range(max_retries): + f = None try: - response = await self._client.put(presigned_url, content=data, headers={"Content-Type": content_type}) + if isinstance(body, Path): + f = open(body, "rb") # closed in finally; fresh handle rewinds on retry + headers = {"Content-Type": content_type, "Content-Length": str(os.fstat(f.fileno()).st_size)} + response = await self._client.put( + presigned_url, content=_aiter_handle(f), headers=headers, timeout=_R2_UPLOAD_TIMEOUT + ) + else: + response = await self._client.put( + presigned_url, content=body, headers={"Content-Type": content_type} + ) response.raise_for_status() return True except Exception as e: @@ -531,6 +597,9 @@ async def _upload_to_r2( delay = 2**attempt self.logger.debug(f"Retrying R2 upload in {delay}s (attempt {attempt + 1}/{max_retries})") await asyncio.sleep(delay) + finally: + if f is not None: + f.close() async def _confirm_samples_upload(self, step: int, s3_key: str, max_retries: int = 3) -> bool: """Confirm samples upload with the backend. Returns True on success.""" @@ -677,6 +746,8 @@ def _init_async_client(self) -> None: self._thread.start() self._client = httpx.AsyncClient(timeout=30) self._pending_futures: list[asyncio.Future] = [] + # Serialize multimodal sample uploads so large disk-streamed bodies don't pile up. + self._upload_semaphore = asyncio.Semaphore(1) if hasattr(self, "_pending_sample_steps") and self._pending_sample_steps: self._pending_sample_steps.clear() diff --git a/tests/unit/utils/test_prime_monitor.py b/tests/unit/utils/test_prime_monitor.py index 2f5c138ed4..e5793b4a9f 100644 --- a/tests/unit/utils/test_prime_monitor.py +++ b/tests/unit/utils/test_prime_monitor.py @@ -1,9 +1,11 @@ +import asyncio import base64 import io import json from pathlib import Path from unittest.mock import Mock +import httpx import pyarrow.parquet as pq from prime_rl.utils.monitor.prime import PrimeMonitor @@ -136,6 +138,46 @@ def test_rollouts_to_parquet_bytes_leaves_large_local_image_urls(tmp_path: Path) assert json.loads(row["prompt"])[0]["content"][0]["image_url"]["url"] == file_url +def test_upload_to_r2_streams_path_with_content_length_and_retry_rewind(tmp_path: Path, monkeypatch): + """The multimodal upload path: a Path body is streamed via an AsyncClient with an explicit + Content-Length (no chunked transfer-encoding), and a retry re-sends the full body.""" + monitor = _new_monitor() + monitor.logger = Mock() + + payload = b"PARQUET-BODY" * 50_000 # ~600 KB, spans many 64 KB stream chunks + path = tmp_path / "samples.parquet" + path.write_bytes(payload) + + seen: list[tuple[dict, int]] = [] + + async def handler(request: httpx.Request) -> httpx.Response: + await request.aread() # consume the streamed body + seen.append((dict(request.headers), len(request.content))) + if len(seen) == 1: + raise httpx.ConnectError("transient") # force a retry to exercise rewind + return httpx.Response(200) + + async def _no_sleep(*_args, **_kwargs): + return None + + monkeypatch.setattr(asyncio, "sleep", _no_sleep) # skip retry backoff + monitor._client = httpx.AsyncClient(transport=httpx.MockTransport(handler)) + + async def run() -> bool: + try: + return await monitor._upload_to_r2("https://r2.example/key", path, content_type="application/parquet") + finally: + await monitor._client.aclose() + + assert asyncio.run(run()) is True + assert len(seen) == 2 # first attempt failed, second succeeded + + headers, body_len = seen[-1] + assert body_len == len(payload) # full body re-sent after rewind, not truncated + assert headers["content-length"] == str(len(payload)) + assert "transfer-encoding" not in headers # presigned PUT must stay on Content-Length + + def test_sanitize_json_payload_drops_non_finite_values_and_logs_paths(): monitor = _new_monitor() monitor.logger = Mock() From 468a09aad6c05ca3e7caa2e8ab9628c851310079 Mon Sep 17 00:00:00 2001 From: Eli Gottlieb <78387377+eligotts@users.noreply.github.com> Date: Fri, 5 Jun 2026 06:23:50 +0000 Subject: [PATCH 23/31] chore(deps): bump renderers + verifiers to merged-main submodule commits Both submodules merged origin/main on their feat/ephemeral-mm-pixels branches (renderers f7696cd, verifiers 00d83204); these supersede main's pins and carry our mm work on top. Co-Authored-By: Claude Opus 4.8 (1M context) --- deps/renderers | 2 +- deps/verifiers | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/deps/renderers b/deps/renderers index a8f874c416..f7696cd91e 160000 --- a/deps/renderers +++ b/deps/renderers @@ -1 +1 @@ -Subproject commit a8f874c416ecb155250db4ca1b732384018289af +Subproject commit f7696cd91e6279aa9bee9d9212534cc8e627ed06 diff --git a/deps/verifiers b/deps/verifiers index a7fc7431b4..00d8320433 160000 --- a/deps/verifiers +++ b/deps/verifiers @@ -1 +1 @@ -Subproject commit a7fc7431b41fb14bb069a5e0f68be24d402a11de +Subproject commit 00d832043321aad9e85e2a70143d9a670e0fee7b From e6998718ea6ceb8d4677e108dbef37965a733659 Mon Sep 17 00:00:00 2001 From: Eli Gottlieb <78387377+eligotts@users.noreply.github.com> Date: Fri, 5 Jun 2026 22:57:30 +0000 Subject: [PATCH 24/31] chore(deps): bump verifiers to c33261b9 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Picks up: no-VCS `fallback-version` so the Docker editable build resolves (fixes `uv sync` LookupError on verifiers), plus worldsims #1557 — multimodal tool content + reasoning_content passthrough in the in-sandbox base runner. Co-Authored-By: Claude Opus 4.8 (1M context) --- deps/verifiers | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deps/verifiers b/deps/verifiers index 00d8320433..c33261b959 160000 --- a/deps/verifiers +++ b/deps/verifiers @@ -1 +1 @@ -Subproject commit 00d832043321aad9e85e2a70143d9a670e0fee7b +Subproject commit c33261b95996754d44992889e1e466fec920f900 From 987d5ecd1f22dda29391337ae3d6db32f6795d5e Mon Sep 17 00:00:00 2001 From: Eli Gottlieb <78387377+eligotts@users.noreply.github.com> Date: Tue, 9 Jun 2026 05:01:32 +0000 Subject: [PATCH 25/31] feat: vLLM serving accepts + materializes mmraw raw-image refs Adds raw-image ref handling to the token serving route: parse mmraw refs, load the raw image from shared disk, reprocess via the HF image processor + vLLM field factory, and validate hash/fingerprint/grid/placeholder before caching. Bumps the renderers submodule to the matching emit-side commit. Co-Authored-By: Codex Co-Authored-By: Claude Opus 4.8 (1M context) --- deps/renderers | 2 +- src/prime_rl/inference/vllm/serving_tokens.py | 252 +++++++++++++++++- tests/unit/inference/test_serving_tokens.py | 53 +++- 3 files changed, 302 insertions(+), 5 deletions(-) diff --git a/deps/renderers b/deps/renderers index f7696cd91e..db5058e8e8 160000 --- a/deps/renderers +++ b/deps/renderers @@ -1 +1 @@ -Subproject commit f7696cd91e6279aa9bee9d9212534cc8e627ed06 +Subproject commit db5058e8e84ee716f662438bb8451367187b3463 diff --git a/src/prime_rl/inference/vllm/serving_tokens.py b/src/prime_rl/inference/vllm/serving_tokens.py index f92f296faf..a09956e656 100644 --- a/src/prime_rl/inference/vllm/serving_tokens.py +++ b/src/prime_rl/inference/vllm/serving_tokens.py @@ -36,6 +36,7 @@ import logging import os import re +import threading import time from collections.abc import AsyncGenerator from functools import cached_property @@ -49,10 +50,14 @@ _SAFE_MM_HASH_RE, _SAFE_RUN_ID_RE, MMFILE_PREFIX, + MMRAW_PREFIX, mm_feature_envelope_matches, mm_feature_fingerprint, mm_feature_path, + mm_processor_fingerprint, + raw_image_path, split_mmfile_ref, + split_mmraw_ref, ) from vllm.entrypoints.openai.engine.protocol import ErrorResponse, RequestResponseMetadata from vllm.entrypoints.serve.disagg.protocol import ( @@ -73,6 +78,8 @@ _MM_FEATURE_LOAD_RETRIES = 3 _MM_FEATURE_LOAD_BACKOFF_S = 0.02 _mm_feature_executor: concurrent.futures.ThreadPoolExecutor | None = None +_mm_raw_processors: dict[str, Any] = {} +_mm_raw_processors_lock = threading.Lock() class PrimeRlGenerateResponseChoice(GenerateResponseChoice): @@ -237,6 +244,56 @@ def _parse_mmfile_ref(ref: str, *, expected_modality: str, expected_hash: str) - return run_id, fingerprint, modality, mm_hash +def _parse_mmraw_ref( + ref: str, *, expected_modality: str, expected_hash: str +) -> tuple[str, str, str, str, str, list[int]]: + try: + run_id, fingerprint, modality, mm_hash, raw_image_id, grid_thw = split_mmraw_ref(ref) + except ValueError as exc: + raise _MMFeatureArtifactError( + error_type="invalid_mm_raw_ref", + message=f"Invalid mmraw ref shape for {expected_modality}.", + status_code=HTTPStatus.BAD_REQUEST, + ) from exc + if not _SAFE_RUN_ID_RE.fullmatch(run_id): + raise _MMFeatureArtifactError( + error_type="invalid_mm_raw_ref", + message="mmraw run_id contains unsafe characters.", + status_code=HTTPStatus.BAD_REQUEST, + ) + if modality != expected_modality: + raise _MMFeatureArtifactError( + error_type="invalid_mm_raw_ref", + message=(f"mmraw modality {modality!r} does not match slot modality {expected_modality!r}."), + status_code=HTTPStatus.BAD_REQUEST, + ) + if mm_hash != expected_hash: + raise _MMFeatureArtifactError( + error_type="invalid_mm_raw_ref", + message="mmraw hash does not match the slot mm_hash.", + status_code=HTTPStatus.BAD_REQUEST, + ) + if not _SAFE_FINGERPRINT_RE.fullmatch(fingerprint): + raise _MMFeatureArtifactError( + error_type="invalid_mm_raw_ref", + message="mmraw fingerprint contains unsafe characters.", + status_code=HTTPStatus.BAD_REQUEST, + ) + if modality != "image": + raise _MMFeatureArtifactError( + error_type="invalid_mm_raw_ref", + message=f"Unsupported mmraw modality: {modality!r}.", + status_code=HTTPStatus.BAD_REQUEST, + ) + if not _SAFE_MM_HASH_RE.fullmatch(mm_hash): + raise _MMFeatureArtifactError( + error_type="invalid_mm_raw_ref", + message="mmraw hash contains unsafe characters.", + status_code=HTTPStatus.BAD_REQUEST, + ) + return run_id, fingerprint, modality, mm_hash, raw_image_id, grid_thw + + def _mm_feature_path(*, run_id: str, fingerprint: str, modality: str, mm_hash: str) -> Path: # ``_parse_mmfile_ref`` validates run_id/fingerprint/modality/mm_hash and the # traversal guard before we reach here, so ``mm_store.mm_feature_path``'s @@ -251,7 +308,7 @@ def _mm_feature_path(*, run_id: str, fingerprint: str, modality: str, mm_hash: s ) from exc -def _decoded_image_placeholder_length(item: Any, *, spatial_merge_size: int) -> int: +def _decoded_image_grid_thw(item: Any) -> list[int]: elem = item.get("image_grid_thw") data = getattr(elem, "data", elem) if hasattr(data, "detach"): @@ -261,9 +318,49 @@ def _decoded_image_placeholder_length(item: Any, *, spatial_merge_size: int) -> grid = data[0] if isinstance(data, list) and data and isinstance(data[0], list) else data if not isinstance(grid, list) or len(grid) != 3: raise ValueError("decoded image_grid_thw does not have shape [T,H,W]") + return [int(grid[0]), int(grid[1]), int(grid[2])] + + +def _decoded_image_placeholder_length(item: Any, *, spatial_merge_size: int) -> int: + grid = _decoded_image_grid_thw(item) return int(grid[0]) * int(grid[1]) * int(grid[2]) // (spatial_merge_size**2) +def _processor_size_value(size: Any, key: str) -> int: + value = getattr(size, key, None) + if value is None and isinstance(size, dict): + value = size.get(key) + if value is None: + raise ValueError(f"image processor size missing {key!r}") + return int(value) + + +def _processor_fingerprint(processor: Any) -> tuple[str, int]: + image_processor = processor.image_processor + merge_size = int(getattr(image_processor, "merge_size")) + fingerprint = mm_processor_fingerprint( + family="qwen_vl", + patch_size=int(getattr(image_processor, "patch_size")), + merge_size=merge_size, + temporal_patch_size=int(getattr(image_processor, "temporal_patch_size")), + min_pixels=_processor_size_value(image_processor.size, "shortest_edge"), + max_pixels=_processor_size_value(image_processor.size, "longest_edge"), + ) + return fingerprint, merge_size + + +def _get_mm_raw_processor(model_name: str) -> Any: + with _mm_raw_processors_lock: + processor = _mm_raw_processors.get(model_name) + if processor is not None: + return processor + from transformers import AutoProcessor + + processor = AutoProcessor.from_pretrained(model_name) + with _mm_raw_processors_lock: + return _mm_raw_processors.setdefault(model_name, processor) + + def _load_mmfile_ref_sync( ref: str, *, @@ -337,6 +434,115 @@ def _load_mmfile_ref_sync( ) from exc +def _load_mmraw_ref_sync( + ref: str, + *, + expected_modality: str, + expected_hash: str, + expected_placeholder_length: int, + processor_model_name: str, +): + from PIL import Image + from renderers.qwen3_vl import _image_hash + from vllm.model_executor.models.qwen2_vl import _create_qwen2vl_field_factory + from vllm.multimodal.inputs import MultiModalKwargsItems + + run_id, fingerprint, modality, mm_hash, raw_image_id, expected_grid = _parse_mmraw_ref( + ref, + expected_modality=expected_modality, + expected_hash=expected_hash, + ) + missing = [ + { + "run_id": run_id, + "modality": modality, + "mm_hash": mm_hash, + "fingerprint": fingerprint, + "raw_image_id": raw_image_id, + } + ] + try: + path = raw_image_path(run_id=run_id, raw_image_id=raw_image_id) + except ValueError as exc: + raise _MMFeatureArtifactError( + error_type="invalid_mm_raw_ref", + message=str(exc), + status_code=HTTPStatus.BAD_REQUEST, + ) from exc + + pil = None + for attempt in range(_MM_FEATURE_LOAD_RETRIES): + try: + with Image.open(path) as img: + pil = img.convert("RGB") + break + except FileNotFoundError: + if attempt + 1 == _MM_FEATURE_LOAD_RETRIES: + raise _MMFeatureArtifactError( + error_type="missing_mm_raw_image", + message=f"Missing mmraw image: {path}", + missing=missing, + ) from None + time.sleep(_MM_FEATURE_LOAD_BACKOFF_S * (attempt + 1)) + except Exception as exc: + raise _MMFeatureArtifactError( + error_type="corrupt_mm_raw_image", + message=f"Corrupt mmraw image for {modality}:{mm_hash}: {exc}", + missing=missing, + ) from exc + if pil is None: + raise _MMFeatureArtifactError( + error_type="missing_mm_raw_image", + message=f"Missing mmraw image: {path}", + missing=missing, + ) + + actual_hash = _image_hash(pil) + if actual_hash != mm_hash: + raise _MMFeatureArtifactError( + error_type="raw_mm_hash_mismatch", + message=f"mmraw image hash mismatch for {modality}:{mm_hash}; got {actual_hash}", + missing=missing, + status_code=HTTPStatus.BAD_REQUEST, + ) + + try: + processor = _get_mm_raw_processor(processor_model_name) + expected_fingerprint, merge_size = _processor_fingerprint(processor) + if fingerprint != expected_fingerprint: + raise _MMFeatureArtifactError( + error_type="incompatible_mm_raw_fingerprint", + message=( + "mmraw fingerprint is not compatible with this vLLM process " + f"(got {fingerprint}, expected {expected_fingerprint})." + ), + status_code=HTTPStatus.BAD_REQUEST, + ) + + hf_inputs = processor.image_processor(images=[pil], return_tensors="pt") + config = _create_qwen2vl_field_factory(merge_size)(hf_inputs) + item = MultiModalKwargsItems.from_hf_inputs(hf_inputs, config)["image"][0] + actual_grid = _decoded_image_grid_thw(item) + actual_placeholder_length = _decoded_image_placeholder_length(item, spatial_merge_size=merge_size) + if actual_grid != expected_grid: + raise ValueError(f"processed image_grid_thw {actual_grid!r} != ref {expected_grid!r}") + if actual_placeholder_length != expected_placeholder_length: + raise ValueError( + "processed image_grid_thw does not match placeholder length " + f"({actual_placeholder_length} != {expected_placeholder_length})" + ) + return item + except _MMFeatureArtifactError: + raise + except Exception as exc: + raise _MMFeatureArtifactError( + error_type="raw_mm_grid_mismatch", + message=f"mmraw materialization failed for {modality}:{mm_hash}: {exc}", + missing=missing, + status_code=HTTPStatus.BAD_REQUEST, + ) from exc + + async def _load_mmfile_ref( ref: str, *, @@ -356,6 +562,27 @@ async def _load_mmfile_ref( ) +async def _load_mmraw_ref( + ref: str, + *, + expected_modality: str, + expected_hash: str, + expected_placeholder_length: int, + processor_model_name: str, +): + loop = asyncio.get_running_loop() + return await loop.run_in_executor( + _get_mm_feature_executor(), + lambda: _load_mmraw_ref_sync( + ref, + expected_modality=expected_modality, + expected_hash=expected_hash, + expected_placeholder_length=expected_placeholder_length, + processor_model_name=processor_model_name, + ), + ) + + def _missing_cache_error_from_exception(exc: Exception, features: Any) -> _MMFeatureArtifactError | None: text = repr(exc) if "Expected a cached item" not in text: @@ -452,13 +679,31 @@ async def serve_tokens( for modality, ranges in features.mm_placeholders.items() } mm_kwargs: dict[str, list[MultiModalKwargsItem | None]] = {} - slot_counts = {"none": 0, "inline": 0, "mmfile": 0} + slot_counts = {"none": 0, "inline": 0, "mmfile": 0, "mmraw": 0} load_start = time.monotonic() + processor_model_name = str(getattr(self.model_config, "model", None) or model_name) async def decode_slot(modality: str, idx: int, item: str | None) -> MultiModalKwargsItem | None: if item is None: slot_counts["none"] += 1 return None + if item.startswith(f"{MMRAW_PREFIX}:"): + slot_counts["mmraw"] += 1 + hashes = features.mm_hashes.get(modality) or [] + placeholders = features.mm_placeholders.get(modality) or [] + if idx >= len(hashes) or idx >= len(placeholders): + raise _MMFeatureArtifactError( + error_type="invalid_mm_raw_ref", + message=("mmraw slot has no matching hash or placeholder entry."), + status_code=HTTPStatus.BAD_REQUEST, + ) + return await _load_mmraw_ref( + item, + expected_modality=modality, + expected_hash=hashes[idx], + expected_placeholder_length=placeholders[idx].length, + processor_model_name=processor_model_name, + ) if item.startswith(f"{MMFILE_PREFIX}:"): slot_counts["mmfile"] += 1 hashes = features.mm_hashes.get(modality) or [] @@ -499,10 +744,11 @@ async def decode_slot(modality: str, idx: int, item: str | None) -> MultiModalKw if any(slot_counts.values()): logger.debug( - "decoded multimodal feature slots none=%d inline=%d mmfile=%d disk_load_ms=%.2f", + "decoded multimodal feature slots none=%d inline=%d mmfile=%d mmraw=%d load_ms=%.2f", slot_counts["none"], slot_counts["inline"], slot_counts["mmfile"], + slot_counts["mmraw"], (time.monotonic() - load_start) * 1000.0, ) diff --git a/tests/unit/inference/test_serving_tokens.py b/tests/unit/inference/test_serving_tokens.py index 9ba720c106..3132ce77d8 100644 --- a/tests/unit/inference/test_serving_tokens.py +++ b/tests/unit/inference/test_serving_tokens.py @@ -15,7 +15,15 @@ import numpy as np import pybase64 import pytest -from renderers.mm_store import mm_feature_fingerprint as _mm_feature_fingerprint +from renderers.mm_store import ( + mm_feature_fingerprint as _mm_feature_fingerprint, +) +from renderers.mm_store import ( + mm_processor_fingerprint as _mm_processor_fingerprint, +) +from renderers.mm_store import ( + mmraw_ref, +) from vllm.entrypoints.serve.disagg.protocol import GenerateResponse, GenerateResponseChoice from prime_rl.inference.vllm.routed_experts import serialize_routed_experts @@ -24,6 +32,7 @@ _client_set_max_tokens, _GenerateRoutedExpertsCapture, _load_mmfile_ref_sync, + _load_mmraw_ref_sync, _missing_cache_error_from_exception, _MMFeatureArtifactError, ) @@ -163,6 +172,48 @@ def test_missing_mmfile_artifact_is_typed(tmp_path, monkeypatch): ] +def test_missing_mmraw_image_is_typed(tmp_path, monkeypatch): + monkeypatch.setenv("PRIME_RL_MM_FEATURE_ROOT", str(tmp_path)) + run_id = "testrun" + mm_hash = "a" * 32 + fingerprint = _mm_processor_fingerprint( + family="qwen_vl", + patch_size=14, + merge_size=2, + temporal_patch_size=2, + min_pixels=56 * 56, + max_pixels=14 * 14 * 4 * 1280, + ) + ref = mmraw_ref( + run_id=run_id, + fingerprint=fingerprint, + modality="image", + mm_hash=mm_hash, + raw_image_id="missing.png", + grid_thw=[[1, 2, 2]], + ) + + with pytest.raises(_MMFeatureArtifactError) as exc_info: + _load_mmraw_ref_sync( + ref, + expected_modality="image", + expected_hash=mm_hash, + expected_placeholder_length=1, + processor_model_name="unused-on-missing-file", + ) + + assert exc_info.value.error_type == "missing_mm_raw_image" + assert exc_info.value.missing == [ + { + "run_id": run_id, + "modality": "image", + "mm_hash": mm_hash, + "fingerprint": fingerprint, + "raw_image_id": "missing.png", + } + ] + + def test_mmfile_artifact_round_trips_vllm_serde(tmp_path, monkeypatch): torch = pytest.importorskip("torch") pytest.importorskip("vllm") From 7414cd7b675a74629e33904ce220a5406fd69d43 Mon Sep 17 00:00:00 2001 From: Eli Gottlieb <78387377+eligotts@users.noreply.github.com> Date: Tue, 9 Jun 2026 05:29:25 +0000 Subject: [PATCH 26/31] fix(orchestrator): default train env num_workers to 32 The auto formula (max(1, ceil(max_inflight_rollouts/256))) gave a single env worker for any max_inflight <= 256, so one worker absorbed all in-flight rollouts; its event loop saturated during concurrent sandbox setup and missed the 30s heartbeat -> EnvRouter restart loop (surfaces as 'env server unhealthy'). Fix the train resolver to a fixed pool of 32 (explicit per-env/group values still win) so setup load spreads across workers. Requires bumping the orchestrator pod's memory/CPU to host 32 workers. Co-Authored-By: Codex Co-Authored-By: Claude Opus 4.8 (1M context) --- .../src/prime_rl/configs/orchestrator.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py b/packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py index c6fd6da534..fc92b8b3e5 100644 --- a/packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py +++ b/packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py @@ -159,8 +159,8 @@ class EnvConfig(BaseConfig): address: str | None = None """ZMQ address of an external env server (e.g. ``tcp://host:5000``). When set, the orchestrator connects to this server instead of spawning one; when None, a subprocess env server is spawned automatically.""" - num_workers: int | Literal["auto"] = "auto" - """Worker processes for the spawned env server. ``auto`` scales to 1 worker per 256 concurrent rollouts. Ignored when ``address`` is set.""" + num_workers: int | Literal["auto"] = 32 + """Worker processes for the spawned env server. ``auto`` is resolved per group (train envs use a fixed pool). Ignored when ``address`` is set.""" ratio: float | None = Field(None, gt=0) """Sampling weight for this environment in the buffer. When None for all envs, samples uniformly across all available problems. When set, must be set on all envs — values are relative weights normalized to probabilities (e.g. [1, 1] and [0.5, 0.5] are equivalent).""" @@ -898,11 +898,12 @@ def resolve_batching(self): if "group_size" not in env_cfg.model_fields_set: env_cfg.group_size = self.group_size - # Resolve train env num_workers from max_inflight_rollouts + # Fixed train env worker pool: a single env worker can't absorb all + # in-flight rollouts (event-loop saturation during sandbox setup -> + # heartbeat timeout -> restart loop). Explicit per-env values win. for env_cfg in self.train.env: if env_cfg.num_workers == "auto": - assert self.max_inflight_rollouts is not None - env_cfg.num_workers = max(1, math.ceil(self.max_inflight_rollouts / 256)) + env_cfg.num_workers = 32 return self From 44561979e08c528e73d70dae5447a10d272ab141 Mon Sep 17 00:00:00 2001 From: Eli Gottlieb <78387377+eligotts@users.noreply.github.com> Date: Tue, 9 Jun 2026 06:32:01 +0000 Subject: [PATCH 27/31] fix(monitor): stream sample parquet to disk incrementally, drop VLM gate R2 sample uploads still hit OpenSSL [BUF] malloc failure (OOM) because the streaming path was gated on a manual config flag (`student.model.vlm`/`is_vlm`) that is silently off when `[model.vlm]` isn't set, and because the parquet was built as a full in-RAM Arrow table before spilling. Always build straight to a disk-backed temp file via ParquetWriter, one rollout at a time, with a per-rollout image cache so base64-inlined screenshots are never all held in RAM; always stream the upload (Content-Length, chunked TLS writes). Removes the fragile config gate and the in-RAM table/bytes build. Co-Authored-By: Codex Co-Authored-By: Claude Opus 4.8 (1M context) --- src/prime_rl/utils/monitor/prime.py | 185 +++++++++++++--------------- 1 file changed, 87 insertions(+), 98 deletions(-) diff --git a/src/prime_rl/utils/monitor/prime.py b/src/prime_rl/utils/monitor/prime.py index a62133da25..5674368514 100644 --- a/src/prime_rl/utils/monitor/prime.py +++ b/src/prime_rl/utils/monitor/prime.py @@ -1,7 +1,6 @@ import asyncio import base64 import contextlib -import io import json import math import mimetypes @@ -77,22 +76,6 @@ async def _aiter_handle(f, chunk_size: int = 65536): yield chunk -def _table_to_parquet_bytes(table: "pa.Table") -> bytes: - buf = io.BytesIO() - pq.write_table(table, buf, compression="snappy", use_dictionary=True, write_statistics=True) - return buf.getvalue() - - -def _run_config_is_multimodal(run_config: Any) -> bool: - """True when the run trains a VLM. Handles trainer (``model``) and orchestrator - (``student.model``) config shapes; defaults False for anything else.""" - candidates = ( - getattr(run_config, "model", None), - getattr(getattr(run_config, "student", None), "model", None), - ) - return any(getattr(model, "is_vlm", False) for model in candidates if model is not None) - - def _drop_non_finite_json_values(value: Any, dropped_paths: list[str], path: str = "") -> Any: if isinstance(value, float) and not math.isfinite(value): dropped_paths.append(path) @@ -211,7 +194,6 @@ def __init__( self.history: list[dict[str, Any]] = [] self._keep_full_history = keep_full_history self.output_dir = output_dir - self._is_multimodal = _run_config_is_multimodal(run_config) self._registered = False self._finalized = False self._closed = False @@ -405,20 +387,15 @@ def log_samples(self, rollouts: list[vf.RolloutOutput], step: int) -> None: self.logger.info(f"Logging {len(rollouts)} samples to Prime Intellect API at step {step}") start_time = time.perf_counter() - table = self._rollouts_to_parquet_table(rollouts, step) - if table is None: + # Always build straight to a disk-backed temp file and stream the upload. Samples can + # inline base64 images and be large; gating on a config flag was fragile (silently + # off when ``[model.vlm]`` wasn't set). Streaming from disk is safe for any size and + # never holds the serialized body in RAM. + body = self._write_samples_parquet(rollouts, step) + if body is None: self.logger.warning(f"No samples to log at step {step}") return - # Multimodal parquets inline base64 images and can be large: write straight to a - # disk-backed temp file and stream the upload, so the serialized body is never held - # in RAM as bytes and nothing is retained during the async upload. Text runs stay inline. - if self._is_multimodal: - body: bytes | Path = self._spill_table_to_tempfile(table, step) - else: - body = _table_to_parquet_bytes(table) - del table - # Mark pending only once the body is built and ready to schedule. self._pending_sample_steps.add(step) self._upload_samples_via_presigned_url(body, step) @@ -427,80 +404,92 @@ def log_samples(self, rollouts: list[vf.RolloutOutput], step: int) -> None: f"Initiated samples upload at step {step} to Prime Intellect API in {time.perf_counter() - start_time:.2f}s" ) - def _rollouts_to_parquet_table(self, rollouts: list[vf.RolloutOutput], step: int) -> "pa.Table | None": - """Build the sample Parquet table from rollouts (None when there is nothing to log).""" + def _write_samples_parquet(self, rollouts: list[vf.RolloutOutput], step: int) -> "Path | None": + """Stream the sample Parquet straight to a disk-backed temp file (co-located with the + run's output dir to avoid a tmpfs ``/tmp``), one rollout at a time, so base64-inlined + images are never all held in RAM. Returns the temp file path, or None when there is + nothing to log. The upload coroutine unlinks the file when done.""" now = datetime.now(timezone.utc) - rows = [] - image_data_url_cache: dict[str, str | None] = {} - - for sample_id, rollout in enumerate(rollouts): - prompt = rollout.get("prompt") - completion = rollout.get("completion") - trajectory = rollout.get("trajectory") or [] - if prompt is None or completion is None or not trajectory: - continue - - example_id = rollout.get("example_id") - try: - problem_id = int(example_id) if example_id is not None else sample_id - except (TypeError, ValueError): - problem_id = sample_id - - trajectory_data = [ - { - "prompt": ts["prompt"], - "completion": ts["completion"], - "reward": ts.get("reward"), - "advantage": ts.get("advantage"), - "extras": ts.get("extras", {}), - "num_input_tokens": len(ts["tokens"]["prompt_ids"]) if ts.get("tokens") else None, - "num_output_tokens": len(ts["tokens"]["completion_ids"]) if ts.get("tokens") else None, - } - for ts in trajectory - ] - - rows.append( - { - "run_id": self.run_id, - "step": step, - "tag": "", - "problem_id": problem_id, - "sample_id": sample_id, - "prompt": json.dumps(_inline_local_image_urls(prompt, image_data_url_cache)), - "completion": json.dumps(_inline_local_image_urls(completion, image_data_url_cache)), - "trajectory": json.dumps(_inline_local_image_urls(trajectory_data, image_data_url_cache)), - "answer": rollout.get("answer") or "", - "env_name": rollout.get("env_name") or "", - "task": rollout.get("task") or "", - "info": _json(rollout.get("info")), - "reward": rollout.get("reward"), - "advantage": rollout.get("advantage"), - "metrics": _json(rollout.get("metrics")), - "timing": _json(rollout.get("timing")), - "num_input_tokens": 0, - "num_output_tokens": 0, - "created_at": now, - } - ) - - if not rows: + directory = str(self.output_dir) if self.output_dir else None + fd, name = tempfile.mkstemp(prefix=f"prime_samples_step{step}_", suffix=".parquet", dir=directory) + path = Path(name) + writer: "pq.ParquetWriter | None" = None + wrote = False + try: + with os.fdopen(fd, "wb") as f: + writer = pq.ParquetWriter( + f, _SAMPLE_SCHEMA, compression="snappy", use_dictionary=True, write_statistics=True + ) + for sample_id, rollout in enumerate(rollouts): + row = self._rollout_to_sample_row(rollout, sample_id, step, now) + if row is None: + continue + writer.write_table(pa.Table.from_pylist([row], schema=_SAMPLE_SCHEMA)) + wrote = True + writer.close() + writer = None + except BaseException: + if writer is not None: + writer.close() + path.unlink(missing_ok=True) + raise + if not wrote: + path.unlink(missing_ok=True) + return None + return path + + def _rollout_to_sample_row( + self, rollout: vf.RolloutOutput, sample_id: int, step: int, now: datetime + ) -> "dict[str, Any] | None": + """Build one sample row, or None to skip. A fresh per-rollout image cache bounds peak + RAM to a single rollout's inlined base64 (images repeated across its turns still dedup).""" + prompt = rollout.get("prompt") + completion = rollout.get("completion") + trajectory = rollout.get("trajectory") or [] + if prompt is None or completion is None or not trajectory: return None - return pa.Table.from_pylist(rows, schema=_SAMPLE_SCHEMA) + example_id = rollout.get("example_id") + try: + problem_id = int(example_id) if example_id is not None else sample_id + except (TypeError, ValueError): + problem_id = sample_id - def _rollouts_to_parquet_bytes(self, rollouts: list[vf.RolloutOutput], step: int) -> bytes | None: - """Convert rollouts to in-memory Parquet bytes (inline upload path).""" - table = self._rollouts_to_parquet_table(rollouts, step) - return None if table is None else _table_to_parquet_bytes(table) + trajectory_data = [ + { + "prompt": ts["prompt"], + "completion": ts["completion"], + "reward": ts.get("reward"), + "advantage": ts.get("advantage"), + "extras": ts.get("extras", {}), + "num_input_tokens": len(ts["tokens"]["prompt_ids"]) if ts.get("tokens") else None, + "num_output_tokens": len(ts["tokens"]["completion_ids"]) if ts.get("tokens") else None, + } + for ts in trajectory + ] - def _spill_table_to_tempfile(self, table: "pa.Table", step: int) -> Path: - """Write the sample Parquet straight to a disk-backed temp file, co-located with the - run's output dir to avoid a tmpfs ``/tmp``. The upload coroutine unlinks it when done.""" - directory = str(self.output_dir) if self.output_dir else None - fd, name = tempfile.mkstemp(prefix=f"prime_samples_step{step}_", suffix=".parquet", dir=directory) - with os.fdopen(fd, "wb") as f: - pq.write_table(table, f, compression="snappy", use_dictionary=True, write_statistics=True) - return Path(name) + image_cache: dict[str, str | None] = {} + return { + "run_id": self.run_id, + "step": step, + "tag": "", + "problem_id": problem_id, + "sample_id": sample_id, + "prompt": json.dumps(_inline_local_image_urls(prompt, image_cache)), + "completion": json.dumps(_inline_local_image_urls(completion, image_cache)), + "trajectory": json.dumps(_inline_local_image_urls(trajectory_data, image_cache)), + "answer": rollout.get("answer") or "", + "env_name": rollout.get("env_name") or "", + "task": rollout.get("task") or "", + "info": _json(rollout.get("info")), + "reward": rollout.get("reward"), + "advantage": rollout.get("advantage"), + "metrics": _json(rollout.get("metrics")), + "timing": _json(rollout.get("timing")), + "num_input_tokens": 0, + "num_output_tokens": 0, + "created_at": now, + } def _upload_samples_via_presigned_url(self, body: bytes | Path, step: int) -> None: """Upload Parquet samples using presigned URL flow (fire-and-forget).""" From b35bce8b55018cf9bce84257b3f661326fe901cf Mon Sep 17 00:00:00 2001 From: hubert-marek Date: Tue, 9 Jun 2026 20:34:58 +0000 Subject: [PATCH 28/31] Revert "fix(orchestrator): default train env num_workers to 32" This reverts commit 7414cd7b675a74629e33904ce220a5406fd69d43. The auto->32 train default was dead code: the earlier num_workers resolver (num_examples == -1 -> 4) sets the value before the auto->32 check runs, so it never applied (train envs still got ~4 workers -> heartbeat timeouts). The correct lever is now per-run run_config.env_server.num_workers (platform #2838, bounds 1-64), which sets it explicitly and bypasses both auto resolvers. Co-Authored-By: Claude Opus 4.8 (1M context) --- .../src/prime_rl/configs/orchestrator.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py b/packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py index fc92b8b3e5..c6fd6da534 100644 --- a/packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py +++ b/packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py @@ -159,8 +159,8 @@ class EnvConfig(BaseConfig): address: str | None = None """ZMQ address of an external env server (e.g. ``tcp://host:5000``). When set, the orchestrator connects to this server instead of spawning one; when None, a subprocess env server is spawned automatically.""" - num_workers: int | Literal["auto"] = 32 - """Worker processes for the spawned env server. ``auto`` is resolved per group (train envs use a fixed pool). Ignored when ``address`` is set.""" + num_workers: int | Literal["auto"] = "auto" + """Worker processes for the spawned env server. ``auto`` scales to 1 worker per 256 concurrent rollouts. Ignored when ``address`` is set.""" ratio: float | None = Field(None, gt=0) """Sampling weight for this environment in the buffer. When None for all envs, samples uniformly across all available problems. When set, must be set on all envs — values are relative weights normalized to probabilities (e.g. [1, 1] and [0.5, 0.5] are equivalent).""" @@ -898,12 +898,11 @@ def resolve_batching(self): if "group_size" not in env_cfg.model_fields_set: env_cfg.group_size = self.group_size - # Fixed train env worker pool: a single env worker can't absorb all - # in-flight rollouts (event-loop saturation during sandbox setup -> - # heartbeat timeout -> restart loop). Explicit per-env values win. + # Resolve train env num_workers from max_inflight_rollouts for env_cfg in self.train.env: if env_cfg.num_workers == "auto": - env_cfg.num_workers = 32 + assert self.max_inflight_rollouts is not None + env_cfg.num_workers = max(1, math.ceil(self.max_inflight_rollouts / 256)) return self From bcadf8c0f6e5908b577683775a581451f4c16cc0 Mon Sep 17 00:00:00 2001 From: Hubert <163992334+hubert-marek@users.noreply.github.com> Date: Tue, 9 Jun 2026 21:49:01 -0700 Subject: [PATCH 29/31] fix(orchestrator): drop per-step prompt arrays from buffered rollouts after tokenization (#2752) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix(orchestrator): drop per-step prompt arrays from buffered rollouts after tokenization Every trajectory step ships prompt_ids/prompt_mask for its ENTIRE prompt prefix — O(turns x context) boxed ints, 100-370MB per long browser rollout — but their only readers run at arrival (backfill_rollout_tokens / interleave_rollout). The rollout then buffers in pending_groups (sibling wait) and pending_batch until its batch ships, so at the ~250 buffered rollouts observed on worldsims gflights runs this dead weight alone overflows a 128GB orchestrator. Strip the prompt arrays at the arrival boundary in TrainSink.add, mirroring the offload_images_to_disk treatment of image bytes: - keep a num_prompt_tokens count (the prime monitor's per-step num_input_tokens stat reads it, falling back to len(prompt_ids)) - keep completion arrays on every step (entropy/rare-token filters scan them) - keep the final step whole (the wandb sample table decodes it) save_rollouts already excludes the trajectory from disk artifacts on both the train and eval paths, so on-disk output is unchanged. Co-Authored-By: Claude Fable 5 * style: ruff format monitor num_input_tokens expression Co-Authored-By: Claude Fable 5 --------- Co-authored-by: Claude Fable 5 --- src/prime_rl/orchestrator/train_sink.py | 14 ++++++++++++++ src/prime_rl/utils/monitor/prime.py | 6 +++++- 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/src/prime_rl/orchestrator/train_sink.py b/src/prime_rl/orchestrator/train_sink.py index 221302a1f6..f24ae0c383 100644 --- a/src/prime_rl/orchestrator/train_sink.py +++ b/src/prime_rl/orchestrator/train_sink.py @@ -142,6 +142,20 @@ async def add(self, rollout: TrainRollout) -> TrainBatch | None: """Process one arrival; finalize the group on the ``group_size``-th arrival; return a ``TrainBatch`` if the batch threshold is met.""" await self.process_rollout(rollout) + # Per-step prompt arrays carry the FULL prompt prefix per step — + # O(turns x context) boxed ints — and have no readers past the + # tokenization above; drop them before the rollout buffers, like + # ``offload_images_to_disk`` does for image bytes. Keep the count + # (sample monitors' ``num_input_tokens``), the completion arrays + # (post-batch filters scan them), and the final step whole (the + # wandb sample table decodes it); ``save_rollouts`` already excludes + # the trajectory from disk artifacts. + for step in (rollout.raw.get("trajectory") or [])[:-1]: + tokens = step.get("tokens") + if tokens: + tokens["num_prompt_tokens"] = len(tokens["prompt_ids"]) + tokens["prompt_ids"] = [] + tokens["prompt_mask"] = [] env_name = rollout.env_name self.arrivals_by_env[env_name] += 1 if rollout.error is not None: diff --git a/src/prime_rl/utils/monitor/prime.py b/src/prime_rl/utils/monitor/prime.py index 5674368514..f8d6d48b27 100644 --- a/src/prime_rl/utils/monitor/prime.py +++ b/src/prime_rl/utils/monitor/prime.py @@ -462,7 +462,11 @@ def _rollout_to_sample_row( "reward": ts.get("reward"), "advantage": ts.get("advantage"), "extras": ts.get("extras", {}), - "num_input_tokens": len(ts["tokens"]["prompt_ids"]) if ts.get("tokens") else None, + # Stripped steps (strip_trajectory_prompt_arrays) carry the + # count instead of the array; fresh steps still have the ids. + "num_input_tokens": ( + ts["tokens"].get("num_prompt_tokens", len(ts["tokens"]["prompt_ids"])) if ts.get("tokens") else None + ), "num_output_tokens": len(ts["tokens"]["completion_ids"]) if ts.get("tokens") else None, } for ts in trajectory From 35456a75cd0d6bf46d4b83628a2d328a25155314 Mon Sep 17 00:00:00 2001 From: Hubert <163992334+hubert-marek@users.noreply.github.com> Date: Tue, 9 Jun 2026 22:05:19 -0700 Subject: [PATCH 30/31] chore: bump renderers to 462149b (mmraw preprocessor_config hub-download fallback) (#2753) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit renderers#82: raw-image layout no longer hard-fails on hosted workers when the model id misses the local HF cache — falls back to hf_hub_download for preprocessor_config.json. Explicit image_* renderer-config overrides remain the hermetic path; this is the safety net for configs without them. Co-authored-by: Claude Fable 5 --- deps/renderers | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deps/renderers b/deps/renderers index db5058e8e8..462149b974 160000 --- a/deps/renderers +++ b/deps/renderers @@ -1 +1 @@ -Subproject commit db5058e8e84ee716f662438bb8451367187b3463 +Subproject commit 462149b974bbad8010896ba26aa3de811486345a From b101d22a6a213b426a16d055a42172d2596dfac2 Mon Sep 17 00:00:00 2001 From: Hubert <163992334+hubert-marek@users.noreply.github.com> Date: Wed, 10 Jun 2026 01:30:52 -0700 Subject: [PATCH 31/31] explicit del and malloc (#2757) (cherry picked from commit 77b85673702b8ac939a5c2e105de54a676276dbe) Co-authored-by: Christian --- src/prime_rl/orchestrator/orchestrator.py | 50 +++++++++++++++-------- 1 file changed, 33 insertions(+), 17 deletions(-) diff --git a/src/prime_rl/orchestrator/orchestrator.py b/src/prime_rl/orchestrator/orchestrator.py index 902c8b963b..75bea40f34 100644 --- a/src/prime_rl/orchestrator/orchestrator.py +++ b/src/prime_rl/orchestrator/orchestrator.py @@ -22,6 +22,7 @@ import asyncio import ctypes +import gc import logging import os import time @@ -109,6 +110,14 @@ TARGET_LAG = 1 +def _release_unused_memory() -> None: + gc.collect() + try: + ctypes.CDLL("libc.so.6").malloc_trim(0) + except (OSError, AttributeError) as e: + get_logger().debug(f"malloc_trim(0) unavailable: {e}") + + class Orchestrator: # Set in ``__init__`` config: OrchestratorConfig @@ -475,10 +484,7 @@ async def start(self) -> None: get_logger().success("Orchestrator finished.") else: get_logger().warning("Orchestrator cleanup complete (forced).") - try: - ctypes.CDLL("libc.so.6").malloc_trim(0) - except Exception as e: - get_logger().debug(f"malloc_trim(0) failed: {e}") + _release_unused_memory() async def main_loop(self) -> None: """Consume ``FinishedRollout``\\ s from the dispatcher and route them @@ -495,19 +501,29 @@ async def main_loop(self) -> None: except asyncio.TimeoutError: continue - if isinstance(rollout, EvalRollout): - assert self.eval_sink is not None # eval rollouts only emitted when eval is configured - eval_batch = self.eval_sink.add(rollout) - if eval_batch is not None: - self.finalize_eval_batch(eval_batch) - continue - - assert isinstance(rollout, TrainRollout) - train_batch = await self.train_sink.add(rollout) - # In drain mode any late-arriving train batch is dropped — we - # don't want to ship past ``max_steps`` - if train_batch is not None and not self.draining and not self.stopped.is_set(): - await self.finalize_train_batch(train_batch) + batch = None + should_release_memory = False + try: + if isinstance(rollout, EvalRollout): + assert self.eval_sink is not None # eval rollouts only emitted when eval is configured + batch = self.eval_sink.add(rollout) + if batch is not None: + should_release_memory = True + self.finalize_eval_batch(batch) + continue + + assert isinstance(rollout, TrainRollout) + batch = await self.train_sink.add(rollout) + # In drain mode any late-arriving train batch is dropped — we + # don't want to ship past ``max_steps`` + if batch is not None: + should_release_memory = True + if batch is not None and not self.draining and not self.stopped.is_set(): + await self.finalize_train_batch(batch) + finally: + del batch, rollout + if should_release_memory: + _release_unused_memory() async def finalize_train_batch(self, batch: TrainBatch) -> None: """Ship one ``TrainBatch`` out to the trainer and handle the I/O