Gemma-4 drafter tuning + DrafterTransport architecture (model | pipelined | ngram | none)#15
Gemma-4 drafter tuning + DrafterTransport architecture (model | pipelined | ngram | none)#15team-wcv wants to merge 72 commits into
Conversation
Move the schema from a single drafter to a preference-ordered list so runners can pick from multiple candidates (e.g. e2b for fastest, e4b for highest acceptance) at startup time via the new EXO_DRAFTER_PREFERENCE env var. - ModelCard.drafter_model_ids: list[ModelId] (default empty) - All 8 Gemma 4 large-instruct cards now declare both e2b and e4b drafters at matching quantisation, in fastest-first order. - _maybe_load_drafter selects from the candidate list, preferring on-disk drafters within the user's preference setting. - _select_drafter_id helper exposed for testing the policy. - DownloadCoordinator chains downloads for ALL declared drafters so the runtime preference can be flipped without an on-demand fetch. - API /v1/models surfaces drafter_model_ids: list[str]. - New tests for the selection helper. X-Orchestraitor-Plan: gemma4-drafter-tuning
Item 1 — Expose num_draft_tokens via EXO_NUM_DRAFT_TOKENS (default 5). mlx_lm's speculative_generate_step defaults to K=2 which is conservative for unknown drafters. For purpose-built family pairs (Gemma 4 e2b -> 31b) acceptance is ~80% and K=5-6 is the sweet spot — at p=0.8 going from K=2 to K=5 yields roughly +60% effective throughput. Item 3 — Drafter participates in warmup_inference. The first request after launch no longer pays a cold-cache penalty on the drafter side. Item 8 — Skip drafter for short generations. When max_tokens is at or below EXO_DRAFTER_MIN_OUTPUT_TOKENS (default 16) the drafter is dropped for that request: speculative decoding's fixed setup cost dominates the saved decode time on classification / structured-output workloads. resolve_speculative_decoding factored out as a pure helper so the policy is unit-testable without spinning up MLX. X-Orchestraitor-Plan: gemma4-drafter-tuning
stream_generate yields from_draft on every token. We now thread that into the per-request GenerationStats so dashboards/clients can A/B configurations on real traffic and (next commit) so adaptive K can self-tune off observed acceptance. - GenerationStats gains drafter_model_id, accepted_draft_tokens, num_draft_tokens, and a drafter_acceptance_fraction property. - mlx_generate counts from_draft hits and stamps the chosen drafter id on the final stats. effective_draft_model is the source of truth, so short-skip and distributed paths correctly report "no drafter". - load_mlx_items now returns the drafter id alongside the model so the builder can plumb it into SequentialGenerator -> mlx_generate. - Dashboard message metadata + ChatMessages.svelte render a SPEC chip next to TTFT/tok/s showing acceptance % and K, with the drafter id surfaced via a tooltip. X-Orchestraitor-Plan: gemma4-drafter-tuning
When EXO_ADAPTIVE_DRAFT_TOKENS=1 is set, K is recomputed each request from a rolling window of observed acceptance fractions: acceptance_rate < 0.5 -> K=2 (drafter is missing, don't waste guesses) 0.5 <= rate < 0.75 -> K=4 rate >= 0.75 -> K=6 Heterogeneous traffic (code completion = predictable patterns = high acceptance, reasoning = lower) gets a real win here over a fixed K. Disabled by default so K stays predictable for benchmarking. The rolling window is sized at 8 (ADAPTIVE_K_WINDOW), with a 2-observation warmup that falls back to the configured static K. adaptive_num_draft_tokens is a pure helper so the band logic is unit-tested separately from the SequentialGenerator wiring. X-Orchestraitor-Plan: gemma4-drafter-tuning
mlx_lm's speculative_generate_step splits its prompt_cache argument as [: len(target.layers)] for the target and [len(target.layers):] for the drafter, so we can hand it a concatenated cache list and have mlx_lm wire each side to the right model. When a drafter is loaded, MlxBuilder allocates a parallel KVPrefixCache for it. mlx_generate looks up the drafter's cache on the same prompt the target's cache hit on, prefills any uncovered tail with a simple per-step forward (single-device, no distributed sync needed), stores the post-prefill state back to the drafter's prefix cache, and concatenates target + drafter caches into the prompt_cache passed to stream_generate. Multi-turn Codex / chat workloads that grow turn-by-turn now get drafter prefix hits instead of paying full drafter prefill on every turn. The benefit scales with K (number of draft tokens per round): bigger K means drafter prefill is a larger share of total time, so the prefix cache hit saves more. The drafter prefill helper deliberately avoids the full target prefill machinery (group sync, pipeline parallelism, SSM snapshots, progress callbacks) -- drafters are small and single-device by construction. X-Orchestraitor-Plan: gemma4-drafter-tuning
Add `use_drafter: bool | None` and `num_draft_tokens: int | None` to TextGenerationTaskParams and the chat-completions wire schema. None on either field means "use the runner's configured defaults", so existing clients see no behaviour change. Use cases: - A latency-sensitive client can set `use_drafter: false` to skip speculative decoding entirely on its tight-loop path. - A throughput-sensitive client (e.g. a Cursor-style integration that knows the prompt is long-form code completion) can raise K via `num_draft_tokens: 8` for that single request. The actual K used is reported back in GenerationStats.num_draft_tokens so the dashboard reflects per-request overrides. X-Orchestraitor-Plan: gemma4-drafter-tuning
Speculative decoding is single-device only in mlx_lm. Existing placement already prefers single-node via get_smallest_cycles (a 1-node cycle wins over any N>1-node cycle when memory permits), so the common case of "declare drafter -> get drafter" already works. The corner case the user flagged: a quant that's too big to fit on any single node falls back to TP and silently loses the drafter. We now log a clear, actionable warning at placement time when a card with drafter_model_ids gets placed across N>1 nodes, telling the operator to re-place a smaller quant for the speculative-decoding speedup. Auto-swapping to a smaller quant at placement time (the user's "single-node + 4-bit on the biggest box over TP + 8-bit" heuristic) requires the placement engine to know about quant variants of the same base model. That's a larger restructuring; the warning unblocks the operator-driven workflow today. X-Orchestraitor-Plan: gemma4-drafter-tuning
mlx_lm's speculative_generate_step splits its prompt_cache argument into [: len(target.layers)] for the target and [len(target.layers):] for the drafter. When draft_model is set we MUST hand it a cache list that contains both halves -- if we pass only the target caches, mlx_lm sees an empty drafter slice and the first prefill stalls. Item 6's path only built the drafter caches when a KVPrefixCache was wired in (the production runner builds one in MlxBuilder.build()), but warmup_inference calls mlx_generate with drafter_kv_prefix_cache=None, so warmup hung indefinitely on wc-studio under gemma-4-26b-a4b-it-4bit + e2b drafter. Restructured so we always build + prefill a drafter cache when the drafter is active, and only consult/update the prefix cache when one is provided. This unblocks warmup and keeps the multi-turn prefix-hit behavior intact for the real request path. X-Orchestraitor-Plan: gemma4-drafter-tuning
…he trim breaks spec)
Reverts the force_plain_kv_cache experiment. The original hang in
exo's spec-decoding path was a cache double-fill, not a cache-type
issue: mlx_lm.speculative_generate_step._prefill always processes
prompt[:y.size-1] regardless of cache offset, so exo's prefill (which
already filled the target cache) plus mlx_lm's _prefill resulted in
2x prompt ingestion and garbage logits.
Forcing plain KVCache for sliding-window layers actually made things
worse: KVCache.make_mask returns variable-shape mx.array masks that
trigger Metal kernel recompiles on every offset change, while
RotatingKVCache.make_mask returns either a bounded-shape mask or the
"causal" string sentinel — one compile that's reused.
Fix:
* make_kv_cache reverts to honoring model.make_cache() (native mix
of RotatingKVCache + KVCache for gemma-4).
* On the spec path, exo's prefill and KVPrefixCache are bypassed;
we hand mlx_lm fresh native caches for target+drafter and the
full prompt, and let speculative_generate_step._prefill drive
ingestion exactly as it expects.
* Non-spec path is unchanged.
Verified end-to-end on wc-studio:
warmup_inference (target + drafter, cold process):
50 tokens in 1.55s @ 33 tok/s
The drafter prefix cache parameter was plumbed through the runner but
never actually used inside ``mlx_generate``: the spec path bypassed the
prefix-cache code entirely and let mlx_lm's spec_step._prefill drive
both target and drafter caches from scratch. This left
drafter_kv_prefix_cache permanently empty and forced the drafter to
re-prefill on every request.
Changes:
* Resolve speculative-decoding policy (use_drafter, num_draft_tokens,
drafter_min_output_tokens, distributed) up-front so cache setup
branches on the *effective* drafter rather than the unfiltered one.
* On the spec path, look up both target and drafter prefix caches,
align them at min(target_hit, drafter_hit) so spec_step sees
identical offsets.
* Add ``_spec_drafter_prefill`` helper that advances the drafter
cache to match the target after exo.prefill (target ends at
prompt - 2 via prefill+trim; drafter must mirror that).
* After both prefills, snapshot the drafter cache via
``add_kv_cache``/``update_kv_cache`` so subsequent requests can
skip the drafter prefill entirely on prefix hits.
* Fix telemetry: ``GenerationStats.num_draft_tokens`` now populates
whenever speculation actually ran, not only when both
drafter_model_id *and* effective_draft_model are present.
Smoke matrix on wc-studio (gemma-4 26b a4b 4bit + e2b 4bit):
warmup-with-drafter: 100 tok / 2.15s
from_draft-telemetry: 31/40 accepted, drafter_id correctly stamped
K-override (K=2/K=8): both reported in stats
use_drafter=False: 0 accepted, drafter_id=None
skip-short (max<min): 0 accepted, drafter_id=None
drafter-prefix-cache: entries grow request-over-request
adaptive-K: 40 tok / 1.86s, 24 accepted
Refactor speculative decoding around a small ``Drafter`` Protocol so
``mlx_generate`` can dispatch on mode rather than branching at every
cache and stream-generate site. Three concrete drafters cover today's
needs and leave a clean seam for EAGLE/Medusa/lookahead/distributed
drafters in future PRs:
* ``NoSpecDrafter``: pass-through to ``mlx_lm.stream_generate``.
* ``ModelDrafter``: wraps ``mlx_lm.stream_generate(draft_model=...)``;
the existing well-tested upstream spec loop keeps owning the
model-drafter path (no behaviour change for current users).
* ``NgramDrafter``: in-house spec loop that proposes drafts by
suffix-matching the running token context against itself - zero
drafter compute, no extra KV cache, no warmup. Match-strength-
adaptive K (proposal length capped to match length) keeps weak
matches cheap so worst case == baseline.
New surface:
* ``EXO_DRAFT_MODE`` env var (model | ngram | none) sets the
process-wide default; falls back to ``model`` if a drafter model
is loaded, else ``none``. Builder respects the env var and forces
SequentialGenerator when ngram is requested without a drafter
model.
* ``TaskParams.draft_mode`` per-request override (wins over both
``use_drafter`` and the env var).
* ``GenerationStats.draft_mode`` telemetry stamps which strategy
actually ran, alongside the existing ``drafter_model_id`` /
``accepted_draft_tokens`` / ``num_draft_tokens`` fields.
Bench (Mac Studio M3 Ultra, gemma-4-26b-a4b 4-bit, 200 gen tokens):
baseline ngram model
echo_rag (RAG) 76.6 75.8 56.9 tps
code_repeat 70.9 57.2 45.2 tps
structured_json 76.7 63.1 53.5 tps
creative_novel 77.1 58.5 43.0 tps
reasoning_chain 76.7 59.5 49.5 tps
The honest read: on memory-bandwidth-bound single-device 4-bit Mac
Studio inference, both spec modes are net-negative (the K+1-token
verify forward costs nearly K+1 times a single-token forward, so
break-even acceptance fraction is K/(K+1) ~ 80% which most workloads
don't reach). ``none`` remains the right default on this hardware;
``ngram`` and ``model`` are wired and tested for slower-target /
distributed regimes where the economics flip. Phase 2 (pipelined
drafter+verify) and Phase 3 (drafter-on-other-device for distributed
runs) ride on the same Drafter seam.
29 new unit tests cover ``parse_draft_mode``, ``resolve_draft_mode``
precedence, ``NgramDrafter.propose`` (longest-match, recency tie-
break, adaptive K cap), and ``make_drafter`` dispatch.
Update: added Drafter abstraction + n-gram drafting strategy (commit 2f52b82)Refactored speculative decoding around a small New surface
Bench (Mac Studio M3 Ultra, gemma-4-26b-a4b 4-bit, 200 gen tokens)
The honest read: on memory-bandwidth-bound single-device 4-bit Mac Studio inference, both spec modes are net-negative. The K+1-token verify forward costs nearly K+1 times a single-token forward, so break-even acceptance fraction is K/(K+1) ~ 80%, which most workloads don't reach. Phase 2 (pipelined drafter+verify) and Phase 3 (drafter-on-other-device for distributed runs) ride on the same Tests29 new unit tests in |
Foundation for the unified pipelined / distributed-drafter architecture. Lands the transport-agnostic interface and the in-process implementation so Layer B's RemoteTransport over mx.distributed.send/recv (RDMA via jaccl, TCP via ring) is purely additive. * drafter_transport.py: DrafterTransport Protocol with two primitives (forward, trim_cache) the spec loop composes into propose / commit / rollback. InProcessTransport mirrors mlx_lm._draft_generate but with cache-state bookkeeping via mlx_lm.trim_prompt_cache. The K+1 upper bound on num_forwards is for the speculative bonus-prediction forward. * pipelined_drafter.py: PipelinedModelDrafter + custom spec loop with cross-round speculation. While target verifies round t, drafter runs K+1 forwards starting from drafts_t[-1] to predict bonus_t and generate round t+1's drafts. On hit, round t+1 propose is free; on miss, the K+1 speculative positions are rolled back. Cache-trim arithmetic for partial / full / hit / miss is documented in the module docstring and exercised by unit tests. * remote_drafter.py: stub surface for Layer B. NotImplementedError now with a pointer to the planned topology change; the wire-protocol constants (COMMAND_FRAME_SIZE, OP_PROPOSE / OP_TRIM_CACHE / OP_SHUTDOWN) are final so Layer B's drafter_serve_loop is purely additive. * drafter.py: DraftMode gains "pipelined". make_drafter routes the new mode through transport_factory_for(EXO_DRAFTER_TRANSPORT) so switching in-process vs remote drafter placement is one env var. * generate.py: spec_active / effective_draft_model widened to include "pipelined" so target prefill, drafter prefill, and prefix-cache bookkeeping all run for the new mode. * builder.py: forces SequentialGenerator for "pipelined" (BatchGenerator has no spec-decoding hook, same reasoning as ngram). * api.py / text_generation.py: extends the request / telemetry literals. * tests: 17 new tests covering Protocol contract, transport-kind parsing, factory dispatch, and cross-round-speculation cache-trim arithmetic via a deterministic FakeTransport (no MLX weights). Plus 3 new mode-resolution tests for the pipelined demote-to-none path. Layer B (RemoteTransport + asymmetric MlxJacclInstance / MlxRingInstance with drafter_rank field + placement + runner serve-loop + subgroup setup + twin testing) lands in the next commit on this branch.
Fills out the RemoteTransport DrafterTransport implementation and the matching drafter-rank serve loop using mx.distributed.send/recv on the parent group. Network primitive is whatever exo's distributed group already negotiated for the cluster: jaccl/RDMA when available, ring/TCP fallback otherwise -- "rdma is already built into exo do not reinvent the wheel". * Fixed-size 8-uint32 command frame keeps per-round IPC overhead microseconds, not milliseconds. * OP_FORWARD / OP_TRIM_CACHE / OP_SHUTDOWN with explicit ack frames give the target deterministic sync points after every command. * Background ThreadPoolExecutor on the target rank lets PipelinedModelDrafter return a non-blocking DraftFuture so verify(t) and drafter_forward(t+1) overlap on the wire.
V1 of operator-controlled asymmetric drafter placement. Operators tag specific nodes as drafter hosts on a model card; placement appends a drafter-only rank to the parent mx.distributed group on one of those nodes whenever target ranks land elsewhere. Target ranks split off into a target subgroup at runtime; the parent group is reserved for target<->drafter point-to-point IPC over RemoteTransport. Shape of the change * ModelCard gains drafter_eligible_nodes: list[NodeId] = [] -- empty by default preserves legacy in-process drafter behaviour. * DrafterPlacement type captures drafter_node_id, drafter_runner_id, drafter_model_id, and the drafter's index in the parent group (always last rank == world_size - 1, by convention). * MlxRingInstance / MlxJacclInstance carry an optional drafter_placement; BoundInstance gains is_drafter_rank and a relaxed validator that accepts the drafter's runner id without a target shard. * DrafterPlacementDegraded event + DrafterPlacementDegradationReason surface placement-time downgrades (no eligible node alive, all eligible nodes already targets, no reachable transport, drafter node below memory floor) so dashboards/CLI can show why asymmetric was denied. The user's request still completes -- placement returns the legacy single-device or multi-node-without-drafter instance and the event rides alongside InstanceCreated. * place_instance picks an eligible drafter node that is alive, outside the target cycle, all-to-all reachable from every target rank (RDMA for MlxJaccl, socket for MlxRing), and above a 6 GB conservative memory floor. * Target cycle selection now reserves drafter-eligible nodes when possible -- a user listing one node as drafter-eligible no longer sees that node grabbed as a target rank. * Multi-node-without-drafter warning updated to point operators at the drafter_eligible_nodes opt-in (alongside the existing "use a smaller quant" suggestion). * Master forwards degradation events alongside transition events so state is consistent with what the operator sees. Why N+1 rank from day one Bigger than V1 needs (twins is N=1 + drafter), but the architecture treats N=1 the same as N=2 or N=8 -- same parent group init, same group.split, same RemoteTransport. Building it asymmetric-aware now avoids paying refactor cost later when the cluster grows. Tested 10 new tests under test_placement_drafter_asymmetric.py cover the happy path (ring + jaccl), each degradation reason, multi-eligible fallthrough, and a serialisation round-trip. test_model_cards_drafter.py covers the new field's defaults and round-trip. All 39 placement tests pass; full shared/master/worker unit suite still 383 green. Layer B6 (utils_mlx group split), B7 (drafter runner dispatch), and B8 (target rank wires PipelinedModelDrafter+RemoteTransport) follow.
initialize_mlx now returns an MlxGroupSplit dataclass holding the parent group, the target subgroup, and the drafter rank index in the parent. Symmetric placement (no drafter) returns an MlxGroupSplit where parent and target subgroup are the same group object, so all existing call sites that pass `self.group` to tensor / pipeline / batch collectives keep working unchanged. Asymmetric placement uses Group.split with deterministic colors: target ranks color=0, drafter rank color=1. Target ranks get a size-N subgroup; the drafter ends up alone in a size-1 subgroup it never uses. Group.split is a collective, so all ranks (including the drafter) call it together. The parent group is reserved for RemoteTransport.send/recv between target rank 0 and the drafter rank. mlx_distributed_init learned to read the rank from DrafterPlacement.drafter_rank when called on the drafter rank (which has no bound_shard). The single-rank assertion now checks parent_group_size instead of len(runner_to_shard) so asymmetric single-target setups pass through to mx.distributed.init. MlxBuilder gained parent_group + drafter_rank_in_parent fields alongside the existing target-subgroup `group`; the image builder keeps using only target_subgroup since it never declares a drafter. The test mock for initialize_mlx now wraps a MockGroup in a symmetric MlxGroupSplit.
Introduces the runner-side dispatch for the asymmetric drafter rank so the worker plan iterates uniformly over target ranks + drafter rank when fanning out lifecycle tasks. * `BaseInstance.all_runner_ids` / `all_node_to_runner` and `BoundInstance.parent_rank` / `is_drafter_rank` so plan-time iteration doesn't have to special-case the drafter. * `worker/plan.py` uses those helpers for `_kill_runner`, `_create_runner`, `_init_distributed_backend`, `_load_model`, `_ready_to_warmup`, `_pending_tasks`. Drafter ranks skip `_model_needs_download` (V1 assumes operator pre-download) and the parent_group_size gates accommodate the asymmetric N+1 rank topology with a single target rank. * New `DrafterRunner` mirrors the target runner state machine (`ConnectToGroup` -> `LoadModel` -> `StartWarmup` -> running) and enters `drafter_serve_loop` after warmup. Single-forward warmup primes Metal kernels so the first real OP_FORWARD doesn't pay JIT-compile cost. * `runner/bootstrap.py` dispatches on `is_drafter_rank` and runs `DrafterRunner` with the same mlx_lm patches the target rank uses. Type-clean under basedpyright, ruff clean, all 404 unit tests pass.
Closes the loop on the asymmetric topology: when an instance's `DrafterPlacement` is set, the target rank constructs a long-lived `RemoteTransport` and routes every request through `PipelinedModelDrafter` talking over `mx.distributed.send/recv` to the drafter rank. RDMA (`MlxJaccl`) and TCP (`MlxRing`) backends both ride the existing `mx.distributed.Group` -- the only knob that picks the wire format is the group's backend, exactly as upstream's send/recv contract guarantees. Wire protocol (extends Layer A's command frames): * New `OP_PREFILL`: per-request setup. Target announces a prompt of `num_prompt_tokens` (encoded in the `num_forwards` slot of the fixed-size command frame), then sends a `(num_prompt_tokens,)` uint32 array of token ids. Drafter trims its KV cache to offset 0 and runs prefill forwards in 4096-token chunks (mirroring `_spec_drafter_prefill`'s step), then acks. Issued at the start of every request so the spec loop's first OP_FORWARD seeds against an aligned drafter cache. DrafterTransport API additions: * `reset_and_prefill(prompt_tokens)` on the Protocol, implemented as no-op for `InProcessTransport` (the legacy mlx_generate path owns drafter cache prefill externally) and as the `OP_PREFILL` IPC for `RemoteTransport`. The Protocol method exists so `make_drafter` doesn't have to dispatch on transport kind to issue per-request setup. Generator wiring: * `make_drafter` accepts `pipelined_transport` (a pre-built `DrafterTransport`) so SequentialGenerator can allocate the RemoteTransport once at build time and reuse it across requests -- per-request executor + drafter cache lifecycle would defeat the asymmetric setup's whole point. Multi-target asymmetric is gated behind a `NotImplementedError` (V1 supports a single target rank; N>1 needs draft-broadcast on the target subgroup, planned follow-up). * `mlx_generate` accepts `asymmetric_parent_group`, `asymmetric_drafter_rank`, `asymmetric_drafter_transport`. When active, bypasses the legacy `group is not None -> draft_mode = "none"` demotion (mlx_lm's own spec loop can't handle pipeline collectives; ours doesn't need to since we restrict to N=1 target ranks for V1) and forces `draft_mode = "pipelined"`. * `MlxBuilder` populates `parent_group` and `drafter_rank_in_parent` in `connect()` and allocates the long-lived `RemoteTransport` in `build()` when asymmetric placement is active. Forces `SequentialGenerator` (BatchGenerator has no spec-decoding hook). * `SequentialGenerator.close()` shuts down the transport (sends `OP_SHUTDOWN`) so the drafter rank's serve loop drains cleanly on runner teardown. Tests: * New `test_remote_drafter.py`: 19 tests covering command-frame round-trip, `RemoteTransport.forward/trim_cache/reset_and_prefill/shutdown` with mocked `mx.distributed`, idempotent shutdown, use-after-shutdown rejection, and `drafter_serve_loop` dispatch (OP_SHUTDOWN, OP_TRIM_CACHE, unknown op rejection). * `test_pipelined_drafter.py`: extends `FakeTransport` with `reset_and_prefill` and adds 3 tests for `make_drafter`'s asymmetric entry points (uses-supplied-transport, rejects-non-protocol, rejects-multi-target). Twin-machine testing recipe documented in `remote_drafter.py` module docstring (TCP via `MlxRing`, RDMA via `MlxJaccl` over Thunderbolt). basedpyright clean, ruff clean, 407 unit tests pass.
Layer B complete: pipelined+remote drafter end-to-end (B7-B9)Two new commits closing the asymmetric drafter loop:
|
| Layer | Status |
|---|---|
| B1-B6 (placement + group split) | merged in earlier commits |
| B7 DrafterRunner + plan helpers | 93a44032 |
| B8 target generator wiring | 8add4551 |
| B9 tests + twin docs | 8add4551 |
| B10 commit + push | this comment |
| Twin benchmarks (TCP + RDMA) | next, on wc-smbp + wc-smbpt |
Standalone OpenAI-compatible-API hammer that captures per-request generation_stats (prompt/gen tps, ttft, drafter accept fraction, draft_mode) across a fixed prompt mix (short repetitive, code, prose, factual). Used to compare local 4-mode runs (none/model/ngram/ pipelined-in-process) and asymmetric N+1 deployments (jaccl RDMA vs ring TCP) without coupling to the existing exo_bench planning harness.
The drafter rank in asymmetric N+1 placement has no shard (it owns the full drafter model, not a slice of the target). RunnerSupervisor.create was unconditionally calling bound_instance.bound_shard during process spawn, which asserts on drafter ranks and crashed the worker before the DrafterRunner could even start. Make shard_metadata Optional, branch on bound_instance.is_drafter_rank when populating it, and read model_id through a new property that falls back to DrafterPlacement.drafter_model_id when shard_metadata is None. The two existing usages (logging, error chunk on shutdown) stay identical for the target rank and now also work for drafter ranks.
brings back EXO_CACHE_HOME as always ~/.cache/exo/, and store the node id in there. no random copies now!
Tracing the conc=4 staircase (slots 0-1 batch but slots 2 and 3 each pay solo prefill, ttft 2719 ms / 3696 ms) through worker. main + supervisor.start_task uncovers the actual upstream serialiser: ``await event.wait()`` at the end of ``RunnerSupervisor.start_task`` blocks the worker's command loop on the runner's per-task ``TaskAcknowledged`` for *every* task, including ``TextGeneration``. With the master fanning all 4 ``Executing command: TextGeneration`` log lines out within 12 ms, the worker dispatches them serially: slot 0 -> ack -> slot 1 -> ack -> slot 2 -> ack -> slot 3. Slot 1 lands on the runner's ``_work_queue`` during the entry-time burst-coalesce (200 ms window) and batches with slot 0, but slot 2's send is still gated on slot 1's ack inside the runner -- which only fires after slot 1's prefill batches *and* the runner's ``_admit_queued_tasks`` runs ``acknowledge_task`` for it. By then the burst-coalesce is closed and the next admit cycle sees candidates=1. Lift the ack-wait for ``TextGeneration`` / ``ImageGeneration`` / ``ImageEdits`` *only when the runner is past warmup* (``RunnerReady`` or ``RunnerRunning``). The runner's state machine admits those tasks unconditionally in those states, so ordering between siblings is preserved by the underlying ``mp_channel`` (FIFO) without needing the worker-level handshake. Lifecycle tasks (``LoadModel``, ``StartWarmup``, ``ConnectToGroup``, ``Shutdown``, ``CancelTask``, etc.) keep the gate so state transitions stay strictly ordered. The ``pending`` map continues to receive the ack via ``_ev_recv`` and the on-runner-death sweep at line 320 still sets all events, so cancellation / shutdown teardown is unchanged. This does NOT remove backpressure: ``mp_channel`` blocks on a full buffer, and the runner's ``_work_queue`` is still drained sequentially by the same single thread. It only removes the coupling between worker-side dispatch latency and runner-side prefill latency for warm generation tasks. With this change the worker fans the conc=4 burst onto ``_work_queue`` within ms of the master's dispatch, and the runner's burst-coalesce + the between-step drain together batch all 4 slots. 389 worker+master unit tests pass; basedpyright/ruff clean. Live conc=4 bench validation in next push.
Conc=4 follow-up: ack-gate lifted, full batched prefill coverageThe earlier conc=4 staircase (slots 0-1 batched, slots 2-3 solo, ttft 1700/1700/2700/3700 ms) traced to Lift the ack-wait gate for Pushed in 70830e8. Conc=4 bench (smbp+smbpt+bmbp, jaccl/RDMA, 2 runs × 4 slots × 384 tokens, long_context_summary)
Diagnostic from runner log on smbp (placement target this run): Two Median TTFT regressionThe +244 ms median TTFT is structural: pre-fix slot 2 paid 1000 ms of solo prefill (TTFT 2700 ms) -- post-fix it pays 1420 ms in a 2-slot batch (TTFT 3250 ms). The batch is faster per token but the second pair has to wait for the first pair's prefill to finish, so 50 % of slots see TTFT shift up. Worst-case TTFT (slot 3 in the pre-fix order) drops 502 ms because slot 3 no longer has to wait through a serial chain. This is the right trade-off when the operator cares about p99 / aggregate throughput; if p50 matters more for a workload, Future leverA single B=4 prefill (instead of two B=2) would cut total prefill wall-clock further. That requires extending Three commits in this slice (all pushed to
|
EXO_BURST_COALESCE_MS sweep at conc=4 (final lever)Same cluster (smbp+smbpt+bmbp jaccl/RDMA), same code (
Findings
Diagnostic from runner log: RecommendationKeep the default at 200ms (already committed). The default targets the most common operator profile (low p50 TTFT, conc<=2 streams hitting an instance) and hands B=4 batching to operators who explicitly opt in via Documented profiles:
The conc=4 work in this PR is now end-to-end:
The |
The pipelined drafter's accept loop did mx.eval(sampled) + .item() per position, costing K+1 host-device syncs per spec round (~2-3ms each on Apple Silicon). On a target with ~10ms step time, the K+1 sync overhead alone is enough to flip spec-decode from net-win to net-loss; that is why the asymmetric drafter benched slower than solo across every K and every prompt type. When logits_processors is empty (the common case at temperature=0 or with no repetition penalty), all K+1 sampling decisions are independent of one another, so we batch them into a single sampler() call and a single mx.eval()+tolist() sync. The stateful path is preserved verbatim for callers that need per-position context (e.g. repetition penalty against running_prev). Expected effect on the asymmetric q4-drafter / q4-MoE-target bench: previous loss of ~14% on code_completion K=3 should narrow or flip to a win.
The previous batched-sampling fast path in PipelinedModelDrafter only
fired when logits_processors was empty. In practice the runner always
prepends ban_token_ids(eos_ids) for bench/length-controlled requests
and make_logits_processors() in the default config returns an empty
processor list otherwise -- so the original gate hit the slow path
on every code path that exercises the bench harness. That is why the
prior commit measured a no-op.
This change:
* Tags ban_token_ids with position_independent=True, since it
discards its history arg and only masks fixed token ids.
* Generalises the spec-loop fast-path gate from "no processors" to
"all processors are position-independent", and applies the
processors once to the batched (K+1, vocab) logits before the
single batched sampler call.
For requests with stateful processors (repetition_penalty,
presence_penalty, frequency_penalty), no attribute is set on the
returned processor so position_independent stays False and the
existing per-position path runs unchanged.
…port) Lift the architectural cap on asymmetric placement: the drafter rank no longer joins mx.distributed.Group, so target ranks are free to run TP/PP collectives without needing Group.split (which jaccl/ring don't implement on Apple Silicon). The wire is now a plain TCP socket between target rank 0 and the drafter rank, carrying the same v3 frames the mx.distributed wire used. Result: sharded multi-rank target + asymmetric drafter on a third node is now placeable on jaccl/RDMA and ring/TCP backends alike. The drafter wire is independent of the target backend and only requires socket reachability between target rank 0 and the drafter node. - New drafter_socket module with bind/dial/frame helpers (TCP_NODELAY, exponential backoff dial, length-prefixed prompt tail). - DrafterPlacement carries drafter_socket_host + drafter_socket_port; placement layer resolves target rank 0's IP via find_ip_prioritised (Thunderbolt > Ethernet > WiFi) and allocates an ephemeral port. - BaseInstance.parent_group_size now returns the target-only rank count (drafter is on TCP, never in mx.distributed). - initialize_mlx asserts not-on-drafter-rank, builds target-only group, and on target rank 0 of an asymmetric placement binds the listener and accepts the drafter's incoming connection. - DrafterRunner._handle_connect dials directly; serve loop runs over the socket with no mx.distributed dependency. - worker/plan.py: drafter ranks dispatch ConnectToGroup + StartWarmup independently; target ranks retain rank-ordered collective barriers. - Retired the EXO_DRAFTER_TRANSPORT="remote" factory path (asymmetric is now built directly from the runner bootstrap). - Tests rewritten on top of socket.socketpair() so both sides of the wire run end-to-end without mocking mx.distributed.
The function returns the SINK end of node_id -> other_node_id connections (i.e. other_node_id's reachable address). To resolve target rank 0's IP for the drafter to dial, the drafter must be node_id and target must be other_node_id; I had them swapped, which produced the drafter's own IP and would have caused dial failure at runtime.
V1 only supported ``target_subgroup_size=1``: a single target rank plus an asymmetric drafter on a separate node. Multi-target placements (e.g. the 3-node TB-RDMA cluster running ``gemma-4-31b-it-bf16`` tensor-sharded across ``wc-smbp`` + ``wc-smbpt`` with the drafter on ``wc-bmbp``) raised ``NotImplementedError`` because non-root target ranks had no way to participate in the spec loop -- they would either skip the drafter entirely (rank-asymmetric draft_mode → TP collective desync) or attempt to load drafter weights twice and never agree on tokens. The fix is a fixed-size rank-0 broadcast on the target subgroup at every drafts-update site (round 0 propose, partial-accept rebuild, full-accept hit, full-accept miss). The broadcast piggybacks on JACCL's well-tested ``all_sum`` collective via a length-prefixed ``int32`` buffer of size ``k+1``; non-root target ranks contribute zeros. ``mx_all_gather`` is deliberately avoided because of an observed JACCL corruption pattern on small int buffers documented in ``mx_all_gather_tasks``. The non-root ``PipelinedModelDrafter`` is built with ``transport=None`` and consumes the broadcast each round; both ranks then run an identical verify forward in TP lockstep. Determinism falls out for free: TP all- reduces logits to be byte-identical on every rank, so accept/reject and emitted tokens match without any further coordination. Sustained measurements on the V2-multi cluster: - 256 tok / short prompt: nodraft 13.94 tok/s -> drafter 15.73 tok/s (+12.8%) - 256 tok / long prompt: nodraft 10.41 tok/s -> drafter 12.90 tok/s (+24%) - 512 tok / long prompt: nodraft 8.63 tok/s -> drafter 9.36-9.69 tok/s (+8.5-12%) Three drafter requests in a row, no deadlocks, coherent output. Adjacent fixes that landed during the bring-up: - ``runner.py``: log compact task identifiers instead of the full ``TextGeneration`` Pydantic model on every entry. Repeated re-planning of the same task while a runner was busy was funneling the entire chat-template + token blob through ``loguru`` once per tick; that hit ``list_repr`` recursion (~300 GB peak physical) and starved the forward loop, causing rank 0 to never enter its TP collective and rank 1 to park forever in ``mlx::core::eval_impl`` cvwait. - ``logging.py``: pass ``diagnose=False`` on every loguru handler. Without this, an exception in a frame that closed over a large list would trigger ``_better_exceptions`` to ``repr()`` every local -- same ``list_repr`` storm pattern, just inside the crash reporter, turning a quick OOM into a multi-minute hang. - ``bootstrap.py``: register ``faulthandler.SIGUSR1`` in every runner subprocess so future TP collective hangs can be diagnosed with full Python tracebacks (root not required, unlike py-spy on macOS). - ``mx_all_gather_tasks``: trust local state and short-circuit (``return list(tasks), []``). The protocol is defensive -- the master pushes the same task list to every target rank -- and JACCL's ``all_gather`` of small int buffers is unreliable on Apple Silicon (sporadic ``[1, 1068875521]`` corruption that drove a 144 GB padded- buffer allocation). If the master ever diverges across ranks, the next TP forward will fail loudly instead of completing on stale data. - ``utils_mlx.load_mlx_items``: surface the drafter ``ModelId`` from ``DrafterPlacement`` on multi-rank distributed loads so ``GenerationStats.drafter_model_id`` is non-null for asymmetric multi-target requests (previously only the single-device load path set it). - ``generate.py``: telemetry now stamps ``drafter_model_id`` and ``K`` for ``mode="pipelined"``; the asymmetric branch builds either the transport-owning or broadcast-consumer drafter based on ``asymmetric_drafter_is_root``.
Closes the correctness gaps left open by the initial V2-multi multi-target asymmetric drafter ship: - **Sampler determinism (cross-rank correctness fix).** The default ``make_sampler`` uses ``temperature=0.7`` -- ``mx.random.categorical`` reads MLX's per-rank PRNG state, so identical post-TP-all-reduce logits produced divergent ``target_tokens`` across target ranks, divergent ``num_accepted``, divergent ``mlx_trim_prompt_cache`` amounts, and silent KV-cache desync at the next TP forward. Rank 0 now samples and broadcasts ``target_tokens`` (k+1 ints) on the target subgroup via a new ``_broadcast_target_tokens`` helper; non-root ranks skip the sampler entirely and consume the broadcast. Both ranks then compute identical accept/reject from byte-identical ``drafts`` and ``target_tokens``. The fix adds one small ``mx.distributed.all_sum`` per round (microsecond-range on Thunderbolt RDMA, negligible against the verify forward). - **mx_all_gather_tasks drift detection.** The trust-local short- circuit kept us safe from JACCL's ``all_gather`` corruption, but silently masked any master-side plan divergence across target ranks. Now hashes the local task list (deterministic 31-bit polynomial hash, ``PYTHONHASHSEED``-independent) and rides an ``all_sum`` collective to compare hashes against root. Mismatch raises locally and surfaces in ``handle_generation_tasks`` rather than corrupting later TP forwards on stale state. The hash mixes per-byte and includes a separator so transposition + concatenation collisions don't slip through. - **mx_broadcast_int_list range validation.** Negatives wrap silently on ``int32`` cast and values >= 2**31 overflow on ``all_sum``. Both are caller-side bugs; centralised ``_validate_broadcast_values`` enforces ``[0, 2**31 - 1]`` on every entry. Length must be >= 1; ``is_root=False`` with ``group is None`` is now rejected as a configuration bug instead of returning the wrong list. - **Drafter-death failure mode documented.** When the drafter rank crashes between rounds, root's ``transport.forward`` raises and ``mlx_generate``'s ``finally`` shuts the broken session down cleanly; non-root target ranks block on the next broadcast until the runner is restarted (same failure mode as any TP-rank death). Documented in the module docstring with a note on how a future termination-sentinel pass could exit non-root cleanly. - **Dashboard surfaces draft_mode.** ``DrafterStats`` now carries the drafter mode through to ``ChatMessages.svelte`` which renders a mode-specific pill (``PIPELINED`` for asymmetric remote, ``MODEL`` for in-process, ``NGRAM`` for suffix lookup, ``SPEC`` fallback for older runner builds). Distinguishes V2 multi-target spec runs from in-process drafting at a glance for A/B comparisons. Tests: - New ``test_utils_mlx_broadcast.py`` covers the single-rank short-circuit contract for ``mx_broadcast_int_list``, the int32 range validation, the task-list drift hash (deterministic, ordering-sensitive, no transposition or concatenation collisions), and the ``mx_all_gather_tasks`` short-circuit. - ``test_pipelined_drafter.py`` adds three V2-multi cases: 3+ target ranks (consumer + root) construct correctly through ``make_drafter``, and the new ``_broadcast_drafts`` / ``_broadcast_target_tokens`` helpers reject configuration bugs (consumer with no group, wrong ``k_this`` length). 698 passed, 1 skipped, 2 deselected (pre-existing model-card test failures from local custom_model_cards override). Lint + basedpyright clean.
Spec-loop level (drafter death):
- Reserve DRAFT_ABORT_SENTINEL in the broadcast length-prefix slot;
raise DrafterAbortedError on non-root when root signals abort.
- Wrap _pipelined_speculative_step body so any OSError on root also
broadcasts the sentinel before re-raising. Non-root exits in
lockstep instead of hanging on the next-round draft broadcast.
- Add RemoteTransport.is_failed sticky flag, set by the blocking
wire helpers on socket close. open_session() rejects after a
failure so subsequent requests can't allocate a session against a
dead wire; the runner crashes via the spec-loop exception and the
master's instance-deletion path rebuilds the placement.
Control-plane level (worker / drafter node disconnect):
- Master's instance-deletion loop now iterates all_node_to_runner
(target + drafter) instead of shard_assignments.node_to_runner
(target only). Drafter-node disconnects therefore tear the
instance down on the same path as a target-rank disconnect, which
routes into the existing supervisor SIGTERM/SIGKILL escalation
chain. Total recovery is bounded by node_inactivity_timeout (5 s)
+ supervisor escalation (~25 s).
Tests:
- TestDrafterAbortRecovery: sentinel range, broadcast short-circuit,
decode -> DrafterAbortedError, root broadcasts on OSError, non-root
does not, recovery best-effort suppression.
- RemoteTransport.is_failed: initial state, set on mid-forward close,
set on mid-trim close, open_session rejection after failure.
- test_asymmetric_all_node_to_runner_includes_drafter_for_disconnect_check:
pins the contract the master fix relies on.
Docs:
- pipelined_drafter.py module docstring rewritten ("Known limitation"
-> "Recovery") with the three-layer model and pointer to target-
rank-death sharing the same control-plane path.
Two bugs caused the V2 multi-target spec loop to hang silently under sustained generation, manifesting as: rank 1 returns to ``RunnerReady`` in 0.4s with no tokens emitted while rank 0 stays stuck in ``mx.eval(sampled_batch)`` inside ``_pipelined_speculative_step_body``; the request times out at the API client. 1. ``SequentialGenerator.step`` was silently swallowing every exception except ``StopIteration``/``PrefillCancelled``. On a multi-rank target group this means a per-rank exception (e.g. the master-divergence ``RuntimeError`` raised below, an MLX collective desync, or a model-side error) makes the offending rank exit the generator, return to the runner's ``_work_queue.get()`` idle, and leave its peer hung in the next TP collective forever. The previous swallow + ``DO NOT re-raise`` comment was written before the supervisor's peer-failure ``_kill_runner`` rule existed; with that rule in place the correct multi-rank behaviour is to escalate (raise -> ``RunnerFailed`` -> peer auto-tears down -> master rebuilds). Single-rank runners keep the swallow path so a malformed request doesn't crash the only generator. Always log with ``logger.opt(exception=True)`` so the peer's hang stack stops being mysterious. 2. ``mx_all_gather_tasks`` is now root-authoritative rather than strict-drift-detecting. The strict detector correctly fired on real divergence but treated libp2p delivery races (master pushing the same task to ranks at slightly different times) as request-failure conditions, so any 2k+ token request had a high chance of catching a benign race window and bailing. The new protocol broadcasts root's task IDs as canonical ASCII slots, every rank admits the subset of root's IDs it has locally, leftover local-only tasks defer to the next round. Rides the well-exercised ``all_sum`` primitive via ``mx_broadcast_int_list``; never desyncs the collective. Tasks beyond ``_MX_AGREE_MAX_TASKS`` (16) defer. Together these turn "rank-1 silent fast-exit + hang" into "both ranks complete the request" or, on a real hardware failure, "supervisor escalates and master rebuilds the instance" -- never indefinite hang. Tests: 26 unit tests in ``test_utils_mlx_broadcast.py`` covering the encode/decode codec, the consumer-side filtering, leftover semantics, root-authoritative ordering, and the ``_MX_AGREE_MAX_TASKS`` cap. All 189 MLX unit tests pass.
The root-authoritative agreement protocol from the previous commit
caused a worse failure mode under sustained generation: when rank 0
admitted a task that rank 1 hadn't yet received via libp2p,
``_active_tasks`` diverged across ranks. Rank 0 would call
``next(gen)`` to advance the spec loop (issuing several ``all_sum``
collectives) while rank 1, with an empty ``_active_tasks``, would
loop back to ``step()`` and re-enter ``agree_on_tasks`` (issuing
its own ``all_sum``). The two collective streams interleaved on the
wire and corrupted each other's payloads, presenting as
``IndexError: list index out of range`` in the detokenizer because
the broadcast token slots arrived scrambled. The supervisor's
peer-failure rule then tore the instance down -- a real recovery,
but not a useful one because the same poison-pill task got
re-placed and crashed the rebuild on every single attempt.
The new protocol is two-phase intersection:
Phase 1: root broadcasts its candidate task IDs as ASCII slots
(same wire format as before; ``mx_broadcast_int_list``).
Phase 2: every rank emits a ``[0, 1]`` indicator vector saying
"I have this canonical ID locally"; the new
``mx_all_sum_int_list`` helper element-wise-sums those
vectors. A slot whose sum equals ``group_size`` is
agreed (every rank had it); slots below ``group_size``
are deferred to the next round.
The agreed set is identical on every rank by construction, so the
collective-count guarantee that ``step()`` depends on
(``if len(_active_tasks) < max_concurrent_tasks: agree_on_tasks()``
firing on the same step on every rank) is preserved across all
delivery-race scenarios. Cost is one extra ``all_sum`` per
``agree_on_tasks`` call -- sub-millisecond on Apple Silicon JACCL,
runs only at admit boundaries.
Tests: rewrote ``test_utils_mlx_broadcast.py`` to exercise the
intersection contract for representative race scenarios (root-only,
peer-only, partial overlap, 3-rank, ordering, max-cap). 190 MLX
unit tests pass.
Sympton on the wild: ~300ms after decode begins on a multi-target asymmetric placement (`gemma-4-31b-it-bf16` over JACCL with the `gemma-4-e2b-it-4bit` drafter), the non-root target rank crashes inside mlx_lm's SPM detokenizer with `IndexError: list index out of range` because the emitted token id sits well outside the tokenizer vocab. Root cause is JACCL collective cross-talk between the model's tensor-parallel `all_sum` (which dispatches on the default Metal stream because `auto_parallel.py` does not pass a `stream=` argument) and the spec-decode broadcasts in `mx_broadcast_int_list` / `mx_all_sum_int_list` / `mx_any` / `mx_barrier`, which deliberately forced themselves onto the CPU stream. With separate streams JACCL sees two independent dispatch queues per rank, so the rendezvous matching is no longer FIFO across the whole group: a CPU `all_sum` on rank 0 can pair with a Metal TP `all_sum` on rank 1 mid-forward, silently corrupting the broadcast buffer. The hot path emits two broadcasts per spec round interleaved with ~120 layer all_sums, which is exactly the pathological pattern. Fix: drop `stream=mx.default_stream(mx.Device(mx.cpu))` from every distributed helper that runs on the same target group as the model TP. The collectives now ride the input array's default stream (Metal), matching the TP all-reduces and giving JACCL one in-order dispatch queue per rank for the entire group. Defence in depth: surface a typed `RuntimeError` (with the offending token id) in `_pipelined_stream_generate` before the SPM detokenizer explodes, so any future broadcast corruption is obvious instead of appearing as an unrelated stack frame deep in `mlx_lm`.
Stream alignment alone did not stop the IndexError -- token id 1083194908 (well outside vocab) still surfaced on rank-1 within ~300ms of decode start. Strong evidence that JACCL is conflating our int32 broadcast `all_sum` with the model's float32 TP `all_sum` on the same target group when both are issued back-to-back on the same stream. Switching the hot-path broadcast to `mx.distributed.send` / `recv` makes it a fundamentally different primitive than the model TP all-reduce, so JACCL has no opportunity to merge the two on the wire. The cost is one wire round-trip per peer per call (vs one collective for the all_sum), which on a 2-rank target subgroup is identical network traffic. Left `mx_all_sum_int_list`, `mx_any`, and `mx_barrier` on `all_sum` because they only fire at admit/cancel boundaries (not per token), so the interleaving frequency that broke the broadcast helper does not apply. Restored the comment to reflect that.
The model's TP all_sum collectives stay on JACCL/RDMA -- those are the
multi-MB tensor reductions the vendor stack is optimised for. But the
tiny ~24-byte int32 broadcasts the spec-decode loop runs between
target rank 0 and its peers (drafts in / sampled tokens out) are NOT
safe on the same wire: probe diagnostics from the prior run showed the
JACCL backend conflating the int32 broadcast with the model's float32
TP all-reduce, returning logit memory reinterpreted as int32 (token id
1083948012 = float ~4.78). The resulting out-of-vocab id surfaced as
an IndexError deep in the SPM detokenizer a few hundred ms into every
generation.
Switching the broadcast primitive (all_sum -> send/recv) didn't help
-- both rode the same target group as the model TP and JACCL kept
mismatching them. The fix is to lift the int wire off mx.distributed
entirely:
* target_peer_socket.py: bind/accept/dial + send_int32_frame /
recv_int32_frame helpers (mirrors drafter_socket.py).
* DrafterPlacement: target_peer_socket_port + per-peer
target_peer_hosts_by_rank, allocated by master/placement.py.
* MlxGroupSplit: target_peer_fanout slot, populated by
initialize_mlx (rank 0 binds + accepts N-1 peers; peers dial in).
* TargetPeerFanout: rank-aware dataclass holding the connected
sockets; target_peer_broadcast_int_list rides them.
* pipelined_drafter: _broadcast_drafts / _broadcast_target_tokens /
_broadcast_abort take a fanout, fall through to the legacy
mx_broadcast_int_list when fanout is None (single-rank,
symmetric, test fakes).
* Plumbing: builder -> SequentialGenerator -> mlx_generate ->
make_drafter -> PipelinedModelDrafter, threading the fanout
through every layer.
Performance: each round adds two ~24-byte TCP frames over Thunderbolt
with TCP_NODELAY, ~100us per broadcast vs. ~30ms verifier forward. <1%
per-round overhead in exchange for a wire that cannot collide with TP
collectives. Same precedent as the drafter wire (which has used a
direct TCP socket since the V3 redesign because JACCL doesn't support
Group.split on Apple Silicon).
The event-router serialises events to JSON for gossipsub fan-out; JSON has no integer dict keys, so a Pydantic dict[int, str] field fails strict re-validation on the receiving worker with "Input should be a valid integer [type=int_type, input_value='1']". Surfaced as: master placed the instance, broadcast InstanceCreated, worker rank 0 (smbp) crashed mid-validate, ASGI loop tore down the runner before MLX init even started, master logged RunnerFailed and gave up. Fix: store rank keys as strings in the wire type and stringify at the consumer (initialize_mlx looks up str(rank) instead of rank). Same shape change pydantic recommends for any int-keyed dict that crosses a JSON boundary.
Drops timestamped logger.info checkpoints around every blocking call on the rank-0 spec path (tolist materialization, OP_PREFILL ack, spec-body prefill iters, seed materialization, OP_FORWARD round 0, broadcast round 0) so we can read off the exact step that wedges against rank 1's draft-broadcast recv. Plain logger.info, not gated; the loops are post-prefill and run at ~30 ms granularity so the noise is bounded.
The runner subprocess on smbpt produces zero loguru output -- its
stdout/stderr aren't reaching tee, so the previous loguru-only
checkpoints were invisible on rank 0 (which is on smbpt this run).
Adds an unconditional file-side-channel that writes plain timestamped
lines to /tmp/spec_diag_<pid>.log alongside the loguru emit. Makes
diagnostics survive whatever's swallowing rank 0's stdout.
Adds checkpoints around every blocking call in the per-round main
loop:
* round top
* model(verify_input) dispatch (NOT eval'd yet -- MLX is lazy)
* sampler() call (root only)
* mx.eval(sampled_batch) -- the spot most likely to wedge because
forcing eval here is what actually fires the deferred verify
forward + TP all_sum
* _broadcast_target_tokens (TCP)
* _broadcast_drafts (next, TCP)
The previous rank-1 logs showed it sailing past round-0 broadcast and
then waiting in sock_recv on _broadcast_target_tokens, while rank 0
was wedged in mlx::core::eval -- consistent with the verify-forward
all_sum failing to pair up. These per-round checkpoints will show
exactly which round wedges and which step inside the round.
Diagnostics from /tmp/spec_diag_<pid>.log captured the exact deadlock:
rank 0 (root): "about to mx.eval(sampled_batch)" -- blocked
forcing the verify forward + all_sum to run
rank 1: "about to call _broadcast_target_tokens" -- blocked
in TCP recv after dispatching model(verify_input)
but never forcing eval of ``logits``
MLX is lazy. ``logits = model(verify_input)`` builds a graph but does
not launch kernels. The TP all-reduce embedded in the final layer is
queued but never executed on rank 1, because rank 1's next step is a
pure-Python TCP recv that does not touch ``logits``. Rank 0's
``mx.eval(sampled_batch)`` then waits forever for rank 1's matching
all_sum, which never fires. Hang.
The fix mirrors the spec-body prefill loop (which worked correctly):
``mx.eval([c.state for c in prompt_cache])`` after every ``model()``
call to force collective kernels to actually run on every rank.
Apply the same pattern to the verify forward: ``mx.eval(logits)``
immediately after ``model(verify_input)``. Both ranks now block until
the TP all-reduce completes. Cost is one extra sync per round, but
rank 0 was going to block on this same sync at the sampler step
anyway -- the only thing that changes is rank 1 now correctly
participates in the collective instead of dropping out.
This is a JACCL / TP-collective correctness fix, independent of the
TCP fanout (which is for the int32 hot-path broadcasts and is working
correctly per the same diagnostics: drafts_broadcast completed in
50ms on rank 1).
The /tmp/spec_diag_<pid>.log side-channel writes ~10 lines per spec round; on an 8k-token run that's ~16k log lines per pid. Useful for correctness regressions, noise in steady-state. Default off; ``EXO_SPEC_DIAG=1`` re-enables. Hooks themselves stay in place so future TP-collective regressions can be isolated immediately without rebuilding with new logging.
Closes the visibility gap that was making "the drafter doesn't deadlock"
indistinguishable from "the drafter is actually working" in benches.
What was missing:
* Non-streaming chat-completions responses dropped GenerationStats on
the floor in collect_chat_response, so callers got no drafter info
even when the drafter ran for thousands of rounds.
* GenerationStats had no proposed-draft count, so the classical
accepted/proposed acceptance rate was unknowable from the API.
* The per-request log was at DEBUG and had no drafter info, so
operators had to grep without knowing what to grep for.
What this adds:
pipelined_drafter:
* Per-request metrics dict on PipelinedModelDrafter, mutated in
lockstep with the spec body's accept loop. Tracks
proposed_draft_tokens (sum of k_this), accepted_draft_tokens
(sum of num_accepted), and spec_decode_rounds.
* Drafter.metrics() returns a snapshot. Other drafter
implementations are unchanged; mlx_generate uses getattr() so
they degrade to {}.
GenerationStats:
* proposed_draft_tokens, spec_decode_rounds new fields (default 0).
* drafter_acceptance_rate = accepted/proposed (classical metric)
alongside the existing drafter_acceptance_fraction
= accepted/generated (speedup-mapping metric).
* Doc clarifies the difference between the two and which to
use when.
Usage / OpenAI compatibility:
* usage.completion_tokens_details.accepted_prediction_tokens
mirrors accepted_draft_tokens (OpenAI's Predicted Outputs term).
* rejected_prediction_tokens = max(0, proposed - accepted) when
the drafter surfaces a proposal count; 0 otherwise.
ChatCompletionResponse:
* generation_stats: GenerationStats | None top-level extension
so non-streaming clients see the full stats object. OpenAI
clients ignore unknown fields; exo benches and the dashboard
read it directly.
collect_chat_response:
* Tracks chunk.stats across the stream and surfaces it on the
final response. Both TokenChunk and ToolCallChunk paths.
Per-request log (mlx_generate):
* Bumped from DEBUG to INFO and now includes drafter mode, id,
K, rounds, accepted/proposed counts, acceptance rate, and
fraction-of-emitted. One bounded line per completed request,
so dashboards / operators can answer "is the drafter helping?"
without flipping verbose.
This is the visibility leg of the gemma-4 drafter rollout. It's a
strict superset of the prior surface (no field removals, all defaults
match the previous behavior), so existing clients keep working.
bench_compare.py: Drafter vs no-drafter A/B at the same prompt and lengths. Surfaces per-run TPS, drafter telemetry (acceptance rate, fraction-of-emitted, K, rounds), and the speedup ratio. Used to demonstrate that the drafter is helping (or, just as importantly, NOT helping at low acceptance rates). bench_concurrent.py: Fires N parallel chat-completions requests at the same instance to measure overlap of in-flight spec-decode sessions. Reports per-request and aggregate TPS so we can tell when adding concurrency hurts vs helps. Both write JSON to /tmp by default; pair with the new generation_stats and per-request log line to get full visibility into drafter effectiveness.
`_custom_cards_dir` resolves to `$EXO_DATA_HOME/custom_model_cards`, where dev workstations keep operator-edited cards (e.g. trimmed drafter lists for memory-constrained clusters). Those overrides layer on top of the shipped TOML, so the gemma-4 card-content gates were asserting against whatever the operator last wrote instead of the shipped data they're supposed to protect. Add an autouse fixture mirroring the pattern in test_model_cards.py: point `_custom_cards_dir` at a fresh tmp dir per test and clear `_card_cache` so the next refresh sees only the shipped builtins. This unblocks the full pytest run that previously failed with mismatches like e2b-it-4bit vs e2b-it-bf16.
|
Closing in favour of the new 4-PR drafter stack:
This split was driven by the file-overlap topology of the 74 commits (heavy churn on Verification (content-preserving):
|
The previous claim that lifting the 14s slot-1 TTFT required upstream
``position_ids`` for tree attention conflated two separate problems:
1. **Tree attention** (EAGLE / Medusa / lookahead) wants K candidate
continuations verified in a *single* forward whose siblings need
different RoPE positions in the same step. ``mlx_lm`` derives every
position's RoPE id from ``KVCache.offset`` (a single ``int``), so
this collapses to linear verify on Apple Silicon. Genuinely blocked
on ``ml-explore/mlx-lm#846``.
2. **Concurrent target requests** (this PR) want N requests' generators
to make progress per ``step``. Each ``mlx_generate`` call already
allocates its own KV cache, so two generators are independent in
everything but model weights (read-only during forward). Round-
robin scheduling at the generator level needs zero changes to
``mlx_lm`` because every ``next(gen)`` is a *single-position*
forward against that generator's own cache.
The previous ``SequentialGenerator._active`` was a singular slot, so
slot 1's TTFT equalled slot 0's full completion time -- 14s on the
PR #15 K=3 single-host n-gram concurrency leg.
This change:
* Replaces ``_active: tuple | None`` with ``_active_tasks: OrderedDict``
capped by ``max_concurrent_tasks``. ``step`` admits up to the cap from
the queue, then round-robins one ``next(gen)`` per active task per
tick. ``OrderedDict`` makes the round-robin order explicit so all
ranks see the same admit/iterate sequence (collective ops in
``agree_on_tasks`` need this).
* Sets ``max_concurrent_tasks = EXO_MAX_CONCURRENT_REQUESTS`` (default
8) for everything except the asymmetric pipelined+remote path, which
stays at 1. The asymmetric cap is real: ``RemoteTransport``'s wire
protocol is per-session and concurrent target requests would
interleave ``OP_PREFILL`` / ``OP_FORWARD`` frames on the same socket
and corrupt the drafter rank's per-request KV state. Lifting *that*
cap requires extending the wire protocol with a request-id field --
separate change.
* Strengthens the K=8-cancel resilience contract: a single in-flight
task's runtime exception now evicts only that task. The previous
contract was "the runner must survive"; the new contract is "every
*other* in-flight task must keep advancing too".
Tests added in ``test_sequential_generator_errors``:
* ``test_round_robin_advances_all_active_tasks_per_tick`` -- the core
contract: with cap=2, both tasks advance one yield per ``step``
(the singular-slot version would have been 0 for slot 1).
* ``test_round_robin_respects_max_concurrent_tasks`` -- cap=1
(asymmetric default) admits only one task per tick.
* ``test_round_robin_per_task_error_does_not_kill_other_active_tasks``
-- a faulty generator evicts only itself; siblings keep advancing.
The pre-existing ``test_runner_survives_sequential_failure_and_serves_
next_task`` was tightened to express the post-refactor contract: a
queue of failing tasks may surface all their failures on tick 1 (the
old ``_active = None`` reset only admitted one per tick).
All 642 tests pass; ``basedpyright`` 0 errors; ``ruff`` clean.
The PR #15 round-robin landed in 456bbb3 cut slot-1 TTFT 5.2x by interleaving decode ticks across in-flight requests, but the residual 11s outliers in the mixed-prompt long_context_summary bench are 6K-token prefills that still run sequentially per-tick: every active slot pays its full prefill cost on the same GPU before its first decode token. Batched prefill folds K eligible slots' prefills into a single ``PromptProcessingBatch.prompt`` call, amortising weight loads across the batch, then hands per-slot pre-filled caches back to ``mlx_generate`` via the new ``precomputed_target_cache`` seam. Eligibility (V1, narrow on purpose): * single-rank target (multi-rank pipeline-parallel prefill keeps its driver loop -- folding it in needs a follow-up that touches ``pipeline_parallel_prefill``'s collective semantics); * no vision (``task_params.images``); * no remote prefill endpoint (the disaggregated-prefill path already runs on a sibling instance via ``InstanceLink``); * no in-process model drafter (its drafter prefill must stay aligned to the target's offset; batching only the target would desync them). The asymmetric pipelined drafter still qualifies because ``self.draft_model is None`` on the target rank -- its drafter prefill goes over the wire per-session and is independent of target prefill batching. ``EXO_BATCH_PREFILL=0`` is the env-var escape hatch for shared-prefix workloads where the per-slot prefix-cache hit rate exceeds the batched-forward speedup. Implementation: * ``batched_prefill`` in ``generate.py`` wraps ``PromptProcessingBatch`` (the upstream mlx-lm helper that handles right-padding + ``prepare(lengths, right_padding)`` + ``finalize()``), slices each prompt to ``prompt[:-1]`` so the post-prefill cache offset lands at ``len(prompt) - 1`` (matching the existing exact-prefix-hit invariant), and re-extracts per-sequence caches in place. Cache layers without ``merge``/``extract`` (e.g. ``DeepseekV4Cache``) raise the typed ``BatchedPrefillUnsupportedError`` so the caller falls back to per-slot prefill instead of crashing the runner. * ``mlx_generate`` accepts ``precomputed_target_cache``: when set, the prefix-cache lookup is bypassed (V1 trade-off -- we don't pollute the shared cache with per-request entries), the local ``prefill`` slice is empty (one-token decode seed), and the exact-prefix-hit shape is reproduced so the rest of the function is untouched. * ``SequentialGenerator`` factors out admission into ``_admit_queued_tasks`` which collects up-to-slack candidates, filters via ``_batch_eligible_for_prefill``, and either runs one ``batched_prefill`` (>=2 eligible) or per-slot ``_start_one``. Per-slot path is preserved verbatim for backwards compatibility with monkeypatched ``_build_generator`` test seams (only forwards ``precomputed_target_cache`` when set). * Untyped failures in the batched forward are charged to every batched task via ``_send_error`` + ``_pending_failed`` so a single malformed request can't take down the runner -- same liveness contract as the per-slot path. Tests: * ``test_batched_prefill.py``: bit-exact correctness vs sequential prefill on B=2, decode continuity after batched prefill, empty- input zero return, length-mismatch / short-prompt ``ValueError``, and ``BatchedPrefillUnsupportedError`` on a cache type without ``merge``. * ``test_sequential_generator_batch_prefill.py``: routing matrix -- two eligible tasks take the batched path, single-eligible falls through to per-slot, mixed eligibility splits cleanly, ``EXO_BATCH_PREFILL=0`` disables, unsupported cache falls back, and each individual disqualifier (group, vision, remote prefill, in-process drafter) routes to per-slot. Asymmetric drafter target is verified to qualify. Bench (still pending): B=1 vs B=2 prefill on 6K-token ``long_context_summary`` -- expected ~1.5-1.8x aggregate prefill tps on a single GPU. The 6K outliers should drop from ~11s to ~6s on B=2, eliminating the worst-case slot-1 TTFT in the mixed bench. Cluster validation gated on TB /30 RDMA repair (operator-driven sudo required, see ``bb rdma repair all``). X-Orchestraitor-Task: gemma4-drafter-tuning X-Orchestraitor-Plan: PR #15 batched-prefill leg X-Agent-Platform: cursor-claude-opus
The PR #15 batched-prefill work landed in 890c7ae, but in production it never fired: two concurrent client requests reliably miss the same ``SequentialGenerator._admit_queued_tasks`` window because the runner calls ``step()`` immediately after submitting the *starting* task and only polls ``_work_queue`` for the second task **after** that first ``step()`` has already run prefill on slot 0. By the time the second admit cycle sees task #2, slot 0 is mid-decode and there's nothing left to batch with. Cluster bench on 3-node TB-RDMA Big Brain (smbp + smbpt + bmbp, gemma-4-26b-a4b-it-4bit single-rank on smbp, drafter disabled, concurrency=2, long_context_summary, 3 runs): EXO_BATCH_PREFILL=0 agg=272.23 t/s slot0_ttft=1119ms slot1_ttft=2143ms EXO_BATCH_PREFILL=1 agg=269.03 t/s slot0_ttft=1114ms slot1_ttft=2135ms Identical numbers across both runs is the smoking gun: batched_prefill never triggered, the per-slot fallback ran in both cases, and slot 1 paid the full prefill of slot 0 on top of its own. Confirmed by the absence of the ``batched_prefill: N slots`` info log on the runner. This commit: * Adds ``Runner._coalesce_burst_generation_tasks`` -- pulls ``TextGeneration`` / ``ImageGeneration`` / ``ImageEdits`` items from ``_work_queue`` and submits them via the existing ``submit_generation`` path before the first ``step()``. * Calls it from ``handle_generation_tasks`` immediately after the starting-task submit so the upcoming admit cycle sees the full burst together. * Blocks on the queue for up to ``EXO_BURST_COALESCE_MS`` (default 20ms) so libp2p-delivery jitter doesn't lose task #2 to the before-step gap. Tuned at 20ms because TB-routed libp2p delivery on concurrent client requests typically straggles 5-15ms; values >50ms start adding user-visible TTFT to solo requests. ``EXO_BURST_COALESCE_MS=0`` disables (per-slot every time). * Routes any non-task item picked up during the drain (``PrefillTask`` / ``_TaskStreamClosed`` / ``Shutdown``) into a new ``_burst_deferred_item`` slot consumed by the main loop before its next ``_work_queue.get_nowait()``. Re-queueing at the tail would race with the listener thread and silently re-order ``Shutdown`` past burst tasks -- this preserves FIFO without needing an appendleft-capable queue. The 307 worker unit tests all pass; basedpyright and ruff clean. Bench validation pending TB-RDMA bringup retry (next commit). X-Orchestraitor-Task: gemma4-drafter-tuning X-Orchestraitor-Plan: PR #15 batched-prefill bench-fix X-Agent-Platform: cursor-claude-opus
Two layers, both behind the unified
DrafterTransportinterface so Layer B is purely additive:Layer A (this commit + the previous 8 in the branch) -- shipping now:
Drafterabstraction with three modes:model(mlx_lm spec_step),ngram(in-house suffix-match spec loop),none(plain stream_generate).from_draft_count,draft_mode), short-output skip, multi-node placement warning, etc.DrafterTransportProtocol +InProcessTransport+PipelinedModelDrafterwith cross-round speculation (drafter forward for round t+1 overlaps target verify of round t; on hit, round t+1 propose is free; on miss, K+1 positions are rolled back).Layer B (next commit on this branch) -- in progress:
RemoteTransportovermx.distributed.send/recv-- carries draft tokens between drafter and target ranks at jaccl (RDMA over Thunderbolt-bridge / IB-verbs) or ring (TCP) speed depending on the instance backend. Reuses exo's existing RDMA infrastructure -- no parallel network stack.MlxJacclInstance/MlxRingInstancewith adrafter_rankfield so rank 0 can load drafter-only while ranks 1..N pipeline-parallel the target.remote_drafter.py(already final).Group.splitso target's pipeline-parallel collectives don't drag the drafter rank in.Apple Silicon caveat: in-process pipelining wins ~10-30% on the Mac Studio (MLX serializes Metal command queues per device); the unambiguous gain is in Layer B's remote transport where target verify includes a cross-machine network round-trip and the speculative drafter forward fully overlaps it. That's why Layer A's
pipelinedmode and Layer B'sRemoteTransportare the same architecture -- the transport is the only thing that changes.Pre-commit checks pass:
uv run basedpyright-- 0 errorsuv run ruff check-- cleanuv run pytest-- 580 passed, 1 skipped, 199 deselected (slow, opt-in)