Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions packages/prime-rl-configs/src/prime_rl/configs/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,69 @@ class MultiNodeInferenceDeploymentConfig(BaseInferenceDeploymentConfig):
] = "consistent_hash"


class NixlTransportConfig(BaseModel):
"""Configures NIXL KV transfer for disaggregated inference deployments."""

model_config = ConfigDict(extra="forbid")

type: Literal["nixl"] = "nixl"

enable_bidirectional: Annotated[
bool,
Field(
description=(
"Whether Prefill workers can pull Decode-side KV through NIXL for later requests "
"in the same conversation."
),
),
] = False
num_threads: Annotated[
int,
Field(ge=1, description="Number of NIXL connector threads."),
] = 1
kv_recompute_threshold: Annotated[
int,
Field(
ge=0,
description=(
"Minimum number of remote Decode-side KV tokens required before a Prefill worker pulls "
"KV through NIXL instead of recomputing locally. Passed to NixlConnector extra config."
),
),
] = 64
abort_timeout_seconds: Annotated[
int,
Field(
gt=0,
description=(
"Seconds vLLM NIXL waits for the peer to fetch held KV blocks before aborting and freeing them. "
"Exported as NIXL_ABORT_TIMEOUT and vLLM's VLLM_NIXL_ABORT_REQUEST_TIMEOUT."
),
),
] = 480
router_cache_ttl_seconds: Annotated[
int | None,
Field(
gt=0,
description=(
"Seconds vllm-router keeps Decode-side KV metadata for bidirectional P/D reuse. "
"Defaults to 95% of abort_timeout_seconds."
),
),
] = None

@model_validator(mode="after")
def validate_router_cache_ttl(self):
if self.router_cache_ttl_seconds is None:
self.router_cache_ttl_seconds = int(self.abort_timeout_seconds * 0.95)
if self.router_cache_ttl_seconds >= self.abort_timeout_seconds:
raise ValueError(
"router_cache_ttl_seconds must be less than abort_timeout_seconds "
f"({self.router_cache_ttl_seconds} >= {self.abort_timeout_seconds})"
)
return self

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Auto-computed TTL can be zero for small timeouts

Low Severity

When abort_timeout_seconds is 1 (the minimum allowed by gt=0), the auto-computed router_cache_ttl_seconds becomes int(1 * 0.95) = 0. This bypasses the field's gt=0 constraint since model validators run after field validation and don't re-validate. The resulting value 0 gets passed to --pd-kv-cache-ttl-secs 0, which may cause unexpected router behavior.

Fix in Cursor Fix in Web

Reviewed by Cursor Bugbot for commit a3f8a6d. Configure here.



