Skip to content
6 changes: 6 additions & 0 deletions packages/prime-rl-configs/src/prime_rl/configs/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,12 @@ class ClientConfig(BaseConfig):
elastic: ElasticConfig | None = None
"""Elastic inference pool config for DNS-based service discovery. When set, ``base_url`` is ignored and inference servers are discovered dynamically via DNS."""

rl_base_url: list[str] | None = None
"""Dynamo RL worker discovery base URLs. Used only for backend='dynamo' when admin_base_url is unset. These URLs point at the Dynamo RL discovery listener (DYN_RL_PORT, default 8001), which serves GET /v1/rl/workers. If unset, prime-rl derives the discovery URL from base_url by replacing the port with DYN_RL_PORT or 8001."""

backend: Literal["vllm", "dynamo"] = "vllm"
"""Inference backend selector. Picks the AdminAPI implementation used for pause/resume/update_weights/load_lora_adapter/list_models. Default 'vllm' matches prime-rl's bundled vLLM frontend. 'dynamo' targets NVIDIA Dynamo's worker /engine/* admin routes on admin_base_url and routes /v1/models to the OpenAI-compat base_url."""

router_url: str | None = None
"""vllm-router URL for load-aware inference routing. With elastic mode, inference requests go through the router while admin ops still hit discovered pods directly."""

Expand Down
6 changes: 6 additions & 0 deletions src/prime_rl/orchestrator/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,11 @@ async def setup(self) -> None:
)
await self.inference_metrics.start()

# Propagate the configured broadcast type to the admin API. The Dynamo
# backend gates its NCCL init + weight-update path on ``_weight_broadcast_type``.
if hasattr(self.student_inference._admin_api, "_weight_broadcast_type"):
self.student_inference._admin_api._weight_broadcast_type = config.weight_broadcast.type

get_logger().info(f"Initializing weight broadcast ({config.weight_broadcast})")
if config.weight_broadcast.type == "nccl":
await init_nccl_broadcast(
Expand All @@ -327,6 +332,7 @@ async def setup(self) -> None:
config.weight_broadcast.timeout,
inference_world_size=config.weight_broadcast.inference_world_size,
quantize_in_weight_transfer=config.weight_broadcast.quantize_in_weight_transfer,
admin=self.student_inference._admin_api,
)

get_logger().info(f"Initializing training batch sender ({config.rollout_transport})")
Expand Down
162 changes: 117 additions & 45 deletions src/prime_rl/orchestrator/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from concurrent.futures import ThreadPoolExecutor
from itertools import cycle
from pathlib import Path
from typing import Any

import orjson
import verifiers as vf
Expand Down Expand Up @@ -96,58 +97,129 @@ def set_default_executor(max_workers: int = 64) -> None:
asyncio.get_event_loop().set_default_executor(ThreadPoolExecutor(max_workers=max_workers))


def _flatten_prompt_logprobs(raw: list[Any] | None) -> list[float]:
"""Shared flattener used by both transports.

``prompt_logprobs[i]`` is a ``{token_id: Logprob}`` dict for tokens the
engine could score, or ``None`` for the leading token which has no
preceding context. Flatten to ``list[float]`` with 0.0 in the unscored
slot. Accepts both vLLM's typed ``Logprob`` objects and dynamo's
``PromptLogprobEntry`` dict shape (`{logprob, rank?, decoded_token?}`).
"""
flat: list[float] = []
for entry in raw or []:
if not entry:
flat.append(0.0)
continue
first = next(iter(entry.values()))
lp = first.logprob if hasattr(first, "logprob") else first.get("logprob")
flat.append(float(lp) if lp is not None else 0.0)
return flat


async def _compute_teacher_logprobs_vllm(
client_config: vf.ClientConfig, model_name: str, sample: TrainingSample
) -> list[float]:
"""Legacy path: prime-rl's vLLM sidecar ``/inference/v1/generate``."""
import httpx
from vllm.entrypoints.serve.disagg.protocol import GenerateResponse

client = setup_openai_client(client_config)
# Two escape hatches from ``AsyncOpenAI.post``:
# 1. URL — ``/inference/v1/generate`` is mounted at server root, not
# under ``/v1``. Pass an absolute URL so the SDK's ``_prepare_url``
# skips the base-url merge.
# 2. Parse — vLLM's ``GenerateResponse`` isn't an ``openai.BaseModel``.
# Use ``cast_to=httpx.Response`` and validate the body ourselves.
base = str(client.base_url).rstrip("/").removesuffix("/v1")
http_response = await client.post(
f"{base}/inference/v1/generate",
cast_to=httpx.Response,
body={
"model": model_name,
"token_ids": list(sample.prompt_ids) + list(sample.completion_ids),
"sampling_params": {
"max_tokens": 1,
"temperature": 1.0,
"top_p": 1.0,
"prompt_logprobs": 1,
},
},
)
response = GenerateResponse.model_validate_json(http_response.content)
return _flatten_prompt_logprobs(response.prompt_logprobs)


