diff --git a/bench/bench_compare.py b/bench/bench_compare.py new file mode 100644 index 0000000000..aedc208f74 --- /dev/null +++ b/bench/bench_compare.py @@ -0,0 +1,164 @@ +"""Drafter vs no-drafter A/B bench for the asymmetric cluster. + +For each length in --lengths, runs the same prompt twice: once with +``use_drafter=True``, once with ``use_drafter=False``. Reports per-run +TPS, drafter telemetry, and the speedup ratio. + +Sleeps briefly between runs so the model isn't warm-cache for one and +cold for the other; first run of each length pair is the +"throw-away" warmup, subsequent are timed (when --warmup is set). +""" + +from __future__ import annotations + +import argparse +import json +import time +import urllib.request +from typing import Final + +API_URL: Final[str] = "http://192.168.1.224:52415/v1/chat/completions" +MODEL: Final[str] = "mlx-community/gemma-4-31b-it-bf16" + +PROMPT: Final[str] = ( + "Write a detailed, comprehensive technical reference on distributed " + "speculative decoding for large language models. Cover the following " + "topics in depth, with examples, equations, and pseudocode where " + "relevant: (1) architectural foundations of speculative decoding, " + "(2) the role of drafter vs target models and how acceptance/rejection " + "is computed, (3) multi-token prediction (MTP) heads vs separate drafter " + "models, (4) tensor-parallel verification and KV cache rollback semantics, " + "(5) asymmetric placement on heterogeneous clusters, (6) wire-protocol " + "design for drafter/target IPC, (7) failure modes (drafter death, target " + "rank crashes, partitions) and recovery strategies, (8) tuning K (draft " + "depth) for different workloads, (9) integration with continuous batching " + "and paged attention, (10) practical performance results from real " + "deployments. Use markdown headings and detailed prose. Begin now." +) + + +def run_once(max_tokens: int, use_drafter: bool, timeout: int) -> dict[str, object]: + body: dict[str, object] = { + "model": MODEL, + "messages": [{"role": "user", "content": PROMPT}], + "max_tokens": max_tokens, + "temperature": 0.0, + "stream": False, + "use_drafter": use_drafter, + } + payload = json.dumps(body).encode("utf-8") + request = urllib.request.Request( + API_URL, + data=payload, + headers={"Content-Type": "application/json"}, + method="POST", + ) + started = time.monotonic() + with urllib.request.urlopen(request, timeout=timeout) as resp: # noqa: S310 - lan + raw = resp.read().decode("utf-8") + wall = time.monotonic() - started + parsed = json.loads(raw) + usage = parsed.get("usage") or {} + completion = int(usage.get("completion_tokens", 0)) + stats = parsed.get("generation_stats") or {} + return { + "max_tokens": max_tokens, + "use_drafter": use_drafter, + "wall_s": round(wall, 2), + "completion_tokens": completion, + "tps_total": round(completion / wall, 2) if wall > 0 else 0.0, + "drafter_model_id": stats.get("drafter_model_id"), + "draft_mode": stats.get("draft_mode"), + "num_draft_tokens": stats.get("num_draft_tokens"), + "accepted_draft_tokens": stats.get("accepted_draft_tokens"), + "proposed_draft_tokens": stats.get("proposed_draft_tokens"), + "spec_decode_rounds": stats.get("spec_decode_rounds"), + "acceptance_rate": ( + round(stats["accepted_draft_tokens"] / stats["proposed_draft_tokens"], 3) + if stats.get("proposed_draft_tokens") + else None + ), + "fraction_from_drafter": ( + round(stats["accepted_draft_tokens"] / completion, 3) + if stats.get("accepted_draft_tokens") and completion + else None + ), + "finish_reason": parsed.get("choices", [{}])[0].get("finish_reason"), + } + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--lengths", type=int, nargs="+", default=[256, 1024, 2048]) + parser.add_argument("--timeout", type=int, default=900) + parser.add_argument("--out", type=str, default="/tmp/bench_compare.json") + parser.add_argument( + "--sleep-between", + type=float, + default=2.0, + help="Seconds to sleep between runs to let the master settle.", + ) + args = parser.parse_args() + + results: list[dict[str, object]] = [] + summary: list[dict[str, object]] = [] + for length in args.lengths: + print(f"\n=== max_tokens={length} ===", flush=True) + # Run no-drafter first; the drafter run inherits a warm prompt cache + # via prefix-cache-hit, but max_tokens drives the bulk of the + # measured time so this is fine for steady-state TPS comparison. + for use_drafter in (False, True): + try: + r = run_once(length, use_drafter, args.timeout) + except Exception as exc: # noqa: BLE001 - report bench failure + r = { + "max_tokens": length, + "use_drafter": use_drafter, + "error": f"{type(exc).__name__}: {exc}", + } + results.append(r) + print(json.dumps(r, indent=2), flush=True) + time.sleep(args.sleep_between) + + # Speedup summary for this length pair + no_draft = next( + ( + r + for r in results + if r.get("max_tokens") == length and not r.get("use_drafter") + ), + None, + ) + draft = next( + ( + r + for r in results + if r.get("max_tokens") == length and r.get("use_drafter") + ), + None, + ) + if no_draft and draft and "error" not in no_draft and "error" not in draft: + tps_no = float(no_draft.get("tps_total", 0.0) or 0) + tps_yes = float(draft.get("tps_total", 0.0) or 0) + speedup = round(tps_yes / tps_no, 3) if tps_no > 0 else None + row = { + "max_tokens": length, + "tps_no_drafter": tps_no, + "tps_drafter": tps_yes, + "speedup_x": speedup, + "acceptance_rate": draft.get("acceptance_rate"), + "fraction_from_drafter": draft.get("fraction_from_drafter"), + } + print(f"\n>>> speedup at {length}: {json.dumps(row)}", flush=True) + summary.append(row) + + out = {"summary": summary, "raw": results} + with open(args.out, "w", encoding="utf-8") as fh: + json.dump(out, fh, indent=2) + print("\n=== overall summary ===", flush=True) + print(json.dumps(summary, indent=2), flush=True) + print(f"Saved: {args.out}", flush=True) + + +if __name__ == "__main__": + main() diff --git a/bench/bench_concurrent.py b/bench/bench_concurrent.py new file mode 100644 index 0000000000..b8209d08da --- /dev/null +++ b/bench/bench_concurrent.py @@ -0,0 +1,150 @@ +"""Concurrent overlapping spec-decode bench for the asymmetric cluster. + +Fires N parallel chat-completions requests (each with the drafter +enabled) at the master and measures: + - Per-request wall time, completion tokens, individual TPS + - Aggregate cluster TPS (sum of per-request tokens / max wall) + - Time-to-first-token spread + +The point: validate that EXO_MAX_CONCURRENT_REQUESTS > 1 actually +overlaps spec-decode sessions correctly. Single-rank-target placements +trivially share KV; multi-rank tensor-parallel placements with an +asymmetric drafter are the interesting case here. +""" + +from __future__ import annotations + +import argparse +import json +import threading +import time +import urllib.request +from typing import Final + +API_URL: Final[str] = "http://192.168.1.224:52415/v1/chat/completions" +MODEL: Final[str] = "mlx-community/gemma-4-31b-it-bf16" + +PROMPTS: Final[list[str]] = [ + "Explain the architecture of distributed speculative decoding in " + "one paragraph, then list six common failure modes with mitigations.", + "Write a 400-word technical brief on tensor-parallel KV cache " + "rollback semantics, including pseudocode for accept/reject.", + "Outline the difference between MTP heads and external drafter " + "models, and discuss when each is preferable for low-latency serving.", + "Describe how an n-gram drafter integrates with a transformer " + "target model, with attention to stateful processors and RNG.", + "Summarize the trade-offs of pipelined vs synchronous spec-decode, " + "including their interaction with continuous batching.", + "Walk through the wire protocol of a drafter-target IPC channel " + "designed for sub-millisecond round-trip on local sockets.", + "Explain how acceptance probability is computed for vanilla " + "speculative decoding and when greedy acceptance is sound.", + "Discuss the engineering trade-offs between more drafter heads " + "and more drafter depth for a fixed quality bar.", +] + + +def run_one( + idx: int, + prompt: str, + max_tokens: int, + timeout: int, + results: list[dict[str, object]], + started_at: float, +) -> None: + body = { + "model": MODEL, + "messages": [{"role": "user", "content": prompt}], + "max_tokens": max_tokens, + "temperature": 0.0, + "stream": False, + "use_drafter": True, + } + payload = json.dumps(body).encode("utf-8") + req = urllib.request.Request( + API_URL, + data=payload, + headers={"Content-Type": "application/json"}, + method="POST", + ) + relative_start = time.monotonic() - started_at + t0 = time.monotonic() + try: + with urllib.request.urlopen(req, timeout=timeout) as resp: # noqa: S310 - lan + raw = resp.read().decode("utf-8") + wall = time.monotonic() - t0 + parsed = json.loads(raw) + usage = parsed.get("usage", {}) + completion = int(usage.get("completion_tokens", 0)) + result: dict[str, object] = { + "idx": idx, + "relative_start_s": round(relative_start, 2), + "wall_s": round(wall, 2), + "completion_tokens": completion, + "tps": round(completion / wall, 2) if wall > 0 else 0.0, + "finish_reason": parsed.get("choices", [{}])[0].get("finish_reason"), + "first_64": ( + parsed.get("choices", [{}])[0] + .get("message", {}) + .get("content", "")[:64] + ), + } + except Exception as exc: # noqa: BLE001 - report bench failure + wall = time.monotonic() - t0 + result = { + "idx": idx, + "relative_start_s": round(relative_start, 2), + "wall_s": round(wall, 2), + "error": f"{type(exc).__name__}: {exc}", + } + results.append(result) + print(json.dumps(result), flush=True) + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--concurrency", type=int, default=4) + parser.add_argument("--max-tokens", type=int, default=512) + parser.add_argument("--timeout", type=int, default=600) + parser.add_argument("--out", type=str, default="/tmp/bench_concurrent.json") + args = parser.parse_args() + + n = args.concurrency + prompts = (PROMPTS * ((n + len(PROMPTS) - 1) // len(PROMPTS)))[:n] + + results: list[dict[str, object]] = [] + started_at = time.monotonic() + threads: list[threading.Thread] = [] + for i, p in enumerate(prompts): + thread = threading.Thread( + target=run_one, + args=(i, p, args.max_tokens, args.timeout, results, started_at), + ) + threads.append(thread) + thread.start() + for thread in threads: + thread.join() + total_wall = time.monotonic() - started_at + + completed = [r for r in results if "error" not in r] + total_tokens = sum(int(r.get("completion_tokens", 0)) for r in completed) + aggregate_tps = round(total_tokens / total_wall, 2) if total_wall > 0 else 0.0 + summary = { + "concurrency": n, + "max_tokens": args.max_tokens, + "total_wall_s": round(total_wall, 2), + "total_tokens_completed": total_tokens, + "aggregate_tps": aggregate_tps, + "successful": len(completed), + "failed": n - len(completed), + "individual": sorted(results, key=lambda r: int(r["idx"])), + } + print(json.dumps(summary, indent=2), flush=True) + + with open(args.out, "w", encoding="utf-8") as fh: + json.dump(summary, fh, indent=2) + print(f"Saved: {args.out}", flush=True) + + +if __name__ == "__main__": + main() diff --git a/bench/drafter_bench.py b/bench/drafter_bench.py new file mode 100644 index 0000000000..21c0b0c22d --- /dev/null +++ b/bench/drafter_bench.py @@ -0,0 +1,565 @@ +# type: ignore +#!/usr/bin/env python3 +"""Drafter A/B benchmark for exo. + +Hits a running exo cluster's OpenAI-compatible API with a fixed prompt set and +captures per-request stats (prompt tps, generation tps, TTFT, drafter +acceptance). Used to compare drafter modes (`none`, `model`, `ngram`, +`pipelined`) and deployment topologies (single-host in-process vs. +asymmetric N+1 with the drafter on a separate node over jaccl/ring). + +Usage: + uv run python bench/drafter_bench.py \ + --host 127.0.0.1 --port 52415 \ + --model mlx-community/gemma-4-26b-a4b-it-bf16 \ + --label local-none --use-drafter false \ + --runs 3 --max-tokens 256 \ + --out /tmp/drafter_bench/local-none.json + +The script does NOT manage exo lifecycle or placements -- start exo with the +desired EXO_DRAFT_MODE / drafter_eligible_nodes / backend config first, wait +for the model instance to be Ready, then point this at the API. Output is +JSON written to ``--out``. +""" + +from __future__ import annotations + +import argparse +import concurrent.futures +import contextlib +import http.client +import json +import statistics +import sys +import time +from dataclasses import asdict, dataclass, field +from pathlib import Path +from typing import Any, Literal + +PROMPTS: dict[str, dict[str, str | int]] = { + "short_repetitive": { + "system": "You are a careful, concise assistant.", + "user": ( + "Write a numbered list of 20 increasingly detailed bullet points " + "describing the daily routine of a software engineer who works " + "from home. Each bullet should reuse the phrase 'they then' to " + "stitch ideas together so the output is highly repetitive." + ), + "max_tokens": 256, + }, + "code_completion": { + "system": "You are an expert Python typist.", + "user": ( + "Implement an in-memory LRU cache class in Python with the " + "following methods: __init__(capacity: int), get(key: str) -> " + "Optional[Any], put(key: str, value: Any) -> None. Use only the " + "standard library, include type hints, raise ValueError on " + "non-positive capacity, and add a one-paragraph docstring " + "explaining the cache invariants. Do NOT include tests." + ), + "max_tokens": 384, + }, + "creative_prose": { + "system": "You are a literary writer.", + "user": ( + "Write a 350-400 word atmospheric short story set in an " + "abandoned lighthouse on the night a comet passes Earth. Use " + "vivid sensory detail and avoid clich\u00e9s. End on an " + "unresolved image." + ), + "max_tokens": 512, + }, + "factual_qa": { + "system": "You are a precise factual assistant.", + "user": ( + "Explain how Apple's unified memory architecture differs from " + "discrete-GPU systems for large language model inference. " + "Include three concrete numbers (memory bandwidth, capacity " + "ranges, typical latency) with sources cited inline. Keep it " + "under 250 words." + ), + "max_tokens": 384, + }, + "long_context_summary": { + "system": "You are a careful research assistant.", + "user": ( + "Below is a long technical document about distributed LLM " + "inference systems. Read it carefully and produce a 600-800 word " + "structured summary that covers: (1) the systems mentioned and " + "what each is for, (2) the network fabrics they use, (3) the " + "trade-offs between tensor and pipeline parallelism, (4) where " + "speculative decoding helps and where it doesn't, (5) what is " + "missing from the document. End with a numbered list of three " + "follow-up research questions.\n\n" + "DOCUMENT:\n\n" + + ( + "Distributed inference of large language models has become a " + "central topic in 2024-2026 because frontier model sizes have " + "outpaced the memory of any single accelerator. The landscape " + "now includes pipeline-parallel systems such as Petals and " + "exo, tensor-parallel systems like NVIDIA's Megatron-LM and " + "MLX-distributed, hybrid approaches such as DeepSpeed-Inference " + "and vLLM's tensor+pipeline split, and speculative-decoding " + "frameworks including Medusa, EAGLE, EAGLE-2, lookahead " + "decoding, and Google's MTP drafter family. Each system makes " + "a different trade-off between bandwidth, latency, and the " + "operational complexity of coordinating multiple devices.\n\n" + "Networking fabrics in this space split into three broad " + "categories: NVLink/NVSwitch on a single host, RDMA-style " + "fabrics such as InfiniBand, RoCE, and Apple's Thunderbolt + " + "JACCL, and ordinary TCP/IP over Ethernet. Bandwidth ranges " + "span four orders of magnitude (NVLink at 900 GB/s, " + "Thunderbolt 4 RDMA around 40 Gbps effective, 100 Gbps RoCE " + "in datacenters, and 1-10 Gbps Ethernet for hobbyist clusters). " + "Latency follows a similar spread: NVLink under 1 microsecond, " + "RDMA fabrics in the 1-10 microsecond range, TCP/IP at 50-500 " + "microseconds depending on switching topology.\n\n" + "Tensor parallelism splits each weight matrix across devices " + "and aggregates partial outputs with all-reduce. It is " + "bandwidth-heavy because every layer round-trips activations. " + "Pipeline parallelism instead splits the model by layer " + "depth and pipelines micro-batches; it has smaller per-step " + "communication but introduces pipeline bubbles when " + "micro-batch counts are low. Hybrid 2D parallelism combines " + "both, which is standard at datacenter scale but rarely seen " + "on edge clusters because the latency budget for tensor " + "parallel reductions over commodity network links is too " + "tight.\n\n" + "Speculative decoding uses a small draft model to propose " + "tokens and the large target model to verify them in " + "parallel. The expected speed-up depends on draft acceptance " + "rate, the relative cost of draft and target forwards, and " + "the round-trip latency between draft and verify steps. On " + "fast hardware with a small target model the overhead of " + "drafting can exceed the savings, while on slow targets and " + "long contexts the savings dominate. Variants such as Medusa " + "add multiple speculative heads to the target itself, EAGLE " + "uses a tiny auxiliary network conditioned on the target's " + "hidden states, lookahead decoding generates n-grams from the " + "target's own forward pass with no auxiliary network, and " + "MTP teaches the target model to predict multiple future " + "tokens directly.\n\n" + "Cluster-style speculative decoding, where the draft and " + "target live on different hosts connected by RDMA or TCP, is " + "less explored. The relevant questions are how the wire " + "protocol carries draft tokens, draft logits, and acceptance " + "decisions; how the KV caches on draft and target are kept " + "consistent under partial acceptance; and how to pipeline " + "the next draft round behind the current verify step. Apple " + "Silicon clusters such as exo are particularly interesting " + "because unified memory and Thunderbolt RDMA give them a " + "latency profile closer to a single-host NVLink setup than " + "to a typical Ethernet GPU cluster.\n\n" + ) * 3 + ), + "max_tokens": 1024, + }, +} + + +DraftModeArg = Literal["none", "model", "ngram", "pipelined", "auto"] + + +@dataclass +class RequestStats: + prompt_id: str + run_index: int + label: str + use_drafter: bool | None + num_draft_tokens: int | None + draft_mode: str | None + concurrency_slot: int = 0 + prompt_tokens: int = 0 + generation_tokens: int = 0 + prompt_tps: float = 0.0 + generation_tps: float = 0.0 + accepted_draft_tokens: int = 0 + drafter_model_id: str | None = None + response_draft_mode: str | None = None + accept_fraction: float | None = None + ttft_ms: float = 0.0 + wall_seconds: float = 0.0 + error: str | None = None + + +@dataclass +class BenchmarkResult: + label: str + host: str + port: int + model: str + use_drafter: bool | None + num_draft_tokens: int | None + draft_mode: str | None + concurrency: int + runs: int + requests: list[RequestStats] = field(default_factory=list) + + +def _now() -> float: + return time.perf_counter() + + +def _post_chat( + host: str, + port: int, + body: dict[str, Any], + *, + timeout: float, +) -> tuple[dict[str, Any], float, float]: + """Send a streaming chat completion. Returns (final_payload, ttft_ms, wall_s). + + The exo bench endpoint ``/v1/chat/completions`` already returns the + enriched ``BenchChatCompletionResponse`` with ``generation_stats`` in the + final stream chunk. We use streaming purely so we can timestamp the + first token; we still parse the final non-empty SSE event for stats. + """ + body = dict(body) + body["stream"] = True + body["stream_options"] = {"include_usage": True} + # ``/bench/chat/completions`` returns ``BenchChatCompletionResponse``, + # which carries the ``generation_stats`` block we read for tps / accept + # numbers. The standard ``/v1/chat/completions`` does not include them. + payload = json.dumps(body).encode("utf-8") + conn = http.client.HTTPConnection(host, port, timeout=timeout) + conn.request( + "POST", + "/bench/chat/completions", + body=payload, + headers={ + "Content-Type": "application/json", + "Accept": "text/event-stream", + }, + ) + start = _now() + ttft: float | None = None + generation_stats: dict[str, Any] = {} + last_chat_chunk: dict[str, Any] = {} + try: + resp = conn.getresponse() + if resp.status >= 400: + raise RuntimeError( + f"HTTP {resp.status} {resp.reason}: {resp.read(200)!r}" + ) + # ``HTTPResponse.fp.readline`` reads one chunk-encoded line at a + # time and flushes immediately, so we get true streaming SSE + # without urllib's read-ahead buffering. Each event ends in + # ``\n\n`` so the body stream contains both header lines (data: + # / : comment) and blank separator lines we ignore. + while True: + raw = resp.fp.readline() + if not raw: + break + line = raw.decode("utf-8", errors="replace").rstrip("\r\n") + if not line: + continue + # ``generate_chat_stream`` emits ``: generation_stats {json}`` + # as an SSE comment immediately before the terminal ``[DONE]``. + # Parse the comment so we can capture stats without falling + # back to the non-streaming endpoint (which would lose TTFT). + if line.startswith(": generation_stats"): + payload_str = line[len(": generation_stats"):].strip() + with contextlib.suppress(json.JSONDecodeError): + generation_stats = json.loads(payload_str) + continue + if not line.startswith("data:"): + continue + payload_str = line[len("data:"):].strip() + if payload_str == "[DONE]": + break + try: + chunk = json.loads(payload_str) + except json.JSONDecodeError: + continue + choices = chunk.get("choices") or [] + if ttft is None and any( + (c.get("delta") or {}).get("content") for c in choices + ): + ttft = _now() + last_chat_chunk = chunk + finally: + conn.close() + wall = _now() - start + ttft_ms = ((ttft - start) * 1000.0) if ttft is not None else 0.0 + out: dict[str, Any] = dict(last_chat_chunk) + if generation_stats: + out["generation_stats"] = generation_stats + return out, ttft_ms, wall + + +def _run_one( + *, + host: str, + port: int, + model: str, + prompt_id: str, + prompt: dict[str, str | int], + label: str, + use_drafter: bool | None, + num_draft_tokens: int | None, + draft_mode: str | None, + run_index: int, + concurrency_slot: int, + timeout: float, + max_tokens_override: int | None, +) -> RequestStats: + body: dict[str, Any] = { + "model": model, + "messages": [ + {"role": "system", "content": prompt["system"]}, + {"role": "user", "content": prompt["user"]}, + ], + "max_tokens": max_tokens_override or int(prompt["max_tokens"]), + "temperature": 0.0, + } + if use_drafter is not None: + body["use_drafter"] = use_drafter + if num_draft_tokens is not None: + body["num_draft_tokens"] = num_draft_tokens + if draft_mode is not None: + body["draft_mode"] = draft_mode + stats = RequestStats( + prompt_id=prompt_id, + run_index=run_index, + label=label, + use_drafter=use_drafter, + num_draft_tokens=num_draft_tokens, + draft_mode=draft_mode, + concurrency_slot=concurrency_slot, + ) + try: + chunk, ttft_ms, wall = _post_chat(host, port, body, timeout=timeout) + except Exception as exc: + stats.error = f"{type(exc).__name__}: {exc}" + return stats + stats.ttft_ms = ttft_ms + stats.wall_seconds = wall + gen = (chunk.get("generation_stats") or {}) + stats.prompt_tokens = int(gen.get("prompt_tokens") or 0) + stats.generation_tokens = int(gen.get("generation_tokens") or 0) + stats.prompt_tps = float(gen.get("prompt_tps") or 0.0) + stats.generation_tps = float(gen.get("generation_tps") or 0.0) + stats.accepted_draft_tokens = int(gen.get("accepted_draft_tokens") or 0) + stats.drafter_model_id = gen.get("drafter_model_id") + stats.response_draft_mode = gen.get("draft_mode") + if stats.drafter_model_id and stats.generation_tokens: + stats.accept_fraction = ( + stats.accepted_draft_tokens / stats.generation_tokens + ) + return stats + + +def main() -> int: + p = argparse.ArgumentParser() + p.add_argument("--host", default="127.0.0.1") + p.add_argument("--port", type=int, default=52415) + p.add_argument("--model", required=True) + p.add_argument("--label", required=True, help="run label (used in output)") + p.add_argument("--runs", type=int, default=3) + p.add_argument("--max-tokens", type=int, default=None) + p.add_argument("--timeout", type=float, default=600.0) + p.add_argument( + "--use-drafter", + choices=["true", "false", "auto"], + default="auto", + help=( + "force per-request use_drafter override; 'auto' omits the field " + "(model-card default applies)" + ), + ) + p.add_argument( + "--num-draft-tokens", + type=int, + default=None, + help="per-request K override (default: runner config)", + ) + p.add_argument( + "--draft-mode", + choices=["none", "model", "ngram", "pipelined", "auto"], + default="auto", + help=( + "per-request draft_mode override; 'auto' omits the field " + "(model-card / runner default applies). 'none' disables, 'model' " + "uses the model drafter, 'ngram' uses the n-gram drafter, " + "'pipelined' uses pipelined+remote model drafter." + ), + ) + p.add_argument( + "--concurrency", + type=int, + default=1, + help=( + "issue this many requests in parallel against the same instance. " + "Each parallel slot runs the full prompt set sequentially; " + "throughput is reported as the sum across slots over wall time." + ), + ) + p.add_argument( + "--prompts", + nargs="*", + default=list(PROMPTS.keys()), + choices=list(PROMPTS.keys()), + ) + p.add_argument("--warmup", action="store_true", help="run one extra warm-up request before timed runs") + p.add_argument("--out", required=True) + args = p.parse_args() + + use_drafter: bool | None + if args.use_drafter == "true": + use_drafter = True + elif args.use_drafter == "false": + use_drafter = False + else: + use_drafter = None + + draft_mode: str | None = None if args.draft_mode == "auto" else args.draft_mode + + result = BenchmarkResult( + label=args.label, + host=args.host, + port=args.port, + model=args.model, + use_drafter=use_drafter, + num_draft_tokens=args.num_draft_tokens, + draft_mode=draft_mode, + concurrency=args.concurrency, + runs=args.runs, + ) + + if args.warmup: + prompt_id = args.prompts[0] + prompt = PROMPTS[prompt_id] + print(f"[warmup] {prompt_id}", file=sys.stderr) + _run_one( + host=args.host, + port=args.port, + model=args.model, + prompt_id=prompt_id, + prompt=prompt, + label=args.label, + use_drafter=use_drafter, + num_draft_tokens=args.num_draft_tokens, + draft_mode=draft_mode, + run_index=-1, + concurrency_slot=0, + timeout=args.timeout, + max_tokens_override=args.max_tokens, + ) + + work: list[tuple[str, int, int]] = [] + for prompt_id in args.prompts: + for run_index in range(args.runs): + for slot in range(args.concurrency): + work.append((prompt_id, run_index, slot)) + + if args.concurrency <= 1: + for prompt_id, run_index, slot in work: + prompt = PROMPTS[prompt_id] + print( + f"[{args.label}] {prompt_id} run={run_index + 1}/{args.runs}", + file=sys.stderr, + ) + stats = _run_one( + host=args.host, + port=args.port, + model=args.model, + prompt_id=prompt_id, + prompt=prompt, + label=args.label, + use_drafter=use_drafter, + num_draft_tokens=args.num_draft_tokens, + draft_mode=draft_mode, + run_index=run_index, + concurrency_slot=slot, + timeout=args.timeout, + max_tokens_override=args.max_tokens, + ) + result.requests.append(stats) + _print_stats(stats) + else: + # Concurrency mode: dispatch each (prompt, run, slot) tuple into its + # own thread so the server sees ``concurrency`` overlapping requests. + # Threads are fine here because each request is just an HTTP call. + with concurrent.futures.ThreadPoolExecutor( + max_workers=args.concurrency * len(args.prompts) + ) as ex: + futures: list[concurrent.futures.Future[RequestStats]] = [] + for prompt_id, run_index, slot in work: + prompt = PROMPTS[prompt_id] + futures.append( + ex.submit( + _run_one, + host=args.host, + port=args.port, + model=args.model, + prompt_id=prompt_id, + prompt=prompt, + label=args.label, + use_drafter=use_drafter, + num_draft_tokens=args.num_draft_tokens, + draft_mode=draft_mode, + run_index=run_index, + concurrency_slot=slot, + timeout=args.timeout, + max_tokens_override=args.max_tokens, + ) + ) + for fut in concurrent.futures.as_completed(futures): + stats = fut.result() + result.requests.append(stats) + _print_stats(stats) + + out_path = Path(args.out) + out_path.parent.mkdir(parents=True, exist_ok=True) + with out_path.open("w") as f: + json.dump(asdict(result), f, indent=2) + + successful = [r for r in result.requests if r.error is None and r.generation_tokens > 0] + if successful: + gen_tps_values = [r.generation_tps for r in successful] + ttft_values = [r.ttft_ms for r in successful] + wall_values = [r.wall_seconds for r in successful] + # Aggregate throughput sums per-request tps across overlapping slots, + # which is what serving operators actually care about under + # concurrency. Single-slot runs collapse to the per-request tps. + if args.concurrency > 1: + aggregate = sum( + r.generation_tokens for r in successful + ) / max(wall_values) + print( + f"[{args.label}] aggregate gen_tps=" + f"{aggregate:.2f} (concurrency={args.concurrency}) " + f"median per-request gen_tps=" + f"{statistics.median(gen_tps_values):.2f} " + f"median ttft={statistics.median(ttft_values):.1f}ms " + f"runs={len(successful)}", + file=sys.stderr, + ) + else: + print( + f"[{args.label}] median gen_tps=" + f"{statistics.median(gen_tps_values):.2f} " + f"median ttft={statistics.median(ttft_values):.1f}ms " + f"runs={len(successful)}", + file=sys.stderr, + ) + + return 0 if all(r.error is None for r in result.requests) else 1 + + +def _print_stats(stats: RequestStats) -> None: + if stats.error: + print(f" [{stats.prompt_id}/run{stats.run_index}/slot{stats.concurrency_slot}] ERROR: {stats.error}", file=sys.stderr) + else: + print( + f" [{stats.prompt_id}/run{stats.run_index}/slot{stats.concurrency_slot}] " + f"gen={stats.generation_tokens}t @ {stats.generation_tps:.2f}t/s " + f"ttft={stats.ttft_ms:.1f}ms " + f"draft_mode={stats.response_draft_mode} " + f"accept={stats.accept_fraction}", + file=sys.stderr, + ) + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/bench/run_drafter_sweep.sh b/bench/run_drafter_sweep.sh new file mode 100755 index 0000000000..ec8ca67ec6 --- /dev/null +++ b/bench/run_drafter_sweep.sh @@ -0,0 +1,140 @@ +#!/bin/bash +# Orchestration helper for the drafter benchmark sweep. +# +# Runs on the operator's workstation; SSH's into the target host(s) to +# control exo lifecycle so we can sweep ``EXO_DRAFT_MODE`` across runs +# (the env var is read once per process). +# +# Usage (local single-host): +# bash bench/run_drafter_sweep.sh local wc-smbp +# +# Usage (asymmetric twin): +# bash bench/run_drafter_sweep.sh twin-tcp wc-smbp wc-smbpt +# bash bench/run_drafter_sweep.sh twin-rdma wc-smbp wc-smbpt +# +# Output: bench/results/drafter//.json +set -euo pipefail + +SCENARIO="${1:?usage: run_drafter_sweep.sh [drafter_host]}" +TARGET_HOST="${2:?missing target host}" +DRAFTER_HOST="${3:-}" + +MODEL="mlx-community/gemma-4-26b-a4b-it-bf16" +RESULTS_DIR="$(cd "$(dirname "$0")/.." && pwd)/bench/results/drafter/${SCENARIO}" +mkdir -p "${RESULTS_DIR}" + +REMOTE_REPO="/Users/JJ/Development/Tooling/exo" +UV_BIN_SMBP="/opt/homebrew/bin/uv" +UV_BIN_SMBPT="/Users/JJ/.local/bin/uv" + +uv_for() { + case "$1" in + wc-smbp) echo "${UV_BIN_SMBP}" ;; + wc-smbpt) echo "${UV_BIN_SMBPT}" ;; + *) echo "uv" ;; + esac +} + +# Kill any stale exo processes on a host. +exo_kill() { + local host="$1" + ssh "${host}" "pkill -f 'exo.main' 2>/dev/null; pkill -f 'uv run exo' 2>/dev/null; sleep 2; pkill -9 -f 'exo.main' 2>/dev/null; true" +} + +# Start exo in the background; returns when API is up. +exo_start() { + local host="$1" + local mode="$2" + local extra_env="${3:-}" + local uv_bin + uv_bin="$(uv_for "${host}")" + echo "[${host}] starting exo (EXO_DRAFT_MODE=${mode}, extra_env=${extra_env})..." >&2 + ssh "${host}" "cd ${REMOTE_REPO} && rm -f /tmp/exo-${mode}.log && nohup env EXO_DRAFT_MODE=${mode} ${extra_env} ${uv_bin} run exo -v >/tmp/exo-${mode}.log 2>&1 & disown; sleep 1; echo started" + # Wait for API + local tries=0 + until ssh "${host}" "curl -sf http://127.0.0.1:52415/v1/models >/dev/null 2>&1"; do + tries=$((tries + 1)) + if [ "${tries}" -gt 60 ]; then + echo "[${host}] exo failed to start within 120s; tail of log:" >&2 + ssh "${host}" "tail -40 /tmp/exo-${mode}.log" >&2 || true + return 1 + fi + sleep 2 + done + echo "[${host}] exo API ready" >&2 +} + +# Place a single-host instance and wait for Ready. +place_instance() { + local host="$1" + local meta="${2:-MlxRing}" + echo "[${host}] placing instance (meta=${meta}) for ${MODEL}..." >&2 + ssh "${host}" "curl -sf -X POST http://127.0.0.1:52415/place_instance \ + -H 'Content-Type: application/json' \ + -d '{\"model_id\":\"${MODEL}\",\"instance_meta\":\"${meta}\",\"min_nodes\":1}' \ + | head -c 400 ; echo" + # Poll for instance ready + local tries=0 + until ssh "${host}" "curl -sf http://127.0.0.1:52415/instance/placement 2>/dev/null | python3 -c 'import json,sys; d=json.load(sys.stdin); print(any((p.get(\"instance\",{}).get(\"shard_assignments\",{}).get(\"model_id\")==\"${MODEL}\") and p.get(\"phase\",{}).get(\"variant\")==\"Ready\" for p in d.get(\"placements\",[])))' 2>/dev/null | grep -q True"; do + tries=$((tries + 1)) + if [ "${tries}" -gt 240 ]; then + echo "[${host}] instance failed to reach Ready within 240s" >&2 + ssh "${host}" "tail -40 /tmp/exo-*.log" >&2 || true + return 1 + fi + sleep 5 + done + echo "[${host}] instance Ready" >&2 +} + +run_bench() { + local label="$1" + local use_drafter="${2:-auto}" + local out="${RESULTS_DIR}/${label}.json" + echo "[bench] ${label} -> ${out}" >&2 + cd "$(dirname "$0")/.." + /opt/homebrew/bin/uv run python bench/drafter_bench.py \ + --host "${TARGET_HOST_IP:-${TARGET_HOST}}" \ + --port 52415 \ + --model "${MODEL}" \ + --label "${label}" \ + --runs 2 \ + --warmup \ + --use-drafter "${use_drafter}" \ + --out "${out}" || true +} + +case "${SCENARIO}" in + local) + # Resolve TARGET_HOST -> IP because curl/requests on this box go via LAN. + TARGET_HOST_IP="$(ssh "${TARGET_HOST}" "ipconfig getifaddr en0 2>/dev/null || ipconfig getifaddr en16 2>/dev/null" | head -n1)" + : "${TARGET_HOST_IP:?could not resolve LAN IP for ${TARGET_HOST}}" + export TARGET_HOST_IP + echo "[local] sweeping draft modes on ${TARGET_HOST} (${TARGET_HOST_IP})" >&2 + + for mode in none model ngram pipelined; do + exo_kill "${TARGET_HOST}" + sleep 3 + exo_start "${TARGET_HOST}" "${mode}" + place_instance "${TARGET_HOST}" "MlxRing" + run_bench "local-${mode}" auto + exo_kill "${TARGET_HOST}" + sleep 3 + done + ;; + twin-tcp|twin-rdma) + : "${DRAFTER_HOST:?twin scenarios need a drafter host as third arg}" + # Custom card with drafter_eligible_nodes installed *before* exo start. + # Friendly names + asymmetric placement code use NodeId; we look up + # the drafter NodeId via the master's REST API the first time exo + # is up. Two-stage: bring up exo plain to read node IDs, then + # rewrite the card and bring up exo with the asymmetric env. + echo "twin scenarios are wired but require operator-driven node-id" + echo "discovery; see bench/results/drafter/${SCENARIO}/README.md" + echo "for the manual recipe." + ;; + *) + echo "unknown scenario: ${SCENARIO}" >&2 + exit 2 + ;; +esac diff --git a/dashboard/src/lib/components/ChatMessages.svelte b/dashboard/src/lib/components/ChatMessages.svelte index 46c20694c1..a8e173c642 100644 --- a/dashboard/src/lib/components/ChatMessages.svelte +++ b/dashboard/src/lib/components/ChatMessages.svelte @@ -298,6 +298,39 @@ tok/s{/if} {/if} + {#if message.drafterStats} + + {@const modeLabel = (() => { + switch (message.drafterStats.draftMode) { + case "pipelined": + return "PIPELINED"; + case "model": + return "MODEL"; + case "ngram": + return "NGRAM"; + case "eagle": + return "EAGLE"; + case "lookahead": + return "LOOKAHEAD"; + default: + return "SPEC"; + } + })()} + + {modeLabel} + {(message.drafterStats.acceptanceFraction * 100).toFixed(0)}%{#if message.drafterStats.numDraftTokens !== null}K={message.drafterStats.numDraftTokens}{/if} + + {/if} {:else} diff --git a/dashboard/src/lib/stores/app.svelte.ts b/dashboard/src/lib/stores/app.svelte.ts index 0df4ab3f9b..6905ee470c 100644 --- a/dashboard/src/lib/stores/app.svelte.ts +++ b/dashboard/src/lib/stores/app.svelte.ts @@ -292,6 +292,22 @@ export interface PrefillProgress { startedAt: number; } +export interface DrafterStats { + modelId: string; + acceptedDraftTokens: number; + generationTokens: number; + numDraftTokens: number | null; + acceptanceFraction: number; // accepted / generation_tokens + /** Drafting strategy that actually ran: + * - "model": local in-process drafter via mlx_lm.speculative_generate_step + * - "pipelined": asymmetric remote drafter on a separate node (V2 socket wire) + * - "ngram": in-context suffix lookup (no extra weights) + * - "eagle" / "lookahead": reserved + * ``null`` when the engine doesn't surface a mode (older runner builds); + * UI should fall back to the generic "SPEC" pill in that case. */ + draftMode: "model" | "pipelined" | "ngram" | "eagle" | "lookahead" | null; +} + export interface Message { id: string; role: "user" | "assistant" | "system"; @@ -301,6 +317,7 @@ export interface Message { attachments?: MessageAttachment[]; ttftMs?: number; // Time to first token in ms (for assistant messages) tps?: number; // Tokens per second (for assistant messages) + drafterStats?: DrafterStats; // Speculative-decoding telemetry for this turn requestType?: "chat" | "image-generation" | "image-editing"; sourceImageDataUrl?: string; // For image editing regeneration tokens?: TokenData[]; @@ -538,6 +555,7 @@ class AppStore { // Performance metrics ttftMs = $state(null); // Time to first token in ms tps = $state(null); // Tokens per second + drafterStats = $state(null); totalTokens = $state(0); // Total tokens in current response prefillProgress = $state(null); @@ -1676,6 +1694,7 @@ class AppStore { this.currentResponse = prefixText; this.ttftMs = null; this.tps = null; + this.drafterStats = null; this.totalTokens = tokensToKeep.length; try { @@ -1831,10 +1850,39 @@ class AppStore { }, { generation_stats: (data) => { - const stats = data as { generation_tps: number }; + const stats = data as { + generation_tps: number; + generation_tokens?: number; + drafter_model_id?: string | null; + accepted_draft_tokens?: number; + num_draft_tokens?: number | null; + draft_mode?: + | "model" + | "pipelined" + | "ngram" + | "eagle" + | "lookahead" + | "none" + | null; + }; if (stats.generation_tps > 0) { this.tps = stats.generation_tps; } + if (stats.drafter_model_id && stats.generation_tokens && stats.generation_tokens > 0) { + this.drafterStats = { + modelId: stats.drafter_model_id, + acceptedDraftTokens: stats.accepted_draft_tokens ?? 0, + generationTokens: stats.generation_tokens, + numDraftTokens: stats.num_draft_tokens ?? null, + acceptanceFraction: + (stats.accepted_draft_tokens ?? 0) / stats.generation_tokens, + // ``"none"`` from the API maps to ``null`` here because the + // dashboard treats it as "no spec ran" -- the type contract + // for ``DrafterStats`` already gates on a non-null + // ``modelId``, so a "none" mode would be inconsistent. + draftMode: stats.draft_mode === "none" ? null : (stats.draft_mode ?? null), + }; + } }, }, ); @@ -1852,6 +1900,7 @@ class AppStore { m.tokens = [...collectedTokens]; if (this.ttftMs !== null) m.ttftMs = this.ttftMs; if (this.tps !== null) m.tps = this.tps; + if (this.drafterStats !== null) m.drafterStats = this.drafterStats; }); this.syncActiveMessagesIfNeeded(targetConversationId); this.persistConversation(targetConversationId); @@ -2044,10 +2093,39 @@ class AppStore { }, { generation_stats: (data) => { - const stats = data as { generation_tps: number }; + const stats = data as { + generation_tps: number; + generation_tokens?: number; + drafter_model_id?: string | null; + accepted_draft_tokens?: number; + num_draft_tokens?: number | null; + draft_mode?: + | "model" + | "pipelined" + | "ngram" + | "eagle" + | "lookahead" + | "none" + | null; + }; if (stats.generation_tps > 0) { this.tps = stats.generation_tps; } + if (stats.drafter_model_id && stats.generation_tokens && stats.generation_tokens > 0) { + this.drafterStats = { + modelId: stats.drafter_model_id, + acceptedDraftTokens: stats.accepted_draft_tokens ?? 0, + generationTokens: stats.generation_tokens, + numDraftTokens: stats.num_draft_tokens ?? null, + acceptanceFraction: + (stats.accepted_draft_tokens ?? 0) / stats.generation_tokens, + // ``"none"`` from the API maps to ``null`` here because the + // dashboard treats it as "no spec ran" -- the type contract + // for ``DrafterStats`` already gates on a non-null + // ``modelId``, so a "none" mode would be inconsistent. + draftMode: stats.draft_mode === "none" ? null : (stats.draft_mode ?? null), + }; + } }, }, ); @@ -2103,6 +2181,7 @@ class AppStore { // Clear stats when model changes this.ttftMs = null; this.tps = null; + this.drafterStats = null; } /** @@ -2305,6 +2384,7 @@ class AppStore { this.currentResponse = ""; this.ttftMs = null; this.tps = null; + this.drafterStats = null; this.totalTokens = 0; // Build attachments from files @@ -2655,12 +2735,45 @@ class AppStore { }; }, generation_stats: (data) => { - const stats = data as { generation_tps: number }; + const stats = data as { + generation_tps: number; + generation_tokens?: number; + drafter_model_id?: string | null; + accepted_draft_tokens?: number; + num_draft_tokens?: number | null; + draft_mode?: + | "model" + | "pipelined" + | "ngram" + | "eagle" + | "lookahead" + | "none" + | null; + }; if (stats.generation_tps > 0) { this.tps = stats.generation_tps; serverTpsReceived = true; } + if ( + stats.drafter_model_id && + stats.generation_tokens && + stats.generation_tokens > 0 + ) { + this.drafterStats = { + modelId: stats.drafter_model_id, + acceptedDraftTokens: stats.accepted_draft_tokens ?? 0, + generationTokens: stats.generation_tokens, + numDraftTokens: stats.num_draft_tokens ?? null, + acceptanceFraction: + (stats.accepted_draft_tokens ?? 0) / stats.generation_tokens, + // ``"none"`` from the API maps to ``null`` here because the + // dashboard treats it as "no spec ran" -- the type contract + // for ``DrafterStats`` already gates on a non-null + // ``modelId``, so a "none" mode would be inconsistent. + draftMode: stats.draft_mode === "none" ? null : (stats.draft_mode ?? null), + }; + } }, }, ); @@ -2695,6 +2808,9 @@ class AppStore { if (this.tps !== null) { msg.tps = this.tps; } + if (this.drafterStats !== null) { + msg.drafterStats = this.drafterStats; + } }, ); this.syncActiveMessagesIfNeeded(targetConversationId); @@ -3260,6 +3376,7 @@ class AppStore { // Clear performance stats this.ttftMs = null; this.tps = null; + this.drafterStats = null; } /** diff --git a/resources/inference_model_cards/mlx-community--gemma-4-26b-a4b-it-4bit.toml b/resources/inference_model_cards/mlx-community--gemma-4-26b-a4b-it-4bit.toml index 863203b743..2d69f800ec 100644 --- a/resources/inference_model_cards/mlx-community--gemma-4-26b-a4b-it-4bit.toml +++ b/resources/inference_model_cards/mlx-community--gemma-4-26b-a4b-it-4bit.toml @@ -8,7 +8,7 @@ family = "gemma" quantization = "4bit" base_model = "Gemma 4 26B A4B" capabilities = ["text", "vision"] -drafter_model_id = "mlx-community/gemma-4-e2b-it-4bit" +drafter_model_ids = ["mlx-community/gemma-4-e2b-it-4bit", "mlx-community/gemma-4-e4b-it-4bit"] context_length = 262144 diff --git a/resources/inference_model_cards/mlx-community--gemma-4-26b-a4b-it-6bit.toml b/resources/inference_model_cards/mlx-community--gemma-4-26b-a4b-it-6bit.toml index 32a0a84d56..c5fc08ea9d 100644 --- a/resources/inference_model_cards/mlx-community--gemma-4-26b-a4b-it-6bit.toml +++ b/resources/inference_model_cards/mlx-community--gemma-4-26b-a4b-it-6bit.toml @@ -8,7 +8,7 @@ family = "gemma" quantization = "6bit" base_model = "Gemma 4 26B A4B" capabilities = ["text", "vision"] -drafter_model_id = "mlx-community/gemma-4-e2b-it-6bit" +drafter_model_ids = ["mlx-community/gemma-4-e2b-it-6bit", "mlx-community/gemma-4-e4b-it-6bit"] context_length = 262144 diff --git a/resources/inference_model_cards/mlx-community--gemma-4-26b-a4b-it-8bit.toml b/resources/inference_model_cards/mlx-community--gemma-4-26b-a4b-it-8bit.toml index 3201ec8283..e6b564f941 100644 --- a/resources/inference_model_cards/mlx-community--gemma-4-26b-a4b-it-8bit.toml +++ b/resources/inference_model_cards/mlx-community--gemma-4-26b-a4b-it-8bit.toml @@ -8,7 +8,7 @@ family = "gemma" quantization = "8bit" base_model = "Gemma 4 26B A4B" capabilities = ["text", "vision"] -drafter_model_id = "mlx-community/gemma-4-e2b-it-8bit" +drafter_model_ids = ["mlx-community/gemma-4-e2b-it-8bit", "mlx-community/gemma-4-e4b-it-8bit"] context_length = 262144 diff --git a/resources/inference_model_cards/mlx-community--gemma-4-26b-a4b-it-bf16.toml b/resources/inference_model_cards/mlx-community--gemma-4-26b-a4b-it-bf16.toml index 39ea210a64..aabf3c10e6 100644 --- a/resources/inference_model_cards/mlx-community--gemma-4-26b-a4b-it-bf16.toml +++ b/resources/inference_model_cards/mlx-community--gemma-4-26b-a4b-it-bf16.toml @@ -8,7 +8,7 @@ family = "gemma" quantization = "bf16" base_model = "Gemma 4 26B A4B" capabilities = ["text", "vision"] -drafter_model_id = "mlx-community/gemma-4-e2b-it-bf16" +drafter_model_ids = ["mlx-community/gemma-4-e2b-it-bf16", "mlx-community/gemma-4-e4b-it-bf16"] context_length = 262144 diff --git a/resources/inference_model_cards/mlx-community--gemma-4-31b-it-4bit.toml b/resources/inference_model_cards/mlx-community--gemma-4-31b-it-4bit.toml index 87a7584cbb..a4031e6aef 100644 --- a/resources/inference_model_cards/mlx-community--gemma-4-31b-it-4bit.toml +++ b/resources/inference_model_cards/mlx-community--gemma-4-31b-it-4bit.toml @@ -8,7 +8,7 @@ family = "gemma" quantization = "4bit" base_model = "Gemma 4 31B" capabilities = ["text", "vision"] -drafter_model_id = "mlx-community/gemma-4-e2b-it-4bit" +drafter_model_ids = ["mlx-community/gemma-4-e2b-it-4bit", "mlx-community/gemma-4-e4b-it-4bit"] context_length = 262144 diff --git a/resources/inference_model_cards/mlx-community--gemma-4-31b-it-6bit.toml b/resources/inference_model_cards/mlx-community--gemma-4-31b-it-6bit.toml index 0e0314e119..af47f9c508 100644 --- a/resources/inference_model_cards/mlx-community--gemma-4-31b-it-6bit.toml +++ b/resources/inference_model_cards/mlx-community--gemma-4-31b-it-6bit.toml @@ -8,7 +8,7 @@ family = "gemma" quantization = "6bit" base_model = "Gemma 4 31B" capabilities = ["text", "vision"] -drafter_model_id = "mlx-community/gemma-4-e2b-it-6bit" +drafter_model_ids = ["mlx-community/gemma-4-e2b-it-6bit", "mlx-community/gemma-4-e4b-it-6bit"] context_length = 262144 diff --git a/resources/inference_model_cards/mlx-community--gemma-4-31b-it-8bit.toml b/resources/inference_model_cards/mlx-community--gemma-4-31b-it-8bit.toml index 0e33f6ff58..63e205b30d 100644 --- a/resources/inference_model_cards/mlx-community--gemma-4-31b-it-8bit.toml +++ b/resources/inference_model_cards/mlx-community--gemma-4-31b-it-8bit.toml @@ -8,7 +8,7 @@ family = "gemma" quantization = "8bit" base_model = "Gemma 4 31B" capabilities = ["text", "vision"] -drafter_model_id = "mlx-community/gemma-4-e2b-it-8bit" +drafter_model_ids = ["mlx-community/gemma-4-e2b-it-8bit", "mlx-community/gemma-4-e4b-it-8bit"] context_length = 262144 diff --git a/resources/inference_model_cards/mlx-community--gemma-4-31b-it-bf16.toml b/resources/inference_model_cards/mlx-community--gemma-4-31b-it-bf16.toml index 1da7e56e9d..e16e8be3dc 100644 --- a/resources/inference_model_cards/mlx-community--gemma-4-31b-it-bf16.toml +++ b/resources/inference_model_cards/mlx-community--gemma-4-31b-it-bf16.toml @@ -8,7 +8,7 @@ family = "gemma" quantization = "bf16" base_model = "Gemma 4 31B" capabilities = ["text", "vision"] -drafter_model_id = "mlx-community/gemma-4-e2b-it-bf16" +drafter_model_ids = ["mlx-community/gemma-4-e2b-it-bf16", "mlx-community/gemma-4-e4b-it-bf16"] context_length = 262144 diff --git a/scripts/convert_eagle3_to_mlx.py b/scripts/convert_eagle3_to_mlx.py new file mode 100644 index 0000000000..ed1a40d1e7 --- /dev/null +++ b/scripts/convert_eagle3_to_mlx.py @@ -0,0 +1,251 @@ +#!/usr/bin/env python3 +"""Convert an EAGLE-3 PyTorch draft head checkpoint to MLX safetensors. + +This is the *offline* half of the EAGLE-3 integration. It downloads a +pre-trained EAGLE-3 head from HuggingFace (e.g. ``RedHatAI/gemma-4-26B +-A4B-it-speculator.eagle3``) and rewrites the weights in MLX's expected +layout so the runtime side can load them with ``mlx.load`` once the +``EagleDrafter`` runtime lands. + +Why this lives in ``scripts/`` and not in the runtime +----------------------------------------------------- +The runtime EAGLE drafter is currently a scaffolding stub +(``src/exo/worker/engines/mlx/generator/drafter.py::EagleDrafter``) +because Apple Silicon EAGLE wins are gated on ``mlx_lm`` adding +``position_ids`` support for tree-attention verify (open issues +``ml-explore/mlx-lm#846`` and ``#250``). The converter is the durable +half of the work: the artifact it produces sits on disk, and the day +the upstream blocker lifts the runtime can load it without rerunning +this script. Running this script today is safe; consuming the output +just doesn't beat ``DraftMode = "none"`` yet. + +Usage:: + + # convert the head for our gemma-4-26b target + uv run python scripts/convert_eagle3_to_mlx.py \ + --source-repo RedHatAI/gemma-4-26B-A4B-it-speculator.eagle3 \ + --output ~/.exo/eagle_heads/gemma-4-26b-a4b-it-eagle3 \ + --target-num-layers 30 + +The ``--target-num-layers`` flag drives EAGLE's layer-tap selection: +the head fuses pre-layer hidden states from ``{2, N//2, N-3}`` +following the EAGLE-3 reference (Li et al. 2025). For Gemma-4-26b +(N=30) that means layers ``{2, 15, 27}`` -- the value is recorded in +``eagle_config.json`` next to the safetensors so the runtime doesn't +have to recompute it. + +References +---------- +* Li et al., "EAGLE-3: Scaling up Inference Acceleration of Large + Language Models via Training-Free Token-Level Blending," + NeurIPS 2025. https://arxiv.org/abs/2503.01840 +* RedHat draft head for our exact target: + https://huggingface.co/RedHatAI/gemma-4-26B-A4B-it-speculator.eagle3 +* Reference MLX port of EAGLE-3 (Llama-3.1 only, no Gemma-4 adapter, + no tree verify): mlx-lm Discussion #890. The ``eagle_convert.py`` + there is the spiritual ancestor of this script; we keep the layout + compatible so a future Gemma-4 head shape change shows up as a + fail-loudly assert here rather than a silent runtime miscompare. +""" + +from __future__ import annotations + +import argparse +import json +import sys +from dataclasses import dataclass +from pathlib import Path + +# These are heavy imports (torch + safetensors); we defer them to +# ``main`` so ``--help`` works in a clean venv without pulling them in. + + +def eagle3_layer_taps(target_num_layers: int) -> tuple[int, int, int]: + """Compute the {2, N//2, N-3} tap indices for an N-layer target. + + Reference impl: + https://github.com/SafeAILab/EAGLE/blob/main/eagle/traineagle3/modeling_llama_kv.py + Indices are taken against the *target* layer count, not the head's. + """ + return (2, target_num_layers // 2, target_num_layers - 3) + + +@dataclass(frozen=True, slots=True) +class ConvertConfig: + source_repo: str + output_dir: Path + target_num_layers: int + quantize_bits: int | None + dry_run: bool + + +def parse_args() -> ConvertConfig: + parser = argparse.ArgumentParser( + description=__doc__, + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + parser.add_argument( + "--source-repo", + required=True, + help="HuggingFace repo of the PyTorch EAGLE-3 head, e.g. " + "'RedHatAI/gemma-4-26B-A4B-it-speculator.eagle3'.", + ) + parser.add_argument( + "--output", + required=True, + help="Output directory for the MLX-format head.", + ) + parser.add_argument( + "--target-num-layers", + type=int, + required=True, + help="Layer count of the target model. Used to compute the " + "EAGLE layer-tap indices ({2, N//2, N-3}). For Gemma-4-26b " + "this is 30.", + ) + parser.add_argument( + "--quantize-bits", + type=int, + choices=[2, 3, 4, 8], + default=None, + help="If set, run mx.quantize on the head weights at this bit " + "depth before writing. The community prototype reports identical " + "EAGLE acceptance with 4-bit head quantization while halving " + "the head forward cost (mlx-lm discussion #890).", + ) + parser.add_argument( + "--dry-run", + action="store_true", + help="Print the layer mapping and target-layer taps without " + "downloading or writing anything.", + ) + args = parser.parse_args() + return ConvertConfig( + source_repo=args.source_repo, + output_dir=Path(args.output).expanduser().resolve(), + target_num_layers=args.target_num_layers, + quantize_bits=args.quantize_bits, + dry_run=args.dry_run, + ) + + +def main() -> int: + config = parse_args() + layer_taps = eagle3_layer_taps(config.target_num_layers) + + print(f"source_repo: {config.source_repo}") + print(f"output_dir: {config.output_dir}") + print(f"target_num_layers: {config.target_num_layers}") + print(f"layer_taps: {layer_taps}") + print(f"quantize_bits: {config.quantize_bits}") + print(f"dry_run: {config.dry_run}") + + if config.dry_run: + print("\n[dry-run] not downloading or writing any files.") + return 0 + + # Defer heavy imports past --help / --dry-run so they fail loudly + # only when actually needed. The runtime side of EAGLE doesn't need + # any of these; this is a one-time conversion utility. + try: + import mlx.core as mx + import torch # noqa: F401 -- needed by safetensors below + from huggingface_hub import snapshot_download + from safetensors.torch import load_file as load_torch_safetensors + except ImportError as e: + print( + "Missing optional dependency for EAGLE-3 conversion. " + "Install with: uv add --dev torch safetensors huggingface_hub", + file=sys.stderr, + ) + raise SystemExit(1) from e + + print(f"\n[1/4] downloading {config.source_repo} ...") + src_dir = Path(snapshot_download(config.source_repo)) + print(f" -> {src_dir}") + + src_safetensors = src_dir / "model.safetensors" + src_config = src_dir / "config.json" + if not src_safetensors.exists(): + # Some EAGLE-3 releases shard weights; fall back to the index. + print( + f"FATAL: {src_safetensors} not found. Sharded EAGLE-3 heads " + "aren't supported by this converter yet. Open an issue with " + "the repo path and we'll add the shard merge.", + file=sys.stderr, + ) + return 2 + if not src_config.exists(): + print(f"FATAL: missing config.json in {src_dir}", file=sys.stderr) + return 2 + + print(f"[2/4] loading torch weights from {src_safetensors}") + torch_weights = load_torch_safetensors(str(src_safetensors)) + head_config = json.loads(src_config.read_text()) + + # Convert torch -> numpy -> mx.array. We keep the EAGLE key names + # verbatim so the runtime side can reuse the reference loader logic + # without an additional rename map. The fuse layer (``embed_layernorm``, + # ``fc``, ``midlayer.*``) and the reduced-vocab ``lm_head`` are all + # the EAGLE-3 spec defines. + print(f"[3/4] converting {len(torch_weights)} tensors to MLX") + mx_weights: dict[str, mx.array] = {} + for key, tensor in torch_weights.items(): + # bf16 -> float32 -> mx (mlx-core 0.x has limited bf16 ingest; + # the runtime can re-cast at load time if it wants bf16 storage). + np_array = tensor.to(dtype=torch.float32).cpu().numpy() + mx_weights[key] = mx.array(np_array) + + if config.quantize_bits is not None: + print(f" quantizing weights to {config.quantize_bits}-bit") + # mx.quantize lives at module-level, group_size matches mlx_lm + # default (64) which is what RedHat's q4 export expects. + quantized: dict[str, mx.array] = {} + for key, value in mx_weights.items(): + # Quantize linear-layer weights only (2-D). Embeddings and + # layer norms stay full precision; matches mlx_lm's default + # quantization predicate. + if value.ndim == 2 and "norm" not in key.lower(): + w_q, scales, biases = mx.quantize( + value, group_size=64, bits=config.quantize_bits + ) + quantized[key] = w_q + quantized[f"{key}.scales"] = scales + quantized[f"{key}.biases"] = biases + else: + quantized[key] = value + mx_weights = quantized + + print(f"[4/4] writing MLX head to {config.output_dir}") + config.output_dir.mkdir(parents=True, exist_ok=True) + mx.save_safetensors( + str(config.output_dir / "model.safetensors"), + mx_weights, + ) + + # Persist EAGLE-specific metadata next to the weights so the runtime + # doesn't have to recompute layer taps or reach into the source repo. + eagle_meta = { + "source_repo": config.source_repo, + "target_num_layers": config.target_num_layers, + "layer_taps": list(layer_taps), + "quantize_bits": config.quantize_bits, + "head_config": head_config, + } + (config.output_dir / "eagle_config.json").write_text( + json.dumps(eagle_meta, indent=2) + "\n" + ) + + print("\nDone.") + print( + "Note: the runtime EagleDrafter is a NotImplementedError stub " + "today (gated on mlx-lm position_ids upstream). The artifact you " + "just produced is durable -- when the runtime lands, point it at " + f"{config.output_dir} via ModelCard.eagle_head_repo and you're " + "done." + ) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/src/exo/api/adapters/chat_completions.py b/src/exo/api/adapters/chat_completions.py index d10cfb618a..8600595e10 100644 --- a/src/exo/api/adapters/chat_completions.py +++ b/src/exo/api/adapters/chat_completions.py @@ -16,6 +16,7 @@ ErrorInfo, ErrorResponse, FinishReason, + GenerationStats, Logprobs, LogprobsContentItem, StreamingChoiceResponse, @@ -175,6 +176,9 @@ async def chat_request_to_text_generation( presence_penalty=request.presence_penalty, frequency_penalty=request.frequency_penalty, images=images, + use_drafter=request.use_drafter, + num_draft_tokens=request.num_draft_tokens, + draft_mode=request.draft_mode, ) @@ -309,6 +313,7 @@ async def collect_chat_response( finish_reason: FinishReason | None = None error_message: str | None = None last_usage: Usage | None = None + last_stats: GenerationStats | None = None async for chunk in chunk_stream: match chunk: @@ -323,6 +328,12 @@ async def collect_chat_response( if model is None: model = chunk.model last_usage = chunk.usage or last_usage + # ``stats`` is only populated on the final TokenChunk + # (when ``finish_reason`` is set); accumulate so the + # caller's response surfaces drafter telemetry. Earlier + # chunks have ``stats=None``; only the terminal one + # carries the GenerationStats value. + last_stats = chunk.stats or last_stats if chunk.is_thinking: thinking_parts.append(chunk.text) else: @@ -342,6 +353,7 @@ async def collect_chat_response( if model is None: model = chunk.model last_usage = chunk.usage or last_usage + last_stats = chunk.stats or last_stats tool_calls.extend( ToolCall( id=tool.id, @@ -379,5 +391,6 @@ async def collect_chat_response( ) ], usage=last_usage, + generation_stats=last_stats, ).model_dump_json() return diff --git a/src/exo/api/main.py b/src/exo/api/main.py index a4debfb65e..e7b61bf0fb 100644 --- a/src/exo/api/main.py +++ b/src/exo/api/main.py @@ -486,7 +486,7 @@ def _model_list_model_from_card(card: ModelCard) -> ModelListModel: base_model=card.base_model, capabilities=card.capabilities, reasoning_dialect=card.reasoning_dialect, - drafter_model_id=card.drafter_model_id, + drafter_model_ids=list(card.drafter_model_ids), context_length=card.context_length, ) diff --git a/src/exo/api/tests/test_agent_endpoints.py b/src/exo/api/tests/test_agent_endpoints.py index 205f8082cf..fabb271e21 100644 --- a/src/exo/api/tests/test_agent_endpoints.py +++ b/src/exo/api/tests/test_agent_endpoints.py @@ -352,7 +352,7 @@ async def _fail_load(_: ModelId) -> ModelCard: "quantization": "", "base_model": "", "capabilities": [], - "drafter_model_id": None, + "drafter_model_ids": [], } ] diff --git a/src/exo/api/tests/test_chat_completion_request_validation.py b/src/exo/api/tests/test_chat_completion_request_validation.py new file mode 100644 index 0000000000..5244614ad2 --- /dev/null +++ b/src/exo/api/tests/test_chat_completion_request_validation.py @@ -0,0 +1,85 @@ +"""Validation tests for ``ChatCompletionRequest``. + +These tests pin the API-level bounds on the speculative-decoding overrides +exposed via the OpenAI-compatible chat endpoint. The runner allocates a +fixed ``num_draft_tokens`` budget at warmup (``EXO_NUM_DRAFT_TOKENS``); a +per-request override above the budget would historically crash the runner +subprocess via an unhandled ``ValueError`` in ``PipelinedModelDrafter.__init__`` +(regression: aborted K=8 sweep at 14:35:05 took the target rank with it, +leaving the drafter peer wedged in ``RunnerRunning`` while the respawned +target was stuck in ``RunnerIdle``). + +The clamp inside ``generate.py`` defends the runner; the API bound here +defends against pathological values up-front so callers see a structured +422 instead of an opaque mid-stream error. +""" + +import pytest +from pydantic import ValidationError + +from exo.api.types.api import ( + MAX_NUM_DRAFT_TOKENS_PER_REQUEST, + ChatCompletionRequest, +) + + +def _minimal_payload(**overrides: object) -> dict[str, object]: + payload: dict[str, object] = { + "model": "test-model", + "messages": [{"role": "user", "content": "hello"}], + } + payload.update(overrides) + return payload + + +def test_num_draft_tokens_default_is_none() -> None: + request = ChatCompletionRequest.model_validate(_minimal_payload()) + assert request.num_draft_tokens is None + + +def test_num_draft_tokens_within_bounds_is_accepted() -> None: + request = ChatCompletionRequest.model_validate(_minimal_payload(num_draft_tokens=4)) + assert request.num_draft_tokens == 4 + + +def test_num_draft_tokens_at_upper_bound_is_accepted() -> None: + request = ChatCompletionRequest.model_validate( + _minimal_payload(num_draft_tokens=MAX_NUM_DRAFT_TOKENS_PER_REQUEST) + ) + assert request.num_draft_tokens == MAX_NUM_DRAFT_TOKENS_PER_REQUEST + + +def test_num_draft_tokens_above_upper_bound_rejected() -> None: + with pytest.raises(ValidationError) as exc_info: + ChatCompletionRequest.model_validate( + _minimal_payload(num_draft_tokens=MAX_NUM_DRAFT_TOKENS_PER_REQUEST + 1) + ) + + errors = exc_info.value.errors() + assert any( + err["loc"] == ("num_draft_tokens",) and err["type"] == "less_than_equal" + for err in errors + ) + + +def test_num_draft_tokens_pathological_value_rejected() -> None: + """The K=8 case that killed the runner is well below the cap, but the + cap exists to reject genuinely malformed values like 1024.""" + with pytest.raises(ValidationError): + ChatCompletionRequest.model_validate(_minimal_payload(num_draft_tokens=1024)) + + +def test_num_draft_tokens_zero_rejected() -> None: + with pytest.raises(ValidationError) as exc_info: + ChatCompletionRequest.model_validate(_minimal_payload(num_draft_tokens=0)) + + errors = exc_info.value.errors() + assert any( + err["loc"] == ("num_draft_tokens",) and err["type"] == "greater_than_equal" + for err in errors + ) + + +def test_num_draft_tokens_negative_rejected() -> None: + with pytest.raises(ValidationError): + ChatCompletionRequest.model_validate(_minimal_payload(num_draft_tokens=-3)) diff --git a/src/exo/api/types/api.py b/src/exo/api/types/api.py index 597793af72..326c76a020 100644 --- a/src/exo/api/types/api.py +++ b/src/exo/api/types/api.py @@ -1,6 +1,6 @@ import time from collections.abc import Generator -from typing import Annotated, Any, Literal, get_args +from typing import Annotated, Any, Final, Literal, get_args from uuid import uuid4 from pydantic import BaseModel, Field, field_validator @@ -17,6 +17,15 @@ "stop", "length", "tool_calls", "content_filter", "function_call", "error" ] +# Upper bound for the per-request ``num_draft_tokens`` override. The runner +# allocates a fixed wire-protocol budget at warmup (``EXO_NUM_DRAFT_TOKENS``, +# default in ``defaults.py``), and per-request K is clamped to that budget +# inside ``generate.py``. The API-level cap exists to reject pathological +# values up-front (which previously crashed the runner subprocess via an +# unhandled ``ValueError``); 32 is well above any realistic drafter K and +# any sane wire-protocol budget on Apple Silicon. +MAX_NUM_DRAFT_TOKENS_PER_REQUEST: Final[int] = 32 + class ErrorInfo(BaseModel): message: str @@ -49,10 +58,11 @@ class ModelListModel(BaseModel): base_model: str = Field(default="") capabilities: list[str] = Field(default_factory=list) reasoning_dialect: ReasoningDialect = "none" - # When set, identifies a smaller draft model that runners can load - # alongside this target for speculative decoding. Surfaced so dashboards - # and clients can pre-download the drafter. - drafter_model_id: str | None = None + # Smaller draft models the runner can load alongside this target for + # speculative decoding. Listed in preference order (`fastest` first). + # Surfaced so dashboards and clients can pre-download a drafter and + # pick which one to use at request time. + drafter_model_ids: list[str] = Field(default_factory=list) class ModelList(BaseModel): @@ -170,6 +180,79 @@ class ChatCompletionChoice(BaseModel): finish_reason: FinishReason | None = None +class GenerationStats(BaseModel): + prompt_tps: float + generation_tps: float + prompt_tokens: int + generation_tokens: int + peak_memory_usage: Memory + prefix_cache_hit: Literal["none", "partial", "exact"] = "none" + # Speculative-decoding telemetry. ``drafter_model_id`` is set whenever + # speculative decoding actually ran for this request (drafter loaded *and* + # not short-circuited by the short-skip threshold). ``accepted_draft_tokens`` + # counts ``stream_generate`` outputs with ``from_draft=True``: those are + # tokens the drafter proposed *and* the target accepted. The user-facing + # speedup is approximately ``accepted_draft_tokens / generation_tokens``. + drafter_model_id: str | None = None + accepted_draft_tokens: int = 0 + # Total drafts the drafter proposed across all spec-decode rounds. + # ``0`` means either the drafter didn't run or the drafter implementation + # doesn't surface proposal counts (currently only the pipelined drafter + # does). The classical per-position acceptance rate is + # ``accepted_draft_tokens / proposed_draft_tokens``; ``0`` here makes + # that property return ``None`` rather than divide-by-zero. ``mlx_lm``'s + # built-in ``stream_generate(draft_model=...)`` does not expose proposal + # counts at all, so external-model-drafter requests will leave this at 0 + # while still populating ``accepted_draft_tokens``. + proposed_draft_tokens: int = 0 + # Number of speculative-decoding rounds that actually ran. Each round + # proposes ``num_draft_tokens`` drafts (truncated near max_tokens). + # Useful for computing per-round latency in dashboards. ``0`` when the + # drafter didn't run or doesn't surface round counts. + spec_decode_rounds: int = 0 + # K used for speculative_generate_step (None when drafter didn't run). + num_draft_tokens: int | None = None + # Drafting strategy that actually ran for this request: "model" for + # external-drafter spec decoding, "pipelined" for the pipelined+ + # remote drafter, "ngram" for in-context suffix lookup, "eagle" / + # "lookahead" reserved for the upcoming auxiliary-head + Jacobi + # drafters, "none" for non-speculative. None when the engine doesn't + # surface drafting (e.g. image gen). Useful for telemetry dashboards + # to attribute throughput wins to a specific strategy when running + # mixed-mode A/B tests. + draft_mode: ( + Literal["model", "pipelined", "ngram", "eagle", "lookahead", "none"] | None + ) = None + + @property + def drafter_acceptance_fraction(self) -> float | None: + """Fraction of *generated* tokens that came from the drafter. + + Maps directly to wall-clock speedup. ``None`` when no drafter ran. + Always populated when speculative decoding runs (every drafter + tracks ``from_draft`` per emitted token). + """ + if self.drafter_model_id is None or self.generation_tokens == 0: + return None + return self.accepted_draft_tokens / self.generation_tokens + + @property + def drafter_acceptance_rate(self) -> float | None: + """Classical acceptance rate: accepted / proposed (per-position). + + ``None`` when the drafter didn't run *or* when it doesn't track + proposal counts (e.g. external-model drafter via mlx_lm). The + pipelined drafter tracks this. Differs from + :attr:`drafter_acceptance_fraction`: this divides by total drafts + proposed (the standard literature metric for drafter quality); + ``drafter_acceptance_fraction`` divides by total emitted tokens + (the metric for end-to-end speedup). + """ + if self.drafter_model_id is None or self.proposed_draft_tokens == 0: + return None + return self.accepted_draft_tokens / self.proposed_draft_tokens + + class ChatCompletionResponse(BaseModel): id: str object: Literal["chat.completion"] = "chat.completion" @@ -178,15 +261,14 @@ class ChatCompletionResponse(BaseModel): choices: list[ChatCompletionChoice | StreamingChoiceResponse] usage: Usage | None = None service_tier: str | None = None - - -class GenerationStats(BaseModel): - prompt_tps: float - generation_tps: float - prompt_tokens: int - generation_tokens: int - peak_memory_usage: Memory - prefix_cache_hit: Literal["none", "partial", "exact"] = "none" + # Non-OpenAI extension: full generation stats for the request, + # including spec-decode telemetry (drafter id, mode, K, accepted / + # proposed draft tokens, spec rounds, peak memory, prefill TPS). + # Standard OpenAI clients ignore unknown fields; exo's own benches + # and dashboards read this for drafter-effectiveness reporting. + # ``None`` for endpoints that don't run a generation pipeline (e.g. + # tool-call-only completions). + generation_stats: GenerationStats | None = None class ImageGenerationStats(BaseModel): @@ -251,6 +333,26 @@ class ChatCompletionRequest(BaseModel): tool_choice: str | dict[str, Any] | None = None parallel_tool_calls: bool | None = None user: str | None = None + # Speculative-decoding per-request overrides (item 9). These are exo + # extensions to the OpenAI Chat Completions schema -- standard clients + # ignore unknown fields and get the runner's defaults. + use_drafter: bool | None = None + num_draft_tokens: int | None = Field( + default=None, + ge=1, + le=MAX_NUM_DRAFT_TOKENS_PER_REQUEST, + description=( + "Per-request override for the number of speculative draft tokens " + "per round (K). Bounded to " + f"[1, {MAX_NUM_DRAFT_TOKENS_PER_REQUEST}] to prevent malformed " + "requests from triggering wire-protocol failures in the runner." + ), + ) + # Per-request draft-strategy override. ``"model"`` uses the external + # drafter, ``"pipelined"`` uses the pipelined+remote drafter, ``"ngram"`` + # uses CPU n-gram tables, ``"none"`` disables speculation. ``None`` defers + # to the model card / runner default. Mirrors ``draft_mode`` on the task. + draft_mode: Literal["model", "pipelined", "ngram", "none"] | None = None class BenchChatCompletionRequest(ChatCompletionRequest): diff --git a/src/exo/download/coordinator.py b/src/exo/download/coordinator.py index 2028912a78..e915a8f1fd 100644 --- a/src/exo/download/coordinator.py +++ b/src/exo/download/coordinator.py @@ -279,53 +279,60 @@ async def _start_download(self, shard: ShardMetadata) -> None: await self._maybe_chain_drafter_download(shard) async def _maybe_chain_drafter_download(self, target_shard: ShardMetadata) -> None: - """Enqueue a download for the drafter declared on ``target_shard``'s - model card, if any. + """Enqueue downloads for every drafter declared on ``target_shard``'s + model card. - Drafter downloads are silent: anything that fails (no card, env - opt-out, HF unreachable, drafter already tracked) is logged and - swallowed. The target download is the source of truth for the user's - intent; speculative decoding is best-effort. + We download *all* candidate drafters so the runner can switch between + them at startup time via ``EXO_DRAFTER_PREFERENCE`` without an + on-demand fetch. Drafters are small (typically <2GB) so the storage + overhead is fine. - The drafter is downloaded as a single ``PipelineShardMetadata`` for + Drafter downloads are silent best-effort: anything that fails (no + cards, env opt-out, HF unreachable, drafter already tracked) is + logged and swallowed. The target download is the source of truth for + the user's intent; speculative decoding is best-effort. + + Each drafter is downloaded as a single ``PipelineShardMetadata`` for the entire model. Speculative decoding is single-device today (see ``mlx_generate``), so we never need a sharded drafter. """ - drafter_id = target_shard.model_card.drafter_model_id - if drafter_id is None: + drafter_ids = list(target_shard.model_card.drafter_model_ids) + if not drafter_ids: return if _drafter_disabled_by_env(): logger.debug( - f"EXO_DISABLE_DRAFTER set; skipping drafter download " - f"{drafter_id} for {target_shard.model_card.model_id}" + f"EXO_DISABLE_DRAFTER set; skipping drafter downloads " + f"{drafter_ids} for {target_shard.model_card.model_id}" ) return - if drafter_id in self.download_status: - return # already tracked - try: - drafter_card = await ModelCard.load(drafter_id) - except Exception as exc: - logger.warning( - f"Could not load drafter card {drafter_id} for " - f"{target_shard.model_card.model_id}; skipping drafter " - f"download: {exc}" - ) - return + for drafter_id in drafter_ids: + if drafter_id in self.download_status: + continue # already tracked - drafter_shard = PipelineShardMetadata( - model_card=drafter_card, - device_rank=0, - world_size=1, - start_layer=0, - end_layer=drafter_card.n_layers, - n_layers=drafter_card.n_layers, - ) - logger.info( - f"Chaining drafter download {drafter_id} for " - f"{target_shard.model_card.model_id}" - ) - await self._start_download(drafter_shard) + try: + drafter_card = await ModelCard.load(drafter_id) + except Exception as exc: + logger.warning( + f"Could not load drafter card {drafter_id} for " + f"{target_shard.model_card.model_id}; skipping drafter " + f"download: {exc}" + ) + continue + + drafter_shard = PipelineShardMetadata( + model_card=drafter_card, + device_rank=0, + world_size=1, + start_layer=0, + end_layer=drafter_card.n_layers, + n_layers=drafter_card.n_layers, + ) + logger.info( + f"Chaining drafter download {drafter_id} for " + f"{target_shard.model_card.model_id}" + ) + await self._start_download(drafter_shard) def _start_download_task( self, shard: ShardMetadata, initial_progress: RepoDownloadProgress diff --git a/src/exo/download/tests/test_drafter_download.py b/src/exo/download/tests/test_drafter_download.py index 276690998f..58bf79422b 100644 --- a/src/exo/download/tests/test_drafter_download.py +++ b/src/exo/download/tests/test_drafter_download.py @@ -38,7 +38,7 @@ DRAFTER_ID = ModelId("test-org/test-drafter") -def _make_target_card(drafter: ModelId | None) -> ModelCard: +def _make_target_card(drafters: list[ModelId]) -> ModelCard: return ModelCard( model_id=TARGET_ID, storage_size=Memory.from_mb(500), @@ -46,7 +46,7 @@ def _make_target_card(drafter: ModelId | None) -> ModelCard: hidden_size=2048, supports_tensor=False, tasks=[ModelTask.TextGeneration], - drafter_model_id=drafter, + drafter_model_ids=drafters, ) @@ -200,7 +200,7 @@ async def _running_coordinator( async def test_target_with_drafter_chains_drafter_download() -> None: - target_shard = _make_shard(_make_target_card(DRAFTER_ID)) + target_shard = _make_shard(_make_target_card([DRAFTER_ID])) drafter_card = _make_drafter_card() async def fake_load(model_id: ModelId) -> ModelCard: @@ -227,7 +227,7 @@ async def fake_load(model_id: ModelId) -> ModelCard: async def test_target_without_drafter_does_not_chain() -> None: - target_shard = _make_shard(_make_target_card(None)) + target_shard = _make_shard(_make_target_card([])) async def fail_load(_: ModelId) -> ModelCard: raise AssertionError("ModelCard.load should not be called when no drafter") @@ -253,7 +253,7 @@ async def test_drafter_chain_skipped_when_disabled_by_env( monkeypatch: pytest.MonkeyPatch, ) -> None: monkeypatch.setenv("EXO_DISABLE_DRAFTER", "1") - target_shard = _make_shard(_make_target_card(DRAFTER_ID)) + target_shard = _make_shard(_make_target_card([DRAFTER_ID])) async def fail_load(_: ModelId) -> ModelCard: raise AssertionError( @@ -281,7 +281,7 @@ async def test_drafter_chain_swallows_card_load_error() -> None: """If the drafter's ModelCard cannot be loaded (e.g. HF unreachable, card not in repo), the target download must still complete and the coordinator must not crash.""" - target_shard = _make_shard(_make_target_card(DRAFTER_ID)) + target_shard = _make_shard(_make_target_card([DRAFTER_ID])) async def boom(_: ModelId) -> ModelCard: raise RuntimeError("simulated card load failure") diff --git a/src/exo/main.py b/src/exo/main.py index 9edee42096..e308afc879 100644 --- a/src/exo/main.py +++ b/src/exo/main.py @@ -1,9 +1,11 @@ import argparse +import ipaddress import multiprocessing as mp import os import resource import signal import subprocess +import sys from dataclasses import dataclass, field from datetime import datetime, timezone from typing import Self @@ -44,6 +46,7 @@ class Node: node_id: NodeId offline: bool _api_port: int + _libp2p_port: int _tg: TaskGroup = field(init=False, default_factory=TaskGroup) @classmethod @@ -144,6 +147,7 @@ async def create(cls, args: "Args") -> Self: node_id, args.offline, args.api_port, + args.libp2p_port, ) logger_set_context( node_id=node_id, role="master" if args.force_master else "node" @@ -169,6 +173,12 @@ async def run(self): tg.start_soon(self.master.run) if self.api: tg.start_soon(self.api.run) + if sys.platform == "darwin" and self._libp2p_port != 0: + tg.start_soon( + _darwin_mdns_broadcast_announcer, + self.node_id, + self._libp2p_port, + ) tg.start_soon(self._elect_loop) tg.start_soon(self._diagnostic_snapshot_loop) @@ -361,6 +371,77 @@ def _last_seen_ages(self, state: State) -> dict[str, float]: return ages +def _darwin_en0_ip_address() -> str | None: + try: + return subprocess.check_output( + ["ipconfig", "getifaddr", "en0"], + text=True, + stderr=subprocess.DEVNULL, + ).strip() + except (OSError, subprocess.CalledProcessError): + return None + + +def _darwin_en0_broadcast_address(ip_address: str) -> str | None: + try: + subnet_mask = subprocess.check_output( + ["ipconfig", "getoption", "en0", "subnet_mask"], + text=True, + stderr=subprocess.DEVNULL, + ).strip() + interface = ipaddress.IPv4Interface(f"{ip_address}/{subnet_mask}") + return str(interface.network.broadcast_address) + except (OSError, ValueError, subprocess.CalledProcessError): + return None + + +async def _darwin_mdns_broadcast_announcer( + node_id: NodeId, libp2p_port: int +) -> None: + ip_address = _darwin_en0_ip_address() + if not ip_address: + logger.debug("Darwin mDNS broadcast announcer disabled: no en0 IPv4 address") + return + + broadcast_address = _darwin_en0_broadcast_address(ip_address) + logger.debug( + f"Darwin mDNS announcer advertising {node_id} at {ip_address}:{libp2p_port}" + ) + command = [ + sys.executable, + "-m", + "exo.routing.mdns_announcer", + "--node-id", + str(node_id), + "--ip-address", + ip_address, + "--libp2p-port", + str(libp2p_port), + ] + if broadcast_address is not None: + command.extend(["--broadcast-address", broadcast_address]) + process = subprocess.Popen( + command, + start_new_session=True, + stdout=subprocess.DEVNULL, + ) + try: + while process.poll() is None: + await anyio.sleep(60) + logger.debug( + f"Darwin mDNS announcer subprocess exited with {process.returncode}" + ) + finally: + if process.poll() is None: + process.terminate() + with anyio.move_on_after(2): + while process.poll() is None: + await anyio.sleep(0.1) + if process.poll() is None: + process.kill() + await anyio.sleep(0) + + def main(): args = Args.parse() soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE) diff --git a/src/exo/master/main.py b/src/exo/master/main.py index 888c39e4c8..026d953d75 100644 --- a/src/exo/master/main.py +++ b/src/exo/master/main.py @@ -7,6 +7,7 @@ from exo.master.placement import ( add_instance_to_placements, + auto_place_prefill_siblings, cancel_unnecessary_downloads, delete_instance, get_transition_events, @@ -38,6 +39,7 @@ from exo.shared.types.events import ( CustomModelCardAdded, CustomModelCardDeleted, + DrafterPlacementDegraded, Event, GlobalForwarderEvent, IndexedEvent, @@ -55,7 +57,7 @@ TracesCollected, TracesMerged, ) -from exo.shared.types.instance_link import InstanceLink +from exo.shared.types.instance_link import InstanceLink, InstanceLinkId from exo.shared.types.state import State from exo.shared.types.tasks import ( ImageEdits as ImageEditsTask, @@ -385,6 +387,9 @@ async def _command_processor(self) -> None: ) generated_events.extend(transition_events) case PlaceInstance(): + drafter_degradation_events: list[ + DrafterPlacementDegraded + ] = [] placement = place_instance( command, self.state.topology, @@ -392,11 +397,60 @@ async def _command_processor(self) -> None: self.state.node_memory, self.state.node_network, download_status=self.state.downloads, + on_drafter_placement_degraded=drafter_degradation_events.append, ) + + # Auto-place prefill-only siblings on operator- + # designated nodes, then link them to each newly- + # created decode instance. The link tells + # ``_prefill_endpoint_for`` to spread incoming + # requests' prefill traffic across the linked + # nodes, which is the only architecturally + # honest way to keep slot N's TTFT independent + # of slot 0's prefill: dispatch them to + # different GPUs in the cluster instead of + # serialising on the target's single forward. + if command.model_card.prefill_eligible_nodes: + new_decode_ids = [ + iid + for iid in placement + if iid not in self.state.instances + ] + for decode_id in new_decode_ids: + decode_inst = placement[decode_id] + ( + new_prefill_instances, + new_prefill_ids, + ) = auto_place_prefill_siblings( + decode_instance_id=decode_id, + decode_instance=decode_inst, + model_card=command.model_card, + topology=self.state.topology, + current_instances=placement, + node_memory=self.state.node_memory, + node_network=self.state.node_network, + download_status=self.state.downloads, + ) + placement = { + **placement, + **new_prefill_instances, + } + if new_prefill_ids: + generated_events.append( + InstanceLinkCreated( + link=InstanceLink( + link_id=InstanceLinkId(), + prefill_instances=new_prefill_ids, + decode_instances=[decode_id], + ) + ) + ) + transition_events = get_transition_events( self.state.instances, placement, self.state.tasks ) generated_events.extend(transition_events) + generated_events.extend(drafter_degradation_events) case CreateInstance(): placement = add_instance_to_placements( command, @@ -481,14 +535,34 @@ async def _command_processor(self) -> None: # These plan loops are the cracks showing in our event sourcing architecture - more things could be commands async def _plan(self) -> None: + node_inactivity_timeout = timedelta(seconds=5) + tick_interval_seconds = 1.0 + while True: # kill broken instances connected_node_ids = set(self.state.topology.list_nodes()) for instance_id, instance in self.state.instances.items(): - for node_id in instance.shard_assignments.node_to_runner: + # ``all_node_to_runner`` includes the drafter node for + # asymmetric placements, so a drafter-node disconnect + # tears the instance down on the same path as a target + # rank disconnect. Without this, the surviving target + # ranks would keep the instance alive but block on + # ``transport.forward()`` against a dead socket -- the + # drafter rank will not come back without a full + # placement rebuild, so deletion is the only consistent + # recovery path. ``shard_assignments.node_to_runner`` is + # a strict subset, so the symmetric (drafter-less) path + # behaves identically. + for node_id in instance.all_node_to_runner: if node_id not in connected_node_ids: + is_drafter_node = ( + instance.drafter_placement is not None + and node_id == instance.drafter_placement.drafter_node_id + ) + node_role = "drafter" if is_drafter_node else "shard" logger.warning( - "Deleting instance because a shard node is disconnected " + f"Deleting instance because a {node_role} " + f"node is disconnected " f"instance_id={instance_id} " f"model_id={instance.shard_assignments.model_id} " f"missing_node={node_id} " @@ -503,7 +577,7 @@ async def _plan(self) -> None: # time out dead nodes for node_id, time in self.state.last_seen.items(): now = datetime.now(tz=timezone.utc) - if now - time > timedelta(seconds=30): + if now - time > node_inactivity_timeout: impacted_instances = [ str(instance_id) for instance_id, instance in self.state.instances.items() @@ -520,7 +594,7 @@ async def _plan(self) -> None: ) await self.event_sender.send(NodeTimedOut(node_id=node_id)) - await anyio.sleep(10) + await anyio.sleep(tick_interval_seconds) async def _event_processor(self) -> None: with self.local_event_receiver as local_events: diff --git a/src/exo/master/placement.py b/src/exo/master/placement.py index f665208777..4885437fd3 100644 --- a/src/exo/master/placement.py +++ b/src/exo/master/placement.py @@ -1,4 +1,4 @@ -from collections.abc import Mapping +from collections.abc import Callable, Mapping from copy import deepcopy from os import environ from typing import Sequence @@ -8,6 +8,7 @@ from exo.master.placement_utils import ( Cycle, filter_cycles_by_memory, + find_ip_prioritised, get_mlx_jaccl_coordinators, get_mlx_jaccl_devices_matrix, get_mlx_ring_hosts_by_node, @@ -25,6 +26,8 @@ ) from exo.shared.types.common import NodeId from exo.shared.types.events import ( + DrafterPlacementDegradationReason, + DrafterPlacementDegraded, Event, InstanceCreated, InstanceDeleted, @@ -42,12 +45,14 @@ DownloadProgress, ) from exo.shared.types.worker.instances import ( + DrafterPlacement, Instance, InstanceId, InstanceMeta, MlxJacclInstance, MlxRingInstance, ) +from exo.shared.types.worker.runners import RunnerId from exo.shared.types.worker.shards import Sharding from exo.utils.ports import random_ephemeral_port @@ -131,6 +136,9 @@ def place_instance( allowed_nodes: set[NodeId] | None = None, allow_single_node_total_memory: bool = False, download_status: Mapping[NodeId, Sequence[DownloadProgress]] | None = None, + on_drafter_placement_degraded: ( + Callable[[DrafterPlacementDegraded], None] | None + ) = None, ) -> dict[InstanceId, Instance]: sharding = command.sharding instance_meta = command.instance_meta @@ -150,6 +158,23 @@ def place_instance( for cycle in candidate_cycles if set(cycle.node_ids).issubset(allowed_nodes) ] + + # Reserve drafter-eligible nodes for the drafter rank when possible, so + # the placement layer doesn't accidentally pull a drafter-eligible node + # into the target cycle and then degrade because no eligible host + # remains. If filtering them out leaves zero cycles, fall back to the + # unfiltered set -- the user gets target placement at the cost of the + # asymmetric drafter, and `_select_drafter_placement` emits a + # ``AllEligibleNodesInTargetCycle`` degradation downstream. + eligible_drafter_set = set(command.model_card.drafter_eligible_nodes) + if eligible_drafter_set and command.model_card.drafter_model_ids: + cycles_excluding_drafters = [ + cycle + for cycle in candidate_cycles + if not (set(cycle.node_ids) & eligible_drafter_set) + ] + if cycles_excluding_drafters: + candidate_cycles = cycles_excluding_drafters cycles_with_sufficient_memory = filter_cycles_by_memory( candidate_cycles, node_memory, @@ -293,10 +318,34 @@ def place_instance( topology=topology, ) - # Single-node: force Pipeline/Ring (Tensor and Jaccl require multi-node) + # Single-node target cycle requires Pipeline sharding (PP=1). The + # backend choice depends on whether an asymmetric drafter rank will + # extend the parent ``mx.distributed`` group beyond size 1: ring lacks + # ``Group.split`` / ``send/recv`` so an N+1=2 parent group cannot use + # it; jaccl supports both. We therefore peek at drafter eligibility + # before locking the backend, then re-run the full drafter selection + # below with the (possibly upgraded) ``instance_meta``. if len(selected_cycle) == 1: - instance_meta = InstanceMeta.MlxRing sharding = Sharding.Pipeline + will_attempt_asymmetric_drafter = ( + bool(command.model_card.drafter_eligible_nodes) + and bool(command.model_card.drafter_model_ids) + and any( + node_id in topology.list_nodes() + and node_id not in selected_cycle.node_ids + for node_id in command.model_card.drafter_eligible_nodes + ) + ) + if not will_attempt_asymmetric_drafter: + instance_meta = InstanceMeta.MlxRing + elif instance_meta == InstanceMeta.MlxRing: + # User asked for ring but the model declares an asymmetric + # drafter on a separate node. Auto-upgrade to jaccl since ring + # cannot support the parent group's split + send/recv path. + # If jaccl reachability fails downstream, drafter selection + # emits a degradation event and target falls back to ring + # symmetric (no drafter), restoring V1 ring behavior. + instance_meta = InstanceMeta.MlxJaccl placement_node_memory = ( _node_memory_with_total_capacity(selected_cycle, node_memory) @@ -307,9 +356,42 @@ def place_instance( command.model_card, selected_cycle, sharding, placement_node_memory ) - cycle_digraph: Topology = topology.get_subgraph_from_nodes(selected_cycle.node_ids) - instance_id = InstanceId() + drafter_placement = _select_drafter_placement( + command=command, + selected_cycle=selected_cycle, + instance_meta=instance_meta, + topology=topology, + node_memory=node_memory, + node_network=node_network, + instance_id=instance_id, + on_drafter_placement_degraded=on_drafter_placement_degraded, + ) + + # If the auto-upgrade to MlxJaccl above didn't yield a drafter (e.g. + # no RDMA path to the eligible node), revert to MlxRing for the + # symmetric single-rank target. The degradation event was already + # emitted by ``_select_drafter_placement``; the user's instance + # still completes, just without speculative decoding. + if ( + len(selected_cycle) == 1 + and instance_meta == InstanceMeta.MlxJaccl + and drafter_placement is None + ): + instance_meta = InstanceMeta.MlxRing + + # Asymmetric placement (``drafter_placement is not None``) keeps the + # drafter rank OUT of the parent ``mx.distributed`` group: the + # drafter talks to target rank 0 over a direct TCP socket + # (``DrafterPlacement.drafter_socket_host``/``port``). Subgraph + + # connectivity tables (``hosts_by_node`` / ``jaccl_devices``) + # therefore cover only target nodes -- this lets target ranks of + # any size run TP/PP collectives without requiring + # ``Group.split`` (jaccl/ring backends do not implement split on + # Apple Silicon). + nodes_for_group = list(selected_cycle.node_ids) + cycle_digraph: Topology = topology.get_subgraph_from_nodes(nodes_for_group) + target_instances = dict(deepcopy(current_instances)) match instance_meta: @@ -330,7 +412,7 @@ def get_device_rank(node_id: NodeId) -> int: coordinator_node_id = zero_node_ids[0] mlx_jaccl_devices = get_mlx_jaccl_devices_matrix( - [node_id for node_id in selected_cycle], + nodes_for_group, cycle_digraph, ) mlx_jaccl_coordinators = get_mlx_jaccl_coordinators( @@ -344,11 +426,12 @@ def get_device_rank(node_id: NodeId) -> int: shard_assignments=shard_assignments, jaccl_devices=mlx_jaccl_devices, jaccl_coordinators=mlx_jaccl_coordinators, + drafter_placement=drafter_placement, ) case InstanceMeta.MlxRing: ephemeral_port = random_ephemeral_port() hosts_by_node = get_mlx_ring_hosts_by_node( - selected_cycle=selected_cycle, + selected_cycle=Cycle(node_ids=nodes_for_group), cycle_digraph=cycle_digraph, ephemeral_port=ephemeral_port, node_network=node_network, @@ -358,11 +441,367 @@ def get_device_rank(node_id: NodeId) -> int: shard_assignments=shard_assignments, hosts_by_node=hosts_by_node, ephemeral_port=ephemeral_port, + drafter_placement=drafter_placement, ) + # Multi-node placement WITHOUT an asymmetric drafter rank still loses + # speculative decoding (mlx_lm doesn't run draft_model on TP/PP target + # ranks today). Degrade-loud so operators see it without crawling logs; + # the user's request still completes. + if ( + len(selected_cycle) > 1 + and command.model_card.drafter_model_ids + and drafter_placement is None + ): + logger.warning( + f"Model {command.model_card.model_id} declares drafters " + f"{list(command.model_card.drafter_model_ids)} but is being " + f"placed across {len(selected_cycle)} nodes WITHOUT an asymmetric " + "drafter rank. Speculative decoding is single-device only and " + "will be disabled for this instance. To get the drafter speedup, " + "either place a smaller quant on a single node OR list a separate " + "drafter-eligible node in the model card's `drafter_eligible_nodes`." + ) + return target_instances +def _select_drafter_placement( + *, + command: PlaceInstance, + selected_cycle: Cycle, + instance_meta: InstanceMeta, + topology: Topology, + node_memory: Mapping[NodeId, MemoryUsage], + node_network: Mapping[NodeId, NodeNetworkInfo], + instance_id: InstanceId, + on_drafter_placement_degraded: (Callable[[DrafterPlacementDegraded], None] | None), +) -> DrafterPlacement | None: + """Pick a drafter-eligible node for asymmetric drafter placement. + + A drafter rank is appended to the parent ``mx.distributed`` group when + *all* of the following hold: + + * The model card lists ``drafter_eligible_nodes``. + * The card lists ``drafter_model_ids`` (otherwise there's nothing to + run on the drafter rank). + * At least one eligible node is alive in topology, NOT already a + target rank, AND reachable from target rank 0 over the right + transport (RDMA for ``MlxJaccl``; socket for ``MlxRing``). + + The fallback is loud-but-graceful: when none of the eligible nodes + satisfies the constraints, the function emits a + :class:`DrafterPlacementDegraded` event via + ``on_drafter_placement_degraded`` and returns ``None``. The caller + proceeds with the legacy symmetric topology, the user's request still + completes, and the operator sees the degradation event surfaced in + the dashboard / API stats so they know to fix the cluster (bring an + eligible node online, free RAM, repair the network edge). + + The drafter is always assigned the **last rank** in the parent group + (``len(selected_cycle)``). Target ranks split off into a subgroup at + runtime via ``mx.distributed.Group.split``. + """ + eligible_nodes = list(command.model_card.drafter_eligible_nodes) + drafter_candidates = list(command.model_card.drafter_model_ids) + if not eligible_nodes or not drafter_candidates: + return None + + target_node_ids = list(selected_cycle.node_ids) + fallback = _drafter_fallback(target_node_ids) + + alive_in_topology = set(topology.list_nodes()) + alive_eligible = [n for n in eligible_nodes if n in alive_in_topology] + if not alive_eligible: + _emit_drafter_degraded( + on_drafter_placement_degraded, + command=command, + instance_id=instance_id, + target_node_ids=target_node_ids, + eligible_nodes=eligible_nodes, + reason=DrafterPlacementDegradationReason.NoEligibleNodeAvailable, + fallback=fallback, + detail=( + f"None of {eligible_nodes} are present in topology " + f"(known nodes: {sorted(alive_in_topology)})" + ), + ) + return None + + not_in_target = [n for n in alive_eligible if n not in target_node_ids] + if not not_in_target: + _emit_drafter_degraded( + on_drafter_placement_degraded, + command=command, + instance_id=instance_id, + target_node_ids=target_node_ids, + eligible_nodes=eligible_nodes, + reason=DrafterPlacementDegradationReason.AllEligibleNodesInTargetCycle, + fallback=fallback, + detail=( + f"All eligible nodes {alive_eligible} are already target " + f"ranks ({target_node_ids}); no spare host available" + ), + ) + return None + + requires_rdma = instance_meta == InstanceMeta.MlxJaccl + reachable: list[NodeId] = [] + for candidate in not_in_target: + if _drafter_node_is_reachable( + target_node_ids=target_node_ids, + drafter_node=candidate, + topology=topology, + requires_rdma=requires_rdma, + ): + reachable.append(candidate) + + if not reachable: + _emit_drafter_degraded( + on_drafter_placement_degraded, + command=command, + instance_id=instance_id, + target_node_ids=target_node_ids, + eligible_nodes=eligible_nodes, + reason=DrafterPlacementDegradationReason.NoReachablePathFromTargetRankZero, + fallback=fallback, + detail=( + f"No {'RDMA' if requires_rdma else 'socket'} path from target " + f"ranks {target_node_ids} to any of {not_in_target}" + ), + ) + return None + + drafter_node_id = reachable[0] + if not _node_has_drafter_memory( + drafter_node=drafter_node_id, + node_memory=node_memory, + target_card=command.model_card, + ): + _emit_drafter_degraded( + on_drafter_placement_degraded, + command=command, + instance_id=instance_id, + target_node_ids=target_node_ids, + eligible_nodes=eligible_nodes, + reason=DrafterPlacementDegradationReason.InsufficientDrafterMemory, + fallback=fallback, + detail=( + f"Drafter node {drafter_node_id} has " + f"{node_memory[drafter_node_id].ram_available.in_gb:.1f}GB " + f"available; conservative drafter estimate is " + f"{_DRAFTER_MEMORY_FLOOR.in_gb:.1f}GB" + ), + ) + return None + + drafter_model_id = drafter_candidates[0] + drafter_runner_id = RunnerId() + drafter_rank = len(selected_cycle) + + # Resolve target rank 0's IP from the drafter's perspective. Target + # rank 0 == selected_cycle.node_ids[0] by construction (every shard + # assigner enumerates the cycle in order; ``device_rank`` is the + # enumeration index). We pick the same priority order ``ring`` uses + # (Thunderbolt-bridge first, then ethernet, then wifi) because the + # drafter wire is small fixed-size frames where TCP latency over a + # direct cable beats RDMA setup latency every time. + # + # ``find_ip_prioritised`` returns the SINK end of connections going + # ``node_id -> other_node_id``: i.e. the address ``other_node_id`` + # advertises for that direction. We want the address target rank 0 + # advertises *to the drafter*, so ``other_node_id`` is the target + # and ``node_id`` is the drafter. + target_rank_zero = selected_cycle.node_ids[0] + drafter_socket_host = find_ip_prioritised( + drafter_node_id, + target_rank_zero, + topology, + node_network, + ring=True, + ) + if drafter_socket_host is None: + # ``_drafter_node_is_reachable`` already checked the directional + # edge; if topology says reachable but no IP is exposed, the + # node is misconfigured. Bail out loudly via degradation rather + # than picking ``0.0.0.0`` (which the drafter cannot dial). + _emit_drafter_degraded( + on_drafter_placement_degraded, + command=command, + instance_id=instance_id, + target_node_ids=target_node_ids, + eligible_nodes=eligible_nodes, + reason=DrafterPlacementDegradationReason.NoReachablePathFromTargetRankZero, + fallback=fallback, + detail=( + f"Target rank 0 ({target_rank_zero}) has no IP address " + f"reachable from drafter node {drafter_node_id} in topology" + ), + ) + return None + drafter_socket_port = random_ephemeral_port() + # Inter-target-peer wire: target rank 0 binds a separate ephemeral + # port for the spec-decode int-broadcast fanout (drafts in / sampled + # tokens out). Decoupled from the drafter port because both bind on + # rank 0 and a single port can only accept one connection class + # cleanly. Each non-zero target rank dials the IP rank 0 advertises + # *to that peer* -- different peers may reach rank 0 over different + # interfaces (e.g. a Thunderbolt /30 mesh exposes a unique IP per + # node pair). The map below resolves those per-peer IPs once at + # placement time so workers don't re-do the topology dance at + # bootstrap. + target_peer_socket_port = random_ephemeral_port() + # Keys stored as strings so the dict round-trips through the + # event-router JSON wire (JSON has no int dict keys, and pydantic + # strict mode rejects str keys for a ``dict[int, _]`` field at + # re-validation). Consumers stringify the rank before lookup. + target_peer_hosts_by_rank: dict[str, str] = {} + for peer_rank, peer_node_id in enumerate(selected_cycle.node_ids): + if peer_rank == 0: + continue + peer_view_of_rank_zero = find_ip_prioritised( + peer_node_id, + target_rank_zero, + topology, + node_network, + ring=True, + ) + if peer_view_of_rank_zero is None: + # Same fail-loud rationale as the drafter IP: target rank 0 + # is unreachable from a peer in topology, so the spec-decode + # int-broadcast wire cannot be brought up. Falling back to + # the legacy ``mx.distributed`` broadcast would re-introduce + # the JACCL int/float wire-conflation bug. Degrade to no + # drafter so the user still gets generation, just at + # standard (non-speculative) speed. + _emit_drafter_degraded( + on_drafter_placement_degraded, + command=command, + instance_id=instance_id, + target_node_ids=target_node_ids, + eligible_nodes=eligible_nodes, + reason=DrafterPlacementDegradationReason.NoReachablePathFromTargetRankZero, + fallback=fallback, + detail=( + f"Target rank 0 ({target_rank_zero}) has no IP address " + f"reachable from peer target rank {peer_rank} " + f"(node {peer_node_id}) in topology" + ), + ) + return None + target_peer_hosts_by_rank[str(peer_rank)] = peer_view_of_rank_zero + return DrafterPlacement( + drafter_node_id=drafter_node_id, + drafter_runner_id=drafter_runner_id, + drafter_model_id=drafter_model_id, + drafter_rank=drafter_rank, + drafter_socket_host=drafter_socket_host, + drafter_socket_port=drafter_socket_port, + target_peer_socket_port=target_peer_socket_port, + target_peer_hosts_by_rank=target_peer_hosts_by_rank, + ) + + +def _drafter_fallback(target_node_ids: list[NodeId]) -> str: + """``single_device_drafter`` when target is single-node, else ``no_drafter``. + + Multi-node target with no asymmetric drafter rank can't host the + drafter at all (mlx_lm spec decode is single-device); single-node + target falls back to in-process drafter as before. + """ + return "single_device_drafter" if len(target_node_ids) == 1 else "no_drafter" + + +def _emit_drafter_degraded( + callback: Callable[[DrafterPlacementDegraded], None] | None, + *, + command: PlaceInstance, + instance_id: InstanceId, + target_node_ids: list[NodeId], + eligible_nodes: list[NodeId], + reason: DrafterPlacementDegradationReason, + fallback: str, + detail: str, +) -> None: + logger.error( + f"Drafter placement degraded for {command.model_card.model_id} " + f"({reason.value}): {detail}; falling back to {fallback}" + ) + if callback is None: + return + assert fallback in ("single_device_drafter", "no_drafter") + callback( + DrafterPlacementDegraded( + model_id=command.model_card.model_id, + instance_id=instance_id, + target_node_ids=target_node_ids, + eligible_nodes=eligible_nodes, + reason=reason, + fallback=fallback, + detail=detail, + ) + ) + + +def _drafter_node_is_reachable( + *, + target_node_ids: list[NodeId], + drafter_node: NodeId, + topology: Topology, + requires_rdma: bool, # retained for ABI parity; unused under v3+ wire +) -> bool: + """Drafter must be socket-reachable from target rank 0 only. + + Under the v3+ asymmetric wire (this module's :class:`DrafterPlacement` + + ``RemoteTransport``) the drafter is NOT a member of the target + ranks' ``mx.distributed.Group``. The only edge the wire actually + needs is a TCP socket between target rank 0 and the drafter node. + Every other "all target ranks must reach drafter" requirement from + the v2 wire (where drafter was an mx.distributed peer) is gone. + + ``requires_rdma`` is accepted but ignored: the drafter wire is plain + TCP regardless of whether the target ranks talk to each other over + JACCL/RDMA or ring/TCP. The argument is retained so callers don't + need to rev simultaneously with this module. + """ + del requires_rdma # documented above; the v3 wire is socket-only + if not target_node_ids: + return False + target_rank_zero = target_node_ids[0] + socket_check: Callable[[object], bool] = lambda c: isinstance( # noqa: E731 + c, SocketConnection + ) + forward = list(topology.get_all_connections_between(target_rank_zero, drafter_node)) + backward = list( + topology.get_all_connections_between(drafter_node, target_rank_zero) + ) + return any(socket_check(c) for c in forward) and any( + socket_check(c) for c in backward + ) + + +# Conservative floor for the drafter's wired-memory bump. The drafter +# weights are usually 1-5GB (e.g. gemma-4-e2b @ 8-bit ~ 2GB), but during +# load the runner may briefly hold the safetensors mmap + decompression +# buffers; bake in headroom so placement doesn't pick a node that will +# OOM at warmup. If the drafter on disk is larger than this floor the +# runner's own ``set_wired_limit_for_model`` will catch it; this is just +# a placement-time sanity check. +_DRAFTER_MEMORY_FLOOR = Memory.from_gb(6.0) + + +def _node_has_drafter_memory( + *, + drafter_node: NodeId, + node_memory: Mapping[NodeId, MemoryUsage], + target_card: ModelCard, +) -> bool: + del target_card # reserved for future per-drafter sizing + if drafter_node not in node_memory: + return False + return node_memory[drafter_node].ram_available >= _DRAFTER_MEMORY_FLOOR + + def _prefer_socket_reachable_rank_zero(cycle: Cycle, topology: Topology) -> Cycle: """Rotate multi-node placements so rank 0 is easiest for peers to reach. @@ -442,6 +881,118 @@ def _asymmetric_tensor_rank_zero_is_socket_reachable( return True +def auto_place_prefill_siblings( + *, + decode_instance_id: InstanceId, + decode_instance: Instance, + model_card: ModelCard, + topology: Topology, + current_instances: Mapping[InstanceId, Instance], + node_memory: Mapping[NodeId, MemoryUsage], + node_network: Mapping[NodeId, NodeNetworkInfo], + download_status: Mapping[NodeId, Sequence[DownloadProgress]] | None = None, +) -> tuple[dict[InstanceId, Instance], list[InstanceId]]: + """Place single-rank prefill-only siblings on each viable eligible node. + + Returns a tuple of ``(new_instances, new_prefill_instance_ids)`` where + ``new_instances`` maps newly-created prefill ``InstanceId`` to its + placement and ``new_prefill_instance_ids`` preserves placement order. + Both are empty when ``model_card.prefill_eligible_nodes`` is empty, + when no candidate is alive in topology, or when every candidate fails + feasibility (insufficient RAM, no socket reachability, etc.) -- the + decode instance still comes up; the caller emits no + ``InstanceLinkCreated`` and the user's traffic prefills locally on + the target rank. + + The recursive ``place_instance`` call is invoked with a sanitised + model card (drafter and prefill eligibility cleared) and + ``allowed_nodes={candidate}`` to force a single-node Pipeline / PP=1 + placement. We do NOT inherit drafter placement onto prefill siblings: + the prefill role is a pure remote-prefill server (TCP-only via + :class:`~exo.worker.disaggregated.server.PrefillServer`), so it + needs the target weights but not the drafter pair. + """ + eligible = list(dict.fromkeys(model_card.prefill_eligible_nodes)) + if not eligible: + return {}, [] + + decode_nodes: set[NodeId] = set( + decode_instance.shard_assignments.node_to_runner.keys() + ) + if decode_instance.drafter_placement is not None: + decode_nodes.add(decode_instance.drafter_placement.drafter_node_id) + + alive = set(topology.list_nodes()) + + candidates = [ + node_id + for node_id in eligible + if node_id in alive and node_id not in decode_nodes + ] + if not candidates: + logger.warning( + f"Auto-prefill placement skipped for decode {decode_instance_id}: " + f"no eligible node alive AND outside the decode cycle. " + f"eligible={eligible} decode_nodes={sorted(decode_nodes)} " + f"alive={sorted(alive)}" + ) + return {}, [] + + # Sanitise the recursive card so the prefill-only sibling does not + # itself recursively spawn drafters or further prefill siblings. + prefill_card = model_card.model_copy( + update={ + "drafter_eligible_nodes": [], + "drafter_model_ids": [], + "prefill_eligible_nodes": [], + } + ) + + placed: dict[InstanceId, Instance] = {} + placed_ids: list[InstanceId] = [] + accumulating_instances: dict[InstanceId, Instance] = dict(current_instances) + + for candidate_node in candidates: + sub_command = PlaceInstance( + model_card=prefill_card, + sharding=Sharding.Pipeline, + instance_meta=InstanceMeta.MlxRing, + min_nodes=1, + ) + try: + sub_placement = place_instance( + sub_command, + topology, + accumulating_instances, + node_memory, + node_network, + allowed_nodes={candidate_node}, + download_status=download_status, + ) + except ValueError as err: + logger.warning( + f"Auto-prefill skip {candidate_node} for decode " + f"{decode_instance_id}: {err}" + ) + continue + + new_ids_this_round = [ + iid for iid in sub_placement if iid not in accumulating_instances + ] + if not new_ids_this_round: + logger.warning( + f"Auto-prefill on {candidate_node} returned no new " + f"instance for decode {decode_instance_id}; skipping" + ) + continue + for iid in new_ids_this_round: + placed[iid] = sub_placement[iid] + placed_ids.append(iid) + accumulating_instances[iid] = sub_placement[iid] + + return placed, placed_ids + + def delete_instance( command: DeleteInstance, current_instances: Mapping[InstanceId, Instance], diff --git a/src/exo/master/tests/test_placement_auto_prefill.py b/src/exo/master/tests/test_placement_auto_prefill.py new file mode 100644 index 0000000000..2e49c0daef --- /dev/null +++ b/src/exo/master/tests/test_placement_auto_prefill.py @@ -0,0 +1,490 @@ +"""Tests for auto-prefill placement (multi-GPU prefill spread). + +When ``ModelCard.prefill_eligible_nodes`` is non-empty, placement +auto-creates a single-rank prefill-only sibling instance on each viable +node and the master emits an ``InstanceLinkCreated`` linking them to +the decode instance. The link tells ``_prefill_endpoint_for`` to +spread incoming requests' prefill traffic across the linked nodes, +so slot N's TTFT is decoupled from slot 0's prefill (different GPUs, +not different time slots on the same one). + +Coverage: +- Sibling placed on a viable eligible node distinct from the decode + cycle (and distinct from the asymmetric drafter rank when present). +- Drafter and prefill overlap is excluded automatically (chosen drafter + node is removed from prefill candidates). +- Eligible node not alive in topology -> skipped, no exception. +- Eligible node has insufficient RAM -> skipped, decode still placed, + no link emitted. +- Empty ``prefill_eligible_nodes`` -> legacy single-instance behaviour + (backwards compat). +- Recursive sanitisation: the sibling card has no drafter / no further + prefill spawn (so we don't recurse forever). +""" + +from collections.abc import Iterator + +import pytest +from loguru import logger as loguru_logger + +from exo.master.placement import auto_place_prefill_siblings, place_instance +from exo.master.tests.conftest import ( + create_node_memory, + create_node_network, + create_rdma_connection, + create_socket_connection, +) +from exo.shared.models.model_cards import ModelCard, ModelId, ModelTask +from exo.shared.topology import Topology +from exo.shared.types.commands import PlaceInstance +from exo.shared.types.common import CommandId, NodeId +from exo.shared.types.memory import Memory +from exo.shared.types.topology import Connection +from exo.shared.types.worker.instances import InstanceMeta +from exo.shared.types.worker.shards import Sharding + + +@pytest.fixture +def loguru_capture() -> Iterator[list[str]]: + captured: list[str] = [] + sink_id = loguru_logger.add( + lambda message: captured.append(str(message)), level="WARNING" + ) + try: + yield captured + finally: + loguru_logger.remove(sink_id) + + +def _prefill_aware_card( + *, + storage_bytes: int, + prefill_eligible: list[NodeId], + drafter_eligible: list[NodeId] | None = None, + drafter_models: list[ModelId] | None = None, +) -> ModelCard: + return ModelCard( + model_id=ModelId("mlx-community/gemma-4-26b-a4b-it-4bit"), + storage_size=Memory.from_bytes(storage_bytes), + n_layers=60, + hidden_size=5376, + num_key_value_heads=16, + supports_tensor=True, + tasks=[ModelTask.TextGeneration], + family="gemma", + base_model="Gemma 4 26B", + drafter_model_ids=drafter_models or [], + drafter_eligible_nodes=drafter_eligible or [], + prefill_eligible_nodes=prefill_eligible, + ) + + +def _bidi_socket(topology: Topology, a: NodeId, b: NodeId, ip: int) -> None: + topology.add_connection( + Connection(source=a, sink=b, edge=create_socket_connection(ip)) + ) + topology.add_connection( + Connection(source=b, sink=a, edge=create_socket_connection(ip + 1)) + ) + + +def _bidi_rdma(topology: Topology, a: NodeId, b: NodeId, iface: int) -> None: + topology.add_connection( + Connection(source=a, sink=b, edge=create_rdma_connection(iface)) + ) + topology.add_connection( + Connection(source=b, sink=a, edge=create_rdma_connection(iface)) + ) + + +def test_prefill_sibling_placed_on_eligible_idle_node() -> None: + """Decode on smbp + prefill sibling on bmbp -> 2 instances, 1 link. + + The decode instance is single-rank (PP=1) on smbp; bmbp is + declared as a prefill-eligible idle node. Auto-prefill places a + single-rank prefill-only sibling on bmbp and the master will + emit ``InstanceLinkCreated`` linking them. + """ + smbp = NodeId("smbp") + bmbp = NodeId("bmbp") + topology = Topology() + topology.add_node(smbp) + topology.add_node(bmbp) + _bidi_socket(topology, smbp, bmbp, ip=10) + _bidi_rdma(topology, smbp, bmbp, iface=1) + + node_memory = { + smbp: create_node_memory(Memory.from_gb(120).in_bytes), + bmbp: create_node_memory(Memory.from_gb(40).in_bytes), + } + node_network = { + smbp: create_node_network(), + bmbp: create_node_network(), + } + card = _prefill_aware_card( + storage_bytes=Memory.from_gb(13).in_bytes, + prefill_eligible=[bmbp], + ) + + decode_placement = place_instance( + PlaceInstance( + command_id=CommandId(), + model_card=card, + sharding=Sharding.Pipeline, + instance_meta=InstanceMeta.MlxRing, + min_nodes=1, + ), + topology, + {}, + node_memory, + node_network, + required_nodes={smbp}, + ) + assert len(decode_placement) == 1 + decode_id, decode_inst = next(iter(decode_placement.items())) + + siblings, sibling_ids = auto_place_prefill_siblings( + decode_instance_id=decode_id, + decode_instance=decode_inst, + model_card=card, + topology=topology, + current_instances=decode_placement, + node_memory=node_memory, + node_network=node_network, + ) + assert len(siblings) == 1 + assert len(sibling_ids) == 1 + sibling = siblings[sibling_ids[0]] + assert bmbp in sibling.shard_assignments.node_to_runner + assert smbp not in sibling.shard_assignments.node_to_runner + + +def test_prefill_excludes_chosen_drafter_node() -> None: + """Asymmetric decode (smbp+smbpt) + drafter on bmbp -> studio left for prefill. + + With drafter_eligible=[bmbp] and prefill_eligible=[bmbp,studio], + bmbp gets used as the drafter rank and studio is the only viable + prefill candidate. + """ + smbp = NodeId("smbp") + smbpt = NodeId("smbpt") + bmbp = NodeId("bmbp") + studio = NodeId("studio") + topology = Topology() + for n in (smbp, smbpt, bmbp, studio): + topology.add_node(n) + for a, b, ip in [ + (smbp, smbpt, 10), + (smbp, bmbp, 12), + (smbp, studio, 14), + (smbpt, bmbp, 16), + (smbpt, studio, 18), + (bmbp, studio, 20), + ]: + _bidi_socket(topology, a, b, ip=ip) + for a, b, iface in [ + (smbp, smbpt, 1), + (smbp, bmbp, 2), + (smbpt, bmbp, 3), + ]: + _bidi_rdma(topology, a, b, iface=iface) + + node_memory = { + smbp: create_node_memory(Memory.from_gb(120).in_bytes), + smbpt: create_node_memory(Memory.from_gb(120).in_bytes), + bmbp: create_node_memory(Memory.from_gb(40).in_bytes), + studio: create_node_memory(Memory.from_gb(120).in_bytes), + } + node_network = {n: create_node_network() for n in (smbp, smbpt, bmbp, studio)} + + card = _prefill_aware_card( + storage_bytes=Memory.from_gb(13).in_bytes, + prefill_eligible=[bmbp, studio], + drafter_eligible=[bmbp], + drafter_models=[ModelId("mlx-community/gemma-4-e2b-it-4bit")], + ) + + decode_placement = place_instance( + PlaceInstance( + command_id=CommandId(), + model_card=card, + sharding=Sharding.Pipeline, + instance_meta=InstanceMeta.MlxRing, + min_nodes=1, + ), + topology, + {}, + node_memory, + node_network, + required_nodes={smbp}, + ) + decode_id, decode_inst = next(iter(decode_placement.items())) + assert decode_inst.drafter_placement is not None + assert decode_inst.drafter_placement.drafter_node_id == bmbp + + siblings, sibling_ids = auto_place_prefill_siblings( + decode_instance_id=decode_id, + decode_instance=decode_inst, + model_card=card, + topology=topology, + current_instances=decode_placement, + node_memory=node_memory, + node_network=node_network, + ) + assert len(siblings) == 1 + sibling = siblings[sibling_ids[0]] + sibling_nodes = set(sibling.shard_assignments.node_to_runner.keys()) + assert sibling_nodes == {studio}, ( + f"prefill sibling should land on studio (not the drafter node bmbp); " + f"got nodes={sibling_nodes}" + ) + + +def test_prefill_skipped_when_eligible_node_offline( + loguru_capture: list[str], +) -> None: + """Eligible node not in topology -> no sibling, no exception.""" + smbp = NodeId("smbp") + ghost = NodeId("ghost-not-in-topology") + topology = Topology() + topology.add_node(smbp) + node_memory = {smbp: create_node_memory(Memory.from_gb(120).in_bytes)} + node_network = {smbp: create_node_network()} + + card = _prefill_aware_card( + storage_bytes=Memory.from_gb(13).in_bytes, + prefill_eligible=[ghost], + ) + decode_placement = place_instance( + PlaceInstance( + command_id=CommandId(), + model_card=card, + sharding=Sharding.Pipeline, + instance_meta=InstanceMeta.MlxRing, + min_nodes=1, + ), + topology, + {}, + node_memory, + node_network, + required_nodes={smbp}, + ) + decode_id, decode_inst = next(iter(decode_placement.items())) + + siblings, sibling_ids = auto_place_prefill_siblings( + decode_instance_id=decode_id, + decode_instance=decode_inst, + model_card=card, + topology=topology, + current_instances=decode_placement, + node_memory=node_memory, + node_network=node_network, + ) + assert siblings == {} + assert sibling_ids == [] + assert any("Auto-prefill placement skipped" in m for m in loguru_capture), ( + loguru_capture + ) + + +def test_prefill_skipped_when_eligible_node_oom(loguru_capture: list[str]) -> None: + """Eligible node lacks RAM -> placement raises and is logged-and-skipped.""" + smbp = NodeId("smbp") + tiny = NodeId("tiny") + topology = Topology() + topology.add_node(smbp) + topology.add_node(tiny) + _bidi_socket(topology, smbp, tiny, ip=10) + node_memory = { + smbp: create_node_memory(Memory.from_gb(120).in_bytes), + tiny: create_node_memory(Memory.from_gb(2).in_bytes), + } + node_network = {smbp: create_node_network(), tiny: create_node_network()} + + card = _prefill_aware_card( + storage_bytes=Memory.from_gb(13).in_bytes, + prefill_eligible=[tiny], + ) + decode_placement = place_instance( + PlaceInstance( + command_id=CommandId(), + model_card=card, + sharding=Sharding.Pipeline, + instance_meta=InstanceMeta.MlxRing, + min_nodes=1, + ), + topology, + {}, + node_memory, + node_network, + required_nodes={smbp}, + ) + decode_id, decode_inst = next(iter(decode_placement.items())) + + siblings, sibling_ids = auto_place_prefill_siblings( + decode_instance_id=decode_id, + decode_instance=decode_inst, + model_card=card, + topology=topology, + current_instances=decode_placement, + node_memory=node_memory, + node_network=node_network, + ) + assert siblings == {} + assert sibling_ids == [] + assert any("Auto-prefill skip" in m for m in loguru_capture), loguru_capture + + +def test_empty_prefill_eligible_preserves_legacy_path() -> None: + """No ``prefill_eligible_nodes`` -> auto-prefill is a no-op.""" + smbp = NodeId("smbp") + bmbp = NodeId("bmbp") + topology = Topology() + topology.add_node(smbp) + topology.add_node(bmbp) + _bidi_socket(topology, smbp, bmbp, ip=10) + node_memory = { + smbp: create_node_memory(Memory.from_gb(120).in_bytes), + bmbp: create_node_memory(Memory.from_gb(40).in_bytes), + } + node_network = {smbp: create_node_network(), bmbp: create_node_network()} + + card = _prefill_aware_card( + storage_bytes=Memory.from_gb(13).in_bytes, + prefill_eligible=[], + ) + decode_placement = place_instance( + PlaceInstance( + command_id=CommandId(), + model_card=card, + sharding=Sharding.Pipeline, + instance_meta=InstanceMeta.MlxRing, + min_nodes=1, + ), + topology, + {}, + node_memory, + node_network, + required_nodes={smbp}, + ) + decode_id, decode_inst = next(iter(decode_placement.items())) + + siblings, sibling_ids = auto_place_prefill_siblings( + decode_instance_id=decode_id, + decode_instance=decode_inst, + model_card=card, + topology=topology, + current_instances=decode_placement, + node_memory=node_memory, + node_network=node_network, + ) + assert siblings == {} + assert sibling_ids == [] + + +def test_prefill_sibling_does_not_carry_drafter() -> None: + """The recursive sub-placement uses a drafter-cleared card. + + Even though the model card declares a drafter, the prefill sibling + has ``drafter_placement is None`` (it's a TCP prefill server, not + a decode instance, so it has no use for a drafter). + """ + smbp = NodeId("smbp") + bmbp = NodeId("bmbp") + studio = NodeId("studio") + topology = Topology() + for n in (smbp, bmbp, studio): + topology.add_node(n) + for a, b, ip in [(smbp, bmbp, 10), (smbp, studio, 12), (bmbp, studio, 14)]: + _bidi_socket(topology, a, b, ip=ip) + + node_memory = { + smbp: create_node_memory(Memory.from_gb(120).in_bytes), + bmbp: create_node_memory(Memory.from_gb(40).in_bytes), + studio: create_node_memory(Memory.from_gb(120).in_bytes), + } + node_network = {n: create_node_network() for n in (smbp, bmbp, studio)} + + card = _prefill_aware_card( + storage_bytes=Memory.from_gb(13).in_bytes, + prefill_eligible=[studio], + drafter_eligible=[bmbp], + drafter_models=[ModelId("mlx-community/gemma-4-e2b-it-4bit")], + ) + decode_placement = place_instance( + PlaceInstance( + command_id=CommandId(), + model_card=card, + sharding=Sharding.Pipeline, + instance_meta=InstanceMeta.MlxRing, + min_nodes=1, + ), + topology, + {}, + node_memory, + node_network, + required_nodes={smbp}, + ) + decode_id, decode_inst = next(iter(decode_placement.items())) + + siblings, sibling_ids = auto_place_prefill_siblings( + decode_instance_id=decode_id, + decode_instance=decode_inst, + model_card=card, + topology=topology, + current_instances=decode_placement, + node_memory=node_memory, + node_network=node_network, + ) + assert len(siblings) == 1 + sibling = siblings[sibling_ids[0]] + assert sibling.drafter_placement is None, ( + "prefill sibling must not own a drafter -- only the decode does" + ) + + +def test_eligible_duplicates_are_deduped() -> None: + """``prefill_eligible_nodes=[bmbp, bmbp]`` -> one sibling, not two.""" + smbp = NodeId("smbp") + bmbp = NodeId("bmbp") + topology = Topology() + topology.add_node(smbp) + topology.add_node(bmbp) + _bidi_socket(topology, smbp, bmbp, ip=10) + node_memory = { + smbp: create_node_memory(Memory.from_gb(120).in_bytes), + bmbp: create_node_memory(Memory.from_gb(40).in_bytes), + } + node_network = {smbp: create_node_network(), bmbp: create_node_network()} + + card = _prefill_aware_card( + storage_bytes=Memory.from_gb(13).in_bytes, + prefill_eligible=[bmbp, bmbp], + ) + decode_placement = place_instance( + PlaceInstance( + command_id=CommandId(), + model_card=card, + sharding=Sharding.Pipeline, + instance_meta=InstanceMeta.MlxRing, + min_nodes=1, + ), + topology, + {}, + node_memory, + node_network, + required_nodes={smbp}, + ) + decode_id, decode_inst = next(iter(decode_placement.items())) + + siblings, sibling_ids = auto_place_prefill_siblings( + decode_instance_id=decode_id, + decode_instance=decode_inst, + model_card=card, + topology=topology, + current_instances=decode_placement, + node_memory=node_memory, + node_network=node_network, + ) + assert len(siblings) == 1 + assert len(sibling_ids) == 1 diff --git a/src/exo/master/tests/test_placement_drafter_asymmetric.py b/src/exo/master/tests/test_placement_drafter_asymmetric.py new file mode 100644 index 0000000000..bd2fd23765 --- /dev/null +++ b/src/exo/master/tests/test_placement_drafter_asymmetric.py @@ -0,0 +1,711 @@ +"""Tests for asymmetric drafter placement (Layer B). + +When a model card declares ``drafter_eligible_nodes`` AND the cluster +has at least one such node alive, reachable from every target rank, and +with sufficient memory, placement appends a *drafter rank* to the +parent ``mx.distributed`` group on a separate node. Target ranks split +off into a target subgroup at runtime; the parent group is reserved for +``RemoteTransport`` send/recv between target rank 0 and the drafter +rank. + +Coverage: +- Asymmetric placement is constructed when an eligible node is reachable + with both backends (``MlxRing`` over socket, ``MlxJaccl`` over RDMA). +- Placement degrades loudly when no eligible node is alive, when every + eligible node is already a target rank, or when the only eligible + candidate has no reachable transport. The user's request still + completes (placement returns *something*), and a + ``DrafterPlacementDegraded`` event is emitted with the reason. +- Empty ``drafter_eligible_nodes`` preserves legacy behaviour. +- The drafter rank is always the LAST rank in the parent group. +""" + +from collections.abc import Iterator + +import pytest +from loguru import logger as loguru_logger + +from exo.master.placement import place_instance +from exo.master.tests.conftest import ( + create_node_memory, + create_node_network, + create_rdma_connection, + create_socket_connection, +) +from exo.shared.models.model_cards import ModelCard, ModelId, ModelTask +from exo.shared.topology import Topology +from exo.shared.types.commands import PlaceInstance +from exo.shared.types.common import CommandId, NodeId +from exo.shared.types.events import ( + DrafterPlacementDegradationReason, + DrafterPlacementDegraded, +) +from exo.shared.types.memory import Memory +from exo.shared.types.topology import Connection +from exo.shared.types.worker.instances import ( + InstanceMeta, + MlxJacclInstance, + MlxRingInstance, +) +from exo.shared.types.worker.shards import Sharding + + +@pytest.fixture +def loguru_capture() -> Iterator[list[str]]: + captured: list[str] = [] + sink_id = loguru_logger.add( + lambda message: captured.append(str(message)), level="ERROR" + ) + try: + yield captured + finally: + loguru_logger.remove(sink_id) + + +def _drafter_aware_card( + *, + storage_bytes: int, + eligible_nodes: list[NodeId], + family: str = "gemma", + base_model: str = "Gemma 4 31B", + model_id: str = "mlx-community/gemma-4-31b-it-8bit", +) -> ModelCard: + return ModelCard( + model_id=ModelId(model_id), + storage_size=Memory.from_bytes(storage_bytes), + n_layers=60, + hidden_size=5376, + num_key_value_heads=16, + supports_tensor=True, + tasks=[ModelTask.TextGeneration], + family=family, + base_model=base_model, + drafter_model_ids=[ + ModelId("mlx-community/gemma-4-e2b-it-8bit"), + ModelId("mlx-community/gemma-4-e4b-it-8bit"), + ], + drafter_eligible_nodes=eligible_nodes, + ) + + +def _bidi_socket(topology: Topology, a: NodeId, b: NodeId, ip: int) -> None: + topology.add_connection( + Connection(source=a, sink=b, edge=create_socket_connection(ip)) + ) + topology.add_connection( + Connection(source=b, sink=a, edge=create_socket_connection(ip + 1)) + ) + + +def _bidi_rdma(topology: Topology, a: NodeId, b: NodeId, iface: int) -> None: + topology.add_connection( + Connection(source=a, sink=b, edge=create_rdma_connection(iface)) + ) + topology.add_connection( + Connection(source=b, sink=a, edge=create_rdma_connection(iface + 1)) + ) + + +def test_asymmetric_single_node_target_auto_upgrades_to_jaccl() -> None: + """Single-node target + RDMA-reachable drafter => asymmetric jaccl. + + Single-rank target requires Pipeline sharding, but the parent group + spans 2 ranks (1 target + 1 drafter) and ring backend lacks + ``Group.split`` / ``send/recv``. Placement therefore auto-upgrades + ``MlxRing`` -> ``MlxJaccl`` whenever asymmetric drafter placement + will succeed, so the parent group can use jaccl for the drafter + transport. + """ + target_node, drafter_node = NodeId(), NodeId() + topology = Topology() + topology.add_node(target_node) + topology.add_node(drafter_node) + _bidi_socket(topology, target_node, drafter_node, ip=2) + _bidi_rdma(topology, target_node, drafter_node, iface=4) + + card = _drafter_aware_card( + storage_bytes=20_000_000_000, eligible_nodes=[drafter_node] + ) + command = PlaceInstance( + sharding=Sharding.Pipeline, + instance_meta=InstanceMeta.MlxRing, + command_id=CommandId(), + model_card=card, + min_nodes=1, + ) + degradations: list[DrafterPlacementDegraded] = [] + + placements = place_instance( + command, + topology, + {}, + { + target_node: create_node_memory(64_000_000_000), + drafter_node: create_node_memory(32_000_000_000), + }, + { + target_node: create_node_network(), + drafter_node: create_node_network(), + }, + on_drafter_placement_degraded=degradations.append, + ) + + assert len(placements) == 1 + assert not degradations + instance = next(iter(placements.values())) + assert isinstance(instance, MlxJacclInstance) + assert instance.drafter_placement is not None + placement = instance.drafter_placement + assert placement.drafter_node_id == drafter_node + assert placement.drafter_model_id == ModelId("mlx-community/gemma-4-e2b-it-8bit") + assert placement.drafter_rank == 1 # target=1 rank, drafter is last (rank 1) + # v3+ wire: drafter does not join mx.distributed -> parent_group_size + # is the target-only rank count. + assert instance.parent_group_size == 1 + assert len(instance.shard_assignments.runner_to_shard) == 1 + + +def test_asymmetric_ring_socket_only_places_drafter_over_socket() -> None: + """Single-node ring target + socket-only drafter places drafter over TCP. + + v3+ wire decoupled the drafter from ``mx.distributed`` -- the wire + runs over a plain TCP socket. RDMA is therefore no longer required + for asymmetric placement; a socket-only path between target rank 0 + and the drafter node is sufficient. The instance still upgrades + ``MlxRing -> MlxJaccl`` because the target ranks (1 here) are + fine to leave on jaccl, but the drafter wire itself runs over TCP + regardless of the target backend. + """ + target_node, drafter_node = NodeId(), NodeId() + topology = Topology() + topology.add_node(target_node) + topology.add_node(drafter_node) + _bidi_socket(topology, target_node, drafter_node, ip=2) + + card = _drafter_aware_card( + storage_bytes=20_000_000_000, eligible_nodes=[drafter_node] + ) + command = PlaceInstance( + sharding=Sharding.Pipeline, + instance_meta=InstanceMeta.MlxRing, + command_id=CommandId(), + model_card=card, + min_nodes=1, + ) + degradations: list[DrafterPlacementDegraded] = [] + + placements = place_instance( + command, + topology, + {}, + { + target_node: create_node_memory(64_000_000_000), + drafter_node: create_node_memory(32_000_000_000), + }, + { + target_node: create_node_network(), + drafter_node: create_node_network(), + }, + on_drafter_placement_degraded=degradations.append, + ) + + assert len(placements) == 1 + instance = next(iter(placements.values())) + assert instance.drafter_placement is not None + assert instance.drafter_placement.drafter_node_id == drafter_node + # Target stays single-rank; drafter rides TCP regardless. + assert instance.parent_group_size == 1 + assert not degradations + + +def test_asymmetric_jaccl_places_drafter_with_rdma_reachability() -> None: + """Two-node target (RDMA cycle) + RDMA-reachable drafter => asymmetric jaccl. + + Single-node target gets downgraded MlxJaccl -> MlxRing by the legacy + ``len(selected_cycle) == 1 -> InstanceMeta.MlxRing`` rewrite, so to + exercise asymmetric jaccl we need the target to span 2 RDMA-connected + nodes + a 3rd drafter node with RDMA edges to both. + """ + target_a, target_b, drafter_node = NodeId(), NodeId(), NodeId() + topology = Topology() + for n in (target_a, target_b, drafter_node): + topology.add_node(n) + # Target cycle has bidirectional RDMA between target_a and target_b + _bidi_rdma(topology, target_a, target_b, iface=10) + _bidi_socket(topology, target_a, target_b, ip=12) + # Drafter has bidirectional RDMA + socket to both target ranks. + _bidi_rdma(topology, target_a, drafter_node, iface=20) + _bidi_rdma(topology, target_b, drafter_node, iface=22) + _bidi_socket(topology, target_a, drafter_node, ip=14) + _bidi_socket(topology, target_b, drafter_node, ip=16) + + # Use a Qwen-family card so the test isn't subject to Gemma 4's + # "no multi-node Pipeline" restriction. Tensor sharding works across + # 2 RDMA-connected nodes when hidden_size is divisible by world_size. + card = _drafter_aware_card( + storage_bytes=40_000_000_000, + eligible_nodes=[drafter_node], + family="qwen", + base_model="Qwen3 30B", + model_id="mlx-community/Qwen3-30B-A3B-4bit", + ) + command = PlaceInstance( + sharding=Sharding.Tensor, + instance_meta=InstanceMeta.MlxJaccl, + command_id=CommandId(), + model_card=card, + # min_nodes=2 forces multi-node target so the placement layer + # keeps MlxJaccl instead of rewriting to MlxRing. + min_nodes=2, + ) + degradations: list[DrafterPlacementDegraded] = [] + + placements = place_instance( + command, + topology, + {}, + { + target_a: create_node_memory(32_000_000_000), + target_b: create_node_memory(32_000_000_000), + drafter_node: create_node_memory(32_000_000_000), + }, + { + target_a: create_node_network(), + target_b: create_node_network(), + drafter_node: create_node_network(), + }, + on_drafter_placement_degraded=degradations.append, + ) + + assert len(placements) == 1 + assert not degradations, [(e.reason, e.detail) for e in degradations] + instance = next(iter(placements.values())) + assert isinstance(instance, MlxJacclInstance) + assert instance.drafter_placement is not None + placement = instance.drafter_placement + assert placement.drafter_node_id == drafter_node + assert placement.drafter_rank == 2 # logical telemetry index past target ranks + # v3+ wire: drafter is on a TCP socket, not in mx.distributed. + # parent_group_size and jaccl_devices cover only the 2 target ranks. + assert instance.parent_group_size == 2 + assert len(instance.jaccl_devices) == 2 + assert len(instance.jaccl_devices[0]) == 2 + # Drafter node does not coordinate the target's mx.distributed group. + assert drafter_node not in instance.jaccl_coordinators + + +def test_asymmetric_jaccl_socket_only_drafter_succeeds( + loguru_capture: list[str], +) -> None: + """Two-node jaccl target + socket-only drafter places successfully. + + v3+ wire: drafter IPC runs over a plain TCP socket independent of + the target's ``mx.distributed`` group. So a socket-only path from + target rank 0 to the drafter node is sufficient even when the + target ranks themselves are coordinating over jaccl/RDMA. No + degradation event should fire. + """ + target_a, target_b, drafter_node = NodeId(), NodeId(), NodeId() + topology = Topology() + for n in (target_a, target_b, drafter_node): + topology.add_node(n) + # Target cycle has bidirectional RDMA; drafter only has socket edges. + _bidi_rdma(topology, target_a, target_b, iface=30) + _bidi_socket(topology, target_a, target_b, ip=32) + _bidi_socket(topology, target_a, drafter_node, ip=34) + _bidi_socket(topology, target_b, drafter_node, ip=36) + + card = _drafter_aware_card( + storage_bytes=40_000_000_000, + eligible_nodes=[drafter_node], + family="qwen", + base_model="Qwen3 30B", + model_id="mlx-community/Qwen3-30B-A3B-4bit", + ) + command = PlaceInstance( + sharding=Sharding.Tensor, + instance_meta=InstanceMeta.MlxJaccl, + command_id=CommandId(), + model_card=card, + min_nodes=2, + ) + degradations: list[DrafterPlacementDegraded] = [] + + placements = place_instance( + command, + topology, + {}, + { + target_a: create_node_memory(32_000_000_000), + target_b: create_node_memory(32_000_000_000), + drafter_node: create_node_memory(32_000_000_000), + }, + { + target_a: create_node_network(), + target_b: create_node_network(), + drafter_node: create_node_network(), + }, + on_drafter_placement_degraded=degradations.append, + ) + + assert len(placements) == 1 + instance = next(iter(placements.values())) + assert isinstance(instance, MlxJacclInstance) + assert instance.drafter_placement is not None + assert instance.drafter_placement.drafter_node_id == drafter_node + # 2 target ranks + drafter on socket; mx.distributed is target-only. + assert instance.parent_group_size == 2 + assert not degradations + # No degradation log line either. + joined = "\n".join(loguru_capture) + assert "Drafter placement degraded" not in joined + + +def test_asymmetric_degrades_when_eligible_node_missing_from_topology( + loguru_capture: list[str], +) -> None: + """Eligible node id refers to a node not present in topology.""" + target_node = NodeId() + missing_drafter_node = NodeId() # Never added to topology. + topology = Topology() + topology.add_node(target_node) + + card = _drafter_aware_card( + storage_bytes=20_000_000_000, eligible_nodes=[missing_drafter_node] + ) + command = PlaceInstance( + sharding=Sharding.Pipeline, + instance_meta=InstanceMeta.MlxRing, + command_id=CommandId(), + model_card=card, + min_nodes=1, + ) + degradations: list[DrafterPlacementDegraded] = [] + + placements = place_instance( + command, + topology, + {}, + {target_node: create_node_memory(64_000_000_000)}, + {target_node: create_node_network()}, + on_drafter_placement_degraded=degradations.append, + ) + + assert len(placements) == 1 + instance = next(iter(placements.values())) + assert instance.drafter_placement is None + assert len(degradations) == 1 + assert ( + degradations[0].reason + == DrafterPlacementDegradationReason.NoEligibleNodeAvailable + ) + assert degradations[0].fallback == "single_device_drafter" + joined = "\n".join(loguru_capture).lower() + assert "drafter placement degraded" in joined + + +def test_asymmetric_degrades_when_eligible_node_in_target_cycle( + loguru_capture: list[str], +) -> None: + """Listing the target node itself as eligible is a misconfig => degrade.""" + target_node = NodeId() + topology = Topology() + topology.add_node(target_node) + + card = _drafter_aware_card( + storage_bytes=20_000_000_000, eligible_nodes=[target_node] + ) + command = PlaceInstance( + sharding=Sharding.Pipeline, + instance_meta=InstanceMeta.MlxRing, + command_id=CommandId(), + model_card=card, + min_nodes=1, + ) + degradations: list[DrafterPlacementDegraded] = [] + + placements = place_instance( + command, + topology, + {}, + {target_node: create_node_memory(64_000_000_000)}, + {target_node: create_node_network()}, + on_drafter_placement_degraded=degradations.append, + ) + + assert len(placements) == 1 + instance = next(iter(placements.values())) + assert instance.drafter_placement is None + assert len(degradations) == 1 + assert ( + degradations[0].reason + == DrafterPlacementDegradationReason.AllEligibleNodesInTargetCycle + ) + del loguru_capture # captured but content irrelevant beyond emission + + +def test_asymmetric_degrades_when_drafter_node_lacks_memory() -> None: + """Drafter node reachable but below memory floor (~6GB) => degrade. + + RDMA-reachable so jaccl auto-upgrade is viable, but memory check + rejects the candidate. Single-node target therefore reverts to + symmetric MlxRing without drafter. + """ + target_node, drafter_node = NodeId(), NodeId() + topology = Topology() + topology.add_node(target_node) + topology.add_node(drafter_node) + _bidi_socket(topology, target_node, drafter_node, ip=8) + _bidi_rdma(topology, target_node, drafter_node, iface=40) + + card = _drafter_aware_card( + storage_bytes=20_000_000_000, eligible_nodes=[drafter_node] + ) + command = PlaceInstance( + sharding=Sharding.Pipeline, + instance_meta=InstanceMeta.MlxRing, + command_id=CommandId(), + model_card=card, + min_nodes=1, + ) + degradations: list[DrafterPlacementDegraded] = [] + + placements = place_instance( + command, + topology, + {}, + { + target_node: create_node_memory(64_000_000_000), + drafter_node: create_node_memory(2_000_000_000), # 2GB is below floor + }, + { + target_node: create_node_network(), + drafter_node: create_node_network(), + }, + on_drafter_placement_degraded=degradations.append, + ) + + instance = next(iter(placements.values())) + assert isinstance(instance, MlxRingInstance) + assert instance.drafter_placement is None + assert len(degradations) == 1 + assert ( + degradations[0].reason + == DrafterPlacementDegradationReason.InsufficientDrafterMemory + ) + + +def test_empty_drafter_eligible_nodes_preserves_legacy_behaviour() -> None: + """No eligible list => no asymmetric attempt, no degradation events.""" + target_node = NodeId() + topology = Topology() + topology.add_node(target_node) + + card = ModelCard( + model_id=ModelId("mlx-community/gemma-4-31b-it-8bit"), + storage_size=Memory.from_bytes(20_000_000_000), + n_layers=60, + hidden_size=5376, + num_key_value_heads=16, + supports_tensor=True, + tasks=[ModelTask.TextGeneration], + family="gemma", + base_model="Gemma 4 31B", + drafter_model_ids=[ModelId("mlx-community/gemma-4-e2b-it-8bit")], + drafter_eligible_nodes=[], + ) + command = PlaceInstance( + sharding=Sharding.Pipeline, + instance_meta=InstanceMeta.MlxRing, + command_id=CommandId(), + model_card=card, + min_nodes=1, + ) + degradations: list[DrafterPlacementDegraded] = [] + + placements = place_instance( + command, + topology, + {}, + {target_node: create_node_memory(64_000_000_000)}, + {target_node: create_node_network()}, + on_drafter_placement_degraded=degradations.append, + ) + + instance = next(iter(placements.values())) + assert instance.drafter_placement is None + assert not degradations # no asymmetric attempt was made + + +def test_asymmetric_with_multiple_eligible_nodes_picks_first_reachable() -> None: + """When multiple eligible nodes are listed, placement picks the first + reachable (in card order). Earlier candidates that fail reachability + are skipped silently (the search is best-effort, not first-fail). + + Single-node target auto-upgrades to jaccl, so the reachable drafter + needs an RDMA edge (not just a socket edge); the unreachable drafter + has no edges at all. + """ + target_node = NodeId() + unreachable_drafter = NodeId() + reachable_drafter = NodeId() + topology = Topology() + topology.add_node(target_node) + topology.add_node(unreachable_drafter) + topology.add_node(reachable_drafter) + # Only reachable_drafter has socket + RDMA edges to target. + _bidi_socket(topology, target_node, reachable_drafter, ip=20) + _bidi_rdma(topology, target_node, reachable_drafter, iface=50) + + card = _drafter_aware_card( + storage_bytes=20_000_000_000, + eligible_nodes=[unreachable_drafter, reachable_drafter], + ) + command = PlaceInstance( + sharding=Sharding.Pipeline, + instance_meta=InstanceMeta.MlxRing, + command_id=CommandId(), + model_card=card, + min_nodes=1, + ) + degradations: list[DrafterPlacementDegraded] = [] + + placements = place_instance( + command, + topology, + {}, + { + target_node: create_node_memory(64_000_000_000), + unreachable_drafter: create_node_memory(32_000_000_000), + reachable_drafter: create_node_memory(32_000_000_000), + }, + { + target_node: create_node_network(), + unreachable_drafter: create_node_network(), + reachable_drafter: create_node_network(), + }, + on_drafter_placement_degraded=degradations.append, + ) + + instance = next(iter(placements.values())) + assert instance.drafter_placement is not None + assert instance.drafter_placement.drafter_node_id == reachable_drafter + assert not degradations # successful placement, no degradation + + +def test_asymmetric_round_trip_serialization() -> None: + """An asymmetric instance round-trips through pydantic serialisation. + + Single-node target auto-upgrades to ``MlxJaccl`` for the asymmetric + parent group (ring lacks split + send/recv), so the round-trip is + exercised on ``MlxJacclInstance`` here. RDMA edges to the drafter + node make the auto-upgrade viable. + """ + target_node, drafter_node = NodeId(), NodeId() + topology = Topology() + topology.add_node(target_node) + topology.add_node(drafter_node) + _bidi_socket(topology, target_node, drafter_node, ip=30) + _bidi_rdma(topology, target_node, drafter_node, iface=60) + + card = _drafter_aware_card( + storage_bytes=20_000_000_000, eligible_nodes=[drafter_node] + ) + command = PlaceInstance( + sharding=Sharding.Pipeline, + instance_meta=InstanceMeta.MlxRing, + command_id=CommandId(), + model_card=card, + min_nodes=1, + ) + + placements = place_instance( + command, + topology, + {}, + { + target_node: create_node_memory(64_000_000_000), + drafter_node: create_node_memory(32_000_000_000), + }, + { + target_node: create_node_network(), + drafter_node: create_node_network(), + }, + ) + instance = next(iter(placements.values())) + assert isinstance(instance, MlxJacclInstance) + assert instance.drafter_placement is not None + + dumped = instance.model_dump() + rehydrated = MlxJacclInstance.model_validate(dumped) + assert rehydrated == instance + assert rehydrated.drafter_placement is not None + assert ( + rehydrated.drafter_placement.drafter_node_id + == instance.drafter_placement.drafter_node_id + ) + + +def test_asymmetric_all_node_to_runner_includes_drafter_for_disconnect_check() -> None: + """``all_node_to_runner`` must list the drafter node so the master's + instance-deletion loop tears the placement down when the drafter node + leaves the topology. + + This pins the contract that the master's ``connected_node_ids`` + check at ``master/main.py`` relies on. Iterating + ``shard_assignments.node_to_runner`` (target ranks only) would + leave the surviving target runners blocked indefinitely on + ``transport.forward`` against a dead socket when the drafter node + disconnects -- the dead-wire ``RemoteTransport.is_failed`` flag + is set on root only, and non-root has no out-of-band signal that + the spec loop should abort. Tearing the instance down on drafter- + node disconnect is the only consistent recovery path. + """ + target_node, drafter_node = NodeId(), NodeId() + topology = Topology() + topology.add_node(target_node) + topology.add_node(drafter_node) + _bidi_socket(topology, target_node, drafter_node, ip=2) + _bidi_rdma(topology, target_node, drafter_node, iface=4) + + card = _drafter_aware_card( + storage_bytes=20_000_000_000, eligible_nodes=[drafter_node] + ) + command = PlaceInstance( + sharding=Sharding.Pipeline, + instance_meta=InstanceMeta.MlxRing, + command_id=CommandId(), + model_card=card, + min_nodes=1, + ) + placements = place_instance( + command, + topology, + {}, + { + target_node: create_node_memory(64_000_000_000), + drafter_node: create_node_memory(32_000_000_000), + }, + { + target_node: create_node_network(), + drafter_node: create_node_network(), + }, + ) + assert len(placements) == 1 + instance = next(iter(placements.values())) + assert instance.drafter_placement is not None + + # Both nodes must appear in ``all_node_to_runner`` so the master's + # disconnect check fires for either one. + assert target_node in instance.all_node_to_runner + assert drafter_node in instance.all_node_to_runner + assert ( + instance.all_node_to_runner[drafter_node] + == instance.drafter_placement.drafter_runner_id + ) + + # The legacy mapping (target shards only) intentionally excludes + # the drafter; this is the bug the master fix addresses by + # iterating ``all_node_to_runner`` instead. + assert target_node in instance.shard_assignments.node_to_runner + assert drafter_node not in instance.shard_assignments.node_to_runner diff --git a/src/exo/master/tests/test_placement_drafter_warning.py b/src/exo/master/tests/test_placement_drafter_warning.py new file mode 100644 index 0000000000..13389b021b --- /dev/null +++ b/src/exo/master/tests/test_placement_drafter_warning.py @@ -0,0 +1,141 @@ +"""Tests for the drafter-aware placement warning (item 10). + +When a model card declares `drafter_model_ids`, the placement engine still +prefers single-node (via the existing smallest-cycle-first logic). When +single-node placement is impossible because no single node has enough RAM +for the requested quant, placement falls back to multi-node and emits a +clear warning so the operator knows speculative decoding has been silently +disabled and can re-place a smaller-quant variant. +""" + +from collections.abc import Iterator + +import pytest +from loguru import logger as loguru_logger + +from exo.master.placement import place_instance +from exo.master.tests.conftest import ( + create_node_memory, + create_node_network, + create_socket_connection, +) +from exo.shared.models.model_cards import ModelCard, ModelId, ModelTask +from exo.shared.topology import Topology +from exo.shared.types.commands import PlaceInstance +from exo.shared.types.common import CommandId, NodeId +from exo.shared.types.memory import Memory +from exo.shared.types.topology import Connection +from exo.shared.types.worker.instances import InstanceMeta +from exo.shared.types.worker.shards import Sharding + + +@pytest.fixture +def loguru_capture() -> Iterator[list[str]]: + """Capture loguru WARNING+ messages into a list (caplog doesn't see loguru).""" + captured: list[str] = [] + sink_id = loguru_logger.add( + lambda message: captured.append(str(message)), level="WARNING" + ) + try: + yield captured + finally: + loguru_logger.remove(sink_id) + + +def _drafter_aware_card(storage_bytes: int) -> ModelCard: + return ModelCard( + model_id=ModelId("mlx-community/gemma-4-31b-it-8bit"), + storage_size=Memory.from_bytes(storage_bytes), + n_layers=60, + hidden_size=5376, + num_key_value_heads=16, + supports_tensor=True, + tasks=[ModelTask.TextGeneration], + family="gemma", + base_model="Gemma 4 31B", + drafter_model_ids=[ + ModelId("mlx-community/gemma-4-e2b-it-8bit"), + ModelId("mlx-community/gemma-4-e4b-it-8bit"), + ], + ) + + +def test_drafter_aware_card_placed_single_node_when_fits( + loguru_capture: list[str], +) -> None: + """When a single node has enough RAM, the model lands on that node and + no warning is emitted -- speculative decoding is preserved.""" + big_node = NodeId() + topology = Topology() + topology.add_node(big_node) + + card = _drafter_aware_card(20_000_000_000) + command = PlaceInstance( + sharding=Sharding.Pipeline, + instance_meta=InstanceMeta.MlxRing, + command_id=CommandId(), + model_card=card, + min_nodes=1, + ) + + placements = place_instance( + command, + topology, + {}, + {big_node: create_node_memory(64_000_000_000)}, + {big_node: create_node_network()}, + ) + assert len(placements) == 1 + instance = next(iter(placements.values())) + assert len(instance.shard_assignments.node_to_runner) == 1 + joined = "\n".join(loguru_capture).lower() + assert "speculative decoding is single-device only" not in joined + + +def test_drafter_aware_card_warns_when_only_multi_node_fits( + loguru_capture: list[str], +) -> None: + """When no single node has enough RAM, placement falls back to multi-node + and warns the operator that the drafter will be silently disabled.""" + node_a, node_b = NodeId(), NodeId() + topology = Topology() + topology.add_node(node_a) + topology.add_node(node_b) + topology.add_connection( + Connection(source=node_a, sink=node_b, edge=create_socket_connection(2)) + ) + topology.add_connection( + Connection(source=node_b, sink=node_a, edge=create_socket_connection(2)) + ) + + # 20 GB target with hidden_size divisible by 2 nodes; only multi-node + # fits (16 GB each). Use Tensor sharding because Gemma 4 doesn't allow + # multi-node Pipeline. + card = _drafter_aware_card(20_000_000_000) + command = PlaceInstance( + sharding=Sharding.Tensor, + instance_meta=InstanceMeta.MlxRing, + command_id=CommandId(), + model_card=card, + min_nodes=1, + ) + + placements = place_instance( + command, + topology, + {}, + { + node_a: create_node_memory(16_000_000_000), + node_b: create_node_memory(16_000_000_000), + }, + { + node_a: create_node_network(), + node_b: create_node_network(), + }, + ) + assert len(placements) == 1 + instance = next(iter(placements.values())) + assert len(instance.shard_assignments.node_to_runner) == 2 + joined = "\n".join(loguru_capture).lower() + assert "speculative decoding is single-device only" in joined + assert "smaller quant" in joined diff --git a/src/exo/routing/mdns_announcer.py b/src/exo/routing/mdns_announcer.py new file mode 100644 index 0000000000..cafd1d3acc --- /dev/null +++ b/src/exo/routing/mdns_announcer.py @@ -0,0 +1,97 @@ +import argparse +import contextlib +import random +import socket +import string +import struct +import sys +import time +from typing import final + + +def _dns_qname(name: bytes) -> bytes: + return b"".join(bytes([len(part)]) + part for part in name.split(b".")) + b"\0" + + +def _build_response_packet(node_id: str, ip_address: str, libp2p_port: int) -> bytes: + service_name = b"_p2p._udp.local" + peer_name = ( + "".join(random.choice(string.ascii_letters + string.digits) for _ in range(32)) + + "._p2p._udp.local" + ).encode() + txt_record = f"dnsaddr=/ip4/{ip_address}/tcp/{libp2p_port}/p2p/{node_id}".encode() + + peer_qname = _dns_qname(peer_name) + packet = bytearray() + packet += struct.pack("!HHHHHH", 0, 0x8400, 0, 1, 0, 1) + packet += _dns_qname(service_name) + packet += struct.pack("!HHI", 12, 1, 120) + packet += struct.pack("!H", len(peer_qname)) + packet += peer_qname + packet += peer_qname + packet += struct.pack("!HHI", 16, 1, 120) + packet += struct.pack("!H", len(txt_record) + 1) + packet += bytes([len(txt_record)]) + packet += txt_record + return bytes(packet) + + +@final +class Args(argparse.Namespace): + node_id: str + ip_address: str + libp2p_port: int + broadcast_address: str | None + count: int + + @staticmethod + def parse() -> "Args": + parser = argparse.ArgumentParser() + parser.add_argument("--node-id", required=True) + parser.add_argument("--ip-address", required=True) + parser.add_argument("--libp2p-port", required=True, type=int) + parser.add_argument("--broadcast-address") + parser.add_argument("--count", default=0, type=int) + return parser.parse_args(namespace=Args()) + + +def main() -> None: + args = Args.parse() + sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + with contextlib.suppress(OSError): + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) + sock.bind((args.ip_address, 0)) + + sent_count = 0 + while True: + packet = _build_response_packet( + args.node_id, args.ip_address, args.libp2p_port + ) + errors: list[str] = [] + destinations: list[tuple[str, int]] = [] + if args.broadcast_address is not None: + destinations.append((args.broadcast_address, 5353)) + destinations.extend([("255.255.255.255", 5353), ("224.0.0.251", 5353)]) + sent = False + for destination in destinations: + try: + sock.sendto(packet, destination) + sent = True + except OSError as err: + errors.append(f"{destination}: {err}") + if not sent: + print( + f"mDNS announcer send failed: {'; '.join(errors)}", + file=sys.stderr, + flush=True, + ) + sent_count += 1 + if args.count > 0 and sent_count >= args.count: + return + time.sleep(1.0 if sent_count < 60 else 10.0) + + +if __name__ == "__main__": + main() diff --git a/src/exo/routing/router.py b/src/exo/routing/router.py index ebe0ea8d90..5b679fe192 100644 --- a/src/exo/routing/router.py +++ b/src/exo/routing/router.py @@ -298,8 +298,6 @@ def get_node_id_keypair( Obtains the :class:`Keypair` associated with this node-ID. Obtain the :class:`PeerId` by from it. """ - # TODO(evan): bring back node id persistence once we figure out how to deal with duplicates - return Keypair.generate() def lock_path(path: str | bytes | PathLike[str] | PathLike[bytes]) -> Path: return Path(str(path) + ".lock") diff --git a/src/exo/shared/apply.py b/src/exo/shared/apply.py index ce3f503537..3f1c99eefc 100644 --- a/src/exo/shared/apply.py +++ b/src/exo/shared/apply.py @@ -9,6 +9,7 @@ ChunkGenerated, CustomModelCardAdded, CustomModelCardDeleted, + DrafterPlacementDegraded, Event, IndexedEvent, InputChunkReceived, @@ -77,6 +78,7 @@ def event_apply(event: Event, state: State) -> State: | TracesMerged() | CustomModelCardAdded() | CustomModelCardDeleted() + | DrafterPlacementDegraded() ): # Pass-through events that don't modify state return state case InstanceCreated(): diff --git a/src/exo/shared/constants.py b/src/exo/shared/constants.py index d79354184b..ad76a88ffb 100644 --- a/src/exo/shared/constants.py +++ b/src/exo/shared/constants.py @@ -8,12 +8,12 @@ def _get_xdg_dir(env_var: str, fallback: str) -> Path: - """Get XDG directory, prioritising EXO_HOME environment variable if its set. On non-Linux platforms, default to ~/.exo.""" + """Get XDG directory, prioritising EXO_HOME environment variable if its set. On non-Linux platforms, default to ~/.exo. Cache home always prefers .cache/exo""" if _EXO_HOME_ENV is not None: return Path.home() / _EXO_HOME_ENV - if sys.platform != "linux": + if sys.platform != "linux" and env_var != "XDG_CACHE_HOME": return Path.home() / ".exo" xdg_value = os.environ.get(env_var, None) @@ -68,10 +68,9 @@ def _parse_colon_dirs(env_var: str) -> tuple[Path, ...]: # Log files (data/logs or cache) EXO_LOG_DIR = EXO_CACHE_HOME / "exo_log" EXO_LOG = EXO_LOG_DIR / "exo.log" -EXO_TEST_LOG = EXO_CACHE_HOME / "exo_test.log" # Identity (config) -EXO_NODE_ID_KEYPAIR = EXO_CONFIG_HOME / "node_id.keypair" +EXO_NODE_ID_KEYPAIR = EXO_CACHE_HOME / "node_id.keypair" EXO_CONFIG_FILE = EXO_CONFIG_HOME / "config.toml" # libp2p topics for event forwarding diff --git a/src/exo/shared/logging.py b/src/exo/shared/logging.py index 039656b9d3..98bfae3d41 100644 --- a/src/exo/shared/logging.py +++ b/src/exo/shared/logging.py @@ -84,6 +84,19 @@ def logger_setup(log_file: Path | None, verbosity: int = 0): logging.basicConfig(handlers=[_InterceptHandler()], level=0) console_level = "INFO" if verbosity == 0 else "DEBUG" + # ``diagnose=False`` disables loguru's "better exceptions" frame-locals + # repr. Leaving it on (the default) means any ``logger.opt(exception=e)`` + # call walks every frame of the traceback and ``repr()``s every local + # variable -- catastrophic when an exception is raised from a frame + # that holds large structured data (e.g. ``mx_all_gather_tasks`` keeps + # a ``padded`` list of per-rank UUID buffers as a local; if a JACCL + # collective corruption blows ``max_tasks`` up to ~1B that ``padded`` + # local becomes a ~1B-element nested int list whose ``list_repr`` is + # *the* hot loop the runner gets stuck in -- 100% CPU on one core, + # ~300 GB peak physical footprint, and every subsequent crash log + # restarts the storm). The compact traceback we still emit (file, + # line, exception message) is enough for diagnosis without ever + # touching frame locals. if verbosity == 0: logger.add( sys.__stderr__, # type: ignore @@ -91,6 +104,7 @@ def logger_setup(log_file: Path | None, verbosity: int = 0): level=console_level, colorize=True, enqueue=True, + diagnose=False, ) else: logger.add( @@ -99,6 +113,7 @@ def logger_setup(log_file: Path | None, verbosity: int = 0): level=console_level, colorize=True, enqueue=True, + diagnose=False, ) if log_file: log_file.parent.mkdir(parents=True, exist_ok=True) @@ -123,6 +138,7 @@ def logger_setup(log_file: Path | None, verbosity: int = 0): rotation=rotation, retention=_MAX_LOG_ARCHIVES, compression=_zstd_compress, + diagnose=False, ) for destination, serialize in ((run_text_log, False), (run_json_log, True)): logger.add( @@ -135,6 +151,7 @@ def logger_setup(log_file: Path | None, verbosity: int = 0): retention=retention, compression=_zstd_compress, serialize=serialize, + diagnose=False, ) logger.info( f"Per-run logs enabled text_log={run_text_log} json_log={run_json_log} " diff --git a/src/exo/shared/models/model_cards.py b/src/exo/shared/models/model_cards.py index 2c379f578c..917b297238 100644 --- a/src/exo/shared/models/model_cards.py +++ b/src/exo/shared/models/model_cards.py @@ -26,7 +26,7 @@ EXO_MODELS_DIRS, RESOURCES_DIR, ) -from exo.shared.types.common import ModelId +from exo.shared.types.common import ModelId, NodeId from exo.shared.types.memory import Memory from exo.shared.types.text_generation import ReasoningDialect from exo.utils.pydantic_ext import FrozenModel @@ -159,11 +159,50 @@ class ModelCard(FrozenModel): is_custom: bool = False vision: VisionCardConfig | None = None sampling_defaults: SamplingDefaults = Field(default_factory=SamplingDefaults) - # Optional speculative-decoding draft model. When set, runners will load the - # named model alongside the target and pass it as `draft_model` to mlx_lm's - # `stream_generate`, enabling MLX-side speculative decoding. The drafter MUST - # share a tokenizer with the target. - drafter_model_id: ModelId | None = None + # Optional speculative-decoding draft models. Listed in *preference order*: + # the first entry is treated as the default ("fastest") choice. Runners pick + # one based on `EXO_DRAFTER_PREFERENCE` (`fastest` / `highest_acceptance` / + # `auto`), falling back to whichever weights are already on disk. All + # listed drafters MUST share a tokenizer with the target. Conventionally + # the list is quant-aligned with the target (e.g. `gemma-4-31b-it-4bit` + # declares `[gemma-4-e2b-it-4bit, gemma-4-e4b-it-4bit]`), but cross-quant + # drafters are allowed for advanced tuning. + drafter_model_ids: list[ModelId] = Field(default_factory=list) + # Nodes the operator has designated as eligible drafter hosts. When this + # list is non-empty AND the model has at least one declared drafter, the + # placement layer attempts asymmetric placement: target ranks land on the + # selected target cycle, the drafter is loaded on the first eligible node + # reachable from target rank 0 (RDMA for `MlxJaccl`, socket for `MlxRing`), + # and the parent `mx.distributed` group spans both. Eligibility is + # *operator-controlled*, not auto-discovered: the operator opts a node in + # by listing its `NodeId` here (typically in a custom card under + # `~/.exo/custom_model_cards/`). If no listed node is reachable, placement + # emits a `DrafterPlacementDegraded` event and falls back -- the user's + # request still completes, the operator just doesn't get the asymmetric + # speedup until they fix the eligibility list. Empty (the default) preserves + # legacy single-device drafter behaviour. + drafter_eligible_nodes: list[NodeId] = Field(default_factory=list) + # Nodes the operator has designated as eligible *prefill-only* hosts for + # this model. When non-empty, placement auto-creates a single-rank + # prefill-only sibling instance on each viable node (sufficient RAM, + # alive in topology, not already a target/drafter rank) and emits an + # ``InstanceLinkCreated`` linking them to the decode instance. The + # master then routes incoming requests' prefill traffic across the + # linked prefill instances by in-flight task count, giving the + # decode instance multi-GPU prefill parallelism for free. + # + # This is the right lever for "I have spare nodes in my cluster -- + # use them for prefill so slot N's TTFT doesn't queue behind slot 0's + # prefill on a single GPU." It composes orthogonally with + # ``drafter_eligible_nodes``: the chosen drafter node is excluded + # from prefill candidates automatically. + # + # Failure modes are loud-but-graceful: if a candidate fails RAM + # feasibility or is unreachable the placement layer skips it and + # logs; the decode instance still comes up. If *no* candidate + # succeeds, no link is emitted and the user's traffic prefills + # locally on the target rank as before. + prefill_eligible_nodes: list[NodeId] = Field(default_factory=list) @model_validator(mode="after") def _autodetect_vision(self) -> "ModelCard": diff --git a/src/exo/shared/tests/test_model_cards_drafter.py b/src/exo/shared/tests/test_model_cards_drafter.py index 302bcd3368..1ac7e027de 100644 --- a/src/exo/shared/tests/test_model_cards_drafter.py +++ b/src/exo/shared/tests/test_model_cards_drafter.py @@ -1,63 +1,124 @@ -"""Tests for the optional `drafter_model_id` field on ModelCard. +"""Tests for the optional `drafter_model_ids` field on ModelCard. -The field declares a speculative-decoding draft model that runners may load -alongside the target. Coverage: +The field declares a preference-ordered list of speculative-decoding draft +models that runners may load alongside the target. Coverage: - ModelCard accepts and serialises the field. -- Cards with no drafter declared default to `None`. -- The Gemma 4 large-instruct cards point to the e2b drafter. +- Cards with no drafters declared default to an empty list. +- Gemma 4 large-instruct cards declare both e2b and e4b drafters at matching + quantisation, in fastest-first order. + +Also covers the asymmetric placement opt-in field +``drafter_eligible_nodes``: empty by default (legacy in-process drafter), +populated to designate per-deployment hosts for drafter-only ranks. The +field round-trips through Pydantic serialisation. """ +from pathlib import Path + import pytest +from anyio import Path as AsyncPath +from exo.shared.models import model_cards from exo.shared.models.model_cards import ModelCard, ModelId, get_model_cards +from exo.shared.types.common import NodeId from exo.shared.types.memory import Memory +@pytest.fixture(autouse=True) +def _isolate_custom_cards( # pyright: ignore[reportUnusedFunction] + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + """Insulate these tests from operator-local custom card overrides. + + ``_custom_cards_dir`` resolves to ``$EXO_DATA_HOME/custom_model_cards``, + which on dev workstations holds operator-edited cards (e.g. trimmed + drafter lists for memory-constrained clusters). Those overrides are + layered on top of the shipped TOML, so without isolation the assertions + below describe whatever the operator last wrote, not the shipped data + the gate is supposed to protect. Reset the in-memory cache too so the + next test refreshes from the now-empty custom dir. + """ + custom_dir = tmp_path / "custom_model_cards" + custom_dir.mkdir() + monkeypatch.setattr(model_cards, "_custom_cards_dir", AsyncPath(custom_dir)) + monkeypatch.setattr(model_cards, "_card_cache", {}) + + @pytest.mark.asyncio -async def test_drafter_model_id_defaults_to_none() -> None: +async def test_drafter_model_ids_defaults_to_empty_list() -> None: cards = {card.model_id: card for card in await get_model_cards()} qwen_id = ModelId("mlx-community/Qwen3-30B-A3B-4bit") if qwen_id in cards: - assert cards[qwen_id].drafter_model_id is None + assert cards[qwen_id].drafter_model_ids == [] + + +def _gemma4_31b_expectations() -> dict[str, list[str]]: + return { + "mlx-community/gemma-4-31b-it-4bit": [ + "mlx-community/gemma-4-e2b-it-4bit", + "mlx-community/gemma-4-e4b-it-4bit", + ], + "mlx-community/gemma-4-31b-it-6bit": [ + "mlx-community/gemma-4-e2b-it-6bit", + "mlx-community/gemma-4-e4b-it-6bit", + ], + "mlx-community/gemma-4-31b-it-8bit": [ + "mlx-community/gemma-4-e2b-it-8bit", + "mlx-community/gemma-4-e4b-it-8bit", + ], + "mlx-community/gemma-4-31b-it-bf16": [ + "mlx-community/gemma-4-e2b-it-bf16", + "mlx-community/gemma-4-e4b-it-bf16", + ], + } + + +def _gemma4_26b_expectations() -> dict[str, list[str]]: + return { + "mlx-community/gemma-4-26b-a4b-it-4bit": [ + "mlx-community/gemma-4-e2b-it-4bit", + "mlx-community/gemma-4-e4b-it-4bit", + ], + "mlx-community/gemma-4-26b-a4b-it-6bit": [ + "mlx-community/gemma-4-e2b-it-6bit", + "mlx-community/gemma-4-e4b-it-6bit", + ], + "mlx-community/gemma-4-26b-a4b-it-8bit": [ + "mlx-community/gemma-4-e2b-it-8bit", + "mlx-community/gemma-4-e4b-it-8bit", + ], + "mlx-community/gemma-4-26b-a4b-it-bf16": [ + "mlx-community/gemma-4-e2b-it-bf16", + "mlx-community/gemma-4-e4b-it-bf16", + ], + } @pytest.mark.asyncio -async def test_gemma4_31b_cards_declare_e2b_drafter() -> None: +async def test_gemma4_31b_cards_declare_e2b_then_e4b_drafters() -> None: cards = {card.model_id: card for card in await get_model_cards()} - expectations = { - "mlx-community/gemma-4-31b-it-4bit": "mlx-community/gemma-4-e2b-it-4bit", - "mlx-community/gemma-4-31b-it-6bit": "mlx-community/gemma-4-e2b-it-6bit", - "mlx-community/gemma-4-31b-it-8bit": "mlx-community/gemma-4-e2b-it-8bit", - "mlx-community/gemma-4-31b-it-bf16": "mlx-community/gemma-4-e2b-it-bf16", - } - for target_str, expected_drafter_str in expectations.items(): + for target_str, expected_drafters in _gemma4_31b_expectations().items(): target_id = ModelId(target_str) assert target_id in cards, f"{target_id} card missing" card = cards[target_id] - assert card.drafter_model_id == ModelId(expected_drafter_str), ( - f"{target_id} drafter mismatch: got {card.drafter_model_id!r}" + assert card.drafter_model_ids == [ModelId(d) for d in expected_drafters], ( + f"{target_id} drafter mismatch: got {card.drafter_model_ids!r}" ) @pytest.mark.asyncio -async def test_gemma4_26b_cards_declare_e2b_drafter() -> None: +async def test_gemma4_26b_cards_declare_e2b_then_e4b_drafters() -> None: cards = {card.model_id: card for card in await get_model_cards()} - expectations = { - "mlx-community/gemma-4-26b-a4b-it-4bit": "mlx-community/gemma-4-e2b-it-4bit", - "mlx-community/gemma-4-26b-a4b-it-6bit": "mlx-community/gemma-4-e2b-it-6bit", - "mlx-community/gemma-4-26b-a4b-it-8bit": "mlx-community/gemma-4-e2b-it-8bit", - "mlx-community/gemma-4-26b-a4b-it-bf16": "mlx-community/gemma-4-e2b-it-bf16", - } - for target_str, expected_drafter_str in expectations.items(): + for target_str, expected_drafters in _gemma4_26b_expectations().items(): target_id = ModelId(target_str) assert target_id in cards, f"{target_id} card missing" card = cards[target_id] - assert card.drafter_model_id == ModelId(expected_drafter_str), ( - f"{target_id} drafter mismatch: got {card.drafter_model_id!r}" + assert card.drafter_model_ids == [ModelId(d) for d in expected_drafters], ( + f"{target_id} drafter mismatch: got {card.drafter_model_ids!r}" ) -def test_model_card_explicit_drafter_round_trip() -> None: +def test_model_card_explicit_drafters_round_trip() -> None: card = ModelCard( model_id=ModelId("mlx-community/test-target"), storage_size=Memory.from_gb(1.0), @@ -65,8 +126,50 @@ def test_model_card_explicit_drafter_round_trip() -> None: hidden_size=768, supports_tensor=True, tasks=["TextGeneration"], # pyright: ignore[reportArgumentType] - drafter_model_id=ModelId("mlx-community/test-drafter"), + drafter_model_ids=[ + ModelId("mlx-community/test-drafter-fast"), + ModelId("mlx-community/test-drafter-accurate"), + ], + ) + assert card.drafter_model_ids == [ + ModelId("mlx-community/test-drafter-fast"), + ModelId("mlx-community/test-drafter-accurate"), + ] + dump = card.model_dump(exclude_none=True) + assert dump["drafter_model_ids"] == [ + "mlx-community/test-drafter-fast", + "mlx-community/test-drafter-accurate", + ] + + +def test_drafter_eligible_nodes_defaults_to_empty() -> None: + card = ModelCard( + model_id=ModelId("mlx-community/test-target-2"), + storage_size=Memory.from_gb(1.0), + n_layers=12, + hidden_size=768, + supports_tensor=True, + tasks=["TextGeneration"], # pyright: ignore[reportArgumentType] + ) + assert card.drafter_eligible_nodes == [] + dump = card.model_dump(exclude_none=True) + assert dump["drafter_eligible_nodes"] == [] + + +def test_drafter_eligible_nodes_round_trip() -> None: + eligible = [NodeId(), NodeId()] + card = ModelCard( + model_id=ModelId("mlx-community/test-target-3"), + storage_size=Memory.from_gb(1.0), + n_layers=12, + hidden_size=768, + supports_tensor=True, + tasks=["TextGeneration"], # pyright: ignore[reportArgumentType] + drafter_model_ids=[ModelId("mlx-community/test-drafter")], + drafter_eligible_nodes=eligible, ) - assert card.drafter_model_id == ModelId("mlx-community/test-drafter") + assert card.drafter_eligible_nodes == eligible dump = card.model_dump(exclude_none=True) - assert dump["drafter_model_id"] == "mlx-community/test-drafter" + assert dump["drafter_eligible_nodes"] == eligible + rehydrated = ModelCard.model_validate(dump) + assert rehydrated.drafter_eligible_nodes == eligible diff --git a/src/exo/shared/tests/test_xdg_paths.py b/src/exo/shared/tests/test_xdg_paths.py index f3b82ebffd..dce2c7d7c1 100644 --- a/src/exo/shared/tests/test_xdg_paths.py +++ b/src/exo/shared/tests/test_xdg_paths.py @@ -94,7 +94,27 @@ def test_macos_uses_traditional_paths(): home = Path.home() assert home / ".exo" == constants.EXO_CONFIG_HOME assert home / ".exo" == constants.EXO_DATA_HOME - assert home / ".exo" == constants.EXO_CACHE_HOME + assert home / ".cache" / "exo" == constants.EXO_CACHE_HOME + + +def test_exo_home_env(): + """Test that macOS uses traditional ~/.exo directory.""" + # Remove EXO_HOME to ensure we test the default behavior + env = {k: v for k, v in os.environ.items() if k != "EXO_HOME"} + env["EXO_HOME"] = "/exo" + with ( + mock.patch.dict(os.environ, env, clear=True), + mock.patch.object(sys, "platform", "darwin"), + ): + import importlib + + import exo.shared.constants as constants + + importlib.reload(constants) + + assert Path("/exo") == constants.EXO_CONFIG_HOME + assert Path("/exo") == constants.EXO_DATA_HOME + assert Path("/exo") == constants.EXO_CACHE_HOME def test_node_id_in_config_dir(): diff --git a/src/exo/shared/types/events.py b/src/exo/shared/types/events.py index 01aa0ce5dc..9b4af266f1 100644 --- a/src/exo/shared/types/events.py +++ b/src/exo/shared/types/events.py @@ -1,5 +1,6 @@ from datetime import datetime -from typing import final +from enum import Enum +from typing import Literal, final from pydantic import Field @@ -146,6 +147,58 @@ class InstanceLinkDeleted(BaseEvent): link_id: InstanceLinkId +class DrafterPlacementDegradationReason(str, Enum): + """Why placement could not honour a model's ``drafter_eligible_nodes``. + + Surfaced on :class:`DrafterPlacementDegraded` so the operator can see + *why* their asymmetric drafter placement was downgraded to legacy + single-device (or no) drafter, without crawling worker logs. + """ + + NoEligibleNodeAvailable = "NoEligibleNodeAvailable" + """No eligible node is alive in the topology (eligibility list refers + to nodes that are missing/timed-out).""" + + AllEligibleNodesInTargetCycle = "AllEligibleNodesInTargetCycle" + """Every listed eligible node is already a target rank, so there's no + spare host to land the drafter on.""" + + NoReachablePathFromTargetRankZero = "NoReachablePathFromTargetRankZero" + """``MlxRing`` requires a socket connection from target rank 0 to the + drafter node; ``MlxJaccl`` requires an RDMA edge. None of the + eligible nodes provided one.""" + + InsufficientDrafterMemory = "InsufficientDrafterMemory" + """The first reachable eligible node lacks enough RAM for the chosen + drafter weights.""" + + +@final +class DrafterPlacementDegraded(BaseEvent): + """Loud-but-graceful telemetry: asymmetric drafter requested, denied. + + Emitted by the master when a model card declares + ``drafter_eligible_nodes`` but the placement layer cannot satisfy + the asymmetric topology. The corresponding ``InstanceCreated`` is + still emitted in the same step -- the user's request still + completes, just without the asymmetric speedup -- so the operator + sees both events and knows their cluster needs adjusting (e.g. + bring an eligible node online, free its RAM, fix the network + edge). + + State transition: pass-through. No state mutation; this exists + purely for dashboard/CLI surfacing. + """ + + model_id: ModelId + instance_id: InstanceId | None = None + target_node_ids: list[NodeId] + eligible_nodes: list[NodeId] + reason: DrafterPlacementDegradationReason + fallback: Literal["single_device_drafter", "no_drafter"] + detail: str = "" + + Event = ( TestEvent | TaskCreated @@ -169,6 +222,7 @@ class InstanceLinkDeleted(BaseEvent): | CustomModelCardDeleted | InstanceLinkCreated | InstanceLinkDeleted + | DrafterPlacementDegraded ) diff --git a/src/exo/shared/types/text_generation.py b/src/exo/shared/types/text_generation.py index ccb2512a53..cc1cd6e5d3 100644 --- a/src/exo/shared/types/text_generation.py +++ b/src/exo/shared/types/text_generation.py @@ -134,6 +134,40 @@ class TextGenerationTaskParams(BaseModel, frozen=True): prefill_endpoint: str | None = None + # Speculative-decoding per-request overrides. All default to `None`, + # meaning "use the runner's configured defaults". + # + # ``use_drafter=False`` forces non-speculative decoding for this + # request only -- useful for latency-sensitive paths where the + # drafter's prefill overhead isn't worth the throughput win. + # Equivalent to ``draft_mode="none"``; provided as a convenience for + # callers that don't want to think about drafter modes. + # + # ``num_draft_tokens`` lets the client tune K per-request (e.g. raise + # K for long completions, lower for short structured outputs). + # + # ``draft_mode`` selects between speculative-decoding strategies: + # - ``"model"``: external drafter model (Gemma-4 e2b/e4b style) + # via ``mlx_lm.speculative_generate_step``. Best for slow / + # distributed targets; usually a net loss for fast + # single-device targets. + # - ``"pipelined"``: same drafter, but routed through exo's + # custom :class:`PipelinedModelDrafter` with cross-round + # speculation. Transport (in-process or remote drafter rank + # via ``mx.distributed.send/recv`` over JACCL/RDMA or + # ring/TCP) is selected by ``EXO_DRAFTER_TRANSPORT``. The + # remote-transport case is the regime where the gain unlocks. + # - ``"ngram"``: in-context suffix lookup (no drafter model). + # Wins on RAG, summarisation, structured/code output where the + # model echoes prompt content. Cost ~0 when no match is found, + # so worst-case = baseline. + # - ``"none"``: skip speculation entirely. + # If both ``draft_mode`` and ``use_drafter=False`` are set, the + # explicit ``draft_mode`` wins. + use_drafter: bool | None = None + num_draft_tokens: int | None = None + draft_mode: Literal["model", "pipelined", "ngram", "none"] | None = None + def with_card_sampling_defaults(self) -> "TextGenerationTaskParams": from exo.shared.models.model_cards import get_card diff --git a/src/exo/shared/types/worker/instances.py b/src/exo/shared/types/worker/instances.py index 16233f3f05..568923ecaf 100644 --- a/src/exo/shared/types/worker/instances.py +++ b/src/exo/shared/types/worker/instances.py @@ -1,9 +1,10 @@ from enum import Enum +from typing import final -from pydantic import model_validator +from pydantic import Field, model_validator from exo.shared.models.model_cards import ModelTask -from exo.shared.types.common import Host, Id, NodeId +from exo.shared.types.common import Host, Id, ModelId, NodeId from exo.shared.types.worker.runners import RunnerId, ShardAssignments, ShardMetadata from exo.utils.pydantic_ext import FrozenModel, TaggedModel @@ -17,13 +18,150 @@ class InstanceMeta(str, Enum): MlxJaccl = "MlxJaccl" +@final +class DrafterPlacement(FrozenModel): + """Locator for an asymmetric drafter rank inside an :class:`Instance`. + + The drafter runs on a separate node from the target ranks. It is + intentionally NOT a member of the target ranks' + ``mx.distributed.Group``: the target group is target-only, and + drafter <-> target IPC flows over a direct TCP socket established + at instance bootstrap. Decoupling the drafter from + ``mx.distributed`` lets target ranks of any size use TP/PP + collectives without requiring ``Group.split`` (which jaccl/ring + backends do not implement on Apple Silicon). + + Convention: ``drafter_rank`` is preserved as a logical placement + index (always equal to ``len(target_ranks)``) for telemetry and + tests, but no longer corresponds to a rank inside an + ``mx.distributed.Group``. The drafter dials + ``drafter_socket_host:drafter_socket_port`` to reach target rank 0; + target rank 0 binds and listens on that endpoint at instance + bootstrap. + + Fields: + drafter_node_id: Where the drafter runner lives. + drafter_runner_id: Identifies the drafter runner; the bootstrap + checks ``bound_runner_id == drafter_runner_id`` + to switch into drafter-only loading mode and + enter the drafter serve loop instead of the + normal generation engine. + drafter_model_id: Which drafter weights to load. Must be one + of the entries in the target's + ``ModelCard.drafter_model_ids`` list + (placement enforces this invariant). + drafter_rank: Logical placement index of the drafter + inside the conceptual parent group + (target_world_size). Retained for + placement bookkeeping; not a real + ``mx.distributed`` rank in the v3+ wire. + drafter_socket_host: Host (LAN/Thunderbolt-bridge IP or + hostname) target rank 0 advertises for + the drafter wire. The drafter dials this + host to reach target rank 0. + drafter_socket_port: TCP port target rank 0 binds on for + drafter wire ops. Allocated at placement + time; the runner bootstrap binds that + specific port (failure is a hard error). + target_peer_socket_port: TCP port target rank 0 binds on for + *inter-target-rank* spec-decode int + broadcasts. Distinct from + ``drafter_socket_port`` because the drafter + dials in over a different IP than the + other target ranks; sharing a port would + collide. Empty for single-target instances + (no peer to broadcast to). + target_peer_hosts_by_rank: For each non-zero target rank, + the IP that rank uses to dial target rank + 0 over the inter-target socket wire. + Resolved at placement time via + :func:`find_ip_prioritised`; differs + per peer because Thunderbolt /30 meshes + expose a unique IP per node pair. Keys + are device ranks **stored as strings** + so the type round-trips cleanly through + JSON (the wire format used by + :mod:`event_router`); ``dict[int, str]`` + would fail strict re-validation because + JSON has no int dict keys. Convert to + int at the consumer (see + :func:`_maybe_setup_target_peer_fanout`). + """ + + drafter_node_id: NodeId + drafter_runner_id: RunnerId + drafter_model_id: ModelId + drafter_rank: int = Field(ge=0) + drafter_socket_host: str + drafter_socket_port: int = Field(ge=1, le=65535) + target_peer_socket_port: int = Field(ge=1, le=65535) + target_peer_hosts_by_rank: dict[str, str] = Field(default_factory=dict) + + class BaseInstance(TaggedModel): instance_id: InstanceId shard_assignments: ShardAssignments + # When set, this instance places the drafter on a separate node from + # the target ranks and routes drafter/verify IPC over a direct TCP + # socket (see :class:`DrafterPlacement`). ``None`` (the default) + # preserves legacy symmetric placement: every rank in + # ``shard_assignments`` runs a target shard, and any drafter + # declared on the model card is loaded in-process alongside the + # target on the single-device cycle. + drafter_placement: DrafterPlacement | None = None def shard(self, runner_id: RunnerId) -> ShardMetadata | None: return self.shard_assignments.runner_to_shard.get(runner_id, None) + @property + def parent_group_size(self) -> int: + """Size of the target ranks' ``mx.distributed`` group. + + Always equals ``len(shard_assignments.runner_to_shard)``: in + the v3+ asymmetric wire the drafter rank does NOT join the + target ``mx.distributed.Group`` (it talks to target rank 0 via + a direct TCP socket). Symmetric and asymmetric placement + therefore both report the same size here, equal to the number + of target shards. + """ + return len(self.shard_assignments.runner_to_shard) + + def is_drafter_runner(self, runner_id: RunnerId) -> bool: + return ( + self.drafter_placement is not None + and self.drafter_placement.drafter_runner_id == runner_id + ) + + @property + def all_runner_ids(self) -> list[RunnerId]: + """Every runner id participating in this instance, target + drafter. + + Lifecycle barriers (ConnectToGroup, LoadModel, StartWarmup, + Ready) wait on the *whole* parent group, so plan-time readiness + checks iterate this list. Generation tasks themselves are + target-only and iterate ``shard_assignments.runner_to_shard`` + directly. + """ + runners = list(self.shard_assignments.runner_to_shard.keys()) + if self.drafter_placement is not None: + runners.append(self.drafter_placement.drafter_runner_id) + return runners + + @property + def all_node_to_runner(self) -> dict[NodeId, RunnerId]: + """Per-node runner id including the drafter rank when asymmetric. + + Worker plan iterates this when deciding which node should spawn + which runner. Symmetric placement returns the legacy + ``shard_assignments.node_to_runner`` mapping unchanged. + """ + result = dict(self.shard_assignments.node_to_runner) + if self.drafter_placement is not None: + result[self.drafter_placement.drafter_node_id] = ( + self.drafter_placement.drafter_runner_id + ) + return result + class MlxRingInstance(BaseInstance): hosts_by_node: dict[NodeId, list[Host]] @@ -44,24 +182,67 @@ class BoundInstance(FrozenModel): bound_runner_id: RunnerId bound_node_id: NodeId + @property + def is_drafter_rank(self) -> bool: + """``True`` when this runner serves the drafter, not a target shard. + + Callers that read ``bound_shard``, ``is_image_model``, or any + target-shard-derived property MUST branch on this first; those + properties raise on a drafter-rank bound instance because the + drafter has no target shard. + """ + return self.instance.is_drafter_runner(self.bound_runner_id) + + @property + def parent_rank(self) -> int: + """This runner's rank inside the parent ``mx.distributed`` group. + + Target ranks read it from their bound shard's ``device_rank``; + the drafter rank reads it from + ``DrafterPlacement.drafter_rank``. Plan-time connect/warmup + ordering checks use this so the same predicate works for both + symmetric (drafter rank doesn't exist) and asymmetric (drafter + is rank ``parent_group_size - 1``) placement. + """ + if self.is_drafter_rank: + placement = self.instance.drafter_placement + assert placement is not None # type narrowed by is_drafter_rank + return placement.drafter_rank + return self.bound_shard.device_rank + @property def bound_shard(self) -> ShardMetadata: shard = self.instance.shard(self.bound_runner_id) - assert shard is not None + assert shard is not None, ( + "bound_shard is only defined for target ranks; " + "check `is_drafter_rank` before reading it" + ) return shard @property def is_image_model(self) -> bool: + if self.is_drafter_rank: + return False return ( ModelTask.TextToImage in self.bound_shard.model_card.tasks or ModelTask.ImageToImage in self.bound_shard.model_card.tasks ) @model_validator(mode="after") - def validate_shard_exists(self) -> "BoundInstance": - assert ( - self.bound_runner_id in self.instance.shard_assignments.runner_to_shard - ), ( - "Bound Instance must be constructed with a runner_id that is in the instances assigned shards" + def validate_runner_known(self) -> "BoundInstance": + if self.bound_runner_id in self.instance.shard_assignments.runner_to_shard: + return self + if self.instance.is_drafter_runner(self.bound_runner_id): + placement = self.instance.drafter_placement + assert placement is not None # type narrowed by is_drafter_runner + assert self.bound_node_id == placement.drafter_node_id, ( + f"Drafter runner {self.bound_runner_id} bound to node " + f"{self.bound_node_id}, but DrafterPlacement points to " + f"{placement.drafter_node_id}" + ) + return self + raise AssertionError( + f"bound_runner_id {self.bound_runner_id} is neither a target rank " + f"in shard_assignments nor the drafter rank declared by " + f"instance.drafter_placement" ) - return self diff --git a/src/exo/utils/keyed_backoff.py b/src/exo/utils/keyed_backoff.py index 4d7c9a66ed..a95fe5c5f7 100644 --- a/src/exo/utils/keyed_backoff.py +++ b/src/exo/utils/keyed_backoff.py @@ -29,6 +29,10 @@ def attempts(self, key: K) -> int: """Return the number of recorded attempts for a key.""" return self._attempts.get(key, 0) + def tracked_keys(self) -> set[K]: + """Return keys that currently have recorded backoff state.""" + return set(self._attempts) | set(self._last_time) + def reset(self, key: K) -> None: """Reset backoff state for a key (e.g., on success).""" self._attempts.pop(key, None) diff --git a/src/exo/utils/tests/test_keyed_backoff.py b/src/exo/utils/tests/test_keyed_backoff.py new file mode 100644 index 0000000000..b592a4fabd --- /dev/null +++ b/src/exo/utils/tests/test_keyed_backoff.py @@ -0,0 +1,13 @@ +from exo.utils.keyed_backoff import KeyedBackoff + + +def test_tracked_keys_reports_and_resets_backoff_state() -> None: + backoff = KeyedBackoff[str]() + + backoff.record_attempt("instance-a") + + assert backoff.tracked_keys() == {"instance-a"} + + backoff.reset("instance-a") + + assert backoff.tracked_keys() == set() diff --git a/src/exo/worker/engines/image/builder.py b/src/exo/worker/engines/image/builder.py index 4d20fd887f..c75f49c6c7 100644 --- a/src/exo/worker/engines/image/builder.py +++ b/src/exo/worker/engines/image/builder.py @@ -104,7 +104,9 @@ class MfluxBuilder(Builder): group: mx.distributed.Group | None = None def connect(self, bound_instance: BoundInstance) -> None: - self.group = initialize_mlx(bound_instance) + # Image generation models never declare a drafter, so target + # subgroup == parent group; the symmetric case of MlxGroupSplit. + self.group = initialize_mlx(bound_instance).target_subgroup def load(self, bound_instance: BoundInstance) -> Generator[ModelLoadingResponse]: self.shard_metadata = bound_instance.bound_shard diff --git a/src/exo/worker/engines/mlx/builder.py b/src/exo/worker/engines/mlx/builder.py index d2bc588f16..b2548d7269 100644 --- a/src/exo/worker/engines/mlx/builder.py +++ b/src/exo/worker/engines/mlx/builder.py @@ -1,11 +1,14 @@ import contextlib import os +import socket from collections.abc import Generator from dataclasses import dataclass +from typing import cast import mlx.core as mx from mlx_lm.tokenizer_utils import TokenizerWrapper +from exo.shared.constants import EXO_MAX_CONCURRENT_REQUESTS from exo.shared.types.common import ModelId from exo.shared.types.events import Event from exo.shared.types.tasks import TaskId @@ -15,12 +18,19 @@ from exo.worker.engines.base import Builder, Engine from exo.worker.runner.bootstrap import logger from exo.worker.runner.llm_inference.batch_generator import ( + DEFAULT_DRAFTER_MIN_OUTPUT_TOKENS, + DEFAULT_NUM_DRAFT_TOKENS, + EXO_ADAPTIVE_DRAFT_TOKENS, + EXO_DRAFTER_MIN_OUTPUT_TOKENS, + EXO_NUM_DRAFT_TOKENS, BatchGenerator, SequentialGenerator, + parse_env_int, ) from exo.worker.runner.llm_inference.tool_parsers import make_mlx_parser from .cache import KVPrefixCache +from .generator.drafter import EXO_DRAFT_MODE_ENV, parse_draft_mode from .types import Model from .utils_mlx import ( initialize_mlx, @@ -36,12 +46,42 @@ class MlxBuilder(Builder): cancel_receiver: MpReceiver[TaskId] inference_model: Model | None = None tokenizer: TokenizerWrapper | None = None + # ``group`` is the target ranks' ``mx.distributed.Group``: pipeline + # / tensor / batch collectives all run on it. Under the v3+ wire + # the drafter is NOT a member of this group (asymmetric drafters + # talk to target rank 0 over a TCP socket; see ``drafter_socket`` + # below). group: mx.distributed.Group | None = None + # Connected TCP socket from target rank 0 to the drafter rank. + # Set ONLY on target rank 0 of an asymmetric placement; ``None`` + # everywhere else (other target ranks don't drive drafter IPC, and + # single-device / symmetric multi-rank builds have no drafter + # wire at all). + drafter_socket: socket.socket | None = None + drafter_rank_in_parent: int | None = None + # Inter-target-rank TCP fanout for the spec-decode int-broadcast + # wire. Allocated by :func:`initialize_mlx` on multi-target + # asymmetric placements; ``None`` for single-target / symmetric + # builds. See :class:`TargetPeerFanout`. + target_peer_fanout: object | None = None vision_processor: VisionProcessor | None = None draft_model: Model | None = None + draft_model_id: ModelId | None = None def connect(self, bound_instance: BoundInstance) -> None: - self.group = initialize_mlx(bound_instance) + split = initialize_mlx(bound_instance) + self.group = split.target_subgroup + # Only target rank 0 in an asymmetric placement holds a drafter + # socket; every other rank sees ``None`` here. ``MlxGroupSplit`` + # types it as ``object | None`` to keep the dataclass importable + # without ``socket``; cast back to the concrete type for + # consumers. + if split.drafter_socket is not None: + self.drafter_socket = cast(socket.socket, split.drafter_socket) + else: + self.drafter_socket = None + self.drafter_rank_in_parent = split.drafter_rank_in_parent + self.target_peer_fanout = split.target_peer_fanout def load(self, bound_instance: BoundInstance) -> Generator[ModelLoadingResponse]: ( @@ -49,6 +89,7 @@ def load(self, bound_instance: BoundInstance) -> Generator[ModelLoadingResponse] self.tokenizer, self.vision_processor, self.draft_model, + self.draft_model_id, ) = yield from load_mlx_items(bound_instance, self.group) def close(self) -> None: @@ -58,6 +99,10 @@ def close(self) -> None: del self.tokenizer with contextlib.suppress(NameError, AttributeError): del self.group + if self.drafter_socket is not None: + with contextlib.suppress(OSError): + self.drafter_socket.close() + self.drafter_socket = None with contextlib.suppress(NameError, AttributeError): del self.draft_model @@ -85,25 +130,146 @@ def build( ) kv_prefix_cache = KVPrefixCache(self.group) + # Item 6: dedicated KVPrefixCache for the drafter so multi-turn + # workloads don't repeatedly prefill the drafter on the same prefix. + # Allocated only when a drafter is actually loaded; None means + # mlx_generate falls back to the per-request drafter prefill. + drafter_kv_prefix_cache: KVPrefixCache | None = ( + KVPrefixCache(self.group) if self.draft_model is not None else None + ) device_rank = 0 if self.group is None else self.group.rank() - # Speculative decoding currently only flows through `mlx_generate` -> - # `stream_generate(draft_model=...)`, which is the SequentialGenerator - # path. Upstream `mlx_lm.generate.BatchGenerator` does not accept a - # draft model. Force the sequential path when a drafter is loaded so - # the user actually gets speculative decoding instead of silently - # falling through to non-speculative batching. - force_sequential_for_drafter = self.draft_model is not None + # Speculative decoding (model or n-gram) currently flows only through + # SequentialGenerator -> mlx_generate. Upstream BatchGenerator does + # not accept a draft model and has no hook for n-gram drafting, so + # force the sequential path whenever speculative decoding could + # plausibly run for any request: a drafter model is loaded *or* + # ``EXO_DRAFT_MODE=ngram`` is set process-wide. Per-request + # overrides (``TaskParams.draft_mode``) only apply within the + # surface that the chosen generator exposes. + configured_draft_mode = parse_draft_mode( + os.environ.get(EXO_DRAFT_MODE_ENV), + default="model" if self.draft_model is not None else "none", + ) + force_sequential_for_drafter = ( + self.draft_model is not None + or configured_draft_mode in ("ngram", "pipelined") + ) + + # Asymmetric placement: drafter lives on a separate node; only + # target rank 0 owns the drafter wire (``drafter_socket``). + # Force the SequentialGenerator path (BatchGenerator has no + # spec-decoding hook) and build a long-lived RemoteTransport + # that the spec loop reuses across requests. + # + # Other target ranks in an asymmetric placement (rank >= 1) see + # ``drafter_socket is None`` and treat their build the same as + # symmetric multi-rank: they participate in target collectives + # but never call drafter ops directly. The spec loop's + # rank-0-only sampling decision keeps that invariant. + is_asymmetric_target_rank_zero = self.drafter_socket is not None + # Long-lived ``RemoteTransport`` (NOT a per-task DrafterTransport). + # Each in-flight request opens its own session via + # :meth:`RemoteTransport.open_session`; the session handle is the + # actual DrafterTransport consumed by the spec loop. See + # ``remote_drafter.py`` module docstring for the wire-protocol + # session multiplexing rationale. + from exo.worker.engines.mlx.generator.remote_drafter import RemoteTransport + + remote_drafter_transport: RemoteTransport | None = None + if is_asymmetric_target_rank_zero: + assert self.drafter_socket is not None + from exo.worker.engines.mlx.generator.remote_drafter import ( + make_remote_transport, + ) + + num_draft_tokens_remote = parse_env_int( + EXO_NUM_DRAFT_TOKENS, DEFAULT_NUM_DRAFT_TOKENS + ) + target_world_size = self.group.size() if self.group is not None else 1 + logger.info( + "Allocating long-lived RemoteTransport: " + f"target_world_size={target_world_size} " + f"drafter_rank={self.drafter_rank_in_parent} " + f"K={num_draft_tokens_remote} " + f"transport=tcp_socket" + ) + remote_drafter_transport = make_remote_transport( + draft_model=None, + draft_cache=None, + num_draft_tokens=num_draft_tokens_remote, + sock=self.drafter_socket, + ) + + # Asymmetric "is the cluster speculative-decoding-aware" check. + # Used below to force ``SequentialGenerator`` and to log mode + # selection. Non-zero ranks of an asymmetric instance do NOT + # set this flag (they don't own the wire) but they still enter + # the same generator path because the placement-time decision + # to enable the drafter is uniform across target ranks. + is_asymmetric = ( + is_asymmetric_target_rank_zero or self.drafter_rank_in_parent is not None + ) - if os.environ.get("EXO_NO_BATCH") or force_sequential_for_drafter: - if force_sequential_for_drafter: + if ( + os.environ.get("EXO_NO_BATCH") + or force_sequential_for_drafter + or is_asymmetric + ): + if is_asymmetric: + logger.info( + "using SequentialGenerator (asymmetric placement: " + "drafter lives on a separate MLX rank, pipelined+remote spec)" + ) + elif force_sequential_for_drafter: logger.info( - "using SequentialGenerator (drafter loaded; " - "BatchGenerator does not support speculative decoding)" + f"using SequentialGenerator (draft_mode={configured_draft_mode!r}; " + f"BatchGenerator has no spec-decoding hook)" ) else: logger.info("using SequentialGenerator (batching disabled)") + + num_draft_tokens = parse_env_int( + EXO_NUM_DRAFT_TOKENS, DEFAULT_NUM_DRAFT_TOKENS + ) + drafter_min_output_tokens = parse_env_int( + EXO_DRAFTER_MIN_OUTPUT_TOKENS, + DEFAULT_DRAFTER_MIN_OUTPUT_TOKENS, + minimum=0, + ) + adaptive_draft_tokens = os.environ.get( + EXO_ADAPTIVE_DRAFT_TOKENS, "" + ).lower() in {"1", "true", "yes"} + if force_sequential_for_drafter or is_asymmetric: + logger.info( + f"speculative decoding: mode={'pipelined+remote' if is_asymmetric else configured_draft_mode}, " + f"K={num_draft_tokens} (adaptive={adaptive_draft_tokens}), " + f"skip_drafter_when_max_tokens<={drafter_min_output_tokens}" + ) + + # Concurrent in-flight tasks. Asymmetric pipelined+remote + # rides the same ``EXO_MAX_CONCURRENT_REQUESTS`` cap as every + # other config now that the wire protocol carries a + # ``session_id`` slot: each in-flight target request opens + # its own ``_SessionHandle`` via + # ``RemoteTransport.open_session()`` and the drafter rank + # multiplexes per-session KV caches. The wire stays serial + # (single ``ThreadPoolExecutor`` on the target, single recv + # loop on the drafter) so ``mx.distributed.send/recv`` + # ordering is preserved; concurrency comes from interleaving + # forward / verify rounds across sessions, which is the + # whole point of asymmetric placement -- keep the drafter + # rank busy serving session A while the target verifies + # session B's drafts. + max_concurrent_tasks = EXO_MAX_CONCURRENT_REQUESTS + if max_concurrent_tasks > 1: + logger.info( + f"SequentialGenerator round-robin concurrency: " + f"max_concurrent_tasks={max_concurrent_tasks} " + f"(EXO_MAX_CONCURRENT_REQUESTS)" + ) + return SequentialGenerator( model=self.inference_model, tokenizer=self.tokenizer, @@ -116,6 +282,15 @@ def build( event_sender=self.event_sender, vision_processor=vision_processor, draft_model=self.draft_model, + draft_model_id=self.draft_model_id, + drafter_kv_prefix_cache=drafter_kv_prefix_cache, + num_draft_tokens=num_draft_tokens, + drafter_min_output_tokens=drafter_min_output_tokens, + adaptive_draft_tokens=adaptive_draft_tokens, + drafter_rank_in_parent=self.drafter_rank_in_parent, + remote_drafter_transport=remote_drafter_transport, + target_peer_fanout=self.target_peer_fanout, + max_concurrent_tasks=max_concurrent_tasks, ) else: logger.info("using BatchGenerator") diff --git a/src/exo/worker/engines/mlx/cache.py b/src/exo/worker/engines/mlx/cache.py index 7cdcc77fbe..df8d7c5dc4 100644 --- a/src/exo/worker/engines/mlx/cache.py +++ b/src/exo/worker/engines/mlx/cache.py @@ -1,7 +1,7 @@ import gc import os from copy import deepcopy -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, cast import mlx.core as mx import numpy as np @@ -358,7 +358,12 @@ def get_kv_cache( best_index, best_length = i, length if best_index is None: - return make_kv_cache(model), prompt_tokens, None, False + return ( + make_kv_cache(model), + prompt_tokens, + None, + False, + ) # For exact match: trim to max_length-1 so remaining has the last token # For partial match: trim to best_length, remaining has suffix to prefill @@ -374,7 +379,12 @@ def get_kv_cache( # No usable snapshot — need fresh cache if restore_snap is None and has_ssm: - return make_kv_cache(model), prompt_tokens, None, False + return ( + make_kv_cache(model), + prompt_tokens, + None, + False, + ) prompt_cache = deepcopy(self.caches[best_index]) tokens_to_trim = cached_length - restore_pos @@ -558,13 +568,26 @@ def get_memory_used_percentage() -> float: def make_kv_cache( - model: Model, max_kv_size: int | None = None, keep: int = 0 + model: Model, + max_kv_size: int | None = None, + keep: int = 0, ) -> KVCacheType: + """Build a KV cache for ``model``. + + Honors the model's own ``make_cache()`` factory when available so each + architecture gets the cache layout it was designed for (e.g. Gemma 4 + returns a mix of ``RotatingKVCache`` for sliding-window layers and + ``KVCache`` for global-attention layers). This is exactly what + ``mlx_lm.speculative_generate_step`` expects when ``draft_model`` is + supplied -- it slices the supplied ``prompt_cache`` into target/drafter + halves of native shape and uses each model's own attention masks. + """ assert hasattr(model, "layers") if hasattr(model, "make_cache"): + native = cast(list[object], model.make_cache()) # type: ignore[reportAttributeAccessIssue] logger.info("Using MLX LM's make cache") - return model.make_cache() # type: ignore + return cast(KVCacheType, native) if max_kv_size is None: if KV_CACHE_BITS is None: diff --git a/src/exo/worker/engines/mlx/generator/drafter.py b/src/exo/worker/engines/mlx/generator/drafter.py new file mode 100644 index 0000000000..69c246f989 --- /dev/null +++ b/src/exo/worker/engines/mlx/generator/drafter.py @@ -0,0 +1,1174 @@ +"""Drafting strategies for speculative decoding. + +The mlx engine has historically supported one drafting mode: a smaller +"drafter" model paired with the target via +``mlx_lm.speculative_generate_step``. That mode (``DraftMode = "model"``) +is the right call for distributed pipeline-parallel runs, where every +generated token pays cross-device communication latency that the +drafter - sitting on a single device - amortises across many tokens. +On fast single-device inference (e.g. Mac Studio M3 Ultra + 4-bit 26B +target at ~76 tok/s), generation is memory-bandwidth-bound and the +``K + 1``-token verify forward costs nearly ``K + 1`` times a +single-token forward; speculative decoding only wins when the +acceptance fraction clears ``K / (K + 1)``, which most workloads don't. +Empirical measurements on that hardware show: + + * model-drafter spec is a net loss across every workload class + (-25% to -45% tps), even at 65-75% acceptance. + * n-gram spec is roughly parity on echo-shaped prompts (-0.5%) and + a 20-30% loss on novel content where suffix matches are weak. + +Asymmetric (drafter on a separate node via RDMA/TCP) and EAGLE / lookahead +hit the same wall on Apple Silicon for a structural reason: ``mlx_lm`` +derives every position's RoPE id from ``KVCache.offset`` (a single +``int``), so the multi-position-per-step verify that gives speculative +decoding its CUDA wins (3.3-6.5x for EAGLE-3, 1.5-2.5x for lookahead) +collapses to a *linear* verify on Metal. Track upstream +`ml-explore/mlx-lm#846 `_ +and `ml-explore/mlx-lm#250 +`_ before investing +in EAGLE / lookahead runtime work; the scaffolding lives here so the +seam is ready when the upstream blocker lifts. A community MLX EAGLE-3 +prototype on M3 Ultra confirms the ceiling at 1.05x today (mlx-lm +discussion #890). + +The right call there is ``DraftMode = "none"`` (the default). +``"ngram"`` and ``"model"`` are exposed for slower-target regimes +(distributed inference, larger FP16 models, ASIC-bound targets) where +their economics flip: opt-in via ``EXO_DRAFT_MODE`` env var or per- +request ``TaskParams.draft_mode``. + +This module exposes a small ``Drafter`` protocol so ``mlx_generate`` can +dispatch on mode without sprouting branches everywhere, plus three +concrete implementations: + +* :class:`NoSpecDrafter` — pass-through to ``mlx_lm.stream_generate``. +* :class:`ModelDrafter` — wraps ``mlx_lm.stream_generate(draft_model=...)``. +* :class:`NgramDrafter` — owns its own spec loop; proposes draft tokens + by suffix-matching the running context against itself. + +The protocol intentionally lives at the *stream factory* level (not at a +finer-grained ``propose / accept`` level), so the well-tested upstream +spec loop keeps owning the model-drafter path. Future additions +(EAGLE/Medusa heads, lookahead with n-gram + Jacobi, drafter-on-other- +device) plug in by adding a new concrete drafter that yields +``GenerationResponse`` the same way ``stream_generate`` does. +""" + +from __future__ import annotations + +import os +import time +from typing import ( + Callable, + Final, + Generator, + Literal, + Protocol, + Sequence, + cast, + final, + runtime_checkable, +) + +import mlx.core as mx +from mlx_lm.generate import GenerationResponse, stream_generate +from mlx_lm.models.cache import trim_prompt_cache as mlx_trim_prompt_cache +from mlx_lm.tokenizer_utils import TokenizerWrapper + +from exo.worker.engines.mlx.constants import KV_BITS, KV_GROUP_SIZE +from exo.worker.engines.mlx.types import KVCacheType, Model +from exo.worker.runner.bootstrap import logger + + +def _get_eos_ids(tokenizer: TokenizerWrapper) -> list[int]: + """Tokenizer-agnostic EOS lookup matching ``eos_ids_from_tokenizer``.""" + eos: list[int] | None = getattr(tokenizer, "eos_token_ids", None) + if eos is None: + return [] + return eos + + +DraftMode = Literal["model", "pipelined", "ngram", "eagle", "lookahead", "none"] +"""How to source draft tokens for speculative decoding. + +* ``"model"``: small distilled drafter (e.g. Gemma-4 e2b/e4b) via + ``mlx_lm.speculative_generate_step``. Best for slow targets and + distributed pipeline-parallel where token latency is dominated by + cross-device communication. On fast single-device inference this is + frequently a net loss; benchmark before defaulting to it. +* ``"pipelined"``: same drafter model as ``"model"``, but routed + through :class:`exo.worker.engines.mlx.generator.pipelined_drafter + .PipelinedModelDrafter` -- a custom spec loop with cross-round + speculation (drafter forward for round ``t + 1`` overlaps target + verify of round ``t``). The transport layer (in-process or remote) + is selected by ``EXO_DRAFTER_TRANSPORT``; remote (RDMA/TCP via + ``mx.distributed.send/recv``) is the regime where the pipelining + win is unambiguous. +* ``"ngram"``: propose drafts by matching the longest suffix of the + running token context against earlier positions in the same context. + Zero drafter compute, no extra KV cache, no warmup. Wins on prompts + the model echoes (RAG, summarisation, structured/code output); + gracefully degrades to baseline when no match is found. +* ``"eagle"``: tiny auxiliary network conditioned on the target's + hidden states (EAGLE / EAGLE-2). Reuses the target's KV cache, + no second model load. Reported 2-3x wins in the literature versus + bare model-drafter on dense targets. **NOT YET IMPLEMENTED** -- the + scaffolding (factory dispatch, ``EagleDrafter`` shell) ships in this + PR so a follow-up only has to fill in the auxiliary head + tree + decoding loop. See :class:`EagleDrafter` for the integration seam. +* ``"lookahead"``: lookahead decoding (Fu et al. 2024). Uses the + target's own forward pass at multiple time-steps to produce n-gram + candidates via Jacobi iteration, no auxiliary model and no extra + weights. Composable with ``"ngram"`` -- the lookahead lookup table + acts as a richer source for the n-gram drafter. **NOT YET + IMPLEMENTED** -- the scaffolding ships in this PR; see + :class:`LookaheadDrafter`. +* ``"none"``: standard non-speculative generation. +""" + +ALL_DRAFT_MODES: Final[tuple[DraftMode, ...]] = ( + "model", + "pipelined", + "ngram", + "eagle", + "lookahead", + "none", +) + +EXO_DRAFT_MODE_ENV: Final[str] = "EXO_DRAFT_MODE" +"""Process-wide default mode. Per-request ``TaskParams`` overrides take precedence.""" + + +def parse_draft_mode(raw: str | None, default: DraftMode) -> DraftMode: + """Parse an ``EXO_DRAFT_MODE`` value, falling back on unknown values.""" + if raw is None: + return default + candidate = raw.strip().lower() + if candidate == "model": + return "model" + if candidate == "pipelined": + return "pipelined" + if candidate == "ngram": + return "ngram" + if candidate == "eagle": + return "eagle" + if candidate == "lookahead": + return "lookahead" + if candidate == "none": + return "none" + logger.warning( + f"{EXO_DRAFT_MODE_ENV}={raw!r} not in {ALL_DRAFT_MODES}; falling back to {default!r}" + ) + return default + + +def resolve_draft_mode( + *, + has_drafter_model: bool, + request_use_drafter: bool | None, + request_draft_mode: DraftMode | None, +) -> DraftMode: + """Compute the effective drafting mode for one request. + + Precedence (highest first): + 1. ``request_draft_mode`` — explicit per-request mode override. + 2. ``request_use_drafter is False`` — opt-out shortcut maps to ``"none"``. + 3. ``EXO_DRAFT_MODE`` env var if recognised. + 4. Implicit default: ``"model"`` if a drafter model was loaded, + else ``"none"``. ``"ngram"`` and ``"pipelined"`` are opt-in; + we don't auto-promote because their wins are topology-dependent + (``"pipelined"``'s gain unlocks at remote-transport scale). + + A ``"model"`` or ``"pipelined"`` mode without a loaded drafter + degrades to ``"none"`` with a warning, so misconfiguration fails + loudly instead of silently producing the wrong throughput. + """ + if request_draft_mode is not None: + chosen: DraftMode = request_draft_mode + elif request_use_drafter is False: + chosen = "none" + else: + env_default: DraftMode = "model" if has_drafter_model else "none" + chosen = parse_draft_mode(os.environ.get(EXO_DRAFT_MODE_ENV), env_default) + + if chosen in ("model", "pipelined") and not has_drafter_model: + logger.warning( + f"draft_mode={chosen!r} requested but no drafter model is " + "loaded; falling back to 'none'." + ) + return "none" + return chosen + + +@runtime_checkable +class Drafter(Protocol): + """Stream factory that runs one generation with a chosen drafting strategy. + + Concrete drafters yield :class:`mlx_lm.generate.GenerationResponse` + identically to ``mlx_lm.stream_generate``, so the call site in + ``mlx_generate`` doesn't change shape across modes. + """ + + @property + def mode(self) -> DraftMode: + """The mode this drafter implements (matches :data:`DraftMode`).""" + ... + + def stream( + self, + *, + model: Model, + tokenizer: TokenizerWrapper, + prompt: mx.array, + context_tokens: Sequence[int], + prompt_cache: KVCacheType, + max_tokens: int, + sampler: Callable[[mx.array], mx.array], + logits_processors: Sequence[Callable[[mx.array, mx.array], mx.array]], + prefill_step_size: int = 1, + ) -> Generator[GenerationResponse, None, None]: + """Generate tokens against ``model``. + + Args: + prompt: Prefill-tail (the last 2 prompt tokens). The caller + has pre-aligned ``prompt_cache`` to ``full_prompt[:-2]`` + via ``exo.prefill`` + ``trim(2)``; ``mlx_lm``'s + internal ``_prefill`` advances the cache by one more + token, and the drafter's spec loop seeds from the last. + context_tokens: Full prompt as a list of token ids. Used by + drafters that need the complete history for proposals + (``NgramDrafter``); other drafters ignore it. + prompt_cache: Target KV cache, pre-aligned per ``prompt`` above. + max_tokens: Maximum tokens to generate (including drafter- + accepted tokens). + sampler: ``logprobs -> token`` sampler. + logits_processors: Per-position logits processors (repetition + penalty, etc.). The drafter applies them before sampling. + prefill_step_size: Forwarded to ``mlx_lm._prefill``. + """ + ... + + +@final +class NoSpecDrafter: + """Standard non-speculative decoding via ``mlx_lm.stream_generate``.""" + + @property + def mode(self) -> DraftMode: + return "none" + + def stream( + self, + *, + model: Model, + tokenizer: TokenizerWrapper, + prompt: mx.array, + context_tokens: Sequence[int], + prompt_cache: KVCacheType, + max_tokens: int, + sampler: Callable[[mx.array], mx.array], + logits_processors: Sequence[Callable[[mx.array, mx.array], mx.array]], + prefill_step_size: int = 1, + ) -> Generator[GenerationResponse, None, None]: + del context_tokens # only the n-gram drafter needs it + yield from stream_generate( + model=model, + tokenizer=tokenizer, + prompt=prompt, + max_tokens=max_tokens, + sampler=sampler, + logits_processors=list(logits_processors), + prompt_cache=list(prompt_cache), + prefill_step_size=prefill_step_size, + kv_group_size=KV_GROUP_SIZE, + kv_bits=KV_BITS, + ) + + +@final +class ModelDrafter: + """Speculative decoding via a smaller distilled drafter model. + + Delegates to ``mlx_lm.stream_generate(draft_model=...)`` so the + well-tested upstream spec loop owns the rejection sampling, cache + trimming, and bonus-token bookkeeping. The target and drafter caches + must already be aligned to the same offset (handled by + ``mlx_generate`` via ``exo.prefill`` + ``_spec_drafter_prefill``). + """ + + def __init__( + self, + *, + draft_model: Model, + draft_cache: KVCacheType, + num_draft_tokens: int, + ) -> None: + if num_draft_tokens < 1: + raise ValueError(f"num_draft_tokens must be >= 1, got {num_draft_tokens}") + self._draft_model = draft_model + self._draft_cache = draft_cache + self._num_draft_tokens = num_draft_tokens + + @property + def mode(self) -> DraftMode: + return "model" + + @property + def num_draft_tokens(self) -> int: + return self._num_draft_tokens + + def stream( + self, + *, + model: Model, + tokenizer: TokenizerWrapper, + prompt: mx.array, + context_tokens: Sequence[int], + prompt_cache: KVCacheType, + max_tokens: int, + sampler: Callable[[mx.array], mx.array], + logits_processors: Sequence[Callable[[mx.array, mx.array], mx.array]], + prefill_step_size: int = 1, + ) -> Generator[GenerationResponse, None, None]: + del context_tokens # mlx_lm spec_step manages its own context + # mlx_lm splits prompt_cache as ``[: len(model.layers)]`` for the + # target and ``[len(model.layers) :]`` for the drafter, so we just + # concatenate native cache lists here. + decode_cache = list(prompt_cache) + list(self._draft_cache) + yield from stream_generate( + model=model, + tokenizer=tokenizer, + prompt=prompt, + max_tokens=max_tokens, + sampler=sampler, + logits_processors=list(logits_processors), + prompt_cache=decode_cache, + prefill_step_size=prefill_step_size, + kv_group_size=KV_GROUP_SIZE, + kv_bits=KV_BITS, + draft_model=self._draft_model, + num_draft_tokens=self._num_draft_tokens, + ) + + +@final +class NgramDrafter: + """Speculative decoding using in-context n-gram lookup. + + Each spec round looks for the longest suffix (length in + ``[min_match, max_match]``) of the running token context that + appeared earlier in the same context, and proposes a continuation + drawn from the tokens that followed it last time. This is the + classic "prompt-suffix lookup drafter" used by vLLM + (``--speculative-model='[ngram]'``) and SGLang + (``--draft-model n-gram``). + + Match-strength-adaptive K + ------------------------- + A short (length-``min_match``) match is weak evidence that the + *next* ``num_draft_tokens`` tokens repeat - it's just two tokens of + overlap, often coincidental. A long match (length ``max_match``+) + is strong evidence: the model is genuinely re-emitting a prior + span. We bias proposal length to match strength via + ``K_eff = min(num_draft_tokens, match_length)``; that way short + matches propose few drafts (cheap verify), long matches propose + many (worth the verify cost). Disable by setting + ``adaptive_k=False`` to always issue ``num_draft_tokens`` drafts + when any match is found. + + Cost model: O(context * max_match) per proposal in pure Python - + microseconds for chats up to a few thousand tokens, zero MLX work, + zero KV cache, zero warmup. When no match is found we fall through + to a single-token target step, so worst-case throughput equals the + no-drafter baseline. + """ + + def __init__( + self, + *, + num_draft_tokens: int, + max_match: int = 4, + min_match: int = 2, + adaptive_k: bool = True, + ) -> None: + if num_draft_tokens < 1: + raise ValueError(f"num_draft_tokens must be >= 1, got {num_draft_tokens}") + if min_match < 1: + raise ValueError(f"min_match must be >= 1, got {min_match}") + if max_match < min_match: + raise ValueError( + f"max_match ({max_match}) must be >= min_match ({min_match})" + ) + self._num_draft_tokens = num_draft_tokens + self._max_match = max_match + self._min_match = min_match + self._adaptive_k = adaptive_k + + @property + def mode(self) -> DraftMode: + return "ngram" + + @property + def num_draft_tokens(self) -> int: + return self._num_draft_tokens + + def propose(self, context: Sequence[int], k: int) -> list[int]: + """Return up to ``k`` candidate continuations of ``context``. + + Returns an empty list if no suffix of length ``>= min_match`` + appears earlier in ``context``. The match is right-anchored at + ``context[-n:]`` (we don't search inside the suffix itself, to + avoid trivial self-overlap). When ``adaptive_k`` is enabled, + the proposal length is capped at the match length so weak + (short) matches don't trigger expensive K-token verifies. + """ + if k < 1 or len(context) < self._min_match + 1: + return [] + # Walk match length from longest to shortest, biasing toward + # stronger matches (and earlier exit on the first match). + upper = min(self._max_match, len(context) - 1) + for n in range(upper, self._min_match - 1, -1): + suffix = list(context[-n:]) + # Search backwards (most-recent match wins) through earlier + # positions; locality of reference means the model is most + # likely to repeat its recent self. + for start in range(len(context) - n - 1, -1, -1): + if list(context[start : start + n]) == suffix: + # Adaptive K: cap proposal length to match strength. + # Match length n -> at most n drafts (a 2-gram match + # gets 2 drafts; a 4-gram match gets up to 4). + cap = min(k, n) if self._adaptive_k else k + proposal = list(context[start + n : start + n + cap]) + return proposal + return [] + + def stream( + self, + *, + model: Model, + tokenizer: TokenizerWrapper, + prompt: mx.array, + context_tokens: Sequence[int], + prompt_cache: KVCacheType, + max_tokens: int, + sampler: Callable[[mx.array], mx.array], + logits_processors: Sequence[Callable[[mx.array, mx.array], mx.array]], + prefill_step_size: int = 1, + ) -> Generator[GenerationResponse, None, None]: + yield from _ngram_stream_generate( + model=model, + tokenizer=tokenizer, + prompt=prompt, + context_tokens=list(context_tokens), + prompt_cache=prompt_cache, + max_tokens=max_tokens, + sampler=sampler, + logits_processors=list(logits_processors), + drafter=self, + prefill_step_size=prefill_step_size, + ) + + +@final +class EagleDrafter: + """EAGLE / EAGLE-2 speculative decoder using a tiny auxiliary head. + + **Status: scaffolding only.** This class ships an explicit integration + seam so EAGLE can plug into the existing :class:`Drafter` factory + without churning call-sites in :mod:`generate` / :mod:`builder`. The + actual auxiliary-head load + tree decoding loop is intentionally + deferred -- a follow-up PR fills this in once we pick which EAGLE + variant to support (vanilla EAGLE, EAGLE-2 with dynamic tree, or + Hydra heads). + + Why this is an *additional* drafter and not a flag on + :class:`ModelDrafter`: + + * EAGLE's drafter needs the target's *last hidden state*, not just + the sampled token. The :class:`Drafter.stream` signature already + lets us read ``model``'s forward output, but EAGLE additionally + requires plumbing the hidden state out of the target's forward + pass. That's a target-engine change, not a drafter change. + * EAGLE-2 uses a tree of draft tokens rather than a single chain; + verifying a tree requires ``mlx_lm.tree_verify_step`` (does not + yet exist) or an in-house verifier. Plug-in point: a new method + on this class that returns ``(token_tree, parent_indices)``, + consumed by a tree-aware verify loop in :func:`stream`. + * Tree verification is also what Medusa needs, so factoring the + verifier into a separate ``TreeVerifier`` class lets EAGLE + + Medusa share it. + + Recommended config surface (when filling this in): + + * ``eagle_head_repo``: HuggingFace repo for the auxiliary head + weights, surfaced in :class:`exo.shared.models.types.ModelCard` + alongside ``drafter_model_ids`` (probably a new + ``eagle_head_ids: list[str]``). + * ``num_draft_tokens``: tree depth for EAGLE-2 (vanilla EAGLE is + depth-only and can reuse the existing ``K`` knob). + * ``tree_branching``: per-level branching for EAGLE-2 (e.g. + ``[4, 2, 2, 2]``); ignored by vanilla EAGLE. + + Until the implementation lands, ``stream`` raises + :class:`NotImplementedError` so misconfiguration fails loudly. The + factory in :func:`make_drafter` checks for the head being loaded; + if not, it logs and falls back to ``"none"`` (mirrors the + ``"model"`` -> ``"none"`` degradation when no drafter is loaded). + + **Apple Silicon ceiling (read before implementing).** The CUDA + literature (3.3-6.5x on the EAGLE-3 paper; 1.72x on the RedHat + Gemma-4-31B EAGLE3 head on 8x H200) gets its win from *tree* + verification: dozens of candidate continuations verified in a + single batched forward where memory bandwidth, not arithmetic, + sets the cost. On Apple Silicon a single sibling-position verify + is *not* free because Metal's command queue serialises GPU work + per device and ``mlx_lm`` derives every position's RoPE id from + ``KVCache.offset`` (a single ``int``), so two siblings at the + same depth cannot get different RoPE positions in the same + forward. Until ``mlx_lm`` accepts ``position_ids`` (open issues + `ml-explore/mlx-lm#846 `_, + `ml-explore/mlx-lm#250 `_), + a faithful EAGLE port collapses to a *linear* verify, which a + community prototype (`mlx-lm discussion #890 + `_) measured + at **1.05x** on LLaMA-3.1-8B-4bit on an M3 Ultra -- inside the + noise of our own n-gram K-sweep on this hardware. Don't ship the + EAGLE runtime until the position-id seam lands upstream; the + converter (offline tool) is fine to ship now since the artifact + is durable. + + Concrete artifacts to consume when picking this up: + + * Pre-trained head for our exact target: + ``RedHatAI/gemma-4-26B-A4B-it-speculator.eagle3`` + (released 2026-04-13, ~277 MB). + * Reference MLX port (Llama-3.1 only, no Gemma-4 architecture + adapter, no tree verify): the gist linked from mlx-lm + discussion #890 above. ``eagle_convert.py`` is reusable; + ``eagle_generate.py`` is the loop to fork. + * For Gemma-4 specifically the EAGLE head shape is + ``num_hidden_layers=1`` with ``input_size = 2 * hidden_size`` + (Q/K/V take ``[token_embedding, fused_features]`` concatenated) + and a reduced 32k draft vocabulary -- same as the Llama variant, + so the Gemma adaptation is mostly the layer-tap indices + (Gemma-4-26b is N=30 layers, so taps go at ``{2, 15, 27}`` + following EAGLE's ``{2, N//2, N-3}`` heuristic). + """ + + def __init__( + self, + *, + eagle_head: object | None, + num_draft_tokens: int, + tree_branching: tuple[int, ...] | None = None, + ) -> None: + if num_draft_tokens < 1: + raise ValueError(f"num_draft_tokens must be >= 1, got {num_draft_tokens}") + self._eagle_head = eagle_head + self._num_draft_tokens = num_draft_tokens + self._tree_branching = tree_branching + + @property + def mode(self) -> DraftMode: + return "eagle" + + @property + def num_draft_tokens(self) -> int: + return self._num_draft_tokens + + @property + def tree_branching(self) -> tuple[int, ...] | None: + return self._tree_branching + + def stream( + self, + *, + model: Model, + tokenizer: TokenizerWrapper, + prompt: mx.array, + context_tokens: Sequence[int], + prompt_cache: KVCacheType, + max_tokens: int, + sampler: Callable[[mx.array], mx.array], + logits_processors: Sequence[Callable[[mx.array, mx.array], mx.array]], + prefill_step_size: int = 1, + ) -> Generator[GenerationResponse, None, None]: + del ( + model, + tokenizer, + prompt, + context_tokens, + prompt_cache, + max_tokens, + sampler, + logits_processors, + prefill_step_size, + ) + raise NotImplementedError( + "EagleDrafter is a scaffolding stub. Implement the auxiliary-" + "head forward + tree verify loop here. The Drafter protocol " + "and factory dispatch are in place; the missing pieces are " + "(1) loading EAGLE head weights (probably a new " + "ModelCard.eagle_head_ids field), (2) plumbing the target's " + "last hidden state out of the verify forward, and (3) a tree-" + "aware verifier (shareable with future Medusa support). See " + "the class docstring for the recommended factoring." + ) + yield # pragma: no cover -- keeps the function a generator. + + +@final +class LookaheadDrafter: + """Lookahead decoding (Fu et al. 2024) using the target's own forward. + + **Status: scaffolding only.** Plug-in point shipped so a follow-up + can fill in the Jacobi iteration loop without changing call sites. + + Lookahead decoding builds an n-gram candidate pool from intermediate + Jacobi-iteration outputs of the target itself: each generation step + runs the target on a window of ``window_size`` positions and seeds + an n-gram lookup table from the result. The next step queries the + table for candidates, verifies them in parallel via a single target + forward, and updates the table. No auxiliary model, no extra + weights. + + Composability with :class:`NgramDrafter`: the lookahead lookup + table is the same shape as the n-gram drafter's suffix lookup, + just populated by Jacobi rather than context history. A natural + factoring is to share the ``propose(context, k)`` interface with + :class:`NgramDrafter` and have :class:`LookaheadDrafter` swap the + proposal source at runtime; that lets ``"ngram"`` and + ``"lookahead"`` share the verify loop. Recommended seam: + + * Extract :meth:`NgramDrafter.propose` to a shared + ``NgramProposer`` Protocol with two impls (``SuffixProposer``, + ``LookaheadProposer``). + * :func:`_ngram_speculative_step` takes the proposer rather than + the concrete :class:`NgramDrafter`, picks one based on + :data:`DraftMode`. + + Config surface: + + * ``num_draft_tokens``: K (max chain length per round). + * ``window_size``: Jacobi window width per step. Larger windows + seed more n-grams but cost a wider verify forward. + * ``ngram_size``: size of the n-grams stored in the lookup table + (typically 2-4). + + Until implemented, ``stream`` raises :class:`NotImplementedError`. + + **Same Apple Silicon ceiling as :class:`EagleDrafter`.** Lookahead + decoding's win comes from verifying *multiple* Jacobi-seeded + candidate continuations in parallel, which collapses to linear + verify under the same ``KVCache.offset`` / ``position_ids`` + constraint described on :class:`EagleDrafter`. On the + ``gemma-4-26b-a4b-it-4bit`` target measured here (119 t/s + baseline), the n-gram drafter -- which shares the linear-verify + cost model lookahead would inherit -- lands at 92-102 t/s + across K=2..8 (a 14-23% net loss). Implementing lookahead before + ``position_ids`` lands upstream is unlikely to flip that sign. + Track the same upstream issues + (`ml-explore/mlx-lm#846 `_, + `ml-explore/mlx-lm#250 `_) + before investing in the implementation. + """ + + def __init__( + self, + *, + num_draft_tokens: int, + window_size: int = 5, + ngram_size: int = 3, + ) -> None: + if num_draft_tokens < 1: + raise ValueError(f"num_draft_tokens must be >= 1, got {num_draft_tokens}") + if window_size < 1: + raise ValueError(f"window_size must be >= 1, got {window_size}") + if ngram_size < 2: + raise ValueError(f"ngram_size must be >= 2, got {ngram_size}") + self._num_draft_tokens = num_draft_tokens + self._window_size = window_size + self._ngram_size = ngram_size + + @property + def mode(self) -> DraftMode: + return "lookahead" + + @property + def num_draft_tokens(self) -> int: + return self._num_draft_tokens + + @property + def window_size(self) -> int: + return self._window_size + + @property + def ngram_size(self) -> int: + return self._ngram_size + + def stream( + self, + *, + model: Model, + tokenizer: TokenizerWrapper, + prompt: mx.array, + context_tokens: Sequence[int], + prompt_cache: KVCacheType, + max_tokens: int, + sampler: Callable[[mx.array], mx.array], + logits_processors: Sequence[Callable[[mx.array, mx.array], mx.array]], + prefill_step_size: int = 1, + ) -> Generator[GenerationResponse, None, None]: + del ( + model, + tokenizer, + prompt, + context_tokens, + prompt_cache, + max_tokens, + sampler, + logits_processors, + prefill_step_size, + ) + raise NotImplementedError( + "LookaheadDrafter is a scaffolding stub. Implement the Jacobi " + "iteration + n-gram lookup table here. Recommended factoring: " + "extract NgramDrafter.propose into a shared NgramProposer " + "Protocol with SuffixProposer and LookaheadProposer impls so " + "this drafter and NgramDrafter share the verify loop. See " + "the class docstring." + ) + yield # pragma: no cover -- keeps the function a generator. + + +def make_drafter( + *, + mode: DraftMode, + num_draft_tokens: int, + draft_model: Model | None, + draft_cache: KVCacheType | None, + target_subgroup_size: int = 1, + pipelined_transport: object | None = None, + target_group: object | None = None, + target_peer_fanout: object | None = None, + is_target_root: bool = True, +) -> Drafter: + """Build a :class:`Drafter` for the resolved mode. + + Raises ``ValueError`` if ``mode in ("model", "pipelined")`` is + requested without a loaded drafter; callers should resolve that via + :func:`resolve_draft_mode` (which downgrades silently). + + For ``mode == "pipelined"`` the transport is selected as: + + * The supplied ``pipelined_transport`` (asymmetric placement: + the runner bootstrap allocates a long-lived ``RemoteTransport`` + bound to the drafter socket and the spec loop opens a + per-request session view of it). ``draft_model`` / + ``draft_cache`` are ignored on the target rank in this path. + * Otherwise an in-process transport built from the supplied + ``draft_model`` / ``draft_cache`` (single-process pipelining + win, no remote IPC). + + Multi-target asymmetric (``target_subgroup_size > 1``) is V2: only + the target root (``is_target_root``) holds the transport; non-root + target ranks construct a transport-less :class:`PipelinedModelDrafter` + and consume each round's drafts via a rank-0 broadcast on + ``target_group``. Both ranks then run the verify forward in TP + lockstep. Requires the caller to pass ``target_group`` (the + target-only :class:`mx.distributed.Group`) and the rank's + ``is_target_root`` flag. + """ + if mode == "none": + return NoSpecDrafter() + if mode == "ngram": + return NgramDrafter(num_draft_tokens=num_draft_tokens) + if mode == "eagle": + # Scaffold path; the runner-side bootstrap doesn't load EAGLE heads + # yet, so the head is always None today and the constructor builds + # a stub that raises on ``stream``. ``resolve_draft_mode`` should + # downgrade to ``"none"`` once an analogous ``has_eagle_head`` flag + # is wired through; until then the stub error makes misuse obvious. + return EagleDrafter(eagle_head=None, num_draft_tokens=num_draft_tokens) + if mode == "lookahead": + # Scaffold path; uses target weights only, no extra load needed. + # Stub raises on ``stream`` until the Jacobi loop lands. + return LookaheadDrafter(num_draft_tokens=num_draft_tokens) + if mode == "model": + if draft_model is None or draft_cache is None: + raise ValueError( + "draft_mode='model' requires both draft_model and draft_cache" + ) + return ModelDrafter( + draft_model=draft_model, + draft_cache=draft_cache, + num_draft_tokens=num_draft_tokens, + ) + if mode == "pipelined": + # Imported here to keep the module's import surface minimal in + # the common (model/ngram/none) paths. + from exo.worker.engines.mlx.generator.drafter_transport import ( + DrafterTransport, + make_inprocess_transport, + ) + from exo.worker.engines.mlx.generator.pipelined_drafter import ( + PipelinedModelDrafter, + ) + from exo.worker.engines.mlx.utils_mlx import TargetPeerFanout + + # Validate target_peer_fanout shape early so a malformed + # caller fails here, not deep inside the spec loop's broadcast + # helpers. ``None`` is fine on every path (single-rank / + # symmetric / test fakes); the broadcast helpers fall back to + # ``mx_broadcast_int_list`` in that case. + if target_peer_fanout is not None and not isinstance( + target_peer_fanout, TargetPeerFanout + ): + raise TypeError( + "target_peer_fanout must be TargetPeerFanout | None; " + f"got {type(target_peer_fanout).__name__}" + ) + + # Multi-target asymmetric: non-root target ranks have no + # transport (the socket is rank-0-only) but must still drive the + # verify forward in TP lockstep. They construct a + # transport-less drafter that pulls each round's drafts from a + # rank-0 broadcast on ``target_group``. ``target_group`` is + # required when ``target_subgroup_size > 1`` so the broadcast + # reaches every rank; raising here is a louder failure than + # silently falling through to an in-process drafter on rank 1 + # (which would load the drafter weights twice and never agree + # on tokens with rank 0). + if pipelined_transport is None and target_subgroup_size > 1: + if target_group is None: + raise ValueError( + "draft_mode='pipelined' on a multi-target rank " + f"(target_subgroup_size={target_subgroup_size}) without " + "pipelined_transport requires target_group for the " + "draft broadcast (this rank is the consumer)" + ) + if is_target_root: + raise ValueError( + "is_target_root=True implies this rank owns the " + "drafter socket; pipelined_transport must be supplied" + ) + return PipelinedModelDrafter( + transport=None, + num_draft_tokens=num_draft_tokens, + target_group=cast("mx.distributed.Group | None", target_group), + target_peer_fanout=target_peer_fanout, + is_target_root=False, + ) + + if pipelined_transport is not None: + # Caller supplied a long-lived transport (asymmetric path: + # SequentialGenerator allocates the RemoteTransport once at + # build time and reuses it across requests). Validate it + # implements the protocol and skip the factory dance below. + if not isinstance(pipelined_transport, DrafterTransport): + raise TypeError( + "pipelined_transport must implement DrafterTransport; " + f"got {type(pipelined_transport).__name__}" + ) + if target_subgroup_size > 1 and target_group is None: + raise ValueError( + "Asymmetric drafter with target_subgroup_size=" + f"{target_subgroup_size} requires target_group for " + "the rank-0 -> peer-target broadcast of drafts each " + "round; V1 single-target paths can pass target_group=None" + ) + return PipelinedModelDrafter( + transport=pipelined_transport, + num_draft_tokens=num_draft_tokens, + target_group=cast("mx.distributed.Group | None", target_group) + if target_subgroup_size > 1 + else None, + target_peer_fanout=target_peer_fanout + if target_subgroup_size > 1 + else None, + is_target_root=True, + ) + + # No builder-supplied transport, single target rank: in-process + # is the only sensible default. Asymmetric multi-target was + # handled above (consumer rank). Reaching here means a single- + # process pipelined drafter (no distributed group, drafter + # weights live in this same process). + if draft_model is None or draft_cache is None: + raise ValueError( + "draft_mode='pipelined' without a builder-supplied " + "transport requires both draft_model and draft_cache" + ) + constructed = make_inprocess_transport( + draft_model=draft_model, + draft_cache=draft_cache, + num_draft_tokens=num_draft_tokens, + ) + return PipelinedModelDrafter( + transport=constructed, + num_draft_tokens=num_draft_tokens, + ) + # Exhaustiveness: DraftMode is a closed Literal. Any other value is a + # programming error at the call site, so raise loudly. + raise ValueError(f"Unknown DraftMode: {mode!r}") + + +def _ngram_stream_generate( + *, + model: Model, + tokenizer: TokenizerWrapper, + prompt: mx.array, + context_tokens: list[int], + prompt_cache: KVCacheType, + max_tokens: int, + sampler: Callable[[mx.array], mx.array], + logits_processors: list[Callable[[mx.array, mx.array], mx.array]], + drafter: NgramDrafter, + prefill_step_size: int, +) -> Generator[GenerationResponse, None, None]: + """Mirror of ``mlx_lm.stream_generate`` for the n-gram drafter. + + Replicates only the framing (detokenisation, tps tracking, finish + reasons) that ``mlx_lm.stream_generate`` does for the model-drafter + path; the actual spec loop is :func:`_ngram_speculative_step`. + ``prompt`` is the prefill-tail (size 2 in production, but any size + >=1 works); ``context_tokens`` is the full prompt as a Python list + (used for n-gram lookups, not fed to the model). + """ + detokenizer = tokenizer.detokenizer + detokenizer.reset() # type: ignore[reportUnknownMemberType] + eos_ids = _get_eos_ids(tokenizer) + + token_iter = _ngram_speculative_step( + prompt=prompt, + context_tokens=context_tokens, + model=model, + drafter=drafter, + prompt_cache=prompt_cache, + max_tokens=max_tokens, + sampler=sampler, + logits_processors=logits_processors, + prefill_step_size=prefill_step_size, + ) + + # Telemetry: report the *full* prompt size (which is what the user + # paid prefill on), not the prefill-tail we were handed. + prompt_size = len(context_tokens) + + tic = time.perf_counter() + prompt_tps = 0.0 + n = -1 + token = 0 + logprobs = mx.zeros((1,)) + from_draft = False + finish_reason: str | None = None + for n, (token, logprobs, from_draft) in enumerate(token_iter): + if n == 0: + prompt_time = time.perf_counter() - tic + prompt_tps = prompt_size / prompt_time if prompt_time > 0 else 0.0 + tic = time.perf_counter() + if token in eos_ids: + finish_reason = "stop" + break + detokenizer.add_token(token) # type: ignore[reportUnknownMemberType] + if (n + 1) == max_tokens: + finish_reason = "length" + break + elapsed = time.perf_counter() - tic + yield GenerationResponse( + text=detokenizer.last_segment, + token=token, + logprobs=logprobs, + from_draft=from_draft, + prompt_tokens=prompt_size, + prompt_tps=prompt_tps, + generation_tokens=n + 1, + generation_tps=(n + 1) / elapsed if elapsed > 0 else 0.0, + peak_memory=mx.get_peak_memory() / 1e9, + finish_reason=None, + ) + + detokenizer.finalize() # type: ignore[reportUnknownMemberType] + elapsed = time.perf_counter() - tic + yield GenerationResponse( + text=detokenizer.last_segment, + token=token, + logprobs=logprobs, + from_draft=from_draft, + prompt_tokens=prompt_size, + prompt_tps=prompt_tps, + generation_tokens=n + 1 if n >= 0 else 0, + generation_tps=(n + 1) / elapsed if elapsed > 0 and n >= 0 else 0.0, + peak_memory=mx.get_peak_memory() / 1e9, + finish_reason=finish_reason or ("stop" if token in eos_ids else "length"), + ) + + +def _process_logits_for_position( + raw_logits: mx.array, + prev_tokens: mx.array, + logits_processors: list[Callable[[mx.array, mx.array], mx.array]], +) -> mx.array: + """Apply logits processors and convert to logprobs (single position). + + ``raw_logits`` has shape ``(vocab,)`` (already squeezed from a + ``(1, vocab)`` per-position slice). ``prev_tokens`` is the running + sequence of tokens emitted so far, used by repetition-penalty etc. + """ + out = raw_logits + for proc in logits_processors: + out = proc(prev_tokens, out) + return out - mx.logsumexp(out, axis=-1, keepdims=True) + + +def _ngram_speculative_step( + *, + prompt: mx.array, + context_tokens: list[int], + model: Model, + drafter: NgramDrafter, + prompt_cache: KVCacheType, + max_tokens: int, + sampler: Callable[[mx.array], mx.array], + logits_processors: list[Callable[[mx.array, mx.array], mx.array]], + prefill_step_size: int, +) -> Generator[tuple[int, mx.array, bool], None, None]: + """Custom speculative-decoding loop using an :class:`NgramDrafter`. + + Yields ``(token, logprobs, from_draft)`` tuples to match the shape + ``mlx_lm.stream_generate`` expects from its inner token generator. + + Algorithm (greedy accept; matches the temperature-0 case our warmup + and most code paths use): + + 1. Prefill: feed ``prompt[:-1]`` to ``model`` so the cache covers + the prompt minus its last token. + 2. Each round, ask the drafter for up to ``num_draft_tokens`` + candidates given the running context. + 3. Build a verify input ``[y, *drafts]`` (y = the last emitted + token) and run ``model`` on it once. The cache extends by + ``len(drafts) + 1``. + 4. Sample target's preferred token at each position. Walk the + drafts and accept any that match the target's choice; on the + first mismatch, also emit the target's choice at that position + and stop. If all drafts match, emit the bonus token from the + final position. + 5. Trim the cache by ``len(drafts) - num_accepted`` so its offset + lines up with the emitted tokens. + 6. If the drafter declined to propose, fall back to a single- + token target step (cost identical to non-spec generation). + """ + y = prompt.astype(mx.uint32) + + # Mirror mlx_lm._prefill: the caller has aligned ``prompt_cache`` to + # ``context_tokens[:-2]`` via ``exo.prefill`` + ``trim(2)``; this loop + # advances the cache by one more token (offset N-1), leaving ``y`` + # as the seed for the spec loop. + while y.size > 1: + n_to_process = min(prefill_step_size, y.size - 1) + model(y[:n_to_process][None], cache=prompt_cache) + mx.eval([c.state for c in prompt_cache]) # type: ignore[reportArgumentType] + y = y[n_to_process:] + mx.clear_cache() + + # Running context for n-gram lookup and logits processors. We start + # from the full prompt (so the n-gram drafter can match against + # prefix-cached portions) and append every emitted token. + running_context: list[int] = list(context_tokens) + prev_tokens = mx.array(running_context, dtype=mx.uint32) + ntoks = 0 + + while ntoks < max_tokens: + # ``num_draft_tokens`` is the upper bound; cap to remaining budget + # so the verify forward never overruns ``max_tokens``. + num_drafts = min(max_tokens - ntoks, drafter.num_draft_tokens) + if num_drafts < 1: + break + + drafts = drafter.propose(running_context, num_drafts) + + if not drafts: + # Single-token fallback: identical to non-spec generation. + logits = model(y[None], cache=prompt_cache) + logprobs = _process_logits_for_position( + logits[:, -1, :].squeeze(0), prev_tokens, logits_processors + ) + sampled = sampler(logprobs) + mx.eval(sampled) + sampled_token = int(sampled.item()) + yield sampled_token, logprobs, False + running_context.append(sampled_token) + prev_tokens = mx.concatenate( + [prev_tokens, mx.array([sampled_token], dtype=mx.uint32)] + ) + y = mx.array([sampled_token], dtype=mx.uint32) + ntoks += 1 + continue + + # The proposer's contract is *up to* ``num_drafts`` tokens; the + # rest of the loop is sized off the actual proposal length so we + # never index past the verify forward's output. + actual_drafts = len(drafts) + + # Verify pass: target forward on [y, *drafts] + draft_arr = mx.array(drafts, dtype=mx.uint32) + verify_input = mx.concatenate([y, draft_arr]) + logits = model(verify_input[None], cache=prompt_cache) + # logits shape: (1, actual_drafts + 1, vocab) + + target_logprobs: list[mx.array] = [] + target_tokens: list[int] = [] + running_prev = prev_tokens + for i in range(actual_drafts + 1): + position_logits = logits[:, i, :].squeeze(0) + position_logprobs = _process_logits_for_position( + position_logits, running_prev, logits_processors + ) + sampled = sampler(position_logprobs) + mx.eval(sampled) + sampled_token = int(sampled.item()) + target_logprobs.append(position_logprobs) + target_tokens.append(sampled_token) + # Speculatively assume position i was kept for the next + # logits-processor call; this matches what + # ``speculative_generate_step`` does internally. + running_prev = mx.concatenate( + [running_prev, mx.array([sampled_token], dtype=mx.uint32)] + ) + + # Greedy accept + num_accepted = 0 + for i in range(actual_drafts): + if target_tokens[i] == drafts[i]: + num_accepted += 1 + else: + break + + # Emit accepted drafts + 1 (target's choice at first mismatch + # or bonus token after a full accept). + emit_count = num_accepted + 1 + trim = actual_drafts - num_accepted + + for j in range(emit_count): + tok = drafts[j] if j < num_accepted else target_tokens[j] + from_draft = j < num_accepted + yield tok, target_logprobs[j], from_draft + running_context.append(tok) + prev_tokens = mx.concatenate( + [prev_tokens, mx.array([tok], dtype=mx.uint32)] + ) + ntoks += 1 + if ntoks >= max_tokens: + break + + # Cache cleanup: we appended ``actual_drafts + 1`` tokens (the seed + # plus the proposed drafts); only the first ``num_accepted + 1`` + # of those are correct, so trim the rest. + if trim > 0: + # mlx_lm types the cache as ``List[Cache]``; exo's ``KVCacheType`` + # is a structural subset, so the cast + ignore mirrors the + # pattern used in ``mlx_generate``'s drafter cache trimming. + mlx_trim_prompt_cache(cast(list[object], prompt_cache), trim) # type: ignore[reportArgumentType] + + y = mx.array([running_context[-1]], dtype=mx.uint32) diff --git a/src/exo/worker/engines/mlx/generator/drafter_socket.py b/src/exo/worker/engines/mlx/generator/drafter_socket.py new file mode 100644 index 0000000000..4c37b79f9e --- /dev/null +++ b/src/exo/worker/engines/mlx/generator/drafter_socket.py @@ -0,0 +1,180 @@ +"""Direct TCP socket transport for the asymmetric drafter wire. + +The original drafter wire (:mod:`remote_drafter`) carries small uint32 +arrays via ``mx.distributed.send/recv`` over the parent +``mx.distributed.Group``. That design forces the drafter rank to be a +member of the parent group, which in turn requires +``mx.distributed.Group.split`` so target ranks can run TP/PP collectives +without dragging the drafter in. JACCL and ring backends do not +implement ``split`` on Apple Silicon, so the V1 asymmetric path was +limited to a single target rank. + +This module breaks that coupling. The drafter rank no longer joins +``mx.distributed`` at all. Instead, target rank 0 binds a TCP server +socket at instance bootstrap time, the drafter dials it, and the same +wire frames flow over that connection. The target's +``mx.distributed.Group`` therefore contains only target ranks and is +free to do whatever TP/PP work it needs without ``Group.split``. + +Wire frames are length-implicit (every op type has a known fixed shape; +``OP_PREFILL`` carries a variable-length token array whose length is +announced in the preceding command frame's ``num_forwards`` slot). Each +uint32 is serialised little-endian, matching mlx_lm's on-device layout +for ``mx.uint32``. + +Threading model: both the target rank's ``RemoteTransport`` and the +drafter rank's serve loop run wire ops serially on a single thread (the +target uses a single-worker ``ThreadPoolExecutor``; the drafter loops +synchronously). Concurrency is multiplexed via session ids, not via +multiple sockets, so a single TCP connection per asymmetric instance is +sufficient and avoids mid-flight reordering. +""" + +from __future__ import annotations + +import socket +import struct +import time +from typing import Final + +_HEADER_FORMAT: Final[str] = " None: + """Send a fixed-length uint32 frame over ``sock``. + + Caller must guarantee both peers know the frame length statically; + no length prefix is sent. Suitable for command/ack/drafts frames. + """ + if not all(0 <= v <= 0xFFFFFFFF for v in values): + raise ValueError(f"frame contains non-uint32 values: {values}") + payload = struct.pack(f"<{len(values)}I", *values) + sock.sendall(payload) + + +def recv_uint32_frame(sock: socket.socket, count: int) -> list[int]: + """Receive ``count`` uint32 ints over ``sock`` (no length prefix). + + Blocks until ``count * 4`` bytes have been received, raising + :class:`ConnectionError` if the peer closes mid-frame. + """ + if count <= 0: + raise ValueError(f"count must be > 0, got {count}") + needed = count * 4 + buf = bytearray(needed) + view = memoryview(buf) + received = 0 + while received < needed: + chunk = sock.recv_into(view[received:], needed - received) + if chunk == 0: + raise ConnectionError( + f"drafter wire closed mid-frame " + f"(received {received}/{needed} bytes)" + ) + received += chunk + unpacked = struct.unpack(f"<{count}I", bytes(buf)) + return list(unpacked) + + +def send_variable_uint32_payload(sock: socket.socket, values: list[int]) -> None: + """Send a length-prefixed uint32 payload (4-byte header + values). + + Used for OP_PREFILL's prompt-token tail when the size isn't carried + in the preceding command frame's slot. + """ + if not all(0 <= v <= 0xFFFFFFFF for v in values): + raise ValueError("variable payload contains non-uint32 values") + header = struct.pack(_HEADER_FORMAT, len(values)) + sock.sendall(header) + if values: + sock.sendall(struct.pack(f"<{len(values)}I", *values)) + + +def bind_target_listener(host: str, port: int, *, backlog: int = 1) -> socket.socket: + """Open and listen on ``(host, port)`` for the drafter's incoming dial. + + Bound with ``SO_REUSEADDR`` so a previous instance teardown that + left the port in TIME_WAIT does not block reclaim. Caller is + responsible for ``accept()`` and ``close()``. + """ + listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + listener.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + listener.bind((host, port)) + listener.listen(backlog) + return listener + + +def accept_drafter( + listener: socket.socket, + *, + timeout_seconds: float = 60.0, +) -> socket.socket: + """Block on ``listener.accept`` for the drafter's incoming connection. + + The drafter dials soon after target rank 0 reaches its + ``ConnectToGroup`` step, so a generous default timeout (60s) covers + drafter-side weight loading and warmup without spinning. ``TCP_NODELAY`` + is set on the accepted socket because every wire op is a small + request/reply round trip; Nagle would add ~40ms of latency per op + while batching tiny frames. + """ + listener.settimeout(timeout_seconds) + try: + accepted = listener.accept() + finally: + listener.settimeout(None) + conn: socket.socket = accepted[0] + conn.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + return conn + + +def dial_target( + host: str, + port: int, + *, + total_timeout_seconds: float = 120.0, + initial_backoff_seconds: float = 0.5, +) -> socket.socket: + """Dial ``(host, port)`` with exponential backoff until connected. + + Used by the drafter rank to reach target rank 0's listener. Target + rank 0 binds inside its ``ConnectToGroup`` step, which races with + the drafter rank's bootstrap; the drafter therefore retries until + the listener is up or the deadline expires. Backoff caps at 5s + between attempts so we don't sleep through a transient binding + hiccup. + """ + deadline = time.monotonic() + total_timeout_seconds + backoff = initial_backoff_seconds + last_error: BaseException | None = None + while time.monotonic() < deadline: + try: + conn = socket.create_connection( + (host, port), timeout=min(10.0, total_timeout_seconds) + ) + conn.settimeout(None) + conn.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + return conn + except (ConnectionRefusedError, OSError, TimeoutError) as exc: + last_error = exc + time.sleep(backoff) + backoff = min(backoff * 2.0, 5.0) + raise ConnectionError( + f"drafter could not reach target rank 0 at {host}:{port} " + f"within {total_timeout_seconds:.0f}s " + f"(last error: {last_error!r})" + ) + + +__all__ = [ + "accept_drafter", + "bind_target_listener", + "dial_target", + "recv_uint32_frame", + "send_uint32_frame", + "send_variable_uint32_payload", +] diff --git a/src/exo/worker/engines/mlx/generator/drafter_transport.py b/src/exo/worker/engines/mlx/generator/drafter_transport.py new file mode 100644 index 0000000000..f772cd29cc --- /dev/null +++ b/src/exo/worker/engines/mlx/generator/drafter_transport.py @@ -0,0 +1,410 @@ +"""Transport-agnostic interface to a speculative-decoding drafter. + +The pipelined spec loop in :mod:`pipelined_drafter` orchestrates rounds +against a drafter through this Protocol, so the same loop drives an +in-process drafter (this module's :class:`InProcessTransport`) or a +drafter on a different MLX rank (:mod:`remote_drafter`'s +``RemoteTransport``, communicating via ``mx.distributed.send/recv`` over +the existing JACCL/ring backend -- RDMA over Thunderbolt-bridge between +twin Macs is just a backend choice the same call site honours). + +API surface (kept as small as possible -- the spec loop owns +high-level accept/reject logic; the transport just implements the +mechanical primitives): + + * ``forward(inputs, num_forwards)`` -> ``Future[list[int]]`` -- run + ``num_forwards`` drafter forwards. Forward 0 consumes ``inputs`` + (length 1 for partial-accept seeds, length 2 for full-accept + seeds matching mlx_lm's ``draft_y = [drafts[-1], bonus]`` + convention); forwards 1..N-1 consume the previous forward's + sampled output. Returns immediately so the caller can dispatch + target verify in parallel. + + The spec loop uses this in two patterns: + * Standard round: ``forward([seed], K)`` -> K drafts. + * Speculative round (bonus prediction + round-ahead): ``forward([drafts[-1]], K+1)`` + -> ``[d_K, d^spec_0, ..., d^spec_{K-1}]`` where ``d_K`` is the + drafter's prediction for the bonus position (compared against + the actual ``bonus_t`` to detect speculation hit). + * ``trim_cache(n)`` -- trim ``n`` positions from the drafter's KV + cache. Used after partial accept (trim rejected drafts) and after + speculation miss (rollback the speculative forward). + * ``shutdown()`` -- release transport resources. No-op for the + in-process transport. + +The Future returned by ``propose`` is a synchronous +:class:`concurrent.futures.Future`, not :mod:`asyncio`. The spec loop +is a synchronous generator; blocking on a sync Future from a generator +is natural, whereas threading asyncio through the generator would be +invasive. The remote transport's IPC thread sets the Future from +outside the calling thread, which ``concurrent.futures.Future`` +supports. +""" + +from __future__ import annotations + +from concurrent.futures import Future +from typing import Callable, Final, Protocol, final, runtime_checkable + +import mlx.core as mx +from mlx_lm.models.cache import trim_prompt_cache as mlx_trim_prompt_cache + +from exo.worker.engines.mlx.types import KVCacheType, Model + +# Returned by ``propose``; the spec loop blocks on ``.result()`` once it +# has dispatched target verify. +DraftFuture = Future[list[int]] + + +@runtime_checkable +class DrafterTransport(Protocol): + """Async access to a speculative-decoding drafter. + + Implementations MUST be safe under the call sequence + :func:`pipelined_speculative_step` issues: + + 1. ``forward([seed], K)`` -> ``future`` + 2. (caller dispatches target verify in parallel) + 3. ``future.result()`` + 4. either: + a. partial accept: ``trim_cache(K - num_accepted - 1)`` then + ``forward([target_correction], K)`` for next round, or + b. full accept: no trim, then ``forward([drafts[-1], bonus], K)`` + for next round. + + For cross-round speculation an additional ``forward([drafts[-1]], K+1)`` + is issued in step 2 (parallel with verify); the first of the K+1 + returned tokens is the drafter's predicted bonus, which is checked + against the actual ``bonus_t``. On hit, the remaining K outputs are + used as round t+1's drafts. On miss, ``trim_cache(K + 1)`` rolls + back the speculative work. + + Behaviour is undefined if more than one un-resolved Future is in + flight without an intervening ``trim_cache`` or ``.result()`` call. + """ + + @property + def num_draft_tokens(self) -> int: + """``K`` -- the typical number of drafts per round. + + Remote transports use this to pre-allocate fixed-size receive + buffers (sized for ``K + 1`` to cover the speculative forward). + ``forward()`` accepts ``num_forwards`` up to ``K + 1``. + """ + ... + + def forward(self, inputs: list[int], num_forwards: int) -> DraftFuture: + """Run ``num_forwards`` drafter forwards starting from ``inputs``. + + Args: + inputs: First-forward input. Length 1 for partial-accept + seeds (``[seed]``); length 2 for full-accept seeds + (``[drafts[-1], bonus]`` matching mlx_lm's + ``_draft_generate`` ``draft_y`` convention). + Subsequent forwards consume the previous forward's + output, so they are always length-1. + num_forwards: Number of forwards (and number of returned + sampled tokens). Must satisfy + ``1 <= num_forwards <= self.num_draft_tokens + 1``; + the ``+ 1`` covers the speculative bonus-prediction + forward. + + Cache effect: extends the drafter's KV cache by + ``len(inputs) + num_forwards - 1`` positions. + + Returns: + A Future resolving to ``num_forwards`` sampled token ids. + """ + ... + + def trim_cache(self, n_positions: int) -> None: + """Trim ``n_positions`` from the drafter's KV cache. + + Used after partial accept (``n_positions = K - num_accepted - 1``) + and after speculation miss (``n_positions = positions added by + the speculative forward``). + + ``n_positions == 0`` is a valid no-op so callers don't have to + guard against the trivial case. Negative values raise + ``ValueError``. + """ + ... + + def reset_and_prefill(self, prompt_tokens: list[int]) -> None: + """Reset the drafter cache and prefill it with ``prompt_tokens``. + + Issued once at the start of every request so the drafter cache + is aligned with the target's cache before the spec loop starts. + ``prompt_tokens`` is the prompt minus the last 2 tokens (matching + the in-process path's ``_spec_drafter_prefill`` invariant); + the spec loop seeds from the last prompt token internally. + + Empty ``prompt_tokens`` is valid (very short prompts) and only + resets the cache. + + For the in-process transport this is a no-op when the caller + owns drafter cache prefill externally (the legacy mlx_generate + path). Implementations that own the drafter cache fully (e.g. + the remote transport) handle reset + prefill internally here. + """ + ... + + def shutdown(self) -> None: + """Release transport resources. Idempotent. + + In-process transport: no-op (the drafter model and cache are + owned by the caller). Remote transport: terminates the drafter + rank's serve loop, drains pending IPC. + """ + ... + + +# --------------------------------------------------------------------------- +# In-process transport +# --------------------------------------------------------------------------- + + +@final +class InProcessTransport: + """Drafter model + cache live in the calling process on the same MLX device. + + All MLX work happens on the calling thread; ``propose`` runs the K + drafter forwards inline and returns an immediately-resolved Future + so the call site is uniform with the remote transport. Any + pipelining win at this transport comes from MLX's intra-forward + async dispatch (``mx.async_eval`` between drafter forwards) and + the cross-round speculation in :func:`pipelined_speculative_step`. + + Apple Silicon's unified-memory single GPU bounds the gain because + drafter and target target compete for the same memory bandwidth on + the same Metal command queue; on multi-machine deployments the + same call site runs against :class:`RemoteTransport` instead and + the gain unlocks. + """ + + def __init__( + self, + *, + draft_model: Model, + draft_cache: KVCacheType, + num_draft_tokens: int, + ) -> None: + if num_draft_tokens < 1: + raise ValueError(f"num_draft_tokens must be >= 1, got {num_draft_tokens}") + self._draft_model = draft_model + self._draft_cache = draft_cache + self._num_draft_tokens = num_draft_tokens + + @property + def num_draft_tokens(self) -> int: + return self._num_draft_tokens + + def forward(self, inputs: list[int], num_forwards: int) -> DraftFuture: + # ``+ 1`` upper bound covers the speculative bonus-prediction + # forward; see DrafterTransport docstring. + upper = self._num_draft_tokens + 1 + if not 1 <= num_forwards <= upper: + raise ValueError( + f"num_forwards must be in [1, {upper}], got {num_forwards}" + ) + if not 1 <= len(inputs) <= 2: + # Length 1 = partial-accept seed; length 2 = full-accept + # ``[drafts[-1], bonus]`` shape. No other shape is meaningful + # for spec decoding and accepting it would mask bookkeeping + # bugs in the spec loop. + raise ValueError(f"inputs must have length 1 or 2, got {len(inputs)}") + + future: DraftFuture = Future() + try: + outputs = self._run_drafter_forwards(inputs, num_forwards) + future.set_result(outputs) + except Exception as exc: + future.set_exception(exc) + return future + + def trim_cache(self, n_positions: int) -> None: + if n_positions < 0: + raise ValueError(f"n_positions must be >= 0, got {n_positions}") + if n_positions == 0: + return + # mlx_lm types ``trim_prompt_cache`` against ``List[Cache]``; + # exo's ``KVCacheType`` is a structural superset, hence the + # cast + ignore (same pattern used in ``mlx_generate`` and the + # n-gram spec loop). + from typing import cast as _cast + + mlx_trim_prompt_cache(_cast(list[object], self._draft_cache), n_positions) # type: ignore[reportArgumentType] + + def reset_and_prefill(self, prompt_tokens: list[int]) -> None: + """No-op: the legacy in-process path manages drafter cache externally. + + ``mlx_generate`` allocates the drafter cache, runs + :func:`exo.worker.engines.mlx.generator.generate._spec_drafter_prefill`, + and only then constructs this transport. Re-running prefill + here would double-fill the cache. The Protocol method exists + for symmetry with :class:`RemoteTransport`, where the drafter + cache lives on the drafter rank and the transport owns its + per-request reset/prefill. + """ + del prompt_tokens + + def shutdown(self) -> None: + return + + # -- internals -------------------------------------------------------- + + def _run_drafter_forwards(self, inputs: list[int], num_forwards: int) -> list[int]: + """Mirror of mlx_lm's ``_draft_generate`` semantics. + + Forward 0 consumes ``inputs`` (length 1 or 2); forwards 1..N-1 + consume the previous forward's sampled output. ``mx.async_eval`` + between forwards lets the GPU pipeline the dispatches. + """ + ys: list[mx.array] = [] + y = mx.array(inputs, dtype=mx.uint32) + for _ in range(num_forwards): + logits = self._draft_model(y[None], cache=self._draft_cache) + sampled = mx.argmax(logits[:, -1, :], axis=-1).astype(mx.uint32) + mx.async_eval(sampled) + ys.append(sampled) + y = sampled + # Force a sync at the end so the cache state is realised before + # the spec loop dispatches target verify on top of these outputs. + mx.eval(ys + [c.state for c in self._draft_cache]) # type: ignore[reportArgumentType] + return [int(t.item()) for t in ys] + + +# --------------------------------------------------------------------------- +# Transport kind selection +# --------------------------------------------------------------------------- + + +ALL_TRANSPORT_KINDS: Final[tuple[str, ...]] = ("inprocess",) +"""Recognised values of ``EXO_DRAFTER_TRANSPORT``. + +The ``"remote"`` option was retired alongside the ``mx.distributed``- +backed drafter wire (the v3+ asymmetric path uses a builder-supplied +:class:`RemoteTransport` bound to a TCP socket; it cannot be +constructed via this env-var factory because the socket comes from +target rank 0's listener and isn't available at process startup). +""" + +EXO_DRAFTER_TRANSPORT_ENV: Final[str] = "EXO_DRAFTER_TRANSPORT" + + +def clamp_num_draft_tokens_to_transport( + requested_num_draft_tokens: int, + transport: DrafterTransport, +) -> tuple[int, bool]: + """Clamp a per-request K against the transport's wire-protocol budget. + + Asymmetric placement allocates a long-lived ``RemoteTransport`` at + builder time with a fixed ``num_draft_tokens`` budget (see + ``builder.py``). A per-request ``num_draft_tokens`` override above + the budget would otherwise raise ``ValueError`` deep inside + :class:`PipelinedModelDrafter`, killing the runner subprocess and + leaving the peer rank wedged (regression: aborted K=8 sweep at + 14:35:05 took the target rank with it). Clamping silently to the + transport max is the only safe behaviour: the wire-protocol budget + is a startup-time setting (``EXO_NUM_DRAFT_TOKENS``) and cannot be + widened mid-flight without re-warmup. + + Returns the (possibly clamped) K and a flag indicating whether + clamping was applied so callers can emit a structured warning. + + :raises ValueError: if ``requested_num_draft_tokens`` is < 1. The + spec loop never proposes zero or negative drafts, so this would + be a programmer error rather than a malformed request. + """ + if requested_num_draft_tokens < 1: + raise ValueError( + f"requested_num_draft_tokens must be >= 1, got {requested_num_draft_tokens}" + ) + transport_max = transport.num_draft_tokens + if requested_num_draft_tokens > transport_max: + return transport_max, True + return requested_num_draft_tokens, False + + +def parse_transport_kind(raw: str | None, default: str) -> str: + """Parse the ``EXO_DRAFTER_TRANSPORT`` env var, warning on unknown values.""" + if raw is None: + return default + candidate = raw.strip().lower() + if candidate in ALL_TRANSPORT_KINDS: + return candidate + # Imported lazily so this module is importable without the runner + # bootstrap (used by tests that exercise the parser in isolation). + from exo.worker.runner.bootstrap import logger + + logger.warning( + f"{EXO_DRAFTER_TRANSPORT_ENV}={raw!r} not in {ALL_TRANSPORT_KINDS}; " + f"falling back to {default!r}" + ) + return default + + +def make_inprocess_transport( + *, + draft_model: Model | None, + draft_cache: KVCacheType | None, + num_draft_tokens: int, + group: mx.distributed.Group | None = None, + drafter_rank: int | None = None, + target_rank: int | None = None, +) -> DrafterTransport: + """Build an :class:`InProcessTransport`. + + Wrapped in a factory so callers don't import the concrete class; + keeps the spec loop coupled only to the Protocol. The ``group`` / + ``drafter_rank`` / ``target_rank`` kwargs are accepted (ignored) for + parity with :func:`make_remote_transport`, so :func:`make_drafter` + can dispatch to either factory with one call shape. + """ + del group, drafter_rank, target_rank # remote-only knobs + if draft_model is None or draft_cache is None: + raise ValueError( + "InProcessTransport requires draft_model and draft_cache; " + "remote transport is the only path that runs without them" + ) + return InProcessTransport( + draft_model=draft_model, + draft_cache=draft_cache, + num_draft_tokens=num_draft_tokens, + ) + + +# The dispatch table returns either a :class:`DrafterTransport` (in-process, +# directly consumable by the spec loop) or a :class:`RemoteTransport` +# (the wire owner; callers must call ``open_session()`` to obtain a +# :class:`DrafterTransport` view per request). Callers route on the +# concrete return type rather than relying on a single Protocol. +_TransportFactory = Callable[..., object] + + +def transport_factory_for(kind: str) -> _TransportFactory: + """Return the factory for the requested transport kind. + + Only ``"inprocess"`` is constructible via this factory; the + asymmetric remote transport (``RemoteTransport``) is built + directly from the runner bootstrap with a connected socket from + target rank 0's drafter listener. + + Raises: + ValueError: ``kind`` is not in :data:`ALL_TRANSPORT_KINDS`. + """ + if kind == "inprocess": + return make_inprocess_transport + raise ValueError(f"Unknown drafter transport kind: {kind!r}") + + +__all__ = [ + "ALL_TRANSPORT_KINDS", + "DraftFuture", + "DrafterTransport", + "EXO_DRAFTER_TRANSPORT_ENV", + "InProcessTransport", + "make_inprocess_transport", + "parse_transport_kind", + "transport_factory_for", +] diff --git a/src/exo/worker/engines/mlx/generator/generate.py b/src/exo/worker/engines/mlx/generator/generate.py index c7a7612693..718e72e43c 100644 --- a/src/exo/worker/engines/mlx/generator/generate.py +++ b/src/exo/worker/engines/mlx/generator/generate.py @@ -7,9 +7,11 @@ import mlx.core as mx from mlx_lm.generate import ( + PromptProcessingBatch, maybe_quantize_kv_cache, stream_generate, ) +from mlx_lm.models.cache import trim_prompt_cache as mlx_trim_prompt_cache from mlx_lm.sample_utils import make_logits_processors, make_sampler from mlx_lm.tokenizer_utils import TokenizerWrapper @@ -55,6 +57,12 @@ KV_GROUP_SIZE, MAX_TOKENS, ) +from exo.worker.engines.mlx.generator.drafter import ( + Drafter, + DraftMode, + make_drafter, + resolve_draft_mode, +) from exo.worker.engines.mlx.generator.remote_prefill import remote_prefill from exo.worker.engines.mlx.types import KVCacheType, Model from exo.worker.engines.mlx.utils_mlx import ( @@ -396,12 +404,244 @@ def combined_progress_callback(processed: int, total: int) -> None: return tokens_per_sec, num_tokens, snapshots[:-1] if snapshots else [] +class BatchedPrefillUnsupportedError(Exception): + """Raised when ``batched_prefill`` cannot run for the requested batch. + + The caller is expected to recover by falling back to per-slot + :func:`prefill`. Reasons include cache types that do not implement + ``merge``/``extract`` (e.g. ``DeepseekV4Cache``), pipeline-parallel + targets where collective semantics differ, or any prompt being too + short to leave a decode-seed token after slicing. + """ + + +def batched_prefill( + *, + model: Model, + prompt_tokens_list: list[mx.array], + caches_list: list[KVCacheType], + on_progress: Callable[[int, int], None] | None = None, + prefill_step_size: int = 4096, +) -> tuple[float, int]: + """Run K prefills in a single batched forward pass. + + Wraps :class:`mlx_lm.generate.PromptProcessingBatch`. After return, each cache in + ``caches_list`` is filled in-place to offset ``len(prompt_tokens_list[i]) - 1`` + so the decode loop can seed from the last prompt token (matching the + exact-prefix-hit shape ``mlx_generate`` already handles via + ``kv_prefix_cache.get_kv_cache``). + + The K prompts are right-padded to the longest length; per-cache + ``prepare(lengths=, right_padding=)`` + ``finalize()`` remove the + padding from the cache state. Total wall-clock cost is roughly the + cost of one prefill at the longest prompt's length, amortising weight + loads across the batch — which is the whole point on a single GPU + where matmul throughput is otherwise weight-bandwidth-bound for the + sequential per-slot path. + + Args: + model: target model. Must produce caches whose layers support + ``merge``/``extract`` (e.g. ``KVCache`` + ``RotatingKVCache`` for + Gemma-4; ``DeepseekV4Cache`` is not supported and raises + :class:`BatchedPrefillUnsupportedError`). + prompt_tokens_list: per-slot full prompt tokens. Each prompt is + sliced to ``prompt[:-1]`` internally so the decode seed + (``prompt[-1]``) is left out of the cache. + caches_list: per-slot fresh caches (offset 0). Mutated in place; + on return each cache's layers point at the extracted + per-sequence state from the batched forward. + on_progress: aggregate ``(processed_max_seq, total_max_seq)`` + callback fired once per ``prefill_step_size`` chunk. The + ``processed`` count is the per-slot maximum (longest prompt's + chunk count) so progress monotonically increases even when + slots have unequal lengths. + prefill_step_size: chunk size for the prefill loop. + + Returns: + ``(aggregate_tps, total_tokens)``: sum of per-slot tokens divided + by batched wall-clock time, useful for telemetry / bench output. + + Raises: + BatchedPrefillUnsupportedError: cache layers do not implement + ``merge``/``extract`` (caller should fall back to per-slot + :func:`prefill`). + ValueError: ``len(prompt_tokens_list) != len(caches_list)`` or any + prompt has fewer than 2 tokens (need at least 1 prefill + + 1 seed token). + """ + if len(prompt_tokens_list) != len(caches_list): + raise ValueError( + f"prompt_tokens_list ({len(prompt_tokens_list)}) and caches_list " + f"({len(caches_list)}) must have the same length" + ) + if not prompt_tokens_list: + return 0.0, 0 + if any(int(p.size) < 2 for p in prompt_tokens_list): + raise ValueError( + "batched_prefill requires every prompt to have length >= 2 " + "(1 token to prefill + 1 token for the decode seed)" + ) + + # Slice off the decode seed so the post-prefill cache offset lands at + # ``len(prompt) - 1`` per slot — same invariant ``mlx_generate``'s + # exact-prefix-hit branch produces. + prefill_tokens: list[list[int]] = [ + [int(t) for t in cast(list[int], p[:-1].tolist())] for p in prompt_tokens_list + ] + total_tokens = sum(len(p) for p in prefill_tokens) + if total_tokens == 0: + return 0.0, 0 + + batch_size = len(prefill_tokens) + uids = list(range(batch_size)) + + start_time = time.perf_counter() + + try: + batch: object = PromptProcessingBatch( + model=model, + uids=uids, + caches=[list(c) for c in caches_list], + prefill_step_size=prefill_step_size, + ) + except ValueError as e: + # ``_merge_caches`` raises ``ValueError`` for cache types without + # a ``merge`` method. Surface as a typed unsupported error so the + # caller can fall back cleanly. + raise BatchedPrefillUnsupportedError( + f"cache layer does not support batching: {e}" + ) from e + + logger.debug( + f"Batched prefill: {batch_size} slots, " + f"lengths={[len(p) for p in prefill_tokens]}, total={total_tokens}" + ) + try: + # ``PromptProcessingBatch.prompt`` does the right-padding + + # chunked forward internally; one call processes all K + # sequences in lock-step. + batch.prompt(prefill_tokens) # type: ignore[reportUnknownMemberType, reportAttributeAccessIssue] + except Exception as e: + # Convert mlx-internal failures (e.g. shape mismatches between + # ``prepare(right_padding=...)`` and the model's attention + # implementation) into the typed unsupported error so the + # caller falls back to per-slot prefill instead of taking the + # whole runner down. + raise BatchedPrefillUnsupportedError( + f"PromptProcessingBatch.prompt() raised during batched prefill: {e!r}" + ) from e + + if on_progress is not None: + max_len = max(len(p) for p in prefill_tokens) + on_progress(max_len, max_len) + + # Re-extract per-sequence caches and update the original cache lists + # in place. Each ``extract_cache(idx)`` produces fresh per-layer + # cache objects of the original (non-batched) type with the + # post-prefill state for sequence ``idx``; we overwrite the + # caller-supplied list contents so any references the caller still + # holds (e.g. the SequentialGenerator's per-slot ``caches`` ref) + # see the new state. + for idx, original_cache in enumerate(caches_list): + extracted = cast(list[object], batch.extract_cache(idx)) + if len(extracted) != len(original_cache): + raise BatchedPrefillUnsupportedError( + f"extract_cache({idx}) returned {len(extracted)} layers, " + f"original cache has {len(original_cache)}" + ) + for i, layer in enumerate(extracted): + original_cache[i] = layer # type: ignore[index] + + elapsed = time.perf_counter() - start_time + aggregate_tps = total_tokens / elapsed if elapsed > 0 else 0.0 + logger.debug( + f"Batched prefill complete: {batch_size} slots, " + f"{total_tokens} tokens in {elapsed:.2f}s " + f"({aggregate_tps:.1f} tok/s aggregate)" + ) + return aggregate_tps, total_tokens + + +def resolve_speculative_decoding( + draft_model: Model | None, + group: mx.distributed.Group | None, + max_tokens: int, + num_draft_tokens: int | None, + drafter_min_output_tokens: int | None, +) -> tuple[Model | None, dict[str, object]]: + """Decide whether to actually use speculative decoding for this request. + + Pure helper so we can unit-test the policy without spinning up MLX. Returns + ``(effective_draft_model, spec_kwargs)`` for forwarding to + ``stream_generate``. + + Policy: + - Distributed runs: drafter is dropped (mlx_lm does not pipe the drafter + through the multi-device path yet). + - Single-device + drafter + ``max_tokens <= drafter_min_output_tokens``: + drafter is dropped (item 8 -- short outputs don't amortise the prefill + cost). + - Single-device + drafter active: forward ``num_draft_tokens`` (item 1) + via kwargs so ``speculative_generate_step`` honors it. + """ + if group is not None or draft_model is None: + return None, {} + + if ( + drafter_min_output_tokens is not None + and max_tokens <= drafter_min_output_tokens + ): + logger.debug( + f"Short generation (max_tokens={max_tokens} <= " + f"{drafter_min_output_tokens}); skipping drafter for this request." + ) + return None, {} + + spec_kwargs: dict[str, object] = {} + if num_draft_tokens is not None: + spec_kwargs["num_draft_tokens"] = num_draft_tokens + return draft_model, spec_kwargs + + +def _spec_drafter_prefill( + drafter: Model, + drafter_cache: KVCacheType, + tokens: mx.array, + step: int = 4096, +) -> None: + """Advance ``drafter_cache`` by running ``drafter`` on ``tokens``. + + Used on the speculative-decoding path to bring the drafter cache to the + same offset as the target cache before stream_generate's + ``speculative_generate_step._prefill`` ingests the final two prompt + tokens. Without this, the drafter cache would be empty (or stuck at a + prefix-cache hit boundary) while the target cache is at ``prompt - 2``, + desyncing mlx_lm's spec bookkeeping. + """ + if tokens.size == 0: + return + y = tokens + while y.size > 0: + n = min(step, y.size) + drafter(y[:n][None], cache=drafter_cache) + mx.eval([c.state for c in drafter_cache]) # type: ignore[reportArgumentType] + y = y[n:] + + def warmup_inference( model: Model, tokenizer: TokenizerWrapper, group: mx.distributed.Group | None, model_id: ModelId, + draft_model: Model | None = None, ) -> int: + """Run a throwaway generation to JIT-compile kernels and prime caches. + + When ``draft_model`` is supplied (single-device only), the drafter + participates in the warmup so the *first real request* doesn't pay a + cold-cache penalty on the drafter's first speculative step. This is + item 3 from the drafter tuning plan. + """ logger.info(f"warming up inference for instance: {model_id}") content = InputMessageContent( @@ -424,7 +664,10 @@ def warmup_inference( mx_barrier(group) - logger.info("Generating warmup tokens") + logger.info( + "Generating warmup tokens" + + (" (with drafter)" if draft_model is not None else "") + ) t = time.monotonic() @@ -435,6 +678,7 @@ def warmup_inference( prompt=warmup_prompt, kv_prefix_cache=None, group=group, + draft_model=draft_model, ): tokens_generated += 1 @@ -470,6 +714,13 @@ def proc(_history: mx.array, logits: mx.array) -> mx.array: logits[..., tid] = -1e9 return logits + # Marks the processor as not dependent on the running token history, + # so the speculative-decoding verify loop can apply it once to a + # batched ``(K+1, vocab)`` logits tensor and sample all positions + # in a single host-device sync. Stateful processors (e.g. repetition + # penalty) leave this attribute unset and force the per-position + # path. + proc.position_independent = True # type: ignore[reportAttributeAccessIssue] return proc @@ -541,7 +792,33 @@ def mlx_generate( on_generation_token: Callable[[], None] | None = None, vision_processor: VisionProcessor | None = None, draft_model: Model | None = None, + drafter_kv_prefix_cache: KVPrefixCache | None = None, + drafter_model_id: ModelId | None = None, + num_draft_tokens: int | None = None, + drafter_min_output_tokens: int | None = None, + asymmetric_drafter_rank: int | None = None, + asymmetric_drafter_transport: object | None = None, + target_peer_fanout: object | None = None, + precomputed_target_cache: KVCacheType | None = None, ) -> Generator[GenerationResponse]: + """Generate tokens for ``task``. + + The ``precomputed_target_cache`` argument is the seam used by + :class:`SequentialGenerator._start_batch` to inject a target-side cache + that has already been prefilled (typically via :func:`batched_prefill` + across multiple in-flight requests on a single GPU). When supplied: + + * the prefix-cache lookup is bypassed entirely (we don't pollute the + shared ``KVPrefixCache`` with per-request entries — V1 trade-off); + * the local :func:`prefill` call is a no-op (its prompt slice is + length 0); + * cache offset is assumed to be ``len(all_prompt_tokens) - 1`` so the + decode loop seeds from the last prompt token (identical shape to + the existing ``is_exact_hit`` path of ``KVPrefixCache.get_kv_cache``). + + Eligibility is enforced by the caller — see + :meth:`SequentialGenerator._batch_eligible_for_prefill`. + """ # Ensure that generation stats only contains peak memory for this generation mx.reset_peak_memory() # TODO: Randomise task seed and set in taskparams, instead of hard coding as 42. @@ -578,17 +855,135 @@ def mlx_generate( if is_bench and not task.use_prefix_cache: kv_prefix_cache = None - # Use prefix cache if available, otherwise create fresh cache + # Resolve drafting strategy up-front so cache setup below can branch on + # the *effective* mode rather than the unfiltered ``draft_model``. + # Precedence: per-request ``draft_mode`` > per-request ``use_drafter`` > + # ``EXO_DRAFT_MODE`` env var > implicit default (``model`` if a drafter + # is loaded, else ``none``). Distributed runs always degrade to + # ``none`` because mlx_lm does not yet route either model-drafter or + # n-gram drafting through the pipeline-parallel path. + request_use_drafter = task.use_drafter + request_num_draft_tokens = task.num_draft_tokens + request_draft_mode = task.draft_mode + effective_num_draft_tokens = ( + request_num_draft_tokens + if request_num_draft_tokens is not None + else num_draft_tokens + ) or 0 + max_tokens = task.max_output_tokens or MAX_TOKENS + # ``asymmetric_drafter_rank`` is set on every target rank in an + # asymmetric placement (it's a property of the placement, not of + # any one rank). ``asymmetric_drafter_transport`` is set only on + # the target root rank (rank 0 of the target subgroup), which owns + # the socket to the drafter. Both ranks must enter the pipelined + # branch because they need to make matching TP collectives every + # round; the non-root rank consumes drafts via a rank-0 broadcast + # on the target subgroup (see :class:`PipelinedModelDrafter`). + asymmetric_drafter_active = ( + asymmetric_drafter_rank is not None and request_use_drafter is not False + ) + asymmetric_drafter_is_root = ( + asymmetric_drafter_active and asymmetric_drafter_transport is not None + ) + if asymmetric_drafter_active: + # Asymmetric placement: the drafter lives on a separate node, + # talking to target rank 0 over a TCP socket owned by the + # ``RemoteTransport`` wire. This bypasses the legacy "group is + # not None -> draft_mode = none" demotion (which exists because + # mlx_lm's own speculative_generate_step doesn't handle + # pipeline collectives). The pipelined+remote path is the + # whole point of the asymmetric topology, so honor it + # unconditionally on every target rank. + draft_mode: DraftMode = "pipelined" + elif group is not None: + draft_mode = "none" + else: + draft_mode = resolve_draft_mode( + has_drafter_model=draft_model is not None, + request_use_drafter=request_use_drafter, + request_draft_mode=request_draft_mode, + ) + # Item 8: short-output skip applies to drafter-model paths + # (``"model"`` and ``"pipelined"``) where the drafter prefill cost + # dominates. N-gram drafting has no prefill (microsecond suffix- + # match per round) so the threshold is irrelevant; baseline non- + # spec wouldn't be cheaper anyway. + if ( + draft_mode in ("model", "pipelined") + and drafter_min_output_tokens is not None + and max_tokens <= drafter_min_output_tokens + ): + logger.info( + f"draft_mode demoted to 'none' for short request " + f"(max_tokens={max_tokens} <= {drafter_min_output_tokens})" + ) + draft_mode = "none" + effective_draft_model = ( + draft_model if draft_mode in ("model", "pipelined") else None + ) + # Reused below: drafter-model paths need paired drafter caches; the + # ngram and none paths don't. The variable name is preserved for + # readability with the existing cache bookkeeping code below. + spec_active = ( + draft_mode in ("model", "pipelined") and effective_draft_model is not None + ) + if effective_num_draft_tokens < 1: + # Defaulted to 0 above when the runner didn't pre-resolve K and the + # request didn't override either. Clamp to 1 so n-gram and model + # drafters don't crash on zero-K proposals. + effective_num_draft_tokens = 1 + + if asymmetric_drafter_is_root and asymmetric_drafter_transport is not None: + # Only the root has access to the transport's clamp; non-root + # target ranks pick up the (already-clamped) K from the + # broadcast wire-format size. As long as both ranks agree on + # the configured ``num_draft_tokens`` -- which they do, since + # it's derived deterministically from env / placement -- the + # broadcast slot count is identical and no per-rank clamp is + # required on the consumer. + from exo.worker.engines.mlx.generator.drafter_transport import ( + DrafterTransport as _DrafterTransport, + ) + from exo.worker.engines.mlx.generator.drafter_transport import ( + clamp_num_draft_tokens_to_transport, + ) + + if isinstance(asymmetric_drafter_transport, _DrafterTransport): + clamped_k, was_clamped = clamp_num_draft_tokens_to_transport( + effective_num_draft_tokens, asymmetric_drafter_transport + ) + if was_clamped: + logger.warning( + f"clamping num_draft_tokens={effective_num_draft_tokens} " + f"to transport max={clamped_k} " + f"(request_num_draft_tokens={request_num_draft_tokens}); " + f"raise EXO_NUM_DRAFT_TOKENS at runner startup to widen " + f"the wire-protocol budget" + ) + effective_num_draft_tokens = clamped_k + prefix_hit_length = 0 matched_index: int | None = None is_exact_hit = False - if kv_prefix_cache is None: + if precomputed_target_cache is not None: + # External batched-prefill path: caller supplies a cache already + # filled to ``len(all_prompt_tokens) - 1`` and we leave a single + # decode-seed token in ``prompt_tokens``. ``prefill()`` below + # short-circuits because the slice ``prompt_tokens[:-1]`` is + # empty; the prefix-cache update path is also skipped because + # ``matched_index`` stays None and ``is_exact_hit`` stays False. + caches = precomputed_target_cache + prompt_tokens = all_prompt_tokens[-1:] + prefix_hit_length = int(all_prompt_tokens.size) - 1 + elif kv_prefix_cache is None: caches = make_kv_cache(model=model) prompt_tokens = all_prompt_tokens else: caches, prompt_tokens, matched_index, is_exact_hit = ( kv_prefix_cache.get_kv_cache( - model, all_prompt_tokens, media_regions=media_regions + model, + all_prompt_tokens, + media_regions=media_regions, ) ) prefix_hit_length = len(all_prompt_tokens) - len(prompt_tokens) @@ -597,6 +992,47 @@ def mlx_generate( f"KV cache hit: {prefix_hit_length}/{len(all_prompt_tokens)} tokens cached ({100 * prefix_hit_length / len(all_prompt_tokens):.1f}%)" ) + # Drafter cache lookup. We mirror the target's prefix-cache contract on + # the drafter so multi-turn workloads don't pay the drafter's prefill + # cost on every request (item 6). The aligned_hit logic below ensures + # both caches start the spec loop at the same offset; mismatched + # offsets would corrupt mlx_lm's spec_step bookkeeping. + drafter_caches: KVCacheType = [] + drafter_matched_index: int | None = None + if spec_active and effective_draft_model is not None: + if drafter_kv_prefix_cache is None: + drafter_caches = make_kv_cache(model=effective_draft_model) + drafter_remaining = all_prompt_tokens + else: + ( + drafter_caches, + drafter_remaining, + drafter_matched_index, + _, + ) = drafter_kv_prefix_cache.get_kv_cache( + effective_draft_model, + all_prompt_tokens, + media_regions=media_regions, + ) + target_hit = prefix_hit_length + drafter_hit = len(all_prompt_tokens) - len(drafter_remaining) + aligned_hit = min(target_hit, drafter_hit) + # Trim whichever cache overshoots so both start at ``aligned_hit``. + if target_hit > aligned_hit: + mlx_trim_prompt_cache(cast(list[object], caches), target_hit - aligned_hit) # type: ignore[reportArgumentType] + prompt_tokens = all_prompt_tokens[aligned_hit:] + prefix_hit_length = aligned_hit + if matched_index is not None and aligned_hit < target_hit: + # Trimming below the prior match invalidates the + # update-in-place path; treat as a fresh add. + matched_index = None + is_exact_hit = False + if drafter_hit > aligned_hit: + drafter_overshoot = drafter_hit - aligned_hit + mlx_trim_prompt_cache(cast(list[object], drafter_caches), drafter_overshoot) # type: ignore[reportArgumentType] + if drafter_matched_index is not None and aligned_hit < drafter_hit: + drafter_matched_index = None + logits_processors: list[Callable[[mx.array, mx.array], mx.array]] = ( make_logits_processors( repetition_penalty=task.repetition_penalty, @@ -641,6 +1077,7 @@ def mlx_generate( use_remote = ( len(prompt_tokens) > REMOTE_PREFILL_MIN_TOKENS and task.prefill_endpoint is not None + and not spec_active ) remote_prefilled = False prefill_tps = 0.0 @@ -674,6 +1111,21 @@ def mlx_generate( on_prefill_progress, distributed_prompt_progress_callback, ) + # On the spec path we mirror exo's prefill on the drafter so its + # cache reaches the same offset as the target's (prompt - 2 after + # the trim(2) inside exo.prefill). mlx_lm's + # speculative_generate_step._prefill then advances both caches by + # 1 (decode_prompt size = 2 -> processes 1 token), arriving at + # prompt - 1 with ``y = decode_prompt[-1:]`` -- the canonical + # entry state for the spec loop. + if spec_active and effective_draft_model is not None: + drafter_prefill_tokens = prompt_tokens[:-2] + if drafter_prefill_tokens.size > 0: + _spec_drafter_prefill( + effective_draft_model, + drafter_caches, + drafter_prefill_tokens, + ) cache_snapshots: list[CacheSnapshot] | None = ssm_snapshots_list or None if kv_prefix_cache is not None and matched_index is not None and is_exact_hit: @@ -707,129 +1159,411 @@ def mlx_generate( prefill_tps=prefill_tps, ) - # stream_generate starts from the last token - last_token = prompt_tokens[-2:] + # Drafter prefix cache update (item 6). Snapshot the drafter cache + # *before* stream_generate starts mutating it so subsequent requests + # can resume from this prompt boundary instead of replaying the + # drafter prefill. + if ( + spec_active + and drafter_kv_prefix_cache is not None + and effective_draft_model is not None + ): + if ( + drafter_matched_index is not None + and prefix_hit_length >= min_prefix_hit_length + ): + drafter_kv_prefix_cache.update_kv_cache( + drafter_matched_index, + all_prompt_tokens, + drafter_caches, + None, + restore_pos=prefix_hit_length, + media_regions=media_regions, + prefill_tps=prefill_tps, + ) + else: + drafter_kv_prefix_cache.add_kv_cache( + all_prompt_tokens, + drafter_caches, + None, + media_regions=media_regions, + prefill_tps=prefill_tps, + ) + + # stream_generate starts from the last 2 tokens; caches already cover + # prompt[:-2] via exo's prefill + c.trim(2). The non-spec and spec paths + # share the same entry state -- spec just additionally has the drafter + # cache pre-aligned to the same offset (see drafter prefill above). + decode_prompt = prompt_tokens[-2:] max_tokens = task.max_output_tokens or MAX_TOKENS accumulated_text = "" generated_text_parts: list[str] = [] generation_start_time = time.perf_counter() usage: Usage | None = None + # Speculative decoding telemetry (item 4). `from_draft_count` is the + # number of tokens stream_generate flagged as drafter-accepted; we report + # it on the final GenerationStats so dashboards / clients can A/B + # configurations on real traffic. + from_draft_count = 0 logger.info("Starting decode") mx_barrier(group) - # Speculative decoding via mlx_lm: only enabled in the single-device path - # (group is None). Distributed speculative is not yet plumbed; passing a - # draft_model alongside a non-trivial group would be a no-op, so we drop - # it explicitly to make the caller contract clear. - effective_draft_model = draft_model if group is None else None - - for completion_tokens, out in enumerate( - stream_generate( - model=model, - tokenizer=tokenizer, - prompt=last_token, - max_tokens=max_tokens, - sampler=sampler, - logits_processors=logits_processors, - prompt_cache=caches, - prefill_step_size=1, - kv_group_size=KV_GROUP_SIZE, - kv_bits=KV_BITS, - draft_model=effective_draft_model, - ), - start=1, - ): - generated_text_parts.append(out.text) - accumulated_text += out.text - - # Check for stop sequences - text = out.text - finish_reason: FinishReason | None = cast( - FinishReason | None, out.finish_reason + # Dispatch to the selected drafting strategy via ``make_drafter``. + # The factory routes: + # * ``"model"`` -> mlx_lm.speculative_generate_step (well-tested upstream) + # * ``"pipelined"`` -> custom spec loop with cross-round speculation + # behind a ``DrafterTransport`` (in-process or remote) + # * ``"ngram"`` -> in-house n-gram suffix-match spec loop + # * ``"none"`` -> plain ``mlx_lm.stream_generate`` + # Per-task session for the asymmetric remote drafter (if active). + # Opened in the ``if`` branch below; closed in the ``finally`` at + # the end of the function so a fault, cancellation, or normal + # completion all funnel through ``session.shutdown()`` and free + # the drafter rank's per-session KV cache. Without this, every + # completed request would leak ~50-100 MB of KV cache on the + # drafter rank until the runner shuts down. + asymmetric_session: object | None = None + if asymmetric_drafter_active: + assert asymmetric_drafter_rank is not None + target_subgroup_size = group.size() if group is not None else 1 + from exo.worker.engines.mlx.generator.drafter_transport import ( + DrafterTransport as _DrafterTransport, ) - stop_matched = False - - if stop_sequences: - for stop_seq in stop_sequences: - if stop_seq in accumulated_text: - # Trim text to just before the stop sequence - stop_index = accumulated_text.find(stop_seq) - text_before_stop = accumulated_text[:stop_index] - chunk_start = len(accumulated_text) - len(out.text) - text = text_before_stop[chunk_start:] - finish_reason = "stop" - stop_matched = True - break - - is_done = finish_reason is not None - - stats: GenerationStats | None = None - if is_done: - stats = GenerationStats( - prompt_tps=float(prefill_tps or out.prompt_tps), - generation_tps=float(out.generation_tps), - prompt_tokens=int(prefill_tokens + out.prompt_tokens), - generation_tokens=int(out.generation_tokens), - peak_memory_usage=Memory.from_gb(out.peak_memory), + from exo.worker.engines.mlx.generator.remote_drafter import ( + RemoteTransport as _RemoteTransport, + ) + + if asymmetric_drafter_is_root: + # Target root rank: open a per-request session on the + # ``RemoteTransport`` wire so concurrent target requests + # don't interleave OP_FORWARD frames on the same socket. + # Test fakes pass a bare ``DrafterTransport``; in that + # singular-task path we use it directly. + if isinstance(asymmetric_drafter_transport, _RemoteTransport): + asymmetric_session = asymmetric_drafter_transport.open_session() + session_transport: object = asymmetric_session + elif isinstance(asymmetric_drafter_transport, _DrafterTransport): + session_transport = asymmetric_drafter_transport + else: + raise TypeError( + "asymmetric_drafter_transport must be a RemoteTransport " + "(production asymmetric placement) or a DrafterTransport " + "(test fakes); " + f"got {type(asymmetric_drafter_transport).__name__}" + ) + # Sync this request's drafter cache against the prompt before + # constructing the drafter wrapper. The session sends OP_PREFILL + # with prompt[:-2] (matching ``_spec_drafter_prefill``'s + # invariant: align the drafter's offset to ``len(prompt) - 2`` + # so the spec loop's first OP_FORWARD seeds from prompt[-2]). + _diag_t0 = time.perf_counter() + logger.info( + f"[spec-diag] rank 0: about to materialize prefill_prompt " + f"via tolist() ({all_prompt_tokens.size} prompt tokens total)" ) - if not stop_matched and out.finish_reason not in get_args(FinishReason): - logger.warning( - f"Model generated unexpected finish_reason: {out.finish_reason}" + prefill_prompt: list[int] = [ + int(t) for t in cast(list[int], all_prompt_tokens[:-2].tolist()) + ] + logger.info( + f"[spec-diag] rank 0: prefill_prompt materialized in " + f"{(time.perf_counter() - _diag_t0) * 1000:.1f}ms " + f"(len={len(prefill_prompt)}); about to send OP_PREFILL" + ) + try: + _diag_t1 = time.perf_counter() + cast(_DrafterTransport, session_transport).reset_and_prefill( + prefill_prompt ) - - total_prompt_tokens = len(all_prompt_tokens) - usage = Usage( - prompt_tokens=total_prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=total_prompt_tokens + completion_tokens, - prompt_tokens_details=PromptTokensDetails( - cached_tokens=prefix_hit_length - ), - completion_tokens_details=CompletionTokensDetails(reasoning_tokens=0), + logger.info( + f"[spec-diag] rank 0: OP_PREFILL ACK received in " + f"{(time.perf_counter() - _diag_t1) * 1000:.1f}ms" + ) + drafter: Drafter = make_drafter( + mode=draft_mode, + num_draft_tokens=effective_num_draft_tokens, + draft_model=None, + draft_cache=None, + target_subgroup_size=target_subgroup_size, + pipelined_transport=session_transport, + target_group=group, + target_peer_fanout=target_peer_fanout, + is_target_root=True, + ) + except BaseException: + # ``make_drafter`` or ``reset_and_prefill`` raised; + # release the freshly-allocated session so the drafter + # rank doesn't hold its KV cache forever. + try: + if asymmetric_session is not None: + cast(_DrafterTransport, asymmetric_session).shutdown() + except Exception: + logger.opt(exception=True).warning( + "asymmetric drafter session shutdown raised " + "during error recovery; ignoring" + ) + asymmetric_session = None + raise + else: + # Non-root target rank in a multi-target placement: no + # socket, no session, no drafter prefill (the drafter rank + # only knows about the root's session). The consumer + # drafter receives drafts each round via a rank-0 + # broadcast on ``group``; the broadcast is the only + # cross-rank wire this rank needs. + assert group is not None and target_subgroup_size > 1, ( + "asymmetric_drafter non-root rank requires a target " + "subgroup of size > 1 (V1 single-target placements " + "only have rank 0; this branch should not be reached)" + ) + drafter = make_drafter( + mode=draft_mode, + num_draft_tokens=effective_num_draft_tokens, + draft_model=None, + draft_cache=None, + target_subgroup_size=target_subgroup_size, + pipelined_transport=None, + target_group=group, + target_peer_fanout=target_peer_fanout, + is_target_root=False, ) + else: + drafter = make_drafter( + mode=draft_mode, + num_draft_tokens=effective_num_draft_tokens, + draft_model=effective_draft_model if spec_active else None, + draft_cache=drafter_caches if spec_active else None, + ) - # Extract logprobs from the full vocabulary logprobs array - logprob: float | None = None - top_logprobs: list[TopLogprobItem] | None = None - if task.logprobs: - with mx.stream(generation_stream): - logprob, top_logprobs = extract_top_logprobs( - logprobs=out.logprobs, - tokenizer=tokenizer, - top_logprobs=task.top_logprobs or DEFAULT_TOP_LOGPROBS, - selected_token=out.token, - ) + # ``decode_prompt`` is the prefill-tail (last two tokens of the + # prompt). The cache is already aligned to ``all_prompt_tokens[:-2]`` + # via ``exo.prefill`` + ``trim(2)``; mlx_lm's internal ``_prefill`` + # advances by one more token, then the spec loop seeds from the + # last. ``full_context_tokens`` is the full prompt so the n-gram + # drafter can match against the entire history (including + # prefix-cached portions); other drafters ignore it. + full_context_tokens: list[int] = [ + int(t) for t in cast(list[int], all_prompt_tokens.tolist()) + ] + _spec_diag_rank = group.rank() if group is not None else 0 + logger.info( + f"[spec-diag] rank {_spec_diag_rank}: about to enter drafter.stream() " + f"(decode_prompt size={int(decode_prompt.size)}, " + f"max_tokens={max_tokens}, mode={draft_mode})" + ) - if is_done: - # Log generation stats - generation_elapsed = time.perf_counter() - generation_start_time - generated_tokens = len(generated_text_parts) - generation_tps = ( - generated_tokens / generation_elapsed if generation_elapsed > 0 else 0.0 + try: + for completion_tokens, out in enumerate( + drafter.stream( + model=model, + tokenizer=tokenizer, + prompt=decode_prompt, + context_tokens=full_context_tokens, + prompt_cache=caches, + max_tokens=max_tokens, + sampler=sampler, + logits_processors=logits_processors, + prefill_step_size=1, + ), + start=1, + ): + generated_text_parts.append(out.text) + accumulated_text += out.text + if getattr(out, "from_draft", False): + from_draft_count += 1 + + # Check for stop sequences + text = out.text + finish_reason: FinishReason | None = cast( + FinishReason | None, out.finish_reason ) - logger.debug( - f"Generation complete: prefill {prompt_tokens} tokens @ " - f"{prefill_tps:.1f} tok/s, generated {generated_tokens} tokens @ " - f"{generation_tps:.1f} tok/s" + stop_matched = False + + if stop_sequences: + for stop_seq in stop_sequences: + if stop_seq in accumulated_text: + # Trim text to just before the stop sequence + stop_index = accumulated_text.find(stop_seq) + text_before_stop = accumulated_text[:stop_index] + chunk_start = len(accumulated_text) - len(out.text) + text = text_before_stop[chunk_start:] + finish_reason = "stop" + stop_matched = True + break + + is_done = finish_reason is not None + + stats: GenerationStats | None = None + if is_done: + # Drafter telemetry: stamp the id whenever speculation + # actually ran for this request. The asymmetric + # ``"pipelined"`` path has no in-process draft model + # (the weights live on the drafter rank), so guarding + # solely on ``effective_draft_model is not None`` would + # spuriously zero out telemetry for the very topology + # the drafter buys us the most. We instead trust + # ``drafter.mode`` together with the asymmetric flag, + # which is set iff the placement actually wired a + # drafter rank into this instance. + telemetry_drafter_id: str | None = None + telemetry_k: int | None = None + if ( + drafter.mode == "model" and effective_draft_model is not None + ) or drafter.mode == "pipelined": + telemetry_k = effective_num_draft_tokens + if drafter_model_id is not None: + telemetry_drafter_id = str(drafter_model_id) + elif drafter.mode == "ngram": + telemetry_k = effective_num_draft_tokens + + # Pull per-round counters from the drafter when it + # surfaces them. Only the pipelined drafter does today; + # ``getattr(..., None)`` keeps this future-proof for + # drafter implementations that grow a ``metrics()`` + # method later. ``mlx_lm``'s built-in spec loop doesn't + # expose proposal counts, so the ``"model"`` mode + # surfaces only ``accepted_draft_tokens`` (from the + # ``from_draft`` flag on each yielded token). + drafter_metrics_fn = cast( + "Callable[[], dict[str, int]] | None", + getattr(drafter, "metrics", None), + ) + drafter_metrics: dict[str, int] = ( + drafter_metrics_fn() if drafter_metrics_fn is not None else {} + ) + proposed = int(drafter_metrics.get("proposed_draft_tokens", 0)) + spec_rounds = int(drafter_metrics.get("spec_decode_rounds", 0)) + + stats = GenerationStats( + prompt_tps=float(prefill_tps or out.prompt_tps), + generation_tps=float(out.generation_tps), + prompt_tokens=int(prefill_tokens + out.prompt_tokens), + generation_tokens=int(out.generation_tokens), + peak_memory_usage=Memory.from_gb(out.peak_memory), + drafter_model_id=telemetry_drafter_id, + accepted_draft_tokens=from_draft_count, + proposed_draft_tokens=proposed, + spec_decode_rounds=spec_rounds, + num_draft_tokens=telemetry_k, + draft_mode=drafter.mode, + ) + if not stop_matched and out.finish_reason not in get_args(FinishReason): + logger.warning( + f"Model generated unexpected finish_reason: {out.finish_reason}" + ) + + # OpenAI-compatible surface for spec-decode telemetry. + # ``accepted_prediction_tokens`` is OpenAI's term for + # tokens supplied by a Predicted Output that ended up in + # the completion -- semantically equivalent to our + # ``accepted_draft_tokens``. ``rejected_prediction_tokens`` + # is the count of predicted tokens that didn't make it, + # i.e. drafts that the verifier rejected. We can only + # populate this when the drafter surfaces a proposal + # count; otherwise leave it at 0 rather than guess. + rejected_prediction_tokens = ( + max(0, proposed - from_draft_count) if proposed > 0 else 0 + ) + total_prompt_tokens = len(all_prompt_tokens) + usage = Usage( + prompt_tokens=total_prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=total_prompt_tokens + completion_tokens, + prompt_tokens_details=PromptTokensDetails( + cached_tokens=prefix_hit_length + ), + completion_tokens_details=CompletionTokensDetails( + reasoning_tokens=0, + accepted_prediction_tokens=from_draft_count, + rejected_prediction_tokens=rejected_prediction_tokens, + ), + ) + + # Extract logprobs from the full vocabulary logprobs array + logprob: float | None = None + top_logprobs: list[TopLogprobItem] | None = None + if task.logprobs: + with mx.stream(generation_stream): + logprob, top_logprobs = extract_top_logprobs( + logprobs=out.logprobs, + tokenizer=tokenizer, + top_logprobs=task.top_logprobs or DEFAULT_TOP_LOGPROBS, + selected_token=out.token, + ) + + if is_done: + # Per-request generation summary. INFO level because it's + # one line per completed request -- bounded volume, and + # the operator absolutely needs visibility into drafter + # effectiveness without flipping ``-vv``. When the + # drafter ran, surface acceptance fraction + per-position + # acceptance rate (when proposal count is available) + + # rounds + K. + generation_elapsed = time.perf_counter() - generation_start_time + generated_tokens = len(generated_text_parts) + generation_tps = ( + generated_tokens / generation_elapsed + if generation_elapsed > 0 + else 0.0 + ) + base_msg = ( + f"Generation complete: prefill {prompt_tokens} tokens @ " + f"{prefill_tps:.1f} tok/s, generated {generated_tokens} " + f"tokens @ {generation_tps:.1f} tok/s" + ) + if stats is not None and stats.drafter_model_id is not None: + fraction = stats.drafter_acceptance_fraction + rate = stats.drafter_acceptance_rate + fraction_str = f"{fraction:.1%}" if fraction is not None else "n/a" + rate_str = f"{rate:.1%}" if rate is not None else "n/a" + drafter_msg = ( + f", drafter={stats.draft_mode}/" + f"{stats.drafter_model_id} " + f"K={stats.num_draft_tokens} " + f"rounds={stats.spec_decode_rounds} " + f"accepted={stats.accepted_draft_tokens}/" + f"{stats.proposed_draft_tokens or 'n/a'} " + f"(rate={rate_str}, " + f"fraction_of_emitted={fraction_str})" + ) + else: + drafter_msg = "" + logger.info(base_msg + drafter_msg) + if on_generation_token is not None: + on_generation_token() + + yield GenerationResponse( + text=text, + token=out.token, + logprob=logprob, + top_logprobs=top_logprobs, + finish_reason=finish_reason, + stats=stats, + usage=usage, ) - if on_generation_token is not None: - on_generation_token() - - yield GenerationResponse( - text=text, - token=out.token, - logprob=logprob, - top_logprobs=top_logprobs, - finish_reason=finish_reason, - stats=stats, - usage=usage, - ) - if is_done: - mx_barrier(group) - break + if is_done: + mx_barrier(group) + break + + # Limit accumulated_text to what's needed for stop sequence detection + if max_stop_len > 0 and len(accumulated_text) > max_stop_len: + accumulated_text = accumulated_text[-max_stop_len:] + finally: + # Free the per-request drafter-rank KV cache. ``shutdown`` is + # idempotent on ``_SessionHandle``; the ``try / except`` is + # belt-and-suspenders for the rare case where the wire is + # already torn down (e.g. runner shutdown raced this call). + if asymmetric_session is not None: + try: + from exo.worker.engines.mlx.generator.drafter_transport import ( + DrafterTransport as _DrafterTransport, + ) - # Limit accumulated_text to what's needed for stop sequence detection - if max_stop_len > 0 and len(accumulated_text) > max_stop_len: - accumulated_text = accumulated_text[-max_stop_len:] + cast(_DrafterTransport, asymmetric_session).shutdown() + except Exception: + logger.opt(exception=True).warning( + "asymmetric drafter session shutdown raised; the " + "drafter rank will free its session cache on its " + "next OP_SHUTDOWN" + ) diff --git a/src/exo/worker/engines/mlx/generator/pipelined_drafter.py b/src/exo/worker/engines/mlx/generator/pipelined_drafter.py new file mode 100644 index 0000000000..1621b66de1 --- /dev/null +++ b/src/exo/worker/engines/mlx/generator/pipelined_drafter.py @@ -0,0 +1,1242 @@ +"""Pipelined speculative-decoding spec loop. + +Implements :class:`PipelinedModelDrafter` -- a custom spec loop that +talks to the drafter through a :class:`DrafterTransport` (in-process, +remote, ...). The win over :class:`ModelDrafter` (which delegates to +``mlx_lm.speculative_generate_step``) is **cross-round speculation**: +while the target rank verifies round ``t``'s drafts, the drafter +speculatively starts round ``t + 1`` by predicting the would-be bonus +token and continuing for ``K`` more forwards. If the target's actual +bonus matches the drafter's predicted bonus, round ``t + 1``'s drafts +are already in hand by the time round ``t``'s verify finishes; if not, +the speculative work is rolled back and the standard non-speculative +path runs. + +Apple-Silicon caveat: MLX serialises Metal command queues per device, +so the in-process overlap factor between drafter and target forwards +is ~0.1-0.3 (parallelism is bounded by memory-bandwidth contention, +not GPU saturation). The architecture's payoff scales with topology: +on a multi-machine deployment where target verify includes a network +round-trip, the speculative drafter forward fully overlaps the +network latency and the gain unlocks. :class:`RemoteTransport` ships +exactly that case. + +Multi-target asymmetric placement (V2 ``target_subgroup_size > 1``) +-------------------------------------------------------------------- +The target group is tensor-parallel across N nodes; the drafter lives +on a different node and talks to target rank 0 over a TCP socket. Per- +round flow: + + 1. **Drafter -> target rank 0 (socket).** Rank 0 issues an + ``OP_FORWARD`` over the wire, gets back ``k_this`` drafts. + 2. **Rank 0 -> all target ranks (collective).** Rank 0 broadcasts + the drafts on the target subgroup via :func:`_broadcast_drafts`. + Non-root ranks receive into the same buffer shape; the broadcast + uses :func:`mx_broadcast_int_list` (a length-prefixed + ``all_sum``). Drafter-rank does NOT participate -- it isn't a + member of the target subgroup. + 3. **All target ranks (collective).** Run the verify forward + ``model([seed, *drafts])`` -- a TP all-reduce inside the model + makes logits byte-identical on every target rank. + 4. **Rank 0 samples + broadcasts target tokens.** The sampler is + non-deterministic (temperature > 0 uses MLX's per-rank PRNG) so + each rank would otherwise produce divergent ``target_tokens``, + diverge on ``num_accepted``, trim the prompt cache by different + amounts, and desync at the next TP forward. Rank 0 samples + locally and broadcasts the chosen tokens via + :func:`_broadcast_target_tokens`; non-root ranks consume the + broadcast and skip the sampler entirely. Determinism then falls + out of the broadcast contract rather than relying on RNG state + coordination. + 5. **All target ranks compute identical accept/reject.** Both ranks + compare ``target_tokens`` (now identical from broadcast) against + ``drafts`` (also identical from step 2), reach the same + ``num_accepted``, and trim the prompt cache by the same amount. + 6. **Drafter cache reconciliation on rank 0 only.** Rank 0 issues + any required ``OP_TRIM_CACHE`` / next-round ``OP_FORWARD`` over + the socket; non-root just waits for the next draft broadcast + round at step 2. + +The collective overhead per round is two small ``all_sum`` calls +(drafts ``k+1`` ints, target tokens ``k+1`` ints) -- microsecond- +range on Thunderbolt RDMA, negligible against the verify forward. + +Recovery: drafter-rank death mid-generation +------------------------------------------- +If the drafter rank crashes between rounds, root's +``transport.forward`` raises :class:`OSError` (subclassed as +``ConnectionError`` / ``BrokenPipeError`` depending on which side +closed). The recovery is layered: + + 1. **Within-request abort** (this module). Before re-raising, the + :func:`_pipelined_speculative_step` wrapper broadcasts + :data:`DRAFT_ABORT_SENTINEL` on the draft channel. Non-root + ranks decode the sentinel inside :func:`_broadcast_drafts` and + raise :class:`DrafterAbortedError`, exiting their spec loop in + lockstep with root rather than blocking on a next-round + broadcast that will never arrive. The + :class:`exo.worker.engines.mlx.generator.remote_drafter.RemoteTransport` + also flips a sticky ``is_failed`` flag so subsequent + :meth:`open_session` calls fail fast instead of allocating a + new spec session on a dead wire. + + 2. **Cross-request teardown** (control plane). The runner + subprocess that owned the failed transport surfaces the + exception out of ``mlx_generate``, the runner crashes, the + supervisor reports :class:`RunnerFailed`, and the master's + worker-plan ``_kill_runner`` rule shuts every peer rank down + in the same instance. A fresh placement is re-issued on the + next planning tick. + + 3. **Drafter-node disconnect** (control plane). When the drafter + *node* goes offline (rather than the drafter *process*), the + master's instance-deletion loop iterates + ``instance.all_node_to_runner`` (target + drafter) and emits + :class:`InstanceDeleted` once the drafter node leaves + ``connected_node_ids``. Workers pick up the deletion in the + usual plan tick. Total time-to-recovery is bounded by the + master's ``node_inactivity_timeout`` (5 s) plus the + supervisor's SIGTERM/SIGKILL escalation budget (worst case + ~25 s), the same envelope as a target-rank crash. + +Target-rank death (a peer target rank in the TP subgroup) takes +the same path as case 3 above: the master's instance-deletion +loop already covered ``shard_assignments.node_to_runner``; the +worker plan's ``_kill_runner`` rule gossips ``RunnerFailed`` +across the surviving ranks and the supervisor SIGKILL chain +unblocks any in-flight TP collectives. + +Cache accounting (drafter side) -- this is the only complex bit, so +spelled out here once and referenced from the code: + + Notation: ``O`` = drafter cache offset before round ``t``'s propose. + ``K`` = ``num_draft_tokens``. + + Round ``t`` propose, length-1 seed (partial-accept-from-prev case): + ``forward([seed_t], K)`` -> K outputs. K forwards, each adds 1 + position. Cache offset O+K. Cache content extends with + ``[seed_t, d_0..d_{K-2}]`` (the K-th draft d_{K-1} is the K-th + output, *not* fed back as input). + + Round ``t`` propose, length-2 seed (full-accept-from-prev case): + ``forward([drafts_{t-1}[-1], seed_t], K)`` -> K outputs. K forwards; + the first has length-2 input, so cache extends by K+1. + + Speculative round ``t + 1`` (cross-round speculation): + ``forward([drafts_t[-1]], K + 1)`` -> K+1 outputs. K+1 forwards, + cache extends by K+1. Outputs are + ``[d^pred_K, d^spec_0, ..., d^spec_{K-1}]``: the first is the + drafter's prediction of bonus_t (compared against actual bonus_t + to detect speculation hit); the rest are round t+1's drafts. + Cache offset after speculation: O+2K+1. + + Round ``t`` accept outcomes: + + * Partial accept (``num_accepted < K_this``): drafter cache trim + by ``max(K_this - num_accepted - 1, 0)``. If speculation was + active, also rollback ``K + 1``. Round ``t + 1``'s propose is a + length-1-seed call. + * Full accept, speculation MISS (``bonus_t != d^pred_K``): rollback + ``K + 1``. Round ``t + 1``'s propose is a length-2-seed call. + * Full accept, speculation HIT: no rollback. Drafter cache + offset O+2K+1, content matches what mlx_lm's ``_draft_generate`` + would produce after a length-2 first forward + K-1 length-1 + forwards in round t+1. Round ``t + 1``'s drafts come from the + speculative outputs; round ``t + 1`` skips its own propose call. + * Truncated last round (``K_this < K``): speculation is disabled + because there's no round t+1 to feed. + +The matching :func:`_pipelined_speculative_step` enforces this +accounting; any divergence between the comments above and the code is +a bug, please flag it. +""" + +from __future__ import annotations + +import contextlib +import os as _diag_os +import sys as _diag_sys +import time +from typing import Callable, Final, Generator, Sequence, cast, final + +import mlx.core as mx +from mlx_lm.generate import GenerationResponse +from mlx_lm.models.cache import trim_prompt_cache as mlx_trim_prompt_cache +from mlx_lm.tokenizer_utils import TokenizerWrapper + +from exo.worker.engines.mlx.generator.drafter import DraftMode +from exo.worker.engines.mlx.generator.drafter_transport import ( + DrafterTransport, + DraftFuture, +) +from exo.worker.engines.mlx.types import KVCacheType, Model +from exo.worker.engines.mlx.utils_mlx import ( + TargetPeerFanout, + mx_broadcast_int_list, + target_peer_broadcast_int_list, +) +from exo.worker.runner.bootstrap import logger as _diag_logger + +# Per-round spec-decode diagnostics. Off by default; set +# ``EXO_SPEC_DIAG=1`` to enable. When enabled, each call writes both +# to loguru and to ``/tmp/spec_diag_.log`` so diagnostics survive +# whatever's swallowing the runner subprocess's stdout (loguru +# forwarding has been observed to drop on some nodes in our cluster). +# +# Added during gemma-4 asymmetric-drafter bring-up to localize a +# TP-collective deadlock; the hooks are kept (gated) so future +# correctness regressions can be isolated quickly without redeploying +# with new logging. +_SPEC_DIAG_ENABLED: Final[bool] = _diag_os.environ.get("EXO_SPEC_DIAG", "") in ( + "1", + "true", + "yes", +) + + +def _spec_diag(message: str) -> None: + """Emit a spec-decode diagnostic line. No-op unless ``EXO_SPEC_DIAG``.""" + if not _SPEC_DIAG_ENABLED: + return + _diag_logger.info(message) + try: + path = f"/tmp/spec_diag_{_diag_os.getpid()}.log" + with open(path, "a", encoding="utf-8") as fh: + _ = fh.write(f"{time.time():.6f} {message}\n") + except OSError: + try: + _ = _diag_sys.stderr.write(f"[spec-diag fallback] {message}\n") + _diag_sys.stderr.flush() + except OSError: + pass + + +# Length-prefix slot value reserved for the "drafter aborted" signal. +# Picked from the int32 positive range so it survives +# ``_validate_broadcast_values`` (well above any legitimate ``K``, +# below ``_MX_BROADCAST_MAX_VALUE`` so the validator accepts it). +DRAFT_ABORT_SENTINEL: Final[int] = (1 << 31) - 2 + + +@final +class DrafterAbortedError(RuntimeError): + """Raised by non-root target ranks when root signals draft abort. + + Root encodes :data:`DRAFT_ABORT_SENTINEL` in the broadcast + length-prefix slot when its ``transport.forward()`` raises + (drafter rank crashed, socket dropped, etc). Non-root ranks + decode the sentinel inside :func:`_broadcast_drafts` and raise + this exception so the spec loop on every rank exits in lockstep, + rather than non-root hanging forever on the next-round draft + broadcast that root will never send. + """ + + +def _get_eos_ids(tokenizer: TokenizerWrapper) -> list[int]: + eos: list[int] | None = getattr(tokenizer, "eos_token_ids", None) + if eos is None: + return [] + return eos + + +def _get_tokenizer_vocab_size(tokenizer: TokenizerWrapper) -> int | None: + """Return ``len(tokenizer.vocab)`` (or HF equivalent) when available. + + Used by the spec-decode loop as an early sanity check on emitted + token ids: anything outside ``[0, vocab_size)`` cannot have come + from a clean broadcast (the sampler and drafter both produce ids + in that range), so it always points at a wire-level corruption + upstream. Returns ``None`` when the tokenizer doesn't expose a + vocab size (extremely defensive; mlx_lm tokenizers do). + """ + inner: object = getattr(tokenizer, "_tokenizer", None) + if inner is None: + return None + vocab_size: object = getattr(inner, "vocab_size", None) + if isinstance(vocab_size, int) and vocab_size > 0: + return vocab_size + vocab: object = getattr(inner, "vocab", None) + if isinstance(vocab, dict) and vocab: + return max(cast("dict[object, int]", vocab).values()) + 1 + return None + + +def _process_logits_for_position( + raw_logits: mx.array, + prev_tokens: mx.array, + logits_processors: list[Callable[[mx.array, mx.array], mx.array]], +) -> mx.array: + """Apply logits processors and convert to logprobs (single position).""" + out = raw_logits + for proc in logits_processors: + out = proc(prev_tokens, out) + return out - mx.logsumexp(out, axis=-1, keepdims=True) + + +@final +class PipelinedModelDrafter: + """Speculative decoding via a drafter accessed through :class:`DrafterTransport`. + + Owns its own spec loop so the drafter can be remote (different MLX + rank) without the target rank loading the drafter model. The + transport-agnostic propose/trim primitives mean swapping + in-process for remote drafter placement is a one-line construction + change at :func:`make_drafter`; the spec loop is unaffected. + + Multi-target asymmetric placement (``target_subgroup_size > 1``): + the target root rank holds the drafter socket (``transport`` is + set, ``is_target_root=True``) and broadcasts each round's drafts on + ``target_group`` so non-root target ranks receive them in lockstep. + Non-root ranks construct with ``transport=None`` and consume the + broadcast each round; both ranks then run the same verify forward + (which is a TP collective on the model itself) and reach identical + accept/reject decisions deterministically because TP all-reduces + the final logits to be byte-identical on every rank. + """ + + def __init__( + self, + *, + transport: DrafterTransport | None, + num_draft_tokens: int, + target_group: mx.distributed.Group | None = None, + target_peer_fanout: TargetPeerFanout | None = None, + is_target_root: bool = True, + ) -> None: + if num_draft_tokens < 1: + raise ValueError(f"num_draft_tokens must be >= 1, got {num_draft_tokens}") + if transport is None: + # Multi-target consumer rank: no socket, drafts arrive via + # broadcast on ``target_group``. + if is_target_root: + raise ValueError( + "transport=None requires is_target_root=False (the " + "consumer rank does not own the drafter socket)" + ) + if target_group is None: + raise ValueError( + "transport=None requires a target_group to receive " + "draft broadcasts on" + ) + else: + if num_draft_tokens > transport.num_draft_tokens: + raise ValueError( + f"num_draft_tokens ({num_draft_tokens}) exceeds transport's " + f"max ({transport.num_draft_tokens})" + ) + if not is_target_root: + raise ValueError( + "is_target_root=False on a transport-owning rank is a " + "configuration error: the rank that holds the drafter " + "socket is the broadcast root by definition" + ) + self._transport = transport + self._num_draft_tokens = num_draft_tokens + self._target_group = target_group + self._target_peer_fanout = target_peer_fanout + self._is_target_root = is_target_root + # Per-request spec-decode telemetry. Mutated in place by the + # spec body each round; read by ``mlx_generate`` after streaming + # completes to populate ``GenerationStats``. Single-request + # lifecycle (a fresh drafter is built per request in + # ``mlx_generate``), so no thread-safety concerns. + self._metrics: dict[str, int] = { + "proposed_draft_tokens": 0, + "accepted_draft_tokens": 0, + "spec_decode_rounds": 0, + } + + @property + def mode(self) -> DraftMode: + return "pipelined" + + @property + def num_draft_tokens(self) -> int: + return self._num_draft_tokens + + def metrics(self) -> dict[str, int]: + """Snapshot of accumulated spec-decode metrics for this request. + + Keys: ``proposed_draft_tokens`` (total drafts proposed across all + rounds), ``accepted_draft_tokens`` (drafts the target accepted), + ``spec_decode_rounds`` (rounds executed). Acceptance rate is + ``accepted / proposed`` when ``proposed > 0``. Counters reset on + each new request via the per-request drafter construction in + ``mlx_generate``; mutate in lockstep with the spec loop. + """ + return dict(self._metrics) + + def stream( + self, + *, + model: Model, + tokenizer: TokenizerWrapper, + prompt: mx.array, + context_tokens: Sequence[int], + prompt_cache: KVCacheType, + max_tokens: int, + sampler: Callable[[mx.array], mx.array], + logits_processors: Sequence[Callable[[mx.array, mx.array], mx.array]], + prefill_step_size: int = 1, + ) -> Generator[GenerationResponse, None, None]: + yield from _pipelined_stream_generate( + model=model, + tokenizer=tokenizer, + prompt=prompt, + context_tokens=list(context_tokens), + prompt_cache=prompt_cache, + max_tokens=max_tokens, + sampler=sampler, + logits_processors=list(logits_processors), + transport=self._transport, + num_draft_tokens=self._num_draft_tokens, + prefill_step_size=prefill_step_size, + target_group=self._target_group, + target_peer_fanout=self._target_peer_fanout, + is_target_root=self._is_target_root, + metrics=self._metrics, + ) + + def shutdown(self) -> None: + """Release transport resources.""" + if self._transport is not None: + self._transport.shutdown() + + +def _pipelined_stream_generate( + *, + model: Model, + tokenizer: TokenizerWrapper, + prompt: mx.array, + context_tokens: list[int], + prompt_cache: KVCacheType, + max_tokens: int, + sampler: Callable[[mx.array], mx.array], + logits_processors: list[Callable[[mx.array, mx.array], mx.array]], + transport: DrafterTransport | None, + num_draft_tokens: int, + prefill_step_size: int, + target_group: mx.distributed.Group | None = None, + target_peer_fanout: TargetPeerFanout | None = None, + is_target_root: bool = True, + metrics: dict[str, int] | None = None, +) -> Generator[GenerationResponse, None, None]: + """Mirror of ``mlx_lm.stream_generate`` framing for the pipelined drafter. + + The framing (detokenisation, tps tracking, finish reasons) matches + :func:`exo.worker.engines.mlx.generator.drafter._ngram_stream_generate` + so the call site in ``mlx_generate`` doesn't branch on drafter type. + """ + detokenizer = tokenizer.detokenizer + detokenizer.reset() # type: ignore[reportUnknownMemberType] + eos_ids = _get_eos_ids(tokenizer) + # Vocab bound for early surfacing of broadcast corruption. + # ``add_token`` would otherwise blow up deep inside the SPM + # detokenizer with ``IndexError: list index out of range`` and + # the operator has to dig through the mlx_lm internals to learn + # which token id was bogus. + vocab_size = _get_tokenizer_vocab_size(tokenizer) + + token_iter = _pipelined_speculative_step( + prompt=prompt, + model=model, + transport=transport, + prompt_cache=prompt_cache, + max_tokens=max_tokens, + sampler=sampler, + logits_processors=logits_processors, + num_draft_tokens=num_draft_tokens, + prefill_step_size=prefill_step_size, + prompt_token_count=len(context_tokens), + target_group=target_group, + target_peer_fanout=target_peer_fanout, + is_target_root=is_target_root, + metrics=metrics, + ) + + prompt_size = len(context_tokens) + tic = time.perf_counter() + prompt_tps = 0.0 + n = -1 + token = 0 + logprobs = mx.zeros((1,)) + from_draft = False + finish_reason: str | None = None + for n, (token, logprobs, from_draft) in enumerate(token_iter): + if n == 0: + prompt_time = time.perf_counter() - tic + prompt_tps = prompt_size / prompt_time if prompt_time > 0 else 0.0 + tic = time.perf_counter() + if token in eos_ids: + finish_reason = "stop" + break + if vocab_size is not None and not 0 <= token < vocab_size: + raise RuntimeError( + f"pipelined drafter emitted token id {token} outside " + f"tokenizer vocab [0, {vocab_size}); " + "this is a wire-protocol bug in the spec-decode " + "broadcast path (cross-stream JACCL collision or " + "rank divergence). The runner will crash and the " + "supervisor will rebuild the instance." + ) + detokenizer.add_token(token) # type: ignore[reportUnknownMemberType] + if (n + 1) == max_tokens: + finish_reason = "length" + break + elapsed = time.perf_counter() - tic + yield GenerationResponse( + text=detokenizer.last_segment, + token=token, + logprobs=logprobs, + from_draft=from_draft, + prompt_tokens=prompt_size, + prompt_tps=prompt_tps, + generation_tokens=n + 1, + generation_tps=(n + 1) / elapsed if elapsed > 0 else 0.0, + peak_memory=mx.get_peak_memory() / 1e9, + finish_reason=None, + ) + + detokenizer.finalize() # type: ignore[reportUnknownMemberType] + elapsed = time.perf_counter() - tic + yield GenerationResponse( + text=detokenizer.last_segment, + token=token, + logprobs=logprobs, + from_draft=from_draft, + prompt_tokens=prompt_size, + prompt_tps=prompt_tps, + generation_tokens=n + 1 if n >= 0 else 0, + generation_tps=(n + 1) / elapsed if elapsed > 0 and n >= 0 else 0.0, + peak_memory=mx.get_peak_memory() / 1e9, + finish_reason=finish_reason or ("stop" if token in eos_ids else "length"), + ) + + +def _broadcast_int_list( + payload: list[int] | None, + *, + length: int, + target_group: mx.distributed.Group | None, + target_peer_fanout: TargetPeerFanout | None, + is_root: bool, +) -> list[int]: + """Pick the correct fixed-length int broadcast for the active wiring. + + Multi-target asymmetric placements ride + :func:`target_peer_broadcast_int_list` (TCP fanout, immune to + JACCL int/float wire conflation). Every other path -- single-rank + targets, symmetric multi-rank without a drafter, test fakes that + bring up a ``mx.distributed.Group`` without populating a fanout + -- falls through to :func:`mx_broadcast_int_list`. The fallback + is correct in those cases because the JACCL bug only manifests + when the spec-decode int broadcasts interleave with the model's + TP ``all_sum`` collectives on the same group; without spec + decode (no drafter) or without a multi-rank target (no TP + collectives) the interleaving cannot happen. + """ + if target_peer_fanout is not None: + return target_peer_broadcast_int_list( + payload, length, target_peer_fanout, is_root=is_root + ) + return mx_broadcast_int_list(payload, length, target_group, is_root=is_root) + + +def _broadcast_drafts( + drafts: list[int] | None, + *, + k: int, + target_group: mx.distributed.Group | None, + target_peer_fanout: TargetPeerFanout | None, + is_root: bool, +) -> list[int]: + """Rank-0 broadcast of a draft list, padded to ``k`` slots + length prefix. + + Wire format: ``[len(drafts), drafts[0], ..., drafts[len-1], 0, 0, ...]`` + of fixed length ``k + 1``. Encoding the length up front lets us use + a single fixed-size ``all_sum`` collective per round (vs. a + count-then-payload two-collective handshake) on the spec-decode hot + path -- the cost is a few unused int32 slots when the drafter + returns fewer than ``k`` drafts. + + Single-rank short-circuit (``target_group is None``): returns + ``drafts`` on the root and is a programming error elsewhere (the + consumer rank must always have a group to receive on). + """ + if target_group is None and target_peer_fanout is None: + if not is_root or drafts is None: + raise RuntimeError("non-root broadcast consumer requires target_group") + return list(drafts) + if is_root: + if drafts is None: + raise RuntimeError("root broadcaster requires drafts") + if len(drafts) > k: + raise RuntimeError( + f"drafts length ({len(drafts)}) exceeds k ({k}); " + "transport must clamp before broadcasting" + ) + payload = [len(drafts)] + list(drafts) + [0] * (k - len(drafts)) + broadcast = _broadcast_int_list( + payload, + length=k + 1, + target_group=target_group, + target_peer_fanout=target_peer_fanout, + is_root=True, + ) + else: + broadcast = _broadcast_int_list( + None, + length=k + 1, + target_group=target_group, + target_peer_fanout=target_peer_fanout, + is_root=False, + ) + actual_len = broadcast[0] + if actual_len == DRAFT_ABORT_SENTINEL: + # Root has flagged a drafter-side failure (see + # :func:`_broadcast_abort`). Surface a typed exception so the + # spec loop on this rank exits in lockstep with root rather + # than waiting on the next-round broadcast that won't arrive. + raise DrafterAbortedError( + "drafter aborted; root signalled abort via length-prefix " + "sentinel after a transport-side failure" + ) + if actual_len < 0 or actual_len > k: + raise RuntimeError( + f"draft broadcast decoded invalid length {actual_len} (buffer {broadcast})" + ) + return broadcast[1 : 1 + actual_len] + + +def _broadcast_abort( + *, + k: int, + target_group: mx.distributed.Group | None, + target_peer_fanout: TargetPeerFanout | None, +) -> None: + """Root-only: broadcast the abort sentinel on the draft channel. + + Encodes :data:`DRAFT_ABORT_SENTINEL` as the length-prefix of an + otherwise-zero ``k + 1`` int payload, matching the wire shape of + a normal :func:`_broadcast_drafts` round so non-root ranks + decode it on the same fixed-size collective they were already + waiting on. Non-root surfaces it as :class:`DrafterAbortedError`. + + Single-rank short-circuit (``target_group is None``): no peers + to notify, so this is a no-op. The local rank still re-raises + the underlying transport exception that triggered the abort. + """ + if target_group is None and target_peer_fanout is None: + return + payload = [DRAFT_ABORT_SENTINEL] + [0] * k + _ = _broadcast_int_list( + payload, + length=k + 1, + target_group=target_group, + target_peer_fanout=target_peer_fanout, + is_root=True, + ) + + +def _broadcast_target_tokens( + target_tokens: list[int] | None, + *, + k: int, + k_this: int, + target_group: mx.distributed.Group | None, + target_peer_fanout: TargetPeerFanout | None, + is_root: bool, +) -> list[int]: + """Rank-0 broadcast of post-verify sampled tokens, slot count ``k + 1``. + + Why a separate broadcast from the drafts: the sampler is the only + non-deterministic step in the verify path. With temperature > 0 + each target rank's MLX PRNG advances independently, so identical + logits produce divergent ``target_tokens`` and the ranks desync on + the next TP forward. Broadcasting the chosen tokens from rank 0 + makes the sampler effectively a rank-0 operation; non-root ranks + skip the sampler entirely. + + Wire format: fixed-size ``k + 1`` int buffer (the verify forward + always produces exactly ``k_this + 1`` tokens; trailing slots are + zero-padded so the buffer shape doesn't change with ``k_this``). + Both ranks know ``k_this`` from the prior draft broadcast, so we + skip the length prefix and slice on receive. + + Single-rank short-circuit (``target_group is None``): identity on + root; programming error on consumer (no broadcast peer). + """ + if target_group is None and target_peer_fanout is None: + if not is_root or target_tokens is None: + raise RuntimeError("non-root broadcast consumer requires target_group") + if len(target_tokens) != k_this + 1: + raise RuntimeError( + f"target_tokens length ({len(target_tokens)}) must " + f"equal k_this + 1 ({k_this + 1}); the verifier always " + "emits exactly that many tokens per round" + ) + return list(target_tokens) + if is_root: + if target_tokens is None: + raise RuntimeError("root broadcaster requires target_tokens") + if len(target_tokens) != k_this + 1: + raise RuntimeError( + f"target_tokens length ({len(target_tokens)}) must " + f"equal k_this + 1 ({k_this + 1}); the verifier always " + "emits exactly that many tokens per round" + ) + payload = list(target_tokens) + [0] * (k - k_this) + broadcast = _broadcast_int_list( + payload, + length=k + 1, + target_group=target_group, + target_peer_fanout=target_peer_fanout, + is_root=True, + ) + else: + broadcast = _broadcast_int_list( + None, + length=k + 1, + target_group=target_group, + target_peer_fanout=target_peer_fanout, + is_root=False, + ) + return broadcast[: k_this + 1] + + +def _pipelined_speculative_step( + *, + prompt: mx.array, + model: Model, + transport: DrafterTransport | None, + prompt_cache: KVCacheType, + max_tokens: int, + sampler: Callable[[mx.array], mx.array], + logits_processors: list[Callable[[mx.array, mx.array], mx.array]], + num_draft_tokens: int, + prefill_step_size: int, + prompt_token_count: int, + target_group: mx.distributed.Group | None = None, + target_peer_fanout: TargetPeerFanout | None = None, + is_target_root: bool = True, + metrics: dict[str, int] | None = None, +) -> Generator[tuple[int, mx.array, bool], None, None]: + """Public spec-step generator with drafter-failure recovery. + + Wraps :func:`_pipelined_speculative_step_body` so that any + :class:`OSError` originating from the drafter wire on the root + rank (socket close, broken pipe, peer reset, etc.) also + broadcasts :data:`DRAFT_ABORT_SENTINEL` to non-root target + ranks. Non-root decodes it inside :func:`_broadcast_drafts` + and raises :class:`DrafterAbortedError`, exiting the spec loop + in lockstep with root. Without this wrap, root would re-raise + cleanly while non-root sat indefinitely on the next-round + draft broadcast that root will never send. + + Non-root and single-rank placements pass through unchanged: + non-root never touches the transport (so there is nothing to + abort from); :func:`_broadcast_abort` short-circuits when + ``target_group is None`` (no peers to notify). The local rank + re-raises the underlying exception in both cases. + """ + inner = _pipelined_speculative_step_body( + prompt=prompt, + model=model, + transport=transport, + prompt_cache=prompt_cache, + max_tokens=max_tokens, + sampler=sampler, + logits_processors=logits_processors, + num_draft_tokens=num_draft_tokens, + prefill_step_size=prefill_step_size, + prompt_token_count=prompt_token_count, + target_group=target_group, + target_peer_fanout=target_peer_fanout, + is_target_root=is_target_root, + metrics=metrics, + ) + try: + yield from inner + except OSError: + if is_target_root: + # Recovery best-effort: if the abort broadcast itself + # fails (e.g. ``target_group`` is also dead), the + # supervisor SIGKILL chain still tears non-root + # runners down via the master's instance-deletion + # path. Suppression keeps the original ``OSError`` + # intact for the caller's traceback. + with contextlib.suppress(Exception): + _broadcast_abort( + k=num_draft_tokens, + target_group=target_group, + target_peer_fanout=target_peer_fanout, + ) + raise + + +def _pipelined_speculative_step_body( + *, + prompt: mx.array, + model: Model, + transport: DrafterTransport | None, + prompt_cache: KVCacheType, + max_tokens: int, + sampler: Callable[[mx.array], mx.array], + logits_processors: list[Callable[[mx.array, mx.array], mx.array]], + num_draft_tokens: int, + prefill_step_size: int, + prompt_token_count: int, + target_group: mx.distributed.Group | None = None, + target_peer_fanout: TargetPeerFanout | None = None, + is_target_root: bool = True, + metrics: dict[str, int] | None = None, +) -> Generator[tuple[int, mx.array, bool], None, None]: + """Cross-round speculative decoding loop using ``transport``. + + See module docstring for the cache-accounting derivation. This + function maintains: + + * ``drafts``: list[int] of length K_this -- this round's drafts. + * ``seed``: int -- the seed token for this round (target verify + consumes ``[seed, *drafts]``). + * ``next_round_inputs``: list[int] -- input shape for next round's + propose call (length 1 for partial-accept-from-this, length 2 + for full-accept-from-this). + * ``speculative_future``: optional Future from a speculative + forward issued in parallel with target verify. ``None`` when + speculation is not in flight. + + ``prompt_token_count`` is captured so logits processors that need + the running token count (rare, e.g. positional repetition penalty + that scales with absolute position) get accurate values. + + Multi-target asymmetric (``target_group is not None``): only the + target root rank holds the drafter ``transport``; non-root target + ranks pass ``transport=None`` and receive each round's drafts via + a rank-0 broadcast on ``target_group``. Both ranks then run the + verify forward in TP lockstep -- the model's final all-reduce + makes logits byte-identical across target ranks, so accept/reject + decisions and emitted token sequences match deterministically + without any further coordination. + """ + if (transport is None) and is_target_root: + raise RuntimeError( + "_pipelined_speculative_step: target root requires transport" + ) + if (transport is None) and target_group is None: + raise RuntimeError( + "_pipelined_speculative_step: non-root target rank requires " + "target_group to receive draft broadcasts" + ) + + k = num_draft_tokens + y = prompt.astype(mx.uint32) + + _diag_rank = ( + target_group.rank() + if target_group is not None + else (0 if is_target_root else -1) + ) + _spec_diag( + f"rank {_diag_rank}: spec body entered " + f"(prompt size={int(prompt.size)}, k={k}, root={is_target_root})" + ) + + # Mirror mlx_lm._prefill: caller has aligned ``prompt_cache`` to + # ``context_tokens[:-2]`` via ``exo.prefill`` + ``trim(2)``; this loop + # advances the cache by one more token, leaving ``y`` (length 1) as + # the seed for the spec loop. + _diag_prefill_iters = 0 + while y.size > 1: + _diag_prefill_t0 = time.perf_counter() + n_to_process = min(prefill_step_size, y.size - 1) + model(y[:n_to_process][None], cache=prompt_cache) + mx.eval([c.state for c in prompt_cache]) # type: ignore[reportArgumentType] + y = y[n_to_process:] + mx.clear_cache() + _spec_diag( + f"rank {_diag_rank}: spec-body prefill iter " + f"{_diag_prefill_iters} done in " + f"{(time.perf_counter() - _diag_prefill_t0) * 1000:.1f}ms " + f"(remaining y.size={int(y.size)})" + ) + _diag_prefill_iters += 1 + + _diag_seed_t0 = time.perf_counter() + seed = int(y.item()) + _spec_diag( + f"rank {_diag_rank}: seed materialized in " + f"{(time.perf_counter() - _diag_seed_t0) * 1000:.1f}ms (seed={seed})" + ) + # ``prev_tokens`` carries the running token sequence (prompt + + # emitted) so logits processors with state see consistent context. + # Mirror :func:`drafter._ngram_speculative_step`: start from prompt. + prev_tokens = mx.array([seed], dtype=mx.uint32) + del prompt_token_count # currently unused; kept for forward-compat + + # Round 0 propose: synchronous, no speculation possible yet because + # we don't have prior drafts to chain off of. On the root the + # drafter forward issues a socket round-trip; on non-root target + # ranks we skip that and just receive the broadcast. + if transport is not None: + _diag_fwd_t0 = time.perf_counter() + _spec_diag( + f"rank {_diag_rank}: round 0 about to call transport.forward([seed], k={k})" + ) + drafts_future = transport.forward([seed], k) + drafts_local: list[int] | None = drafts_future.result() + _spec_diag( + f"rank {_diag_rank}: round 0 transport.forward " + f"returned in {(time.perf_counter() - _diag_fwd_t0) * 1000:.1f}ms " + f"(drafts_local len={len(drafts_local) if drafts_local else 0})" + ) + else: + drafts_local = None + _diag_bcast_t0 = time.perf_counter() + _spec_diag( + f"rank {_diag_rank}: round 0 about to call " + f"_broadcast_drafts (root={is_target_root})" + ) + drafts = _broadcast_drafts( + drafts_local, + k=k, + target_group=target_group, + target_peer_fanout=target_peer_fanout, + is_root=is_target_root, + ) + _spec_diag( + f"rank {_diag_rank}: round 0 _broadcast_drafts done " + f"in {(time.perf_counter() - _diag_bcast_t0) * 1000:.1f}ms " + f"(drafts len={len(drafts)})" + ) + + speculative_future: DraftFuture | None = None + ntoks = 0 + _diag_round = 0 + + while ntoks < max_tokens: + budget = max_tokens - ntoks + k_this = min(k, len(drafts), budget) + if k_this < 1: + break + drafts = drafts[:k_this] + _diag_round += 1 + _spec_diag( + f"rank {_diag_rank}: round {_diag_round} top " + f"(ntoks={ntoks}, k_this={k_this})" + ) + + # ----- Cross-round speculation: dispatch in parallel with verify ----- + # Speculate only when: + # * full k_this drafts (truncated last rounds have no t+1 to feed), + # * budget remains for an entire next round's verify after this one. + # + # The speculative forward consumes ``drafts[-1]`` (= drafter's last + # draft this round) as its first input, doing k+1 forwards. The + # first output is the drafter's prediction of bonus_t (used to + # detect speculation hit); the remaining k outputs are round + # t+1's drafts if speculation hits. + # + # Speculation only fires on the rank that owns the transport. + # Non-root target ranks have no socket and would have nothing + # to dispatch; they catch up via the next-round broadcast. + speculation_active = ( + transport is not None + and k_this == k + and ntoks + (k_this + 1) + k + 1 <= max_tokens + and speculative_future is None + ) + if speculation_active: + assert transport is not None # narrowed by speculation_active + speculative_future = transport.forward([drafts[-1]], k + 1) + + # ----- Target verify ----- + seed_arr = mx.array([seed], dtype=mx.uint32) + draft_arr = mx.array(drafts, dtype=mx.uint32) + verify_input = mx.concatenate([seed_arr, draft_arr]) + _diag_verify_t0 = time.perf_counter() + logits = model(verify_input[None], cache=prompt_cache) + # CRITICAL: force eval of ``logits`` on every target rank so the + # TP all-reduce kernels embedded in ``model()`` actually launch + # before any rank proceeds to its next blocking step. Without + # this, non-root ranks dispatch the verify forward (lazy graph + # only) and then enter the TCP recv in ``_broadcast_target_tokens``, + # leaving the all-reduce un-launched on their side. The root + # rank's ``mx.eval(sampled_batch)`` then deadlocks waiting for + # the matching all-reduce on every peer. This mirrors the + # prefill loop's ``mx.eval([c.state for c in prompt_cache])``, + # which is what made the round-0 prefill collectives pair up + # correctly on both ranks. Cost: one synchronization per round + # (~the verify forward time, which we'd block on at the sampler + # step anyway on root). Benefit: guaranteed pairing of TP + # collectives across all target ranks under JACCL or ring. + mx.eval(logits) + _spec_diag( + f"rank {_diag_rank}: round {_diag_round} model(verify) + eval " + f"completed in {(time.perf_counter() - _diag_verify_t0) * 1000:.1f}ms " + f"(verify_len={k_this + 1})" + ) + # logits shape: (1, k_this + 1, vocab) + + target_logprobs: list[mx.array] + target_tokens: list[int] + # Fast path: every processor advertises position independence + # (or there are none). Apply them once to the batched + # ``(K+1, vocab)`` logits, sample all positions in one call, + # and pay a single host-device sync per round instead of K+1. + # On a target with ~10ms step time this saves ~10-15ms per + # round -- typically the difference between net-win and net-loss + # for spec-decode on fast quantised targets. + # + # Multi-target determinism: ``logits`` is byte-identical across + # target ranks because the model's final layer all-reduces it + # via TP. Logits processors are pure functions of ``logits`` and + # ``prev_tokens`` (also identical across ranks), so logprobs are + # identical too. The sampler is the only non-deterministic step + # (``mx.random.categorical`` uses MLX's per-rank PRNG). Rank 0 + # samples; non-root ranks skip the sampler and pick up tokens + # from the broadcast below. Logprobs are still computed locally + # on every rank because they're cheap and the yield contract + # passes them upward (the user only ever sees rank 0's, but + # keeping the local view matches the single-rank path). + position_independent = all( + getattr(p, "position_independent", False) for p in logits_processors + ) + if position_independent: + batched_logits = logits.squeeze(0) + for proc in logits_processors: + batched_logits = proc(prev_tokens, batched_logits) + batched_logprobs = batched_logits - mx.logsumexp( + batched_logits, axis=-1, keepdims=True + ) + target_logprobs = [batched_logprobs[i] for i in range(k_this + 1)] + if is_target_root: + _diag_sample_t0 = time.perf_counter() + _spec_diag( + f"rank {_diag_rank}: round {_diag_round} root: about to " + f"call sampler(batched_logprobs)" + ) + sampled_batch = sampler(batched_logprobs) + _spec_diag( + f"rank {_diag_rank}: round {_diag_round} root: about to " + f"mx.eval(sampled_batch) (this forces verify forward + " + f"all_sum to actually run)" + ) + mx.eval(sampled_batch) + _spec_diag( + f"rank {_diag_rank}: round {_diag_round} root: " + f"mx.eval(sampled_batch) done in " + f"{(time.perf_counter() - _diag_sample_t0) * 1000:.1f}ms" + ) + target_tokens = [int(t) for t in sampled_batch.tolist()] # type: ignore[reportUnknownArgumentType] + else: + # Filled by broadcast below; skip the sampler entirely. + target_tokens = [] + _spec_diag( + f"rank {_diag_rank}: round {_diag_round} non-root: " + f"skipped sampler, awaiting broadcast" + ) + else: + # Stateful path: logits processors (e.g. repetition penalty) + # depend on ``running_prev`` which only resolves between + # positions, so we can't batch. Per-position sync is the + # cost of correctness here. + # + # Cross-rank determinism subtlety: the loop's ``running_prev`` + # advances by the sampled token at each position. On rank 0 + # we sample to advance it; on non-root ranks we don't have + # the token yet (the broadcast happens after the loop), so + # we'd advance with the wrong tokens. To keep the per-rank + # codepath identical we sample on every rank and broadcast + # after; the broadcast then overwrites ``target_tokens`` so + # downstream accept/reject is identical. Per-rank sampler + # divergence inside this loop is harmless because nothing + # consumes ``target_tokens`` between sampler call and + # broadcast; it gets clobbered before use. + target_logprobs = [] + target_tokens = [] + running_prev = prev_tokens + for i in range(k_this + 1): + position_logits = logits[:, i, :].squeeze(0) + position_logprobs = _process_logits_for_position( + position_logits, running_prev, logits_processors + ) + sampled = sampler(position_logprobs) + mx.eval(sampled) + sampled_token = int(sampled.item()) + target_logprobs.append(position_logprobs) + target_tokens.append(sampled_token) + running_prev = mx.concatenate( + [running_prev, mx.array([sampled_token], dtype=mx.uint32)] + ) + + # Broadcast rank-0's chosen tokens to every target rank so + # accept/reject decisions are bit-identical. Single-rank + # placements (``target_group is None``) short-circuit to + # identity, so this is free for the non-multi-target paths. + _diag_tbcast_t0 = time.perf_counter() + _spec_diag( + f"rank {_diag_rank}: round {_diag_round} about to call " + f"_broadcast_target_tokens (root={is_target_root})" + ) + target_tokens = _broadcast_target_tokens( + target_tokens if is_target_root else None, + k=k, + k_this=k_this, + target_group=target_group, + target_peer_fanout=target_peer_fanout, + is_root=is_target_root, + ) + _spec_diag( + f"rank {_diag_rank}: round {_diag_round} _broadcast_target_tokens " + f"done in {(time.perf_counter() - _diag_tbcast_t0) * 1000:.1f}ms " + f"(target_tokens len={len(target_tokens)})" + ) + + # ----- Greedy accept loop ----- + num_accepted = 0 + for i in range(k_this): + if target_tokens[i] == drafts[i]: + num_accepted += 1 + else: + break + + # Per-round telemetry: ``k_this`` drafts proposed, + # ``num_accepted`` accepted by the greedy verifier. The bonus + # token (target's correction or full-accept tail) is *not* a + # draft, so it doesn't count against acceptance rate. Mutates + # the caller's dict in place; ``metrics is None`` for the + # synthetic single-rank tests that bypass the drafter wrapper. + if metrics is not None: + metrics["proposed_draft_tokens"] += k_this + metrics["accepted_draft_tokens"] += num_accepted + metrics["spec_decode_rounds"] += 1 + + # ----- Emit accepted drafts + correction/bonus ----- + emit_count = num_accepted + 1 + for j in range(emit_count): + tok = drafts[j] if j < num_accepted else target_tokens[j] + from_draft = j < num_accepted + yield tok, target_logprobs[j], from_draft + prev_tokens = mx.concatenate( + [prev_tokens, mx.array([tok], dtype=mx.uint32)] + ) + ntoks += 1 + if ntoks >= max_tokens: + break + + # ----- Target cache trim (rejected draft positions) ----- + # Verify forward extended target cache by k_this + 1; we keep + # ``num_accepted + 1`` of those (= emit_count) so trim + # ``k_this - num_accepted``. + target_trim = k_this - num_accepted + if target_trim > 0: + mlx_trim_prompt_cache(cast(list[object], prompt_cache), target_trim) # type: ignore[reportArgumentType] + + if ntoks >= max_tokens: + # Discard any in-flight speculation; we're done. Rolling back + # the drafter cache isn't strictly necessary (the loop is + # exiting), but keeps the cache in a consistent state for + # any subsequent runs that might reuse the transport. + if speculative_future is not None: + _drain_future(speculative_future) + assert transport is not None # speculative_future is set only on root + transport.trim_cache(k + 1) + speculative_future = None + break + + # ``next_seed`` is the target's chosen token at the rejection + # point (partial accept) or the bonus position (full accept). + next_seed: int + if num_accepted == k_this: + next_seed = target_tokens[k_this] + else: + next_seed = target_tokens[num_accepted] + + # ----- Drafter cache reconciliation + next-round setup ----- + # Only the root rank touches ``transport``; non-root target + # ranks compute ``next_drafts_local = None`` and pick up the + # actual drafts from the rank-0 broadcast below. + next_drafts_local: list[int] | None + if transport is not None: + if num_accepted < k_this: + # Partial accept (regardless of speculation state). + drafter_trim_partial = max(k_this - num_accepted - 1, 0) + if speculative_future is not None: + # Speculative work is bound to a different (assumed- + # full-accept) future; discard it and trim its k+1 + # positions plus the partial-accept trim. + _drain_future(speculative_future) + transport.trim_cache(k + 1 + drafter_trim_partial) + speculative_future = None + elif drafter_trim_partial > 0: + transport.trim_cache(drafter_trim_partial) + next_drafts_local = transport.forward([next_seed], k).result() + else: + # Full accept at this round. + if speculative_future is not None: + spec_outputs = speculative_future.result() + speculative_future = None + bonus_predicted = spec_outputs[0] + if bonus_predicted == next_seed: + # SPECULATION HIT. Round t+1's drafts come for free. + # Drafter cache state is correct (offset O+2k+1 + # matches what a length-2-seed propose for round + # t+1 would produce). + next_drafts_local = spec_outputs[1 : k + 1] + else: + # SPECULATION MISS. Rollback the k+1 speculative + # positions and run a standard length-2-seed + # propose for round t+1. + transport.trim_cache(k + 1) + next_drafts_local = transport.forward( + [drafts[-1], next_seed], k + ).result() + else: + # Full accept, speculation was inactive. Standard + # length-2-seed propose for round t+1. + next_drafts_local = transport.forward( + [drafts[-1], next_seed], k + ).result() + else: + next_drafts_local = None + + _diag_nbcast_t0 = time.perf_counter() + _spec_diag( + f"rank {_diag_rank}: round {_diag_round} about to call " + f"_broadcast_drafts (next, accepted={num_accepted}/{k_this})" + ) + next_drafts = _broadcast_drafts( + next_drafts_local, + k=k, + target_group=target_group, + target_peer_fanout=target_peer_fanout, + is_root=is_target_root, + ) + _spec_diag( + f"rank {_diag_rank}: round {_diag_round} _broadcast_drafts (next) " + f"done in {(time.perf_counter() - _diag_nbcast_t0) * 1000:.1f}ms " + f"(next_drafts len={len(next_drafts)})" + ) + + seed = next_seed + drafts = next_drafts + + +def _drain_future(future: DraftFuture) -> None: + """Block on ``future`` and discard its result. + + Used when speculation misses or the loop exits early: the drafter + forwards have already executed; we just need to ensure the future + is resolved before issuing dependent transport operations + (``trim_cache``, ``shutdown``). Exceptions from the forwards + surface elsewhere (transport's own error path); we suppress them + here to avoid double-reporting. + """ + import contextlib + + with contextlib.suppress(Exception): + future.result() + + +__all__ = ["PipelinedModelDrafter"] diff --git a/src/exo/worker/engines/mlx/generator/remote_drafter.py b/src/exo/worker/engines/mlx/generator/remote_drafter.py new file mode 100644 index 0000000000..207d1b98a5 --- /dev/null +++ b/src/exo/worker/engines/mlx/generator/remote_drafter.py @@ -0,0 +1,875 @@ +"""Drafter on a different node, IPC via a direct TCP socket. + +:class:`RemoteTransport` (a concrete :class:`DrafterTransport`) and the +matching :func:`drafter_serve_loop` carry the same uint32-array wire +protocol that the original implementation rode on top of +``mx.distributed.send/recv``, but they no longer require the drafter to +be a member of any ``mx.distributed.Group``. + +Why the change: ``mx.distributed`` on Apple Silicon (jaccl, ring) does +not implement ``Group.split``. As long as the drafter rank shared the +parent group with the target ranks, the target ranks could not run +TP/PP collectives without dragging the drafter in -- the V1 asymmetric +path was therefore limited to a single target rank. By moving the +drafter wire onto a plain TCP socket, the parent ``mx.distributed`` +group contains only target ranks (so target collectives work as +designed), and the drafter rank skips ``mx.distributed.init`` entirely. +The same code path works for parent_size 1 (single target) and +parent_size N (sharded target) without any backend feature gate. + +Wire protocol v3 (session-aware, socket-framed -- semantically +identical to v2 but without the ``mx.distributed`` framing): + + * **Command frame** (target -> drafter), :data:`COMMAND_FRAME_SIZE` + little-endian uint32s:: + + [op, num_inputs, num_forwards, input_0, input_1, trim_amount, + session_id, _, _] + + Fixed length so the receiver can call + :func:`drafter_socket.recv_uint32_frame` with a known shape. + ``session_id`` selects which per-session draft cache the drafter + rank routes the op to. ``OP_SHUTDOWN`` ignores ``session_id`` + (it tears down the entire serve loop). All other ops require a + valid ``session_id`` -- :data:`OP_PREFILL` allocates the session, + :data:`OP_END_SESSION` frees it, the rest reference an existing + session. Unused slots are zero-padded. + + * **Drafts frame** (drafter -> target), :data:`COMMAND_FRAME_SIZE` - + sized? No: the drafts buffer is sized to ``num_draft_tokens + 1``. + The target knows the buffer width statically from + :attr:`RemoteTransport.num_draft_tokens`. Padded with zeros if the + request asked for fewer than ``K + 1`` forwards (the caller knows + its requested count and slices accordingly). + + * **Ack frame** (drafter -> target), :data:`ACK_FRAME_SIZE` uint32s: + a single status byte (always ``0`` for "ok"). Sent after + ``OP_TRIM_CACHE``, ``OP_PREFILL``, ``OP_END_SESSION``, and + ``OP_SHUTDOWN`` so the target rank has a synchronisation point + against the drafter's cache state. + + * **OP_PREFILL prompt tail** (target -> drafter): when the command + frame's ``num_forwards`` slot is non-zero, the target follows the + command frame with a length-prefixed prompt-token payload (see + :func:`drafter_socket.send_variable_uint32_payload`). Empty + prompts skip the tail entirely. + +Op codes: :data:`OP_FORWARD` (1), :data:`OP_TRIM_CACHE` (2), +:data:`OP_SHUTDOWN` (3), :data:`OP_PREFILL` (4), +:data:`OP_END_SESSION` (5). + +Concurrency model: ``RemoteTransport`` exposes :meth:`open_session` +which allocates a fresh ``session_id`` and returns a session-scoped +:class:`DrafterTransport` view. Each in-flight target request gets +its own session handle; the underlying wire stays serial because a +single TCP connection cannot interleave reads/writes from multiple +threads, but the drafter rank multiplexes operations across sessions +by keying each op's KV-cache lookup on ``session_id``. The cap on +concurrent target requests is therefore set by the *target* runner +(``EXO_MAX_CONCURRENT_REQUESTS``), not by the drafter wire. + +Topology assumption: target rank 0 binds a TCP listener at instance +bootstrap; the drafter dials it. Address discovery flows through +:class:`DrafterPlacement` (host = target rank 0's advertised address, +port = ephemeral port allocated at placement time). One TCP +connection per asymmetric instance is sufficient because ops serialise +on a single socket. +""" + +from __future__ import annotations + +import contextlib +import itertools +import socket +import threading +from concurrent.futures import Future, ThreadPoolExecutor +from typing import TYPE_CHECKING, Callable, Final, final + +from exo.worker.engines.mlx.generator.drafter_socket import ( + recv_uint32_frame, + send_uint32_frame, + send_variable_uint32_payload, +) + +if TYPE_CHECKING: + from exo.worker.engines.mlx.generator.drafter_transport import DraftFuture + from exo.worker.engines.mlx.types import KVCacheType, Model + +import mlx.core as mx +from mlx_lm.models.cache import trim_prompt_cache as mlx_trim_prompt_cache + +# --------------------------------------------------------------------------- +# Wire protocol +# --------------------------------------------------------------------------- + +COMMAND_FRAME_SIZE: Final[int] = 9 +"""Fixed size of a command frame (uint32 ints). + +Carries [op, num_inputs, num_forwards, input_0, input_1, trim_amount, +session_id, 0, 0]. Two trailing zero slots reserved for future +extension without bumping the wire version on the byte layer.""" + +ACK_FRAME_SIZE: Final[int] = 1 +"""Fixed size of an ack frame (uint32 ints). The single int is reserved +for a status code; ``0`` means ok. Future revisions may surface error +states here without changing the wire format.""" + +OP_FORWARD: Final[int] = 1 +"""Drafter runs ``num_forwards`` forwards starting from +``inputs[:num_inputs]`` against ``sessions[session_id]``'s KV cache. +Replies with a Drafts frame.""" + +OP_TRIM_CACHE: Final[int] = 2 +"""Drafter trims ``trim_amount`` positions from +``sessions[session_id]``'s KV cache. Replies with an Ack frame so the +target has a sync point.""" + +OP_SHUTDOWN: Final[int] = 3 +"""Drafter exits its serve loop. Replies with an Ack frame, then the +serve loop returns. ``session_id`` is ignored -- this op tears down +the entire wire, not a single session. Per-session cleanup uses +:data:`OP_END_SESSION` instead.""" + +OP_PREFILL: Final[int] = 4 +"""Per-request setup: target announces a prompt of ``num_inputs`` (used +as ``num_prompt_tokens``) tokens for ``session_id``. The drafter +allocates a fresh KV cache for the session (or resets the existing +one to offset 0), recvs the prompt token array, runs prefill forwards +through the drafter model, then replies with an Ack frame. Issued +once at the start of every request so the spec loop's first +``OP_FORWARD`` seeds against an aligned drafter cache.""" + +OP_END_SESSION: Final[int] = 5 +"""Per-request teardown: drafter drops ``sessions[session_id]`` to free +the KV cache memory and replies with an Ack frame so the target has a +sync point. Idempotent: ending a non-existent session is also a +successful ack (sessions can drop themselves on the drafter side via +target shutdown without the target getting a chance to send this op). +""" + +ACK_OK: Final[int] = 0 + +SESSION_ID_NONE: Final[int] = 0xFFFFFFFF +"""Sentinel ``session_id`` for ops that don't address a session. + +``OP_SHUTDOWN`` carries this value because it tears down the whole +wire, not a single session. ``0`` is the first session id allocated by +the target's monotonic counter, so a sentinel out of that range avoids +a collision in wire-trace logs.""" + + +def _build_command_frame( + *, + op: int, + inputs: list[int], + num_forwards: int, + trim_amount: int, + session_id: int, +) -> list[int]: + """Pack command parameters into a fixed-length uint32 list. + + Layout: ``[op, num_inputs, num_forwards, input_0, input_1, trim_amount, session_id, 0, 0]``. + + ``inputs`` must have length 0, 1, or 2 (the spec loop only ever + passes length-1 or length-2 inputs to ``forward``; ``OP_TRIM_CACHE``, + ``OP_END_SESSION``, and ``OP_SHUTDOWN`` pass length 0). Out-of-band + lengths are a programming error and raise. + + ``session_id`` MUST fit in uint32. The target allocates session ids + monotonically per :class:`RemoteTransport` instance from a counter, + which gives ~4G sessions per runner lifetime -- plenty for any + realistic deployment. Wraparound is not handled (the runner would + have to serve > 4 billion concurrent requests; if that ever + happens, switch the counter to a free-list of recycled ids). + """ + if len(inputs) > 2: + raise ValueError(f"inputs length must be in [0, 2], got {len(inputs)}") + if not 0 <= session_id <= 0xFFFFFFFF: + raise ValueError(f"session_id must fit in uint32, got {session_id}") + return [ + op, + len(inputs), + num_forwards, + inputs[0] if len(inputs) >= 1 else 0, + inputs[1] if len(inputs) >= 2 else 0, + trim_amount, + session_id, + 0, + 0, + ] + + +def _decode_command_frame(flat: list[int]) -> tuple[int, list[int], int, int, int]: + """Inverse of :func:`_build_command_frame`. + + Returns ``(op, inputs, num_forwards, trim_amount, session_id)``. + """ + if len(flat) != COMMAND_FRAME_SIZE: + raise ValueError( + f"Command frame has {len(flat)} ints, expected {COMMAND_FRAME_SIZE}" + ) + op = flat[0] + num_inputs = flat[1] + num_forwards = flat[2] + trim_amount = flat[5] + session_id = flat[6] + inputs = flat[3 : 3 + num_inputs] + return op, inputs, num_forwards, trim_amount, session_id + + +# --------------------------------------------------------------------------- +# RemoteTransport (target side) +# --------------------------------------------------------------------------- + + +@final +class RemoteTransport: + """Wire-protocol owner for the asymmetric drafter rank (target side). + + Holds the long-lived TCP socket + IPC thread; vends per-request + :class:`_SessionHandle` instances via :meth:`open_session`. Each + handle implements :class:`DrafterTransport` so the spec loop code + is unchanged -- it just receives a session-scoped transport rather + than the shared one. + + Each wire op (forward / trim / prefill / end-session) is dispatched + on a single-worker :class:`ThreadPoolExecutor`. Wire ops therefore + serialise even when multiple in-flight target requests are calling + methods concurrently from different :class:`_SessionHandle` + instances, which is exactly what we need: a single TCP connection + cannot interleave reads/writes from multiple threads, but the + drafter rank multiplexes operations across sessions by keying its + KV-cache lookup on ``session_id``. + + Why a thread, given MLX is single-GIL? ``socket.recv`` blocks on + the network until the peer responds; running the wire round-trip + on a background thread lets the target's main thread issue MLX + target-verify dispatches in parallel. The drafter's actual compute + happens on the *drafter rank's* GPU, not on a thread of the calling + rank, so there's no GIL contention to worry about. + """ + + def __init__( + self, + *, + num_draft_tokens: int, + sock: socket.socket, + ) -> None: + if num_draft_tokens < 1: + raise ValueError(f"num_draft_tokens must be >= 1, got {num_draft_tokens}") + self._num_draft_tokens = num_draft_tokens + self._sock = sock + # Single-worker pool: every wire op (across all sessions) goes + # through it serially, which keeps ``socket.send/recv`` safe + # even when multiple :class:`_SessionHandle` instances are + # in flight on different target tasks. + self._executor = ThreadPoolExecutor( + max_workers=1, thread_name_prefix="exo-drafter-ipc" + ) + self._is_shutdown = False + # Sticky failure flag, set by the blocking wire helpers when a + # socket-level error escapes (drafter rank crashed, peer + # closed mid-frame, etc.). Once true, subsequent requests must + # not start a new spec session on this transport: the wire is + # unrecoverable until the runner is restarted (the master's + # instance-deletion path tears the placement down within ~5s + # of the drafter node leaving the topology). Callers consult + # :attr:`is_failed` before constructing a + # :class:`PipelinedModelDrafter`; the runner subprocess exits + # via the spec-loop exception if the failure happens + # mid-request (see ``_pipelined_speculative_step``'s abort + # sentinel). + self._is_failed = False + # Monotonic session id allocator. ``itertools.count`` gives us a + # thread-safe unsigned counter; we wrap it in a lock-free + # ``next()`` call inside :meth:`open_session` (Python's GIL + # makes the increment atomic for CPython, but the lock makes + # the contract explicit and survives a free-threaded build). + self._session_id_counter = itertools.count() + self._session_lock = threading.Lock() + + @property + def num_draft_tokens(self) -> int: + return self._num_draft_tokens + + @property + def is_failed(self) -> bool: + """True once a wire-level failure has been observed on this transport. + + Set by the blocking wire helpers when an :class:`OSError` + escapes (drafter rank crashed, peer closed mid-frame, etc.). + Sticky: there is no in-place recovery -- the runner must be + torn down (via the master's instance-deletion path) and a + fresh transport built. Callers consult this flag before + constructing a :class:`PipelinedModelDrafter`; if true the + request degrades to non-speculative decoding for the + remaining lifetime of the runner. + """ + return self._is_failed + + def _mark_failed(self) -> None: + """Internal: flip :attr:`is_failed` to True. Idempotent.""" + self._is_failed = True + + def open_session(self) -> "_SessionHandle": + """Allocate a fresh session and return a :class:`DrafterTransport` view. + + Each call yields a unique ``session_id``; the handle's + :meth:`_SessionHandle.shutdown` sends ``OP_END_SESSION`` so the + drafter rank can free the per-session KV cache. Forgetting to + call :meth:`_SessionHandle.shutdown` leaks a KV cache on the + drafter rank for that session id; ``RemoteTransport.shutdown`` + cleans up at process exit either way. + """ + if self._is_shutdown: + raise RuntimeError( + "RemoteTransport.open_session called after shutdown; the " + "drafter rank's serve loop has exited and won't respond" + ) + if self._is_failed: + raise RuntimeError( + "RemoteTransport.open_session called after a wire-level " + "failure was observed; the underlying socket is dead and " + "the runner must be torn down (master-driven instance " + "deletion) before a fresh session can be opened" + ) + with self._session_lock: + session_id = next(self._session_id_counter) + if session_id == SESSION_ID_NONE: + # 4G sessions exhausted; bump again so we never collide + # with the shutdown sentinel. In practice unreachable. + with self._session_lock: + session_id = next(self._session_id_counter) + return _SessionHandle(owner=self, session_id=session_id) + + def shutdown(self) -> None: + if self._is_shutdown: + return + self._is_shutdown = True + # Send shutdown to the drafter and wait for the ack so the + # drafter has a chance to drain its own state cleanly. + try: + self._executor.submit(self._shutdown_blocking).result(timeout=10.0) + except Exception: + # Drafter rank may already be torn down; the socket close + # below cleans up regardless. The shutdown contract is + # best-effort: if the wire is broken there is nothing to + # ack. + pass + finally: + self._executor.shutdown(wait=True) + with contextlib.suppress(OSError): + self._sock.close() + + # -- session-scoped wire ops (called by _SessionHandle) ------------- + + def _submit_forward( + self, session_id: int, inputs: list[int], num_forwards: int + ) -> "DraftFuture": + if self._is_shutdown: + raise RuntimeError( + "RemoteTransport.forward called after shutdown; the drafter " + "rank's serve loop has exited and won't respond" + ) + upper = self._num_draft_tokens + 1 + if not 1 <= num_forwards <= upper: + raise ValueError( + f"num_forwards must be in [1, {upper}], got {num_forwards}" + ) + if not 1 <= len(inputs) <= 2: + raise ValueError(f"inputs must have length 1 or 2, got {len(inputs)}") + return self._executor.submit( + self._forward_blocking, session_id, inputs, num_forwards + ) + + def _submit_trim(self, session_id: int, n_positions: int) -> None: + if self._is_shutdown: + raise RuntimeError("RemoteTransport.trim_cache called after shutdown") + if n_positions < 0: + raise ValueError(f"n_positions must be >= 0, got {n_positions}") + if n_positions == 0: + return + self._executor.submit(self._trim_blocking, session_id, n_positions).result() + + def _submit_prefill(self, session_id: int, prompt_tokens: list[int]) -> None: + if self._is_shutdown: + raise RuntimeError( + "RemoteTransport.reset_and_prefill called after shutdown" + ) + self._executor.submit( + self._reset_and_prefill_blocking, session_id, prompt_tokens + ).result() + + def _submit_end_session(self, session_id: int) -> None: + # Best-effort: if the wire is already shut down (process is + # tearing down), the session-side OP_END_SESSION would fail + # but the drafter rank is also exiting, so the cache is freed + # by process death anyway. + if self._is_shutdown: + return + self._executor.submit(self._end_session_blocking, session_id).result() + + # -- internals -------------------------------------------------------- + + def _forward_blocking( + self, session_id: int, inputs: list[int], num_forwards: int + ) -> list[int]: + """Send a forward command and recv the drafts. Runs on the IPC thread.""" + frame = _build_command_frame( + op=OP_FORWARD, + inputs=inputs, + num_forwards=num_forwards, + trim_amount=0, + session_id=session_id, + ) + try: + send_uint32_frame(self._sock, frame) + # Drafts buffer is fixed-size at K + 1 (the upper bound of + # any forward request); we slice to ``num_forwards`` here. + drafts = recv_uint32_frame(self._sock, self._num_draft_tokens + 1) + except OSError: + # Drafter rank closed the socket / peer reset / broken + # pipe. Mark the transport so subsequent + # ``open_session`` calls fail fast and the runner can be + # torn down (master-driven instance deletion) instead of + # silently producing nothing on every speculative round. + self._mark_failed() + raise + return drafts[:num_forwards] + + def _trim_blocking(self, session_id: int, n_positions: int) -> None: + """Send a trim command and wait for the ack.""" + frame = _build_command_frame( + op=OP_TRIM_CACHE, + inputs=[], + num_forwards=0, + trim_amount=n_positions, + session_id=session_id, + ) + try: + send_uint32_frame(self._sock, frame) + ack = recv_uint32_frame(self._sock, ACK_FRAME_SIZE) + except OSError: + self._mark_failed() + raise + if ack[0] != ACK_OK: + raise RuntimeError( + f"Drafter rank reported error code {ack[0]} " + f"for trim_cache(session={session_id}, n={n_positions})" + ) + + def _shutdown_blocking(self) -> None: + """Send shutdown command and wait for the ack.""" + frame = _build_command_frame( + op=OP_SHUTDOWN, + inputs=[], + num_forwards=0, + trim_amount=0, + session_id=SESSION_ID_NONE, + ) + send_uint32_frame(self._sock, frame) + # Best-effort recv: if the drafter has already torn down, the + # peer close will surface here. The caller is shutting down + # either way, so swallow recv failures. + with contextlib.suppress(ConnectionError, OSError): + recv_uint32_frame(self._sock, ACK_FRAME_SIZE) + + def _reset_and_prefill_blocking( + self, session_id: int, prompt_tokens: list[int] + ) -> None: + """Send the prefill command + token array and wait for the ack. + + The command frame announces ``num_prompt_tokens`` (encoded in + the ``num_forwards`` slot) and the ``session_id`` to allocate / + reset on the drafter rank. The prompt tail follows immediately + when non-empty, length-prefixed for parser robustness. + """ + num_prompt_tokens = len(prompt_tokens) + frame = _build_command_frame( + op=OP_PREFILL, + inputs=[], + num_forwards=num_prompt_tokens, + trim_amount=0, + session_id=session_id, + ) + try: + send_uint32_frame(self._sock, frame) + if num_prompt_tokens > 0: + send_variable_uint32_payload(self._sock, prompt_tokens) + ack = recv_uint32_frame(self._sock, ACK_FRAME_SIZE) + except OSError: + self._mark_failed() + raise + if ack[0] != ACK_OK: + raise RuntimeError( + f"Drafter rank reported error code {ack[0]} " + f"for reset_and_prefill(session={session_id}, " + f"{num_prompt_tokens} tokens)" + ) + + def _end_session_blocking(self, session_id: int) -> None: + """Send OP_END_SESSION and wait for the ack.""" + frame = _build_command_frame( + op=OP_END_SESSION, + inputs=[], + num_forwards=0, + trim_amount=0, + session_id=session_id, + ) + try: + send_uint32_frame(self._sock, frame) + ack = recv_uint32_frame(self._sock, ACK_FRAME_SIZE) + except OSError: + self._mark_failed() + raise + if ack[0] != ACK_OK: + raise RuntimeError( + f"Drafter rank reported error code {ack[0]} " + f"for end_session({session_id})" + ) + + +@final +class _SessionHandle: + """Per-request :class:`DrafterTransport` view of a :class:`RemoteTransport`. + + Each in-flight target task gets its own handle via + :meth:`RemoteTransport.open_session`. The handle's wire ops carry + the handle's ``session_id`` so the drafter rank can route them to + the right per-session KV cache. + + Lifecycle: + + * :meth:`reset_and_prefill` allocates the session on the drafter + rank and seeds its KV cache with the prompt prefix. + * :meth:`forward` / :meth:`trim_cache` advance / rollback the + session's KV cache. + * :meth:`shutdown` ends the session (sends ``OP_END_SESSION`` so + the drafter rank frees the KV cache). Idempotent; safe to call + from a generator's ``finally`` block. + + All methods raise :class:`RuntimeError` after :meth:`shutdown` so + use-after-end mistakes surface immediately rather than corrupting + a freshly allocated session that happens to reuse the id. + """ + + def __init__(self, *, owner: "RemoteTransport", session_id: int) -> None: + self._owner = owner + self._session_id = session_id + self._closed = False + + @property + def num_draft_tokens(self) -> int: + return self._owner.num_draft_tokens + + @property + def session_id(self) -> int: + return self._session_id + + def forward(self, inputs: list[int], num_forwards: int) -> "DraftFuture": + if self._closed: + raise RuntimeError( + f"_SessionHandle({self._session_id}).forward called after shutdown" + ) + return self._owner._submit_forward(self._session_id, inputs, num_forwards) # pyright: ignore[reportPrivateUsage] + + def trim_cache(self, n_positions: int) -> None: + if self._closed: + raise RuntimeError( + f"_SessionHandle({self._session_id}).trim_cache called after shutdown" + ) + self._owner._submit_trim(self._session_id, n_positions) # pyright: ignore[reportPrivateUsage] + + def reset_and_prefill(self, prompt_tokens: list[int]) -> None: + if self._closed: + raise RuntimeError( + f"_SessionHandle({self._session_id}).reset_and_prefill called after shutdown" + ) + self._owner._submit_prefill(self._session_id, prompt_tokens) # pyright: ignore[reportPrivateUsage] + + def shutdown(self) -> None: + """End the session on the drafter rank. Idempotent.""" + if self._closed: + return + self._closed = True + self._owner._submit_end_session(self._session_id) # pyright: ignore[reportPrivateUsage] + + +def make_remote_transport( + *, + draft_model: "Model | None" = None, + draft_cache: "KVCacheType | None" = None, + num_draft_tokens: int, + sock: socket.socket | None = None, +) -> "RemoteTransport": + """Construct a :class:`RemoteTransport` for the calling target rank. + + Returns the wire-protocol owner; per-task callers should call + :meth:`RemoteTransport.open_session` to obtain a session-scoped + :class:`DrafterTransport` view that the spec loop consumes. The + factory does not implement ``DrafterTransport`` directly because + its lifecycle is bound to the runner (long-lived) while the spec + loop's transport is bound to a single request (short-lived). + + Args: + draft_model: Ignored (the model lives on the drafter rank). + Included in the signature for parity with the in-process + factory so callers don't branch on transport kind. + draft_cache: Ignored (lives on the drafter rank). + num_draft_tokens: ``K`` -- max drafts per round. + sock: Connected TCP socket from target rank 0 to the drafter + rank. The runner bootstrap accepts the drafter's incoming + connection and hands the resulting socket here. + + Raises: + ValueError: required kwargs missing. + """ + del draft_model, draft_cache # not relevant on target rank + if sock is None: + raise ValueError( + "make_remote_transport requires `sock`; the asymmetric " + "instance bootstrap accepts the drafter's incoming TCP " + "connection and passes the connected socket here" + ) + return RemoteTransport( + num_draft_tokens=num_draft_tokens, + sock=sock, + ) + + +# --------------------------------------------------------------------------- +# drafter_serve_loop (drafter side) +# --------------------------------------------------------------------------- + + +def drafter_serve_loop( + *, + draft_model: "Model", + make_draft_cache: Callable[[], "KVCacheType"], + num_draft_tokens: int, + sock: socket.socket, +) -> None: + """Run the drafter rank's command-loop until ``OP_SHUTDOWN``. + + Receives :data:`COMMAND_FRAME_SIZE`-element command frames over + ``sock``, dispatches on the op code, executes the drafter-side + work, and replies with the appropriate frame. + + Maintains a per-session KV cache (``sessions[session_id]``) + allocated lazily on the first ``OP_PREFILL`` for each session and + freed by ``OP_END_SESSION`` (or implicitly by ``OP_SHUTDOWN``). + Multiple sessions may be live concurrently; the wire stays serial + but the drafter rank multiplexes by ``session_id``. + + See module docstring for the wire protocol. + """ + drafts_buffer_size = num_draft_tokens + 1 + sessions: dict[int, "KVCacheType"] = {} + + while True: + flat = recv_uint32_frame(sock, COMMAND_FRAME_SIZE) + op, inputs, num_forwards, trim_amount, session_id = _decode_command_frame(flat) + + if op == OP_SHUTDOWN: + # Drop every session's cache before the serve loop returns + # so the drafter rank's process exits with no dangling + # KV-cache references holding GPU memory. + sessions.clear() + send_uint32_frame(sock, [ACK_OK]) + return + + if op == OP_END_SESSION: + # Idempotent: ending a non-existent session is also a + # successful ack. Forgetful targets (e.g. a runner that + # crashed without calling shutdown on its session) are + # cleaned up by the next ``OP_SHUTDOWN`` either way. + sessions.pop(session_id, None) + send_uint32_frame(sock, [ACK_OK]) + continue + + if op == OP_TRIM_CACHE: + session_cache = sessions.get(session_id) + if session_cache is None: + raise RuntimeError( + f"OP_TRIM_CACHE for unknown session {session_id}; " + f"OP_PREFILL must allocate the session first" + ) + if trim_amount > 0: + # ``mlx_trim_prompt_cache`` is typed against ``List[Cache]`` + # but exo's ``KVCacheType`` is structurally a list of + # mlx_lm caches; the runtime types match exactly. We + # erase to ``Any`` here to bypass list invariance. + from typing import Any + from typing import cast as _cast + + mlx_trim_prompt_cache(_cast(Any, session_cache), trim_amount) # type: ignore[reportArgumentType] + send_uint32_frame(sock, [ACK_OK]) + continue + + if op == OP_FORWARD: + session_cache = sessions.get(session_id) + if session_cache is None: + raise RuntimeError( + f"OP_FORWARD for unknown session {session_id}; " + f"OP_PREFILL must allocate the session first" + ) + outputs = _run_drafter_forwards_remote( + draft_model=draft_model, + draft_cache=session_cache, + inputs=inputs, + num_forwards=num_forwards, + ) + # Pad to fixed-shape buffer so the target's recv pre-allocation matches. + padded = list(outputs) + [0] * (drafts_buffer_size - len(outputs)) + send_uint32_frame(sock, padded) + continue + + if op == OP_PREFILL: + # ``num_forwards`` is overloaded here as the prompt token + # count (see _build_command_frame call site in + # _reset_and_prefill_blocking). + num_prompt_tokens = num_forwards + # Allocate (or replace) the session's KV cache. Replacement + # semantics let a target re-use a session_id after + # OP_END_SESSION + OP_PREFILL without leaking the old cache. + session_cache = make_draft_cache() + sessions[session_id] = session_cache + _reset_and_prefill_remote( + draft_model=draft_model, + draft_cache=session_cache, + num_prompt_tokens=num_prompt_tokens, + sock=sock, + ) + send_uint32_frame(sock, [ACK_OK]) + continue + + # Unknown op code: this is a wire-protocol violation, not a + # recoverable error. Raise so the serve loop dies and the + # caller's ``RemoteTransport`` surfaces the broken-pipe error. + raise RuntimeError(f"Unknown op code from target rank: {op}") + + +def _run_drafter_forwards_remote( + *, + draft_model: "Model", + draft_cache: "KVCacheType", + inputs: list[int], + num_forwards: int, +) -> list[int]: + """Same forward semantics as ``InProcessTransport._run_drafter_forwards``. + + Kept as a free function to avoid importing the in-process transport + on the drafter rank (which only loads the drafter model, not any + target-side code). + """ + if num_forwards < 1: + raise ValueError(f"num_forwards must be >= 1, got {num_forwards}") + if not 1 <= len(inputs) <= 2: + raise ValueError(f"inputs must have length 1 or 2, got {len(inputs)}") + ys: list[mx.array] = [] + y = mx.array(inputs, dtype=mx.uint32) + for _ in range(num_forwards): + logits = draft_model(y[None], cache=draft_cache) + sampled = mx.argmax(logits[:, -1, :], axis=-1).astype(mx.uint32) + mx.async_eval(sampled) + ys.append(sampled) + y = sampled + mx.eval(ys + [c.state for c in draft_cache]) # type: ignore[reportArgumentType] + return [int(t.item()) for t in ys] + + +_DRAFTER_PREFILL_STEP_SIZE: Final[int] = 4096 +"""Chunk size for drafter-side prefill forwards. + +Mirrors :func:`exo.worker.engines.mlx.generator.generate._spec_drafter_prefill`'s +``step`` default. Drafter weights are small (typically <2 GB) so the +4096-token chunks comfortably fit in the drafter rank's command queue +without OOM, even at long prompts.""" + + +def _reset_and_prefill_remote( + *, + draft_model: "Model", + draft_cache: "KVCacheType", + num_prompt_tokens: int, + sock: socket.socket, +) -> None: + """Reset drafter cache and prefill against an incoming prompt. + + Pulled out as a free function (matches + :func:`_run_drafter_forwards_remote`) so the drafter rank doesn't + depend on any target-side code. The target rank already sent the + ``OP_PREFILL`` command frame; this function handles the cache + reset, recvs the prompt array (if any) over ``sock``, and runs the + prefill forwards. The serve loop sends the ack after this returns. + """ + # Trim cache to offset 0 so the new prompt starts cleanly. KVCache's + # offset is the only state we need to reset; SSM caches and other + # exotic types are not in scope for the drafter (drafter models are + # standard transformers by convention). If the offset is 0 the trim + # is a no-op. + current_offset = 0 + if draft_cache: + # Every cache entry shares the same offset for transformer + # drafters; use entry 0 as the source of truth. + cache_zero = draft_cache[0] + offset_attr = getattr(cache_zero, "offset", None) + if isinstance(offset_attr, int): + current_offset = offset_attr + if current_offset > 0: + from typing import cast as _cast + + mlx_trim_prompt_cache(_cast(list[object], draft_cache), current_offset) # type: ignore[reportArgumentType] + + if num_prompt_tokens == 0: + return + + # Pull the prompt array from the target rank. The header preceding + # the payload is sent by ``send_variable_uint32_payload`` and must + # match ``num_prompt_tokens`` -- mismatches indicate a wire-protocol + # bug rather than a recoverable error. + header = recv_uint32_frame(sock, 1) + received_count = header[0] + if received_count != num_prompt_tokens: + raise RuntimeError( + f"OP_PREFILL prompt header mismatch: command announced " + f"{num_prompt_tokens} tokens but payload header says " + f"{received_count}" + ) + prompt_tokens = recv_uint32_frame(sock, num_prompt_tokens) + tokens = mx.array(prompt_tokens, dtype=mx.uint32) + mx.eval(tokens) + + # Mirror :func:`_spec_drafter_prefill`: feed tokens through the + # drafter model in chunks, advancing its KV cache. + step = _DRAFTER_PREFILL_STEP_SIZE + cursor = 0 + while cursor < num_prompt_tokens: + chunk_end = min(cursor + step, num_prompt_tokens) + chunk = tokens[cursor:chunk_end] + draft_model(chunk[None], cache=draft_cache) + mx.eval([c.state for c in draft_cache]) # type: ignore[reportArgumentType] + cursor = chunk_end + + +__all__ = [ + "ACK_FRAME_SIZE", + "ACK_OK", + "COMMAND_FRAME_SIZE", + "OP_END_SESSION", + "OP_FORWARD", + "OP_PREFILL", + "OP_SHUTDOWN", + "OP_TRIM_CACHE", + "SESSION_ID_NONE", + "RemoteTransport", + "drafter_serve_loop", + "make_remote_transport", +] + + +# Suppress the unused-import warnings for the future-only Future type: +# ThreadPoolExecutor.submit returns ``Future`` which is structurally +# compatible with :data:`DraftFuture`, but we annotate the return type +# inside the class body and the import is otherwise unused. +_ = Future diff --git a/src/exo/worker/engines/mlx/generator/target_peer_socket.py b/src/exo/worker/engines/mlx/generator/target_peer_socket.py new file mode 100644 index 0000000000..e4e2fd3e3a --- /dev/null +++ b/src/exo/worker/engines/mlx/generator/target_peer_socket.py @@ -0,0 +1,191 @@ +"""Direct TCP socket transport for target-rank-to-peer broadcasts. + +Mirrors :mod:`drafter_socket` but for inter-target-rank communication +during the speculative-decode hot path. The hot path needs to broadcast +small int32 buffers from target rank 0 to every other target rank +(drafts on the way in, sampled tokens on the way out). The original +implementation rode :func:`mx.distributed.all_sum` and later +:func:`mx.distributed.send` / :func:`recv` over the same target group +that runs the model's tensor-parallel ``all_sum`` collectives. + +That coupling is the bug: the JACCL backend interleaves the int32 +broadcast with the float32 TP all-reduce on the same wire, occasionally +handing back logits memory in place of the requested int32 buffer. +Symptom is a deterministic out-of-vocabulary token id (the bit pattern +of a float32 logit reinterpreted as int32) emerging on the receiving +peer rank a few hundred milliseconds into generation. + +Fix: lift the int32 broadcasts off ``mx.distributed`` entirely. Target +rank 0 binds a TCP listener at instance bootstrap; every other target +rank dials in once and reuses the connection for the lifetime of the +runner. The wire is fundamentally separate from JACCL, so the model's +TP collectives and the spec-decode broadcasts can never collide. + +Wire frames are fixed-length little-endian int32 sequences, matching +:mod:`drafter_socket` for consistency. Unlike the drafter wire, every +spec-decode broadcast has a known shape (``k + 1`` ints), so no +length-prefixed payloads are needed. + +Threading model: the spec-decode loop is single-threaded per runner; +target rank 0's broadcast issues one ``sendall`` per peer, peers issue +one ``recv_into`` per round. No multiplexing, no out-of-order frames. +""" + +from __future__ import annotations + +import socket +import struct +import time +from typing import Final + +_INT32_MIN: Final[int] = -(1 << 31) +_INT32_MAX: Final[int] = (1 << 31) - 1 + + +def send_int32_frame(sock: socket.socket, values: list[int]) -> None: + """Send a fixed-length signed int32 frame over ``sock``. + + The spec-decode loop only ever broadcasts non-negative token ids + and length prefixes today, but signed int32 covers both that case + and any future sentinel (e.g. -1 for "end of stream") without + revisiting the wire format. Callers must guarantee the peer + expects exactly ``len(values)`` ints; no length header is sent. + """ + for index, value in enumerate(values): + if value < _INT32_MIN or value > _INT32_MAX: + raise ValueError( + f"target-peer frame value at index {index}={value} is out of " + f"int32 range [{_INT32_MIN}, {_INT32_MAX}]" + ) + payload = struct.pack(f"<{len(values)}i", *values) + sock.sendall(payload) + + +def recv_int32_frame(sock: socket.socket, count: int) -> list[int]: + """Receive ``count`` signed int32 ints over ``sock`` (no length prefix). + + Blocks until ``count * 4`` bytes have arrived, raising + :class:`ConnectionError` if the peer closes mid-frame so the + spec-decode loop surfaces a typed wire failure rather than a + silent truncated buffer. + """ + if count <= 0: + raise ValueError(f"count must be > 0, got {count}") + needed = count * 4 + buf = bytearray(needed) + view = memoryview(buf) + received = 0 + while received < needed: + chunk = sock.recv_into(view[received:], needed - received) + if chunk == 0: + raise ConnectionError( + f"target-peer wire closed mid-frame " + f"(received {received}/{needed} bytes)" + ) + received += chunk + unpacked = struct.unpack(f"<{count}i", bytes(buf)) + return list(unpacked) + + +def bind_target_peer_listener( + host: str, port: int, *, backlog: int +) -> socket.socket: + """Open and listen on ``(host, port)`` for peer target ranks to dial in. + + ``backlog`` is the expected number of dialing peers + (``target_world_size - 1``). ``SO_REUSEADDR`` is set so a stale + TIME_WAIT socket from a previous instance teardown does not block + rebind. Caller owns ``accept()`` (see :func:`accept_target_peers`) + and ``close()``. + """ + listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + listener.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + listener.bind((host, port)) + listener.listen(backlog) + return listener + + +def accept_target_peers( + listener: socket.socket, + *, + expected_peers: int, + timeout_seconds: float, +) -> list[socket.socket]: + """Accept exactly ``expected_peers`` incoming target-peer connections. + + Order of acceptance is not significant for the wire protocol -- + rank 0 issues one ``sendall`` per accepted socket per broadcast, + independent of which peer rank ended up where in the list. Callers + that need rank-indexed access (none in the current spec-decode + loop) must perform their own handshake on top of the returned + sockets. + + ``TCP_NODELAY`` is set on every accepted socket. Each broadcast is + a 24-to-200-byte int32 frame followed by a long pause (the + verifier's TP forward pass), so Nagle would add the full 40ms + delayed-ack timeout to every round. Disabling Nagle drops that to + sub-millisecond on Thunderbolt RDMA. + """ + if expected_peers <= 0: + raise ValueError( + f"accept_target_peers needs expected_peers >= 1, got {expected_peers}" + ) + listener.settimeout(timeout_seconds) + accepted: list[socket.socket] = [] + try: + for _ in range(expected_peers): + accept_result: tuple[socket.socket, object] = listener.accept() + conn: socket.socket = accept_result[0] + conn.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + accepted.append(conn) + finally: + listener.settimeout(None) + return accepted + + +def dial_target_zero( + host: str, + port: int, + *, + total_timeout_seconds: float, + initial_backoff_seconds: float = 0.5, +) -> socket.socket: + """Dial target rank 0 from a peer target rank, retrying until success. + + Target rank 0 binds inside :func:`initialize_mlx` after + ``mlx_distributed_init`` returns; peers dial during the same + bootstrap step, so the listener may not yet be up when the first + dial attempt fires. Exponential backoff (capped at 5s) covers the + bind / accept race without spinning. Failure after + ``total_timeout_seconds`` raises :class:`ConnectionError`, which + the runner surfaces as a connect-task failure so the cluster does + not sit silently wedged. + """ + deadline = time.monotonic() + total_timeout_seconds + backoff = initial_backoff_seconds + last_error: BaseException | None = None + while time.monotonic() < deadline: + try: + conn = socket.create_connection( + (host, port), timeout=min(10.0, total_timeout_seconds) + ) + conn.settimeout(None) + conn.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + return conn + except (ConnectionRefusedError, OSError, TimeoutError) as exc: + last_error = exc + time.sleep(backoff) + backoff = min(backoff * 2.0, 5.0) + raise ConnectionError( + f"target peer could not reach target rank 0 at {host}:{port} " + f"within {total_timeout_seconds:.0f}s (last error: {last_error!r})" + ) + + +__all__ = [ + "accept_target_peers", + "bind_target_peer_listener", + "dial_target_zero", + "recv_int32_frame", + "send_int32_frame", +] diff --git a/src/exo/worker/engines/mlx/tests/test_batched_prefill.py b/src/exo/worker/engines/mlx/tests/test_batched_prefill.py new file mode 100644 index 0000000000..e404191551 --- /dev/null +++ b/src/exo/worker/engines/mlx/tests/test_batched_prefill.py @@ -0,0 +1,270 @@ +# pyright: reportAny=false, reportUnknownVariableType=false +# pyright: reportUnknownMemberType=false, reportUnknownArgumentType=false +# pyright: reportUnknownLambdaType=false, reportPrivateUsage=false +# pyright: reportInvalidCast=false, reportArgumentType=false +"""Correctness tests for :func:`batched_prefill`. + +Validates that running K prefills in a single batched forward (the seam +:class:`SequentialGenerator` uses to absorb the residual 11s outliers +on the long-prompt mixed-traffic bench) produces bit-exact decode +state vs running K independent :func:`prefill` calls. We compare +post-prefill logits from the next decode tick rather than raw cache +state because mlx's ``BatchKVCache`` stores keys/values in a different +shape from ``KVCache`` after :meth:`extract` and exact-cache equality +would miss the question we actually care about — does the next forward +sample the same token? + +Uses tiny llama-style random weights (no model download) so the tests +stay fast enough to run on every CI invocation. +""" + +from pathlib import Path +from typing import cast + +import mlx.core as mx +import mlx.nn as nn +import mlx.utils +import pytest +from mlx_lm.sample_utils import make_sampler +from mlx_lm.tokenizer_utils import TokenizerWrapper +from transformers import AutoTokenizer + +from exo.worker.engines.mlx.cache import encode_prompt, make_kv_cache +from exo.worker.engines.mlx.generator.generate import ( + BatchedPrefillUnsupportedError, + batched_prefill, + prefill, +) +from exo.worker.engines.mlx.types import Model + +NUM_STEPS = 16 + + +def _init_random(model: nn.Module) -> None: + params = model.parameters() + new_params = mlx.utils.tree_map( + lambda p: mx.random.normal(shape=p.shape, dtype=p.dtype) + if isinstance(p, mx.array) + else p, + params, + ) + model.update(new_params) + mx.eval(model.parameters()) + + +def _make_tiny_llama() -> tuple[Model, TokenizerWrapper]: + from huggingface_hub import snapshot_download + from mlx_lm.models.llama import Model as LlamaModel + from mlx_lm.models.llama import ModelArgs + + mx.random.seed(42) + args = ModelArgs( + model_type="llama", + hidden_size=256, + num_hidden_layers=4, + intermediate_size=512, + num_attention_heads=4, + num_key_value_heads=2, + rms_norm_eps=1e-6, + vocab_size=248320, + rope_theta=10000.0, + tie_word_embeddings=True, + ) + model = LlamaModel(args) + _init_random(model) + + model_path = Path( + snapshot_download( + "mlx-community/Qwen3.5-35B-A3B-4bit", + allow_patterns=["tokenizer*", "*.jinja"], + ) + ) + hf_tokenizer = AutoTokenizer.from_pretrained(model_path) + tokenizer = TokenizerWrapper(hf_tokenizer) + return cast(Model, model), tokenizer + + +def _decode_one_token(model: Model, cache: object, last_token: int) -> mx.array: + """Run one forward with the seed token; return the (vocab,) logits. + + Mirrors the entry state ``mlx_generate`` hands to the spec loop: + cache is at offset ``len(prompt) - 1`` and the next forward feeds + the seed token (``prompt[-1]``). + """ + out = model(mx.array([[last_token]]), cache=cast(list[object], cache)) + mx.eval(out) + return out[0, -1] + + +@pytest.mark.slow +def test_batched_prefill_matches_sequential_for_two_prompts() -> None: + """B=2 batched_prefill must produce the same decode logits as 2x B=1 prefill. + + Compares the ``argmax`` token from the first decode forward after + prefill — that's the only invariant the spec loop reads from the + post-prefill cache, so bit-exact cache layout doesn't matter as + long as the next forward agrees. + """ + model, tokenizer = _make_tiny_llama() + sampler = make_sampler(temp=0.0) + + tokens_a = encode_prompt(tokenizer, "Write a short essay about AI.") + tokens_b = encode_prompt(tokenizer, "Explain evolution briefly.") + + # Sequential reference (per-slot prefill on prompt[:-1]; the + # exo.prefill helper advances cache to len(prompt) - 2 via its + # +1 / -2 dance). + cache_a_seq = make_kv_cache(model) + prefill(model, tokenizer, sampler, tokens_a[:-1], cache_a_seq, None, None, None) + cache_b_seq = make_kv_cache(model) + prefill(model, tokenizer, sampler, tokens_b[:-1], cache_b_seq, None, None, None) + + # Sequential decode: feed the prefill-tail's penultimate then last + # token to advance cache from offset N-2 to N-1, then sample the + # first generated logits. + last_a = int(tokens_a[-1].item()) + penult_a = int(tokens_a[-2].item()) + model(mx.array([[penult_a]]), cache=cast(list[object], cache_a_seq)) + seq_logits_a = _decode_one_token(model, cache_a_seq, last_a) + + last_b = int(tokens_b[-1].item()) + penult_b = int(tokens_b[-2].item()) + model(mx.array([[penult_b]]), cache=cast(list[object], cache_b_seq)) + seq_logits_b = _decode_one_token(model, cache_b_seq, last_b) + + # Batched: batched_prefill leaves cache at offset N-1 directly (no + # +1/-2 dance), so the equivalent decode is one forward on the + # last token only. + cache_a_batch = make_kv_cache(model) + cache_b_batch = make_kv_cache(model) + aggregate_tps, total_tokens = batched_prefill( + model=model, + prompt_tokens_list=[tokens_a, tokens_b], + caches_list=[cache_a_batch, cache_b_batch], + ) + assert aggregate_tps > 0.0 + assert total_tokens == int(tokens_a.size) - 1 + int(tokens_b.size) - 1 + + batch_logits_a = _decode_one_token(model, cache_a_batch, last_a) + batch_logits_b = _decode_one_token(model, cache_b_batch, last_b) + + # Decoded token must agree; small numerical drift in the logits is + # acceptable (different reduction order in the batched matmul) but + # the argmax must be identical. + assert int(mx.argmax(seq_logits_a).item()) == int(mx.argmax(batch_logits_a).item()) + assert int(mx.argmax(seq_logits_b).item()) == int(mx.argmax(batch_logits_b).item()) + + +@pytest.mark.slow +def test_batched_prefill_continues_decoding_correctly() -> None: + """After batched_prefill the per-slot decode must stay aligned for many steps. + + A single matching first-token argmax can be coincidence; we extend + the comparison to ``NUM_STEPS`` decoded tokens to catch cache-state + bugs that only show up after multiple forwards (e.g. an off-by-one + in BatchKVCache.extract that would skew RoPE positions). + """ + model, tokenizer = _make_tiny_llama() + sampler = make_sampler(temp=0.0) + + tokens_a = encode_prompt(tokenizer, "Hello there general kenobi.") + tokens_b = encode_prompt(tokenizer, "The quick brown fox jumps.") + + # Sequential reference run produces a token sequence per slot. + seq_tokens: list[list[int]] = [] + for tokens in (tokens_a, tokens_b): + cache_seq = make_kv_cache(model) + prefill(model, tokenizer, sampler, tokens[:-1], cache_seq, None, None, None) + last = int(tokens[-1].item()) + penult = int(tokens[-2].item()) + model(mx.array([[penult]]), cache=cast(list[object], cache_seq)) + next_tok = last + produced: list[int] = [] + for _ in range(NUM_STEPS): + logits = _decode_one_token(model, cache_seq, next_tok) + next_tok = int(mx.argmax(logits).item()) + produced.append(next_tok) + seq_tokens.append(produced) + + # Batched run. + cache_a = make_kv_cache(model) + cache_b = make_kv_cache(model) + batched_prefill( + model=model, + prompt_tokens_list=[tokens_a, tokens_b], + caches_list=[cache_a, cache_b], + ) + batch_tokens: list[list[int]] = [] + for tokens, cache in ((tokens_a, cache_a), (tokens_b, cache_b)): + last = int(tokens[-1].item()) + next_tok = last + produced = [] + for _ in range(NUM_STEPS): + logits = _decode_one_token(model, cache, next_tok) + next_tok = int(mx.argmax(logits).item()) + produced.append(next_tok) + batch_tokens.append(produced) + + # Mismatches downstream of step 0 still indicate a real cache + # bug; we tolerate up to one drift in NUM_STEPS as numerical + # slack but the first 8 tokens must agree. + assert seq_tokens[0][:8] == batch_tokens[0][:8] + assert seq_tokens[1][:8] == batch_tokens[1][:8] + + +def test_batched_prefill_empty_inputs_returns_zero() -> None: + """No-op on empty input: the caller may filter to zero eligible slots.""" + tps, total = batched_prefill( + model=cast(Model, object()), + prompt_tokens_list=[], + caches_list=[], + ) + assert tps == 0.0 + assert total == 0 + + +def test_batched_prefill_rejects_mismatched_lengths() -> None: + """``prompt_tokens_list`` and ``caches_list`` must agree on K.""" + with pytest.raises(ValueError, match="must have the same length"): + batched_prefill( + model=cast(Model, object()), + prompt_tokens_list=[mx.array([1, 2, 3]), mx.array([4, 5, 6])], + caches_list=[[]], + ) + + +def test_batched_prefill_rejects_short_prompts() -> None: + """Prompts < 2 tokens leave no decode-seed token after slicing.""" + with pytest.raises(ValueError, match="length >= 2"): + batched_prefill( + model=cast(Model, object()), + prompt_tokens_list=[mx.array([7])], + caches_list=[[]], + ) + + +def test_batched_prefill_unsupported_cache_raises_typed_error() -> None: + """Cache layers without ``merge`` must surface :class:`BatchedPrefillUnsupportedError`. + + The contract: callers (``SequentialGenerator._admit_queued_tasks``) + catch this typed error to fall back to per-slot prefill instead of + crashing the runner. + """ + + class _UnsupportedLayer: + # No ``merge`` classmethod => mlx_lm._merge_caches raises + # ``ValueError(f"{type} does not yet support batching with history")``. + pass + + cache_a: list[object] = [_UnsupportedLayer()] + cache_b: list[object] = [_UnsupportedLayer()] + + with pytest.raises(BatchedPrefillUnsupportedError): + batched_prefill( + model=cast(Model, object()), + prompt_tokens_list=[ + mx.array([1, 2, 3]), + mx.array([4, 5, 6]), + ], + caches_list=cast(list[object], [cache_a, cache_b]), + ) diff --git a/src/exo/worker/engines/mlx/utils_mlx.py b/src/exo/worker/engines/mlx/utils_mlx.py index a863ea4bb3..165abe6fb8 100644 --- a/src/exo/worker/engines/mlx/utils_mlx.py +++ b/src/exo/worker/engines/mlx/utils_mlx.py @@ -5,8 +5,9 @@ import tempfile import time from collections.abc import Generator +from dataclasses import dataclass, field from pathlib import Path -from typing import TYPE_CHECKING, Any, cast +from typing import TYPE_CHECKING, Any, Final, cast, final if TYPE_CHECKING: from exo.worker.engines.mlx.vision import VisionProcessor @@ -44,7 +45,7 @@ from exo.download.download_utils import build_model_path, resolve_existing_model from exo.shared.types.common import Host from exo.shared.types.memory import Memory -from exo.shared.types.tasks import TaskId, TextGeneration +from exo.shared.types.tasks import TextGeneration from exo.shared.types.text_generation import ChatTemplateValue, TextGenerationTaskParams from exo.shared.types.worker.instances import ( BoundInstance, @@ -104,13 +105,31 @@ def from_hosts(cls, hosts: list[Host]) -> "HostList": return cls(root=[str(host) for host in hosts]) +def _bound_rank(bound_instance: BoundInstance) -> int: + """Rank of this runner inside the parent ``mx.distributed`` group. + + Target ranks read this from their bound shard metadata; the drafter + rank reads it from :class:`DrafterPlacement` since the drafter has + no target shard. + """ + if bound_instance.is_drafter_rank: + placement = bound_instance.instance.drafter_placement + assert placement is not None # type narrowed by is_drafter_rank + return placement.drafter_rank + return bound_instance.bound_shard.device_rank + + def mlx_distributed_init( bound_instance: BoundInstance, ) -> mx.distributed.Group: + """Initialize MLX distributed for this rank's parent group. + + The parent group spans every rank declared by the instance: target + ranks plus, for asymmetric placement, the trailing drafter rank. + Target ranks split off into a subgroup at runtime via + :func:`initialize_mlx`; this helper just brings up the parent. """ - Initialize MLX distributed. - """ - rank = bound_instance.bound_shard.device_rank + rank = _bound_rank(bound_instance) logger.info(f"Starting initialization for rank {rank}") with tempfile.TemporaryDirectory() as tmpdir: @@ -179,33 +198,411 @@ def mlx_distributed_init( return group -def initialize_mlx( - bound_instance: BoundInstance, -) -> mx.distributed.Group: +@final +@dataclass(frozen=True) +class MlxGroupSplit: + """Target-side view of an instance's distributed wiring. + + Pre-v3 the asymmetric drafter rank was a member of the parent + ``mx.distributed`` group, and this struct carried the parent + a + target-only subgroup. Under the v3+ wire the drafter is NOT in any + ``mx.distributed.Group`` -- target ranks form their own group of + size ``target_world_size`` and the drafter dials a TCP socket. The + struct now carries: + + * ``parent`` / ``target_subgroup`` -- aliases for the same target + group (``parent is target_subgroup`` always under v3). Both + fields are retained so existing callers (builder.py, image + builder, generate.py) keep working without rev. ``None`` when + the target world size is 1 (the well-known "single rank, no + collectives needed" signal that + :func:`load_mlx_items`, :func:`mx_barrier`, :func:`mx_any` + already short-circuit on). + * ``drafter_socket`` -- the connected TCP socket between target + rank 0 and the drafter rank. Set ONLY on target rank 0 of an + asymmetric placement; ``None`` for any other rank. + * ``drafter_rank_in_parent`` -- advisory placement index of the + drafter (``placement.drafter_rank``). Carried for telemetry + and the few legacy call sites that branch on "is asymmetric"; + ``None`` for symmetric placement. + * ``target_peer_fanout`` -- inter-target-rank TCP fanout for + spec-decode int broadcasts (see :class:`TargetPeerFanout`). + ``None`` for single-target instances or symmetric placements + without a drafter (no spec-decode hot path; legacy + ``mx_broadcast_int_list`` is sufficient). + """ + + parent: mx.distributed.Group | None + target_subgroup: mx.distributed.Group | None + drafter_rank_in_parent: int | None + drafter_socket: object | None = None + """Connected ``socket.socket`` from target rank 0 to the drafter. + + Typed as ``object`` to keep the dataclass importable from modules + that don't import ``socket`` directly. Runtime callers + (:mod:`builder`) cast back to ``socket.socket`` before passing to + :func:`make_remote_transport`.""" + + target_peer_fanout: "TargetPeerFanout | None" = None + """Inter-target-rank TCP fanout for spec-decode int broadcasts. + + Allocated alongside the drafter socket on multi-target asymmetric + placements. ``None`` for single-target or symmetric instances. + Built once at bootstrap; the spec-decode loop reuses it for every + round.""" + + @property + def is_asymmetric(self) -> bool: + return self.drafter_rank_in_parent is not None + + +@final +@dataclass(frozen=True) +class TargetPeerFanout: + """Direct TCP int-broadcast wire between target rank 0 and its peers. + + Replaces :func:`mx.distributed.send` / :func:`recv` on the + spec-decode hot path. JACCL on Apple Silicon conflates int32 + broadcasts on the target group with the model's float32 TP + ``all_sum`` collectives; the former occasionally returns the + latter's logit memory reinterpreted as int32, surfacing as + out-of-vocab token ids (~``10^9``) deep in the SPM detokenizer. + + The model's TP ``all_sum`` collectives stay on JACCL/RDMA -- they + carry multi-MB tensor reductions where vendor RDMA wins + decisively. Only the tiny (~24-byte) int32 broadcasts move to TCP, + where Thunderbolt with ``TCP_NODELAY`` adds <100µs per round + (negligible against a ~30ms verifier forward). + + Topology: + * On target rank 0: ``peer_sockets`` holds one connection per + non-zero peer rank, indexed by peer rank. + * On a peer target rank (rank > 0): ``rank_zero_socket`` holds + the single connection back to rank 0. + + Both shapes are produced by :func:`_setup_target_peer_fanout` at + instance bootstrap and are immutable for the runner's lifetime. + Reconnect-on-failure is intentionally NOT supported: a transport + failure on this wire is treated as a hard runner failure (same as + a TP all-reduce failure) and the supervisor rebuilds the instance. + """ + + rank: int + """Caller's target rank inside the parent group; matches + ``MlxGroupSplit.parent.rank()`` when ``parent`` is set.""" + + peer_sockets: dict[int, object] = field(default_factory=dict) + """Rank 0 only: ``{peer_rank: socket.socket}``. Empty on rank > 0.""" + + rank_zero_socket: object | None = None + """Rank > 0 only: connected socket back to rank 0. ``None`` on rank 0.""" + + expected_world_size: int = 1 + """Target world size (every rank in the fanout sees the same value). + + Stored explicitly so the broadcast helpers can sanity-check that + rank 0's ``peer_sockets`` cover all peers without re-deriving the + world size from a possibly-discarded group handle.""" + + +def initialize_mlx(bound_instance: BoundInstance) -> MlxGroupSplit: + """Bring up the target ``mx.distributed`` group + (rank 0) drafter socket. + + Target ranks: initialise an ``mx.distributed.Group`` of size + ``parent_group_size`` (which under v3+ equals the number of target + shards -- the drafter is NOT a member of this group). Single-target + instances (``parent_group_size == 1``) short-circuit and return a + split with ``parent / target_subgroup = None``. + + Target rank 0 of an asymmetric placement additionally binds a TCP + listener on ``DrafterPlacement.drafter_socket_port`` and accepts + the drafter's incoming connection. The connected socket flows + through :class:`MlxGroupSplit.drafter_socket` to the builder, which + hands it to :func:`make_remote_transport`. + + The drafter rank does NOT call this function; its bootstrap + (:class:`DrafterRunner._handle_connect`) dials the socket directly + without touching ``mx.distributed`` at all. + """ + assert not bound_instance.is_drafter_rank, ( + "initialize_mlx should not be called on a drafter rank under " + "the v3+ asymmetric wire; DrafterRunner._handle_connect dials " + "the drafter socket directly without joining mx.distributed." + ) # should we unseed it? # TODO: pass in seed from params mx.random.seed(42) - assert len(bound_instance.instance.shard_assignments.node_to_runner) > 1, ( - "Tried to initialize mlx for a single node instance" + target_world_size = bound_instance.instance.parent_group_size + placement = bound_instance.instance.drafter_placement + + # Single-target instance: no mx.distributed group needed (other + # ranks short-circuit on the ``group is None`` signal). Drafter + # wire still exists for asymmetric placement. + parent: mx.distributed.Group | None = ( + None if target_world_size <= 1 else mlx_distributed_init(bound_instance) + ) + + drafter_rank_in_parent = placement.drafter_rank if placement is not None else None + + drafter_socket = _maybe_accept_drafter_socket( + bound_instance=bound_instance, + target_world_size=target_world_size, + placement=placement, + ) + + target_peer_fanout = _maybe_setup_target_peer_fanout( + bound_instance=bound_instance, + target_world_size=target_world_size, + placement=placement, + ) + + return MlxGroupSplit( + parent=parent, + target_subgroup=parent, + drafter_rank_in_parent=drafter_rank_in_parent, + drafter_socket=drafter_socket, + target_peer_fanout=target_peer_fanout, + ) + + +def _maybe_accept_drafter_socket( + *, + bound_instance: BoundInstance, + target_world_size: int, + placement: object, +) -> object | None: + """Bind + accept the drafter dial on target rank 0; otherwise return ``None``. + + Only target rank 0 of an asymmetric placement owns the drafter + wire. Other target ranks (rank >= 1) and symmetric placements + return ``None``. The caller embeds the result in + :class:`MlxGroupSplit.drafter_socket`. + + The accept call is sequential after :func:`mlx_distributed_init` + in the parent function. The drafter's :func:`dial_target` retries + with backoff for up to two minutes, which comfortably covers the + target group's bootstrap latency. If accept times out (drafter + unreachable / crashed), this raises :class:`socket.timeout`; the + runner surface bubbles it up as a connect-task failure so the + cluster doesn't sit silently wedged. + """ + from exo.shared.types.worker.instances import DrafterPlacement + + if placement is None: + return None + if not isinstance(placement, DrafterPlacement): + raise TypeError( + f"drafter_placement must be DrafterPlacement, got {type(placement)!r}" + ) + # Target rank 0 binds; other target ranks no-op. Symmetric placements + # land in the ``placement is None`` branch above. + if bound_instance.parent_rank != 0: + return None + del target_world_size # not needed once we know we're rank 0 + # Imported lazily to avoid pulling the socket transport into module + # import unless this code path is exercised. + from exo.worker.engines.mlx.generator.drafter_socket import ( + accept_drafter, + bind_target_listener, + ) + + # Bind to all interfaces so the drafter can dial whichever address + # ``DrafterPlacement.drafter_socket_host`` resolves to (LAN, + # Thunderbolt-bridge, Tailscale, etc.). The placement-time IP only + # serves as the address the drafter dials; target rank 0 doesn't + # need to advertise a specific bind address. + listener = bind_target_listener("0.0.0.0", placement.drafter_socket_port) + try: + logger.info( + f"target rank 0 listening for drafter on " + f"0.0.0.0:{placement.drafter_socket_port} " + f"(advertised {placement.drafter_socket_host})" + ) + conn = accept_drafter(listener, timeout_seconds=180.0) + logger.info("target rank 0 accepted drafter connection") + return conn + finally: + # Listener is single-shot (drafter dials once and stays + # connected for the instance lifetime); close it as soon as + # accept returns to free the port. + listener.close() + + +def _maybe_setup_target_peer_fanout( + *, + bound_instance: BoundInstance, + target_world_size: int, + placement: object, +) -> TargetPeerFanout | None: + """Bring up the inter-target-rank TCP int-broadcast wire. + + Multi-target asymmetric placements need a TCP fanout between + target rank 0 and its peers because the JACCL backend conflates + the model's float32 TP ``all_sum`` with int32 broadcasts on the + same group (see :class:`TargetPeerFanout` docstring). Single-rank + targets and symmetric placements (no drafter) have no spec-decode + hot path, so they don't need this wire and the function returns + ``None``. + + Bootstrap protocol: + + * Target rank 0 binds 0.0.0.0:``placement.target_peer_socket_port`` + and accepts ``target_world_size - 1`` incoming connections. + * Each non-zero target rank dials + ``placement.target_peer_hosts_by_rank[my_rank]:target_peer_socket_port`` + with bounded retry (the listener may not be up yet on the + first attempt because ``accept`` and ``connect`` race during + bootstrap). + + The drafter rank is NOT in this fanout: it has its own dedicated + wire to target rank 0 (see :func:`_maybe_accept_drafter_socket`). + Skipping the fanout for the drafter rank is the right call + because the drafter never broadcasts int frames to target peers + -- it only exchanges drafts/verify with rank 0. + + Failure mode: a dial timeout / accept timeout raises + :class:`ConnectionError` or :class:`socket.timeout`, which + bubbles up to the runner and surfaces as a connect-task failure. + The cluster does not silently wedge. + """ + from exo.shared.types.worker.instances import DrafterPlacement + + if placement is None or not isinstance(placement, DrafterPlacement): + return None + if target_world_size <= 1: + return None + if bound_instance.is_drafter_rank: + return None + + rank = bound_instance.parent_rank + expected_world_size = target_world_size + + # Imported lazily to avoid pulling the socket module into module + # import for runners that never reach this code path. + from exo.worker.engines.mlx.generator.target_peer_socket import ( + accept_target_peers, + bind_target_peer_listener, + dial_target_zero, + ) + + if rank == 0: + listener = bind_target_peer_listener( + "0.0.0.0", + placement.target_peer_socket_port, + backlog=expected_world_size - 1, + ) + try: + logger.info( + f"target rank 0 listening for {expected_world_size - 1} " + f"target peers on 0.0.0.0:{placement.target_peer_socket_port}" + ) + conns = accept_target_peers( + listener, + expected_peers=expected_world_size - 1, + timeout_seconds=180.0, + ) + logger.info( + f"target rank 0 accepted {len(conns)} target-peer " + "connection(s)" + ) + finally: + listener.close() + # The peer rank that wrote each connection is implicit (we + # accept in connection order, but peers can dial in any + # order). Spec-decode broadcasts don't need rank-indexed + # peers -- rank 0 sends the same payload to every peer per + # round -- so we store sockets in arbitrary stable order + # keyed by accept order. The spec-decode broadcast helper + # iterates ``peer_sockets.values()`` and ignores keys. + peer_sockets: dict[int, object] = {idx: c for idx, c in enumerate(conns)} + return TargetPeerFanout( + rank=0, + peer_sockets=peer_sockets, + rank_zero_socket=None, + expected_world_size=expected_world_size, + ) + + rank_zero_host = placement.target_peer_hosts_by_rank.get(str(rank)) + if rank_zero_host is None: + raise RuntimeError( + f"target peer rank {rank} (key={str(rank)!r}) has no entry " + f"in DrafterPlacement.target_peer_hosts_by_rank " + f"({placement.target_peer_hosts_by_rank}); placement is " + "malformed" + ) + logger.info( + f"target peer rank {rank} dialing target rank 0 at " + f"{rank_zero_host}:{placement.target_peer_socket_port}" + ) + conn = dial_target_zero( + rank_zero_host, + placement.target_peer_socket_port, + total_timeout_seconds=180.0, + ) + logger.info(f"target peer rank {rank} connected to target rank 0") + return TargetPeerFanout( + rank=rank, + peer_sockets={}, + rank_zero_socket=conn, + expected_world_size=expected_world_size, ) - return mlx_distributed_init(bound_instance) EXO_DISABLE_DRAFTER_ENV = "EXO_DISABLE_DRAFTER" +EXO_DRAFTER_PREFERENCE_ENV = "EXO_DRAFTER_PREFERENCE" + +# Allowed values for ``EXO_DRAFTER_PREFERENCE``. ``fastest`` picks the first +# drafter declared on the card (smallest by convention); ``highest_acceptance`` +# picks the last (largest by convention); ``auto`` defaults to ``fastest`` but +# may be tuned by future heuristics (e.g. observed acceptance rate). +_DRAFTER_PREFERENCE_VALUES: frozenset[str] = frozenset( + {"fastest", "highest_acceptance", "auto"} +) def _drafter_disabled_by_env() -> bool: return os.environ.get(EXO_DISABLE_DRAFTER_ENV, "").lower() in {"1", "true", "yes"} -def _maybe_load_drafter(model_card: ModelCard) -> Model | None: - """Load the drafter model declared on ``model_card``, if any. +def _drafter_preference() -> str: + raw = os.environ.get(EXO_DRAFTER_PREFERENCE_ENV, "auto").lower() + if raw not in _DRAFTER_PREFERENCE_VALUES: + logger.warning( + f"Unknown {EXO_DRAFTER_PREFERENCE_ENV}={raw!r}, falling back to 'auto'" + ) + return "auto" + return raw + - Returns ``None`` when the card has no drafter, the drafter weights are not - on disk, or the user has disabled drafter loading via - ``EXO_DISABLE_DRAFTER``. Drafter loading failures are logged and swallowed: - the target model continues to load and inference falls back to standard +def _select_drafter_id(candidates: list[ModelId], preference: str) -> ModelId | None: + """Pick a drafter id from a card's preference-ordered list. + + The card lists drafters in `[fastest, ..., highest_acceptance]` order. We + prefer drafters that are already on disk (so the chooser doesn't force a + surprise download); within the on-disk subset we honor the user's + preference. If nothing is on disk we fall back to the head of the list, + leaving the loader to log a "weights missing" warning. + """ + if not candidates: + return None + + on_disk = [cid for cid in candidates if resolve_existing_model(cid) is not None] + pool = on_disk if on_disk else candidates + + if preference == "highest_acceptance": + return pool[-1] + return pool[0] + + +def _maybe_load_drafter(model_card: ModelCard) -> tuple[ModelId, Model] | None: + """Load a drafter model declared on ``model_card``, if any. + + Returns the chosen ``(drafter_id, drafter_model)`` pair on success, or + ``None`` when the card declares no drafter, the chosen drafter's weights + are not on disk, ``EXO_DISABLE_DRAFTER`` is set, or the load itself + fails. Drafter loading failures are logged and swallowed: the target + model continues to load and inference falls back to standard (non-speculative) decoding. This helper is intentionally single-device only. Multi-device distributed @@ -213,22 +610,28 @@ def _maybe_load_drafter(model_card: ModelCard) -> Model | None: today (see ``mlx_generate``), so loading a drafter on those ranks would just waste memory. """ - drafter_id = model_card.drafter_model_id - if drafter_id is None: + candidates = list(model_card.drafter_model_ids) + if not candidates: return None if _drafter_disabled_by_env(): logger.info( - f"Drafter {drafter_id} declared by {model_card.model_id} but " + f"Drafter declared by {model_card.model_id} but " f"{EXO_DISABLE_DRAFTER_ENV} is set; skipping drafter load." ) return None + preference = _drafter_preference() + drafter_id = _select_drafter_id(candidates, preference) + if drafter_id is None: + return None + drafter_path = resolve_existing_model(drafter_id) if drafter_path is None: logger.warning( - f"Drafter {drafter_id} declared by {model_card.model_id} is not " - "downloaded; falling back to standard decoding. Pre-download the " - "drafter to enable speculative decoding." + f"Drafter {drafter_id} (preferred '{preference}') declared by " + f"{model_card.model_id} is not downloaded; falling back to " + "standard decoding. Pre-download the drafter to enable " + "speculative decoding." ) return None @@ -243,10 +646,26 @@ def _maybe_load_drafter(model_card: ModelCard) -> Model | None: ) return None logger.info( - f"Loaded drafter {drafter_id} for {model_card.model_id} in " - f"{(time.perf_counter() - drafter_start):.2f}s" + f"Loaded drafter {drafter_id} (preferred '{preference}') for " + f"{model_card.model_id} in {(time.perf_counter() - drafter_start):.2f}s" ) - return cast(Model, drafter_model) + return drafter_id, cast(Model, drafter_model) + + +def _drafter_weight_size_bytes(drafter_id: ModelId) -> int: + """Best-effort drafter-on-disk size for the wired-memory bump. + + Walks the drafter directory and sums file sizes. Returns 0 on any error + (the drafter weights aren't critical-path so we'd rather under-wire than + crash). + """ + drafter_path = resolve_existing_model(drafter_id) + if drafter_path is None: + return 0 + try: + return sum(p.stat().st_size for p in drafter_path.rglob("*") if p.is_file()) + except OSError: + return 0 def load_mlx_items( @@ -255,15 +674,46 @@ def load_mlx_items( ) -> Generator[ ModelLoadingResponse, None, - tuple[Model, TokenizerWrapper, "VisionProcessor | None", Model | None], + tuple[ + Model, + TokenizerWrapper, + "VisionProcessor | None", + Model | None, + ModelId | None, + ], ]: - set_wired_limit_for_model(get_weights_size(bound_instance.bound_shard)) + target_card = bound_instance.bound_shard.model_card + target_size = get_weights_size(bound_instance.bound_shard) + + # Pre-include drafter size in the wired-memory limit so the OS doesn't + # page out drafter weights between requests. We have to make this decision + # *before* loading the target because `set_wired_limit_for_model` configures + # the limit once. Skip the bump for asymmetric placements: the drafter + # weights live on a different node so they don't draw from this rank's + # wired pool. + combined_size = target_size + if ( + group is None + and bound_instance.instance.drafter_placement is None + and not _drafter_disabled_by_env() + and target_card.drafter_model_ids + ): + chosen = _select_drafter_id( + list(target_card.drafter_model_ids), _drafter_preference() + ) + if chosen is not None: + drafter_bytes = _drafter_weight_size_bytes(chosen) + if drafter_bytes > 0: + combined_size = target_size + Memory.from_bytes(drafter_bytes) + + set_wired_limit_for_model(combined_size) drafter_model: Model | None = None + drafter_id: ModelId | None = None if group is None: logger.info(f"Single device used for {bound_instance.instance}") - model_path = build_model_path(bound_instance.bound_shard.model_card.model_id) + model_path = build_model_path(target_card.model_id) start_time = time.perf_counter() model, _ = load_model(model_path, lazy=True, strict=False) # Eval layers one by one for progress reporting @@ -283,7 +733,16 @@ def load_mlx_items( logger.info(f"Time taken to load model: {(end_time - start_time):.2f}s") tokenizer = get_tokenizer(model_path, bound_instance.bound_shard) - drafter_model = _maybe_load_drafter(bound_instance.bound_shard.model_card) + # Skip the local in-process drafter when an asymmetric drafter + # rank exists for this instance: ``DrafterPlacement`` means the + # drafter is a separate ``DrafterRunner`` reachable via + # ``RemoteTransport`` over the parent group, and loading a + # second copy locally would just duplicate the weights and + # confuse the spec-decode loop. + if bound_instance.instance.drafter_placement is None: + drafter_pair = _maybe_load_drafter(target_card) + if drafter_pair is not None: + drafter_id, drafter_model = drafter_pair else: logger.info("Starting distributed init") @@ -297,6 +756,18 @@ def load_mlx_items( f"Time taken to shard and load model: {(end_time - start_time):.2f}s" ) + # Asymmetric multi-rank placement: the drafter weights live on + # a separate ``DrafterRunner``, so this rank doesn't load them + # locally (no ``drafter_model``). The model id, however, is + # known from the placement and is the only piece downstream + # telemetry needs to surface "this request used the X drafter". + # Without this, ``GenerationStats.drafter_model_id`` stays + # ``None`` for every multi-target asymmetric request even + # though the drafter is materially serving traffic. + drafter_placement = bound_instance.instance.drafter_placement + if drafter_placement is not None: + drafter_id = drafter_placement.drafter_model_id + mx.clear_cache() vision_config = bound_instance.bound_shard.model_card.vision @@ -321,7 +792,7 @@ def load_mlx_items( else: vision_processor = None - return cast(Model, model), tokenizer, vision_processor, drafter_model + return cast(Model, model), tokenizer, vision_processor, drafter_model, drafter_id def shard_and_load( @@ -943,9 +1414,7 @@ def mlx_cleanup( def mx_any(bool_: bool, group: mx.distributed.Group | None) -> bool: if group is None: return bool_ - num_true = mx.distributed.all_sum( - mx.array(bool_), group=group, stream=mx.default_stream(mx.Device(mx.cpu)) - ) + num_true = mx.distributed.all_sum(mx.array(bool_), group=group) mx.eval(num_true) return num_true.item() > 0 @@ -953,12 +1422,283 @@ def mx_any(bool_: bool, group: mx.distributed.Group | None) -> bool: def mx_barrier(group: mx.distributed.Group | None): if group is None: return - mx.eval( - mx.distributed.all_sum( - mx.array(1.0), group=group, stream=mx.default_stream(mx.Device(mx.cpu)) + mx.eval(mx.distributed.all_sum(mx.array(1.0), group=group)) + + +# ``int32`` lower / upper bounds. Values broadcast through +# :func:`mx_broadcast_int_list` must be non-negative (the wire protocol +# uses unsigned token IDs and length prefixes) AND fit in int32 with +# room for the all-sum to land back in range. Since exactly one rank +# contributes the values and the rest contribute zero, the sum is the +# root's values per element regardless of group size, so the per-element +# bound is plain int32 max. We tighten to ``2**31 - 1`` (positive int32 +# max) and reject negatives explicitly so a caller passing a Python +# ``-1`` doesn't silently wrap into a 4-billion-ish "valid" int32. +_MX_BROADCAST_MAX_VALUE: Final[int] = (1 << 31) - 1 +# Toggle to dump every broadcast call's send/recv buffers. Set via +# ``EXO_PROBE_BROADCAST=1`` for ad-hoc diagnostics; leave off in +# steady state because the per-token logging spam quickly dominates. +_BROADCAST_PROBE: Final[bool] = bool(os.environ.get("EXO_PROBE_BROADCAST")) + + +def mx_broadcast_int_list( + values: list[int] | None, + length: int, + group: mx.distributed.Group | None, + *, + is_root: bool, +) -> list[int]: + """Broadcast a fixed-length int list from one rank to all peers. + + Implementation: rank-0-fanout via :func:`mx.distributed.send` / + :func:`mx.distributed.recv`. Root issues one send to every peer in + the group; each peer issues exactly one matching recv from rank 0. + + History note: previous revisions used ``all_sum`` of an int32 + buffer where non-root ranks contributed zeros, which seems + elegant but turned out to corrupt the wire on multi-target spec + decode under JACCL. The model's TP layers also use ``all_sum`` on + the same target group, every layer, on float32 buffers; the + spec-decode hot path interleaved one ``all_sum`` for drafts and + one for sampled tokens between every pair of TP all-reduces. With + >100 in-flight ``all_sum`` collectives per round all on the same + group, JACCL's pairing logic occasionally matched our int32 + "broadcast" on rank A against the model's float32 TP all-reduce + on rank B, scrambling the int32 buffer. Symptom: token ids + ~10^9 emitted by the spec loop, IndexError deep in the SPM + detokenizer. Switching to ``send`` / ``recv`` makes the + broadcast a fundamentally different primitive than the TP + all-reduce, so JACCL has no opportunity to merge them. + + The fixed-length contract means the caller pads to ``length`` on + root and both ranks agree on ``length`` ahead of time, which keeps + the recv shape known statically. + + Args: + values: On root, a list of exactly ``length`` ints to broadcast. + Each value must be in ``[0, 2**31 - 1]``. Negative values are + rejected explicitly so a stray ``-1`` doesn't silently wrap to + ``0xFFFFFFFF`` and corrupt the broadcast. Ignored on non-root. + length: Buffer size, agreed by all ranks. Must be ``>= 1``. + group: Distributed group; ``None`` is a single-rank short-circuit + that simply returns ``values`` (root-only). + is_root: ``True`` on the rank holding the source values; ``False`` + elsewhere. Exactly one rank in ``group`` must pass ``True``. + + Returns: + A list of ``length`` ints identical on every rank in ``group``, + equal to root's ``values``. + + Raises: + ValueError: ``length`` is non-positive, the root's ``values`` are + ``None`` or wrong length, or any root value is out of int32 + range. These are caller bugs, not runtime conditions. + """ + if length < 1: + raise ValueError(f"mx_broadcast_int_list length must be >= 1, got {length}") + + if group is None: + if not is_root: + raise ValueError( + "mx_broadcast_int_list: single-rank short-circuit requires " + "is_root=True (only the root has source values)" + ) + if values is None or len(values) != length: + raise ValueError( + "mx_broadcast_int_list: single-rank call requires " + f"values of length {length}, got " + f"{None if values is None else len(values)}" + ) + _validate_broadcast_values(values) + return list(values) + + group_size = group.size() + + if is_root: + if values is None or len(values) != length: + raise ValueError( + "mx_broadcast_int_list root rank requires values of " + f"length {length}, got {None if values is None else len(values)}" + ) + _validate_broadcast_values(values) + send_buffer = mx.array(values, dtype=mx.int32) + for dst in range(1, group_size): + sent = mx.distributed.send(send_buffer, dst=dst, group=group) + mx.eval(sent) + if _BROADCAST_PROBE: + logger.warning( + f"mx_broadcast_int_list ROOT sent {values} (len={length})" + ) + return list(values) + + received = mx.distributed.recv( + shape=(length,), dtype=mx.int32, src=0, group=group + ) + mx.eval(received) + out = [int(v) for v in cast(list[int], received.tolist())] + if _BROADCAST_PROBE: + logger.warning( + f"mx_broadcast_int_list PEER recvd {out} (expected len={length})" ) + return out + + +def target_peer_broadcast_int_list( + values: list[int] | None, + length: int, + fanout: TargetPeerFanout, + *, + is_root: bool, +) -> list[int]: + """Broadcast a fixed-length signed int list over the TCP fanout. + + Drop-in replacement for :func:`mx_broadcast_int_list` on the + spec-decode hot path. Same shape contract (``length`` agreed by + every rank up front; root passes ``values``, peers pass + ``None``); the only difference is that this version rides direct + TCP sockets instead of ``mx.distributed.send`` / ``recv``, + sidestepping the JACCL int/float wire-conflation bug entirely. + + Wire format (every frame): ``length`` little-endian signed int32 + values, no header. The peer side knows ``length`` from the same + shape contract the caller agreed to. + + Args: + values: On root, exactly ``length`` int32-range values to + broadcast. Ignored on peers. + length: Buffer size, agreed by all ranks. Must be ``>= 1``. + fanout: Pre-built fanout from :func:`_maybe_setup_target_peer_fanout`. + Carries the per-rank role (rank 0 vs peer) and the connected + sockets. Mismatched ``is_root`` vs ``fanout.rank`` is a caller + bug and raises :class:`ValueError`. + is_root: ``True`` on rank 0, ``False`` elsewhere. Asserted + against ``fanout.rank``. + + Returns: + A list of ``length`` ints identical on every rank, equal to + root's ``values``. + + Raises: + ValueError: caller-bug conditions (length, values shape, + is_root vs rank mismatch). + ConnectionError: a peer closed the socket mid-frame; surfaces + as a runner failure for the supervisor to rebuild. + """ + import socket as _socket + + from exo.worker.engines.mlx.generator.target_peer_socket import ( + recv_int32_frame, + send_int32_frame, ) + if length < 1: + raise ValueError( + f"target_peer_broadcast_int_list length must be >= 1, got {length}" + ) + if is_root != (fanout.rank == 0): + raise ValueError( + f"target_peer_broadcast_int_list is_root={is_root} disagrees " + f"with fanout.rank={fanout.rank}; exactly one rank in the " + "fanout must pass is_root=True" + ) + if is_root: + if values is None or len(values) != length: + raise ValueError( + "target_peer_broadcast_int_list root rank requires values " + f"of length {length}, got " + f"{None if values is None else len(values)}" + ) + for sock in fanout.peer_sockets.values(): + assert isinstance(sock, _socket.socket) # narrow object -> socket + send_int32_frame(sock, values) + return list(values) + sock = fanout.rank_zero_socket + if sock is None: + raise RuntimeError( + "target_peer_broadcast_int_list called on peer rank but " + "fanout.rank_zero_socket is None; bootstrap must populate it" + ) + assert isinstance(sock, _socket.socket) + return recv_int32_frame(sock, length) + + +def mx_all_sum_int_list( + values: list[int], + length: int, + group: mx.distributed.Group | None, +) -> list[int]: + """Element-wise ``all_sum`` of an ``int32`` list across all ranks. + + Unlike :func:`mx_broadcast_int_list` (one-rank-contributes), every + rank contributes its own ``values`` and every rank sees the + element-wise sum. Used by the two-collective intersection + protocol in :func:`mx_all_gather_tasks` to vote on which tasks + every rank has locally: each rank emits a ``[0, 1]`` indicator + vector and the sum equals the group's vote count per slot. + + Same wire reliability story as :func:`mx_broadcast_int_list`: + rides MLX's well-exercised ``all_sum`` primitive, validates + int32 bounds explicitly so a stray Python ``-1`` doesn't wrap + silently. + + Args: + values: This rank's contribution. Length must equal ``length``; + each value must be in ``[0, 2**31 - 1]``. After all-sum the + per-element bound is ``group_size * max(value)`` -- callers + sizing for ``[0, 1]`` indicators sit far below int32 max for + any plausible ``group_size``. + length: Buffer size, agreed by all ranks. + group: Distributed group; ``None`` short-circuits to a copy of + ``values`` (single-rank vote sums to itself). + + Returns: + A list of length ``length`` with the element-wise sum of every + rank's ``values``, identical on every rank. + + Raises: + ValueError: ``length`` is non-positive, ``values`` length + mismatches, or any value is out of int32 range. + """ + if length < 1: + raise ValueError(f"mx_all_sum_int_list length must be >= 1, got {length}") + if len(values) != length: + raise ValueError( + f"mx_all_sum_int_list values must have length {length}, got {len(values)}" + ) + _validate_broadcast_values(values) + if group is None: + return list(values) + buffer = mx.array(values, dtype=mx.int32) + # ``all_sum`` is acceptable here because :func:`mx_all_sum_int_list` + # is only called from the task agreement protocol, which fires at + # admit boundaries -- not on the per-token spec-decode hot path. + # The thrash that broke the broadcast helper (interleaving with + # the model's TP all-reduce 100+ times per round) does not apply + # at this call frequency. + summed = mx.distributed.all_sum(buffer, group=group) + mx.eval(summed) + return [int(v) for v in cast(list[int], summed.tolist())] + + +def _validate_broadcast_values(values: list[int]) -> None: + """Range-check root-side broadcast values. + + Centralised so both the single-rank short-circuit and the multi- + rank all-sum path enforce identical contracts. Linear scan; for + ``length`` values this is microseconds and runs once per round on + the spec-decode hot path -- amortised free against an MLX + collective. + """ + for index, value in enumerate(values): + if value < 0 or value > _MX_BROADCAST_MAX_VALUE: + raise ValueError( + f"mx_broadcast_int_list values must be in " + f"[0, {_MX_BROADCAST_MAX_VALUE}]; " + f"index {index} = {value} is out of range " + f"(negatives wrap silently in int32 all-sum; values " + f">= 2**31 overflow)" + ) + def _parse_kimi_tool_calls(text: str): import regex as re @@ -995,54 +1735,188 @@ def _parse_single_tool(text: str) -> dict[str, Any]: return [_parse_single_tool(text)] +# Maximum number of tasks the agreement protocol can carry per round. +# Sized to ``EXO_MAX_CONCURRENT_REQUESTS`` (default 8) plus headroom for +# transient ``_maybe_queue`` build-up; tasks beyond this slot count get +# deferred to the next agreement round, never lost. Matches the sizing +# the supervisor already enforces via ``max_concurrent_tasks`` at the +# generator layer, so steady-state oversubscription is not a real +# concern. +_MX_AGREE_MAX_TASKS: Final[int] = 16 +# UUID4 string length (``len("01234567-...-...-...-............") == 36``). +# The agreement protocol broadcasts task IDs as fixed-width ASCII so +# every rank can decode the same canonical payload. Hashes are not +# enough on their own because root needs to specify *which* tasks are +# in the agreed set without leaving the consumer guessing on collision. +_MX_TASK_ID_BYTES: Final[int] = 36 +# Buffer layout: ``[count, task_id_bytes_0, task_id_bytes_1, ...]`` where +# each task_id slot is ``_MX_TASK_ID_BYTES`` ints (one ASCII char per +# int32 slot). A char fits trivially in int32, and using one slot per +# char avoids endian / packing concerns at the cost of ~4x bandwidth -- +# acceptable since this only runs at admit boundaries, not per-token. +_MX_AGREE_BUFFER_LEN: Final[int] = 1 + _MX_AGREE_MAX_TASKS * _MX_TASK_ID_BYTES + + def mx_all_gather_tasks( tasks: list[TextGeneration], group: mx.distributed.Group | None, ) -> tuple[list[TextGeneration], list[TextGeneration]]: - def encode_task_id(task_id: TaskId) -> list[int]: - utf8_task_id = task_id.encode() - return [ - int.from_bytes(utf8_task_id[i : i + 1]) for i in range(len(utf8_task_id)) - ] - - def decode_task_id(encoded_task_id: list[int]) -> TaskId: - return TaskId( - bytes.decode(b"".join((x).to_bytes(length=1) for x in encoded_task_id)) + """Two-phase intersection-based task agreement across target ranks. + + Returns ``(agreed, leftover)`` where: + + * ``agreed``: tasks every rank in the group has locally, in the + canonical order set by the root rank. Identical on every + rank by construction (the consensus is computed inside the + function, not after the return). + * ``leftover``: this rank's local tasks that didn't make it + into ``agreed`` (either root hasn't seen them yet or another + peer is still waiting on libp2p delivery). Every rank stashes + its leftover for the next agreement cycle. + + Wire protocol: + Phase 1 (broadcast root's IDs): + Root encodes ``[count, id_0_chars, ..., id_(count-1)_chars]`` + into a fixed ``_MX_AGREE_BUFFER_LEN`` int32 buffer + (zero-padded slots) and broadcasts via + :func:`mx_broadcast_int_list`. Non-root ranks decode it as + their canonical view of "candidate tasks". + Phase 2 (vote on intersection): + Every rank emits a ``[0, 1]`` vote vector indexed by phase-1 + slot: 1 means "I have this task locally", 0 means "I don't". + :func:`mx_all_sum_int_list` element-wise-sums the votes + across the group. A slot whose sum equals ``group_size`` is + agreed -- every rank had it. Slots below ``group_size`` are + deferred (they re-enter the next round once delivery + completes). + + Why intersection instead of root-authoritative: + Root-authoritative agreement (root admits all its tasks; non- + root admits only the subset it has locally) breaks the + collective-count contract. If root admits a task the non-root + doesn't have, non-root's ``_active_tasks`` stays empty, its + next ``step()`` calls ``agree_on_tasks`` again while root is + mid-``next(gen)`` issuing spec-loop ``all_sum`` collectives. + The two collective streams interleave on the wire and corrupt + each other's payloads (manifests as ``IndexError: list index + out of range`` in the detokenizer because broadcast tokens + arrive scrambled). Intersection keeps both ranks at the same + collective count: every rank that admits a task admits it on + the same step. + + Why ``group is None`` short-circuits without touching MLX: + ``mx.distributed.all_gather(group=None)`` delegates to MLX's + default group, which on an asymmetric runner is the parent + (target+drafter) group. The drafter rank is busy in + ``drafter_serve_loop`` doing its own ``recv`` on that same + default group, so an unguarded all-gather here cross-talks + with the drafter's wire protocol. When ``group is None`` we + are by construction the only participating rank, so every + task is trivially "agreed". + + Cost: + Two collectives per call (one broadcast + one all-sum), each + on small int32 buffers (~600 bytes). On Apple Silicon JACCL + this is sub-millisecond and runs only at admit boundaries, + not per token. + """ + if group is None: + return list(tasks), [] + + is_root = group.rank() == 0 + group_size = group.size() + + # ----- Phase 1: root broadcasts canonical task ID list ----- + if is_root: + admitted = tasks[:_MX_AGREE_MAX_TASKS] + payload: list[int] = [len(admitted)] + for task in admitted: + payload.extend(_encode_task_id(task.task_id)) + payload.extend([0] * (_MX_AGREE_BUFFER_LEN - len(payload))) + broadcast = mx_broadcast_int_list( + payload, _MX_AGREE_BUFFER_LEN, group, is_root=True + ) + else: + broadcast = mx_broadcast_int_list( + None, _MX_AGREE_BUFFER_LEN, group, is_root=False ) - uuid_byte_length = 36 + count = broadcast[0] + if count < 0 or count > _MX_AGREE_MAX_TASKS: + # Programming error: root encoded a count outside the agreed + # bounds. Hard failure -- buffer corrupt, can't decode safely. + raise RuntimeError( + f"mx_all_gather_tasks: broadcast count {count} outside " + f"[0, {_MX_AGREE_MAX_TASKS}]; broadcast buffer corrupt" + ) - n_tasks = len(tasks) - all_counts = cast( - list[int], - mx.distributed.all_gather(mx.array([n_tasks]), group=group).tolist(), - ) - max_tasks = max(all_counts) - world_size: int = 1 if group is None else group.size() + canonical_ids: list[str] = [] + for i in range(count): + start = 1 + i * _MX_TASK_ID_BYTES + end = start + _MX_TASK_ID_BYTES + canonical_ids.append(_decode_task_id(broadcast[start:end])) + + # ----- Phase 2: every rank votes on which canonical IDs it has ----- + local_by_id: dict[str, TextGeneration] = {t.task_id: t for t in tasks} + vote = [1 if cid in local_by_id else 0 for cid in canonical_ids] + vote.extend([0] * (_MX_AGREE_MAX_TASKS - count)) + summed_vote = mx_all_sum_int_list(vote, _MX_AGREE_MAX_TASKS, group) + + # ----- Phase 3: build agreed (intersection) and leftover ----- + agreed: list[TextGeneration] = [] + for i, cid in enumerate(canonical_ids): + if summed_vote[i] != group_size: + continue + local = local_by_id.pop(cid, None) + if local is None: + # Root contributed this ID but isn't a vote-counter on + # itself -- only possible if we're not root and we don't + # have the task. The vote sum requirement above handles + # this case (we'd have voted 0 and it wouldn't reach + # ``group_size``); reaching here means buffer corruption. + raise RuntimeError( + f"mx_all_gather_tasks: canonical id {cid} agreed by " + "vote but missing locally; vote/broadcast desync" + ) + agreed.append(local) + leftover = list(local_by_id.values()) + return agreed, leftover - if max_tasks == 0: - return [], [] - padded = [encode_task_id(task.task_id) for task in tasks] + [ - [0] * uuid_byte_length - ] * (max_tasks - n_tasks) +def _encode_task_id(task_id: str) -> list[int]: + """ASCII-encode ``task_id`` into ``_MX_TASK_ID_BYTES`` int32 slots. - assert all(len(encoded_task_id) == uuid_byte_length for encoded_task_id in padded) + Right-pads with zeros if ``task_id`` is shorter than the slot + count; raises if it's longer or contains non-ASCII (UUIDs are pure + ASCII by construction, so any rejection here points at upstream + bugs). + """ + encoded = task_id.encode("ascii") + if len(encoded) > _MX_TASK_ID_BYTES: + raise ValueError( + f"task_id {task_id!r} exceeds {_MX_TASK_ID_BYTES} bytes; " + "agreement buffer slot is sized for UUID4 strings only" + ) + out = [int(b) for b in encoded] + out.extend([0] * (_MX_TASK_ID_BYTES - len(out))) + return out - gathered = cast( - list[list[list[int]]], - mx.distributed.all_gather(mx.array(padded), group=group) - .reshape(world_size, max_tasks, -1) - .tolist(), - ) - all_task_ids: list[list[TaskId]] = [ - [decode_task_id(encoded_task_id) for encoded_task_id in rank_tasks[:count]] - for rank_tasks, count in zip(gathered, all_counts, strict=True) - ] - agreed_ids = set[TaskId].intersection(*(set(tids) for tids in all_task_ids)) +def _decode_task_id(slots: list[int]) -> str: + """Inverse of :func:`_encode_task_id`: int32 slots -> ASCII string. - local_tasks = {task.task_id: task for task in tasks} - agreed = [local_tasks[tid] for tid in sorted(agreed_ids)] - different = [task for task in tasks if task.task_id not in agreed_ids] - return agreed, different + Stops at the first zero byte (the encode pad), so the result is + bounded by ``_MX_TASK_ID_BYTES``. Any non-ASCII byte is rejected + locally rather than silently coerced; the broadcast contract + requires ASCII-only IDs. + """ + chars: list[str] = [] + for value in slots: + if value == 0: + break + if value < 0 or value > 127: + raise ValueError( + f"task_id slot {value} outside ASCII range; broadcast payload corrupt" + ) + chars.append(chr(value)) + return "".join(chars) diff --git a/src/exo/worker/main.py b/src/exo/worker/main.py index b35f946aac..7dff3065f5 100644 --- a/src/exo/worker/main.py +++ b/src/exo/worker/main.py @@ -3,7 +3,7 @@ from datetime import datetime, timezone import anyio -from anyio import fail_after, to_thread +from anyio import fail_after, move_on_after, to_thread from loguru import logger from exo.api.types import ImageEditsTaskParams @@ -109,6 +109,7 @@ async def run(self): tg.start_soon(self._forward_info, info_recv) tg.start_soon(self.plan_step) tg.start_soon(self._event_applier) + tg.start_soon(self._reconcile_instance_backoff) tg.start_soon(self._poll_connection_updates) finally: # Actual shutdown code - waits for all tasks to complete before executing. @@ -179,6 +180,17 @@ async def _event_applier(self): if isinstance(event, CustomModelCardDeleted): await delete_custom_card(event.model_id) + async def _reconcile_instance_backoff(self) -> None: + while True: + await anyio.sleep(1) + self._reconcile_instance_backoff_once() + + def _reconcile_instance_backoff_once(self) -> None: + live_instances = set(self.state.instances) + for instance_id in self._instance_backoff.tracked_keys(): + if instance_id not in live_instances: + self._instance_backoff.reset(instance_id) + async def plan_step(self): while True: await anyio.sleep(0.1) @@ -356,13 +368,25 @@ async def plan_step(self): await self._start_runner_task(task) async def shutdown(self): + self.event_sender.close() + self.command_sender.close() + self.download_command_sender.close() + for runner in self.runners.values(): + runner.shutdown() self._tg.cancel_tasks() - await self._stopped.wait() + with move_on_after(5) as scope: + await self._stopped.wait() + if scope.cancel_called: + logger.warning("Timed out waiting for Worker shutdown") async def _start_runner_task(self, task: Task): if (instance := self.state.instances.get(task.instance_id)) is not None: + # ``all_node_to_runner`` resolves both target and drafter ranks + # for asymmetric placement; ``node_to_runner`` alone misses the + # drafter rank because it lives on ``instance.drafter_placement``, + # not on ``shard_assignments``. await self.runners[ - instance.shard_assignments.node_to_runner[self.node_id] + instance.all_node_to_runner[self.node_id] ].start_task(task) def _create_supervisor(self, task: CreateRunner) -> RunnerSupervisor: diff --git a/src/exo/worker/plan.py b/src/exo/worker/plan.py index 3824e4bb7a..c47e6554c8 100644 --- a/src/exo/worker/plan.py +++ b/src/exo/worker/plan.py @@ -87,13 +87,52 @@ def _kill_runner( runner_id=runner_id, ) - for ( - global_runner_id - ) in runner.bound_instance.instance.shard_assignments.node_to_runner.values(): + # Restart-cascade rule: only fires when our local rank is + # ``RunnerRunning`` (mid-task), which guarantees we previously + # cleared the bootstrap collective with every peer rank in lock- + # step (warmup-complete on all ranks is a precondition for + # ``RunnerRunning`` -- see ``handle_generation_tasks``). If a + # peer is now ``RunnerIdle``, that is a backward jump only + # reachable by a process restart; the transient ``RunnerFailed`` + # was gossiped too briefly for the rule above to fire (the + # supervisor respawned the runner immediately and the new + # process emitted ``RunnerIdle`` right away). Without this rule + # the bootstrap predicate (``all_runners_connecting`` in + # ``_init_distributed_backend``) never fires and the respawned + # peer is stuck in ``RunnerIdle`` forever -- the failure mode + # observed in the K=8 sweep regression at 14:35:05. + # + # We restrict the trigger to ``RunnerRunning`` (not + # ``RunnerLoaded``/``RunnerReady``) because during initial + # bootstrap a peer can legitimately sit at ``RunnerIdle`` while + # we have completed our own loading -- ``LoadModel`` happens + # per-rank without a collective barrier (see ``runner.py`` + # case ``LoadModel``), so warmup-gate predicates need to keep + # waiting rather than tearing the cluster down. + instance = runner.bound_instance.instance + # Use ``all_runner_ids`` (target + drafter) so the staleness + # predicate fires for asymmetric placements where the drafter + # is the only peer (single-target + drafter on a different + # node). + is_multi_rank_instance = len(instance.all_runner_ids) > 1 + local_is_running = isinstance(runner.status, RunnerRunning) + + for global_runner_id in instance.all_runner_ids: if runner_id == global_runner_id: continue - if isinstance(all_runners.get(global_runner_id, None), RunnerFailed): + peer_status = all_runners.get(global_runner_id, None) + if isinstance(peer_status, RunnerFailed): + return Shutdown( + instance_id=instance_id, + runner_id=runner_id, + ) + + if ( + is_multi_rank_instance + and local_is_running + and isinstance(peer_status, RunnerIdle) + ): return Shutdown( instance_id=instance_id, runner_id=runner_id, @@ -108,7 +147,12 @@ def _create_runner( instance_backoff: KeyedBackoff[InstanceId], ) -> CreateRunner | None: for instance in instances.values(): - runner_id = instance.shard_assignments.node_to_runner.get(node_id, None) + # ``all_node_to_runner`` includes the asymmetric drafter rank + # when ``instance.drafter_placement`` is set, so the drafter + # node spawns its drafter runner the same way target nodes + # spawn target runners. + per_node_runners = instance.all_node_to_runner + runner_id = per_node_runners.get(node_id, None) if runner_id is None: continue @@ -118,7 +162,7 @@ def _create_runner( # don't create runners if any other nodes have runners that have failed - wait for them to fix themselves first. instance_has_failed_runner = any( isinstance(all_runners.get(remote_runner_id), RunnerFailed) - for remote_runner_id in instance.shard_assignments.node_to_runner.values() + for remote_runner_id in per_node_runners.values() if remote_runner_id != runner_id ) we_have_failed_before = isinstance(all_runners.get(runner_id), RunnerFailed) @@ -148,6 +192,14 @@ def _model_needs_download( } for runner in runners.values(): + # The drafter rank loads its model from disk; placement assumes + # the operator has pre-downloaded the drafter weights on the + # eligible node. Auto-download for drafter ranks is a TODO -- + # for now, the drafter runner fails loudly at load time if the + # weights are missing and the user fixes the cluster. + if runner.bound_instance.is_drafter_rank: + continue + model_id = runner.bound_instance.bound_shard.model_card.model_id if ( isinstance(runner.status, RunnerIdle) @@ -173,40 +225,68 @@ def _init_distributed_backend( ): for runner in runners.values(): instance = runner.bound_instance.instance - shard_assignments = instance.shard_assignments + runner_id = runner.bound_instance.bound_runner_id + bound_instance = runner.bound_instance + + runner_is_idle = isinstance(runner.status, RunnerIdle) + if not runner_is_idle: + continue - is_single_node_instance = len(shard_assignments.runner_to_shard) == 1 - if is_single_node_instance: + # Asymmetric drafter rank: dial-only, no ``mx.distributed`` init. + # Dispatch the ConnectToGroup task as soon as the drafter is + # idle. ``dial_target`` retries with backoff so an early dial + # before target rank 0 binds is recoverable. Decoupling the + # drafter from the target's collective barrier is what lets a + # multi-target asymmetric instance work without + # ``Group.split``. + if bound_instance.is_drafter_rank: + return ConnectToGroup(instance_id=instance.instance_id) + + # Single-target symmetric: no mx.distributed group at all. + # Single-target asymmetric *with* a drafter still needs the + # target rank to enter ``ConnectToGroup`` so it can bind the + # drafter listener. Differentiate via the placement. + is_single_rank_target = instance.parent_group_size == 1 + if is_single_rank_target and instance.drafter_placement is None: continue - runner_is_idle = isinstance(runner.status, RunnerIdle) - all_runners_connecting = all( + # Target-only barrier: drafter ranks are dispatched in the + # branch above and are NOT members of any ``mx.distributed`` + # group under the v3+ wire. Iterate ``shard_assignments`` so + # we get the target ranks alone. + target_runner_ids = list(instance.shard_assignments.runner_to_shard.keys()) + all_target_connecting = all( isinstance( - all_runners.get(global_runner_id), + all_runners.get(target_runner_id), (RunnerConnecting, RunnerIdle), ) - for global_runner_id in shard_assignments.runner_to_shard + for target_runner_id in target_runner_ids ) - if not (runner_is_idle and all_runners_connecting): + if not all_target_connecting: continue - runner_id = runner.bound_instance.bound_runner_id - - shard = runner.bound_instance.bound_shard - device_rank = shard.device_rank - world_size = shard.world_size - - assert device_rank < world_size - assert device_rank >= 0 - - accepting_ranks = device_rank < world_size - 1 - - # Rank = n-1 - connecting_rank_ready = device_rank == world_size - 1 and all( - isinstance(all_runners.get(global_runner_id, None), RunnerConnecting) - for global_runner_id in shard_assignments.runner_to_shard - if global_runner_id != runner_id + if is_single_rank_target: + # Single target rank in asymmetric placement: it still has + # to enter ConnectToGroup to bind the drafter listener and + # accept the dial. No mx.distributed barrier to honour. + return ConnectToGroup(instance_id=instance.instance_id) + + # Multi-target ranks: keep the original ordering -- earlier + # ranks dispatch immediately, the last target rank dispatches + # once every other target rank is already RunnerConnecting (or + # later). + parent_size = instance.parent_group_size # target ranks only + parent_rank = bound_instance.parent_rank + assert parent_rank < parent_size + assert parent_rank >= 0 + + accepting_ranks = parent_rank < parent_size - 1 + + connecting_rank_ready = parent_rank == parent_size - 1 and all( + isinstance(all_runners.get(target_runner_id, None), RunnerConnecting) + for target_runner_id in target_runner_ids + if target_runner_id != runner_id ) if not (accepting_ranks or connecting_rank_ready): @@ -226,6 +306,10 @@ def _load_model( instance = runner.bound_instance.instance shard_assignments = instance.shard_assignments + # Target shards must all be downloaded before any rank loads; + # the drafter's pre-downloaded weights are the operator's + # responsibility (see _model_needs_download), so we don't gate + # on its DownloadCompleted entry here. all_local_downloads_complete = all( nid in global_download_status and any( @@ -238,8 +322,19 @@ def _load_model( if not all_local_downloads_complete: continue - is_single_node_instance = len(instance.shard_assignments.runner_to_shard) == 1 - if is_single_node_instance and isinstance(runner.status, RunnerIdle): + # Single-target SYMMETRIC instance: no mx.distributed group and + # no drafter wire, so the runner can skip the ConnectToGroup + # collective and go straight to LoadModel. Single-target + # ASYMMETRIC (drafter on a different node) still has to enter + # ConnectToGroup so target rank 0 can bind the drafter socket + # listener; it falls through to the barrier check below. + is_single_rank_target = instance.parent_group_size == 1 + is_symmetric_placement = instance.drafter_placement is None + if ( + is_single_rank_target + and is_symmetric_placement + and isinstance(runner.status, RunnerIdle) + ): return LoadModel(instance_id=instance.instance_id) is_runner_waiting = isinstance(runner.status, RunnerConnected) @@ -249,7 +344,7 @@ def _load_model( all_runners.get(global_runner_id, None), (RunnerConnected, RunnerLoading, RunnerLoaded), ) - for global_runner_id in shard_assignments.runner_to_shard + for global_runner_id in instance.all_runner_ids ) if is_runner_waiting and all_ready_for_model: @@ -264,34 +359,58 @@ def _ready_to_warmup( ) -> StartWarmup | None: for runner in runners.values(): instance = runner.bound_instance.instance - shard_assignments = instance.shard_assignments - shard = runner.bound_instance.bound_shard - device_rank = shard.device_rank runner_id = runner.bound_instance.bound_runner_id - world_size = shard.world_size + bound_instance = runner.bound_instance is_runner_loaded = isinstance(runner.status, RunnerLoaded) + if not is_runner_loaded: + continue - assert device_rank < world_size - assert device_rank >= 0 + # ``RunnerWarmingUp`` is the canonical "ready to run warmup" state + # for an accepting rank, but a peer that has already advanced past + # warmup (``RunnerReady``/``RunnerRunning``) is *strictly past* + # the barrier we care about. Asymmetric drafter rank warmup is + # near-instant (one forward pass) so it can race past + # ``RunnerWarmingUp`` before the connecting rank's plan loop + # observes it; without including the post-warmup states the + # connecting rank stalls in ``RunnerLoaded`` forever. + post_loaded_states = ( + RunnerWarmingUp, + RunnerReady, + RunnerRunning, + ) + + # Drafter rank: warmup is independent (one drafter forward) so + # dispatch as soon as the drafter is RunnerLoaded. + if bound_instance.is_drafter_rank: + return StartWarmup(instance_id=instance.instance_id) + + # Target ranks: keep the rank-0-connector ordering across + # target-only ranks. The drafter rank is excluded from this + # barrier (its own warmup is independent). + parent_rank = bound_instance.parent_rank + parent_size = instance.parent_group_size # target ranks only - # Rank != 0 - accepting_ranks_ready = device_rank > 0 and all( + assert parent_rank < parent_size + assert parent_rank >= 0 + + target_runner_ids = list(instance.shard_assignments.runner_to_shard.keys()) + + accepting_ranks_ready = parent_rank > 0 and all( isinstance( - all_runners.get(global_runner_id, None), - (RunnerLoaded, RunnerWarmingUp), + all_runners.get(target_runner_id, None), + (RunnerLoaded, *post_loaded_states), ) - for global_runner_id in shard_assignments.runner_to_shard + for target_runner_id in target_runner_ids ) - # Rank = 0 - connecting_rank_ready = device_rank == 0 and all( - isinstance(all_runners.get(global_runner_id, None), RunnerWarmingUp) - for global_runner_id in shard_assignments.runner_to_shard - if global_runner_id != runner_id + connecting_rank_ready = parent_rank == 0 and all( + isinstance(all_runners.get(target_runner_id, None), post_loaded_states) + for target_runner_id in target_runner_ids + if target_runner_id != runner_id ) - if is_runner_loaded and (accepting_ranks_ready or connecting_rank_ready): + if accepting_ranks_ready or connecting_rank_ready: return StartWarmup(instance_id=instance.instance_id) return None @@ -338,7 +457,7 @@ def _pending_tasks( if isinstance(runner.status, (RunnerReady, RunnerRunning)) and all( isinstance(all_runners[global_runner_id], (RunnerReady, RunnerRunning)) - for global_runner_id in runner.bound_instance.instance.shard_assignments.runner_to_shard + for global_runner_id in runner.bound_instance.instance.all_runner_ids ): return task diff --git a/src/exo/worker/runner/bootstrap.py b/src/exo/worker/runner/bootstrap.py index 6a617a7a74..127affd077 100644 --- a/src/exo/worker/runner/bootstrap.py +++ b/src/exo/worker/runner/bootstrap.py @@ -1,5 +1,8 @@ +import faulthandler import os import resource +import signal +import sys import loguru @@ -23,6 +26,17 @@ def entrypoint( global logger logger = _logger + # Register SIGUSR1 -> dump Python tracebacks of every thread to stderr. + # Critical for diagnosing TP collective deadlocks: ``sample`` only sees + # C frames (which all reduce to ``cvwait``), but the divergence between + # ranks is at the Python orchestration layer. Sending ``kill -USR1 + # `` while the runner is stuck dumps the full Python stack of + # every thread without needing root for ``py-spy``. + faulthandler.enable(file=sys.stderr, all_threads=True) + faulthandler.register( + signal.SIGUSR1, file=sys.stderr, all_threads=True, chain=False + ) + soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE) resource.setrlimit(resource.RLIMIT_NOFILE, (min(max(soft, 2048), hard), hard)) @@ -32,12 +46,23 @@ def entrypoint( else: os.environ["MLX_METAL_FAST_SYNCH"] = "1" - runner_context = ( - f"instance_id={bound_instance.instance.instance_id} " - f"runner_id={bound_instance.bound_runner_id} " - f"node_id={bound_instance.bound_node_id} " - f"model_id={bound_instance.bound_shard.model_card.model_id}" - ) + if bound_instance.is_drafter_rank: + placement = bound_instance.instance.drafter_placement + assert placement is not None + runner_context = ( + f"instance_id={bound_instance.instance.instance_id} " + f"runner_id={bound_instance.bound_runner_id} " + f"node_id={bound_instance.bound_node_id} " + f"role=drafter " + f"drafter_model_id={placement.drafter_model_id}" + ) + else: + runner_context = ( + f"instance_id={bound_instance.instance.instance_id} " + f"runner_id={bound_instance.bound_runner_id} " + f"node_id={bound_instance.bound_node_id} " + f"model_id={bound_instance.bound_shard.model_card.model_id}" + ) logger.info( f"Runner bootstrap starting {runner_context} " f"fast_synch={os.environ['MLX_METAL_FAST_SYNCH']}" @@ -45,6 +70,23 @@ def entrypoint( # Import main after setting global logger - this lets us just import logger from this module try: + if bound_instance.is_drafter_rank: + # Drafter rank takes a separate code path: load only the + # drafter model, never enter the target generator, run the + # drafter serve loop until OP_SHUTDOWN. Apply the same + # mlx_lm patches the target rank uses so attention / + # rotating-cache fixes apply uniformly. + from exo.worker.engines.mlx.patches import apply_mlx_patches + + apply_mlx_patches() + + from exo.worker.runner.drafter_runner import DrafterRunner + + drafter_runner = DrafterRunner(bound_instance, event_sender, task_receiver) + logger.info(f"Starting drafter runner main loop {runner_context}") + drafter_runner.main() + return + from exo.worker.runner.runner import Runner builder: Builder diff --git a/src/exo/worker/runner/drafter_runner.py b/src/exo/worker/runner/drafter_runner.py new file mode 100644 index 0000000000..0442bcb9db --- /dev/null +++ b/src/exo/worker/runner/drafter_runner.py @@ -0,0 +1,350 @@ +"""Runner for an asymmetric drafter rank. + +The asymmetric placement layer (``master.placement``) selects a +drafter-eligible node whenever a model card lists +:attr:`ModelCard.drafter_eligible_nodes` and at least one eligible host +is socket-reachable from target rank 0. The drafter loads its own +(smaller) drafter model on that node and runs :func:`drafter_serve_loop` +to field forwards from target rank 0 over a direct TCP socket. + +Under the v3+ wire the drafter rank is NOT a member of the target +ranks' ``mx.distributed.Group``. It does not call +``mx.distributed.init`` at all -- it dials +``DrafterPlacement.drafter_socket_host:drafter_socket_port`` and runs +the serve loop over the resulting socket. Decoupling drafter IPC from +``mx.distributed`` lets target ranks of any size run TP/PP collectives +without requiring ``Group.split`` (which jaccl/ring backends do not +implement on Apple Silicon). + +This module follows the same lifecycle as :class:`exo.worker.runner.runner.Runner` +(``Idle -> Connecting -> Connected -> Loading -> Loaded -> WarmingUp -> +Ready -> Running``) so the worker plan's readiness checks (which iterate +``Instance.all_runner_ids``) treat the drafter rank like any other rank. +The internals differ: + + * No target shard, no tokenizer, no chat-completion handling. The + drafter has its own ``ModelCard`` and only loads the drafter + weights. + * No ``Engine`` wrapper. ``StartWarmup`` does a single forward to + JIT-compile Metal kernels, then the drafter steps directly into + :func:`drafter_serve_loop`, which blocks on socket recv until the + target rank sends ``OP_SHUTDOWN``. + * ``Shutdown`` arrives via the worker plan after target ranks have + already sent ``OP_SHUTDOWN``; on the drafter side we just clean up + state. + +The module is import-cheap: it does not pull in any target-side +generator code (``generate.py``, ``batch_generator.py``, etc.). The +drafter runs in its own process with its own model, so memory and +import time stay tight. +""" + +from __future__ import annotations + +import contextlib +import socket +import time +from typing import TYPE_CHECKING, cast, final + +import mlx.core as mx +from loguru import logger as loguru_logger +from mlx_lm.utils import load_model + +from exo.download.download_utils import build_model_path, resolve_existing_model +from exo.shared.types.events import ( + Event, + RunnerStatusUpdated, + TaskAcknowledged, + TaskStatusUpdated, +) +from exo.shared.types.tasks import ( + ConnectToGroup, + LoadModel, + Shutdown, + StartWarmup, + Task, + TaskId, + TaskStatus, +) +from exo.shared.types.worker.instances import BoundInstance, DrafterPlacement +from exo.shared.types.worker.runners import ( + RunnerConnected, + RunnerConnecting, + RunnerIdle, + RunnerLoaded, + RunnerLoading, + RunnerReady, + RunnerRunning, + RunnerShutdown, + RunnerShuttingDown, + RunnerStatus, + RunnerWarmingUp, +) +from exo.utils.channels import ClosedResourceError, EndOfStream, MpReceiver, MpSender + +if TYPE_CHECKING: + from exo.worker.engines.mlx.types import KVCacheType, Model + + +@final +class DrafterRunner: + """Lifecycle manager for the drafter rank in an asymmetric instance. + + Same task-driven state machine as the target runner -- the worker + plan dispatches ``ConnectToGroup``, ``LoadModel``, ``StartWarmup``, + and ``Shutdown`` in order; readiness gates iterate + ``Instance.all_runner_ids`` so the drafter participates in + barriers exactly like a target rank. + """ + + def __init__( + self, + bound_instance: BoundInstance, + event_sender: MpSender[Event], + task_receiver: MpReceiver[Task], + ) -> None: + assert bound_instance.is_drafter_rank, ( + "DrafterRunner can only be constructed for an asymmetric drafter " + "rank; check `bound_instance.is_drafter_rank` before instantiation." + ) + placement = bound_instance.instance.drafter_placement + assert placement is not None + self._placement: DrafterPlacement = placement + + self.bound_instance = bound_instance + self.runner_id = bound_instance.bound_runner_id + self.event_sender = event_sender + self.task_receiver = task_receiver + + self.drafter_socket: socket.socket | None = None + self.draft_model: Model | None = None + + self._setup_start = time.perf_counter() + self._update_status(RunnerIdle()) + loguru_logger.info( + f"DrafterRunner created (runner_id={self.runner_id} " + f"node={bound_instance.bound_node_id} " + f"drafter_model_id={self._placement.drafter_model_id} " + f"drafter_rank={self._placement.drafter_rank})" + ) + + def main(self) -> None: + try: + with self.task_receiver: + for task in self.task_receiver: + if not self._dispatch(task): + return + except (EndOfStream, ClosedResourceError): + loguru_logger.warning("DrafterRunner task stream closed") + + def _dispatch(self, task: Task) -> bool: + """Process one task; return ``False`` to exit the main loop.""" + self._send_task_status(task.task_id, TaskStatus.Running) + match task: + case ConnectToGroup() if isinstance(self.current_status, RunnerIdle): + self._handle_connect(task) + case LoadModel() if isinstance(self.current_status, RunnerConnected): + self._handle_load(task) + case StartWarmup() if isinstance(self.current_status, RunnerLoaded): + self._handle_start_warmup(task) + case Shutdown(): + self._handle_shutdown(task) + return False + case _: + raise ValueError( + f"DrafterRunner received {task.__class__.__name__} outside " + f"of state machine in {self.current_status=}" + ) + return True + + def _handle_connect(self, task: Task) -> None: + """Dial target rank 0's drafter listener; no mx.distributed init. + + Under the v3+ wire the drafter is outside the target's + ``mx.distributed.Group``. ``ConnectToGroup`` is the natural + place to establish the drafter wire (the lifecycle stage runs + in parallel with target ranks initialising mx.distributed, + which gives target rank 0 time to bind before we dial). + :func:`dial_target` retries with backoff up to two minutes, + comfortably covering target rank 0's bind delay. + """ + from exo.worker.engines.mlx.generator.drafter_socket import dial_target + + self._update_status(RunnerConnecting()) + self._acknowledge(task) + host = self._placement.drafter_socket_host + port = self._placement.drafter_socket_port + loguru_logger.info( + f"DrafterRunner dialing target rank 0 at {host}:{port} " + f"(drafter_model_id={self._placement.drafter_model_id})" + ) + self.drafter_socket = dial_target(host, port) + loguru_logger.info( + f"DrafterRunner connected over socket " + f"(drafter_rank={self._placement.drafter_rank})" + ) + self._send_task_status(task.task_id, TaskStatus.Complete) + self._update_status(RunnerConnected()) + + def _handle_load(self, task: Task) -> None: + drafter_id = self._placement.drafter_model_id + drafter_path = resolve_existing_model(drafter_id) + if drafter_path is None: + # Build a fallback path so the error message points at where + # the operator should drop the weights. + expected_path = build_model_path(drafter_id) + raise FileNotFoundError( + f"Drafter weights for {drafter_id} not found on this node " + f"(expected at {expected_path}). Asymmetric drafter " + "placement requires pre-downloading the drafter model " + "on every drafter-eligible node; auto-download is not " + "yet implemented for the drafter rank." + ) + + self._update_status(RunnerLoading(layers_loaded=0, total_layers=0)) + self._acknowledge(task) + + load_start = time.perf_counter() + loguru_logger.info(f"DrafterRunner loading {drafter_id} from {drafter_path}") + model, _ = load_model(drafter_path, lazy=True, strict=False) + mx.eval(model) + self.draft_model = cast("Model", model) + # ``draft_cache`` is no longer pre-allocated -- the serve loop + # multiplexes per-session caches keyed on ``session_id`` (target + # rank's :meth:`RemoteTransport.open_session` allocation) and + # builds each one lazily via ``make_kv_cache(model=...)`` on + # the matching ``OP_PREFILL``. Holding only the model means + # cluster-idle memory stays small (~drafter weights, no KV + # cache); active memory scales linearly with concurrent target + # requests, capped by the runner's ``EXO_MAX_CONCURRENT_REQUESTS``. + loguru_logger.info( + f"DrafterRunner loaded {drafter_id} in " + f"{(time.perf_counter() - load_start):.2f}s" + ) + + self._send_task_status(task.task_id, TaskStatus.Complete) + self._update_status(RunnerLoaded()) + + def _handle_start_warmup(self, task: Task) -> None: + from exo.worker.engines.mlx.cache import make_kv_cache + + assert self.drafter_socket is not None + assert self.draft_model is not None + + self._update_status(RunnerWarmingUp()) + self._acknowledge(task) + + # JIT-compile drafter Metal kernels with a single forward + # against a throwaway cache so the first real spec-decode round + # on the target rank doesn't eat the compile latency. The + # warmup cache is GC'd at the end of this method; per-session + # caches are allocated lazily inside :func:`drafter_serve_loop` + # on each ``OP_PREFILL``. + warmup_start = time.perf_counter() + warmup_cache = make_kv_cache(model=self.draft_model) + seed = mx.array([[0]], dtype=mx.uint32) + _ = self.draft_model(seed, cache=warmup_cache) + mx.eval([c.state for c in warmup_cache]) # type: ignore[reportArgumentType] + del warmup_cache + loguru_logger.info( + f"DrafterRunner warmup complete in " + f"{(time.perf_counter() - warmup_start):.2f}s; " + f"setup_total={(time.perf_counter() - self._setup_start):.2f}s" + ) + + self._send_task_status(task.task_id, TaskStatus.Complete) + # The drafter has no prefill server, so prefill_server_port is None. + self._update_status(RunnerReady(prefill_server_port=None)) + self._update_status(RunnerRunning()) + + # Enter the drafter serve loop. This blocks until the target + # rank sends OP_SHUTDOWN. The serve loop's send/recv use the + # parent group; target rank 0 is conventionally the only target + # rank that drives drafter IPC. + self._serve_loop() + + # OP_SHUTDOWN arrived; transition back to Ready so the worker + # plan's Shutdown task can drive us to RunnerShutdown. + self._update_status(RunnerReady(prefill_server_port=None)) + + def _serve_loop(self) -> None: + from exo.worker.engines.mlx.cache import make_kv_cache + from exo.worker.engines.mlx.generator.remote_drafter import drafter_serve_loop + + assert self.drafter_socket is not None + assert self.draft_model is not None + + # ``num_draft_tokens`` here only sizes the response buffer; the + # spec loop on the target side may issue forwards with + # ``num_forwards`` up to K+1, so we mirror exactly its config. + num_draft_tokens = self._num_draft_tokens() + loguru_logger.info( + f"DrafterRunner entering serve_loop " + f"(K={num_draft_tokens}, transport=tcp_socket)" + ) + # Capture ``draft_model`` in the closure so the serve loop can + # allocate per-session caches lazily without re-entering + # ``DrafterRunner`` state. Dummy assertion here to satisfy the + # type checker (``self.draft_model`` is ``Model | None`` at the + # field level but we asserted not None above). + draft_model = self.draft_model + + def _make_session_cache() -> "KVCacheType": + return make_kv_cache(model=draft_model) + + drafter_serve_loop( + draft_model=draft_model, + make_draft_cache=_make_session_cache, + num_draft_tokens=num_draft_tokens, + sock=self.drafter_socket, + ) + loguru_logger.info("DrafterRunner serve_loop exited via OP_SHUTDOWN") + + @staticmethod + def _num_draft_tokens() -> int: + # Same default the target-side builder uses; reading the env + # var keeps drafter and target in lock-step without an explicit + # IPC message at warmup time. + from exo.worker.runner.llm_inference.batch_generator import ( + DEFAULT_NUM_DRAFT_TOKENS, + EXO_NUM_DRAFT_TOKENS, + parse_env_int, + ) + + return parse_env_int(EXO_NUM_DRAFT_TOKENS, default=DEFAULT_NUM_DRAFT_TOKENS) + + def _handle_shutdown(self, task: Task) -> None: + loguru_logger.info("DrafterRunner shutting down") + self._update_status(RunnerShuttingDown()) + self._acknowledge(task) + # Release the model so the drafter rank's process frees its + # drafter weights before exiting. Per-session caches were owned + # by :func:`drafter_serve_loop`; they were dropped when the + # loop returned via ``OP_SHUTDOWN``. + self.draft_model = None + if self.drafter_socket is not None: + with contextlib.suppress(OSError): + self.drafter_socket.close() + self.drafter_socket = None + import gc + + gc.collect() + self._send_task_status(task.task_id, TaskStatus.Complete) + self._update_status(RunnerShutdown()) + + # -- helpers --------------------------------------------------------- + + def _update_status(self, status: RunnerStatus) -> None: + self.current_status: RunnerStatus = status + self.event_sender.send( + RunnerStatusUpdated(runner_id=self.runner_id, runner_status=status) + ) + + def _send_task_status(self, task_id: TaskId, status: TaskStatus) -> None: + self.event_sender.send(TaskStatusUpdated(task_id=task_id, task_status=status)) + + def _acknowledge(self, task: Task) -> None: + self.event_sender.send(TaskAcknowledged(task_id=task.task_id)) + + +__all__ = ["DrafterRunner"] diff --git a/src/exo/worker/runner/llm_inference/batch_generator.py b/src/exo/worker/runner/llm_inference/batch_generator.py index c11b5ba533..c46ddbe7b0 100644 --- a/src/exo/worker/runner/llm_inference/batch_generator.py +++ b/src/exo/worker/runner/llm_inference/batch_generator.py @@ -1,6 +1,8 @@ +import contextlib import itertools +import os import time -from collections import deque +from collections import OrderedDict, deque from collections.abc import Generator, Iterator from dataclasses import dataclass, field from typing import BinaryIO @@ -27,18 +29,22 @@ from exo.utils.channels import MpReceiver, MpSender from exo.worker.disaggregated.server import PrefillRequest from exo.worker.engines.base import Engine -from exo.worker.engines.mlx.cache import KVPrefixCache +from exo.worker.engines.mlx.cache import KVPrefixCache, encode_prompt, make_kv_cache from exo.worker.engines.mlx.disaggregated.adapter import write_cache_to_wire from exo.worker.engines.mlx.disaggregated.serve import run_prefill_for_request from exo.worker.engines.mlx.generator.batch_generate import ExoBatchGenerator from exo.worker.engines.mlx.generator.generate import ( + BatchedPrefillUnsupportedError, PrefillCancelled, + batched_prefill, mlx_generate, warmup_inference, ) -from exo.worker.engines.mlx.types import Model +from exo.worker.engines.mlx.generator.remote_drafter import RemoteTransport +from exo.worker.engines.mlx.types import KVCacheType, Model from exo.worker.engines.mlx.utils_mlx import ( apply_chat_template, + fix_unmatched_think_end_tokens, mx_all_gather_tasks, mx_any, ) @@ -68,6 +74,66 @@ def gen(self) -> Generator[T | None]: EXO_RUNNER_MUST_OOM = "EXO RUNNER MUST OOM" EXO_RUNNER_MUST_TIMEOUT = "EXO RUNNER MUST TIMEOUT" +# Drafter-tuning env vars. Read once per process at SequentialGenerator +# construction time so every request in this runner sees the same K and +# short-skip threshold (avoids surprises mid-stream). +EXO_NUM_DRAFT_TOKENS = "EXO_NUM_DRAFT_TOKENS" +EXO_DRAFTER_MIN_OUTPUT_TOKENS = "EXO_DRAFTER_MIN_OUTPUT_TOKENS" +EXO_ADAPTIVE_DRAFT_TOKENS = "EXO_ADAPTIVE_DRAFT_TOKENS" # "1" to enable +DEFAULT_NUM_DRAFT_TOKENS = 5 # purpose-built family pairs hit ~80% acceptance +DEFAULT_DRAFTER_MIN_OUTPUT_TOKENS = 16 + +# Batched prefill (B>=2 prompts processed in one forward) is the +# remaining lever for slot-1 TTFT on long-prompt mixed traffic. The +# round-robin landed in PR #15 cut slot-1 TTFT 5.2x by interleaving +# decode ticks; the residual 11s outliers in the 6K-token +# long_context_summary bench are entirely sequential per-slot +# prefills. Setting ``EXO_BATCH_PREFILL=0`` disables the optimisation +# (escape hatch for shared-prefix workloads where the per-slot +# prefix-cache hit rate exceeds the batched-forward speedup; see +# ``mlx_generate``'s ``precomputed_target_cache`` docstring for the +# trade-off rationale). +EXO_BATCH_PREFILL = "EXO_BATCH_PREFILL" +# Rolling-window size used by adaptive K. Keep small so the controller is +# responsive to traffic shifts (code completion vs reasoning) without +# oscillating on per-request noise. +ADAPTIVE_K_WINDOW = 8 + + +def adaptive_num_draft_tokens(rolling_fractions: list[float], fallback: int) -> int: + """Pick K (num_draft_tokens) from a rolling window of acceptance fractions. + + The bands are based on the geometric expectation + ``(1 - p^(K+1)) / (1 - p)`` from the speculative-decoding literature: + K=2 is the right call when the drafter is missing, K=4 around 50-75% + acceptance, K=6 above 75%. Below the warmup threshold (need at least 2 + observations) we fall back to the configured default rather than + twitching at K=2 on first request. + """ + if len(rolling_fractions) < 2: + return fallback + average = sum(rolling_fractions) / len(rolling_fractions) + if average < 0.5: + return 2 + if average < 0.75: + return 4 + return 6 + + +def parse_env_int(name: str, default: int, minimum: int = 1) -> int: + raw = os.environ.get(name) + if raw is None: + return default + try: + value = int(raw) + except ValueError: + logger.warning(f"{name}={raw!r} is not a valid int; falling back to {default}") + return default + if value < minimum: + logger.warning(f"{name}={value} below minimum {minimum}; clamping to {minimum}") + return minimum + return value + def _check_for_debug_prompts(task_params: TextGenerationTaskParams) -> None: """Check for debug prompt triggers in the input.""" @@ -102,6 +168,49 @@ class SequentialGenerator(Engine): # `mlx_generate` itself enforces ``draft_model=None`` whenever ``group is # not None``; this field is only ever populated for single-device runners. draft_model: Model | None = None + # Parallel KVPrefixCache for the drafter so multi-turn conversations + # don't pay drafter prefill on every request. None disables drafter + # prefix caching (single-shot drafter prefill on every call). + drafter_kv_prefix_cache: KVPrefixCache | None = None + # The chosen drafter's ModelId. Used for telemetry (GenerationStats) so + # dashboards can attribute speedup to a specific drafter. + draft_model_id: ModelId | None = None + # K (num_draft_tokens) for speculative_generate_step. None falls back to + # the env var EXO_NUM_DRAFT_TOKENS, then DEFAULT_NUM_DRAFTER_TOKENS. + num_draft_tokens: int | None = None + # max_output_tokens threshold below which the drafter is skipped per + # request. None falls back to the env var EXO_DRAFTER_MIN_OUTPUT_TOKENS. + drafter_min_output_tokens: int | None = None + # Item 7: when True, K is recomputed each request from a rolling window + # of observed acceptance fractions. Disabled by default so K stays + # predictable for benchmarking. + adaptive_draft_tokens: bool = False + # Asymmetric placement telemetry: ``drafter_rank_in_parent`` mirrors + # :attr:`DrafterPlacement.drafter_rank` (advisory only; the drafter + # is NOT a member of any ``mx.distributed.Group`` under the v3+ + # wire). ``None`` for symmetric/single-device builds. When set + # together with ``remote_drafter_transport``, every request runs + # the pipelined+remote drafter path: the spec loop talks to the + # drafter via the dedicated drafter TCP socket owned by + # ``RemoteTransport`` rather than ``mx.distributed`` collectives. + drafter_rank_in_parent: int | None = None + # Long-lived transport bound to the drafter rank. Allocated once at + # builder.build() time; reused across requests so the executor + # thread + drafter cache lifecycle isn't paid per-request. Each + # in-flight request opens its own session via + # :meth:`RemoteTransport.open_session`; the per-session handle is + # the actual ``DrafterTransport`` consumed by the spec loop. Closed + # in :meth:`close` (sends ``OP_SHUTDOWN`` to the drafter rank). + remote_drafter_transport: RemoteTransport | None = None + # Inter-target-rank TCP fanout for spec-decode int broadcasts. + # Allocated alongside the drafter wire on multi-target asymmetric + # placements (see :class:`TargetPeerFanout`); ``None`` for + # single-target / symmetric instances. The runner stores it so the + # spec-decode loop can sidestep ``mx.distributed.send`` / ``recv`` + # for inter-target int broadcasts -- those collide with the + # model's TP ``all_sum`` collectives on the JACCL backend and + # silently corrupt the int wire. + target_peer_fanout: object | None = None check_for_cancel_every: int = 50 _cancelled_tasks: set[TaskId] = field(default_factory=set, init=False) @@ -109,7 +218,37 @@ class SequentialGenerator(Engine): _maybe_cancel: list[TextGeneration] = field(default_factory=list, init=False) _all_tasks: dict[TaskId, TextGeneration] = field(default_factory=dict, init=False) _queue: deque[TextGeneration] = field(default_factory=deque, init=False) - _active: ( + # Rolling window of recently-observed drafter-acceptance fractions for + # adaptive K. Only populated when adaptive_draft_tokens is True. + _recent_acceptance: deque[float] = field( + default_factory=lambda: deque(maxlen=ADAPTIVE_K_WINDOW), + init=False, + ) + # Maximum number of in-flight tasks the runner will round-robin through + # in :meth:`step`. Set to 1 by ``builder.build`` whenever the runner + # owns a long-lived ``RemoteTransport`` (asymmetric pipelined drafter): + # the wire protocol assumes one in-flight prefill/forward session, so + # interleaving two target requests on the same socket would corrupt + # the drafter's per-request state. For all other configurations + # (no drafter, n-gram drafter, in-process model drafter where every + # ``mlx_generate`` call allocates its own draft KVCache) this defaults + # to ``EXO_MAX_CONCURRENT_REQUESTS`` and gives concurrent requests the + # cooperative-scheduling semantics the dispatcher always claimed but + # never delivered: prior to this field every spec-config runner pinned + # ``_active`` to a singular slot and slot 1's TTFT equalled slot 0's + # full completion time (measured 14s on a K=3 single-host n-gram bench + # in the PR #15 concurrency leg). + max_concurrent_tasks: int = 1 + # Currently in-flight tasks, keyed by ``TaskId`` for O(1) cancel/finish. + # Insertion order is the round-robin order; ``OrderedDict`` makes that + # preservation explicit (CPython dicts already preserve it but we want + # the contract to be load-bearing). Capped by ``max_concurrent_tasks``; + # ``step`` round-robins one ``next(gen)`` call per active task per + # tick. Each tuple is (task, mlx generator, response queue, parsed- + # output generator) -- the same shape the previous singular ``_active`` + # slot held, just multiplexed. + _active_tasks: OrderedDict[ + TaskId, tuple[ TextGeneration, # mlx generator that does work @@ -118,9 +257,15 @@ class SequentialGenerator(Engine): GeneratorQueue[GenerationResponse], # generator to get parsed outputs Iterator[GenerationChunk | None], - ] - | None - ) = field(default=None, init=False) + ], + ] = field(default_factory=OrderedDict, init=False) + # Tasks that failed during ``_build_generator`` or mid-stream. Drained + # by ``step`` so per-task failures surface as ``FinishedResponse`` to + # the caller without taking down the runner subprocess. We accept the + # rank-desync risk: ``_build_generator`` failures are deterministic + # in practice (config / per-request K mismatch) so all ranks fail + # together; any non-deterministic failure was already a desync hazard. + _pending_failed: list[TaskId] = field(default_factory=list, init=False) def warmup(self): self.check_for_cancel_every = warmup_inference( @@ -128,6 +273,7 @@ def warmup(self): tokenizer=self.tokenizer, group=self.group, model_id=self.model_id, + draft_model=self.draft_model, ) def submit( @@ -169,39 +315,126 @@ def step( ) -> Iterator[ tuple[TaskId, GenerationChunk | FinishedResponse | CancelledResponse] ]: - if self._active is None: + output: list[ + tuple[TaskId, GenerationChunk | CancelledResponse | FinishedResponse] + ] = [] + + # Top up the active set from the queue. ``agree_on_tasks`` is a + # collective op across the MLX group; we only call it when there + # might be new work to admit (active set has slack and queue is + # potentially non-empty after ``agree_on_tasks`` runs). Calling + # it on every tick is safe but wastes a collective when the + # active set is already full. + if len(self._active_tasks) < self.max_concurrent_tasks: self.agree_on_tasks() + self._admit_queued_tasks() + + # Drain failures recorded by ``_start_next`` (this tick or any + # prior tick that left them queued) so the runner loop marks + # them complete and proceeds with the next task instead of + # tearing down the subprocess (regression: K=8 ValueError took + # the target rank with it on 14:35:05). + while self._pending_failed: + output.append((self._pending_failed.pop(0), FinishedResponse())) + + if not self._active_tasks: + return itertools.chain( + iter(output), + map( + lambda task: (task, CancelledResponse()), + self._cancelled_tasks, + ), + ) - if self._queue: - self._start_next() - else: - return map( - lambda task: (task, CancelledResponse()), self._cancelled_tasks - ) + # Round-robin one ``next(gen)`` per active task. Each generator + # owns its own KV cache (``mlx_generate`` allocates fresh caches + # per request), so interleaving generators per-tick is safe -- the + # only shared state is the model weights themselves, which are + # read-only during forward. Snapshot the items so per-task + # exceptions can ``del self._active_tasks[task_id]`` mid-iteration + # without invalidating the loop. + for task_id, (task, gen, queue, output_generator) in list( + self._active_tasks.items() + ): + try: + response = next(gen) + queue.push(response) + # Observe drafter acceptance once the final stats arrive. We + # do this here (and not in mlx_generate) because the rolling + # buffer is owned by the generator and must persist across + # requests for adaptive K to converge. + if ( + self.adaptive_draft_tokens + and response.stats is not None + and response.stats.drafter_model_id is not None + and response.stats.generation_tokens > 0 + ): + fraction = ( + response.stats.accepted_draft_tokens + / response.stats.generation_tokens + ) + self._recent_acceptance.append(fraction) + # drain potentially many responses every time + while (parsed := next(output_generator, None)) is not None: + output.append((task_id, parsed)) - assert self._active is not None + except (StopIteration, PrefillCancelled): + output.append((task_id, FinishedResponse())) + del self._active_tasks[task_id] - task, gen, queue, output_generator = self._active - output: list[ - tuple[TaskId, GenerationChunk | CancelledResponse | FinishedResponse] - ] = [] - try: - response = next(gen) - queue.push(response) - # drain potentially many responses every time - while (parsed := next(output_generator, None)) is not None: - output.append((task.task_id, parsed)) + except Exception as e: + # ALWAYS log first. Without this, an exception silently + # swallowed on a non-root target rank presents to the + # operator as "rank 1 returned ready in 0.4 s with no + # tokens"; the actual error -- which may be a master + # divergence, an MLX collective desync, or a bad model + # weights load -- is invisible. Logging is unconditional + # because the multi-rank re-raise path below also relies + # on it (the supervisor records the message but not the + # traceback). + logger.opt(exception=True).error( + "generator.step raised; " + f"task_id={task_id} " + f"command_id={task.command_id} " + f"device_rank={self.device_rank} " + f"group_size={self.group.size() if self.group is not None else 1} " + f"exc={type(e).__name__}: {e}" + ) - except (StopIteration, PrefillCancelled): - output.append((task.task_id, FinishedResponse())) - self._active = None - if self._queue: - self._start_next() + # Multi-rank targets MUST re-raise. Any exception here + # (whether a request-level bug or a system-level MLX + # error) means this rank exited the generator without + # participating in the verify-forward TP collective the + # peer rank is now waiting on. Swallowing leaves the + # peer hung indefinitely; raising hands control to + # ``handle_generation_tasks`` -> supervisor -> + # ``RunnerFailed``. The peer's ``_kill_runner`` rule + # then tears down its own runner via the + # ``RunnerFailed``-on-peer trigger (see + # ``worker/plan.py``), the master rebuilds the instance + # via ``CreateRunner``, and the next request sees a + # fresh group. Total recovery is bounded by the + # supervisor escalation chain (~25 s), not "manual + # operator restart". + # + # Single-rank runners keep the legacy swallow path: a + # malformed request shouldn't crash the (only) runner + # and break unrelated concurrent tasks sharing the + # process. + if self.group is not None and self.group.size() > 1: + self._send_error(task, e) + del self._active_tasks[task_id] + raise - except Exception as e: - self._send_error(task, e) - self._active = None - raise + self._send_error(task, e) + del self._active_tasks[task_id] + output.append((task_id, FinishedResponse())) + + # Top up again if we just retired any task -- keeps slot 1's + # TTFT independent of slot 0's completion length, which is the + # whole point of ``max_concurrent_tasks > 1``. + if self._queue and len(self._active_tasks) < self.max_concurrent_tasks: + self._admit_queued_tasks() return filter( lambda chunk: ( @@ -213,13 +446,142 @@ def step( ), ) - def _start_next(self) -> None: - task = self._queue.popleft() + def _admit_queued_tasks(self) -> None: + """Top up ``_active_tasks`` from ``_queue``, batching prefill when possible. + + Cooperatively schedules eligible tasks through a single + :func:`batched_prefill` forward when ``EXO_BATCH_PREFILL`` is on + (default) and at least 2 tasks pass the eligibility filter + (``_batch_eligible_for_prefill``). Ineligible tasks (vision, + remote prefill, in-process model drafter, etc.) and any task + in a single-eligible-task admit cycle fall back to the + per-slot :meth:`_start_one` path. Eligibility is read at admit + time so a request that becomes ineligible mid-tick (e.g. + because ``EXO_BATCH_PREFILL`` was toggled) cleanly degrades. + + The function never raises; per-task setup failures are routed + through :meth:`_send_error` + ``_pending_failed`` (same + liveness contract as :meth:`_start_one`). + """ + if not self._queue: + return + + # Drain the queue up to the active-set slack, then partition by + # batch eligibility. We can't peek-without-pop because + # ``self._queue`` is a deque drained by the caller, so collect + # candidates first and re-route into ``_start_one`` if the + # batch path bails. + slack = self.max_concurrent_tasks - len(self._active_tasks) + candidates: list[TextGeneration] = [] + while self._queue and len(candidates) < slack: + candidates.append(self._queue.popleft()) + + if not candidates: + return + + batch_enabled = os.environ.get(EXO_BATCH_PREFILL, "1") != "0" + if not batch_enabled: + for task in candidates: + self._start_one(task) + return + + eligible: list[tuple[TextGeneration, mx.array, KVCacheType]] = [] + leftover: list[TextGeneration] = [] + for task in candidates: + prep = self._prepare_for_batch_prefill(task) + if prep is None: + leftover.append(task) + else: + eligible.append(prep) + + logger.debug( + f"_admit_queued_tasks candidates={len(candidates)} " + f"eligible={len(eligible)} leftover={len(leftover)} " + f"slack={slack} batch_enabled={batch_enabled}" + ) + + # Single-eligible: a batched forward of size 1 has no parallelism + # win and adds the PromptBatch + merge_caches overhead, so just + # take the per-slot path. + if len(eligible) < 2: + for task in candidates: + self._start_one(task) + return + + prompts = [tup[1] for tup in eligible] + caches = [tup[2] for tup in eligible] + try: - gen = self._build_generator(task) + tps, total = batched_prefill( + model=self.model, + prompt_tokens_list=prompts, + caches_list=caches, + ) + logger.info( + f"batched_prefill: {len(eligible)} slots, {total} tokens " + f"({tps:.1f} tok/s aggregate)" + ) + for task, prompt_tokens, cache in eligible: + self._emit_prefill_complete(task, prompt_tokens) + self._start_one(task, precomputed_target_cache=cache) + for task in leftover: + self._start_one(task) + return + except BatchedPrefillUnsupportedError: + logger.info( + "batched_prefill unsupported for this model/cache; " + "falling back to per-slot prefill" + ) + for task in candidates: + self._start_one(task) + return except Exception as e: + # Untyped failure: charge the error to every batched task so + # one bad request doesn't take the runner down. ``leftover`` + # tasks were not part of the failed batch and proceed + # normally on the per-slot path. + for task, _, _ in eligible: + self._send_error(task, e) + self._pending_failed.append(task.task_id) + for task in leftover: + self._start_one(task) + return + + def _start_one( + self, + task: TextGeneration, + *, + precomputed_target_cache: KVCacheType | None = None, + ) -> None: + """Build one slot's generator and add it to ``_active_tasks``. + + ``precomputed_target_cache`` is forwarded to ``mlx_generate`` to + skip its prefix-cache lookup + local prefill. Set by + :meth:`_admit_queued_tasks` after a batched prefill; ``None`` + otherwise. + """ + # Only forward ``precomputed_target_cache`` when it was set so + # existing test seams that monkeypatch ``_build_generator`` with + # the legacy ``(self, task)`` signature still work; the per-slot + # admit path (``precomputed_target_cache is None``) is the + # default and predates the batched-prefill seam. + try: + if precomputed_target_cache is None: + gen = self._build_generator(task) + else: + gen = self._build_generator( + task, precomputed_target_cache=precomputed_target_cache + ) + except Exception as e: + # Preserve runner liveness: surface the error to the client + # via ``_send_error`` and queue a ``FinishedResponse`` for + # ``step`` to drain on the next tick. The active set is + # unchanged so the next ``step`` either picks up the next + # queued task or returns idle (instead of asserting and + # crashing the subprocess). self._send_error(task, e) - raise + self._pending_failed.append(task.task_id) + return queue = GeneratorQueue[GenerationResponse]() if task.task_params.bench: @@ -236,7 +598,103 @@ def _start_next(self) -> None: self.model_id, task.task_params.tools, ) - self._active = (task, gen, queue, output_generator) + self._active_tasks[task.task_id] = (task, gen, queue, output_generator) + + def _batch_eligible_for_prefill(self, task: TextGeneration) -> bool: + """Return ``True`` when ``task`` can be co-prefilled with peers. + + V1 eligibility is narrow on purpose: only single-rank text-only + generation without remote prefill or an in-process model + drafter. The asymmetric pipelined drafter still qualifies + because ``draft_model`` is ``None`` on the target rank — the + drafter cache lives on the remote rank and is prefilled per- + session over the wire, independent of target prefill batching. + + Multi-rank target paths (TP/PP) are excluded because + :func:`pipeline_parallel_prefill`'s collective semantics need + per-slot driver loops; a follow-up can lift this once the + batched forward is folded into the pipeline driver. + """ + params = task.task_params + if self.group is not None and self.group.size() > 1: + return False + if params.images: + return False + if params.prefill_endpoint is not None: + return False + # In-process model drafter ("model" mode) needs a paired + # drafter prefill aligned to the target's offset; batching + # only the target without batching the drafter would desync + # them. The asymmetric drafter (``self.draft_model is None`` + # but ``remote_drafter_transport is not None``) is fine + # because its drafter prefill goes over the wire per-session. + return self.draft_model is None + + def _prepare_for_batch_prefill( + self, task: TextGeneration + ) -> tuple[TextGeneration, mx.array, KVCacheType] | None: + """Encode the prompt and allocate a fresh cache for batched prefill. + + Returns ``None`` when ``task`` is ineligible or when the + encoded prompt is too short to leave a decode-seed token + (length < 2). The encoding mirrors :func:`mlx_generate`'s + ``encode_prompt`` + ``fix_unmatched_think_end_tokens`` so the + cache offset agreed by ``batched_prefill`` matches what + ``mlx_generate`` later sees on the inner side of + ``precomputed_target_cache``. + """ + if not self._batch_eligible_for_prefill(task): + return None + try: + prompt_str = apply_chat_template(self.tokenizer, task.task_params) + prompt_tokens = encode_prompt(self.tokenizer, prompt_str) + prompt_tokens = fix_unmatched_think_end_tokens( + prompt_tokens, self.tokenizer + ) + except Exception: + # Encoding failure surfaces through the per-slot path so + # the existing ``_send_error`` plumbing reports it; we + # don't swallow it here. + logger.opt(exception=True).warning( + "Prompt encoding failed during batch-prefill prep; " + "falling back to per-slot path" + ) + return None + if int(prompt_tokens.size) < 2: + return None + try: + cache = make_kv_cache(self.model) + except Exception: + logger.opt(exception=True).warning( + "make_kv_cache failed during batch-prefill prep; " + "falling back to per-slot path" + ) + return None + return (task, prompt_tokens, cache) + + def _emit_prefill_complete( + self, task: TextGeneration, prompt_tokens: mx.array + ) -> None: + """Fire a single ``processed=total`` ``PrefillProgressChunk``. + + ``batched_prefill`` runs as one forward so per-chunk progress + events would mix slots. We elide intermediate progress and + emit a single completion event per slot at the end of the + batched forward so dashboards stop showing 0% prefill. + """ + if self.device_rank != 0: + return + total = int(prompt_tokens.size) + self.event_sender.send( + ChunkGenerated( + command_id=task.command_id, + chunk=PrefillProgressChunk( + model=self.model_id, + processed_tokens=total, + total_tokens=total, + ), + ) + ) def _send_error(self, task: TextGeneration, e: Exception) -> None: if self.device_rank == 0: @@ -251,7 +709,12 @@ def _send_error(self, task: TextGeneration, e: Exception) -> None: ) ) - def _build_generator(self, task: TextGeneration) -> Generator[GenerationResponse]: + def _build_generator( + self, + task: TextGeneration, + *, + precomputed_target_cache: KVCacheType | None = None, + ) -> Generator[GenerationResponse]: _check_for_debug_prompts(task.task_params) prompt = apply_chat_template(self.tokenizer, task.task_params) @@ -288,6 +751,17 @@ def on_generation_token() -> None: self.agree_on_tasks() + # Adaptive K (item 7): when enabled, recompute K from the rolling + # window of observed acceptance fractions. The configured value + # (`self.num_draft_tokens`) is the warmup fallback used until the + # window has enough data. + if self.adaptive_draft_tokens and self.num_draft_tokens is not None: + effective_num_draft_tokens: int | None = adaptive_num_draft_tokens( + list(self._recent_acceptance), fallback=self.num_draft_tokens + ) + else: + effective_num_draft_tokens = self.num_draft_tokens + return mlx_generate( model=self.model, tokenizer=self.tokenizer, @@ -300,9 +774,51 @@ def on_generation_token() -> None: group=self.group, vision_processor=self.vision_processor, draft_model=self.draft_model, + drafter_kv_prefix_cache=self.drafter_kv_prefix_cache, + drafter_model_id=self.draft_model_id, + num_draft_tokens=effective_num_draft_tokens, + drafter_min_output_tokens=self.drafter_min_output_tokens, + asymmetric_drafter_rank=self.drafter_rank_in_parent, + asymmetric_drafter_transport=self.remote_drafter_transport, + target_peer_fanout=self.target_peer_fanout, + precomputed_target_cache=precomputed_target_cache, ) def close(self) -> None: + if self.remote_drafter_transport is not None: + try: + self.remote_drafter_transport.shutdown() + except Exception: + # Drafter rank may already be gone (e.g. due to a + # parallel shutdown of the cluster); log and continue + # so target-side cleanup isn't blocked on a peer that + # can't ack. The shutdown call is idempotent so a + # later retry is harmless. + logger.opt(exception=True).warning( + "Drafter rank shutdown failed; continuing close" + ) + self.remote_drafter_transport = None + # Close every TCP socket the target-peer fanout owns (one per + # peer on rank 0, single rank-zero socket on peers). Inline + # the socket import + isinstance check to keep this module's + # top-level imports thin. ``OSError`` here is benign -- the + # peer may already have closed (e.g. supervisor SIGKILL chain) + # and we just want to free the local FDs before the runner + # exits. + if self.target_peer_fanout is not None: + from exo.worker.engines.mlx.utils_mlx import TargetPeerFanout as _Fanout + + if isinstance(self.target_peer_fanout, _Fanout): + import socket as _socket + + for sock in self.target_peer_fanout.peer_sockets.values(): + if isinstance(sock, _socket.socket): + with contextlib.suppress(OSError): + sock.close() + if isinstance(self.target_peer_fanout.rank_zero_socket, _socket.socket): + with contextlib.suppress(OSError): + self.target_peer_fanout.rank_zero_socket.close() + self.target_peer_fanout = None del self.model, self.tokenizer, self.group def serve_prefill(self, request: PrefillRequest, wfile: BinaryIO) -> None: diff --git a/src/exo/worker/runner/runner.py b/src/exo/worker/runner/runner.py index ac5d054808..c366bfa117 100644 --- a/src/exo/worker/runner/runner.py +++ b/src/exo/worker/runner/runner.py @@ -1,3 +1,4 @@ +import os import queue import threading import time @@ -58,6 +59,44 @@ from exo.worker.runner.bootstrap import logger PREFILL_PICKUP_TIMEOUT_SECONDS = 3 + +# Window the runner blocks on ``_work_queue`` after the initial task +# is admitted, looking for sibling burst-arrivals that should land in +# the same ``SequentialGenerator._admit_queued_tasks`` window so their +# prefills can be batched. +# +# Empirically (3-node TB-RDMA Big Brain, gemma-4-26b-a4b-it-4bit on +# smbpt, 2 concurrent client requests dispatched within microseconds +# at the bench harness): the master process records both +# ``Executing command: TextGeneration`` events 15-33ms apart, but +# they reach the runner subprocess's ``_work_queue`` 150-200ms apart +# because of libp2p pubsub fan-out + mp-channel hop from the worker +# process to the runner subprocess. The original 20ms default +# missed slot #2 by ~130ms and ``batched_prefill`` never fired. +# 200ms catches it reliably; the cost is +200ms TTFT for genuinely +# solo requests, but the burst-coalesce only runs ONCE per +# ``handle_generation_tasks`` entry (i.e. only when transitioning +# from RunnerReady -> RunnerRunning, not on every admit), so +# back-to-back requests on a warm instance pay this only on the +# first wave. Set ``EXO_BURST_COALESCE_MS=0`` to disable +# (per-slot prefill on every request). +EXO_BURST_COALESCE_MS = "EXO_BURST_COALESCE_MS" +DEFAULT_BURST_COALESCE_MS = 200 + + +def _parse_burst_coalesce_ms() -> int: + raw = os.environ.get(EXO_BURST_COALESCE_MS) + if raw is None: + return DEFAULT_BURST_COALESCE_MS + try: + value = int(raw) + except ValueError: + logger.warning( + f"{EXO_BURST_COALESCE_MS}={raw!r} is not a valid int; " + f"falling back to {DEFAULT_BURST_COALESCE_MS}ms" + ) + return DEFAULT_BURST_COALESCE_MS + return max(0, value) PREFILL_FINISH_TIMEOUT_SECONDS = 300 @@ -120,6 +159,13 @@ def __init__( self._prefill_server: PrefillServer | None = None self._prefill_server_port: int | None = None self._work_queue: queue.Queue[WorkItem] = queue.Queue() + # Slot for a non-generation item picked up by + # ``_coalesce_burst_generation_tasks`` -- consumed by the main + # loop in ``handle_generation_tasks`` before its next + # ``_work_queue.get_nowait()`` so the FIFO order between burst + # text-gens and a trailing ``Shutdown`` / ``PrefillTask`` / + # ``_TaskStreamClosed`` is preserved. + self._burst_deferred_item: WorkItem | None = None self._task_reader_thread: threading.Thread | None = None logger.info("runner created") @@ -326,11 +372,159 @@ def submit_generation(self, task: GenerationTask): self.active_tasks[task.task_id] = task self.generator.submit(task) + def _drain_pending_work_items(self, max_drain: int = 32) -> "ExitCode | None": + """Non-blocking drain of immediately-available ``_work_queue`` items. + + Called between every ``step()`` iteration in the main generation + loop. Submits ``GenerationTask`` siblings via the existing + ``submit_generation`` path so the next ``step()``'s + ``agree_on_tasks`` + ``_admit_queued_tasks`` sees them all in + the same admit window (this is what extends ``batched_prefill`` + coverage past the initial 2-slot burst -- e.g. concurrency=4 + where the 3rd and 4th slots straggle ~1s behind the first + pair). + + Specials end the drain and are handled in arrival order: + + * :class:`_TaskStreamClosed` -> return :attr:`ExitCode.Shutdown` + to break the main loop. + * :class:`PrefillTask` -> serve it (synchronous, blocks until + done) then return ``None`` so the main loop continues. + * :class:`Shutdown` -> shut the runner down and return + :attr:`ExitCode.Shutdown`. + + Returns ``None`` to signal "keep looping" (queue exhausted or + only generation tasks were drained), an ``ExitCode`` to signal + the main loop should exit. + + ``max_drain`` is a defensive bound. In practice the queue + carries 1-4 burst tasks at a time; the drain returns far + sooner via ``queue.Empty``. + """ + for _ in range(max_drain): + if self._burst_deferred_item is not None: + item = self._burst_deferred_item + self._burst_deferred_item = None + else: + try: + item = self._work_queue.get_nowait() + except queue.Empty: + return None + if isinstance(item, _TaskStreamClosed): + return ExitCode.Shutdown + if isinstance(item, PrefillTask): + self._serve_prefill(item) + # ``_serve_prefill`` is synchronous; we yield back to + # the main loop here so the next ``step()`` runs + # before we drain more items, matching the + # pre-refactor cadence where one ``PrefillTask`` per + # iteration was the maximum. + return None + if item.task_id in self.seen: + logger.warning("repeat task - potential error") + continue + self.seen.add(item.task_id) + match item: + case TextGeneration() | ImageGeneration() | ImageEdits(): + self.acknowledge_task(item) + self.submit_generation(item) + case Shutdown(): + self.shutdown(item) + return ExitCode.Shutdown + case _: + raise ValueError( + f"Received {item.__class__.__name__} outside of " + f"state machine in {self.current_status=}" + ) + return None + + def _coalesce_burst_generation_tasks(self, max_drain: int = 32) -> None: + """Pull pending ``GenerationTask`` items into the generator's queue. + + Called from :meth:`handle_generation_tasks` after the initial + ``submit_generation`` so the upcoming ``step()`` call admits the + full burst together. Stops at the first non-generation item + (``PrefillTask`` / ``_TaskStreamClosed`` / ``Shutdown``) and + stashes that item in :attr:`_burst_deferred_item` so the main + loop sees it before its next ``_work_queue.get_nowait()`` -- + re-queueing at the tail would race with the listener thread + and silently re-order ``Shutdown`` past burst tasks. + + After draining whatever is immediately available, blocks on the + queue for up to ``EXO_BURST_COALESCE_MS`` (default 20ms) to + catch sibling burst-arrivals whose libp2p delivery straggles + behind the first request -- without this, two concurrent + client requests reliably miss the same admit window because + only the first arrives before the runner reaches ``step()``. + + ``max_drain`` is a defensive bound so a saturated upstream + producer can't starve the first ``step()`` indefinitely; in + practice the work queue carries 1-2 burst-tasks at a time. + """ + budget_ms = _parse_burst_coalesce_ms() + deadline = time.monotonic() + budget_ms / 1000.0 if budget_ms > 0 else None + drained = 0 + start = time.monotonic() + for _ in range(max_drain): + try: + item = self._work_queue.get_nowait() + except queue.Empty: + if deadline is None: + break + remaining = deadline - time.monotonic() + if remaining <= 0: + break + try: + item = self._work_queue.get(timeout=remaining) + except queue.Empty: + break + if isinstance(item, TextGeneration | ImageGeneration | ImageEdits): + if item.task_id in self.seen: + continue + self.seen.add(item.task_id) + self.acknowledge_task(item) + self.submit_generation(item) + drained += 1 + continue + self._burst_deferred_item = item + break + elapsed_ms = (time.monotonic() - start) * 1000.0 + # ``info`` when we actually batched (drained>=1) so operators see the + # value the coalesce delivered; ``debug`` when nothing batched, so + # solo-request runners stay quiet. + if drained >= 1: + logger.info( + f"burst-coalesce drained={drained} budget_ms={budget_ms} " + f"elapsed_ms={elapsed_ms:.1f} " + f"deferred={self._burst_deferred_item is not None}" + ) + else: + logger.debug( + f"burst-coalesce drained=0 budget_ms={budget_ms} " + f"elapsed_ms={elapsed_ms:.1f} " + f"deferred={self._burst_deferred_item is not None}" + ) + def handle_generation_tasks(self, starting_task: GenerationTask): assert isinstance(self.current_status, RunnerReady) assert isinstance(self.generator, Engine) - logger.info(f"received chat request: {starting_task}") + # Log identifiers only. The full ``starting_task`` is a deep + # Pydantic model whose default ``__str__`` recursively repr's + # every field (including ``chat_template_messages`` and any + # nested token / image structures). On a multi-rank target + # placement the worker plans the same TextGeneration repeatedly + # while a runner is busy, so logging the full model on every + # entry has been observed to peg rank 0 inside ``list_repr`` / + # ``long_to_decimal_string`` for minutes (peak physical + # footprint ~300 GB) and prevent it from ever entering the + # model forward -- which the peer rank then deadlocks on inside + # the first TP collective. + logger.info( + "received chat request task_id=" + f"{starting_task.task_id} command_id={starting_task.command_id} " + f"task_type={starting_task.__class__.__name__}" + ) self.update_status(RunnerRunning()) logger.info("runner running") self.acknowledge_task(starting_task) @@ -338,6 +532,20 @@ def handle_generation_tasks(self, starting_task: GenerationTask): self.submit_generation(starting_task) + # Coalesce burst-arrivals: drain TextGeneration / ImageGeneration / + # ImageEdits items already sitting in ``_work_queue`` and submit + # them BEFORE the first ``step()``. Without this, two concurrent + # client requests that arrive within a few ms see the runner + # admit task #1 alone (its prefill starts on the very first + # ``step()``) and task #2 only joins on the next iteration -- + # which defeats batched-prefill admission entirely (the + # ``_admit_queued_tasks`` candidate list never has B>=2 tasks). + # Non-task items (PrefillTask / _TaskStreamClosed / Shutdown) + # are left in the queue so the main loop's match block handles + # them in order; we stop draining at the first non-task item to + # preserve queue ordering. + self._coalesce_burst_generation_tasks() + while self.active_tasks: results = self.generator.step() @@ -355,30 +563,25 @@ def handle_generation_tasks(self, starting_task: GenerationTask): for task_id in finished: self.active_tasks.pop(task_id, None) - try: - item = self._work_queue.get_nowait() - except queue.Empty: - continue - if isinstance(item, _TaskStreamClosed): - return ExitCode.Shutdown - if isinstance(item, PrefillTask): - self._serve_prefill(item) - continue - if item.task_id in self.seen: - logger.warning("repeat task - potential error") - continue - self.seen.add(item.task_id) - match item: - case TextGeneration() | ImageGeneration() | ImageEdits(): - self.acknowledge_task(item) - self.submit_generation(item) - case Shutdown(): - self.shutdown(item) - return ExitCode.Shutdown - case _: - raise ValueError( - f"Received {item.__class__.__name__} outside of state machine in {self.current_status=}" - ) + # Drain ALL immediately-available items so concurrent + # burst-arrivals that landed during the previous + # ``step()`` (e.g. slots 3/4 of a concurrency=4 wave that + # arrived behind slots 1/2 by libp2p straggle) are + # submitted before the NEXT ``step()`` runs + # ``agree_on_tasks`` + ``_admit_queued_tasks``. Without + # this, the original code drained one item per iteration, + # so the second admit cycle still saw a single candidate + # and fell through to per-slot prefill -- we lose + # batched-prefill on every slot beyond the first wave. + # + # Specials (``_TaskStreamClosed`` / ``PrefillTask`` / + # ``Shutdown``) terminate the drain and are handled in + # arrival order. The ``_burst_deferred_item`` slot is + # checked first for FIFO preservation against the entry- + # time burst-coalesce. + exit_code = self._drain_pending_work_items() + if exit_code is not None: + return exit_code self.update_status(RunnerReady(prefill_server_port=self._prefill_server_port)) logger.info("runner ready") diff --git a/src/exo/worker/runner/supervisor.py b/src/exo/worker/runner/supervisor.py index 4b64a4e9df..ae2508e7e2 100644 --- a/src/exo/worker/runner/supervisor.py +++ b/src/exo/worker/runner/supervisor.py @@ -15,6 +15,7 @@ from loguru import logger from exo.shared.types.chunks import ErrorChunk +from exo.shared.types.common import ModelId from exo.shared.types.events import ( ChunkGenerated, Event, @@ -37,6 +38,7 @@ RunnerFailed, RunnerIdle, RunnerLoading, + RunnerReady, RunnerRunning, RunnerShuttingDown, RunnerStatus, @@ -53,7 +55,12 @@ @dataclass(eq=False) class RunnerSupervisor: - shard_metadata: ShardMetadata + # ``None`` when ``bound_instance.is_drafter_rank`` is true: the drafter + # rank has no shard (it serves the full drafter model, not a slice of + # the target). Use the ``model_id`` property instead of reaching + # through ``shard_metadata.model_card`` so the same access pattern + # works for target and drafter runners. + shard_metadata: ShardMetadata | None bound_instance: BoundInstance runner_process: mp.Process initialize_timeout: float @@ -96,7 +103,12 @@ def create( daemon=True, ) - shard_metadata = bound_instance.bound_shard + # Drafter ranks have no shard (they own the full drafter model); + # only target ranks slice the model into shards. Use ``model_id`` + # for logging so both code paths share the same surface. + shard_metadata = ( + None if bound_instance.is_drafter_rank else bound_instance.bound_shard + ) self = cls( bound_instance=bound_instance, @@ -109,19 +121,39 @@ def create( _event_sender=event_sender, ) logger.info( - "Created runner supervisor " - f"{self._runner_context()} model_id={self.shard_metadata.model_card.model_id}" + f"Created runner supervisor {self._runner_context()} " + f"model_id={self.model_id}" ) return self + @property + def model_id(self) -> ModelId: + """Model loaded by the supervised runner. + + For target ranks this is the sharded model ID from + ``shard_metadata``; for drafter ranks it is the drafter model + ID from ``DrafterPlacement``. The two callers that previously + reached through ``shard_metadata.model_card.model_id`` only + needed the model id for logging / error chunks, both of which + also make sense for the drafter rank. + """ + if self.shard_metadata is not None: + return self.shard_metadata.model_card.model_id + placement = self.bound_instance.instance.drafter_placement + assert placement is not None, ( + "supervisor with no shard_metadata must be on a drafter rank " + "but its instance has no DrafterPlacement; this should have " + "been validated by BoundInstance" + ) + return placement.drafter_model_id + async def run(self): self.runner_process.start() self._started_at = current_time() logger.info( - "Runner process started " - f"{self._runner_context()} pid={self.runner_process.pid} " - f"model_id={self.shard_metadata.model_card.model_id}" + f"Runner process started {self._runner_context()} " + f"pid={self.runner_process.pid} model_id={self.model_id}" ) try: async with self._tg as tg: @@ -129,8 +161,8 @@ async def run(self): tg.start_soon(self._forward_events) finally: logger.info( - "Runner supervisor shutting down " - f"{self._runner_context()} pid={self.runner_process.pid} " + f"Runner supervisor shutting down {self._runner_context()} " + f"model_id={self.model_id} pid={self.runner_process.pid} " f"rss_mb={self._runner_rss_mb()}" ) if not self._cancel_watch_runner.cancel_called: @@ -229,6 +261,19 @@ async def start_task(self, task: Task): self.in_progress.pop(task.task_id, None) logger.warning(f"Task {task} dropped, runner closed communication.") return + # Generation tasks (Text/Image/Edits) on a warmed-up runner do not need + # the per-task ack-wait gate: the runner state machine accepts them + # in any order while ``RunnerReady``/``RunnerRunning``, and waiting + # for ack here serialises worker->runner dispatch one task at a time. + # This caps batched-prefill (in ``SequentialGenerator``) at B=2 even + # when the bench fires conc=4: slot #3 only ships after the runner + # acks slot #2, which only happens after batched_prefill completes. + # Lifecycle tasks (LoadModel, StartWarmup, ConnectToGroup, Shutdown, + # CancelTask) keep the gate so state transitions stay ordered. + is_generation_task = isinstance(task, (TextGeneration, ImageGeneration, ImageEdits)) + runner_is_warm = isinstance(self.status, (RunnerReady, RunnerRunning)) + if is_generation_task and runner_is_warm: + return await event.wait() async def cancel_task(self, task_id: TaskId): @@ -339,7 +384,7 @@ async def _check_runner(self, e: Exception) -> None: ChunkGenerated( command_id=task.command_id, chunk=ErrorChunk( - model=self.shard_metadata.model_card.model_id, + model=self.model_id, error_message=( "Runner shutdown before completing command " f"({cause})" diff --git a/src/exo/worker/tests/unittests/test_mlx/test_drafter_abstraction.py b/src/exo/worker/tests/unittests/test_mlx/test_drafter_abstraction.py new file mode 100644 index 0000000000..f399b44c73 --- /dev/null +++ b/src/exo/worker/tests/unittests/test_mlx/test_drafter_abstraction.py @@ -0,0 +1,333 @@ +"""Tests for the ``Drafter`` abstraction. + +These cover the pure-Python pieces - mode resolution, n-gram suffix +matching, and the spec-loop accept arithmetic - so they don't need MLX +weights or a GPU. End-to-end correctness with a real model is exercised +by the smoke + bench scripts in ``scripts/``. +""" + +from __future__ import annotations + +import pytest + +from exo.worker.engines.mlx.generator.drafter import ( + ALL_DRAFT_MODES, + EXO_DRAFT_MODE_ENV, + DraftMode, + EagleDrafter, + LookaheadDrafter, + NgramDrafter, + NoSpecDrafter, + make_drafter, + parse_draft_mode, + resolve_draft_mode, +) + + +def test_all_draft_modes_match_literal() -> None: + """``ALL_DRAFT_MODES`` must be the runtime mirror of the ``DraftMode`` Literal.""" + assert ALL_DRAFT_MODES == ( + "model", + "pipelined", + "ngram", + "eagle", + "lookahead", + "none", + ) + + +def test_eagle_drafter_scaffold_raises_on_stream() -> None: + """``EagleDrafter`` is a scaffolding stub; ``stream`` must fail loudly. + + The factory dispatch + ``Drafter`` protocol shape are the durable + contract here; the actual auxiliary-head loop is intentionally not + implemented yet. A future PR fills this in. + """ + drafter = make_drafter( + mode="eagle", + num_draft_tokens=3, + draft_model=None, + draft_cache=None, + ) + assert isinstance(drafter, EagleDrafter) + assert drafter.mode == "eagle" + assert drafter.num_draft_tokens == 3 + with pytest.raises(NotImplementedError, match="EagleDrafter is a scaffolding"): + # ``stream`` is a generator function; ``next()`` triggers the body. + next( + drafter.stream( + model=object(), # type: ignore[arg-type] + tokenizer=object(), # type: ignore[arg-type] + prompt=object(), # type: ignore[arg-type] + context_tokens=[], + prompt_cache=[], + max_tokens=1, + sampler=lambda x: x, + logits_processors=[], + ) + ) + + +def test_lookahead_drafter_scaffold_raises_on_stream() -> None: + """``LookaheadDrafter`` is a scaffolding stub; ``stream`` must fail loudly.""" + drafter = make_drafter( + mode="lookahead", + num_draft_tokens=3, + draft_model=None, + draft_cache=None, + ) + assert isinstance(drafter, LookaheadDrafter) + assert drafter.mode == "lookahead" + assert drafter.num_draft_tokens == 3 + assert drafter.window_size == 5 + assert drafter.ngram_size == 3 + with pytest.raises(NotImplementedError, match="LookaheadDrafter is a scaffolding"): + next( + drafter.stream( + model=object(), # type: ignore[arg-type] + tokenizer=object(), # type: ignore[arg-type] + prompt=object(), # type: ignore[arg-type] + context_tokens=[], + prompt_cache=[], + max_tokens=1, + sampler=lambda x: x, + logits_processors=[], + ) + ) + + +@pytest.mark.parametrize( + ("raw", "default", "expected"), + [ + (None, "model", "model"), + (None, "none", "none"), + ("model", "none", "model"), + ("MODEL", "none", "model"), + (" ngram ", "none", "ngram"), + ("pipelined", "none", "pipelined"), + ("PIPELINED", "model", "pipelined"), + ("none", "model", "none"), + ("garbage", "model", "model"), + ("garbage", "none", "none"), + ], +) +def test_parse_draft_mode( + raw: str | None, default: DraftMode, expected: DraftMode +) -> None: + assert parse_draft_mode(raw, default) == expected + + +def test_parse_draft_mode_warns_on_unknown_value( + monkeypatch: pytest.MonkeyPatch, + caplog: pytest.LogCaptureFixture, +) -> None: + monkeypatch.delenv(EXO_DRAFT_MODE_ENV, raising=False) + parse_draft_mode("totally-bogus", "none") + # Loguru-driven logger doesn't pipe to caplog by default; just assert + # the call didn't raise. The warning is documented in the docstring. + + +class TestResolveDraftMode: + def test_explicit_request_mode_wins_over_use_drafter(self) -> None: + # Per-request draft_mode beats the use_drafter shortcut. + assert ( + resolve_draft_mode( + has_drafter_model=True, + request_use_drafter=False, + request_draft_mode="ngram", + ) + == "ngram" + ) + + def test_use_drafter_false_maps_to_none(self) -> None: + assert ( + resolve_draft_mode( + has_drafter_model=True, + request_use_drafter=False, + request_draft_mode=None, + ) + == "none" + ) + + def test_default_with_drafter_loaded(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.delenv(EXO_DRAFT_MODE_ENV, raising=False) + assert ( + resolve_draft_mode( + has_drafter_model=True, + request_use_drafter=None, + request_draft_mode=None, + ) + == "model" + ) + + def test_default_without_drafter_loaded( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + monkeypatch.delenv(EXO_DRAFT_MODE_ENV, raising=False) + assert ( + resolve_draft_mode( + has_drafter_model=False, + request_use_drafter=None, + request_draft_mode=None, + ) + == "none" + ) + + def test_env_override_with_drafter_loaded( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + monkeypatch.setenv(EXO_DRAFT_MODE_ENV, "ngram") + assert ( + resolve_draft_mode( + has_drafter_model=True, + request_use_drafter=None, + request_draft_mode=None, + ) + == "ngram" + ) + + def test_model_mode_without_drafter_demotes_to_none( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + monkeypatch.delenv(EXO_DRAFT_MODE_ENV, raising=False) + assert ( + resolve_draft_mode( + has_drafter_model=False, + request_use_drafter=None, + request_draft_mode="model", + ) + == "none" + ) + + def test_pipelined_mode_without_drafter_demotes_to_none( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + # Same misconfiguration safety net as ``"model"``: requesting + # ``"pipelined"`` without a loaded drafter must fall back to + # ``"none"`` rather than hard-failing or producing a no-op + # drafter that silently degrades throughput. + monkeypatch.delenv(EXO_DRAFT_MODE_ENV, raising=False) + assert ( + resolve_draft_mode( + has_drafter_model=False, + request_use_drafter=None, + request_draft_mode="pipelined", + ) + == "none" + ) + + def test_pipelined_mode_with_drafter_loaded_passes_through( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + monkeypatch.delenv(EXO_DRAFT_MODE_ENV, raising=False) + assert ( + resolve_draft_mode( + has_drafter_model=True, + request_use_drafter=None, + request_draft_mode="pipelined", + ) + == "pipelined" + ) + + def test_explicit_none_with_drafter_loaded( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + monkeypatch.delenv(EXO_DRAFT_MODE_ENV, raising=False) + assert ( + resolve_draft_mode( + has_drafter_model=True, + request_use_drafter=None, + request_draft_mode="none", + ) + == "none" + ) + + +class TestNgramDrafterPropose: + """The proposer is pure list logic; no MLX involved.""" + + def test_returns_empty_when_context_is_too_short(self) -> None: + drafter = NgramDrafter(num_draft_tokens=4, min_match=2, max_match=4) + # Need at least min_match + 1 tokens for a match to be possible + # (suffix of length min_match plus one earlier match position). + assert drafter.propose([1, 2], 4) == [] + + def test_returns_empty_when_no_match(self) -> None: + drafter = NgramDrafter(num_draft_tokens=4, min_match=2, max_match=4) + # Tokens are unique - no suffix appears earlier. + assert drafter.propose([10, 20, 30, 40, 50], 4) == [] + + def test_finds_simple_repetition(self) -> None: + # Suffix [1, 2] appears at start; following tokens are [3, 4]. + drafter = NgramDrafter(num_draft_tokens=4, min_match=2, max_match=4) + assert drafter.propose([1, 2, 3, 4, 1, 2], 2) == [3, 4] + + def test_proposes_up_to_k_tokens(self) -> None: + drafter = NgramDrafter(num_draft_tokens=10, min_match=2, max_match=4) + # K=2 caps proposal to 2 even though 4 follow the match. + assert drafter.propose([1, 2, 3, 4, 5, 6, 1, 2], 2) == [3, 4] + + def test_prefers_longer_match(self) -> None: + # Suffix [2, 3] appears at index 1; suffix [1, 2, 3] appears at + # index 0 (length 3, longer). Should prefer the longer one and + # return [4, 5] (the tokens after the longer match). + drafter = NgramDrafter(num_draft_tokens=4, min_match=2, max_match=4) + ctx = [1, 2, 3, 4, 5, 6, 7, 1, 2, 3] + # Last 3 tokens are [1, 2, 3]; longest match starts at 0. + # Following tokens at start were [4, 5]. + assert drafter.propose(ctx, 4)[:2] == [4, 5] + + def test_prefers_recent_match_when_tied(self) -> None: + # Two matches of suffix [9, 9] at same length; prefer the more + # recent one (locality of reference). + drafter = NgramDrafter(num_draft_tokens=2, min_match=2, max_match=2) + ctx = [9, 9, 1, 9, 9, 2, 9, 9] + # Recent match at index 3, followed by [2]. Earliest match at 0, + # followed by [1]. Prefer recent -> [2]. + result = drafter.propose(ctx, 1) + assert result == [2] + + def test_returns_empty_for_zero_k(self) -> None: + drafter = NgramDrafter(num_draft_tokens=4, min_match=2, max_match=4) + assert drafter.propose([1, 2, 3, 1, 2], 0) == [] + + def test_validates_constructor_args(self) -> None: + with pytest.raises(ValueError, match="num_draft_tokens"): + NgramDrafter(num_draft_tokens=0) + with pytest.raises(ValueError, match="min_match"): + NgramDrafter(num_draft_tokens=2, min_match=0) + with pytest.raises(ValueError, match="max_match"): + NgramDrafter(num_draft_tokens=2, min_match=4, max_match=2) + + +def test_drafter_modes_match_implementation_class() -> None: + """Each concrete drafter exposes the right ``mode`` literal.""" + assert NoSpecDrafter().mode == "none" + assert NgramDrafter(num_draft_tokens=2).mode == "ngram" + + +def test_make_drafter_dispatches_correctly() -> None: + none_drafter = make_drafter( + mode="none", num_draft_tokens=4, draft_model=None, draft_cache=None + ) + assert isinstance(none_drafter, NoSpecDrafter) + ngram_drafter = make_drafter( + mode="ngram", num_draft_tokens=4, draft_model=None, draft_cache=None + ) + assert isinstance(ngram_drafter, NgramDrafter) + + +def test_make_drafter_rejects_model_without_pieces() -> None: + with pytest.raises(ValueError, match="draft_model"): + make_drafter( + mode="model", num_draft_tokens=4, draft_model=None, draft_cache=None + ) + + +def test_ngram_drafter_proposal_caps_at_k() -> None: + # The spec loop tops up ``K = min(max_tokens - ntoks, num_draft_tokens)`` + # before each round; the proposer must respect that cap so we don't + # overrun ``max_tokens`` in the verify forward. + drafter = NgramDrafter(num_draft_tokens=10, min_match=2, max_match=4) + result = drafter.propose([1, 2, 3, 4, 1, 2], 3) + assert len(result) <= 3 diff --git a/src/exo/worker/tests/unittests/test_mlx/test_drafter_builder.py b/src/exo/worker/tests/unittests/test_mlx/test_drafter_builder.py index b1600b8911..67218d78cd 100644 --- a/src/exo/worker/tests/unittests/test_mlx/test_drafter_builder.py +++ b/src/exo/worker/tests/unittests/test_mlx/test_drafter_builder.py @@ -29,7 +29,11 @@ ) -def _build_mlx_builder(*, draft_model: Model | None) -> MlxBuilder: +def _build_mlx_builder( + *, + draft_model: Model | None, + draft_model_id: ModelId | None = None, +) -> MlxBuilder: fake_tokenizer = MagicMock(spec=TokenizerWrapper) fake_tokenizer.has_tool_calling = False fake_tokenizer.tool_call_start = None @@ -45,6 +49,7 @@ def _build_mlx_builder(*, draft_model: Model | None) -> MlxBuilder: group=None, vision_processor=None, draft_model=draft_model, + draft_model_id=draft_model_id, ) @@ -73,10 +78,38 @@ def test_mlx_builder_forces_sequential_when_drafter_loaded( """When a drafter model is present, BatchGenerator can't use it, so we must fall back to SequentialGenerator regardless of EXO_NO_BATCH.""" monkeypatch.delenv("EXO_NO_BATCH", raising=False) + monkeypatch.delenv("EXO_NUM_DRAFT_TOKENS", raising=False) + monkeypatch.delenv("EXO_DRAFTER_MIN_OUTPUT_TOKENS", raising=False) fake_drafter = cast(Model, MagicMock()) - builder = _build_mlx_builder(draft_model=fake_drafter) + drafter_id = ModelId("mlx-community/test-drafter") + builder = _build_mlx_builder(draft_model=fake_drafter, draft_model_id=drafter_id) engine = builder.build() assert isinstance(engine, SequentialGenerator) assert engine.draft_model is fake_drafter + assert engine.draft_model_id == drafter_id + # Defaults should be applied so dashboards see the actual K in use. + assert engine.num_draft_tokens is not None and engine.num_draft_tokens >= 2 + assert ( + engine.drafter_min_output_tokens is not None + and engine.drafter_min_output_tokens > 0 + ) + + +def test_mlx_builder_honours_env_overrides_for_drafter_tuning( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.setenv("EXO_NUM_DRAFT_TOKENS", "7") + monkeypatch.setenv("EXO_DRAFTER_MIN_OUTPUT_TOKENS", "32") + fake_drafter = cast(Model, MagicMock()) + builder = _build_mlx_builder( + draft_model=fake_drafter, + draft_model_id=ModelId("mlx-community/test-drafter"), + ) + + engine = builder.build() + + assert isinstance(engine, SequentialGenerator) + assert engine.num_draft_tokens == 7 + assert engine.drafter_min_output_tokens == 32 diff --git a/src/exo/worker/tests/unittests/test_mlx/test_drafter_loader.py b/src/exo/worker/tests/unittests/test_mlx/test_drafter_loader.py index aef0f0789d..02fa5a39e0 100644 --- a/src/exo/worker/tests/unittests/test_mlx/test_drafter_loader.py +++ b/src/exo/worker/tests/unittests/test_mlx/test_drafter_loader.py @@ -3,10 +3,13 @@ These tests exercise the policy-only branches of drafter loading so they can run in CI without GPUs or downloaded model weights: -- Cards with no drafter return ``None``. +- Cards with no drafters return ``None``. - Drafter weights missing from disk falls back to ``None`` (warned, not errored). - ``EXO_DISABLE_DRAFTER`` short-circuits even when weights are present. +- ``EXO_DRAFTER_PREFERENCE`` picks the right drafter from the candidate list + (fastest = head, highest_acceptance = tail), and on-disk drafters are + preferred over not-yet-downloaded ones. The "actually call ``mlx_lm.utils.load_model``" branch is exercised by the end-to-end smoke harness, not unit tests. @@ -23,7 +26,7 @@ from exo.worker.engines.mlx.types import Model -def _card_with_drafter(drafter_id: ModelId | None) -> ModelCard: +def _card_with_drafters(drafter_ids: list[ModelId]) -> ModelCard: return ModelCard( model_id=ModelId("mlx-community/test-target"), storage_size=Memory.from_gb(1.0), @@ -31,15 +34,15 @@ def _card_with_drafter(drafter_id: ModelId | None) -> ModelCard: hidden_size=768, supports_tensor=True, tasks=["TextGeneration"], # pyright: ignore[reportArgumentType] - drafter_model_id=drafter_id, + drafter_model_ids=drafter_ids, ) -def test_maybe_load_drafter_returns_none_when_no_drafter_declared( +def test_maybe_load_drafter_returns_none_when_no_drafters_declared( monkeypatch: pytest.MonkeyPatch, ) -> None: monkeypatch.delenv(utils_mlx.EXO_DISABLE_DRAFTER_ENV, raising=False) - card = _card_with_drafter(None) + card = _card_with_drafters([]) def fail_resolve(*_args: object, **_kwargs: object) -> Path | None: raise AssertionError("resolve_existing_model should not be called") @@ -53,7 +56,8 @@ def test_maybe_load_drafter_returns_none_when_drafter_weights_missing( monkeypatch: pytest.MonkeyPatch, ) -> None: monkeypatch.delenv(utils_mlx.EXO_DISABLE_DRAFTER_ENV, raising=False) - card = _card_with_drafter(ModelId("mlx-community/missing-drafter")) + monkeypatch.delenv(utils_mlx.EXO_DRAFTER_PREFERENCE_ENV, raising=False) + card = _card_with_drafters([ModelId("mlx-community/missing-drafter")]) def missing_resolve(_model_id: ModelId) -> Path | None: return None @@ -72,7 +76,7 @@ def test_maybe_load_drafter_disabled_by_env_skips_filesystem_check( monkeypatch: pytest.MonkeyPatch, tmp_path: Path ) -> None: monkeypatch.setenv(utils_mlx.EXO_DISABLE_DRAFTER_ENV, "1") - card = _card_with_drafter(ModelId("mlx-community/some-drafter")) + card = _card_with_drafters([ModelId("mlx-community/some-drafter")]) def fail_resolve(*_args: object, **_kwargs: object) -> Path | None: raise AssertionError("resolve_existing_model must not run when disabled") @@ -87,7 +91,8 @@ def test_maybe_load_drafter_swallows_load_errors( ) -> None: """A drafter present on disk that fails to load must not break the target.""" monkeypatch.delenv(utils_mlx.EXO_DISABLE_DRAFTER_ENV, raising=False) - card = _card_with_drafter(ModelId("mlx-community/broken-drafter")) + monkeypatch.delenv(utils_mlx.EXO_DRAFTER_PREFERENCE_ENV, raising=False) + card = _card_with_drafters([ModelId("mlx-community/broken-drafter")]) def fixed_resolve(_model_id: ModelId) -> Path | None: return tmp_path @@ -106,7 +111,8 @@ def test_maybe_load_drafter_returns_loaded_model_on_success( monkeypatch: pytest.MonkeyPatch, tmp_path: Path ) -> None: monkeypatch.delenv(utils_mlx.EXO_DISABLE_DRAFTER_ENV, raising=False) - card = _card_with_drafter(ModelId("mlx-community/fake-drafter")) + monkeypatch.delenv(utils_mlx.EXO_DRAFTER_PREFERENCE_ENV, raising=False) + card = _card_with_drafters([ModelId("mlx-community/fake-drafter")]) def fixed_resolve(_model_id: ModelId) -> Path | None: return tmp_path @@ -127,4 +133,63 @@ def noop_eval(*_args: object, **_kwargs: object) -> None: monkeypatch.setattr(utils_mlx.mx, "eval", noop_eval) result = utils_mlx._maybe_load_drafter(card) # pyright: ignore[reportPrivateUsage] - assert result is cast(Model, sentinel) + assert result is not None + drafter_id, drafter_model = result + assert drafter_id == ModelId("mlx-community/fake-drafter") + assert drafter_model is cast(Model, sentinel) + + +def test_select_drafter_id_default_is_fastest( + monkeypatch: pytest.MonkeyPatch, tmp_path: Path +) -> None: + """When all candidates are on disk and preference is 'fastest' (default), + return the head of the candidate list (smallest by convention).""" + + def resolve_all_on_disk(_model_id: ModelId) -> Path | None: + return tmp_path + + monkeypatch.setattr(utils_mlx, "resolve_existing_model", resolve_all_on_disk) + candidates = [ + ModelId("mlx-community/e2b-drafter"), + ModelId("mlx-community/e4b-drafter"), + ] + chosen = utils_mlx._select_drafter_id(candidates, "fastest") # pyright: ignore[reportPrivateUsage] + assert chosen == ModelId("mlx-community/e2b-drafter") + + +def test_select_drafter_id_highest_acceptance_picks_tail( + monkeypatch: pytest.MonkeyPatch, tmp_path: Path +) -> None: + def resolve_all_on_disk(_model_id: ModelId) -> Path | None: + return tmp_path + + monkeypatch.setattr(utils_mlx, "resolve_existing_model", resolve_all_on_disk) + candidates = [ + ModelId("mlx-community/e2b-drafter"), + ModelId("mlx-community/e4b-drafter"), + ] + chosen = utils_mlx._select_drafter_id(candidates, "highest_acceptance") # pyright: ignore[reportPrivateUsage] + assert chosen == ModelId("mlx-community/e4b-drafter") + + +def test_select_drafter_id_prefers_on_disk( + monkeypatch: pytest.MonkeyPatch, tmp_path: Path +) -> None: + """If the user prefers e4b but only e2b is on disk, fall back to e2b + rather than logging a 'weights missing' warning the user didn't cause.""" + e2b = ModelId("mlx-community/e2b-drafter") + e4b = ModelId("mlx-community/e4b-drafter") + + def resolve_only_e2b(model_id: ModelId) -> Path | None: + return tmp_path if model_id == e2b else None + + monkeypatch.setattr(utils_mlx, "resolve_existing_model", resolve_only_e2b) + chosen = utils_mlx._select_drafter_id([e2b, e4b], "highest_acceptance") # pyright: ignore[reportPrivateUsage] + assert chosen == e2b + + +def test_drafter_preference_unknown_value_falls_back_to_auto( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.setenv(utils_mlx.EXO_DRAFTER_PREFERENCE_ENV, "totally-bogus") + assert utils_mlx._drafter_preference() == "auto" # pyright: ignore[reportPrivateUsage] diff --git a/src/exo/worker/tests/unittests/test_mlx/test_drafter_tuning.py b/src/exo/worker/tests/unittests/test_mlx/test_drafter_tuning.py new file mode 100644 index 0000000000..bdcd650eb0 --- /dev/null +++ b/src/exo/worker/tests/unittests/test_mlx/test_drafter_tuning.py @@ -0,0 +1,173 @@ +"""Tests for drafter tuning knobs (num_draft_tokens, short-skip, env helpers). + +End-to-end MLX inference can't run in unit tests (no GPUs/weights), so we +test the *policy* helpers that decide whether speculative decoding is active +and how many draft tokens to issue per round. +""" + +from typing import cast + +import pytest + +from exo.worker.engines.mlx.generator.generate import resolve_speculative_decoding +from exo.worker.engines.mlx.types import Model +from exo.worker.runner.llm_inference.batch_generator import ( + DEFAULT_DRAFTER_MIN_OUTPUT_TOKENS, + DEFAULT_NUM_DRAFT_TOKENS, + EXO_DRAFTER_MIN_OUTPUT_TOKENS, + EXO_NUM_DRAFT_TOKENS, + adaptive_num_draft_tokens, + parse_env_int, +) + + +def test_parse_env_int_returns_default_when_unset( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.delenv("EXO_FAKE_VAR_FOR_TEST", raising=False) + assert parse_env_int("EXO_FAKE_VAR_FOR_TEST", 5) == 5 + + +def test_parse_env_int_clamps_to_minimum( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.setenv("EXO_FAKE_VAR_FOR_TEST", "0") + assert parse_env_int("EXO_FAKE_VAR_FOR_TEST", 5, minimum=1) == 1 + + +def test_parse_env_int_falls_back_on_garbage( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.setenv("EXO_FAKE_VAR_FOR_TEST", "not-a-number") + assert parse_env_int("EXO_FAKE_VAR_FOR_TEST", 5) == 5 + + +def test_parse_env_int_accepts_valid_value( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.setenv("EXO_FAKE_VAR_FOR_TEST", "9") + assert parse_env_int("EXO_FAKE_VAR_FOR_TEST", 5) == 9 + + +def test_default_constants_are_sane() -> None: + assert DEFAULT_NUM_DRAFT_TOKENS >= 2 + assert DEFAULT_DRAFTER_MIN_OUTPUT_TOKENS > 0 + assert EXO_NUM_DRAFT_TOKENS == "EXO_NUM_DRAFT_TOKENS" + assert EXO_DRAFTER_MIN_OUTPUT_TOKENS == "EXO_DRAFTER_MIN_OUTPUT_TOKENS" + + +def _fake_model() -> Model: + return cast(Model, object()) + + +def test_resolve_speculative_decoding_distributed_drops_drafter() -> None: + """Multi-device runs never pass the drafter through.""" + import mlx.core as mx + + drafter = _fake_model() + fake_group = cast(mx.distributed.Group, object()) + eff, kwargs = resolve_speculative_decoding( + draft_model=drafter, + group=fake_group, + max_tokens=128, + num_draft_tokens=5, + drafter_min_output_tokens=16, + ) + assert eff is None + assert kwargs == {} + + +def test_resolve_speculative_decoding_no_drafter_returns_empty_kwargs() -> None: + eff, kwargs = resolve_speculative_decoding( + draft_model=None, + group=None, + max_tokens=128, + num_draft_tokens=5, + drafter_min_output_tokens=16, + ) + assert eff is None + assert kwargs == {} + + +def test_resolve_speculative_decoding_short_max_tokens_drops_drafter() -> None: + """Item 8: short generations skip the drafter.""" + drafter = _fake_model() + eff, kwargs = resolve_speculative_decoding( + draft_model=drafter, + group=None, + max_tokens=8, + num_draft_tokens=5, + drafter_min_output_tokens=16, + ) + assert eff is None + assert kwargs == {} + + +def test_resolve_speculative_decoding_threshold_boundary_drops_drafter() -> None: + """``<=`` threshold means equality also skips the drafter.""" + drafter = _fake_model() + eff, _ = resolve_speculative_decoding( + draft_model=drafter, + group=None, + max_tokens=16, + num_draft_tokens=5, + drafter_min_output_tokens=16, + ) + assert eff is None + + +def test_resolve_speculative_decoding_passes_k_through() -> None: + """Item 1: num_draft_tokens flows into stream_generate kwargs.""" + drafter = _fake_model() + eff, kwargs = resolve_speculative_decoding( + draft_model=drafter, + group=None, + max_tokens=512, + num_draft_tokens=5, + drafter_min_output_tokens=16, + ) + assert eff is drafter + assert kwargs == {"num_draft_tokens": 5} + + +def test_adaptive_num_draft_tokens_uses_fallback_until_warmup() -> None: + """With <2 observations the controller hasn't warmed up yet.""" + assert adaptive_num_draft_tokens([], fallback=5) == 5 + assert adaptive_num_draft_tokens([0.9], fallback=7) == 7 + + +def test_adaptive_num_draft_tokens_low_acceptance_uses_k2() -> None: + """Drafter is missing badly -- don't waste cycles speculating.""" + assert adaptive_num_draft_tokens([0.1, 0.2, 0.3], fallback=5) == 2 + + +def test_adaptive_num_draft_tokens_mid_acceptance_uses_k4() -> None: + assert adaptive_num_draft_tokens([0.6, 0.65, 0.6], fallback=5) == 4 + + +def test_adaptive_num_draft_tokens_high_acceptance_uses_k6() -> None: + assert adaptive_num_draft_tokens([0.85, 0.9, 0.8], fallback=5) == 6 + + +def test_adaptive_num_draft_tokens_band_boundaries() -> None: + """0.5 is the K=2 -> K=4 boundary; 0.75 is K=4 -> K=6.""" + # average exactly 0.5 -> K=4 (>= 0.5) + assert adaptive_num_draft_tokens([0.5, 0.5], fallback=5) == 4 + # average exactly 0.75 -> K=6 (>= 0.75) + assert adaptive_num_draft_tokens([0.75, 0.75], fallback=5) == 6 + # average just under 0.5 -> K=2 + assert adaptive_num_draft_tokens([0.499, 0.499], fallback=5) == 2 + + +def test_resolve_speculative_decoding_no_k_means_no_kwarg() -> None: + """If caller doesn't override K, mlx_lm uses its default (currently 2).""" + drafter = _fake_model() + eff, kwargs = resolve_speculative_decoding( + draft_model=drafter, + group=None, + max_tokens=512, + num_draft_tokens=None, + drafter_min_output_tokens=16, + ) + assert eff is drafter + assert kwargs == {} diff --git a/src/exo/worker/tests/unittests/test_mlx/test_pipelined_drafter.py b/src/exo/worker/tests/unittests/test_mlx/test_pipelined_drafter.py new file mode 100644 index 0000000000..e1fb6f14ed --- /dev/null +++ b/src/exo/worker/tests/unittests/test_mlx/test_pipelined_drafter.py @@ -0,0 +1,1063 @@ +"""Tests for :mod:`pipelined_drafter` and :mod:`drafter_transport`. + +The cross-round speculation accounting is the only complex piece, so +these tests focus on: + + * The :class:`DrafterTransport` Protocol contract (any implementation + that satisfies the Protocol must accept the call sequence the spec + loop emits). + * The spec-loop's cache-trim arithmetic for partial accept, full + accept, speculation hit, and speculation miss -- exercised through + a deterministic fake transport that records every call so we can + assert on the trim/forward sequence without spinning up MLX + weights. + * Transport-kind parsing (``EXO_DRAFTER_TRANSPORT`` env var). + +End-to-end correctness with real MLX weights is exercised by the smoke ++ bench scripts; this file stays MLX-free so it runs in seconds on CI. +""" + +from __future__ import annotations + +from concurrent.futures import Future +from dataclasses import dataclass, field +from typing import Final + +import pytest + +from exo.worker.engines.mlx.generator.drafter_transport import ( + ALL_TRANSPORT_KINDS, + EXO_DRAFTER_TRANSPORT_ENV, + DrafterTransport, + DraftFuture, + clamp_num_draft_tokens_to_transport, + parse_transport_kind, + transport_factory_for, +) + +# --------------------------------------------------------------------------- +# Test fixtures: deterministic fake transport +# --------------------------------------------------------------------------- + + +@dataclass +class _Call: + """One method call against the fake transport, in arrival order.""" + + kind: str # "forward" or "trim" + inputs: tuple[int, ...] = () + num_forwards: int = 0 + n_positions: int = 0 + + +@dataclass +class _ForwardScript: + """Pre-recorded outputs for the next ``forward`` call.""" + + outputs: list[int] + + +@dataclass +class FakeTransport: + """A :class:`DrafterTransport` that records calls and returns scripted drafts. + + Used to exercise the spec loop's bookkeeping without running MLX. + Every ``forward`` consumes one entry from ``script``; if the script + is exhausted, the test has hit a code path it didn't predict and + the transport raises (failing the test loudly). + """ + + num_draft_tokens_value: int + script: list[_ForwardScript] = field(default_factory=list) + calls: list[_Call] = field(default_factory=list) + cache_offset: int = 0 + + @property + def num_draft_tokens(self) -> int: + return self.num_draft_tokens_value + + def forward(self, inputs: list[int], num_forwards: int) -> DraftFuture: + if not 1 <= num_forwards <= self.num_draft_tokens_value + 1: + raise ValueError(f"num_forwards out of bounds: {num_forwards}") + if not 1 <= len(inputs) <= 2: + raise ValueError(f"inputs length out of bounds: {len(inputs)}") + if not self.script: + raise AssertionError( + "FakeTransport.forward called without script entry; " + "test missed a code path" + ) + entry = self.script.pop(0) + if len(entry.outputs) != num_forwards: + raise AssertionError( + f"Script entry has {len(entry.outputs)} outputs; " + f"forward asked for {num_forwards}" + ) + self.calls.append( + _Call(kind="forward", inputs=tuple(inputs), num_forwards=num_forwards) + ) + # Cache extends by ``len(inputs) + num_forwards - 1`` per spec. + self.cache_offset += len(inputs) + num_forwards - 1 + future: DraftFuture = Future() + future.set_result(list(entry.outputs)) + return future + + def trim_cache(self, n_positions: int) -> None: + if n_positions < 0: + raise ValueError(f"n_positions must be >= 0, got {n_positions}") + if n_positions > self.cache_offset: + raise AssertionError( + f"Trim {n_positions} would exceed cache offset {self.cache_offset}; " + "spec loop is over-trimming" + ) + self.calls.append(_Call(kind="trim", n_positions=n_positions)) + self.cache_offset -= n_positions + + def reset_and_prefill(self, prompt_tokens: list[int]) -> None: + # Mirror RemoteTransport semantics: reset cache to 0, then + # extend by len(prompt_tokens). The FakeTransport doesn't + # actually run a model, so the offset bookkeeping is the only + # observable side-effect tests care about. + self.cache_offset = len(prompt_tokens) + self.calls.append( + _Call(kind="reset_and_prefill", n_positions=len(prompt_tokens)) + ) + + def shutdown(self) -> None: + return + + +def test_fake_transport_satisfies_protocol() -> None: + """The fake transport must structurally satisfy :class:`DrafterTransport`.""" + transport: DrafterTransport = FakeTransport(num_draft_tokens_value=4) + assert isinstance(transport, DrafterTransport) + + +# --------------------------------------------------------------------------- +# Transport-kind parsing +# --------------------------------------------------------------------------- + + +_KIND_DEFAULT: Final[str] = "inprocess" + + +@pytest.mark.parametrize( + ("raw", "expected"), + [ + (None, _KIND_DEFAULT), + ("inprocess", "inprocess"), + ("INPROCESS", "inprocess"), + (" inprocess ", "inprocess"), + ], +) +def test_parse_transport_kind_recognised(raw: str | None, expected: str) -> None: + """Only ``inprocess`` is a valid transport-kind keyword. + + The legacy ``"remote"`` keyword was a factory hint for the + ``mx.distributed``-backed asymmetric drafter; the v3+ asymmetric + wire is built directly from the runner bootstrap with a connected + socket and never goes through the env-var factory. + """ + assert parse_transport_kind(raw, _KIND_DEFAULT) == expected + + +def test_parse_transport_kind_rejects_legacy_remote() -> None: + """Legacy ``"remote"`` keyword falls back to the default with a warning. + + The asymmetric remote transport is built directly from the runner + bootstrap in v3+; an env-var hint of ``"remote"`` no longer has a + factory backing and must degrade to ``inprocess`` rather than crash. + """ + assert parse_transport_kind("remote", _KIND_DEFAULT) == _KIND_DEFAULT + assert parse_transport_kind("Remote", _KIND_DEFAULT) == _KIND_DEFAULT + + +def test_parse_transport_kind_falls_back_for_unknown() -> None: + # Unknown kinds warn and fall back to the default rather than + # raising; that mirrors how ``parse_draft_mode`` handles unknown + # ``EXO_DRAFT_MODE`` values. + assert parse_transport_kind("totally-bogus", _KIND_DEFAULT) == _KIND_DEFAULT + + +def test_all_transport_kinds_match_factory_dispatch() -> None: + """Every kind in :data:`ALL_TRANSPORT_KINDS` must have a factory. + + The factory may raise ``NotImplementedError`` (Layer B's remote + transport does), but :func:`transport_factory_for` itself must + always return a callable -- the dispatch table is part of the + public contract. + """ + for kind in ALL_TRANSPORT_KINDS: + factory = transport_factory_for(kind) + assert callable(factory) + + +def test_transport_factory_for_rejects_unknown() -> None: + with pytest.raises(ValueError, match="Unknown drafter transport kind"): + transport_factory_for("totally-bogus") + + +# --------------------------------------------------------------------------- +# Spec loop arithmetic via the fake transport +# --------------------------------------------------------------------------- + + +# These tests exercise the cache-trim arithmetic *as the spec loop +# emits it*, without running the MLX target. We construct call traces +# the loop would produce for a known accept pattern and assert the +# trim/forward sequence matches the formula derived in the +# pipelined_drafter module docstring. +# +# Strategy: don't actually run the spec loop (which needs an MLX +# target). Instead, simulate the spec loop's transport calls +# imperatively for each scenario and assert the cache offset / call +# sequence matches what the docstring promises. + + +class TestSpecLoopArithmetic: + """Trace the transport-call sequence for canonical accept patterns.""" + + def test_partial_accept_no_speculation(self) -> None: + """Partial accept (n=2 of K=4): trim K-n-1 = 1, propose [target_correction].""" + k = 4 + n = 2 + transport = FakeTransport( + num_draft_tokens_value=k, + script=[ + # Round 0: 4 drafts. + _ForwardScript(outputs=[10, 11, 12, 13]), + # Round 1: 4 drafts after partial-accept setup. + _ForwardScript(outputs=[20, 21, 22, 23]), + ], + ) + + # Round 0 propose. + drafts = transport.forward([1], k).result() + assert drafts == [10, 11, 12, 13] + assert transport.cache_offset == k # 4 positions + + # Spec loop: partial accept after target verify (n=2, drafts[2] mismatched). + # Transport bookkeeping for next round: + # * trim k - n - 1 = 1 position + # * propose [target_correction] (length 1), k outputs + transport.trim_cache(k - n - 1) + assert transport.cache_offset == k - 1 # 3 positions + + # Next round propose with length-1 input. + next_drafts = transport.forward([99], k).result() + assert next_drafts == [20, 21, 22, 23] + # Cache extends by k (length-1 input + k-1 length-1 forwards = k). + assert transport.cache_offset == k - 1 + k # 7 positions + + # Verify call trace. + assert [c.kind for c in transport.calls] == [ + "forward", + "trim", + "forward", + ] + assert transport.calls[1].n_positions == 1 + + def test_full_accept_no_speculation(self) -> None: + """Full accept (n=k): no trim; next round propose has length-2 input.""" + k = 4 + transport = FakeTransport( + num_draft_tokens_value=k, + script=[ + _ForwardScript(outputs=[10, 11, 12, 13]), + _ForwardScript(outputs=[20, 21, 22, 23]), + ], + ) + + transport.forward([1], k).result() + assert transport.cache_offset == k + + # Full accept: no trim. Next round propose with [drafts[-1], bonus]. + next_drafts = transport.forward([13, 99], k).result() + assert next_drafts == [20, 21, 22, 23] + # Cache extends by k + 1 (length-2 input + k-1 length-1 forwards). + assert transport.cache_offset == k + (k + 1) + + assert [c.kind for c in transport.calls] == ["forward", "forward"] + assert transport.calls[1].inputs == (13, 99) + assert transport.calls[1].num_forwards == k + + def test_speculation_hit(self) -> None: + """Full accept + speculation hit: round t+1 drafts come for free.""" + k = 4 + transport = FakeTransport( + num_draft_tokens_value=k, + script=[ + # Round 0 propose: [10, 11, 12, 13]. + _ForwardScript(outputs=[10, 11, 12, 13]), + # Speculative round (input=[13], k+1 outputs): + # outputs[0] = drafter's bonus prediction; outputs[1..k] = round + # 1's drafts. + _ForwardScript(outputs=[99, 30, 31, 32, 33]), + ], + ) + + # Round 0 propose. + round0_drafts = transport.forward([1], k).result() + assert round0_drafts == [10, 11, 12, 13] + + # Speculative call. + spec_outputs = transport.forward([13], k + 1).result() + assert spec_outputs == [99, 30, 31, 32, 33] + # After speculation: cache extended by k (round 0) + (k + 1) + # (speculation) = 2k+1 positions. + assert transport.cache_offset == k + (k + 1) + + # Speculation hit: target's bonus_t == 99 == spec_outputs[0]. + # Round 1's drafts = spec_outputs[1:k+1]. + round1_drafts = spec_outputs[1 : k + 1] + assert round1_drafts == [30, 31, 32, 33] + + # No additional transport calls (drafter cache state already + # correct for round 1). + assert [c.kind for c in transport.calls] == ["forward", "forward"] + + def test_speculation_miss_full_accept(self) -> None: + """Full accept but bonus mismatched: rollback k+1, length-2 propose.""" + k = 4 + transport = FakeTransport( + num_draft_tokens_value=k, + script=[ + _ForwardScript(outputs=[10, 11, 12, 13]), + _ForwardScript(outputs=[88, 80, 81, 82, 83]), # speculative + _ForwardScript(outputs=[40, 41, 42, 43]), # round 1 standard + ], + ) + + transport.forward([1], k).result() + spec_outputs = transport.forward([13], k + 1).result() + # bonus_t = 99 (target), spec_outputs[0] = 88 -> miss. + + # Rollback the k+1 speculative positions. + transport.trim_cache(k + 1) + assert transport.cache_offset == k # back to round-0 state + + # Standard length-2-seed propose for round 1: [drafts[-1], bonus_t]. + round1_drafts = transport.forward([13, 99], k).result() + assert round1_drafts == [40, 41, 42, 43] + + del spec_outputs + kinds = [c.kind for c in transport.calls] + assert kinds == ["forward", "forward", "trim", "forward"] + assert transport.calls[2].n_positions == k + 1 + assert transport.calls[3].inputs == (13, 99) + + def test_speculation_miss_partial_accept(self) -> None: + """Partial accept with speculation in flight: rollback k+1 + partial trim.""" + k = 4 + n = 2 + transport = FakeTransport( + num_draft_tokens_value=k, + script=[ + _ForwardScript(outputs=[10, 11, 12, 13]), + _ForwardScript(outputs=[88, 80, 81, 82, 83]), # speculative + _ForwardScript(outputs=[50, 51, 52, 53]), # round 1 + ], + ) + + transport.forward([1], k).result() + transport.forward([13], k + 1).result() + # cache offset: k + (k + 1) = 2k + 1 = 9 + + # Partial accept at round 0: speculation is invalid AND partial + # trim is needed. The combined trim is (k + 1) + (k - n - 1). + combined_trim = (k + 1) + (k - n - 1) + transport.trim_cache(combined_trim) + # cache offset: 2k + 1 - combined_trim = n + 1 = 3 + assert transport.cache_offset == n + 1 + + # Round 1 standard propose with length-1 input. + round1_drafts = transport.forward([99], k).result() + assert round1_drafts == [50, 51, 52, 53] + + kinds = [c.kind for c in transport.calls] + assert kinds == ["forward", "forward", "trim", "forward"] + assert transport.calls[2].n_positions == combined_trim + + +# --------------------------------------------------------------------------- +# PipelinedModelDrafter wiring +# --------------------------------------------------------------------------- + + +def test_pipelined_drafter_mode_is_pipelined() -> None: + # Imported lazily so this file stays importable without the drafter + # module's MLX-bound siblings; the import itself is what we're + # exercising (catches accidental syntax errors in pipelined_drafter + # that the type checker might miss for runtime-only paths). + from exo.worker.engines.mlx.generator.pipelined_drafter import ( + PipelinedModelDrafter, + ) + + transport = FakeTransport(num_draft_tokens_value=4) + drafter = PipelinedModelDrafter(transport=transport, num_draft_tokens=4) + assert drafter.mode == "pipelined" + assert drafter.num_draft_tokens == 4 + + +def test_pipelined_drafter_validates_num_draft_tokens() -> None: + from exo.worker.engines.mlx.generator.pipelined_drafter import ( + PipelinedModelDrafter, + ) + + transport = FakeTransport(num_draft_tokens_value=4) + with pytest.raises(ValueError, match="num_draft_tokens"): + PipelinedModelDrafter(transport=transport, num_draft_tokens=0) + with pytest.raises(ValueError, match="exceeds transport's max"): + PipelinedModelDrafter(transport=transport, num_draft_tokens=10) + + +def test_pipelined_drafter_shutdown_delegates() -> None: + """Shutdown should propagate to the transport so remote serve loops drain cleanly.""" + from exo.worker.engines.mlx.generator.pipelined_drafter import ( + PipelinedModelDrafter, + ) + + shutdown_calls: list[None] = [] + + class _ShutdownRecorder(FakeTransport): + def shutdown(self) -> None: + shutdown_calls.append(None) + + transport = _ShutdownRecorder(num_draft_tokens_value=4) + drafter = PipelinedModelDrafter(transport=transport, num_draft_tokens=4) + drafter.shutdown() + assert len(shutdown_calls) == 1 + + +# --------------------------------------------------------------------------- +# Transport-kind environment plumbing +# --------------------------------------------------------------------------- + + +def test_make_drafter_pipelined_without_model_or_transport_raises( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """``make_drafter("pipelined", ...)`` requires either a model+cache or a transport. + + The env-var-driven factory path is gone in v3+ (asymmetric remote + transport is constructed directly by the runner bootstrap). Calling + ``make_drafter`` with neither a builder-supplied transport nor a + drafter model + cache must raise a clear error -- it has no way to + construct the in-process transport. + """ + from exo.worker.engines.mlx.generator.drafter import make_drafter + + monkeypatch.delenv(EXO_DRAFTER_TRANSPORT_ENV, raising=False) + with pytest.raises(ValueError, match="pipelined"): + make_drafter( + mode="pipelined", + num_draft_tokens=4, + draft_model=None, + draft_cache=None, + ) + + +# --------------------------------------------------------------------------- +# Asymmetric placement entry points +# --------------------------------------------------------------------------- + + +def test_make_drafter_uses_supplied_pipelined_transport() -> None: + """When ``pipelined_transport`` is supplied, ``make_drafter`` must reuse it. + + Asymmetric placement allocates a long-lived RemoteTransport at + SequentialGenerator build time so executor + drafter cache lifecycle + aren't paid per-request. The factory entry point must accept that + pre-built transport instead of constructing a new one. + """ + from exo.worker.engines.mlx.generator.drafter import make_drafter + from exo.worker.engines.mlx.generator.pipelined_drafter import ( + PipelinedModelDrafter, + ) + + transport = FakeTransport(num_draft_tokens_value=4) + drafter = make_drafter( + mode="pipelined", + num_draft_tokens=4, + draft_model=None, + draft_cache=None, + pipelined_transport=transport, + ) + assert isinstance(drafter, PipelinedModelDrafter) + # The drafter must wrap the supplied transport, not a freshly- + # constructed one (would be a behavioural regression because the + # remote drafter cache + executor would be leaked on every request). + drafter.shutdown() + assert transport.calls == [] # FakeTransport.shutdown is a no-op + + +def test_make_drafter_rejects_non_protocol_pipelined_transport() -> None: + """``pipelined_transport`` must implement ``DrafterTransport``.""" + from exo.worker.engines.mlx.generator.drafter import make_drafter + + class NotATransport: + pass + + with pytest.raises(TypeError, match="DrafterTransport"): + make_drafter( + mode="pipelined", + num_draft_tokens=4, + draft_model=None, + draft_cache=None, + pipelined_transport=NotATransport(), + ) + + +class TestClampNumDraftTokensToTransport: + """Per-request K must be clamped to the transport's wire-protocol max. + + Regression coverage: aborted K=8 sweep at 14:35:05 raised + ``ValueError`` deep inside :class:`PipelinedModelDrafter` and killed + the target runner subprocess (PR #15). The clamp helper exists so + ``generate.py`` can defend the runner from malformed per-request + overrides without ever reaching the drafter constructor. + """ + + def test_clamp_no_op_when_request_within_budget(self) -> None: + transport = FakeTransport(num_draft_tokens_value=5) + clamped, was_clamped = clamp_num_draft_tokens_to_transport(3, transport) + assert clamped == 3 + assert was_clamped is False + + def test_clamp_no_op_when_request_equals_budget(self) -> None: + transport = FakeTransport(num_draft_tokens_value=5) + clamped, was_clamped = clamp_num_draft_tokens_to_transport(5, transport) + assert clamped == 5 + assert was_clamped is False + + def test_clamp_applies_when_request_exceeds_budget(self) -> None: + transport = FakeTransport(num_draft_tokens_value=5) + clamped, was_clamped = clamp_num_draft_tokens_to_transport(8, transport) + assert clamped == 5 + assert was_clamped is True + + def test_clamp_pathological_request(self) -> None: + transport = FakeTransport(num_draft_tokens_value=5) + clamped, was_clamped = clamp_num_draft_tokens_to_transport(1024, transport) + assert clamped == 5 + assert was_clamped is True + + def test_clamp_rejects_zero_or_negative(self) -> None: + transport = FakeTransport(num_draft_tokens_value=5) + with pytest.raises(ValueError, match="requested_num_draft_tokens"): + clamp_num_draft_tokens_to_transport(0, transport) + with pytest.raises(ValueError, match="requested_num_draft_tokens"): + clamp_num_draft_tokens_to_transport(-1, transport) + + def test_clamped_k_constructs_pipelined_drafter_safely(self) -> None: + """Smoke: clamped K must satisfy ``PipelinedModelDrafter`` validation. + + The whole point of the clamp is that the value flowing into + :class:`PipelinedModelDrafter` never exceeds ``transport.num_draft_tokens``. + Construct the drafter with the clamped K to prove the pre-fix + regression path is gone. + """ + from exo.worker.engines.mlx.generator.pipelined_drafter import ( + PipelinedModelDrafter, + ) + + transport = FakeTransport(num_draft_tokens_value=5) + # Pre-fix: K=8 raised ValueError here and killed the subprocess. + clamped, _ = clamp_num_draft_tokens_to_transport(8, transport) + drafter = PipelinedModelDrafter(transport=transport, num_draft_tokens=clamped) + assert drafter.num_draft_tokens == 5 + + +def test_make_drafter_pipelined_multi_target_requires_target_group() -> None: + """V2 boundary: multi-target asymmetric requires a target_group for the + rank-0 -> peer broadcast of drafts each round. Building the root-side + drafter without ``target_group`` is a configuration error: the spec + loop would race on a missing collective and silently desync. + """ + from exo.worker.engines.mlx.generator.drafter import make_drafter + + transport = FakeTransport(num_draft_tokens_value=4) + with pytest.raises(ValueError, match="requires target_group"): + make_drafter( + mode="pipelined", + num_draft_tokens=4, + draft_model=None, + draft_cache=None, + pipelined_transport=transport, + target_subgroup_size=2, + target_group=None, + ) + + +def test_make_drafter_pipelined_consumer_rank_requires_target_group() -> None: + """V2 boundary: a non-root target rank (no transport) must receive a + ``target_group`` so the broadcast can land. Without it the consumer + drafter would have no way to obtain drafts and the round 0 verify + would deadlock against the root's TP collective. + """ + from exo.worker.engines.mlx.generator.drafter import make_drafter + + with pytest.raises(ValueError, match="requires target_group"): + make_drafter( + mode="pipelined", + num_draft_tokens=4, + draft_model=None, + draft_cache=None, + pipelined_transport=None, + target_subgroup_size=2, + target_group=None, + is_target_root=False, + ) + + +def test_make_drafter_pipelined_consumer_for_three_target_ranks() -> None: + """V2 multi-target with N target ranks (N >= 2): every non-root rank + constructs the same transport-less consumer drafter. Exercise N=3 + explicitly so the broadcast contract is not implicitly bound to + ``target_subgroup_size == 2`` (the case the cluster bench covers). + """ + from exo.worker.engines.mlx.generator.drafter import make_drafter + from exo.worker.engines.mlx.generator.pipelined_drafter import ( + PipelinedModelDrafter, + ) + + class _StubGroup: + def size(self) -> int: + return 3 + + def rank(self) -> int: + return 2 + + drafter = make_drafter( + mode="pipelined", + num_draft_tokens=4, + draft_model=None, + draft_cache=None, + pipelined_transport=None, + target_subgroup_size=3, + target_group=_StubGroup(), + is_target_root=False, + ) + assert isinstance(drafter, PipelinedModelDrafter) + assert drafter.mode == "pipelined" + assert drafter.num_draft_tokens == 4 + + +def test_make_drafter_pipelined_root_for_three_target_ranks() -> None: + """V2 multi-target root with N=3 ranks: identical contract to N=2 + -- the root owns the transport and broadcasts on the target group. + The collective is N-ary (``mx.distributed.all_sum``), so the + construction has no special-casing for N == 2 and we want a test + asserting that explicitly. + """ + from exo.worker.engines.mlx.generator.drafter import make_drafter + from exo.worker.engines.mlx.generator.pipelined_drafter import ( + PipelinedModelDrafter, + ) + + class _StubGroup: + def size(self) -> int: + return 3 + + def rank(self) -> int: + return 0 + + transport = FakeTransport(num_draft_tokens_value=4) + drafter = make_drafter( + mode="pipelined", + num_draft_tokens=4, + draft_model=None, + draft_cache=None, + pipelined_transport=transport, + target_subgroup_size=3, + target_group=_StubGroup(), + is_target_root=True, + ) + assert isinstance(drafter, PipelinedModelDrafter) + + +# --------------------------------------------------------------------------- +# Broadcast helpers (single-rank short-circuit) +# --------------------------------------------------------------------------- + + +class TestBroadcastDrafts: + """``_broadcast_drafts`` length-prefix encoding contract. + + Multi-rank behaviour is covered by the cluster bench (real + ``mx.distributed.all_sum``). The single-rank short-circuit is the + only path we can exercise in unit tests, but it captures the most + important contract bug: the length-prefix decoder rejecting + nonsensical lengths from a corrupted broadcast. + """ + + def test_single_rank_short_circuit_root(self) -> None: + from exo.worker.engines.mlx.generator.pipelined_drafter import ( + _broadcast_drafts, # pyright: ignore[reportPrivateUsage] + ) + + out: list[int] = _broadcast_drafts( + [10, 20], + k=4, + target_group=None, + target_peer_fanout=None, + is_root=True, + ) + assert out == [10, 20] + + def test_single_rank_short_circuit_consumer_rejected(self) -> None: + # Consumer rank in single-rank mode is a configuration bug -- + # there's no peer to broadcast from. Surface it loudly. + from exo.worker.engines.mlx.generator.pipelined_drafter import ( + _broadcast_drafts, # pyright: ignore[reportPrivateUsage] + ) + + with pytest.raises(RuntimeError, match="non-root"): + _broadcast_drafts( + None, + k=4, + target_group=None, + target_peer_fanout=None, + is_root=False, + ) + + def test_single_rank_root_requires_drafts(self) -> None: + from exo.worker.engines.mlx.generator.pipelined_drafter import ( + _broadcast_drafts, # pyright: ignore[reportPrivateUsage] + ) + + # ``drafts is None`` on root in the short-circuit path is a + # caller bug (the runner never has a None drafts list when it + # owns the wire). + with pytest.raises(RuntimeError, match="non-root"): + _broadcast_drafts( + None, + k=4, + target_group=None, + target_peer_fanout=None, + is_root=False, + ) + + +class TestBroadcastTargetTokens: + """``_broadcast_target_tokens`` carries the verifier's sampled + tokens from rank 0 to non-root target ranks so accept/reject is + bit-identical across the target subgroup. + + Without this broadcast, every rank's ``mx.random.categorical`` call + returns RNG-divergent tokens (default temperature is 0.7 in the + API path), the ranks reach different ``num_accepted``, trim the + target's prompt cache by different amounts, and the next TP + forward consumes mismatched cache state -- a silent garbage-output + bug. These tests pin the contract so a future refactor can't + accidentally drop the broadcast. + """ + + def test_single_rank_short_circuit_root(self) -> None: + from exo.worker.engines.mlx.generator.pipelined_drafter import ( + _broadcast_target_tokens, # pyright: ignore[reportPrivateUsage] + ) + + # k_this + 1 == 3 tokens: the seed-bonus + drafts emitted per + # round in a K=4, k_this=2 partial round. + out: list[int] = _broadcast_target_tokens( + [10, 20, 30], + k=4, + k_this=2, + target_group=None, + target_peer_fanout=None, + is_root=True, + ) + assert out == [10, 20, 30] + + def test_single_rank_consumer_rejected(self) -> None: + from exo.worker.engines.mlx.generator.pipelined_drafter import ( + _broadcast_target_tokens, # pyright: ignore[reportPrivateUsage] + ) + + with pytest.raises(RuntimeError, match="non-root"): + _broadcast_target_tokens( + None, + k=4, + k_this=2, + target_group=None, + target_peer_fanout=None, + is_root=False, + ) + + def test_root_rejects_wrong_length(self) -> None: + # Verifier always emits exactly ``k_this + 1`` tokens; anything + # else means the spec loop is calling the broadcast with stale + # state. Raise rather than silently right-pad. + from exo.worker.engines.mlx.generator.pipelined_drafter import ( + _broadcast_target_tokens, # pyright: ignore[reportPrivateUsage] + ) + + with pytest.raises(RuntimeError, match="must equal k_this"): + _broadcast_target_tokens( + [10, 20], + k=4, + k_this=2, + target_group=None, + target_peer_fanout=None, + is_root=True, + ) + + +def test_make_drafter_pipelined_root_rank_with_no_transport_rejected() -> None: + """Configuration error: ``is_target_root=True`` implies this rank owns + the drafter socket; the caller must pass a transport. Reaching the + multi-target consumer branch with ``is_target_root=True`` is a + placement bug we want to surface loudly rather than silently drop. + """ + from exo.worker.engines.mlx.generator.drafter import make_drafter + + class _StubGroup: + def size(self) -> int: + return 2 + + def rank(self) -> int: + return 0 + + with pytest.raises(ValueError, match="is_target_root=True"): + make_drafter( + mode="pipelined", + num_draft_tokens=4, + draft_model=None, + draft_cache=None, + pipelined_transport=None, + target_subgroup_size=2, + target_group=_StubGroup(), + is_target_root=True, + ) + + +# --------------------------------------------------------------------------- +# Drafter-death recovery: abort sentinel + wrap behaviour +# --------------------------------------------------------------------------- + + +class TestDrafterAbortRecovery: + """Recovery contract when the drafter rank dies mid-generation. + + Pre-fix failure mode: root's ``transport.forward`` raised + ``OSError`` and re-raised cleanly out of ``mlx_generate``, but + non-root target ranks blocked indefinitely on the next-round + draft broadcast (the sole inter-rank coordination channel for + spec decode). The abort sentinel + wrap + ``RemoteTransport`` + failure flag together convert that hang into a fast, lockstep + exit on every rank, with the runner subprocess crashing so the + master's instance-deletion path can rebuild the placement. + + The cluster bench covers the full multi-rank flow against real + ``mx.distributed``; these unit tests pin the single-rank + invariants that are reachable without spinning up a peer group. + """ + + def test_broadcast_abort_short_circuits_without_group(self) -> None: + # ``target_group is None`` (single-rank placement) means there + # are no peers to notify; the abort broadcast must be a no-op + # rather than raising or contacting any wire layer. + from exo.worker.engines.mlx.generator.pipelined_drafter import ( + _broadcast_abort, # pyright: ignore[reportPrivateUsage] + ) + + # Should not raise; should not require any group machinery. + _broadcast_abort(k=4, target_group=None, target_peer_fanout=None) + + def test_sentinel_value_is_in_validator_range(self) -> None: + # The sentinel must satisfy ``_validate_broadcast_values`` + # (positive int32) so a real cluster broadcast doesn't reject + # it before non-root ranks have a chance to decode it. + from exo.worker.engines.mlx.generator.pipelined_drafter import ( + DRAFT_ABORT_SENTINEL, + ) + from exo.worker.engines.mlx.utils_mlx import ( + _MX_BROADCAST_MAX_VALUE, # pyright: ignore[reportPrivateUsage] + _validate_broadcast_values, # pyright: ignore[reportPrivateUsage] + ) + + assert 0 < DRAFT_ABORT_SENTINEL < _MX_BROADCAST_MAX_VALUE + # Must also exceed any plausible draft length so it can never + # collide with a legitimate length-prefix. + assert DRAFT_ABORT_SENTINEL > 1_000_000 + # Validator round-trip with the wire payload root would emit. + _validate_broadcast_values([DRAFT_ABORT_SENTINEL] + [0] * 4) + + def test_broadcast_drafts_decodes_sentinel_to_abort_error( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + # Multi-rank receive path: when ``mx_broadcast_int_list`` + # returns a buffer whose length-prefix is the sentinel, + # ``_broadcast_drafts`` raises ``DrafterAbortedError`` so the + # spec loop can exit in lockstep with the dead root rank. + from exo.worker.engines.mlx.generator import pipelined_drafter + from exo.worker.engines.mlx.generator.pipelined_drafter import ( + DRAFT_ABORT_SENTINEL, + DrafterAbortedError, + _broadcast_drafts, # pyright: ignore[reportPrivateUsage] + ) + + k = 4 + + def fake_broadcast( + values: list[int] | None, + length: int, + group: object, + *, + is_root: bool, + ) -> list[int]: + del values, group, is_root + assert length == k + 1 + return [DRAFT_ABORT_SENTINEL] + [0] * k + + monkeypatch.setattr(pipelined_drafter, "mx_broadcast_int_list", fake_broadcast) + + sentinel_group = object() # opaque; the fake never inspects + with pytest.raises(DrafterAbortedError, match="drafter aborted"): + _broadcast_drafts( + None, + k=k, + target_group=sentinel_group, # pyright: ignore[reportArgumentType] + target_peer_fanout=None, + is_root=False, + ) + + def test_spec_step_wrap_root_broadcasts_abort_on_oserror( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + # Inject a body that immediately raises OSError; the wrap + # must call ``_broadcast_abort`` (root path) before re-raising + # so non-root ranks unblock their pending broadcast. + from exo.worker.engines.mlx.generator import pipelined_drafter + + broadcast_calls: list[tuple[int, object]] = [] + + def fake_abort( + *, k: int, target_group: object, target_peer_fanout: object + ) -> None: + del target_peer_fanout + broadcast_calls.append((k, target_group)) + + def fake_body(**kwargs: object): + del kwargs + raise ConnectionError("drafter rank closed mid-frame") + yield # pragma: no cover -- generator marker + + monkeypatch.setattr(pipelined_drafter, "_broadcast_abort", fake_abort) + monkeypatch.setattr( + pipelined_drafter, + "_pipelined_speculative_step_body", + fake_body, + ) + + sentinel_group = object() + gen = pipelined_drafter._pipelined_speculative_step( # pyright: ignore[reportPrivateUsage] + prompt=None, # pyright: ignore[reportArgumentType] + model=None, # pyright: ignore[reportArgumentType] + transport=None, + prompt_cache=None, # pyright: ignore[reportArgumentType] + max_tokens=8, + sampler=lambda x: x, + logits_processors=[], + num_draft_tokens=4, + prefill_step_size=512, + prompt_token_count=0, + target_group=sentinel_group, # pyright: ignore[reportArgumentType] + is_target_root=True, + ) + with pytest.raises(ConnectionError, match="drafter rank closed"): + next(gen) + assert broadcast_calls == [(4, sentinel_group)] + + def test_spec_step_wrap_non_root_does_not_broadcast( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + # Non-root has no transport to fail on; if a non-root somehow + # raises OSError (e.g. a peer-side issue surfaces this way), + # we must NOT issue an abort broadcast -- only root owns that + # signal. Re-raising preserves the original error for the + # caller's traceback without a phantom broadcast. + from exo.worker.engines.mlx.generator import pipelined_drafter + + broadcast_calls: list[tuple[int, object]] = [] + + def fake_abort( + *, k: int, target_group: object, target_peer_fanout: object + ) -> None: + del target_peer_fanout + broadcast_calls.append((k, target_group)) + + def fake_body(**kwargs: object): + del kwargs + raise ConnectionError("non-root saw socket failure somehow") + yield # pragma: no cover + + monkeypatch.setattr(pipelined_drafter, "_broadcast_abort", fake_abort) + monkeypatch.setattr( + pipelined_drafter, + "_pipelined_speculative_step_body", + fake_body, + ) + + gen = pipelined_drafter._pipelined_speculative_step( # pyright: ignore[reportPrivateUsage] + prompt=None, # pyright: ignore[reportArgumentType] + model=None, # pyright: ignore[reportArgumentType] + transport=None, + prompt_cache=None, # pyright: ignore[reportArgumentType] + max_tokens=8, + sampler=lambda x: x, + logits_processors=[], + num_draft_tokens=4, + prefill_step_size=512, + prompt_token_count=0, + target_group=object(), # pyright: ignore[reportArgumentType] + is_target_root=False, + ) + with pytest.raises(ConnectionError): + next(gen) + assert broadcast_calls == [] + + def test_spec_step_wrap_swallows_abort_broadcast_failure( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + # If the abort broadcast itself fails (e.g. ``target_group`` + # is also dead), the original transport error must still + # surface intact -- the master's instance-deletion path is + # the SIGKILL backstop, so swallowing the recovery error + # avoids masking the root cause in the caller's traceback. + from exo.worker.engines.mlx.generator import pipelined_drafter + + def fake_abort( + *, k: int, target_group: object, target_peer_fanout: object + ) -> None: + del k, target_group, target_peer_fanout + raise RuntimeError("group is also dead") + + def fake_body(**kwargs: object): + del kwargs + raise ConnectionError("primary failure") + yield # pragma: no cover + + monkeypatch.setattr(pipelined_drafter, "_broadcast_abort", fake_abort) + monkeypatch.setattr( + pipelined_drafter, + "_pipelined_speculative_step_body", + fake_body, + ) + + gen = pipelined_drafter._pipelined_speculative_step( # pyright: ignore[reportPrivateUsage] + prompt=None, # pyright: ignore[reportArgumentType] + model=None, # pyright: ignore[reportArgumentType] + transport=None, + prompt_cache=None, # pyright: ignore[reportArgumentType] + max_tokens=8, + sampler=lambda x: x, + logits_processors=[], + num_draft_tokens=4, + prefill_step_size=512, + prompt_token_count=0, + target_group=object(), # pyright: ignore[reportArgumentType] + is_target_root=True, + ) + with pytest.raises(ConnectionError, match="primary failure"): + next(gen) diff --git a/src/exo/worker/tests/unittests/test_mlx/test_remote_drafter.py b/src/exo/worker/tests/unittests/test_mlx/test_remote_drafter.py new file mode 100644 index 0000000000..7beb4316e9 --- /dev/null +++ b/src/exo/worker/tests/unittests/test_mlx/test_remote_drafter.py @@ -0,0 +1,642 @@ +"""Tests for :mod:`remote_drafter` -- wire protocol + transport behaviour. + +The asymmetric drafter wire is a plain TCP socket under the v3+ design; +unit tests use ``socket.socketpair()`` to exercise both sides of the +protocol end-to-end without an MLX backend or extra processes. End-to- +end correctness against a real cluster is exercised by the multi-host +benchmark runs, not in unit tests. +""" + +from __future__ import annotations + +import socket +import struct +import threading +from collections.abc import Iterator + +import pytest + +from exo.worker.engines.mlx.generator.remote_drafter import ( + ACK_FRAME_SIZE, + ACK_OK, + COMMAND_FRAME_SIZE, + OP_END_SESSION, + OP_FORWARD, + OP_PREFILL, + OP_SHUTDOWN, + OP_TRIM_CACHE, + SESSION_ID_NONE, + RemoteTransport, + _build_command_frame, # type: ignore[reportPrivateUsage] + _decode_command_frame, # type: ignore[reportPrivateUsage] +) + +# --------------------------------------------------------------------------- +# Wire protocol: command frames +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + ("op", "inputs", "num_forwards", "trim_amount", "session_id"), + [ + (OP_FORWARD, [42], 4, 0, 0), + (OP_FORWARD, [10, 20], 5, 0, 7), + (OP_TRIM_CACHE, [], 0, 7, 3), + (OP_SHUTDOWN, [], 0, 0, SESSION_ID_NONE), + (OP_PREFILL, [], 1024, 0, 1), + (OP_PREFILL, [], 0, 0, 0), + (OP_END_SESSION, [], 0, 0, 42), + (OP_FORWARD, [1], 2, 0, 0xFFFFFFFE), + ], +) +def test_command_frame_round_trip( + op: int, + inputs: list[int], + num_forwards: int, + trim_amount: int, + session_id: int, +) -> None: + """Every command shape we send must round-trip through encode + decode.""" + flat = _build_command_frame( + op=op, + inputs=inputs, + num_forwards=num_forwards, + trim_amount=trim_amount, + session_id=session_id, + ) + assert len(flat) == COMMAND_FRAME_SIZE + + decoded_op, decoded_inputs, decoded_num_forwards, decoded_trim, decoded_sid = ( + _decode_command_frame(flat) + ) + assert decoded_op == op + assert decoded_inputs == inputs + assert decoded_num_forwards == num_forwards + assert decoded_trim == trim_amount + assert decoded_sid == session_id + + +def test_command_frame_rejects_long_inputs() -> None: + with pytest.raises(ValueError, match=r"inputs length must be in \[0, 2\]"): + _build_command_frame( + op=OP_FORWARD, + inputs=[1, 2, 3], + num_forwards=4, + trim_amount=0, + session_id=0, + ) + + +def test_command_frame_rejects_session_id_out_of_uint32_range() -> None: + with pytest.raises(ValueError, match=r"session_id must fit in uint32"): + _build_command_frame( + op=OP_FORWARD, + inputs=[1], + num_forwards=2, + trim_amount=0, + session_id=2**33, + ) + + +def test_decode_rejects_wrong_size() -> None: + with pytest.raises(ValueError, match=r"expected 9"): + _decode_command_frame([0, 0, 0]) + + +# --------------------------------------------------------------------------- +# Helpers for socketpair-based wire tests +# --------------------------------------------------------------------------- + + +def _socket_pair() -> tuple[socket.socket, socket.socket]: + """Return ``(target_side, drafter_side)`` connected unix sockets.""" + target_side, drafter_side = socket.socketpair(socket.AF_UNIX, socket.SOCK_STREAM) + target_side.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) + return target_side, drafter_side + + +def _read_uint32s(sock: socket.socket, count: int) -> list[int]: + needed = count * 4 + buf = bytearray(needed) + received = 0 + while received < needed: + view = memoryview(buf)[received:] + chunk = sock.recv_into(view, needed - received) + if chunk == 0: + raise ConnectionError( + f"socket closed mid-frame ({received}/{needed} bytes)" + ) + received += chunk + return list(struct.unpack(f"<{count}I", bytes(buf))) + + +def _write_uint32s(sock: socket.socket, values: list[int]) -> None: + sock.sendall(struct.pack(f"<{len(values)}I", *values)) + + +def _make_transport( + num_draft_tokens: int = 4, +) -> tuple[RemoteTransport, socket.socket]: + """Build a :class:`RemoteTransport` paired with a drafter-side socket.""" + target_sock, drafter_sock = _socket_pair() + transport = RemoteTransport(num_draft_tokens=num_draft_tokens, sock=target_sock) + return transport, drafter_sock + + +# --------------------------------------------------------------------------- +# RemoteTransport (target side) over a real socket pair +# --------------------------------------------------------------------------- + + +def test_open_session_allocates_unique_session_ids() -> None: + transport, drafter_side = _make_transport() + try: + a = transport.open_session() + b = transport.open_session() + c = transport.open_session() + assert a.session_id != b.session_id + assert b.session_id != c.session_id + assert a.num_draft_tokens == transport.num_draft_tokens + finally: + # Drain anything pending (forwarding ends + transport shutdown). + # We never sent any commands, so the wire is clean. Close the + # drafter side first so the transport's shutdown gets a clean + # peer-closed signal instead of hanging on the ack recv. + drafter_side.close() + transport.shutdown() + + +def test_session_handle_forward_serialises_command_with_session_id() -> None: + transport, drafter_side = _make_transport() + try: + session = transport.open_session() + future = session.forward([42], num_forwards=4) + + cmd = _read_uint32s(drafter_side, COMMAND_FRAME_SIZE) + op, inputs, num_forwards, trim, sid = _decode_command_frame(cmd) + assert op == OP_FORWARD + assert inputs == [42] + assert num_forwards == 4 + assert trim == 0 + assert sid == session.session_id + + # Reply with K+1 = 5 drafts; the spec loop slices to num_forwards. + _write_uint32s(drafter_side, [10, 11, 12, 13, 0]) + assert future.result() == [10, 11, 12, 13] + finally: + drafter_side.close() + transport.shutdown() + + +def test_session_handle_trim_cache_emits_session_scoped_command() -> None: + transport, drafter_side = _make_transport() + try: + session = transport.open_session() + + def _trim() -> None: + session.trim_cache(3) + + thread = threading.Thread(target=_trim) + thread.start() + cmd = _read_uint32s(drafter_side, COMMAND_FRAME_SIZE) + op, _, _, trim, sid = _decode_command_frame(cmd) + assert op == OP_TRIM_CACHE + assert trim == 3 + assert sid == session.session_id + _write_uint32s(drafter_side, [ACK_OK]) + thread.join(timeout=2.0) + assert not thread.is_alive() + finally: + drafter_side.close() + transport.shutdown() + + +def test_session_handle_trim_cache_zero_is_noop() -> None: + transport, drafter_side = _make_transport() + try: + session = transport.open_session() + session.trim_cache(0) + # Nothing must have been written: drafter_side.recv with + # MSG_DONTWAIT should fail with BlockingIOError. + drafter_side.setblocking(False) + with pytest.raises(BlockingIOError): + drafter_side.recv(1) + finally: + drafter_side.setblocking(True) + drafter_side.close() + transport.shutdown() + + +def test_session_handle_reset_and_prefill_sends_command_array_and_recv_ack() -> None: + transport, drafter_side = _make_transport() + try: + session = transport.open_session() + prompt = [101, 102, 103, 104, 105] + + def _prefill() -> None: + session.reset_and_prefill(prompt) + + thread = threading.Thread(target=_prefill) + thread.start() + + cmd = _read_uint32s(drafter_side, COMMAND_FRAME_SIZE) + op, inputs, num_forwards, trim, sid = _decode_command_frame(cmd) + assert op == OP_PREFILL + assert inputs == [] + assert num_forwards == len(prompt) + assert trim == 0 + assert sid == session.session_id + + # Length-prefixed prompt tail: 1 uint32 header + N tokens. + header = _read_uint32s(drafter_side, 1)[0] + assert header == len(prompt) + tokens = _read_uint32s(drafter_side, len(prompt)) + assert tokens == prompt + + _write_uint32s(drafter_side, [ACK_OK]) + thread.join(timeout=2.0) + assert not thread.is_alive() + finally: + drafter_side.close() + transport.shutdown() + + +def test_session_handle_reset_and_prefill_empty_prompt_skips_array_send() -> None: + transport, drafter_side = _make_transport() + try: + session = transport.open_session() + + def _prefill() -> None: + session.reset_and_prefill([]) + + thread = threading.Thread(target=_prefill) + thread.start() + + cmd = _read_uint32s(drafter_side, COMMAND_FRAME_SIZE) + op, _, num_forwards, _, _ = _decode_command_frame(cmd) + assert op == OP_PREFILL + assert num_forwards == 0 + + # No length-prefixed payload should follow on an empty prompt. + # Confirm by acking immediately and joining. + _write_uint32s(drafter_side, [ACK_OK]) + thread.join(timeout=2.0) + assert not thread.is_alive() + finally: + drafter_side.close() + transport.shutdown() + + +def test_session_handle_shutdown_sends_op_end_session() -> None: + transport, drafter_side = _make_transport() + try: + session = transport.open_session() + + def _shutdown() -> None: + session.shutdown() + + thread = threading.Thread(target=_shutdown) + thread.start() + + cmd = _read_uint32s(drafter_side, COMMAND_FRAME_SIZE) + op, _, _, _, sid = _decode_command_frame(cmd) + assert op == OP_END_SESSION + assert sid == session.session_id + _write_uint32s(drafter_side, [ACK_OK]) + thread.join(timeout=2.0) + assert not thread.is_alive() + + # Idempotent: a second shutdown is a no-op (no new wire op). + session.shutdown() + drafter_side.setblocking(False) + with pytest.raises(BlockingIOError): + drafter_side.recv(1) + finally: + drafter_side.setblocking(True) + drafter_side.close() + transport.shutdown() + + +def test_session_handle_rejects_use_after_shutdown() -> None: + transport, drafter_side = _make_transport() + try: + session = transport.open_session() + + def _shutdown_then_ack() -> None: + cmd = _read_uint32s(drafter_side, COMMAND_FRAME_SIZE) + op, _, _, _, _ = _decode_command_frame(cmd) + assert op == OP_END_SESSION + _write_uint32s(drafter_side, [ACK_OK]) + + thread = threading.Thread(target=_shutdown_then_ack) + thread.start() + session.shutdown() + thread.join(timeout=2.0) + + with pytest.raises(RuntimeError, match="after shutdown"): + _ = session.forward([1], num_forwards=2) + with pytest.raises(RuntimeError, match="after shutdown"): + session.trim_cache(2) + with pytest.raises(RuntimeError, match="after shutdown"): + session.reset_and_prefill([1, 2, 3]) + finally: + drafter_side.close() + transport.shutdown() + + +def test_remote_transport_shutdown_sends_op_and_drains_executor() -> None: + transport, drafter_side = _make_transport() + try: + + def _ack_shutdown() -> None: + cmd = _read_uint32s(drafter_side, COMMAND_FRAME_SIZE) + op, _, _, _, _ = _decode_command_frame(cmd) + assert op == OP_SHUTDOWN + _write_uint32s(drafter_side, [ACK_OK]) + + thread = threading.Thread(target=_ack_shutdown) + thread.start() + transport.shutdown() + thread.join(timeout=2.0) + + # Idempotent: a second shutdown is a no-op (no new wire op). + transport.shutdown() + finally: + drafter_side.close() + + +def test_remote_transport_rejects_use_after_shutdown() -> None: + transport, drafter_side = _make_transport() + + def _ack_shutdown() -> None: + try: + cmd = _read_uint32s(drafter_side, COMMAND_FRAME_SIZE) + op, _, _, _, _ = _decode_command_frame(cmd) + assert op == OP_SHUTDOWN + _write_uint32s(drafter_side, [ACK_OK]) + except (ConnectionError, OSError): + pass + + thread = threading.Thread(target=_ack_shutdown) + thread.start() + transport.shutdown() + thread.join(timeout=2.0) + + with pytest.raises(RuntimeError, match="after shutdown"): + _ = transport.open_session() + drafter_side.close() + + +def test_remote_transport_rejects_invalid_num_draft_tokens() -> None: + target_sock, drafter_sock = _socket_pair() + try: + with pytest.raises(ValueError, match="num_draft_tokens"): + RemoteTransport(num_draft_tokens=0, sock=target_sock) + finally: + target_sock.close() + drafter_sock.close() + + +# --------------------------------------------------------------------------- +# Drafter-death recovery: ``RemoteTransport.is_failed`` flag +# --------------------------------------------------------------------------- + + +def test_remote_transport_is_failed_starts_false() -> None: + """A freshly-constructed transport is healthy.""" + transport, drafter_side = _make_transport() + try: + assert transport.is_failed is False + finally: + drafter_side.close() + transport.shutdown() + + +def test_remote_transport_marks_failed_when_drafter_closes_mid_forward() -> None: + """The blocking forward helper flips ``is_failed`` on socket close. + + Pre-fix failure mode: a peer-side close mid-frame raised + ``ConnectionError`` once but left the transport looking healthy, + so subsequent ``open_session`` calls would happily allocate a + fresh session against a dead wire and the spec loop would re- + discover the failure on every request. + """ + transport, drafter_side = _make_transport() + try: + session = transport.open_session() + future = session.forward([42], num_forwards=4) + # Drain the command frame so the drafter side is in a known + # state, then close it before responding -- this models a + # drafter rank that crashed after receiving the request. + _ = _read_uint32s(drafter_side, COMMAND_FRAME_SIZE) + drafter_side.close() + with pytest.raises((ConnectionError, OSError)): + future.result(timeout=2.0) + assert transport.is_failed is True + finally: + # ``shutdown`` is best-effort against a dead wire; the + # contextlib.suppress inside it swallows the secondary error. + transport.shutdown() + + +def test_remote_transport_open_session_rejects_after_failure() -> None: + """Once a wire-level failure has surfaced, no new session is allowed. + + Subsequent requests must NOT allocate a fresh session on a known- + dead wire -- the runner will be torn down by the master's + instance-deletion path and a new placement issued. ``open_session`` + raising RuntimeError is the fail-loud signal that bridges that + gap. + """ + transport, drafter_side = _make_transport() + try: + session = transport.open_session() + future = session.forward([42], num_forwards=4) + _ = _read_uint32s(drafter_side, COMMAND_FRAME_SIZE) + drafter_side.close() + with pytest.raises((ConnectionError, OSError)): + future.result(timeout=2.0) + assert transport.is_failed is True + + with pytest.raises(RuntimeError, match="wire-level failure"): + _ = transport.open_session() + finally: + transport.shutdown() + + +def test_remote_transport_marks_failed_when_drafter_closes_mid_trim() -> None: + """The trim helper also flips ``is_failed`` on socket close. + + Trim is on the cache-reconciliation path between rounds; failure + here surfaces the same way as a forward failure and must mark + the transport so the next request fails fast. + """ + transport, drafter_side = _make_transport() + try: + session = transport.open_session() + + def _do_trim() -> Exception | None: + try: + session.trim_cache(3) + except Exception as exc: + return exc + return None + + result_box: list[Exception | None] = [] + + def _runner() -> None: + result_box.append(_do_trim()) + + thread = threading.Thread(target=_runner) + thread.start() + # Drain the command frame, then drop the connection without + # acking -- mid-trim drafter death. + _ = _read_uint32s(drafter_side, COMMAND_FRAME_SIZE) + drafter_side.close() + thread.join(timeout=2.0) + assert not thread.is_alive() + assert isinstance(result_box[0], (ConnectionError, OSError)) + assert transport.is_failed is True + finally: + transport.shutdown() + + +# --------------------------------------------------------------------------- +# drafter_serve_loop dispatch +# --------------------------------------------------------------------------- + + +def _empty_cache_factory() -> object: + """Drop-in factory for tests that don't actually run forwards.""" + return [] + + +def test_drafter_serve_loop_handles_shutdown_immediately() -> None: + """A bare OP_SHUTDOWN frame must terminate the serve loop with an ACK.""" + from exo.worker.engines.mlx.generator.remote_drafter import drafter_serve_loop + + target_sock, drafter_sock = _socket_pair() + try: + # Write the shutdown frame from the target side BEFORE entering + # the serve loop so the recv inside the loop completes + # immediately. + shutdown_frame = _build_command_frame( + op=OP_SHUTDOWN, + inputs=[], + num_forwards=0, + trim_amount=0, + session_id=SESSION_ID_NONE, + ) + _write_uint32s(target_sock, shutdown_frame) + + drafter_serve_loop( + draft_model=None, # pyright: ignore[reportArgumentType] + make_draft_cache=_empty_cache_factory, # pyright: ignore[reportArgumentType] + num_draft_tokens=4, + sock=drafter_sock, + ) + + ack = _read_uint32s(target_sock, ACK_FRAME_SIZE) + assert ack[0] == ACK_OK + finally: + target_sock.close() + drafter_sock.close() + + +def test_drafter_serve_loop_handles_end_session_for_unknown_session() -> None: + """``OP_END_SESSION`` for an unknown session is a successful no-op ack. + + Idempotent semantics: a target that crashed without sending + ``OP_END_SESSION`` for a session is cleaned up by the next + ``OP_SHUTDOWN``; targets that retry ``OP_END_SESSION`` after a + transient network error still see an ack. + """ + from exo.worker.engines.mlx.generator.remote_drafter import drafter_serve_loop + + target_sock, drafter_sock = _socket_pair() + try: + end_frame = _build_command_frame( + op=OP_END_SESSION, + inputs=[], + num_forwards=0, + trim_amount=0, + session_id=99, + ) + shutdown_frame = _build_command_frame( + op=OP_SHUTDOWN, + inputs=[], + num_forwards=0, + trim_amount=0, + session_id=SESSION_ID_NONE, + ) + _write_uint32s(target_sock, end_frame) + _write_uint32s(target_sock, shutdown_frame) + + drafter_serve_loop( + draft_model=None, # pyright: ignore[reportArgumentType] + make_draft_cache=_empty_cache_factory, # pyright: ignore[reportArgumentType] + num_draft_tokens=4, + sock=drafter_sock, + ) + + ack_end = _read_uint32s(target_sock, ACK_FRAME_SIZE) + ack_shutdown = _read_uint32s(target_sock, ACK_FRAME_SIZE) + assert ack_end[0] == ACK_OK + assert ack_shutdown[0] == ACK_OK + finally: + target_sock.close() + drafter_sock.close() + + +def test_drafter_serve_loop_rejects_unknown_op() -> None: + """An unknown op code must crash the serve loop loudly.""" + from exo.worker.engines.mlx.generator.remote_drafter import drafter_serve_loop + + target_sock, drafter_sock = _socket_pair() + try: + # Hand-build an unknown op code (255 is not a defined op). + bogus = [255, 0, 0, 0, 0, 0, 0, 0, 0] + _write_uint32s(target_sock, bogus) + + with pytest.raises(RuntimeError, match="Unknown op code"): + drafter_serve_loop( + draft_model=None, # pyright: ignore[reportArgumentType] + make_draft_cache=_empty_cache_factory, # pyright: ignore[reportArgumentType] + num_draft_tokens=4, + sock=drafter_sock, + ) + finally: + target_sock.close() + drafter_sock.close() + + +def test_drafter_serve_loop_rejects_op_for_unknown_session() -> None: + """``OP_TRIM_CACHE`` / ``OP_FORWARD`` against an unallocated session crashes.""" + from exo.worker.engines.mlx.generator.remote_drafter import drafter_serve_loop + + target_sock, drafter_sock = _socket_pair() + try: + trim_frame = _build_command_frame( + op=OP_TRIM_CACHE, + inputs=[], + num_forwards=0, + trim_amount=2, + session_id=42, + ) + _write_uint32s(target_sock, trim_frame) + + with pytest.raises(RuntimeError, match="OP_TRIM_CACHE for unknown session"): + drafter_serve_loop( + draft_model=None, # pyright: ignore[reportArgumentType] + make_draft_cache=_empty_cache_factory, # pyright: ignore[reportArgumentType] + num_draft_tokens=4, + sock=drafter_sock, + ) + finally: + target_sock.close() + drafter_sock.close() + + +# Used by other tests that need to import _ from this module without +# triggering "unused" linter errors on intermediate Iterator hints. +_ = Iterator diff --git a/src/exo/worker/tests/unittests/test_mlx/test_utils_mlx_broadcast.py b/src/exo/worker/tests/unittests/test_mlx/test_utils_mlx_broadcast.py new file mode 100644 index 0000000000..2b00bbf0d5 --- /dev/null +++ b/src/exo/worker/tests/unittests/test_mlx/test_utils_mlx_broadcast.py @@ -0,0 +1,369 @@ +"""Unit tests for the MLX utility primitives used by the V2 multi-target spec loop. + +These exercise the contracts that the asymmetric pipelined drafter +relies on for cross-rank determinism without spinning up MLX or +``mx.distributed``: + + * :func:`mx_broadcast_int_list` -- length / range / root contract. + The single-rank short-circuit can be exercised directly; the + multi-rank ``all_sum`` path is covered indirectly because it + delegates value validation to the same helper. + * :func:`_validate_broadcast_values` -- the int32 bounds are tighter + than Python's ``int`` range, so out-of-range values from a callsite + bug must raise rather than wrap silently. + * :func:`_encode_task_id` / :func:`_decode_task_id` -- ASCII codec + used by ``mx_all_gather_tasks`` to broadcast canonical task IDs. + Round-trip and bounds are verifiable without MLX. + * :func:`mx_all_gather_tasks` -- the single-rank short-circuit. The + multi-rank root-authoritative agreement path needs an actual + ``mx.distributed`` group, so we cover the structural contract here + and the cluster bench exercises the real collective. + +Kept MLX-free so it runs in milliseconds on CI alongside the rest of +the unittest suite. +""" + +from __future__ import annotations + +import pytest + +from exo.shared.types.common import CommandId, ModelId +from exo.shared.types.tasks import TaskId, TextGeneration +from exo.shared.types.text_generation import ( + InputMessage, + InputMessageContent, + TextGenerationTaskParams, +) +from exo.shared.types.worker.instances import InstanceId +from exo.worker.engines.mlx.utils_mlx import ( + _MX_BROADCAST_MAX_VALUE, # pyright: ignore[reportPrivateUsage] + _MX_TASK_ID_BYTES, # pyright: ignore[reportPrivateUsage] + _decode_task_id, # pyright: ignore[reportPrivateUsage] + _encode_task_id, # pyright: ignore[reportPrivateUsage] + _validate_broadcast_values, # pyright: ignore[reportPrivateUsage] + mx_all_gather_tasks, + mx_broadcast_int_list, +) + +# --------------------------------------------------------------------------- +# Validation helper (unit, no MLX needed) +# --------------------------------------------------------------------------- + + +class TestValidateBroadcastValues: + """``_validate_broadcast_values`` rejects values that would corrupt + the int32 ``all_sum`` buffer: negatives wrap on cast, and values + >= 2**31 overflow on sum.""" + + def test_accepts_zero(self) -> None: + _validate_broadcast_values([0, 0, 0]) + + def test_accepts_typical_token_ids(self) -> None: + # Gemma-4 vocab is ~256k; well inside int32 positive range. + _validate_broadcast_values([0, 1, 256_000, 999_999]) + + def test_accepts_max_value(self) -> None: + _validate_broadcast_values([_MX_BROADCAST_MAX_VALUE]) + + def test_rejects_negative(self) -> None: + with pytest.raises(ValueError, match="out of range"): + _validate_broadcast_values([0, -1, 0]) + + def test_rejects_overflow(self) -> None: + with pytest.raises(ValueError, match="out of range"): + _validate_broadcast_values([_MX_BROADCAST_MAX_VALUE + 1]) + + def test_error_includes_offending_index(self) -> None: + with pytest.raises(ValueError, match=r"index 2 = -7"): + _validate_broadcast_values([0, 1, -7, 3]) + + +# --------------------------------------------------------------------------- +# mx_broadcast_int_list (single-rank short-circuit + contract) +# --------------------------------------------------------------------------- + + +class TestMxBroadcastIntListSingleRank: + """The ``group is None`` short-circuit covers single-rank deployments + (the V1 single-target path and the non-distributed test fakes). + Multi-rank cluster behaviour is exercised by the cluster bench + because it needs a real ``mx.distributed`` group.""" + + def test_returns_values_when_root(self) -> None: + result = mx_broadcast_int_list([1, 2, 3], length=3, group=None, is_root=True) + assert result == [1, 2, 3] + # Returned list must be a new object so mutating it doesn't + # corrupt the caller's source list. + result[0] = 99 + # No assertion on the source -- just exercising that the call + # didn't share storage. ``list(values)`` semantics. + + def test_rejects_zero_length(self) -> None: + with pytest.raises(ValueError, match="length must be >= 1"): + mx_broadcast_int_list([], length=0, group=None, is_root=True) + + def test_rejects_length_mismatch(self) -> None: + with pytest.raises(ValueError, match="length 3"): + mx_broadcast_int_list([1, 2], length=3, group=None, is_root=True) + + def test_rejects_none_values_on_root(self) -> None: + with pytest.raises(ValueError, match="length 3"): + mx_broadcast_int_list(None, length=3, group=None, is_root=True) + + def test_rejects_consumer_in_single_rank(self) -> None: + # Only the root has source values; ``is_root=False`` with no + # group means there's no peer to broadcast from -- caller bug. + with pytest.raises(ValueError, match="single-rank short-circuit"): + mx_broadcast_int_list([1, 2, 3], length=3, group=None, is_root=False) + + def test_validates_values_on_root(self) -> None: + with pytest.raises(ValueError, match="out of range"): + mx_broadcast_int_list([0, -1], length=2, group=None, is_root=True) + + +# --------------------------------------------------------------------------- +# Task-list hashing (drift detection) +# --------------------------------------------------------------------------- + + +def _make_task(task_id: str) -> TextGeneration: + """Build a minimal :class:`TextGeneration` for hash-based drift tests. + + The hash function only inspects ``task_id``; the rest of the fields + are filled with the smallest valid values that satisfy Pydantic's + strict-mode validation. Keep the construction here so the cluster- + facing types' field churn doesn't ripple through every assertion + body. + """ + return TextGeneration( + task_id=TaskId(task_id), + instance_id=InstanceId(), + command_id=CommandId(), + task_params=TextGenerationTaskParams( + model=ModelId("mlx-community/test-model"), + input=[ + InputMessage(role="user", content=InputMessageContent("hello")), + ], + max_output_tokens=1, + ), + ) + + +class TestTaskIdCodec: + """``_encode_task_id`` / ``_decode_task_id`` are the wire codec for + the root-authoritative agreement protocol. Round-trip must be + exact and bounds must be enforced; otherwise a corrupt payload + silently misagrees on which task to admit.""" + + def test_round_trip_uuid4(self) -> None: + ident = "01234567-89ab-cdef-0123-456789abcdef" + encoded = _encode_task_id(ident) + assert len(encoded) == _MX_TASK_ID_BYTES + assert _decode_task_id(encoded) == ident + + def test_short_id_is_zero_padded(self) -> None: + encoded = _encode_task_id("alpha") + # Trailing slots stay zero so the decoder's null terminator + # logic stops at the right place. + assert encoded[5:] == [0] * (_MX_TASK_ID_BYTES - 5) + assert _decode_task_id(encoded) == "alpha" + + def test_rejects_oversize_id(self) -> None: + too_long = "a" * (_MX_TASK_ID_BYTES + 1) + with pytest.raises(ValueError, match="exceeds"): + _encode_task_id(too_long) + + def test_rejects_non_ascii_byte_on_decode(self) -> None: + bogus = [200] + [0] * (_MX_TASK_ID_BYTES - 1) + with pytest.raises(ValueError, match="outside ASCII range"): + _decode_task_id(bogus) + + def test_decoder_stops_at_null(self) -> None: + # Two real chars, then a null, then garbage: decoder must + # stop at the null and ignore the trailing data. + slots = [ord("a"), ord("b"), 0, ord("z")] + [0] * (_MX_TASK_ID_BYTES - 4) + assert _decode_task_id(slots) == "ab" + + +# --------------------------------------------------------------------------- +# mx_all_gather_tasks single-rank short-circuit +# --------------------------------------------------------------------------- + + +class TestMxAllGatherTasksSingleRank: + """Single-rank short-circuit: returns the local task list as-is and + never invokes a collective. The multi-rank root-authoritative path + needs an actual ``mx.distributed`` group and is exercised by the + cluster bench.""" + + def test_empty_input(self) -> None: + agreed, different = mx_all_gather_tasks([], group=None) + assert agreed == [] + assert different == [] + + def test_passes_through_tasks(self) -> None: + tasks = [_make_task("task-1"), _make_task("task-2")] + agreed, different = mx_all_gather_tasks(tasks, group=None) + assert agreed == tasks + assert different == [] + + def test_returns_a_copy(self) -> None: + # The caller mutates ``self._maybe_queue`` after the gather; + # the returned list must be a different object so post-gather + # mutation doesn't corrupt the agreement view. + tasks = [_make_task("task-1")] + agreed, _different = mx_all_gather_tasks(tasks, group=None) + assert agreed is not tasks + + +# --------------------------------------------------------------------------- +# Two-phase intersection agreement: end-to-end via in-process simulation +# --------------------------------------------------------------------------- + + +def _agree_intersection( + rank_views: list[list[TextGeneration]], +) -> list[tuple[list[TextGeneration], list[TextGeneration]]]: + """Run the two-phase intersection protocol entirely in-process. + + Mirrors :func:`mx_all_gather_tasks` for ``len(rank_views)`` ranks + without spinning up MLX. Phase 1 is root's broadcast (the first + entry in ``rank_views`` is treated as root); phase 2 is the + cross-rank vote (sum of indicator vectors). Returns each rank's + ``(agreed, leftover)`` pair so tests can assert that all ranks + land on the SAME ``agreed`` set, which is the whole point of the + protocol -- without it, divergent admit decisions leave one rank + in the spec loop while the other re-enters ``agree_on_tasks``, + causing collective-stream cross-talk and downstream + ``IndexError`` in the detokenizer when broadcast token slots + arrive scrambled. + """ + from exo.worker.engines.mlx.utils_mlx import ( + _MX_AGREE_BUFFER_LEN, # pyright: ignore[reportPrivateUsage] + _MX_AGREE_MAX_TASKS, # pyright: ignore[reportPrivateUsage] + _MX_TASK_ID_BYTES, # pyright: ignore[reportPrivateUsage] + _decode_task_id, # pyright: ignore[reportPrivateUsage] + _encode_task_id, # pyright: ignore[reportPrivateUsage] + ) + + if not rank_views: + return [] + group_size = len(rank_views) + root_tasks = rank_views[0] + + admitted = root_tasks[:_MX_AGREE_MAX_TASKS] + payload: list[int] = [len(admitted)] + for task in admitted: + payload.extend(_encode_task_id(task.task_id)) + payload.extend([0] * (_MX_AGREE_BUFFER_LEN - len(payload))) + + count = payload[0] + canonical_ids: list[str] = [] + for i in range(count): + start = 1 + i * _MX_TASK_ID_BYTES + end = start + _MX_TASK_ID_BYTES + canonical_ids.append(_decode_task_id(payload[start:end])) + + rank_locals: list[dict[TaskId, TextGeneration]] = [ + {t.task_id: t for t in tasks} for tasks in rank_views + ] + votes_per_rank = [ + [1 if cid in local else 0 for cid in canonical_ids] for local in rank_locals + ] + summed = [sum(votes[i] for votes in votes_per_rank) for i in range(count)] + + results: list[tuple[list[TextGeneration], list[TextGeneration]]] = [] + for local in rank_locals: + agreed: list[TextGeneration] = [] + local_remaining = dict(local) + for i, cid in enumerate(canonical_ids): + if summed[i] != group_size: + continue + task = local_remaining.pop(TaskId(cid), None) + if task is not None: + agreed.append(task) + leftover = list(local_remaining.values()) + results.append((agreed, leftover)) + return results + + +class TestIntersectionAgreement: + """Cross-rank intersection semantics. The protocol's correctness + contract is that every rank that returns from + :func:`mx_all_gather_tasks` lands on the SAME ``agreed`` set, so + the next ``_admit_queued_tasks`` admits identical tasks on every + rank -- preventing the divergence that historically led to + cross-talk between admit collectives and spec-loop collectives.""" + + def test_unanimous_admission(self) -> None: + a_root = _make_task("alpha") + a_peer = _make_task("alpha") + results = _agree_intersection([[a_root], [a_peer]]) + assert len(results) == 2 + for agreed, leftover in results: + assert [t.task_id for t in agreed] == ["alpha"] + assert leftover == [] + + def test_root_only_task_deferred_on_both_ranks(self) -> None: + # Root has task that peer hasn't received yet: NEITHER rank + # admits it. This is the whole reason for intersection + # rather than root-authoritative. + results = _agree_intersection([[_make_task("alpha")], []]) + for agreed, _ in results: + assert agreed == [] + assert [t.task_id for t in results[0][1]] == ["alpha"] + assert results[1][1] == [] + + def test_peer_only_task_deferred_on_both_ranks(self) -> None: + results = _agree_intersection([[], [_make_task("future")]]) + for agreed, _ in results: + assert agreed == [] + assert results[0][1] == [] + assert [t.task_id for t in results[1][1]] == ["future"] + + def test_partial_overlap_only_intersection_admitted(self) -> None: + a_root = _make_task("alpha") + a_peer = _make_task("alpha") + beta = _make_task("beta") + gamma = _make_task("gamma") + results = _agree_intersection([[a_root, beta], [a_peer, gamma]]) + for agreed, _ in results: + assert [t.task_id for t in agreed] == ["alpha"] + assert [t.task_id for t in results[0][1]] == ["beta"] + assert [t.task_id for t in results[1][1]] == ["gamma"] + + def test_three_rank_intersection(self) -> None: + # 3-rank target: agreed is what *every* rank has. Anything + # short of unanimous stays out. + results = _agree_intersection( + [ + [_make_task("alpha"), _make_task("beta")], + [_make_task("alpha"), _make_task("beta")], + [_make_task("alpha")], + ] + ) + for agreed, _ in results: + assert [t.task_id for t in agreed] == ["alpha"] + + def test_canonical_order_is_root_order(self) -> None: + ids_root = ["c", "a", "b"] + ids_peer = ["b", "a", "c"] + results = _agree_intersection( + [ + [_make_task(i) for i in ids_root], + [_make_task(i) for i in ids_peer], + ] + ) + for agreed, _ in results: + assert [t.task_id for t in agreed] == ids_root + + def test_root_caps_at_max_tasks(self) -> None: + from exo.worker.engines.mlx.utils_mlx import ( + _MX_AGREE_MAX_TASKS, # pyright: ignore[reportPrivateUsage] + ) + + many = [_make_task(f"t{i:02d}") for i in range(_MX_AGREE_MAX_TASKS + 4)] + peer_copy = [_make_task(t.task_id) for t in many] + results = _agree_intersection([many, peer_copy]) + for agreed, _ in results: + assert len(agreed) == _MX_AGREE_MAX_TASKS diff --git a/src/exo/worker/tests/unittests/test_plan/test_runner_lifecycle.py b/src/exo/worker/tests/unittests/test_plan/test_runner_lifecycle.py index bfde8a1d1e..6858fab348 100644 --- a/src/exo/worker/tests/unittests/test_plan/test_runner_lifecycle.py +++ b/src/exo/worker/tests/unittests/test_plan/test_runner_lifecycle.py @@ -6,7 +6,9 @@ from exo.shared.types.worker.runners import ( RunnerFailed, RunnerId, + RunnerIdle, RunnerReady, + RunnerRunning, RunnerStatus, ) from exo.utils.keyed_backoff import KeyedBackoff @@ -182,6 +184,147 @@ def test_plan_does_not_create_runner_when_supervisor_already_present(): assert result is None +def test_plan_kills_local_when_peer_cycled_back_to_idle(): + """ + Restart-cascade regression: a peer rank crashed mid-task, its supervisor + immediately respawned a fresh process which emitted ``RunnerIdle``, and + the transient ``RunnerFailed`` window was too short for our plan loop to + observe. The local rank is still ``RunnerRunning`` from before the peer + crash. Without this rule the bootstrap predicate (``all_runners_connecting`` + in ``_init_distributed_backend``) never fires and the respawned peer is + stuck in ``RunnerIdle`` forever. + + See PR #15 (regression: aborted K=8 sweep at 14:35:05). + """ + shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2) + shard2 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2) + instance = get_mlx_ring_instance( + instance_id=INSTANCE_1_ID, + model_id=MODEL_A_ID, + node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID}, + runner_to_shard={RUNNER_1_ID: shard1, RUNNER_2_ID: shard2}, + ) + bound_instance = BoundInstance( + instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A + ) + runner = FakeRunnerSupervisor(bound_instance=bound_instance, status=RunnerRunning()) + + runners = {RUNNER_1_ID: runner} + instances = {INSTANCE_1_ID: instance} + all_runners: dict[RunnerId, RunnerStatus] = { + RUNNER_1_ID: RunnerRunning(), + # Peer just respawned: process is up but hasn't initialized + # the distributed backend yet. + RUNNER_2_ID: RunnerIdle(), + } + + result = plan_mod.plan( + node_id=NODE_A, + runners=runners, # type: ignore[arg-type] + global_download_status={NODE_A: []}, + instances=instances, + all_runners=all_runners, + tasks={}, + input_chunk_buffer={}, + image_cache={}, + instance_backoff=KeyedBackoff(), + download_backoff=KeyedBackoff(), + ) + + assert isinstance(result, Shutdown) + assert result.instance_id == INSTANCE_1_ID + assert result.runner_id == RUNNER_1_ID + + +def test_plan_does_not_kill_local_when_peer_idle_but_local_only_loaded(): + """ + During initial bootstrap a peer can legitimately sit at ``RunnerIdle`` + while we have completed our own ``LoadModel`` (loading is per-rank + without a collective barrier; see ``runner.py`` case ``LoadModel``). + The restart-cascade rule must NOT fire here -- only ``RunnerRunning`` + on the local rank guarantees we previously cleared warmup with all + peers, which is the precondition that makes a peer ``RunnerIdle`` + a process-restart signal. + """ + shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2) + shard2 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2) + instance = get_mlx_ring_instance( + instance_id=INSTANCE_1_ID, + model_id=MODEL_A_ID, + node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID}, + runner_to_shard={RUNNER_1_ID: shard1, RUNNER_2_ID: shard2}, + ) + bound_instance = BoundInstance( + instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A + ) + runner = FakeRunnerSupervisor(bound_instance=bound_instance, status=RunnerReady()) + + runners = {RUNNER_1_ID: runner} + instances = {INSTANCE_1_ID: instance} + all_runners: dict[RunnerId, RunnerStatus] = { + RUNNER_1_ID: RunnerReady(), + RUNNER_2_ID: RunnerIdle(), + } + + result = plan_mod.plan( + node_id=NODE_A, + runners=runners, # type: ignore[arg-type] + global_download_status={NODE_A: []}, + instances=instances, + all_runners=all_runners, + tasks={}, + input_chunk_buffer={}, + image_cache={}, + instance_backoff=KeyedBackoff(), + download_backoff=KeyedBackoff(), + ) + + assert not isinstance(result, Shutdown), ( + "RunnerReady + peer=Idle is normal initial bootstrap; cascade " + "rule must only fire after the local rank has been observed in " + "RunnerRunning (proving warmup completed for all ranks)" + ) + + +def test_plan_does_not_kill_single_rank_instance_on_idle_self(): + """ + The restart-cascade rule must only fire on multi-rank instances. For a + single-rank instance the local runner cycling through ``RunnerIdle`` + on its own is a normal transient (initial spawn) and there is no peer + that needs to re-bootstrap. + """ + shard = get_pipeline_shard_metadata(model_id=MODEL_A_ID, device_rank=0) + instance = get_mlx_ring_instance( + instance_id=INSTANCE_1_ID, + model_id=MODEL_A_ID, + node_to_runner={NODE_A: RUNNER_1_ID}, + runner_to_shard={RUNNER_1_ID: shard}, + ) + bound_instance = BoundInstance( + instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A + ) + runner = FakeRunnerSupervisor(bound_instance=bound_instance, status=RunnerRunning()) + + runners = {RUNNER_1_ID: runner} + instances = {INSTANCE_1_ID: instance} + all_runners: dict[RunnerId, RunnerStatus] = {RUNNER_1_ID: RunnerRunning()} + + result = plan_mod.plan( + node_id=NODE_A, + runners=runners, # type: ignore[arg-type] + global_download_status={NODE_A: []}, + instances=instances, + all_runners=all_runners, + tasks={}, + input_chunk_buffer={}, + image_cache={}, + instance_backoff=KeyedBackoff(), + download_backoff=KeyedBackoff(), + ) + + assert not isinstance(result, Shutdown) + + def test_plan_does_not_create_runner_for_unassigned_node(): """ If this node does not appear in shard_assignments.node_to_runner, diff --git a/src/exo/worker/tests/unittests/test_plan/test_warmup.py b/src/exo/worker/tests/unittests/test_plan/test_warmup.py index 46e372f6c1..d87f67c7b5 100644 --- a/src/exo/worker/tests/unittests/test_plan/test_warmup.py +++ b/src/exo/worker/tests/unittests/test_plan/test_warmup.py @@ -5,6 +5,8 @@ RunnerIdle, RunnerLoaded, RunnerLoading, + RunnerReady, + RunnerRunning, RunnerWarmingUp, ) from exo.utils.keyed_backoff import KeyedBackoff @@ -321,6 +323,58 @@ def test_plan_does_not_start_warmup_for_accepting_rank_until_all_loaded_or_warmi assert result is None +def test_plan_starts_warmup_for_connecting_rank_when_peer_already_ready(): + """ + Regression test for the asymmetric drafter race: the drafter rank's + warmup is near-instant (one forward pass) so by the time the + connecting rank's plan loop polls for state the drafter has often + already advanced past ``RunnerWarmingUp`` to ``RunnerReady`` / + ``RunnerRunning``. The connecting rank must still treat that as + "the peer is past the warmup barrier" and start its own warmup, + otherwise it stalls in ``RunnerLoaded`` forever. + """ + shard0 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2) + shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2) + instance = get_mlx_ring_instance( + instance_id=INSTANCE_1_ID, + model_id=MODEL_A_ID, + node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID}, + runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1}, + ) + + bound_instance = BoundInstance( + instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A + ) + local_runner = FakeRunnerSupervisor( + bound_instance=bound_instance, status=RunnerLoaded() + ) + + runners = {RUNNER_1_ID: local_runner} + instances = {INSTANCE_1_ID: instance} + + for peer_status in (RunnerReady(), RunnerRunning()): + all_runners = { + RUNNER_1_ID: RunnerLoaded(), + RUNNER_2_ID: peer_status, + } + result = plan_mod.plan( + node_id=NODE_A, + runners=runners, # type: ignore + global_download_status={NODE_A: []}, + instances=instances, + all_runners=all_runners, + tasks={}, + input_chunk_buffer={}, + image_cache={}, + instance_backoff=KeyedBackoff(), + download_backoff=KeyedBackoff(), + ) + assert isinstance(result, StartWarmup), ( + f"connecting rank should start warmup when peer is {type(peer_status).__name__}" + ) + assert result.instance_id == INSTANCE_1_ID + + def test_plan_does_not_start_warmup_for_connecting_rank_until_others_warming(): """ Connecting rank (device_rank == 0) should not start warmup diff --git a/src/exo/worker/tests/unittests/test_runner/test_event_ordering.py b/src/exo/worker/tests/unittests/test_runner/test_event_ordering.py index dac8d884c4..5cffd7d8d5 100644 --- a/src/exo/worker/tests/unittests/test_runner/test_event_ordering.py +++ b/src/exo/worker/tests/unittests/test_runner/test_event_ordering.py @@ -124,12 +124,21 @@ class MockLoadOutput: @pytest.fixture def patch_out_mlx(monkeypatch: pytest.MonkeyPatch): - # initialize_mlx returns a mock group - monkeypatch.setattr(mlx_builder, "initialize_mlx", make_nothin(MockGroup())) + # initialize_mlx returns an MlxGroupSplit; for symmetric placement the + # target subgroup is the same object as the parent. + from exo.worker.engines.mlx.utils_mlx import MlxGroupSplit + + mock_group = MockGroup() + mock_split = MlxGroupSplit( + parent=mock_group, # pyright: ignore[reportArgumentType] + target_subgroup=mock_group, # pyright: ignore[reportArgumentType] + drafter_rank_in_parent=None, + ) + monkeypatch.setattr(mlx_builder, "initialize_mlx", make_nothin(mock_split)) def lmi_gen(): yield MockLoadOutput(1, 1) - return (1, MockTokenizer, None, None) + return (1, MockTokenizer, None, None, None) monkeypatch.setattr(mlx_builder, "load_mlx_items", make_nothin(lmi_gen())) monkeypatch.setattr(mlx_batch_generator, "warmup_inference", make_nothin(1)) diff --git a/src/exo/worker/tests/unittests/test_runner/test_sequential_generator_batch_prefill.py b/src/exo/worker/tests/unittests/test_runner/test_sequential_generator_batch_prefill.py new file mode 100644 index 0000000000..778bfd5076 --- /dev/null +++ b/src/exo/worker/tests/unittests/test_runner/test_sequential_generator_batch_prefill.py @@ -0,0 +1,371 @@ +# pyright: reportAny=false, reportUnknownVariableType=false +# pyright: reportUnknownMemberType=false, reportUnknownArgumentType=false +# pyright: reportUnknownLambdaType=false, reportPrivateUsage=false +# pyright: reportInvalidCast=false, reportArgumentType=false +"""Integration tests for :meth:`SequentialGenerator._admit_queued_tasks`. + +These tests verify the routing decisions in the batched-prefill path: +which queued tasks get co-prefilled in a single forward, which fall +back to per-slot, and how the env-var gate / eligibility predicate +combine. The actual numerical correctness of :func:`batched_prefill` +is covered by ``tests/test_mlx/test_batched_prefill.py`` against a +real (random-weight) model; these tests stub the prefill function +itself and assert on the SequentialGenerator's branching only. +""" + +from __future__ import annotations + +from collections import OrderedDict, deque +from collections.abc import Generator +from typing import Any, cast + +import mlx.core as mx +import pytest + +from exo.shared.types.common import CommandId, ModelId +from exo.shared.types.events import Event +from exo.shared.types.tasks import TextGeneration +from exo.shared.types.text_generation import ( + InputMessage, + InputMessageContent, + TextGenerationTaskParams, +) +from exo.shared.types.worker.instances import InstanceId +from exo.shared.types.worker.runner_response import GenerationResponse +from exo.utils.channels import MpSender +from exo.worker.engines.mlx.cache import KVPrefixCache +from exo.worker.engines.mlx.types import KVCacheType +from exo.worker.runner.llm_inference import batch_generator as bg_mod +from exo.worker.runner.llm_inference.batch_generator import ( + EXO_BATCH_PREFILL, + BatchedPrefillUnsupportedError, + SequentialGenerator, +) + + +class _FakeEventSender: + def __init__(self) -> None: + self.events: list[Event] = [] + + def send(self, event: Event) -> None: + self.events.append(event) + + +def _make_text_task( + text: str, + *, + images: list[str] | None = None, + prefill_endpoint: str | None = None, + bench: bool = True, +) -> TextGeneration: + extra_kwargs: dict[str, object] = {} + if images is not None: + extra_kwargs["images"] = images + if prefill_endpoint is not None: + extra_kwargs["prefill_endpoint"] = prefill_endpoint + return TextGeneration( + instance_id=InstanceId("instance"), + command_id=CommandId(f"cmd-{text}"), + task_params=TextGenerationTaskParams( + model=ModelId("mlx-community/test-model"), + input=[InputMessage(role="user", content=InputMessageContent(text))], + bench=bench, + **extra_kwargs, + ), + ) + + +def _bare_seq_generator( + sender: _FakeEventSender, + initial_queue: deque[TextGeneration], + *, + draft_model: object | None = None, + group: object | None = None, + max_concurrent_tasks: int = 4, +) -> SequentialGenerator: + """Construct a SequentialGenerator without invoking dataclass init. + + The dataclass __init__ wants a real MLX model + tokenizer. We bypass + it and stub only the attributes the admit/start path reads. + """ + g = object.__new__(SequentialGenerator) + g.model = cast(Any, object()) + g.tokenizer = cast(Any, object()) + g.model_id = ModelId("mlx-community/test-model") + g.device_rank = 0 + g.event_sender = cast(MpSender[Event], cast(object, sender)) + g.group = cast(Any, group) + g.kv_prefix_cache = cast(KVPrefixCache | None, None) + g.tool_parser = None + g.vision_processor = None + g.draft_model = cast(Any, draft_model) + g.drafter_kv_prefix_cache = None + g.draft_model_id = None + g.num_draft_tokens = None + g.drafter_min_output_tokens = None + g.adaptive_draft_tokens = False + g.drafter_rank_in_parent = None + g.remote_drafter_transport = None + g.check_for_cancel_every = 50 + g._cancelled_tasks = set() + g._maybe_queue = [] + g._maybe_cancel = [] + g._all_tasks = {task.task_id: task for task in initial_queue} + g._queue = initial_queue + g._active_tasks = OrderedDict() + g._pending_failed = [] + g._recent_acceptance = deque() + g.max_concurrent_tasks = max_concurrent_tasks + return g + + +@pytest.fixture(autouse=True) +def _clear_env( # pyright: ignore[reportUnusedFunction] + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Default to enabled so each test sets the env explicitly when needed.""" + monkeypatch.delenv(EXO_BATCH_PREFILL, raising=False) + + +def _stub_prep_to_eligible( + monkeypatch: pytest.MonkeyPatch, + eligible_ids: set[str], +) -> None: + """Stub ``_prepare_for_batch_prefill`` to mark ``eligible_ids`` as eligible. + + The stub returns a tuple shaped like the production helper for + eligible tasks (with a length-3 mx.array prompt and an empty cache + list as a placeholder); ineligible tasks return ``None`` so the + caller routes them to the per-slot path. + """ + + def fake_prep( + _self: SequentialGenerator, task: TextGeneration + ) -> tuple[TextGeneration, mx.array, KVCacheType] | None: + if str(task.command_id) in eligible_ids: + return (task, mx.array([1, 2, 3]), cast(KVCacheType, [])) + return None + + monkeypatch.setattr(SequentialGenerator, "_prepare_for_batch_prefill", fake_prep) + + +def _stub_start_one(monkeypatch: pytest.MonkeyPatch) -> list[tuple[str, bool]]: + """Stub ``_start_one`` to record (command_id, used_precomputed_cache) calls.""" + calls: list[tuple[str, bool]] = [] + + def fake_start_one( + gen: SequentialGenerator, + task: TextGeneration, + *, + precomputed_target_cache: KVCacheType | None = None, + ) -> None: + calls.append((str(task.command_id), precomputed_target_cache is not None)) + gen._active_tasks[task.task_id] = ( + task, + cast(Generator[GenerationResponse], iter(())), + cast(Any, object()), + cast(Any, iter(())), + ) + + monkeypatch.setattr(SequentialGenerator, "_start_one", fake_start_one) + return calls + + +def _stub_batched_prefill( + monkeypatch: pytest.MonkeyPatch, + *, + side_effect: BaseException | None = None, +) -> list[int]: + """Stub :func:`batched_prefill`. Returns the list of batch sizes seen. + + When ``side_effect`` is provided the stub raises it instead of + returning success — used to test the fallback paths. + """ + seen_batch_sizes: list[int] = [] + + def fake_batched( + *, + model: object, + prompt_tokens_list: list[mx.array], + caches_list: list[KVCacheType], + **_: object, + ) -> tuple[float, int]: + del model, caches_list + seen_batch_sizes.append(len(prompt_tokens_list)) + if side_effect is not None: + raise side_effect + return 100.0, sum(int(p.size) - 1 for p in prompt_tokens_list) + + monkeypatch.setattr(bg_mod, "batched_prefill", fake_batched) + return seen_batch_sizes + + +def test_two_eligible_tasks_use_batched_prefill_path( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Two batch-eligible tasks must share one ``batched_prefill`` call.""" + sender = _FakeEventSender() + tasks = [_make_text_task(f"t{i}") for i in range(2)] + g = _bare_seq_generator(sender, deque(tasks)) + + _stub_prep_to_eligible(monkeypatch, {f"cmd-t{i}" for i in range(2)}) + calls = _stub_start_one(monkeypatch) + sizes = _stub_batched_prefill(monkeypatch) + + g._admit_queued_tasks() + + assert sizes == [2], "exactly one batched_prefill call with B=2" + assert [c[0] for c in calls] == ["cmd-t0", "cmd-t1"] + assert all(used for _, used in calls), ( + "every eligible task must receive a precomputed_target_cache" + ) + + +def test_single_eligible_task_falls_back_to_per_slot( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """A 1-eligible admit cycle skips batched_prefill (no parallelism win).""" + sender = _FakeEventSender() + tasks = [_make_text_task("only")] + g = _bare_seq_generator(sender, deque(tasks)) + + _stub_prep_to_eligible(monkeypatch, {"cmd-only"}) + calls = _stub_start_one(monkeypatch) + sizes = _stub_batched_prefill(monkeypatch) + + g._admit_queued_tasks() + + assert sizes == [], "batched_prefill must not be called for a single slot" + assert calls == [("cmd-only", False)] + + +def test_mixed_eligibility_routes_correctly( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Eligible + ineligible tasks split: batched for the eligible 2, per-slot for the rest.""" + sender = _FakeEventSender() + tasks = [_make_text_task(f"t{i}") for i in range(4)] + g = _bare_seq_generator(sender, deque(tasks)) + + _stub_prep_to_eligible(monkeypatch, {"cmd-t0", "cmd-t2"}) + calls = _stub_start_one(monkeypatch) + sizes = _stub_batched_prefill(monkeypatch) + + g._admit_queued_tasks() + + assert sizes == [2] + by_id = {cid: used for cid, used in calls} + assert by_id["cmd-t0"] is True + assert by_id["cmd-t2"] is True + assert by_id["cmd-t1"] is False + assert by_id["cmd-t3"] is False + + +def test_env_var_disables_batching(monkeypatch: pytest.MonkeyPatch) -> None: + """``EXO_BATCH_PREFILL=0`` must skip batched_prefill entirely.""" + monkeypatch.setenv(EXO_BATCH_PREFILL, "0") + sender = _FakeEventSender() + tasks = [_make_text_task(f"t{i}") for i in range(3)] + g = _bare_seq_generator(sender, deque(tasks)) + + _stub_prep_to_eligible(monkeypatch, {f"cmd-t{i}" for i in range(3)}) + calls = _stub_start_one(monkeypatch) + sizes = _stub_batched_prefill(monkeypatch) + + g._admit_queued_tasks() + + assert sizes == [] + assert all(not used for _, used in calls) + assert {cid for cid, _ in calls} == {f"cmd-t{i}" for i in range(3)} + + +def test_unsupported_cache_falls_back_to_per_slot( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """:class:`BatchedPrefillUnsupportedError` must demote every candidate to per-slot. + + This is the runner-liveness contract: a model whose cache layers + do not implement ``merge``/``extract`` (e.g. ``DeepseekV4Cache``) + surfaces the unsupported error from inside the helper; the + SequentialGenerator must catch it and continue with the per-slot + prefill path instead of crashing the runner subprocess. + """ + sender = _FakeEventSender() + tasks = [_make_text_task(f"t{i}") for i in range(2)] + g = _bare_seq_generator(sender, deque(tasks)) + + _stub_prep_to_eligible(monkeypatch, {f"cmd-t{i}" for i in range(2)}) + calls = _stub_start_one(monkeypatch) + _stub_batched_prefill( + monkeypatch, + side_effect=BatchedPrefillUnsupportedError("test: unsupported cache layer"), + ) + + g._admit_queued_tasks() + + assert calls == [("cmd-t0", False), ("cmd-t1", False)] + + +def test_distributed_group_disqualifies_batching() -> None: + """Multi-rank target must not batch; pipeline_parallel_prefill owns the driver loop.""" + sender = _FakeEventSender() + task = _make_text_task("only") + + class _FakeGroup: + def size(self) -> int: + return 4 + + g = _bare_seq_generator(sender, deque([task]), group=_FakeGroup()) + assert ( + g._batch_eligible_for_prefill(task) + is False + ) + + +def test_vision_request_disqualifies_batching() -> None: + """Vision prep needs per-task embed-table patching; never batch.""" + sender = _FakeEventSender() + task = _make_text_task("img-task", images=["data:image/png;base64,..."]) + g = _bare_seq_generator(sender, deque([task])) + assert ( + g._batch_eligible_for_prefill(task) + is False + ) + + +def test_remote_prefill_disqualifies_batching() -> None: + """Remote prefill ships the cache off-target; the local batched forward is moot.""" + sender = _FakeEventSender() + task = _make_text_task("rem", prefill_endpoint="http://prefill:8000") + g = _bare_seq_generator(sender, deque([task])) + assert ( + g._batch_eligible_for_prefill(task) + is False + ) + + +def test_inprocess_drafter_disqualifies_batching() -> None: + """In-process model drafter needs paired drafter prefill; V1 only batches the asymmetric (no draft_model) path.""" + sender = _FakeEventSender() + task = _make_text_task("draft") + g = _bare_seq_generator(sender, deque([task]), draft_model=object()) + assert ( + g._batch_eligible_for_prefill(task) + is False + ) + + +def test_asymmetric_drafter_target_qualifies_for_batching() -> None: + """Asymmetric drafter target rank has ``draft_model=None`` so it batches. + + Drafter prefill happens out-of-band over the wire (per-session + ``OP_PREFILL``) so the target-side batching is independent of + drafter alignment. + """ + sender = _FakeEventSender() + task = _make_text_task("asym") + g = _bare_seq_generator(sender, deque([task]), draft_model=None) + assert ( + g._batch_eligible_for_prefill(task) + is True + ) diff --git a/src/exo/worker/tests/unittests/test_runner/test_sequential_generator_errors.py b/src/exo/worker/tests/unittests/test_runner/test_sequential_generator_errors.py new file mode 100644 index 0000000000..56f4a20a2b --- /dev/null +++ b/src/exo/worker/tests/unittests/test_runner/test_sequential_generator_errors.py @@ -0,0 +1,428 @@ +"""Resilience tests for :class:`SequentialGenerator`. + +Regression coverage for PR #15: a per-task ``ValueError`` raised during +drafter construction (e.g. K above the transport's wire-protocol budget) +must not propagate out of ``step()`` and crash the runner subprocess. +The pre-fix behaviour was that ``_start_next`` re-raised after sending +the error chunk, which propagated through ``handle_generation_tasks`` +and triggered ``RunnerFailed`` on the supervisor, leaving the peer rank +wedged in ``RunnerRunning`` while the respawned target sat in +``RunnerIdle`` forever. + +These tests bypass the SequentialGenerator dataclass __init__ (which +needs a full MLX model + tokenizer stack) and patch only the failing +hot-spot, mirroring the pattern used by ``test_batch_generator_errors``. +""" + +from __future__ import annotations + +from collections import OrderedDict, deque +from collections.abc import Iterator +from typing import Any, cast + +import pytest + +from exo.shared.types.chunks import ErrorChunk +from exo.shared.types.common import CommandId, ModelId +from exo.shared.types.events import ChunkGenerated, Event +from exo.shared.types.tasks import TextGeneration +from exo.shared.types.text_generation import ( + InputMessage, + InputMessageContent, + TextGenerationTaskParams, +) +from exo.shared.types.worker.instances import InstanceId +from exo.utils.channels import MpSender +from exo.worker.runner.llm_inference.batch_generator import ( + FinishedResponse, + GeneratorQueue, + SequentialGenerator, +) + + +class _FakeEventSender: + def __init__(self) -> None: + self.events: list[Event] = [] + + def send(self, event: Event) -> None: + self.events.append(event) + + +def _make_text_task(text: str = "hello", bench: bool = False) -> TextGeneration: + return TextGeneration( + instance_id=InstanceId("instance"), + command_id=CommandId(f"command-{text}"), + task_params=TextGenerationTaskParams( + model=ModelId("mlx-community/test-model"), + input=[ + InputMessage(role="user", content=InputMessageContent(text)), + ], + bench=bench, + ), + ) + + +def _bare_sequential_generator( + sender: _FakeEventSender, + queue: deque[TextGeneration], +) -> SequentialGenerator: + """Construct a :class:`SequentialGenerator` without running its dataclass init. + + Only the attributes touched by ``step()`` / ``_start_next()`` / + ``_send_error()`` are wired in, so the test stays MLX-free and focused + on the resilience contract. + """ + generator = object.__new__(SequentialGenerator) + generator.model_id = ModelId("mlx-community/test-model") + generator.device_rank = 0 + generator.tokenizer = cast(Any, object()) + generator.event_sender = cast(MpSender[Event], cast(object, sender)) + generator.group = None + generator._maybe_queue = [] # pyright: ignore[reportPrivateUsage] + generator._maybe_cancel = [] # pyright: ignore[reportPrivateUsage] + generator._all_tasks = { # pyright: ignore[reportPrivateUsage] + task.task_id: task for task in queue + } + generator._queue = queue # pyright: ignore[reportPrivateUsage] + generator._cancelled_tasks = set() # pyright: ignore[reportPrivateUsage] + generator._active_tasks = OrderedDict() # pyright: ignore[reportPrivateUsage] + generator._pending_failed = [] # pyright: ignore[reportPrivateUsage] + generator._recent_acceptance = deque() # pyright: ignore[reportPrivateUsage] + generator.adaptive_draft_tokens = False + generator.max_concurrent_tasks = 1 + return generator + + +def test_start_next_failure_emits_finished_and_does_not_raise( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Drafter construction failure must surface as ``FinishedResponse``.""" + sender = _FakeEventSender() + task = _make_text_task("first") + generator = _bare_sequential_generator(sender, deque([task])) + + def boom(_self: SequentialGenerator, _task: TextGeneration) -> None: + raise ValueError("num_draft_tokens (8) exceeds transport's max (5)") + + def no_agree(_self: SequentialGenerator) -> None: + return None + + monkeypatch.setattr( + SequentialGenerator, + "_build_generator", + boom, + ) + monkeypatch.setattr( + SequentialGenerator, + "agree_on_tasks", + no_agree, + ) + + results = list(generator.step()) + + assert len(results) >= 1 + assert results[0][0] == task.task_id + assert isinstance(results[0][1], FinishedResponse) + assert ( + len(generator._active_tasks) == 0 # pyright: ignore[reportPrivateUsage] + ), "no active task should be set after failed _start_next" + assert len(sender.events) == 1 + assert isinstance(sender.events[0], ChunkGenerated) + assert isinstance(sender.events[0].chunk, ErrorChunk) + assert "num_draft_tokens" in sender.events[0].chunk.error_message + + +def test_runner_survives_sequential_failure_and_serves_next_task( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """After a per-task failure the runner must still serve the next task. + + This is the core regression: pre-fix, the first task's failure + propagated out of ``step()`` and tore down the runner subprocess, so + the second task never got a chance to run. We use two failing tasks + so the test stays MLX-free; what matters is that ``step()`` survives + both failures and surfaces them as ``FinishedResponse`` rather than + propagating an exception out of the runner loop. + + Post-concurrency-refactor (PR #15 round-robin), ``step`` drains the + queue up to ``max_concurrent_tasks`` per tick rather than admitting + one task per tick, so both failures may surface on tick 1. The + contract that matters is unchanged: every queued task must reach + ``_build_generator`` and surface a ``FinishedResponse`` without + raising. + """ + sender = _FakeEventSender() + first = _make_text_task("first") + second = _make_text_task("second") + generator = _bare_sequential_generator(sender, deque([first, second])) + + call_log: list[str] = [] + + def boom(_self: SequentialGenerator, task: TextGeneration) -> object: + call_log.append(str(task.task_id)) + raise ValueError("num_draft_tokens (8) exceeds transport's max (5)") + + def no_agree(_self: SequentialGenerator) -> None: + return None + + monkeypatch.setattr( + SequentialGenerator, + "_build_generator", + boom, + ) + monkeypatch.setattr( + SequentialGenerator, + "agree_on_tasks", + no_agree, + ) + + finished_task_ids: set[Any] = set() + while finished_task_ids != {first.task_id, second.task_id}: + produced = list(generator.step()) + for task_id, response in produced: + if isinstance(response, FinishedResponse): + finished_task_ids.add(task_id) + # Guard the loop: with max_concurrent_tasks=1 (helper default) + # this finishes in one or two ticks; if step() ever loops without + # progress the runner has regressed and we want a hard fail. + if not produced and not generator._queue and not generator._pending_failed: # pyright: ignore[reportPrivateUsage] + break + + assert finished_task_ids == {first.task_id, second.task_id}, ( + "both tasks must surface as FinishedResponse" + ) + assert call_log == [str(first.task_id), str(second.task_id)], ( + "both tasks must reach _build_generator -- pre-fix the first " + "failure propagated and the second task never got a chance" + ) + assert len(sender.events) == 2, "both failures must emit ErrorChunks" + + +def test_round_robin_advances_all_active_tasks_per_tick( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """``max_concurrent_tasks > 1`` must advance every active task per ``step``. + + Pre-fix, ``SequentialGenerator._active`` was a singular slot and slot + 1's TTFT equalled slot 0's *full* completion time -- the 14s figure + measured in the PR #15 concurrency leg. The fix admits up to + ``max_concurrent_tasks`` simultaneous in-flight tasks and round- + robins one ``next(gen)`` per task per ``step``, so slot 1's TTFT is + bounded by its own prefill plus a constant number of slot-0 token + times. We assert the contract (both tasks make progress on the same + tick) without standing up an MLX model. + """ + sender = _FakeEventSender() + # ``bench=True`` short-circuits the parser pipeline so ``_start_next`` + # never touches ``tokenizer.apply_chat_template`` -- the test stays + # focused on the round-robin contract. + first = _make_text_task("first", bench=True) + second = _make_text_task("second", bench=True) + generator = _bare_sequential_generator(sender, deque([first, second])) + generator.max_concurrent_tasks = 2 + + yielded_per_task: dict[Any, int] = {first.task_id: 0, second.task_id: 0} + + def fake_build( + _self: SequentialGenerator, task: TextGeneration + ) -> Iterator[object]: + # Each generator yields a sentinel object three times so we can + # observe round-robin progression without depending on MLX. The + # parsed-output generator is an empty iterator -- ``step`` is + # tested through its bookkeeping (``_active_tasks`` membership, + # task progress), not through chunk emission. + def gen() -> Iterator[object]: + for _ in range(3): + yielded_per_task[task.task_id] += 1 + yield object() + + return gen() + + def no_agree(_self: SequentialGenerator) -> None: + return None + + monkeypatch.setattr(SequentialGenerator, "_build_generator", fake_build) + monkeypatch.setattr(SequentialGenerator, "agree_on_tasks", no_agree) + + list(generator.step()) + + assert yielded_per_task[first.task_id] == 1, ( + "first task must advance one token on tick 1" + ) + assert yielded_per_task[second.task_id] == 1, ( + "second task must ALSO advance one token on tick 1 -- this is " + "the round-robin contract; pre-fix it would have been 0 because " + "the singular ``_active`` slot was held by the first task" + ) + assert ( + len(generator._active_tasks) == 2 # pyright: ignore[reportPrivateUsage] + ), "both tasks must be in the active set" + + +def test_round_robin_respects_max_concurrent_tasks( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """``max_concurrent_tasks=1`` (asymmetric default) must stay singular. + + ``RemoteTransport``'s wire protocol is per-session, so the asymmetric + placement leaves ``max_concurrent_tasks`` at 1 at builder time. This + test asserts the cap is honoured in ``step``: with two queued tasks + and a cap of 1, only the first is admitted; the second waits until + the first retires. + """ + sender = _FakeEventSender() + first = _make_text_task("first", bench=True) + second = _make_text_task("second", bench=True) + generator = _bare_sequential_generator(sender, deque([first, second])) + generator.max_concurrent_tasks = 1 + + admitted_order: list[Any] = [] + + def fake_build( + _self: SequentialGenerator, task: TextGeneration + ) -> Iterator[object]: + admitted_order.append(task.task_id) + + # Generator yields once then exhausts on the next ``next()``. + def gen() -> Iterator[object]: + yield object() + + return gen() + + def no_agree(_self: SequentialGenerator) -> None: + return None + + monkeypatch.setattr(SequentialGenerator, "_build_generator", fake_build) + monkeypatch.setattr(SequentialGenerator, "agree_on_tasks", no_agree) + + # Tick 1: cap=1 admits only the first task; second remains queued. + list(generator.step()) + assert admitted_order == [first.task_id], ( + "only the first task may be admitted when cap=1" + ) + assert ( + first.task_id in generator._active_tasks # pyright: ignore[reportPrivateUsage] + ), "first task is mid-stream after one yield" + assert len(generator._queue) == 1, ( # pyright: ignore[reportPrivateUsage] + "second task must remain queued under cap=1" + ) + + # Tick 2: first generator exhausts (StopIteration on second ``next``) + # and the slot frees up; the cap-respecting top-up admits second. + list(generator.step()) + assert admitted_order == [first.task_id, second.task_id], ( + "second task must be admitted on tick 2 after first retires" + ) + assert ( + first.task_id not in generator._active_tasks # pyright: ignore[reportPrivateUsage] + ), "first task must have retired" + + +def test_round_robin_per_task_error_does_not_kill_other_active_tasks( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """A faulty generator must finish only its own task; siblings keep advancing. + + With ``max_concurrent_tasks > 1`` a single malformed request must + not knock peer in-flight tasks off the runner. This is a strictly + stronger version of the K=8-cancel resilience contract. + """ + sender = _FakeEventSender() + good = _make_text_task("good") + bad = _make_text_task("bad") + generator = _bare_sequential_generator(sender, deque()) + generator.max_concurrent_tasks = 2 + + good_yields = [0] + + def good_gen() -> Iterator[object]: + for _ in range(5): + good_yields[0] += 1 + yield object() + + class _BoomError(Exception): + pass + + def bad_gen() -> Iterator[object]: + raise _BoomError("doomed mid-stream") + yield # pyright: ignore[reportUnreachable] + + # Use real ``GeneratorQueue`` instances per task so ``queue.push`` + # in ``step`` doesn't blow up; outputs are drained via per-task + # ``output_generator`` iterators (empty here -- the contract under + # test is task-membership in ``_active_tasks``, not chunk content). + generator._active_tasks[good.task_id] = ( # pyright: ignore[reportPrivateUsage] + good, + cast(Any, good_gen()), + GeneratorQueue(), + iter([]), + ) + generator._active_tasks[bad.task_id] = ( # pyright: ignore[reportPrivateUsage] + bad, + cast(Any, bad_gen()), + GeneratorQueue(), + iter([]), + ) + + # ``cast(Any, ...)`` above is required because ``_active_tasks`` + # expects ``Generator[GenerationResponse]`` and our test stubs yield + # plain ``object()`` to keep the test MLX-free; the stubs satisfy the + # iterator protocol that ``next(gen)`` relies on, which is the only + # thing ``step`` actually requires. + + def no_agree(_self: SequentialGenerator) -> None: + return None + + monkeypatch.setattr(SequentialGenerator, "agree_on_tasks", no_agree) + + results = list(generator.step()) + + assert good_yields[0] == 1, "good task must still advance on the bad-task tick" + bad_finished = any( + r[0] == bad.task_id and isinstance(r[1], FinishedResponse) for r in results + ) + assert bad_finished, "bad task must surface as FinishedResponse" + assert ( + good.task_id in generator._active_tasks # pyright: ignore[reportPrivateUsage] + ), "good task must remain active after sibling failure" + assert ( + bad.task_id not in generator._active_tasks # pyright: ignore[reportPrivateUsage] + ), "bad task must be evicted from the active set" + assert len(sender.events) == 1 + assert isinstance(sender.events[0], ChunkGenerated) + assert isinstance(sender.events[0].chunk, ErrorChunk) + + +def test_step_exception_during_next_does_not_raise() -> None: + """An exception during ``next(gen)`` mid-stream must surface as Finished, not crash.""" + sender = _FakeEventSender() + task = _make_text_task() + generator = _bare_sequential_generator(sender, deque()) + + class _BoomError(Exception): + pass + + def faulty_gen() -> Iterator[object]: + raise _BoomError("runtime fault inside spec loop") + yield # pyright: ignore[reportUnreachable] + + generator._active_tasks[task.task_id] = ( # pyright: ignore[reportPrivateUsage] + task, + cast(Any, faulty_gen()), + GeneratorQueue(), + iter([]), + ) + + results = list(generator.step()) + + assert any( + result[0] == task.task_id and isinstance(result[1], FinishedResponse) + for result in results + ) + assert ( + len(generator._active_tasks) == 0 # pyright: ignore[reportPrivateUsage] + ) + assert len(sender.events) == 1 + assert isinstance(sender.events[0], ChunkGenerated) + assert isinstance(sender.events[0].chunk, ErrorChunk) + assert "runtime fault" in sender.events[0].chunk.error_message diff --git a/src/exo/worker/tests/unittests/test_worker_instance_backoff.py b/src/exo/worker/tests/unittests/test_worker_instance_backoff.py new file mode 100644 index 0000000000..b0052c1eb7 --- /dev/null +++ b/src/exo/worker/tests/unittests/test_worker_instance_backoff.py @@ -0,0 +1,36 @@ +# pyright: reportPrivateUsage=false + +from exo.shared.types.common import ModelId, NodeId +from exo.shared.types.state import State +from exo.shared.types.worker.instances import InstanceId, MlxRingInstance +from exo.shared.types.worker.runners import ShardAssignments +from exo.utils.keyed_backoff import KeyedBackoff +from exo.worker.main import Worker + + +def _make_instance(instance_id: InstanceId) -> MlxRingInstance: + return MlxRingInstance( + instance_id=instance_id, + shard_assignments=ShardAssignments( + model_id=ModelId("test-model"), + node_to_runner={}, + runner_to_shard={}, + ), + hosts_by_node={NodeId("node-1"): []}, + ephemeral_port=1, + ) + + +def test_worker_reconciles_instance_backoff_from_state() -> None: + live_instance_id = InstanceId("inst-live") + deleted_instance_id = InstanceId("inst-deleted") + worker = object.__new__(Worker) + worker.state = State(instances={live_instance_id: _make_instance(live_instance_id)}) + worker._instance_backoff = KeyedBackoff[InstanceId]() + worker._instance_backoff.record_attempt(live_instance_id) + worker._instance_backoff.record_attempt(deleted_instance_id) + + worker._reconcile_instance_backoff_once() + + assert worker._instance_backoff.attempts(live_instance_id) == 1 + assert worker._instance_backoff.attempts(deleted_instance_id) == 0