Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions packages/prime-rl-configs/src/prime_rl/configs/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,12 @@ class ClientConfig(BaseConfig):
elastic: ElasticConfig | None = None
"""Elastic inference pool config for DNS-based service discovery. When set, ``base_url`` is ignored and inference servers are discovered dynamically via DNS."""

rl_base_url: list[str] | None = None
"""Dynamo RL worker discovery base URLs. Used only for backend='dynamo' when admin_base_url is unset. These URLs point at the Dynamo RL discovery listener (DYN_RL_PORT, default 8001), which serves GET /v1/rl/workers. If unset, prime-rl derives the discovery URL from base_url by replacing the port with DYN_RL_PORT or 8001."""

backend: Literal["vllm", "dynamo"] = "vllm"
"""Inference backend selector. Picks the AdminAPI implementation used for pause/resume/update_weights/load_lora_adapter/list_models. Default 'vllm' matches prime-rl's bundled vLLM frontend. 'dynamo' targets NVIDIA Dynamo's worker /engine/* admin routes on admin_base_url and routes /v1/models to the OpenAI-compat base_url."""

router_url: str | None = None
"""vllm-router URL for load-aware inference routing. With elastic mode, inference requests go through the router while admin ops still hit discovered pods directly."""

Expand Down
6 changes: 6 additions & 0 deletions src/prime_rl/orchestrator/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,11 @@ async def setup(self) -> None:
)
await self.inference_metrics.start()

# Propagate the configured broadcast type to the admin API. The Dynamo
# backend gates its NCCL init + weight-update path on ``_weight_broadcast_type``.
if hasattr(self.student_inference._admin_api, "_weight_broadcast_type"):
self.student_inference._admin_api._weight_broadcast_type = config.weight_broadcast.type

get_logger().info(f"Initializing weight broadcast ({config.weight_broadcast})")
if config.weight_broadcast.type == "nccl":
await init_nccl_broadcast(
Expand All @@ -327,6 +332,7 @@ async def setup(self) -> None:
config.weight_broadcast.timeout,
inference_world_size=config.weight_broadcast.inference_world_size,
quantize_in_weight_transfer=config.weight_broadcast.quantize_in_weight_transfer,
admin=self.student_inference._admin_api,
)

get_logger().info(f"Initializing training batch sender ({config.rollout_transport})")
Expand Down
162 changes: 117 additions & 45 deletions src/prime_rl/orchestrator/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from concurrent.futures import ThreadPoolExecutor
from itertools import cycle
from pathlib import Path
from typing import Any

import orjson
import verifiers as vf
Expand Down Expand Up @@ -96,58 +97,129 @@ def set_default_executor(max_workers: int = 64) -> None:
asyncio.get_event_loop().set_default_executor(ThreadPoolExecutor(max_workers=max_workers))


def _flatten_prompt_logprobs(raw: list[Any] | None) -> list[float]:
"""Shared flattener used by both transports.

``prompt_logprobs[i]`` is a ``{token_id: Logprob}`` dict for tokens the
engine could score, or ``None`` for the leading token which has no
preceding context. Flatten to ``list[float]`` with 0.0 in the unscored
slot. Accepts both vLLM's typed ``Logprob`` objects and dynamo's
``PromptLogprobEntry`` dict shape (`{logprob, rank?, decoded_token?}`).
"""
flat: list[float] = []
for entry in raw or []:
if not entry:
flat.append(0.0)
continue
first = next(iter(entry.values()))
lp = first.logprob if hasattr(first, "logprob") else first.get("logprob")
flat.append(float(lp) if lp is not None else 0.0)
return flat


async def _compute_teacher_logprobs_vllm(
client_config: vf.ClientConfig, model_name: str, sample: TrainingSample
) -> list[float]:
"""Legacy path: prime-rl's vLLM sidecar ``/inference/v1/generate``."""
import httpx
from vllm.entrypoints.serve.disagg.protocol import GenerateResponse

client = setup_openai_client(client_config)
# Two escape hatches from ``AsyncOpenAI.post``:
# 1. URL — ``/inference/v1/generate`` is mounted at server root, not
# under ``/v1``. Pass an absolute URL so the SDK's ``_prepare_url``
# skips the base-url merge.
# 2. Parse — vLLM's ``GenerateResponse`` isn't an ``openai.BaseModel``.
# Use ``cast_to=httpx.Response`` and validate the body ourselves.
base = str(client.base_url).rstrip("/").removesuffix("/v1")
http_response = await client.post(
f"{base}/inference/v1/generate",
cast_to=httpx.Response,
body={
"model": model_name,
"token_ids": list(sample.prompt_ids) + list(sample.completion_ids),
"sampling_params": {
"max_tokens": 1,
"temperature": 1.0,
"top_p": 1.0,
"prompt_logprobs": 1,
},
},
)
response = GenerateResponse.model_validate_json(http_response.content)
return _flatten_prompt_logprobs(response.prompt_logprobs)