class DisaggregatedInferenceDeploymentConfig(BaseInferenceDeploymentConfig):
"""Configures a disaggregated prefill/decode inference deployment.

Expand Down Expand Up @@ -211,6 +274,11 @@ class DisaggregatedInferenceDeploymentConfig(BaseInferenceDeploymentConfig):
str, Field(description="Routing policy for the vllm-router (e.g. 'consistent_hash', 'round_robin').")
] = "consistent_hash"

kv_transport_config: Annotated[
NixlTransportConfig,
Field(description="KV transport settings for disaggregated P/D deployments."),
] = NixlTransportConfig()

prefill_env_overrides: Annotated[
dict[str, str],
Field(description="Extra environment variables exported only on prefill nodes."),
Expand Down
3 changes: 2 additions & 1 deletion packages/prime-rl-configs/src/prime_rl/configs/rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -770,8 +770,9 @@ def auto_setup_lora(self):

@model_validator(mode="after")
def auto_setup_session_headers(self):
"""Ensure X-Session-ID header is always set for sticky DP-aware routing at the inference router."""
"""Ensure stable routing headers are set for inference routers."""
self.orchestrator.client.extra_headers_from_state.setdefault("X-Session-ID", "example_id")
self.orchestrator.client.extra_headers_from_state.setdefault("X-Conversation-ID", "trajectory_id")

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nonexistent state field used for conversation header mapping

High Severity

The extra_headers_from_state dict maps header names to rollout state field names. The value "trajectory_id" doesn't appear to exist as a field in the rollout state dict anywhere in src/. Searching the codebase, "trajectory_id" only appears in test fixtures, never in the actual orchestrator state or rollout output dictionaries. By contrast, "example_id" (used for X-Session-ID) is a well-established state field found in buffer.py and envs.py. At runtime, the X-Conversation-ID header will likely be empty or cause an error when the framework tries to read "trajectory_id" from the state, breaking bidirectional KV routing which depends on this header.

Fix in Cursor Fix in Web

Reviewed by Cursor Bugbot for commit a3f8a6d. Configure here.

return self

@model_validator(mode="after")
Expand Down
9 changes: 5 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ dependencies = [
"torchaudio",
"torchdata>=0.11.0",
"transformers",
"vllm>=0.20.2",
"vllm==0.21.0",
"wandb>=0.26.1",
"ring-flash-attn>=0.1.8",
"prime>=0.6.4",
Expand Down Expand Up @@ -176,6 +176,7 @@ override-dependencies = [
[tool.uv.exclude-newer-package]
# we want latest vllm, remove next patch
vllm = false
tokenspeed-mla = false
flash_attn_3 = false
# PrimeIntellect-published on PyPI (trusted publisher)
prime = false
Expand Down Expand Up @@ -229,10 +230,10 @@ dion = { git = "https://github.com/samsja/dion.git", rev = "d891eeb" }
transformers = { git = "https://github.com/huggingface/transformers.git", rev = "c1c3424" }
flash-attn-4 = { git = "https://github.com/Dao-AILab/flash-attention.git", subdirectory = "flash_attn/cute", rev = "96bd151" }
pydantic-config = { git = "https://github.com/samsja/pydantic_config.git", branch = "main" }
vllm-router = { url = "https://github.com/PrimeIntellect-ai/router/releases/download/v0.1.22/vllm_router-0.1.22-cp38-abi3-manylinux_2_28_x86_64.whl" }
vllm-router = { git = "https://github.com/PrimeIntellect-ai/router.git", rev = "1a441d6" }
vllm = [
{ url = "https://github.com/vllm-project/vllm/releases/download/v0.20.2/vllm-0.20.2+cu129-cp38-abi3-manylinux_2_31_x86_64.whl", marker = "platform_machine == 'x86_64'" },
{ url = "https://github.com/vllm-project/vllm/releases/download/v0.20.2/vllm-0.20.2+cu129-cp38-abi3-manylinux_2_31_aarch64.whl", marker = "platform_machine == 'aarch64'" },
{ url = "https://files.pythonhosted.org/packages/73/6d/9b78990c9fabc70c7731de6af246a420156dc019f66b48da7c86f509c132/vllm-0.21.0-1-cp38-abi3-manylinux_2_24_x86_64.whl", marker = "platform_machine == 'x86_64'" },
{ url = "https://files.pythonhosted.org/packages/ac/58/564b64d17dde6dc31faae836f98313538c152edf88e2a4fb43b9d551a635/vllm-0.21.0-1-cp38-abi3-manylinux_2_24_aarch64.whl", marker = "platform_machine == 'aarch64'" },
]
deep-ep = { url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/deep_ep-1.2.1+29d31c0-cp312-cp312-linux_x86_64.whl" }
deep-gemm = { url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/deep_gemm-2.5.0+891d57b-cp312-cp312-linux_x86_64.whl" }
Expand Down
1 change: 1 addition & 0 deletions src/prime_rl/entrypoints/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def write_slurm_script(config: InferenceConfig, config_path: Path, script_path:
decode_port=config.deployment.decode_port,
router_port=config.deployment.router_port,
router_policy=config.deployment.router_policy,
kv_transport_config=config.deployment.kv_transport_config,
data_parallel_rpc_port=config.data_parallel_rpc_port,
use_deep_gemm=config.use_deep_gemm,
prefill_env_overrides=config.deployment.prefill_env_overrides,
Expand Down
2 changes: 2 additions & 0 deletions src/prime_rl/entrypoints/rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,8 +437,10 @@ def write_slurm_script(config: RLConfig, config_dir: Path, script_path: Path) ->
num_decode_replicas=infer_deploy.num_decode_replicas,
gpus_per_node=config.deployment.gpus_per_node,
router_port=infer_deploy.router_port,
router_policy=infer_deploy.router_policy,
prefill_port=infer_deploy.prefill_port,
decode_port=infer_deploy.decode_port,
kv_transport_config=infer_deploy.kv_transport_config,
inference_tp=config.inference.parallel.tp,
inference_data_parallel_rpc_port=config.inference.data_parallel_rpc_port,
use_deep_gemm=config.inference.use_deep_gemm,
Expand Down
152 changes: 13 additions & 139 deletions src/prime_rl/inference/patches.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ def transformers_v5_compat():
_patch_qwen35_lora()
_patch_lora_key_prefix()
monkey_patch_deep_gemm_ep_scatter()
monkey_patch_deep_gemm_silu_mul_quant_int64()
monkey_patch_dp_engine_core_pause_resume_deadlock()
monkey_patch_vllm_layerwise_reload_alias_buffers()

Expand Down Expand Up @@ -205,140 +204,6 @@ def monkey_patch_deep_gemm_ep_scatter():
logger.warning("Enabled int64-addressing Triton patch for vLLM DeepGEMM ep_scatter.")


@triton.jit
def _silu_mul_per_token_group_quant_fp8_colmajor_int64_kernel(
y_ptr,
y_q_ptr,
y_s_ptr,
M: tl.int64,
N: tl.int64,
y_s_col_stride: tl.int64,
eps,
fp8_min: tl.constexpr,
fp8_max: tl.constexpr,
use_ue8m0: tl.constexpr,
GROUP_SIZE: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
N_2 = N // 2

m_offset = (pid_m * BLOCK_M).to(tl.int64)
n_offset = (pid_n * BLOCK_N).to(tl.int64)
if m_offset >= M:
return

offs_n = tl.arange(0, BLOCK_N).to(tl.int64)
offs_m = tl.arange(0, BLOCK_M).to(tl.int64)

base_y_ptr = y_ptr + m_offset * N + n_offset
act_in_ptrs = base_y_ptr + offs_m[:, None] * N + offs_n[None, :]

act_in = tl.load(act_in_ptrs)
mul_in = tl.load(act_in_ptrs + N_2)

act_in = act_in.to(tl.float32)
one_f32 = tl.cast(1, tl.float32)
silu_out = (act_in / (one_f32 + tl.exp(-act_in))).to(y_ptr.dtype.element_ty)
y = (silu_out * mul_in).to(tl.float32)

absmax = tl.maximum(tl.max(tl.abs(y), axis=1), eps)
scale_raw = absmax * (1.0 / fp8_max)
y_s = tl.math.exp2(tl.ceil(tl.log2(scale_raw))) if use_ue8m0 else scale_raw
y_s = tl.reshape(y_s, (BLOCK_M, 1))
y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)

base_y_q_ptr = y_q_ptr + m_offset * N_2 + n_offset
y_q_ptrs = base_y_q_ptr + offs_m[:, None] * N_2 + offs_n[None, :]
tl.store(y_q_ptrs, y_q)

group_id = n_offset // GROUP_SIZE
base_y_s_ptr = y_s_ptr + group_id * y_s_col_stride + m_offset
y_s_ptrs = base_y_s_ptr + offs_m
y_s = tl.reshape(y_s, (BLOCK_M,))
tl.store(y_s_ptrs, y_s)


def _silu_mul_per_token_group_quant_fp8_colmajor_int64(
input: torch.Tensor,
output: torch.Tensor | None = None,
use_ue8m0: bool | None = None,
eps: float = 1e-10,
):
from vllm.platforms import current_platform
from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used

group_size = 128
assert input.ndim == 2
if output is not None:
assert output.ndim == 2
assert input.size(0) % group_size == 0
assert input.size(1) % (group_size * 2) == 0

if use_ue8m0 is None:
use_ue8m0 = is_deep_gemm_e8m0_used()

M, N = input.size()
N_2 = N // 2

fp8_dtype = current_platform.fp8_dtype()
if output is None:
output = torch.empty((M, N_2), dtype=fp8_dtype, device=input.device)

output_scales = torch.empty(((N_2 // group_size), M), dtype=torch.float32, device=input.device).transpose(0, 1)

block_m = 8
block_n = group_size
assert M % block_m == 0
assert N_2 % block_n == 0

finfo = torch.finfo(fp8_dtype)
fp8_min = -224.0 if current_platform.is_fp8_fnuz() else finfo.min
fp8_max = 224.0 if current_platform.is_fp8_fnuz() else finfo.max

grid = (M // block_m, N_2 // block_n)
_silu_mul_per_token_group_quant_fp8_colmajor_int64_kernel[grid](
input,
output,
output_scales,
M,
N,
output_scales.stride(-1),
eps,
fp8_min,
fp8_max,
use_ue8m0,
group_size,
block_m,
block_n,
)

return output, output_scales


def monkey_patch_deep_gemm_silu_mul_quant_int64():
# Temporary local carry for large DeepGEMM profile shapes whose row offsets
# exceed signed int32 address arithmetic in vLLM's Triton kernel.
import sys

from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils import fp8_utils

logger = init_logger(__name__)

fp8_utils.silu_mul_per_token_group_quant_fp8_colmajor = _silu_mul_per_token_group_quant_fp8_colmajor_int64

deep_gemm_moe_module = sys.modules.get("vllm.model_executor.layers.fused_moe.experts.deep_gemm_moe")
if deep_gemm_moe_module is not None:
deep_gemm_moe_module.silu_mul_per_token_group_quant_fp8_colmajor = (
_silu_mul_per_token_group_quant_fp8_colmajor_int64
)

logger.warning("Enabled int64-addressing Triton patch for vLLM DeepGEMM SiLU/mul FP8 quant.")


def _patch_qwen35_lora():
"""Fix Qwen3.5 LoRA: align packed_modules_mapping with output_sizes.

