diff --git a/packages/prime-rl-configs/src/prime_rl/configs/shared.py b/packages/prime-rl-configs/src/prime_rl/configs/shared.py index ff311f145d..9c16c410bd 100644 --- a/packages/prime-rl-configs/src/prime_rl/configs/shared.py +++ b/packages/prime-rl-configs/src/prime_rl/configs/shared.py @@ -136,6 +136,12 @@ class ClientConfig(BaseConfig): elastic: ElasticConfig | None = None """Elastic inference pool config for DNS-based service discovery. When set, ``base_url`` is ignored and inference servers are discovered dynamically via DNS.""" + rl_base_url: list[str] | None = None + """Dynamo RL worker discovery base URLs. Used only for backend='dynamo' when admin_base_url is unset. These URLs point at the Dynamo RL discovery listener (DYN_RL_PORT, default 8001), which serves GET /v1/rl/workers. If unset, prime-rl derives the discovery URL from base_url by replacing the port with DYN_RL_PORT or 8001.""" + + backend: Literal["vllm", "dynamo"] = "vllm" + """Inference backend selector. Picks the AdminAPI implementation used for pause/resume/update_weights/load_lora_adapter/list_models. Default 'vllm' matches prime-rl's bundled vLLM frontend. 'dynamo' targets NVIDIA Dynamo's worker /engine/* admin routes on admin_base_url and routes /v1/models to the OpenAI-compat base_url.""" + router_url: str | None = None """vllm-router URL for load-aware inference routing. With elastic mode, inference requests go through the router while admin ops still hit discovered pods directly.""" diff --git a/src/prime_rl/orchestrator/orchestrator.py b/src/prime_rl/orchestrator/orchestrator.py index cb8bbcf852..f7e3f098e2 100644 --- a/src/prime_rl/orchestrator/orchestrator.py +++ b/src/prime_rl/orchestrator/orchestrator.py @@ -318,6 +318,11 @@ async def setup(self) -> None: ) await self.inference_metrics.start() + # Propagate the configured broadcast type to the admin API. The Dynamo + # backend gates its NCCL init + weight-update path on ``_weight_broadcast_type``. + if hasattr(self.student_inference._admin_api, "_weight_broadcast_type"): + self.student_inference._admin_api._weight_broadcast_type = config.weight_broadcast.type + get_logger().info(f"Initializing weight broadcast ({config.weight_broadcast})") if config.weight_broadcast.type == "nccl": await init_nccl_broadcast( @@ -327,6 +332,7 @@ async def setup(self) -> None: config.weight_broadcast.timeout, inference_world_size=config.weight_broadcast.inference_world_size, quantize_in_weight_transfer=config.weight_broadcast.quantize_in_weight_transfer, + admin=self.student_inference._admin_api, ) get_logger().info(f"Initializing training batch sender ({config.rollout_transport})") diff --git a/src/prime_rl/orchestrator/utils.py b/src/prime_rl/orchestrator/utils.py index 5675ba3f34..082bc54c5a 100644 --- a/src/prime_rl/orchestrator/utils.py +++ b/src/prime_rl/orchestrator/utils.py @@ -4,6 +4,7 @@ from concurrent.futures import ThreadPoolExecutor from itertools import cycle from pathlib import Path +from typing import Any import orjson import verifiers as vf @@ -96,58 +97,129 @@ def set_default_executor(max_workers: int = 64) -> None: asyncio.get_event_loop().set_default_executor(ThreadPoolExecutor(max_workers=max_workers)) +def _flatten_prompt_logprobs(raw: list[Any] | None) -> list[float]: + """Shared flattener used by both transports. + + ``prompt_logprobs[i]`` is a ``{token_id: Logprob}`` dict for tokens the + engine could score, or ``None`` for the leading token which has no + preceding context. Flatten to ``list[float]`` with 0.0 in the unscored + slot. Accepts both vLLM's typed ``Logprob`` objects and dynamo's + ``PromptLogprobEntry`` dict shape (`{logprob, rank?, decoded_token?}`). + """ + flat: list[float] = [] + for entry in raw or []: + if not entry: + flat.append(0.0) + continue + first = next(iter(entry.values())) + lp = first.logprob if hasattr(first, "logprob") else first.get("logprob") + flat.append(float(lp) if lp is not None else 0.0) + return flat + + +async def _compute_teacher_logprobs_vllm( + client_config: vf.ClientConfig, model_name: str, sample: TrainingSample +) -> list[float]: + """Legacy path: prime-rl's vLLM sidecar ``/inference/v1/generate``.""" + import httpx + from vllm.entrypoints.serve.disagg.protocol import GenerateResponse + + client = setup_openai_client(client_config) + # Two escape hatches from ``AsyncOpenAI.post``: + # 1. URL — ``/inference/v1/generate`` is mounted at server root, not + # under ``/v1``. Pass an absolute URL so the SDK's ``_prepare_url`` + # skips the base-url merge. + # 2. Parse — vLLM's ``GenerateResponse`` isn't an ``openai.BaseModel``. + # Use ``cast_to=httpx.Response`` and validate the body ourselves. + base = str(client.base_url).rstrip("/").removesuffix("/v1") + http_response = await client.post( + f"{base}/inference/v1/generate", + cast_to=httpx.Response, + body={ + "model": model_name, + "token_ids": list(sample.prompt_ids) + list(sample.completion_ids), + "sampling_params": { + "max_tokens": 1, + "temperature": 1.0, + "top_p": 1.0, + "prompt_logprobs": 1, + }, + }, + ) + response = GenerateResponse.model_validate_json(http_response.content) + return _flatten_prompt_logprobs(response.prompt_logprobs) + + +async def _compute_teacher_logprobs_dynamo( + client_config: vf.ClientConfig, model_name: str, sample: TrainingSample +) -> list[float]: + """Dynamo path: ``/v1/chat/completions`` with an nvext envelope. + + Wire shape: + - top-level ``prompt_logprobs: 1`` (CommonExt sampling param) + - ``nvext.token_data`` carries the pre-tokenized prompt + - ``nvext.extra_fields = ["prompt_logprobs"]`` opts into the response + field; dynamo emits ``response.nvext.prompt_logprobs`` shaped as + ``[None | {token_id: {logprob, rank?, decoded_token?}}]``, which the + shared flattener consumes unchanged. + + Requires the vLLM worker to populate ``prompt_logprobs`` when + ``SamplingParams.prompt_logprobs`` is set; otherwise the field is None. + """ + client = setup_openai_client(client_config) + token_ids = list(sample.prompt_ids) + list(sample.completion_ids) + body = { + "model": model_name, + # Placeholder stub the OpenAI/Dynamo chat schema requires (rejected if + # empty); the authoritative prompt is carried in nvext.token_data and + # Dynamo ignores these messages. Matches renderers' dynamo client. + "messages": [{"role": "user", "content": ""}], + "max_completion_tokens": 1, + "temperature": 1.0, + "top_p": 1.0, + "prompt_logprobs": 1, + "nvext": { + "token_data": token_ids, + "extra_fields": ["prompt_logprobs"], + }, + } + # Dynamo's response is a standard chat-completion JSON with an extra + # ``nvext`` field. Use ``cast_to=httpx.Response`` so we can read the raw + # body and pluck ``nvext.prompt_logprobs`` — the OpenAI SDK response + # models drop unknown fields. + import httpx as _httpx + + http_response = await client.post( + "/chat/completions", + cast_to=_httpx.Response, + body=body, + ) + payload = http_response.json() + nvext_resp = (payload or {}).get("nvext") or {} + raw = nvext_resp.get("prompt_logprobs") + return _flatten_prompt_logprobs(raw) + + async def compute_teacher_logprobs( clients: list[vf.ClientConfig], model_name: str, samples: list[TrainingSample], ) -> list[list[float]]: - """Compute teacher model logprobs for a batch of training samples via prefill.""" - import httpx - from vllm.entrypoints.serve.disagg.protocol import GenerateResponse + """Compute teacher model logprobs for a batch of training samples via prefill. + + Dispatches to the vLLM-sidecar or dynamo-nvext path based on the per-client + ``renderer_transport``: + + - ``vllm`` (default): POST ``/inference/v1/generate`` + - ``dynamo``: POST ``/v1/chat/completions`` with nvext + + Both flatten to ``list[float]`` via the shared helper. + """ async def _compute_single(client_config: vf.ClientConfig, sample: TrainingSample) -> list[float]: - client = setup_openai_client(client_config) - - # Two escape hatches from ``AsyncOpenAI.post``: - # 1. URL — ``/inference/v1/generate`` is mounted at server root, not - # under ``/v1``. Pass an absolute URL so the SDK's - # ``_prepare_url`` skips the base-url merge (it short-circuits - # when the path passes ``httpx.URL.is_relative_url`` as False). - # 2. Parse — vLLM's ``GenerateResponse`` is a plain - # ``pydantic.BaseModel`` and the SDK's parse layer rejects any - # ``cast_to`` that doesn't subclass ``openai.BaseModel``. Use - # ``cast_to=httpx.Response`` so the SDK still builds the request - # (preserving ``auth_headers``, retries, timeouts, idempotency - # keys) and just hands us the raw response to validate ourselves. - base = str(client.base_url).rstrip("/").removesuffix("/v1") - http_response = await client.post( - f"{base}/inference/v1/generate", - cast_to=httpx.Response, - body={ - "model": model_name, - "token_ids": list(sample.prompt_ids) + list(sample.completion_ids), - "sampling_params": { - "max_tokens": 1, - "temperature": 1.0, - "top_p": 1.0, - "prompt_logprobs": 1, - }, - }, - ) - response = GenerateResponse.model_validate_json(http_response.content) - # ``prompt_logprobs[i]`` is a ``{token_id: Logprob}`` dict for tokens - # the engine could score, or ``None`` for the leading token which has - # no preceding context. Flatten to ``list[float]`` with 0.0 in the - # unscored slot. - flat: list[float] = [] - for entry in response.prompt_logprobs or []: - if not entry: - flat.append(0.0) - continue - first = next(iter(entry.values())) - lp = first.logprob if hasattr(first, "logprob") else first.get("logprob") - flat.append(float(lp) if lp is not None else 0.0) - return flat + if getattr(client_config, "renderer_transport", "vllm") == "dynamo": + return await _compute_teacher_logprobs_dynamo(client_config, model_name, sample) + return await _compute_teacher_logprobs_vllm(client_config, model_name, sample) return await asyncio.gather(*[_compute_single(client, sample) for client, sample in zip(cycle(clients), samples)]) diff --git a/src/prime_rl/utils/client.py b/src/prime_rl/utils/client.py index 533f6e2711..98e1c9d10d 100644 --- a/src/prime_rl/utils/client.py +++ b/src/prime_rl/utils/client.py @@ -6,6 +6,7 @@ from itertools import cycle from pathlib import Path from typing import Protocol, runtime_checkable +from urllib.parse import urlsplit, urlunsplit import httpx import verifiers as vf @@ -28,6 +29,296 @@ def client_identity(client: vf.ClientConfig) -> ClientIdentity: return (client.api_base_url, client.extra_headers.get("X-data-parallel-rank")) +def _is_retryable_admin_error(exception: BaseException) -> bool: + """Check if an exception should trigger a retry for an admin op (pause/resume/update_weights).""" + if isinstance(exception, httpx.HTTPStatusError): + # Retry on transient server errors (5xx, e.g. engine briefly unresponsive); + # client errors (4xx) won't fix themselves on retry. + return exception.response.status_code >= 500 + # Retry on transport-level failures (timeouts, connection resets, etc.) so the + # per-attempt read timeout below turns a stuck server into a bounded retry loop + # instead of hanging forever on the global timeout=None admin client. + if isinstance(exception, (httpx.TimeoutException, httpx.TransportError)): + return True + return False + + +# Per-attempt read timeout for admin ops, overridable per call. The admin +# AsyncClient uses `timeout=None`, so without this a stuck server would hang the +# weight update forever: the read timeout converts a hang into a TimeoutException +# that tenacity retries. Sized for `/pause`, which drains in-flight requests +# (mode="keep") and so can legitimately take a while. +ADMIN_TIMEOUT_S = 300.0 +# `/update_weights` runs a collective NCCL receive across all DP workers, which +# can take longer than the other admin ops. +UPDATE_WEIGHTS_TIMEOUT_S = 720.0 + + +async def _admin_post(client: AsyncClient, path: str, *, timeout_s: float = ADMIN_TIMEOUT_S, **kwargs) -> None: + """POST an admin op with a bounded per-attempt timeout, retrying transient errors. + + The total wall-clock budget across all retries is twice the per-attempt timeout. + """ + async for attempt in AsyncRetrying( + retry=retry_if_exception(_is_retryable_admin_error), + stop=stop_after_delay(2 * timeout_s) | stop_after_attempt(10), + wait=wait_exponential(multiplier=1, min=1, max=10), + reraise=True, + ): + with attempt: + response = await client.post( + path, + timeout=httpx.Timeout(connect=10.0, read=timeout_s, write=60.0, pool=10.0), + **kwargs, + ) + response.raise_for_status() + + +class AdminAPI(Protocol): + """Admin endpoints for an inference backend. + + Per-method: construct one HTTP call. Per-server parallelism, retry, and + raise-for-status policy live in the caller. + """ + + async def health(self, client: AsyncClient) -> None: ... + async def list_models(self, client: AsyncClient) -> list[dict]: ... + async def pause(self, client: AsyncClient) -> None: ... + async def resume(self, client: AsyncClient) -> None: ... + async def update_weights(self, client: AsyncClient, weight_dir: str | None) -> None: ... + async def load_lora_adapter( + self, + client: AsyncClient, + lora_name: str, + lora_path: str, + *, + timeout: httpx.Timeout, + ) -> None: ... + async def init_broadcaster( + self, + client: AsyncClient, + *, + host: str, + port: int, + rank_offset: int, + inference_world_size: int, + timeout: int, + quantize_in_weight_transfer: bool, + ) -> None: ... + + +class VLLMAdminAPI: + """vLLM admin endpoints.""" + + async def health(self, client: AsyncClient) -> None: + # No raise_for_status: any HTTP response means the server is up. + # Only transport errors mean "not ready yet" (caller retries). + await client.get("/health") + + async def list_models(self, client: AsyncClient) -> list[dict]: + response = await client.get("/v1/models") + return response.json()["data"] + + async def pause(self, client: AsyncClient) -> None: + await _admin_post(client, "/pause", params={"mode": "keep", "clear_cache": "false"}) + + async def resume(self, client: AsyncClient) -> None: + await _admin_post(client, "/resume") + + async def update_weights(self, client: AsyncClient, weight_dir: str | None) -> None: + await _admin_post( + client, "/update_weights", json={"weight_dir": weight_dir}, timeout_s=UPDATE_WEIGHTS_TIMEOUT_S + ) + + async def load_lora_adapter( + self, + client: AsyncClient, + lora_name: str, + lora_path: str, + *, + timeout: httpx.Timeout, + ) -> None: + response = await client.post( + "/load_lora_adapter", + json={"lora_name": lora_name, "lora_path": lora_path}, + timeout=timeout, + ) + response.raise_for_status() + + async def init_broadcaster( + self, + client: AsyncClient, + *, + host: str, + port: int, + rank_offset: int, + inference_world_size: int, + timeout: int, + quantize_in_weight_transfer: bool, + ) -> None: + response = await client.post( + "/init_broadcaster", + json={ + "host": host, + "port": port, + "rank_offset": rank_offset, + "inference_world_size": inference_world_size, + "timeout": timeout, + "quantize_in_weight_transfer": quantize_in_weight_transfer, + }, + ) + response.raise_for_status() + + +class DynamoAdminAPI(VLLMAdminAPI): + """NVIDIA Dynamo worker admin endpoints via ``POST /engine/``. + + Each Dynamo worker exposes engine routes on its system status server + (``DYN_SYSTEM_PORT``, default 8081). Multi-worker deployments are handled by + iterating over ``admin_clients``. + + Args: + engine_rpc: The ``collective_rpc`` target forwarded by + ``update_weights_from_disk``. Use ``"reload_weights"`` for plain + vLLM / dynamo.vllm without a worker extension (default). Use + ``"update_weights_from_path"`` only when + FileSystemWeightUpdateWorker / NCCLWeightUpdateWorker is loaded via + ``--worker-extension-cls``. + """ + + def __init__(self, engine_rpc: str = "reload_weights", weight_broadcast_type: str = "filesystem") -> None: + self._engine_rpc = engine_rpc + # Determines which engine method is called per step: "update_weights_from_distributed" + # for NCCL (trainer broadcasts; worker just needs to receive) vs + # "update_weights_from_disk" for filesystem. Set externally by the orchestrator + # once weight_broadcast config is resolved. Defaults to filesystem (run #35 behaviour). + self._weight_broadcast_type = weight_broadcast_type + + async def health(self, client: AsyncClient) -> None: + await client.get("/health") + + async def _post_engine( + self, + client: AsyncClient, + method: str, + body: dict | None = None, + *, + timeout_s: float = ADMIN_TIMEOUT_S, + ) -> dict: + # Mirror _admin_post: bounded per-attempt read timeout + retry on + # transient 5xx/transport errors. The admin AsyncClient uses + # timeout=None, so without this a stuck worker hangs the op forever. + async for attempt in AsyncRetrying( + retry=retry_if_exception(_is_retryable_admin_error), + stop=stop_after_delay(2 * timeout_s) | stop_after_attempt(10), + wait=wait_exponential(multiplier=1, min=1, max=10), + reraise=True, + ): + with attempt: + response = await client.post( + f"/engine/{method}", + json=body or {}, + timeout=httpx.Timeout(connect=10.0, read=timeout_s, write=60.0, pool=10.0), + ) + response.raise_for_status() + data = response.json() + if isinstance(data, dict) and data.get("status") == "error": + raise RuntimeError(data.get("message", f"Dynamo /engine/{method} failed")) + return data + raise AssertionError("unreachable: AsyncRetrying returns or raises") + + async def pause(self, client: AsyncClient) -> None: + await self._post_engine(client, "pause_generation", {"mode": "keep", "clear_cache": False}) + + async def resume(self, client: AsyncClient) -> None: + await self._post_engine(client, "resume_generation") + + async def update_weights(self, client: AsyncClient, weight_dir: str | None) -> None: + if weight_dir is None: + return + if self._weight_broadcast_type == "nccl": + # NCCL path: trainer has already broadcast weights via the NCCL group; + # this RPC tells the inference worker to call receive_state_dict(). + # NCCLWeightUpdateWorker exposes "update_weights_from_path", not "reload_weights". + await self._post_engine( + client, + "update_weights_from_distributed", + { + "weight_version": Path(weight_dir).name, + "weight_dir": weight_dir, + "engine_rpc": "update_weights_from_path", + }, + timeout_s=UPDATE_WEIGHTS_TIMEOUT_S, + ) + else: + # Resolve to absolute path so the inference worker (which may run in a + # different working directory) can find the checkpoint on the shared NFS. + abs_path = str(Path(weight_dir).resolve()) + await self._post_engine( + client, + "update_weights_from_disk", + { + "model_path": abs_path, + "weight_version": Path(weight_dir).name, + "engine_rpc": self._engine_rpc, + }, + timeout_s=UPDATE_WEIGHTS_TIMEOUT_S, + ) + + async def load_lora_adapter( + self, + client: AsyncClient, + lora_name: str, + lora_path: str, + *, + timeout: httpx.Timeout, + ) -> None: + await self._post_engine( + client, + "load_lora", + { + "lora_name": lora_name, + "source": {"uri": Path(lora_path).absolute().as_uri()}, + }, + timeout_s=timeout.read or ADMIN_TIMEOUT_S, + ) + + async def init_broadcaster( + self, + client: AsyncClient, + *, + host: str, + port: int, + rank_offset: int, + inference_world_size: int, + timeout: int, + quantize_in_weight_transfer: bool, + ) -> None: + await self._post_engine( + client, + "init_weights_update_group", + { + "host": host, + "port": port, + "rank_offset": rank_offset, + "inference_world_size": inference_world_size, + "timeout": timeout, + "quantize_in_weight_transfer": quantize_in_weight_transfer, + "engine_rpc": "init_broadcaster", + }, + ) + + +def setup_admin_api(client_config: ClientConfig) -> AdminAPI: + """Pick the AdminAPI implementation that matches ``client_config.backend``.""" + if client_config.backend == "dynamo": + return DynamoAdminAPI() + return VLLMAdminAPI() + + +_DEFAULT_ADMIN: AdminAPI = VLLMAdminAPI() + + @runtime_checkable class InferencePool(Protocol): """Protocol for inference pools (static or elastic).""" @@ -93,6 +384,7 @@ def __init__( renderer_config: RendererConfig | None = None, pool_size: int | None = None, ): + self._client_config = client_config renderer_model_name = model_name if train_client_type == "renderer" else None self._train_clients = setup_clients( client_config, @@ -102,7 +394,17 @@ def __init__( pool_size=pool_size, ) self._eval_clients = setup_clients(client_config, client_type=eval_client_type) - self._admin_clients = setup_admin_clients(client_config) + self._admin_clients = ( + [] + if client_config.backend == "dynamo" and not client_config.admin_base_url + else setup_admin_clients(client_config) + ) + self._model_clients = ( + setup_admin_clients(client_config, use_admin_base_url=False) + if client_config.backend == "dynamo" or client_config.admin_base_url + else self._admin_clients + ) + self._admin_api = setup_admin_api(client_config) self._skip_model_check = client_config.skip_model_check self._wait_for_ready_timeout = client_config.wait_for_ready_timeout self._eval_cycle = cycle(self._eval_clients) @@ -131,14 +433,43 @@ async def select_train_client(self, load: Mapping[ClientIdentity, int]) -> vf.Cl await asyncio.sleep(0.5) return min(self.train_clients, key=lambda c: load[client_identity(c)]) + async def _ensure_admin_clients(self, timeout: int) -> None: + if self._admin_clients: + return + if self._client_config.backend != "dynamo" or self._client_config.admin_base_url: + self._admin_clients = setup_admin_clients(self._client_config) + return + + logger = get_logger() + wait_time = 0 + while wait_time < timeout: + try: + self._admin_clients = await asyncio.to_thread(setup_admin_clients, self._client_config) + return + except Exception as e: + if wait_time % 10 == 0 and wait_time > 0: + logger.warning( + f"Dynamo worker admin URLs were not discovered after {wait_time} seconds (Error: {e})" + ) + await asyncio.sleep(1) + wait_time += 1 + raise TimeoutError(f"Dynamo worker admin URLs were not discovered after {wait_time} seconds") + async def wait_for_ready(self, model_name: str, timeout: int | None = None) -> None: + timeout = timeout if timeout is not None else self._wait_for_ready_timeout + await self._ensure_admin_clients(timeout) await check_health( - self._admin_clients, timeout=timeout if timeout is not None else self._wait_for_ready_timeout + self._admin_clients, + timeout=timeout, + admin=self._admin_api, + ) + await maybe_check_has_model( + self._model_clients, model_name, skip_model_check=self._skip_model_check, admin=self._admin_api ) - await maybe_check_has_model(self._admin_clients, model_name, skip_model_check=self._skip_model_check) async def update_weights(self, weight_dir: Path | None, lora_name: str | None = None, step: int = 0) -> None: - await update_weights(self._admin_clients, weight_dir, lora_name=lora_name, step=step) + await self._ensure_admin_clients(self._wait_for_ready_timeout) + await update_weights(self._admin_clients, weight_dir, lora_name=lora_name, step=step, admin=self._admin_api) def get_metrics(self) -> dict[str, float]: return {} @@ -185,6 +516,13 @@ def setup_clients( renderer_model_name: str | None = None, pool_size: int | None = None, ) -> list[vf.ClientConfig]: + # Pick the verifiers wire-shape selector based on client_config.backend. + # When backend == "dynamo", both RendererClient and + # OpenAIChatCompletionsTokenClient route through Dynamo's nvext path: + # - request: nvext.token_data carries pre-tokenized prompt + # - response: nvext.engine_data carries completion_token_ids + logprobs + # Default backend keeps the legacy vLLM TITO surface. + renderer_transport = "dynamo" if client_config.backend == "dynamo" else "vllm" clients = [] client_idx = 0 # Only forward the renderer config when the client actually uses a @@ -208,6 +546,9 @@ def setup_clients( vf.ClientConfig( client_idx=client_idx, client_type=client_type, + # Dynamo backend routes both renderer and token clients through + # the nvext path; default backend keeps the legacy vLLM TITO surface. + renderer_transport=renderer_transport, api_base_url=base_url, api_key_var=client_config.api_key_var, timeout=client_config.timeout, @@ -224,14 +565,21 @@ def setup_clients( return clients -def setup_admin_clients(client_config: ClientConfig) -> list[AsyncClient]: +def setup_admin_clients(client_config: ClientConfig, *, use_admin_base_url: bool = True) -> list[AsyncClient]: """Create dedicated admin clients for weight update operations. Uses a separate connection pool to avoid queueing behind streaming requests. - When admin_base_url is set, uses those URLs instead of base_url, allowing - weight updates to bypass routers in disaggregated P/D deployments. + When admin_base_url is set and use_admin_base_url is true, uses those URLs + instead of base_url, allowing weight updates to bypass routers in + disaggregated P/D deployments. For Dynamo, if admin_base_url is unset, + discover worker-advertised system URLs from GET /v1/rl/workers. """ - urls = client_config.admin_base_url if client_config.admin_base_url else client_config.base_url + if use_admin_base_url and client_config.admin_base_url: + urls = client_config.admin_base_url + elif use_admin_base_url and client_config.backend == "dynamo": + urls = discover_dynamo_admin_base_urls(client_config) + else: + urls = client_config.base_url def _setup_admin_client(base_url: str) -> httpx.AsyncClient: env_headers = { @@ -255,23 +603,90 @@ def _setup_admin_client(base_url: str) -> httpx.AsyncClient: return [_setup_admin_client(base_url) for base_url in urls] +def discover_dynamo_admin_base_urls(client_config: ClientConfig) -> list[str]: + urls: list[str] = [] + # Match the header set used by _setup_admin_client: static headers plus + # env-resolved headers, so discovery passes the same auth/routing headers + # that the admin clients themselves will carry. + env_headers = { + k: v for k, v in ((k, os.getenv(v)) for k, v in client_config.headers_from_env.items()) if v is not None + } + headers = {**client_config.headers, **env_headers} + api_key = os.getenv(client_config.api_key_var, "EMPTY") + if api_key and api_key != "EMPTY": + headers["Authorization"] = f"Bearer {api_key}" + + for base_url in _dynamo_rl_discovery_base_urls(client_config): + discovery_base = base_url.rstrip("/").removesuffix("/v1") + with httpx.Client( + base_url=discovery_base, + headers=headers, + timeout=httpx.Timeout(connect=client_config.connect_timeout, read=30.0, write=30.0, pool=10.0), + ) as client: + response = client.get("/v1/rl/workers") + response.raise_for_status() + for worker in response.json().get("workers", []): + system_url = worker.get("system_url") + if system_url: + urls.append(system_url) + + deduped = list(dict.fromkeys(urls)) + if not deduped: + raise ValueError( + "Dynamo backend did not discover any worker system URLs from /v1/rl/workers. " + "Set client.admin_base_url explicitly, set client.rl_base_url to the Dynamo " + "RL discovery listener, and make sure Dynamo workers run with DYN_ENABLE_RL " + "and a system status server enabled." + ) + return deduped + + +def _dynamo_rl_discovery_base_urls(client_config: ClientConfig) -> list[str]: + configured = getattr(client_config, "rl_base_url", None) + if configured: + return configured + + rl_port = int(os.getenv("DYN_RL_PORT", "8001")) + return [_replace_url_port(base_url, rl_port) for base_url in client_config.base_url] + + +def _replace_url_port(base_url: str, port: int) -> str: + parsed = urlsplit(base_url.rstrip("/").removesuffix("/v1")) + scheme = parsed.scheme or "http" + host = parsed.hostname or parsed.netloc + if not host: + raise ValueError(f"Cannot derive Dynamo RL discovery URL from base_url={base_url!r}") + if ":" in host and not host.startswith("["): + host = f"[{host}]" + netloc = f"{host}:{port}" + return urlunsplit((scheme, netloc, "", "", "")) + + async def maybe_check_has_model( - admin_clients: list[AsyncClient], model_name: str, skip_model_check: bool = False + admin_clients: list[AsyncClient], + model_name: str, + skip_model_check: bool = False, + *, + admin: AdminAPI = _DEFAULT_ADMIN, ) -> None: if skip_model_check: return logger = get_logger() logger.debug(f"Checking if model {model_name} is in the inference pool") - results = await asyncio.gather(*[admin_client.get("/v1/models") for admin_client in admin_clients]) - for admin_client, result in zip(admin_clients, results): - models = result.json()["data"] + results = await asyncio.gather(*[admin.list_models(admin_client) for admin_client in admin_clients]) + for admin_client, models in zip(admin_clients, results): if not any(model["id"] == model_name for model in models): raise ValueError(f"Model {model_name} was not found in the inference pool on {admin_client.base_url}") logger.debug(f"Model {model_name} was found in the inference pool") async def check_health( - admin_clients: list[AsyncClient], interval: int = 1, log_interval: int = 10, timeout: int = 1800 + admin_clients: list[AsyncClient], + interval: int = 1, + log_interval: int = 10, + timeout: int = 1800, + *, + admin: AdminAPI = _DEFAULT_ADMIN, ) -> None: logger = get_logger() @@ -280,7 +695,7 @@ async def _check_health(admin_client: AsyncClient) -> None: logger.debug("Starting pinging /health to check health") while wait_time < timeout: try: - await admin_client.get("/health") + await admin.health(admin_client) logger.debug(f"Inference pool is ready after {wait_time} seconds") return except NotFoundError: @@ -303,118 +718,44 @@ async def _check_health(admin_client: AsyncClient) -> None: NCCL_READY_MARKER = "NCCL_READY" -def _is_retryable_admin_error(exception: BaseException) -> bool: - """Check if an exception should trigger a retry for an admin op (pause/resume/update_weights).""" - if isinstance(exception, httpx.HTTPStatusError): - # Retry on transient server errors (5xx, e.g. engine briefly unresponsive); - # client errors (4xx) won't fix themselves on retry. - return exception.response.status_code >= 500 - # Retry on transport-level failures (timeouts, connection resets, etc.) so the - # per-attempt read timeout below turns a stuck server into a bounded retry loop - # instead of hanging forever on the global timeout=None admin client. - if isinstance(exception, (httpx.TimeoutException, httpx.TransportError)): - return True - return False - - -# Per-attempt read timeout for admin ops, overridable per call. The admin -# AsyncClient uses `timeout=None`, so without this a stuck server would hang the -# weight update forever: the read timeout converts a hang into a TimeoutException -# that tenacity retries. Sized for `/pause`, which drains in-flight requests -# (mode="keep") and so can legitimately take a while. -ADMIN_TIMEOUT_S = 300.0 -# `/update_weights` runs a collective NCCL receive across all DP workers, which -# can take longer than the other admin ops. -UPDATE_WEIGHTS_TIMEOUT_S = 720.0 - - -async def _admin_post(client: AsyncClient, path: str, *, timeout_s: float = ADMIN_TIMEOUT_S, **kwargs) -> None: - """POST an admin op with a bounded per-attempt timeout, retrying transient errors. - - The total wall-clock budget across all retries is twice the per-attempt timeout. - """ - async for attempt in AsyncRetrying( - retry=retry_if_exception(_is_retryable_admin_error), - stop=stop_after_delay(2 * timeout_s) | stop_after_attempt(10), - wait=wait_exponential(multiplier=1, min=1, max=10), - reraise=True, - ): - with attempt: - response = await client.post( - path, - timeout=httpx.Timeout(connect=10.0, read=timeout_s, write=60.0, pool=10.0), - **kwargs, - ) - response.raise_for_status() - - -async def _pause_engines(admin_clients: list[AsyncClient], *, step: int) -> None: - """Pause all inference engines, waiting for in-flight requests to drain.""" - logger = get_logger() - logger.info(f"Updating policy in-flight to v{step}") - await asyncio.gather( - *[_admin_post(client, "/pause", params={"mode": "keep", "clear_cache": "false"}) for client in admin_clients] - ) - logger.debug("All inference engines paused") - - -async def _resume_engines(admin_clients: list[AsyncClient]) -> None: - """Resume all inference engines after weight update. - - Resuming is idempotent (it just clears the paused flag), so retrying transient - failures is safe; a dropped /resume would leave engines paused indefinitely. - """ - logger = get_logger() - await asyncio.gather(*[_admin_post(client, "/resume") for client in admin_clients]) - logger.debug("All inference engines resumed") - - async def update_weights( admin_clients: list[AsyncClient], weight_dir: Path | None, lora_name: str | None = None, step: int = 0, + *, + admin: AdminAPI = _DEFAULT_ADMIN, ) -> None: """Update weights on static inference servers. - Pauses all engines first to drain in-flight requests, then performs the - weight update, then resumes. This ensures all DP workers are idle and can - participate in the collective weight transfer. - - Note: The server-side /update_weights endpoint automatically resets the prefix cache - to invalidate any cached KV states computed with the old weights. + Pauses all engines to drain in-flight requests, performs the weight update, + then resumes. Ensures all DP workers are idle and can participate in the + collective weight transfer. The server-side ``/update_weights`` endpoint + resets the prefix cache to invalidate any KV states computed with the old + weights. """ logger = get_logger() + if lora_name is not None and weight_dir is not None: + await load_lora_adapter(admin_clients, lora_name, weight_dir, admin=admin) + return + weight_dir_posix = weight_dir.as_posix() if weight_dir is not None else None - if lora_name is not None and weight_dir is not None: - await load_lora_adapter(admin_clients, lora_name, weight_dir) - else: - # Pause engines so all DP workers drain in-flight work and can join the NCCL broadcast - await _pause_engines(admin_clients, step=step) + logger.info("Pausing inference engines for weight update") + await asyncio.gather(*[admin.pause(c) for c in admin_clients]) + try: + # NCCL_READY marker is created before servers enter the receive path + if weight_dir is not None: + nccl_ready_file = weight_dir / NCCL_READY_MARKER + nccl_ready_file.parent.mkdir(parents=True, exist_ok=True) + nccl_ready_file.touch() + logger.debug(f"Created NCCL_READY marker at {nccl_ready_file}") - try: - # Create ready marker before servers enter receive path (used by NCCL broadcast) - if weight_dir is not None: - nccl_ready_file = weight_dir / NCCL_READY_MARKER - nccl_ready_file.parent.mkdir(parents=True, exist_ok=True) - nccl_ready_file.touch() - logger.debug(f"Created NCCL_READY marker at {nccl_ready_file}") - - await asyncio.gather( - *[ - _admin_post( - admin_client, - "/update_weights", - json={"weight_dir": weight_dir_posix}, - timeout_s=UPDATE_WEIGHTS_TIMEOUT_S, - ) - for admin_client in admin_clients - ] - ) - finally: - await _resume_engines(admin_clients) + await asyncio.gather(*[admin.update_weights(c, weight_dir_posix) for c in admin_clients]) + finally: + await asyncio.gather(*[admin.resume(c) for c in admin_clients]) + logger.info("Inference engines resumed") def _is_retryable_lora_error(exception: BaseException) -> bool: @@ -441,7 +782,13 @@ def _is_retryable_lora_error(exception: BaseException) -> bool: LORA_LOAD_TOTAL_TIMEOUT_S = 120.0 -async def load_lora_adapter(admin_clients: list[AsyncClient], lora_name: str, lora_path: Path) -> None: +async def load_lora_adapter( + admin_clients: list[AsyncClient], + lora_name: str, + lora_path: Path, + *, + admin: AdminAPI = _DEFAULT_ADMIN, +) -> None: """Make a HTTP post request to the vLLM server to load a LoRA adapter. Uses our wrapper endpoint that also resets the prefix cache to invalidate @@ -452,6 +799,7 @@ async def load_lora_adapter(admin_clients: list[AsyncClient], lora_name: str, lo """ logger = get_logger() lora_path_posix = lora_path.as_posix() + per_attempt_timeout = httpx.Timeout(connect=10.0, read=LORA_LOAD_READ_TIMEOUT_S, write=60.0, pool=10.0) @retry( retry=retry_if_exception(_is_retryable_lora_error), @@ -461,29 +809,11 @@ async def load_lora_adapter(admin_clients: list[AsyncClient], lora_name: str, lo ) async def _load_lora_adapter(admin_client: AsyncClient) -> None: logger.debug(f"Sending request to load LoRA adapter {lora_name} from {lora_path}") - response = await admin_client.post( - "/load_lora_adapter", - json={"lora_name": lora_name, "lora_path": lora_path_posix}, - timeout=httpx.Timeout(connect=10.0, read=LORA_LOAD_READ_TIMEOUT_S, write=60.0, pool=10.0), - ) - response.raise_for_status() + await admin.load_lora_adapter(admin_client, lora_name, lora_path_posix, timeout=per_attempt_timeout) await asyncio.gather(*[_load_lora_adapter(admin_client) for admin_client in admin_clients]) -async def unload_lora_adapter(admin_clients: list[AsyncClient], lora_name: str) -> None: - """Make a HTTP post request to the vLLM server to unload a LoRA adapter.""" - logger = get_logger() - - async def _unload_lora_adapter(admin_client: AsyncClient) -> None: - logger.debug(f"Sending request to unload LoRA adapter {lora_name}") - await admin_client.post("/v1/unload_lora_adapter", json={"lora_name": lora_name}) - # TODO: The first one can fail, but subsequent ones should succeed. - # response.raise_for_status() - - await asyncio.gather(*[_unload_lora_adapter(admin_client) for admin_client in admin_clients]) - - async def init_nccl_broadcast( admin_clients: list[AsyncClient], host: str, @@ -491,6 +821,8 @@ async def init_nccl_broadcast( timeout: int, inference_world_size: int | None = None, quantize_in_weight_transfer: bool = False, + *, + admin: AdminAPI = _DEFAULT_ADMIN, ) -> None: """Initialize NCCL broadcast on all inference servers. @@ -515,18 +847,15 @@ async def init_nccl_broadcast( async def _init_nccl_broadcast(admin_client: AsyncClient, rank_offset: int) -> None: try: - response = await admin_client.post( - "/init_broadcaster", - json={ - "host": host, - "port": port, - "rank_offset": rank_offset, - "inference_world_size": inference_world_size, - "timeout": timeout, - "quantize_in_weight_transfer": quantize_in_weight_transfer, - }, + await admin.init_broadcaster( + admin_client, + host=host, + port=port, + rank_offset=rank_offset, + inference_world_size=inference_world_size, + timeout=timeout, + quantize_in_weight_transfer=quantize_in_weight_transfer, ) - response.raise_for_status() except httpx.HTTPStatusError as e: if e.response.status_code == 404: logger.warning("The route /init_broadcaster does not exist. Skipping NCCL broadcast initialization.") diff --git a/src/prime_rl/utils/elastic.py b/src/prime_rl/utils/elastic.py index 951b3673c1..ff15a0ae9c 100644 --- a/src/prime_rl/utils/elastic.py +++ b/src/prime_rl/utils/elastic.py @@ -15,6 +15,7 @@ from dataclasses import dataclass from pathlib import Path from typing import Literal +from urllib.parse import urlsplit import httpx import verifiers as vf @@ -22,7 +23,15 @@ from renderers import RendererConfig from prime_rl.configs.shared import ClientConfig -from prime_rl.utils.client import ClientIdentity, client_identity, load_lora_adapter, setup_admin_clients, setup_clients +from prime_rl.utils.client import ( + ClientIdentity, + client_identity, + discover_dynamo_admin_base_urls, + load_lora_adapter, + setup_admin_api, + setup_admin_clients, + setup_clients, +) from prime_rl.utils.logger import get_logger # --- Shared discovery functions --- @@ -73,6 +82,26 @@ async def discover_ready_servers(hostname: str, port: int, model_name: str) -> l return sorted(with_model) +def _system_url_matches_ip(system_url: str, ip: str) -> bool: + host = urlsplit(system_url).hostname + if not host: + return False + if host == ip: + return True + try: + _, _, ips = socket.gethostbyname_ex(host) + except socket.gaierror: + return False + return ip in ips + + +def _find_system_url_for_ip(system_urls: list[str], ip: str) -> str | None: + for system_url in system_urls: + if _system_url_matches_ip(system_url, ip): + return system_url + return None + + @dataclass class AdapterState: """State of a LoRA adapter (loaded or desired).""" @@ -127,6 +156,8 @@ def __init__( self._servers: dict[str, ServerState] = {} self._admin_clients: dict[str, AsyncClient] = {} + self._model_clients: dict[str, AsyncClient] = {} + self._admin_api = setup_admin_api(client_config) self._lock = asyncio.Lock() self._desired: AdapterState = AdapterState() @@ -190,6 +221,9 @@ def _rebuild_clients(self) -> None: self._client_urls = urls self._eval_index = 0 + backend = getattr(self.client_config, "backend", "vllm") + if backend not in ("vllm", "dynamo"): + backend = "vllm" url_config = ClientConfig( timeout=self.client_config.timeout, connect_timeout=self.client_config.connect_timeout, @@ -199,6 +233,7 @@ def _rebuild_clients(self) -> None: headers_from_env=self.client_config.headers_from_env, dp_rank_count=self.client_config.dp_rank_count, extra_headers_from_state=self.client_config.extra_headers_from_state, + backend=backend, ) self._train_clients = ( setup_clients( @@ -248,28 +283,49 @@ def num_servers(self) -> int: def num_ready_servers(self) -> int: return sum(1 for s in self._servers.values() if s.status == "ready") - async def _create_admin_client(self, ip: str) -> AsyncClient: + def _server_client_config(self, ip: str) -> ClientConfig: url = self._build_url(ip) - config = ClientConfig( + return ClientConfig( timeout=self.client_config.timeout, base_url=[f"{url}/v1"], api_key_var=self.client_config.api_key_var, headers=self.client_config.headers, headers_from_env=self.client_config.headers_from_env, + # Propagate backend (and RL discovery URL) so the admin client matches + # the configured backend instead of silently defaulting to vLLM. + backend=self.client_config.backend, + rl_base_url=self.client_config.rl_base_url, ) + + async def _create_admin_client(self, ip: str) -> AsyncClient: + config = self._server_client_config(ip) + if config.backend == "dynamo" and not config.admin_base_url: + # Dynamo admin (/engine/*) lives on the worker system server, not the + # inference port. Resolve THIS pod's system URL from RL discovery + # (matched by raw host or DNS-resolved IP) and pin it, so admin ops + # don't hit the inference port or a different worker. + system_urls = await asyncio.to_thread(discover_dynamo_admin_base_urls, config) + match = await asyncio.to_thread(_find_system_url_for_ip, system_urls, ip) + if match is None: + raise ValueError( + f"Dynamo RL discovery did not return a worker system URL matching inference pod {ip}. " + f"Discovered system URLs: {system_urls}" + ) + config = config.model_copy(update={"admin_base_url": [match]}) return setup_admin_clients(config)[0] + def _create_model_client(self, ip: str) -> AsyncClient: + # Dynamo admin clients point at the worker system server. Model checks + # still need the OpenAI-compatible inference URL. + return setup_admin_clients(self._server_client_config(ip), use_admin_base_url=False)[0] + async def _get_loaded_adapter(self, ip: str) -> AdapterState | None: - if ip not in self._admin_clients: + model_client = self._model_clients.get(ip) or self._admin_clients.get(ip) + if model_client is None: return None try: - admin = self._admin_clients[ip] - response = await admin.get("/v1/models") - response.raise_for_status() - data = response.json() - - for model in data.get("data", []): + for model in await self._admin_api.list_models(model_client): parent = model.get("parent") model_id = model.get("id", "") @@ -334,7 +390,9 @@ async def _sync_server_adapter(self, ip: str) -> bool: if self._desired.name and self._desired.path: try: self.logger.debug(f"Loading adapter {self._desired.name} on {ip}") - await load_lora_adapter([self._admin_clients[ip]], self._desired.name, self._desired.path) + await load_lora_adapter( + [self._admin_clients[ip]], self._desired.name, self._desired.path, admin=self._admin_api + ) except Exception as e: server.status = "unhealthy" server.sync_failures += 1 @@ -360,21 +418,19 @@ async def _sync_server_adapter(self, ip: str) -> bool: server.sync_failures += 1 return False - async def _check_server_health(self, admin_client: AsyncClient, ip: str) -> bool: + async def _check_server_health( + self, admin_client: AsyncClient, ip: str, model_client: AsyncClient | None = None + ) -> bool: try: - response = await admin_client.get("/health") - response.raise_for_status() + await self._admin_api.health(admin_client) except Exception as e: self.logger.debug(f"Server {ip} health check failed: {e}") return False try: - response = await admin_client.get("/v1/models") - response.raise_for_status() - data = response.json() - models = [m.get("id") for m in data.get("data", [])] - - if self.base_model_name not in models: + model_client = model_client or self._model_clients.get(ip) or admin_client + models = await self._admin_api.list_models(model_client) + if self.base_model_name not in [m.get("id") for m in models]: self.logger.debug(f"Server {ip} does not have base model {self.base_model_name}") return False except Exception as e: @@ -384,18 +440,30 @@ async def _check_server_health(self, admin_client: AsyncClient, ip: str) -> bool return True async def _add_server(self, ip: str) -> bool: + admin_client: AsyncClient | None = None + model_client: AsyncClient | None = None try: admin_client = await self._create_admin_client(ip) + model_client = ( + self._create_model_client(ip) if self.client_config.backend == "dynamo" else admin_client + ) except Exception as e: self.logger.debug(f"Failed to create admin client for {ip}: {e}") + if admin_client is not None: + await admin_client.aclose() + if model_client is not None and model_client is not admin_client: + await model_client.aclose() return False - if not await self._check_server_health(admin_client, ip): + if not await self._check_server_health(admin_client, ip, model_client): await admin_client.aclose() + if model_client is not admin_client: + await model_client.aclose() return False self.logger.debug(f"Discovered new inference server: {ip}") self._admin_clients[ip] = admin_client + self._model_clients[ip] = model_client self._servers[ip] = ServerState(ip=ip, url=self._build_url(ip), status="discovering") await self._sync_server_adapter(ip) return True @@ -403,8 +471,12 @@ async def _add_server(self, ip: str) -> bool: async def _remove_server(self, ip: str) -> None: self.logger.debug(f"Inference server removed: {ip}") self._servers.pop(ip, None) - if ip in self._admin_clients: - await self._admin_clients.pop(ip).aclose() + admin_client = self._admin_clients.pop(ip, None) + model_client = self._model_clients.pop(ip, None) + if model_client is not None and model_client is not admin_client: + await model_client.aclose() + if admin_client is not None: + await admin_client.aclose() async def sync(self) -> tuple[int, int]: async with self._lock: @@ -427,7 +499,9 @@ async def sync(self) -> tuple[int, int]: for ip in list(self._servers.keys()): if ip not in self._admin_clients: continue - if not await self._check_server_health(self._admin_clients[ip], ip): + if not await self._check_server_health( + self._admin_clients[ip], ip, self._model_clients.get(ip) + ): self.logger.debug(f"Server {ip} failed health check, removing") await self._remove_server(ip) removed += 1 diff --git a/tests/unit/utils/test_client.py b/tests/unit/utils/test_client.py index 69325e6c4a..fc1336474a 100644 --- a/tests/unit/utils/test_client.py +++ b/tests/unit/utils/test_client.py @@ -1,12 +1,12 @@ import asyncio from pathlib import Path -from unittest.mock import AsyncMock, MagicMock +from unittest.mock import AsyncMock, MagicMock, patch import httpx import verifiers as vf from prime_rl.configs.shared import ClientConfig -from prime_rl.utils.client import _is_retryable_lora_error, load_lora_adapter, setup_clients +from prime_rl.utils.client import StaticInferencePool, _is_retryable_lora_error, load_lora_adapter, setup_clients def test_is_retryable_lora_error_returns_true_for_404(): @@ -117,3 +117,42 @@ def test_setup_clients_preserves_chat_client_defaults(): extra_headers_from_state={}, ) ] + + +def test_static_dynamo_admin_discovery_retries_in_wait_for_ready(): + client_config = ClientConfig( + base_url=["http://worker-a:8000/v1"], + api_key_var="PRIME_API_KEY", + backend="dynamo", + wait_for_ready_timeout=2, + ) + model_client = AsyncMock() + admin_client = AsyncMock() + admin_attempts = 0 + + def fake_setup_admin_clients(config, *, use_admin_base_url=True): + nonlocal admin_attempts + if not use_admin_base_url: + return [model_client] + admin_attempts += 1 + if admin_attempts == 1: + raise ValueError("workers not ready") + return [admin_client] + + with ( + patch("prime_rl.utils.client.setup_admin_clients", side_effect=fake_setup_admin_clients), + patch("prime_rl.utils.client.check_health", new=AsyncMock()) as mock_check_health, + patch("prime_rl.utils.client.maybe_check_has_model", new=AsyncMock()) as mock_check_has_model, + patch("prime_rl.utils.client.asyncio.sleep", new=AsyncMock()), + ): + pool = StaticInferencePool(client_config, model_name="test-model") + assert pool.admin_clients == [] + + asyncio.run(pool.wait_for_ready("test-model", timeout=2)) + + assert admin_attempts == 2 + assert pool.admin_clients == [admin_client] + mock_check_health.assert_awaited_once_with([admin_client], timeout=2, admin=pool._admin_api) + mock_check_has_model.assert_awaited_once_with( + [model_client], "test-model", skip_model_check=False, admin=pool._admin_api + ) diff --git a/tests/unit/utils/test_elastic.py b/tests/unit/utils/test_elastic.py index 21490497e7..cb2e541772 100644 --- a/tests/unit/utils/test_elastic.py +++ b/tests/unit/utils/test_elastic.py @@ -4,6 +4,7 @@ from unittest.mock import AsyncMock, MagicMock, patch import httpx +import pytest import verifiers as vf from prime_rl.utils.elastic import ( @@ -404,6 +405,127 @@ def test_get_loaded_adapter_handles_step_dash_format(): assert result.step == 99 +def test_get_loaded_adapter_uses_model_client_when_present(): + with patch("prime_rl.utils.elastic.get_logger"): + mock_config = MagicMock() + mock_config.elastic.hostname = "test.hostname" + mock_config.elastic.port = 8000 + mock_config.elastic.sync_interval = 5.0 + mock_config.router_url = None + pool = ElasticInferencePool(client_config=mock_config, model_name="base-model") + pool._desired.name = "my-lora" + + mock_admin = AsyncMock() + mock_model = AsyncMock() + mock_response = MagicMock() + mock_response.json.return_value = { + "data": [ + {"id": "my-lora", "parent": "base-model", "root": "/weights/step_7"}, + ] + } + mock_model.get.return_value = mock_response + pool._admin_clients["10.0.0.1"] = mock_admin + pool._model_clients["10.0.0.1"] = mock_model + + result = asyncio.run(pool._get_loaded_adapter("10.0.0.1")) + + assert result is not None + assert result.step == 7 + mock_admin.get.assert_not_called() + mock_model.get.assert_awaited_once_with("/v1/models") + + +def test_check_server_health_uses_model_client_for_model_list(): + with patch("prime_rl.utils.elastic.get_logger"): + mock_config = MagicMock() + mock_config.elastic.hostname = "test.hostname" + mock_config.elastic.port = 8000 + mock_config.elastic.sync_interval = 5.0 + mock_config.router_url = None + pool = ElasticInferencePool(client_config=mock_config, model_name="base-model") + + mock_admin = AsyncMock() + mock_model = AsyncMock() + mock_model_response = MagicMock() + mock_model_response.json.return_value = {"data": [{"id": "base-model"}]} + mock_model.get.return_value = mock_model_response + + result = asyncio.run(pool._check_server_health(mock_admin, "10.0.0.1", mock_model)) + + assert result is True + mock_admin.get.assert_awaited_once_with("/health") + mock_model.get.assert_awaited_once_with("/v1/models") + + +def test_create_admin_client_matches_dynamo_system_url_by_resolved_ip(): + with patch("prime_rl.utils.elastic.get_logger"): + client_config = MagicMock() + client_config.elastic.hostname = "test.hostname" + client_config.elastic.port = 8000 + client_config.elastic.sync_interval = 5.0 + client_config.router_url = None + client_config.timeout = 1200 + client_config.api_key_var = "PRIME_API_KEY" + client_config.headers = {} + client_config.headers_from_env = {} + client_config.backend = "dynamo" + client_config.rl_base_url = None + pool = ElasticInferencePool(client_config=client_config, model_name="base-model") + + captured = [] + + def fake_setup_admin_clients(config): + captured.append(config) + return ["admin-client"] + + def fake_gethostbyname_ex(hostname): + if hostname == "worker-b.local": + return hostname, [], ["10.0.0.2"] + return hostname, [], ["10.0.0.9"] + + with ( + patch( + "prime_rl.utils.elastic.discover_dynamo_admin_base_urls", + return_value=["http://worker-a.local:8081", "http://worker-b.local:8081"], + ), + patch("prime_rl.utils.elastic.socket.gethostbyname_ex", side_effect=fake_gethostbyname_ex), + patch("prime_rl.utils.elastic.setup_admin_clients", side_effect=fake_setup_admin_clients), + ): + result = asyncio.run(pool._create_admin_client("10.0.0.2")) + + assert result == "admin-client" + assert captured[0].admin_base_url == ["http://worker-b.local:8081"] + + +def test_create_admin_client_fails_without_matching_dynamo_system_url(): + with patch("prime_rl.utils.elastic.get_logger"): + client_config = MagicMock() + client_config.elastic.hostname = "test.hostname" + client_config.elastic.port = 8000 + client_config.elastic.sync_interval = 5.0 + client_config.router_url = None + client_config.timeout = 1200 + client_config.api_key_var = "PRIME_API_KEY" + client_config.headers = {} + client_config.headers_from_env = {} + client_config.backend = "dynamo" + client_config.rl_base_url = None + pool = ElasticInferencePool(client_config=client_config, model_name="base-model") + + with ( + patch( + "prime_rl.utils.elastic.discover_dynamo_admin_base_urls", + return_value=["http://worker-a.local:8081"], + ), + patch("prime_rl.utils.elastic.socket.gethostbyname_ex", return_value=("worker-a.local", [], ["10.0.0.9"])), + patch("prime_rl.utils.elastic.setup_admin_clients") as mock_setup_admin_clients, + pytest.raises(ValueError, match="matching inference pod 10.0.0.2"), + ): + asyncio.run(pool._create_admin_client("10.0.0.2")) + + mock_setup_admin_clients.assert_not_called() + + def test_elastic_clients_preserve_renderer_model_name_when_model_name_updates(): with patch("prime_rl.utils.elastic.get_logger"): client_config = MagicMock() @@ -452,3 +574,33 @@ def test_elastic_clients_preserve_renderer_model_name_when_model_name_updates(): extra_headers_from_state={}, ) ] + + +def test_elastic_clients_preserve_dynamo_backend_for_transport(): + with patch("prime_rl.utils.elastic.get_logger"): + client_config = MagicMock() + client_config.elastic.hostname = "test.hostname" + client_config.elastic.port = 8000 + client_config.elastic.sync_interval = 5.0 + client_config.router_url = None + client_config.timeout = 1200 + client_config.connect_timeout = 30.0 + client_config.api_key_var = "PRIME_API_KEY" + client_config.headers = {} + client_config.headers_from_env = {} + client_config.extra_headers_from_state = {} + client_config.dp_rank_count = 1 + client_config.backend = "dynamo" + + pool = ElasticInferencePool( + client_config=client_config, + model_name="test-model", + train_client_type="openai_chat_completions_token", + ) + pool._servers = { + "10.0.0.1": MagicMock(status="ready"), + } + + clients = pool.train_clients + + assert clients[0].renderer_transport == "dynamo"