async def _compute_teacher_logprobs_dynamo(
client_config: vf.ClientConfig, model_name: str, sample: TrainingSample
) -> list[float]:
"""Dynamo path: ``/v1/chat/completions`` with an nvext envelope.

Wire shape:
- top-level ``prompt_logprobs: 1`` (CommonExt sampling param)
- ``nvext.token_data`` carries the pre-tokenized prompt
- ``nvext.extra_fields = ["prompt_logprobs"]`` opts into the response
field; dynamo emits ``response.nvext.prompt_logprobs`` shaped as
``[None | {token_id: {logprob, rank?, decoded_token?}}]``, which the
shared flattener consumes unchanged.

Requires the vLLM worker to populate ``prompt_logprobs`` when
``SamplingParams.prompt_logprobs`` is set; otherwise the field is None.
"""
client = setup_openai_client(client_config)
token_ids = list(sample.prompt_ids) + list(sample.completion_ids)
body = {
"model": model_name,
# Placeholder stub the OpenAI/Dynamo chat schema requires (rejected if
# empty); the authoritative prompt is carried in nvext.token_data and
# Dynamo ignores these messages. Matches renderers' dynamo_chat client.
"messages": [{"role": "user", "content": ""}],
"max_completion_tokens": 1,
"temperature": 1.0,
"top_p": 1.0,
"prompt_logprobs": 1,
"nvext": {
"token_data": token_ids,
"extra_fields": ["prompt_logprobs"],
},
}
# Dynamo's response is a standard chat-completion JSON with an extra
# ``nvext`` field. Use ``cast_to=httpx.Response`` so we can read the raw
# body and pluck ``nvext.prompt_logprobs`` — the OpenAI SDK response
# models drop unknown fields.
import httpx as _httpx

http_response = await client.post(
"/chat/completions",
cast_to=_httpx.Response,
body=body,
)
payload = http_response.json()
nvext_resp = (payload or {}).get("nvext") or {}
raw = nvext_resp.get("prompt_logprobs")
return _flatten_prompt_logprobs(raw)


async def compute_teacher_logprobs(
clients: list[vf.ClientConfig],
model_name: str,
samples: list[TrainingSample],
) -> list[list[float]]:
"""Compute teacher model logprobs for a batch of training samples via prefill."""
import httpx
from vllm.entrypoints.serve.disagg.protocol import GenerateResponse
"""Compute teacher model logprobs for a batch of training samples via prefill.

Dispatches to the vLLM-sidecar or dynamo-nvext path based on the per-client
``renderer_transport``:

- ``vllm_generate`` (default): POST ``/inference/v1/generate``
- ``dynamo_chat``: POST ``/v1/chat/completions`` with nvext

Both flatten to ``list[float]`` via the shared helper.
"""

async def _compute_single(client_config: vf.ClientConfig, sample: TrainingSample) -> list[float]:
client = setup_openai_client(client_config)

# Two escape hatches from ``AsyncOpenAI.post``:
# 1. URL — ``/inference/v1/generate`` is mounted at server root, not
# under ``/v1``. Pass an absolute URL so the SDK's
# ``_prepare_url`` skips the base-url merge (it short-circuits
# when the path passes ``httpx.URL.is_relative_url`` as False).
# 2. Parse — vLLM's ``GenerateResponse`` is a plain
# ``pydantic.BaseModel`` and the SDK's parse layer rejects any
# ``cast_to`` that doesn't subclass ``openai.BaseModel``. Use
# ``cast_to=httpx.Response`` so the SDK still builds the request
# (preserving ``auth_headers``, retries, timeouts, idempotency
# keys) and just hands us the raw response to validate ourselves.
base = str(client.base_url).rstrip("/").removesuffix("/v1")
http_response = await client.post(
f"{base}/inference/v1/generate",
cast_to=httpx.Response,
body={
"model": model_name,
"token_ids": list(sample.prompt_ids) + list(sample.completion_ids),
"sampling_params": {
"max_tokens": 1,
"temperature": 1.0,
"top_p": 1.0,
"prompt_logprobs": 1,
},
},
)
response = GenerateResponse.model_validate_json(http_response.content)
# ``prompt_logprobs[i]`` is a ``{token_id: Logprob}`` dict for tokens
# the engine could score, or ``None`` for the leading token which has
# no preceding context. Flatten to ``list[float]`` with 0.0 in the
# unscored slot.
flat: list[float] = []
for entry in response.prompt_logprobs or []:
if not entry:
flat.append(0.0)
continue
first = next(iter(entry.values()))
lp = first.logprob if hasattr(first, "logprob") else first.get("logprob")
flat.append(float(lp) if lp is not None else 0.0)
return flat
if getattr(client_config, "renderer_transport", "vllm_generate") == "dynamo_chat":
return await _compute_teacher_logprobs_dynamo(client_config, model_name, sample)
return await _compute_teacher_logprobs_vllm(client_config, model_name, sample)

return await asyncio.gather(*[_compute_single(client, sample) for client, sample in zip(cycle(clients), samples)])

Expand Down
Loading