Expand Down Expand Up @@ -897,9 +762,9 @@ def monkey_patch_dp_engine_core_pause_resume_deadlock():
- on resume, wake every DP rank and force an immediate global unfinished
sync instead of waiting for the normal 32-step cadence

This keeps the upstream pause-side fix from
https://github.com/vllm-project/vllm/pull/37024 and extends it with the
resume-side wave-state fix.
This also bypasses vLLM's two-phase DP pause implementation
(https://github.com/vllm-project/vllm/pull/39366), which makes resume
reject states that our weight-update flow can validly hit.
"""
from vllm.config import ParallelConfig
from vllm.v1.core.sched.interface import PauseState
Expand All @@ -909,7 +774,8 @@ def monkey_patch_dp_engine_core_pause_resume_deadlock():

_base_add_request = EngineCore.add_request
_base_handle_client_request = EngineCoreProc._handle_client_request
_base_resume_scheduler = DPEngineCoreProc.resume_scheduler
_base_pause_complete = EngineCoreProc._pause_complete
_base_resume_scheduler = EngineCoreProc.resume_scheduler

def _patched_add_request(self, request: Request, request_wave: int = 0):
_base_add_request(self, request, request_wave)
Expand All @@ -930,8 +796,15 @@ def _patched_handle_client_request(self, request_type, request):
else:
_base_handle_client_request(self, request_type, request)

def _patched_pause_complete(self) -> bool:
self.pending_pause = False
self.ignore_start_dp_wave = False
return _base_pause_complete(self)

def _patched_resume_scheduler(self):
was_paused = self.scheduler.pause_state != PauseState.UNPAUSED
self.pending_pause = False
self.ignore_start_dp_wave = False
_base_resume_scheduler(self)
if was_paused:
self.engines_running = True
Expand All @@ -948,6 +821,7 @@ def _patched_has_global_unfinished_reqs(self, local_unfinished: bool) -> bool:

DPEngineCoreProc.add_request = _patched_add_request
DPEngineCoreProc._handle_client_request = _patched_handle_client_request
DPEngineCoreProc._pause_complete = _patched_pause_complete
DPEngineCoreProc.resume_scheduler = _patched_resume_scheduler
DPEngineCoreProc._has_global_unfinished_reqs = _patched_has_global_unfinished_reqs

Expand Down
11 changes: 7 additions & 4 deletions src/prime_rl/templates/inference.sbatch.j2
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ export PREFILL_PORT={{ prefill_port }}
export DECODE_PORT={{ decode_port }}
export ROUTER_PORT={{ router_port }}
export RPC_PORT={{ data_parallel_rpc_port }}
export NIXL_ABORT_TIMEOUT={{ kv_transport_config.abort_timeout_seconds }}
export VLLM_NIXL_ABORT_REQUEST_TIMEOUT={{ kv_transport_config.abort_timeout_seconds }}
{%- elif num_nodes > 1 %}
export ROUTER_PORT={{ router_port }}
export BACKEND_PORT={{ backend_port }}
Expand Down Expand Up @@ -171,15 +173,15 @@ srun bash -c '
export VLLM_NIXL_SIDE_CHANNEL_PORT=5600

{%- if kv_offload %}
PREFILL_KV_CFG='"'"'{"kv_connector":"MultiConnector","kv_role":"kv_both","kv_connector_extra_config":{"connectors":[{"kv_connector":"NixlConnector","kv_role":"kv_both","kv_connector_extra_config":{"num_threads":1}},{"kv_connector":"OffloadingConnector","kv_role":"kv_both","kv_connector_extra_config":{"cpu_bytes_to_use":{{ kv_offload_cpu_bytes }}}}]}}'"'"'
PREFILL_KV_CFG='"'"'{"kv_connector":"MultiConnector","kv_role":"kv_both","kv_connector_extra_config":{"connectors":[{"kv_connector":"NixlConnector","kv_role":"kv_both","kv_connector_extra_config":{"num_threads":{{ kv_transport_config.num_threads }},"bidirectional_kv_xfer":{{ "true" if kv_transport_config.enable_bidirectional else "false" }},"kv_recompute_threshold":{{ kv_transport_config.kv_recompute_threshold }}}},{"kv_connector":"OffloadingConnector","kv_role":"kv_both","kv_connector_extra_config":{"cpu_bytes_to_use":{{ kv_offload_cpu_bytes }}}}]}}'"'"'
{%- else %}
PREFILL_KV_CFG='"'"'{"kv_connector":"NixlConnector","kv_role":"kv_both","kv_connector_extra_config":{"num_threads":1}}'"'"'
PREFILL_KV_CFG='"'"'{"kv_connector":"NixlConnector","kv_role":"kv_both","kv_connector_extra_config":{"num_threads":{{ kv_transport_config.num_threads }},"bidirectional_kv_xfer":{{ "true" if kv_transport_config.enable_bidirectional else "false" }},"kv_recompute_threshold":{{ kv_transport_config.kv_recompute_threshold }}}}'"'"'
{%- endif %}

{%- if kv_offload %}
DECODE_KV_CFG='"'"'{"kv_connector":"MultiConnector","kv_role":"kv_both","kv_connector_extra_config":{"connectors":[{"kv_connector":"NixlConnector","kv_role":"kv_both","kv_connector_extra_config":{"num_threads":1}},{"kv_connector":"OffloadingConnector","kv_role":"kv_both","kv_connector_extra_config":{"cpu_bytes_to_use":{{ kv_offload_cpu_bytes }}}}]}}'"'"'
DECODE_KV_CFG='"'"'{"kv_connector":"MultiConnector","kv_role":"kv_both","kv_connector_extra_config":{"connectors":[{"kv_connector":"NixlConnector","kv_role":"kv_both","kv_connector_extra_config":{"num_threads":{{ kv_transport_config.num_threads }},"bidirectional_kv_xfer":{{ "true" if kv_transport_config.enable_bidirectional else "false" }},"kv_recompute_threshold":{{ kv_transport_config.kv_recompute_threshold }}}},{"kv_connector":"OffloadingConnector","kv_role":"kv_both","kv_connector_extra_config":{"cpu_bytes_to_use":{{ kv_offload_cpu_bytes }}}}]}}'"'"'
{%- else %}
DECODE_KV_CFG='"'"'{"kv_connector":"NixlConnector","kv_role":"kv_both","kv_connector_extra_config":{"num_threads":1}}'"'"'
DECODE_KV_CFG='"'"'{"kv_connector":"NixlConnector","kv_role":"kv_both","kv_connector_extra_config":{"num_threads":{{ kv_transport_config.num_threads }},"bidirectional_kv_xfer":{{ "true" if kv_transport_config.enable_bidirectional else "false" }},"kv_recompute_threshold":{{ kv_transport_config.kv_recompute_threshold }}}}'"'"'
{%- endif %}

DECODE_COMPILE_CFG='"'"'{"cudagraph_mode":"FULL_DECODE_ONLY"}'"'"'
Expand Down Expand Up @@ -250,6 +252,7 @@ srun bash -c '
--host 0.0.0.0 \
--port $ROUTER_PORT \
--intra-node-data-parallel-size {{ dp_per_node }} \
--pd-kv-cache-ttl-secs {{ kv_transport_config.router_cache_ttl_seconds if kv_transport_config.enable_bidirectional else 0 }} \
--worker-startup-timeout-secs 4200 \
--log-level debug \
>> $ROUTER_LOG 2>&1 &
Expand Down
Loading
Loading