From a296f0ea3210ee702863d55727deb0b15b239110 Mon Sep 17 00:00:00 2001 From: Joan <94200647+joanvelja@users.noreply.github.com> Date: Thu, 11 Jun 2026 07:28:31 +0000 Subject: [PATCH] fix(inference): patch vLLM 0.22 O(B*L) per-step sampler hot paths Two pure-Python per-step costs collapse decode ~5x on long generations with Qwen3.5 default sampling (presence_penalty=1.5) + thinking_token_budget: penalties rebuild a padded [B, out_len] tensor from lists every step, and the thinking-budget holder rescans the whole output for every step. Replace Sampler.apply_penalties with a vectorized numpy slice of InputBatch.token_ids_cpu (pinned double-buffered staging, identity-checked fallback to upstream for unrecognized rows), add async write-back so token_ids_cpu stays authoritative, and wrap _update_think_state with an incremental watermark scan (-2 sentinel skips upstream's full rescan). Equivalence-tested against upstream on randomized streams. --- src/prime_rl/inference/patches.py | 3 + src/prime_rl/inference/vllm/sampler_perf.py | 334 ++++++++++++++++++++ tests/unit/inference/test_sampler_perf.py | 184 +++++++++++ 3 files changed, 521 insertions(+) create mode 100644 src/prime_rl/inference/vllm/sampler_perf.py create mode 100644 tests/unit/inference/test_sampler_perf.py diff --git a/src/prime_rl/inference/patches.py b/src/prime_rl/inference/patches.py index 8190f210b0..ed913616a8 100644 --- a/src/prime_rl/inference/patches.py +++ b/src/prime_rl/inference/patches.py @@ -15,11 +15,14 @@ def transformers_v5_compat(): if not hasattr(Qwen3VLMoeTextConfig, "tie_word_embeddings"): Qwen3VLMoeTextConfig.tie_word_embeddings = False + from prime_rl.inference.vllm.sampler_perf import apply_sampler_perf_patches + _patch_qwen35_lora() _patch_lora_key_prefix() monkey_patch_deep_gemm_silu_mul_quant_int64() monkey_patch_vllm_padded_input_scrub() monkey_patch_return_routed_experts_with_nixl_connector() + apply_sampler_perf_patches() def monkey_patch_return_routed_experts_with_nixl_connector(): diff --git a/src/prime_rl/inference/vllm/sampler_perf.py b/src/prime_rl/inference/vllm/sampler_perf.py new file mode 100644 index 0000000000..4efd5e61bd --- /dev/null +++ b/src/prime_rl/inference/vllm/sampler_perf.py @@ -0,0 +1,334 @@ +"""Hot-path fixes for two O(batch x output_len) per-step CPU costs in vLLM 0.22.0. + +Both costs collapse decode throughput ~5x for long generations with the Qwen3.5 +recommended sampling defaults (presence_penalty=1.5) plus thinking_token_budget: + +1. Penalties: ``vllm/v1/sample/ops/penalties.py`` rebuilds a padded + ``[B, output_len]`` int64 tensor from Python lists via ``make_tensor_with_pad`` + on every decode step (~80ms/step at B=128, output 12k). Upstream's rework + lives only in Model Runner V2, which rejects ``reasoning_config`` and + ``enable_return_routed_experts``, so it is unusable here. We replace + ``Sampler.apply_penalties`` with a vectorized builder that slices the + already-materialized ``InputBatch.token_ids_cpu`` numpy buffer into a + reusable pinned staging tensor (no per-token Python iteration). Any state + we do not recognize (speculative-decode combined rows, foreign batches) + falls back to the upstream implementation, so semantics are identical by + construction. + + Under async scheduling, ``token_ids_cpu`` output positions are written as + ``-1`` placeholders and never repaired (only the Python lists are). We + vendor ``InputBatch.update_async_output_token_ids`` with a write-back so + the numpy buffer stays authoritative; residual ``-1`` rows (kv-load + discards, unrepaired rows) are masked to the pad bin exactly like + upstream's ``masked_fill_``. + +2. Thinking budget: ``ThinkingBudgetStateHolder._update_think_state`` rescans + the entire generated output for the think-end token ids on every step until + they appear (O(L) pure Python per request per step for the whole thinking + phase). We wrap the method with an incremental watermark scan: only tokens + generated since the last scan are searched. ``end_thinking`` uses a ``-2`` + sentinel for "scanned, not found": every downstream read in the original + (`== -1` scan guards, ``> -1``, ``>= 0``) treats -2 exactly like -1, while + the guard that triggers the full rescan only fires on -1. When the start + tokens are absent we replicate the original's early return (scan-start -> + scan-end -> return) without calling it. + +Apply via :func:`apply_sampler_perf_patches` from the vLLM general plugin so +every engine/worker process gets patched. Kill switch: +``PRIME_RL_DISABLE_SAMPLER_PERF_PATCH=1``. +""" + +import os +import weakref + +import numpy as np +import torch + +SUPPORTED_VLLM = "0.22.0" + +# Weakref to the live InputBatch of this process (captured at construction). +_INPUT_BATCH_REF: weakref.ref | None = None + + +# --------------------------------------------------------------------------- +# Patch 1: penalties +# --------------------------------------------------------------------------- + + +class _PinnedStaging: + """Double-buffered pinned staging for the [B, max_out_len] token tensor.""" + + def __init__(self, max_rows: int, max_cols: int): + numel = max_rows * max_cols + use_cuda_events = torch.cuda.is_available() + self._bufs = [torch.empty(numel, dtype=torch.int64, pin_memory=use_cuda_events) for _ in range(2)] + self._events = [torch.cuda.Event(), torch.cuda.Event()] if use_cuda_events else [None, None] + self._recorded = [False, False] + self._idx = 0 + + def get(self, rows: int, cols: int) -> tuple[torch.Tensor, int]: + i = self._idx + self._idx ^= 1 + event = self._events[i] + if self._recorded[i] and event is not None: + # The previous H2D copy from this buffer must complete before the + # CPU overwrites it (non_blocking copies read pinned memory async). + event.synchronize() + return self._bufs[i][: rows * cols].view(rows, cols), i + + def record(self, i: int) -> None: + event = self._events[i] + if event is not None: + event.record() + self._recorded[i] = True + + +_staging: _PinnedStaging | None = None + + +def _capture_input_batch() -> None: + from vllm.v1.worker.gpu_input_batch import InputBatch + + orig_init = InputBatch.__init__ + + def patched_init(self, *args, **kwargs): + orig_init(self, *args, **kwargs) + global _INPUT_BATCH_REF, _staging + _INPUT_BATCH_REF = weakref.ref(self) + _staging = _PinnedStaging(self.max_num_reqs, self.max_model_len) + + InputBatch.__init__ = patched_init + + +def _patch_async_output_writeback() -> None: + """Vendor InputBatch.update_async_output_token_ids (vLLM 0.22.0, + gpu_input_batch.py) + write repaired ids back into token_ids_cpu.""" + from vllm.v1.worker.gpu_input_batch import InputBatch + + def update_async_output_token_ids(self) -> None: + output_token_ids = self.sampling_metadata.output_token_ids + if self.sampled_token_ids_cpu is None or not output_token_ids: + return + + assert self.prev_req_id_to_index is not None + sampled_token_ids = None + for index, req_id in enumerate(self.req_ids): + prev_index = self.prev_req_id_to_index.get(req_id) + if prev_index is None: + continue + req_output_token_ids = output_token_ids[index] + if not req_output_token_ids or req_output_token_ids[-1] != -1: + continue + if sampled_token_ids is None: + assert self.async_copy_ready_event is not None + self.async_copy_ready_event.synchronize() + sampled_token_ids = self.sampled_token_ids_cpu.tolist() + new_ids: list[int] = sampled_token_ids[prev_index] + if not new_ids: + continue + num_sampled_ids = len(new_ids) if new_ids[-1] != -1 else new_ids.index(-1) + first_placeholder = len(req_output_token_ids) + while first_placeholder > 0 and req_output_token_ids[first_placeholder - 1] == -1: + first_placeholder -= 1 + num_placeholders = len(req_output_token_ids) - first_placeholder + num_to_replace = min(num_sampled_ids, num_placeholders) + del new_ids[num_to_replace:] + req_output_token_ids[first_placeholder:] = new_ids + # prime-rl addition: keep token_ids_cpu authoritative under async + # scheduling so the fast penalties path can slice it. + start = int(self.num_prompt_tokens[index]) + first_placeholder + self.token_ids_cpu[index, start : start + len(new_ids)] = new_ids + + InputBatch.update_async_output_token_ids = update_async_output_token_ids + + +def build_output_tokens_fast( + input_batch, + staging: "_PinnedStaging", + output_token_ids: list[list[int]], + vocab_size: int, + device: torch.device, +) -> torch.Tensor | None: + """Build the [B, max_out_len] padded output-token tensor without Python + per-token iteration. Returns None when the rows are not the live batch + rows (caller must fall back to the upstream implementation).""" + n = len(output_token_ids) + req_lists = input_batch.req_output_token_ids + if len(req_lists) < n: + return None + for i in range(n): + if output_token_ids[i] is not req_lists[i]: + return None + if n == 0: + return torch.empty(0, 0, dtype=torch.int64, device=device) + + out_lens = np.fromiter(map(len, output_token_ids), np.int64, n) + max_len = int(out_lens.max()) + if max_len == 0: + return torch.empty(n, 0, dtype=torch.int64, device=device) + + buf, buf_idx = staging.get(n, max_len) + dst = buf.numpy() + dst.fill(vocab_size) + token_ids_cpu = input_batch.token_ids_cpu + num_prompt = input_batch.num_prompt_tokens + for i in range(n): + length = out_lens[i] + if length: + start = num_prompt[i] + dst[i, :length] = token_ids_cpu[i, start : start + length] + # Unrepaired placeholders / discarded rows: same semantics as upstream's + # masked_fill_(output_tokens_t == -1, vocab_size). + dst[dst == -1] = vocab_size + tensor = buf.to(device, non_blocking=True) + staging.record(buf_idx) + return tensor + + +def _patch_fast_penalties() -> None: + from vllm.model_executor.layers.utils import apply_penalties as gpu_apply_penalties + from vllm.v1.sample.sampler import Sampler + + orig_apply_penalties = Sampler.apply_penalties + + def apply_penalties(logits, sampling_metadata, output_token_ids): + if sampling_metadata.no_penalties: + return logits + input_batch = _INPUT_BATCH_REF() if _INPUT_BATCH_REF is not None else None + tensor = None + if input_batch is not None and _staging is not None and logits.shape[0] == len(output_token_ids): + tensor = build_output_tokens_fast(input_batch, _staging, output_token_ids, logits.shape[1], logits.device) + if tensor is None: + # Unrecognized rows (e.g. spec-decode combined lists): upstream path. + return orig_apply_penalties(logits, sampling_metadata, output_token_ids) + assert sampling_metadata.prompt_token_ids is not None + return gpu_apply_penalties( + logits, + sampling_metadata.prompt_token_ids, + tensor, + sampling_metadata.presence_penalties, + sampling_metadata.frequency_penalties, + sampling_metadata.repetition_penalties, + ) + + Sampler.apply_penalties = staticmethod(apply_penalties) + + +# --------------------------------------------------------------------------- +# Patch 2: thinking budget incremental scan +# --------------------------------------------------------------------------- + + +def find_last_in_window(lst: list[int], pattern: list[int], lo: int, hi: int) -> int: + """Last start index of `pattern` in lst[lo:hi], or -1. C-speed for the + common single-token pattern via list.index.""" + m = len(pattern) + if m == 0 or hi - lo < m: + return -1 + if m == 1: + target = pattern[0] + last = -1 + i = lo + while True: + try: + i = lst.index(target, i, hi) + except ValueError: + return last + last = i + i += 1 + last = -1 + for i in range(lo, hi - m + 1): + if lst[i : i + m] == pattern: + last = i + return last + + +def _patch_thinking_budget_scan() -> None: + from vllm.v1.sample.thinking_budget_state import ThinkingBudgetStateHolder + + orig_update = ThinkingBudgetStateHolder._update_think_state + + def _update_think_state(self, state) -> None: + if state.get("thinking_token_budget", -1) == -1 or not self.think_end_token_ids: + return orig_update(self, state) + + out = state.get("output_tok_ids") or [] + # Watermark excludes trailing async -1 placeholders: those positions + # are rewritten with real ids next step and must be rescanned then. + hi = len(out) + while hi > 0 and out[hi - 1] == -1: + hi -= 1 + pos = state.get("_prime_scan_pos", 0) + start_idx = state.get("start_thinking", -1) + end_idx = state.get("end_thinking", -1) + # Shrinkage (spec rejection / kv-load discard): rescan from scratch. + # Caution: in the continue_thinking case, start_thinking is a + # prompt-absolute index set at init — never treat it as shrunk. + start_is_output_relative = not state.get("continue_thinking", False) + if ( + pos > hi + or (end_idx >= 0 and end_idx >= hi) + or (start_is_output_relative and 0 <= start_idx and start_idx >= hi) + ): + pos = 0 + if start_is_output_relative and 0 <= start_idx and start_idx >= hi: + start_idx = -1 + state["start_thinking"] = -1 + if end_idx >= 0 and end_idx >= hi: + end_idx = -1 + state["end_thinking"] = -1 + + if start_idx == -1: + m = len(self.think_start_token_ids) + lo = max(0, pos - (m - 1)) if m else 0 + idx = find_last_in_window(out, self.think_start_token_ids, lo, hi) if self.think_start_token_ids else -1 + if idx != -1: + state["start_thinking"] = idx + start_idx = idx + if end_idx < 0: # -1 (never scanned) or -2 (scanned, absent) + m = len(self.think_end_token_ids) + lo = max(0, pos - (m - 1)) + idx = find_last_in_window(out, self.think_end_token_ids, lo, hi) + # -2 sentinel: skips the original's full rescan (`== -1` guard) + # while behaving identically to -1 in every downstream comparison + # (`> -1`, `>= 0`). + state["end_thinking"] = idx if idx != -1 else -2 + state["_prime_scan_pos"] = hi + + if state.get("start_thinking", -1) == -1: + # Replicate the original's early return (scan start -> scan end -> + # `if start_thinking == -1: return`) without paying its scans. + return None + return orig_update(self, state) + + ThinkingBudgetStateHolder._update_think_state = _update_think_state + + +# --------------------------------------------------------------------------- + + +def apply_sampler_perf_patches() -> None: + from vllm.logger import init_logger + + logger = init_logger(__name__) + if os.environ.get("PRIME_RL_DISABLE_SAMPLER_PERF_PATCH", "0") == "1": + logger.warning("Sampler perf patches disabled via PRIME_RL_DISABLE_SAMPLER_PERF_PATCH") + return + import vllm + + if getattr(vllm, "_prime_rl_sampler_perf_patched", False): + return + vllm._prime_rl_sampler_perf_patched = True + + if vllm.__version__ != SUPPORTED_VLLM: + raise RuntimeError( + f"sampler_perf patches are pinned to vLLM {SUPPORTED_VLLM}, found " + f"{vllm.__version__}. Re-validate the vendored code paths " + "(Sampler.apply_penalties, InputBatch.update_async_output_token_ids, " + "ThinkingBudgetStateHolder._update_think_state) before bumping." + ) + _capture_input_batch() + _patch_async_output_writeback() + _patch_fast_penalties() + _patch_thinking_budget_scan() + logger.info("Applied sampler perf patches (fast penalties tensor build + incremental thinking-budget scan)") diff --git a/tests/unit/inference/test_sampler_perf.py b/tests/unit/inference/test_sampler_perf.py new file mode 100644 index 0000000000..cf15990800 --- /dev/null +++ b/tests/unit/inference/test_sampler_perf.py @@ -0,0 +1,184 @@ +"""Equivalence tests for the vLLM 0.22 sampler hot-path patches. + +Both patches must be bit-identical to the upstream implementations they +replace; these tests drive randomized inputs through old and new paths and +compare results exactly. +""" + +import random + +import numpy as np +import pytest +import torch + +from prime_rl.inference.vllm.sampler_perf import ( + _PinnedStaging, + build_output_tokens_fast, + find_last_in_window, +) + +VOCAB = 1000 + + +class FakeInputBatch: + def __init__(self, prompts: list[list[int]], outputs: list[list[int]], max_model_len: int = 256): + n = len(prompts) + self.max_num_reqs = n + self.max_model_len = max_model_len + self.token_ids_cpu = np.full((n, max_model_len), -7, dtype=np.int32) + self.num_prompt_tokens = np.zeros(n, dtype=np.int32) + self.req_output_token_ids = outputs + for i, (p, o) in enumerate(zip(prompts, outputs)): + self.num_prompt_tokens[i] = len(p) + self.token_ids_cpu[i, : len(p)] = p + self.token_ids_cpu[i, len(p) : len(p) + len(o)] = o + + +def _upstream_convert(output_token_ids: list[list[int]], vocab_size: int) -> torch.Tensor: + from vllm.v1.sample.ops.penalties import _convert_to_tensors + + t = _convert_to_tensors(output_token_ids, vocab_size, torch.device("cpu")) + t.masked_fill_(t == -1, vocab_size) + return t + + +@pytest.mark.parametrize("seed", [0, 1, 2, 3]) +def test_fast_output_tokens_matches_upstream(seed): + rng = random.Random(seed) + n = rng.randint(1, 24) + prompts = [[rng.randrange(VOCAB) for _ in range(rng.randint(1, 40))] for _ in range(n)] + outputs = [] + for _ in range(n): + out = [rng.randrange(VOCAB) for _ in range(rng.randint(0, 60))] + # async placeholders: trailing -1s, occasionally an unrepaired row + if out and rng.random() < 0.5: + out.append(-1) + if out and rng.random() < 0.1: + out = [-1] * len(out) + outputs.append(out) + ib = FakeInputBatch(prompts, outputs) + staging = _PinnedStaging(ib.max_num_reqs, ib.max_model_len) + + fast = build_output_tokens_fast(ib, staging, outputs, VOCAB, torch.device("cpu")) + ref = _upstream_convert(outputs, VOCAB) + assert fast is not None + assert fast.shape == ref.shape + assert torch.equal(fast, ref) + + +def test_fast_output_tokens_falls_back_on_foreign_rows(): + prompts = [[1, 2, 3]] + outputs = [[4, 5]] + ib = FakeInputBatch(prompts, outputs) + staging = _PinnedStaging(ib.max_num_reqs, ib.max_model_len) + # Spec-decode combine creates new list objects -> identity check must fail + combined = [[4, 5, 6]] + assert build_output_tokens_fast(ib, staging, combined, VOCAB, torch.device("cpu")) is None + + +def test_staging_double_buffer_reuse(): + ib = FakeInputBatch([[1]], [[2, 3]]) + staging = _PinnedStaging(4, 32) + a = build_output_tokens_fast(ib, staging, ib.req_output_token_ids, VOCAB, torch.device("cpu")) + b = build_output_tokens_fast(ib, staging, ib.req_output_token_ids, VOCAB, torch.device("cpu")) + assert torch.equal(a, b) + + +@pytest.mark.parametrize("pattern_len", [1, 2, 3]) +@pytest.mark.parametrize("seed", [0, 1, 2]) +def test_find_last_in_window_matches_full_scan(pattern_len, seed): + rng = random.Random(seed * 17 + pattern_len) + pattern = [rng.randrange(5) for _ in range(pattern_len)] + for _ in range(200): + lst = [rng.randrange(5) for _ in range(rng.randint(0, 50))] + full = find_last_in_window(lst, pattern, 0, len(lst)) + # reference: brute force + ref = -1 + for i in range(len(lst) - pattern_len + 1): + if lst[i : i + pattern_len] == pattern: + ref = i + assert full == ref + # windows tile the list with (m-1) overlap: max over windows == full scan + if lst: + cut = rng.randint(0, len(lst)) + lo2 = max(0, cut - (pattern_len - 1)) + w1 = find_last_in_window(lst, pattern, 0, cut) + w2 = find_last_in_window(lst, pattern, lo2, len(lst)) + assert max(w1, w2) == ref + + +def _mk_holder(): + from vllm.v1.sample.thinking_budget_state import ThinkingBudgetStateHolder + + class RC: + reasoning_start_token_ids = [7] + reasoning_end_token_ids = [9] + + return ThinkingBudgetStateHolder( + RC(), max_num_seqs=8, num_spec_tokens=0, device=torch.device("cpu"), is_pin_memory=False + ) + + +def _normalize(state: dict) -> dict: + out = { + k: v + for k, v in state.items() + if k not in ("_prime_scan_pos", "output_tok_ids", "prompt_tok_ids", "spec_token_ids") + } + # -2 sentinel is semantically "not found" == -1 + if out.get("end_thinking") == -2: + out["end_thinking"] = -1 + return out + + +@pytest.fixture(scope="module") +def think_state_impls(): + import vllm.v1.sample.thinking_budget_state as tbs + + from prime_rl.inference.vllm.sampler_perf import _patch_thinking_budget_scan + + orig_fn = tbs.ThinkingBudgetStateHolder._update_think_state + _patch_thinking_budget_scan() + patched_fn = tbs.ThinkingBudgetStateHolder._update_think_state + assert patched_fn is not orig_fn + yield orig_fn, patched_fn + tbs.ThinkingBudgetStateHolder._update_think_state = orig_fn + + +@pytest.mark.parametrize("seed", list(range(8))) +def test_thinking_budget_incremental_scan_equivalence(seed, think_state_impls): + orig_fn, patched_fn = think_state_impls + + rng = random.Random(seed) + holder = _mk_holder() + budget = rng.choice([3, 8, 20]) + prompt = [rng.randrange(5) for _ in range(rng.randint(1, 10))] + if rng.random() < 0.3: + prompt += [7] # think starts in prompt (continue_thinking) + + state_a = holder._init_state_entry(list(prompt), budget) + state_b = holder._init_state_entry(list(prompt), budget) + out_a: list[int] = [] + out_b: list[int] = [] + state_a["output_tok_ids"] = out_a + state_b["output_tok_ids"] = out_b + + # random token stream with think start/end injected at random points + stream = [] + for _ in range(rng.randint(5, 60)): + r = rng.random() + if r < 0.06: + stream.append(7) + elif r < 0.12: + stream.append(9) + else: + stream.append(rng.randrange(5)) + + for tok in stream: + out_a.append(tok) + out_b.append(tok) + state_a["force_index"] = [] + state_b["force_index"] = [] + orig_fn(holder, state_a) + patched_fn(holder, state_b) + assert _normalize(state_a) == _normalize(state_b), f"divergence after {len(out_a)} tokens (stream={stream})"