async def _compute_teacher_logprobs_dynamo(
client_config: vf.ClientConfig, model_name: str, sample: TrainingSample
) -> list[float]:
"""Dynamo path: ``/v1/chat/completions`` with an nvext envelope.

Wire shape:
- top-level ``prompt_logprobs: 1`` (CommonExt sampling param)
- ``nvext.token_data`` carries the pre-tokenized prompt
- ``nvext.extra_fields = ["prompt_logprobs"]`` opts into the response
field; dynamo emits ``response.nvext.prompt_logprobs`` shaped as
``[None | {token_id: {logprob, rank?, decoded_token?}}]``, which the
shared flattener consumes unchanged.

Requires the vLLM worker to populate ``prompt_logprobs`` when
``SamplingParams.prompt_logprobs`` is set; otherwise the field is None.
"""
client = setup_openai_client(client_config)
token_ids = list(sample.prompt_ids) + list(sample.completion_ids)
body = {
"model": model_name,
# Placeholder stub the OpenAI/Dynamo chat schema requires (rejected if
# empty); the authoritative prompt is carried in nvext.token_data and
# Dynamo ignores these messages. Matches renderers' dynamo_chat client.
"messages": [{"role": "user", "content": ""}],
"max_completion_tokens": 1,
"temperature": 1.0,
"top_p": 1.0,
"prompt_logprobs": 1,
"nvext": {
"token_data": token_ids,
"extra_fields": ["prompt_logprobs"],
},
}
# Dynamo's response is a standard chat-completion JSON with an extra
# ``nvext`` field. Use ``cast_to=httpx.Response`` so we can read the raw
# body and pluck ``nvext.prompt_logprobs`` — the OpenAI SDK response
# models drop unknown fields.
import httpx as _httpx

http_response = await client.post(
"/chat/completions",
cast_to=_httpx.Response,
body=body,
)
payload = http_response.json()
nvext_resp = (payload or {}).get("nvext") or {}
raw = nvext_resp.get("prompt_logprobs")
return _flatten_prompt_logprobs(raw)


async def compute_teacher_logprobs(
clients: list[vf.ClientConfig],
model_name: str,
samples: list[TrainingSample],
) -> list[list[float]]:
"""Compute teacher model logprobs for a batch of training samples via prefill."""
import httpx
from vllm.entrypoints.serve.disagg.protocol import GenerateResponse
"""Compute teacher model logprobs for a batch of training samples via prefill.

Dispatches to the vLLM-sidecar or dynamo-nvext path based on the per-client
``renderer_transport``:

- ``vllm_generate`` (default): POST ``/inference/v1/generate``
- ``dynamo_chat``: POST ``/v1/chat/completions`` with nvext

Both flatten to ``list[float]`` via the shared helper.
"""

async def _compute_single(client_config: vf.ClientConfig, sample: TrainingSample) -> list[float]:
client = setup_openai_client(client_config)

# Two escape hatches from ``AsyncOpenAI.post``:
# 1. URL — ``/inference/v1/generate`` is mounted at server root, not
# under ``/v1``. Pass an absolute URL so the SDK's
# ``_prepare_url`` skips the base-url merge (it short-circuits
# when the path passes ``httpx.URL.is_relative_url`` as False).
# 2. Parse — vLLM's ``GenerateResponse`` is a plain
# ``pydantic.BaseModel`` and the SDK's parse layer rejects any
# ``cast_to`` that doesn't subclass ``openai.BaseModel``. Use
# ``cast_to=httpx.Response`` so the SDK still builds the request
# (preserving ``auth_headers``, retries, timeouts, idempotency
# keys) and just hands us the raw response to validate ourselves.
base = str(client.base_url).rstrip("/").removesuffix("/v1")
http_response = await client.post(
f"{base}/inference/v1/generate",
cast_to=httpx.Response,
body={
"model": model_name,
"token_ids": list(sample.prompt_ids) + list(sample.completion_ids),
"sampling_params": {
"max_tokens": 1,
"temperature": 1.0,
"top_p": 1.0,
"prompt_logprobs": 1,
},
},
)
response = GenerateResponse.model_validate_json(http_response.content)
# ``prompt_logprobs[i]`` is a ``{token_id: Logprob}`` dict for tokens
# the engine could score, or ``None`` for the leading token which has
# no preceding context. Flatten to ``list[float]`` with 0.0 in the
# unscored slot.
flat: list[float] = []
for entry in response.prompt_logprobs or []:
if not entry:
flat.append(0.0)
continue
first = next(iter(entry.values()))
lp = first.logprob if hasattr(first, "logprob") else first.get("logprob")
flat.append(float(lp) if lp is not None else 0.0)
return flat
if getattr(client_config, "renderer_transport", "vllm_generate") == "dynamo_chat":
return await _compute_teacher_logprobs_dynamo(client_config, model_name, sample)
return await _compute_teacher_logprobs_vllm(client_config, model_name, sample)

return await asyncio.gather(*[_compute_single(client, sample) for client, sample in zip(cycle(clients), samples)])

Expand Down
Loading
Loading