Skip to content

Gemma-4 drafter tuning + DrafterTransport architecture (model | pipelined | ngram | none)#15

Closed
team-wcv wants to merge 72 commits into
feature/gemma4-drafter-loaderfrom
feature/gemma4-drafter-tuning
Closed

Gemma-4 drafter tuning + DrafterTransport architecture (model | pipelined | ngram | none)#15
team-wcv wants to merge 72 commits into
feature/gemma4-drafter-loaderfrom
feature/gemma4-drafter-tuning

Conversation

@team-wcv
Copy link
Copy Markdown
Owner

@team-wcv team-wcv commented May 7, 2026

Two layers, both behind the unified DrafterTransport interface so Layer B is purely additive:

Layer A (this commit + the previous 8 in the branch) -- shipping now:

  • Drafter abstraction with three modes: model (mlx_lm spec_step), ngram (in-house suffix-match spec loop), none (plain stream_generate).
  • 10-item drafter tuning: per-request K, num_draft_tokens API, EXO_DRAFT_MODE env var, prefix cache for drafter, drafter prefill on spec path, telemetry (from_draft_count, draft_mode), short-output skip, multi-node placement warning, etc.
  • NEW: DrafterTransport Protocol + InProcessTransport + PipelinedModelDrafter with cross-round speculation (drafter forward for round t+1 overlaps target verify of round t; on hit, round t+1 propose is free; on miss, K+1 positions are rolled back).
  • 17 new tests covering the transport contract, factory dispatch, and cache-trim arithmetic for partial / full / hit / miss via a deterministic FakeTransport.

Layer B (next commit on this branch) -- in progress:

  • RemoteTransport over mx.distributed.send/recv -- carries draft tokens between drafter and target ranks at jaccl (RDMA over Thunderbolt-bridge / IB-verbs) or ring (TCP) speed depending on the instance backend. Reuses exo's existing RDMA infrastructure -- no parallel network stack.
  • Asymmetric MlxJacclInstance / MlxRingInstance with a drafter_rank field so rank 0 can load drafter-only while ranks 1..N pipeline-parallel the target.
  • Placement code: replaces the multi-node-drafter warning with actual drafter-on-separate-rank placement.
  • Runner bootstrap: drafter rank loads only the drafter, runs a serve loop that handles the wire protocol from remote_drafter.py (already final).
  • Subgroup setup: Group.split so target's pipeline-parallel collectives don't drag the drafter rank in.
  • Twin testing on wc-smbp + wc-smbpt: TCP/IP via ring + RDMA via jaccl over Thunderbolt-bridge.

Apple Silicon caveat: in-process pipelining wins ~10-30% on the Mac Studio (MLX serializes Metal command queues per device); the unambiguous gain is in Layer B's remote transport where target verify includes a cross-machine network round-trip and the speculative drafter forward fully overlaps it. That's why Layer A's pipelined mode and Layer B's RemoteTransport are the same architecture -- the transport is the only thing that changes.

Pre-commit checks pass:

  • uv run basedpyright -- 0 errors
  • uv run ruff check -- clean
  • uv run pytest -- 580 passed, 1 skipped, 199 deselected (slow, opt-in)

jw-wcv added 16 commits May 6, 2026 19:37
Move the schema from a single drafter to a preference-ordered list so
runners can pick from multiple candidates (e.g. e2b for fastest,
e4b for highest acceptance) at startup time via the new
EXO_DRAFTER_PREFERENCE env var.

- ModelCard.drafter_model_ids: list[ModelId] (default empty)
- All 8 Gemma 4 large-instruct cards now declare both e2b and e4b
  drafters at matching quantisation, in fastest-first order.
- _maybe_load_drafter selects from the candidate list, preferring
  on-disk drafters within the user's preference setting.
- _select_drafter_id helper exposed for testing the policy.
- DownloadCoordinator chains downloads for ALL declared drafters so
  the runtime preference can be flipped without an on-demand fetch.
- API /v1/models surfaces drafter_model_ids: list[str].
- New tests for the selection helper.

X-Orchestraitor-Plan: gemma4-drafter-tuning
Item 1 — Expose num_draft_tokens via EXO_NUM_DRAFT_TOKENS (default 5).
mlx_lm's speculative_generate_step defaults to K=2 which is conservative
for unknown drafters. For purpose-built family pairs (Gemma 4 e2b -> 31b)
acceptance is ~80% and K=5-6 is the sweet spot — at p=0.8 going from K=2
to K=5 yields roughly +60% effective throughput.

Item 3 — Drafter participates in warmup_inference. The first request after
launch no longer pays a cold-cache penalty on the drafter side.

Item 8 — Skip drafter for short generations. When max_tokens is at or
below EXO_DRAFTER_MIN_OUTPUT_TOKENS (default 16) the drafter is dropped
for that request: speculative decoding's fixed setup cost dominates the
saved decode time on classification / structured-output workloads.

resolve_speculative_decoding factored out as a pure helper so the policy
is unit-testable without spinning up MLX.

X-Orchestraitor-Plan: gemma4-drafter-tuning
stream_generate yields from_draft on every token. We now thread that
into the per-request GenerationStats so dashboards/clients can A/B
configurations on real traffic and (next commit) so adaptive K can
self-tune off observed acceptance.

- GenerationStats gains drafter_model_id, accepted_draft_tokens,
  num_draft_tokens, and a drafter_acceptance_fraction property.
- mlx_generate counts from_draft hits and stamps the chosen drafter id
  on the final stats. effective_draft_model is the source of truth, so
  short-skip and distributed paths correctly report "no drafter".
- load_mlx_items now returns the drafter id alongside the model so the
  builder can plumb it into SequentialGenerator -> mlx_generate.
- Dashboard message metadata + ChatMessages.svelte render a SPEC chip
  next to TTFT/tok/s showing acceptance % and K, with the drafter id
  surfaced via a tooltip.

X-Orchestraitor-Plan: gemma4-drafter-tuning
When EXO_ADAPTIVE_DRAFT_TOKENS=1 is set, K is recomputed each request
from a rolling window of observed acceptance fractions:

  acceptance_rate < 0.5  -> K=2  (drafter is missing, don't waste guesses)
  0.5 <= rate < 0.75     -> K=4
  rate >= 0.75           -> K=6

Heterogeneous traffic (code completion = predictable patterns =
high acceptance, reasoning = lower) gets a real win here over a
fixed K. Disabled by default so K stays predictable for benchmarking.

The rolling window is sized at 8 (ADAPTIVE_K_WINDOW), with a
2-observation warmup that falls back to the configured static K.
adaptive_num_draft_tokens is a pure helper so the band logic is
unit-tested separately from the SequentialGenerator wiring.

X-Orchestraitor-Plan: gemma4-drafter-tuning
mlx_lm's speculative_generate_step splits its prompt_cache argument as
[: len(target.layers)] for the target and [len(target.layers):] for the
drafter, so we can hand it a concatenated cache list and have mlx_lm
wire each side to the right model.

When a drafter is loaded, MlxBuilder allocates a parallel KVPrefixCache
for it. mlx_generate looks up the drafter's cache on the same prompt the
target's cache hit on, prefills any uncovered tail with a simple per-step
forward (single-device, no distributed sync needed), stores the
post-prefill state back to the drafter's prefix cache, and concatenates
target + drafter caches into the prompt_cache passed to stream_generate.

Multi-turn Codex / chat workloads that grow turn-by-turn now get drafter
prefix hits instead of paying full drafter prefill on every turn. The
benefit scales with K (number of draft tokens per round): bigger K means
drafter prefill is a larger share of total time, so the prefix cache hit
saves more.

The drafter prefill helper deliberately avoids the full target prefill
machinery (group sync, pipeline parallelism, SSM snapshots, progress
callbacks) -- drafters are small and single-device by construction.

X-Orchestraitor-Plan: gemma4-drafter-tuning
Add `use_drafter: bool | None` and `num_draft_tokens: int | None` to
TextGenerationTaskParams and the chat-completions wire schema. None on
either field means "use the runner's configured defaults", so existing
clients see no behaviour change.

Use cases:
- A latency-sensitive client can set `use_drafter: false` to skip
  speculative decoding entirely on its tight-loop path.
- A throughput-sensitive client (e.g. a Cursor-style integration that
  knows the prompt is long-form code completion) can raise K via
  `num_draft_tokens: 8` for that single request.

The actual K used is reported back in GenerationStats.num_draft_tokens
so the dashboard reflects per-request overrides.

X-Orchestraitor-Plan: gemma4-drafter-tuning
Speculative decoding is single-device only in mlx_lm. Existing placement
already prefers single-node via get_smallest_cycles (a 1-node cycle wins
over any N>1-node cycle when memory permits), so the common case of
"declare drafter -> get drafter" already works.

The corner case the user flagged: a quant that's too big to fit on any
single node falls back to TP and silently loses the drafter. We now log
a clear, actionable warning at placement time when a card with
drafter_model_ids gets placed across N>1 nodes, telling the operator to
re-place a smaller quant for the speculative-decoding speedup.

Auto-swapping to a smaller quant at placement time (the user's
"single-node + 4-bit on the biggest box over TP + 8-bit" heuristic)
requires the placement engine to know about quant variants of the same
base model. That's a larger restructuring; the warning unblocks the
operator-driven workflow today.

X-Orchestraitor-Plan: gemma4-drafter-tuning
mlx_lm's speculative_generate_step splits its prompt_cache argument
into [: len(target.layers)] for the target and [len(target.layers):]
for the drafter. When draft_model is set we MUST hand it a cache
list that contains both halves -- if we pass only the target caches,
mlx_lm sees an empty drafter slice and the first prefill stalls.

Item 6's path only built the drafter caches when a KVPrefixCache was
wired in (the production runner builds one in MlxBuilder.build()), but
warmup_inference calls mlx_generate with drafter_kv_prefix_cache=None,
so warmup hung indefinitely on wc-studio under
gemma-4-26b-a4b-it-4bit + e2b drafter.

Restructured so we always build + prefill a drafter cache when the
drafter is active, and only consult/update the prefix cache when one
is provided. This unblocks warmup and keeps the multi-turn prefix-hit
behavior intact for the real request path.

X-Orchestraitor-Plan: gemma4-drafter-tuning
Reverts the force_plain_kv_cache experiment. The original hang in
exo's spec-decoding path was a cache double-fill, not a cache-type
issue: mlx_lm.speculative_generate_step._prefill always processes
prompt[:y.size-1] regardless of cache offset, so exo's prefill (which
already filled the target cache) plus mlx_lm's _prefill resulted in
2x prompt ingestion and garbage logits.

Forcing plain KVCache for sliding-window layers actually made things
worse: KVCache.make_mask returns variable-shape mx.array masks that
trigger Metal kernel recompiles on every offset change, while
RotatingKVCache.make_mask returns either a bounded-shape mask or the
"causal" string sentinel — one compile that's reused.

Fix:
  * make_kv_cache reverts to honoring model.make_cache() (native mix
    of RotatingKVCache + KVCache for gemma-4).
  * On the spec path, exo's prefill and KVPrefixCache are bypassed;
    we hand mlx_lm fresh native caches for target+drafter and the
    full prompt, and let speculative_generate_step._prefill drive
    ingestion exactly as it expects.
  * Non-spec path is unchanged.

Verified end-to-end on wc-studio:
  warmup_inference (target + drafter, cold process):
    50 tokens in 1.55s @ 33 tok/s
The drafter prefix cache parameter was plumbed through the runner but
never actually used inside ``mlx_generate``: the spec path bypassed the
prefix-cache code entirely and let mlx_lm's spec_step._prefill drive
both target and drafter caches from scratch. This left
drafter_kv_prefix_cache permanently empty and forced the drafter to
re-prefill on every request.

Changes:
  * Resolve speculative-decoding policy (use_drafter, num_draft_tokens,
    drafter_min_output_tokens, distributed) up-front so cache setup
    branches on the *effective* drafter rather than the unfiltered one.
  * On the spec path, look up both target and drafter prefix caches,
    align them at min(target_hit, drafter_hit) so spec_step sees
    identical offsets.
  * Add ``_spec_drafter_prefill`` helper that advances the drafter
    cache to match the target after exo.prefill (target ends at
    prompt - 2 via prefill+trim; drafter must mirror that).
  * After both prefills, snapshot the drafter cache via
    ``add_kv_cache``/``update_kv_cache`` so subsequent requests can
    skip the drafter prefill entirely on prefix hits.
  * Fix telemetry: ``GenerationStats.num_draft_tokens`` now populates
    whenever speculation actually ran, not only when both
    drafter_model_id *and* effective_draft_model are present.

Smoke matrix on wc-studio (gemma-4 26b a4b 4bit + e2b 4bit):
  warmup-with-drafter:    100 tok / 2.15s
  from_draft-telemetry:   31/40 accepted, drafter_id correctly stamped
  K-override (K=2/K=8):   both reported in stats
  use_drafter=False:      0 accepted, drafter_id=None
  skip-short (max<min):   0 accepted, drafter_id=None
  drafter-prefix-cache:   entries grow request-over-request
  adaptive-K:             40 tok / 1.86s, 24 accepted
Refactor speculative decoding around a small ``Drafter`` Protocol so
``mlx_generate`` can dispatch on mode rather than branching at every
cache and stream-generate site. Three concrete drafters cover today's
needs and leave a clean seam for EAGLE/Medusa/lookahead/distributed
drafters in future PRs:

  * ``NoSpecDrafter``: pass-through to ``mlx_lm.stream_generate``.
  * ``ModelDrafter``: wraps ``mlx_lm.stream_generate(draft_model=...)``;
    the existing well-tested upstream spec loop keeps owning the
    model-drafter path (no behaviour change for current users).
  * ``NgramDrafter``: in-house spec loop that proposes drafts by
    suffix-matching the running token context against itself - zero
    drafter compute, no extra KV cache, no warmup. Match-strength-
    adaptive K (proposal length capped to match length) keeps weak
    matches cheap so worst case == baseline.

New surface:

  * ``EXO_DRAFT_MODE`` env var (model | ngram | none) sets the
    process-wide default; falls back to ``model`` if a drafter model
    is loaded, else ``none``. Builder respects the env var and forces
    SequentialGenerator when ngram is requested without a drafter
    model.
  * ``TaskParams.draft_mode`` per-request override (wins over both
    ``use_drafter`` and the env var).
  * ``GenerationStats.draft_mode`` telemetry stamps which strategy
    actually ran, alongside the existing ``drafter_model_id`` /
    ``accepted_draft_tokens`` / ``num_draft_tokens`` fields.

Bench (Mac Studio M3 Ultra, gemma-4-26b-a4b 4-bit, 200 gen tokens):

                         baseline   ngram   model
  echo_rag (RAG)            76.6    75.8    56.9   tps
  code_repeat               70.9    57.2    45.2   tps
  structured_json           76.7    63.1    53.5   tps
  creative_novel            77.1    58.5    43.0   tps
  reasoning_chain           76.7    59.5    49.5   tps

The honest read: on memory-bandwidth-bound single-device 4-bit Mac
Studio inference, both spec modes are net-negative (the K+1-token
verify forward costs nearly K+1 times a single-token forward, so
break-even acceptance fraction is K/(K+1) ~ 80% which most workloads
don't reach). ``none`` remains the right default on this hardware;
``ngram`` and ``model`` are wired and tested for slower-target /
distributed regimes where the economics flip. Phase 2 (pipelined
drafter+verify) and Phase 3 (drafter-on-other-device for distributed
runs) ride on the same Drafter seam.

29 new unit tests cover ``parse_draft_mode``, ``resolve_draft_mode``
precedence, ``NgramDrafter.propose`` (longest-match, recency tie-
break, adaptive K cap), and ``make_drafter`` dispatch.
@team-wcv team-wcv changed the title Gemma 4 drafter tuning: K, telemetry, KV prefix, adaptive K, per-request overrides, placement warning Gemma-4 drafter tuning + Drafter abstraction (model | ngram | none) May 7, 2026
@team-wcv
Copy link
Copy Markdown
Owner Author

team-wcv commented May 7, 2026

Update: added Drafter abstraction + n-gram drafting strategy (commit 2f52b82)

Refactored speculative decoding around a small Drafter Protocol so mlx_generate dispatches on mode rather than branching at every cache and stream-generate site. Three concrete drafters cover today's needs and leave a clean seam for EAGLE/Medusa/lookahead/distributed drafters in future PRs.

New surface

  • DraftMode = Literal["model", "ngram", "none"]
  • Drafter Protocol at the stream-factory level (stream(...) -> Generator[GenerationResponse]).
  • Three concrete drafters:
    • NoSpecDrafter — pass-through to mlx_lm.stream_generate.
    • ModelDrafter — wraps mlx_lm.stream_generate(draft_model=...); the well-tested upstream spec loop keeps owning the model-drafter path (no behaviour change for current users).
    • NgramDrafter — in-house spec loop. Proposes drafts by suffix-matching the running token context against itself. Match-strength-adaptive K (proposal length capped to match length) keeps weak matches cheap so worst case == baseline.
  • EXO_DRAFT_MODE env var (model | ngram | none) for process-wide default.
  • TaskParams.draft_mode per-request override (wins over both use_drafter and the env var).
  • GenerationStats.draft_mode telemetry stamps which strategy actually ran.

Bench (Mac Studio M3 Ultra, gemma-4-26b-a4b 4-bit, 200 gen tokens)

workload baseline tps ngram tps model tps
echo_rag (RAG) 76.6 75.8 56.9
code_repeat 70.9 57.2 45.2
structured_json 76.7 63.1 53.5
creative_novel 77.1 58.5 43.0
reasoning_chain 76.7 59.5 49.5

The honest read: on memory-bandwidth-bound single-device 4-bit Mac Studio inference, both spec modes are net-negative. The K+1-token verify forward costs nearly K+1 times a single-token forward, so break-even acceptance fraction is K/(K+1) ~ 80%, which most workloads don't reach. none remains the right default on this hardware. ngram and model are wired and tested for slower-target / distributed regimes where the economics flip.

Phase 2 (pipelined drafter+verify) and Phase 3 (drafter-on-other-device for distributed runs) ride on the same Drafter seam.

Tests

29 new unit tests in test_drafter_abstraction.py covering parse_draft_mode, resolve_draft_mode precedence, NgramDrafter.propose (longest-match, recency tie-break, adaptive K cap, validation), and make_drafter dispatch. All 557 existing tests still pass; basedpyright clean; ruff clean.

Foundation for the unified pipelined / distributed-drafter architecture.
Lands the transport-agnostic interface and the in-process implementation
so Layer B's RemoteTransport over mx.distributed.send/recv (RDMA via
jaccl, TCP via ring) is purely additive.

* drafter_transport.py: DrafterTransport Protocol with two primitives
  (forward, trim_cache) the spec loop composes into propose / commit /
  rollback. InProcessTransport mirrors mlx_lm._draft_generate but with
  cache-state bookkeeping via mlx_lm.trim_prompt_cache. The K+1 upper
  bound on num_forwards is for the speculative bonus-prediction
  forward.
* pipelined_drafter.py: PipelinedModelDrafter + custom spec loop with
  cross-round speculation. While target verifies round t, drafter runs
  K+1 forwards starting from drafts_t[-1] to predict bonus_t and
  generate round t+1's drafts. On hit, round t+1 propose is free; on
  miss, the K+1 speculative positions are rolled back. Cache-trim
  arithmetic for partial / full / hit / miss is documented in the
  module docstring and exercised by unit tests.
* remote_drafter.py: stub surface for Layer B. NotImplementedError now
  with a pointer to the planned topology change; the wire-protocol
  constants (COMMAND_FRAME_SIZE, OP_PROPOSE / OP_TRIM_CACHE /
  OP_SHUTDOWN) are final so Layer B's drafter_serve_loop is purely
  additive.
* drafter.py: DraftMode gains "pipelined". make_drafter routes the
  new mode through transport_factory_for(EXO_DRAFTER_TRANSPORT) so
  switching in-process vs remote drafter placement is one env var.
* generate.py: spec_active / effective_draft_model widened to include
  "pipelined" so target prefill, drafter prefill, and prefix-cache
  bookkeeping all run for the new mode.
* builder.py: forces SequentialGenerator for "pipelined" (BatchGenerator
  has no spec-decoding hook, same reasoning as ngram).
* api.py / text_generation.py: extends the request / telemetry literals.
* tests: 17 new tests covering Protocol contract, transport-kind
  parsing, factory dispatch, and cross-round-speculation cache-trim
  arithmetic via a deterministic FakeTransport (no MLX weights). Plus
  3 new mode-resolution tests for the pipelined demote-to-none path.

Layer B (RemoteTransport + asymmetric MlxJacclInstance / MlxRingInstance
with drafter_rank field + placement + runner serve-loop + subgroup setup
+ twin testing) lands in the next commit on this branch.
@team-wcv team-wcv changed the title Gemma-4 drafter tuning + Drafter abstraction (model | ngram | none) Gemma-4 drafter tuning + DrafterTransport architecture (model | pipelined | ngram | none) May 7, 2026
jw-wcv added 5 commits May 7, 2026 01:08
Fills out the RemoteTransport DrafterTransport implementation and the
matching drafter-rank serve loop using mx.distributed.send/recv on the
parent group. Network primitive is whatever exo's distributed group
already negotiated for the cluster: jaccl/RDMA when available, ring/TCP
fallback otherwise -- "rdma is already built into exo do not reinvent
the wheel".

* Fixed-size 8-uint32 command frame keeps per-round IPC overhead
  microseconds, not milliseconds.
* OP_FORWARD / OP_TRIM_CACHE / OP_SHUTDOWN with explicit ack frames
  give the target deterministic sync points after every command.
* Background ThreadPoolExecutor on the target rank lets PipelinedModelDrafter
  return a non-blocking DraftFuture so verify(t) and drafter_forward(t+1)
  overlap on the wire.
V1 of operator-controlled asymmetric drafter placement. Operators tag
specific nodes as drafter hosts on a model card; placement appends a
drafter-only rank to the parent mx.distributed group on one of those
nodes whenever target ranks land elsewhere. Target ranks split off into
a target subgroup at runtime; the parent group is reserved for
target<->drafter point-to-point IPC over RemoteTransport.

Shape of the change

* ModelCard gains drafter_eligible_nodes: list[NodeId] = [] -- empty by
  default preserves legacy in-process drafter behaviour.
* DrafterPlacement type captures drafter_node_id, drafter_runner_id,
  drafter_model_id, and the drafter's index in the parent group
  (always last rank == world_size - 1, by convention).
* MlxRingInstance / MlxJacclInstance carry an optional drafter_placement;
  BoundInstance gains is_drafter_rank and a relaxed validator that
  accepts the drafter's runner id without a target shard.
* DrafterPlacementDegraded event + DrafterPlacementDegradationReason
  surface placement-time downgrades (no eligible node alive, all
  eligible nodes already targets, no reachable transport, drafter node
  below memory floor) so dashboards/CLI can show why asymmetric was
  denied. The user's request still completes -- placement returns the
  legacy single-device or multi-node-without-drafter instance and the
  event rides alongside InstanceCreated.
* place_instance picks an eligible drafter node that is alive,
  outside the target cycle, all-to-all reachable from every target
  rank (RDMA for MlxJaccl, socket for MlxRing), and above a 6 GB
  conservative memory floor.
* Target cycle selection now reserves drafter-eligible nodes when
  possible -- a user listing one node as drafter-eligible no longer
  sees that node grabbed as a target rank.
* Multi-node-without-drafter warning updated to point operators at the
  drafter_eligible_nodes opt-in (alongside the existing "use a smaller
  quant" suggestion).
* Master forwards degradation events alongside transition events so
  state is consistent with what the operator sees.

Why N+1 rank from day one

Bigger than V1 needs (twins is N=1 + drafter), but the architecture
treats N=1 the same as N=2 or N=8 -- same parent group init, same
group.split, same RemoteTransport. Building it asymmetric-aware now
avoids paying refactor cost later when the cluster grows.

Tested

10 new tests under test_placement_drafter_asymmetric.py cover the happy
path (ring + jaccl), each degradation reason, multi-eligible fallthrough,
and a serialisation round-trip. test_model_cards_drafter.py covers the
new field's defaults and round-trip. All 39 placement tests pass; full
shared/master/worker unit suite still 383 green.

Layer B6 (utils_mlx group split), B7 (drafter runner dispatch), and B8
(target rank wires PipelinedModelDrafter+RemoteTransport) follow.
initialize_mlx now returns an MlxGroupSplit dataclass holding the
parent group, the target subgroup, and the drafter rank index in the
parent. Symmetric placement (no drafter) returns an MlxGroupSplit
where parent and target subgroup are the same group object, so all
existing call sites that pass `self.group` to tensor / pipeline /
batch collectives keep working unchanged.

Asymmetric placement uses Group.split with deterministic colors:
target ranks color=0, drafter rank color=1. Target ranks get a
size-N subgroup; the drafter ends up alone in a size-1 subgroup it
never uses. Group.split is a collective, so all ranks (including
the drafter) call it together. The parent group is reserved for
RemoteTransport.send/recv between target rank 0 and the drafter rank.

mlx_distributed_init learned to read the rank from
DrafterPlacement.drafter_rank when called on the drafter rank
(which has no bound_shard). The single-rank assertion now checks
parent_group_size instead of len(runner_to_shard) so asymmetric
single-target setups pass through to mx.distributed.init.

MlxBuilder gained parent_group + drafter_rank_in_parent fields
alongside the existing target-subgroup `group`; the image builder
keeps using only target_subgroup since it never declares a drafter.
The test mock for initialize_mlx now wraps a MockGroup in a
symmetric MlxGroupSplit.
Introduces the runner-side dispatch for the asymmetric drafter rank so
the worker plan iterates uniformly over target ranks + drafter rank
when fanning out lifecycle tasks.

* `BaseInstance.all_runner_ids` / `all_node_to_runner` and
  `BoundInstance.parent_rank` / `is_drafter_rank` so plan-time iteration
  doesn't have to special-case the drafter.
* `worker/plan.py` uses those helpers for `_kill_runner`,
  `_create_runner`, `_init_distributed_backend`, `_load_model`,
  `_ready_to_warmup`, `_pending_tasks`. Drafter ranks skip
  `_model_needs_download` (V1 assumes operator pre-download) and the
  parent_group_size gates accommodate the asymmetric N+1 rank topology
  with a single target rank.
* New `DrafterRunner` mirrors the target runner state machine
  (`ConnectToGroup` -> `LoadModel` -> `StartWarmup` -> running) and
  enters `drafter_serve_loop` after warmup. Single-forward warmup
  primes Metal kernels so the first real OP_FORWARD doesn't pay
  JIT-compile cost.
* `runner/bootstrap.py` dispatches on `is_drafter_rank` and runs
  `DrafterRunner` with the same mlx_lm patches the target rank uses.

Type-clean under basedpyright, ruff clean, all 404 unit tests pass.
Closes the loop on the asymmetric topology: when an instance's
`DrafterPlacement` is set, the target rank constructs a long-lived
`RemoteTransport` and routes every request through `PipelinedModelDrafter`
talking over `mx.distributed.send/recv` to the drafter rank. RDMA
(`MlxJaccl`) and TCP (`MlxRing`) backends both ride the existing
`mx.distributed.Group` -- the only knob that picks the wire format
is the group's backend, exactly as upstream's send/recv contract
guarantees.

Wire protocol (extends Layer A's command frames):
* New `OP_PREFILL`: per-request setup. Target announces a prompt of
  `num_prompt_tokens` (encoded in the `num_forwards` slot of the
  fixed-size command frame), then sends a `(num_prompt_tokens,)`
  uint32 array of token ids. Drafter trims its KV cache to offset 0
  and runs prefill forwards in 4096-token chunks (mirroring
  `_spec_drafter_prefill`'s step), then acks. Issued at the start of
  every request so the spec loop's first OP_FORWARD seeds against an
  aligned drafter cache.

DrafterTransport API additions:
* `reset_and_prefill(prompt_tokens)` on the Protocol, implemented as
  no-op for `InProcessTransport` (the legacy mlx_generate path owns
  drafter cache prefill externally) and as the `OP_PREFILL` IPC for
  `RemoteTransport`. The Protocol method exists so `make_drafter`
  doesn't have to dispatch on transport kind to issue per-request setup.

Generator wiring:
* `make_drafter` accepts `pipelined_transport` (a pre-built
  `DrafterTransport`) so SequentialGenerator can allocate the
  RemoteTransport once at build time and reuse it across requests --
  per-request executor + drafter cache lifecycle would defeat the
  asymmetric setup's whole point. Multi-target asymmetric is gated
  behind a `NotImplementedError` (V1 supports a single target rank;
  N>1 needs draft-broadcast on the target subgroup, planned follow-up).
* `mlx_generate` accepts `asymmetric_parent_group`,
  `asymmetric_drafter_rank`, `asymmetric_drafter_transport`. When
  active, bypasses the legacy `group is not None -> draft_mode = "none"`
  demotion (mlx_lm's own spec loop can't handle pipeline collectives;
  ours doesn't need to since we restrict to N=1 target ranks for V1)
  and forces `draft_mode = "pipelined"`.
* `MlxBuilder` populates `parent_group` and `drafter_rank_in_parent`
  in `connect()` and allocates the long-lived `RemoteTransport` in
  `build()` when asymmetric placement is active. Forces
  `SequentialGenerator` (BatchGenerator has no spec-decoding hook).
* `SequentialGenerator.close()` shuts down the transport (sends
  `OP_SHUTDOWN`) so the drafter rank's serve loop drains cleanly on
  runner teardown.

Tests:
* New `test_remote_drafter.py`: 19 tests covering command-frame
  round-trip, `RemoteTransport.forward/trim_cache/reset_and_prefill/shutdown`
  with mocked `mx.distributed`, idempotent shutdown, use-after-shutdown
  rejection, and `drafter_serve_loop` dispatch (OP_SHUTDOWN, OP_TRIM_CACHE,
  unknown op rejection).
* `test_pipelined_drafter.py`: extends `FakeTransport` with
  `reset_and_prefill` and adds 3 tests for `make_drafter`'s asymmetric
  entry points (uses-supplied-transport, rejects-non-protocol,
  rejects-multi-target).

Twin-machine testing recipe documented in `remote_drafter.py` module
docstring (TCP via `MlxRing`, RDMA via `MlxJaccl` over Thunderbolt).

basedpyright clean, ruff clean, 407 unit tests pass.
@team-wcv
Copy link
Copy Markdown
Owner Author

team-wcv commented May 7, 2026

Layer B complete: pipelined+remote drafter end-to-end (B7-B9)

Two new commits closing the asymmetric drafter loop:

93a44032 Add DrafterRunner + plan helpers for asymmetric drafter rank (B7)

Runner-side dispatch for the asymmetric drafter rank. The worker plan now iterates uniformly over target ranks + drafter rank when fanning out lifecycle tasks (ConnectToGroup -> LoadModel -> StartWarmup -> running). New DrafterRunner mirrors the target runner state machine and enters drafter_serve_loop after warming up its drafter model. The drafter rank skips _model_needs_download (V1 assumes operator pre-download) so it doesn't fault during plan execution.

8add4551 Wire pipelined+remote drafter at target rank (asymmetric N+1 path) (B8+B9)

Closes the loop. When instance.drafter_placement is set the target rank constructs a long-lived RemoteTransport and routes every request through PipelinedModelDrafter over mx.distributed.send/recv. RDMA (MlxJaccl) and TCP (MlxRing) ride the existing group; the only knob picking wire format is the group's backend.

Wire protocol additions:

  • OP_PREFILL: per-request setup. Target announces num_prompt_tokens in the command frame, then sends the prompt token array. Drafter trims its KV cache to offset 0 and runs prefill in 4096-token chunks (mirrors _spec_drafter_prefill's step), then acks. Issued at the start of every request so the spec loop's first OP_FORWARD seeds against an aligned drafter cache.

DrafterTransport API:

  • reset_and_prefill(prompt_tokens) on the Protocol. No-op for in-process (legacy mlx_generate path owns drafter cache externally), full IPC for remote.

Generator wiring:

  • make_drafter accepts a pre-built pipelined_transport so SequentialGenerator allocates the RemoteTransport once at build time (per-request executor + drafter cache lifecycle would defeat the asymmetric setup's whole point).
  • mlx_generate accepts asymmetric_parent_group / asymmetric_drafter_rank / asymmetric_drafter_transport. When active, bypasses the legacy group is not None -> draft_mode = "none" demotion and forces pipelined mode.
  • MlxBuilder populates parent group + drafter rank from connect() and allocates the long-lived RemoteTransport in build() when asymmetric.
  • SequentialGenerator.close() shuts down the transport (sends OP_SHUTDOWN) so the drafter rank's serve loop drains cleanly on teardown.

V1 boundary:
Multi-target asymmetric (N>1 target ranks + 1 drafter rank) is gated behind a NotImplementedError -- placement still allows N>1 to keep telemetry honest, but the runner side fails loudly so misconfiguration doesn't silently fall back to an arbitrary mode. N>1 needs draft-broadcast on the target subgroup, planned follow-up. V1 covers the twins case (single target rank + single drafter rank, RDMA or TCP) end-to-end.

Tests: new test_remote_drafter.py covers command-frame round-trip, RemoteTransport.forward / trim_cache / reset_and_prefill / shutdown with mocked mx.distributed, use-after-shutdown rejection, and drafter_serve_loop dispatch. test_pipelined_drafter.py extends with make_drafter asymmetric entry-point tests. Twin-machine testing recipe documented in remote_drafter.py module docstring.

basedpyright clean, ruff clean, 407 unit tests pass.

Status

Layer Status
B1-B6 (placement + group split) merged in earlier commits
B7 DrafterRunner + plan helpers 93a44032
B8 target generator wiring 8add4551
B9 tests + twin docs 8add4551
B10 commit + push this comment
Twin benchmarks (TCP + RDMA) next, on wc-smbp + wc-smbpt

jw-wcv and others added 4 commits May 7, 2026 02:06
Standalone OpenAI-compatible-API hammer that captures per-request
generation_stats (prompt/gen tps, ttft, drafter accept fraction,
draft_mode) across a fixed prompt mix (short repetitive, code, prose,
factual). Used to compare local 4-mode runs (none/model/ngram/
pipelined-in-process) and asymmetric N+1 deployments (jaccl RDMA vs ring
TCP) without coupling to the existing exo_bench planning harness.
The drafter rank in asymmetric N+1 placement has no shard (it owns the
full drafter model, not a slice of the target). RunnerSupervisor.create
was unconditionally calling bound_instance.bound_shard during process
spawn, which asserts on drafter ranks and crashed the worker before the
DrafterRunner could even start.

Make shard_metadata Optional, branch on bound_instance.is_drafter_rank
when populating it, and read model_id through a new property that
falls back to DrafterPlacement.drafter_model_id when shard_metadata is
None. The two existing usages (logging, error chunk on shutdown) stay
identical for the target rank and now also work for drafter ranks.
brings back EXO_CACHE_HOME as always ~/.cache/exo/, and store the node
id in there. no random copies now!
Tracing the conc=4 staircase (slots 0-1 batch but slots 2 and 3
each pay solo prefill, ttft 2719 ms / 3696 ms) through worker.
main + supervisor.start_task uncovers the actual upstream
serialiser: ``await event.wait()`` at the end of
``RunnerSupervisor.start_task`` blocks the worker's command loop
on the runner's per-task ``TaskAcknowledged`` for *every* task,
including ``TextGeneration``. With the master fanning all 4
``Executing command: TextGeneration`` log lines out within 12 ms,
the worker dispatches them serially: slot 0 -> ack -> slot 1 ->
ack -> slot 2 -> ack -> slot 3. Slot 1 lands on the runner's
``_work_queue`` during the entry-time burst-coalesce (200 ms
window) and batches with slot 0, but slot 2's send is still
gated on slot 1's ack inside the runner -- which only fires
after slot 1's prefill batches *and* the runner's
``_admit_queued_tasks`` runs ``acknowledge_task`` for it. By
then the burst-coalesce is closed and the next admit cycle sees
candidates=1.

Lift the ack-wait for ``TextGeneration`` / ``ImageGeneration`` /
``ImageEdits`` *only when the runner is past warmup*
(``RunnerReady`` or ``RunnerRunning``). The runner's state machine
admits those tasks unconditionally in those states, so ordering
between siblings is preserved by the underlying ``mp_channel``
(FIFO) without needing the worker-level handshake. Lifecycle
tasks (``LoadModel``, ``StartWarmup``, ``ConnectToGroup``,
``Shutdown``, ``CancelTask``, etc.) keep the gate so state
transitions stay strictly ordered. The ``pending`` map continues
to receive the ack via ``_ev_recv`` and the on-runner-death
sweep at line 320 still sets all events, so cancellation /
shutdown teardown is unchanged.

This does NOT remove backpressure: ``mp_channel`` blocks on a
full buffer, and the runner's ``_work_queue`` is still drained
sequentially by the same single thread. It only removes the
coupling between worker-side dispatch latency and runner-side
prefill latency for warm generation tasks. With this change the
worker fans the conc=4 burst onto ``_work_queue`` within ms of
the master's dispatch, and the runner's burst-coalesce + the
between-step drain together batch all 4 slots.

389 worker+master unit tests pass; basedpyright/ruff clean.
Live conc=4 bench validation in next push.
@team-wcv
Copy link
Copy Markdown
Owner Author

team-wcv commented May 8, 2026

Conc=4 follow-up: ack-gate lifted, full batched prefill coverage

The earlier conc=4 staircase (slots 0-1 batched, slots 2-3 solo, ttft 1700/1700/2700/3700 ms) traced to RunnerSupervisor.start_task's await event.wait() -- the worker's command loop was blocked waiting for the runner's TaskAcknowledged between every dispatch. With the master fanning all 4 Executing command: TextGeneration log lines out within 12 ms, the worker was nonetheless dispatching them serially: slot 0 -> ack -> slot 1 -> ack -> slot 2 -> ack -> slot 3. Slot 1 made the entry burst-coalesce window; slots 2 and 3 missed it because their dispatches were gated on a runner ack that only fired after the previous slot's prefill completed.

Lift the ack-wait gate for TextGeneration / ImageGeneration / ImageEdits only when the runner is past warmup (RunnerReady or RunnerRunning). Lifecycle tasks (LoadModel / StartWarmup / ConnectToGroup / Shutdown / CancelTask) keep the gate so state transitions stay strictly ordered. The mp_channel underneath _task_sender is FIFO, so sibling-ordering between concurrent generation tasks is preserved without the worker-level handshake.

Pushed in 70830e8.

Conc=4 bench (smbp+smbpt+bmbp, jaccl/RDMA, 2 runs × 4 slots × 384 tokens, long_context_summary)

Metric single-drain (conc=4 baseline) ack-lifted (conc=4 treatment) Δ
Aggregate gen_tps 175.51 179.28 +2.1 %
Median per-request gen_tps 25.93 27.12 +4.6 %
Median TTFT 2352 ms 2596 ms +244 ms
Worst-case TTFT 3984 ms 3482 ms −502 ms
TTFT pattern 1700/1700/2700/3700 1830/1830/3250/3250 --
batched_prefill calls / run 1 (B=2, slots 0-1 only) 2 (B=2 each, slots 0-1 and 2-3) --

Diagnostic from runner log on smbp (placement target this run):

[ 08:12:28.4 INFO ] burst-coalesce drained=1 budget_ms=200 elapsed_ms=204.7 deferred=False
[ 08:12:29.7 INFO ] batched_prefill: 2 slots, 4260 tokens (3228.7 tok/s aggregate)
[ 08:12:31.1 INFO ] batched_prefill: 2 slots, 4260 tokens (3736.7 tok/s aggregate)
[ 08:12:44.9 INFO ] burst-coalesce drained=1 budget_ms=200 elapsed_ms=202.6 deferred=False
[ 08:12:46.3 INFO ] batched_prefill: 2 slots, 4260 tokens (3085.2 tok/s aggregate)
[ 08:12:47.8 INFO ] batched_prefill: 2 slots, 4260 tokens (3315.2 tok/s aggregate)

Two batched_prefill per run, one for each pair of slots, both running > 3 GB/s aggregate over RDMA. Slots 2-3 batch via the between-iteration _drain_pending_work_items (398440c) -- they arrive on _work_queue together because the ack-gate no longer serialises their dispatch.

Median TTFT regression

The +244 ms median TTFT is structural: pre-fix slot 2 paid 1000 ms of solo prefill (TTFT 2700 ms) -- post-fix it pays 1420 ms in a 2-slot batch (TTFT 3250 ms). The batch is faster per token but the second pair has to wait for the first pair's prefill to finish, so 50 % of slots see TTFT shift up. Worst-case TTFT (slot 3 in the pre-fix order) drops 502 ms because slot 3 no longer has to wait through a serial chain. This is the right trade-off when the operator cares about p99 / aggregate throughput; if p50 matters more for a workload, EXO_BURST_COALESCE_MS=0 reverts to the single-slot-per-admit path.

Future lever

A single B=4 prefill (instead of two B=2) would cut total prefill wall-clock further. That requires extending EXO_BURST_COALESCE_MS past the natural inter-task arrival window (currently ~50 ms post-ack-fix) so all 4 slots arrive before the first step() admit. Trade-off: solo and conc=2 workloads pay the wider window. Filing as a follow-up; not in this PR.

Three commits in this slice (all pushed to feature/gemma4-drafter-tuning)

  • 79bfaf2 Downgrade burst-coalesce/admit chatter to debug for solo runners
  • 398440c Drain all pending work items between step() iterations
  • 70830e8 Lift per-task ack-wait gate for warm-runner generation tasks

389 worker+master unit tests pass, basedpyright/ruff clean across all three.

@team-wcv
Copy link
Copy Markdown
Owner Author

team-wcv commented May 8, 2026

EXO_BURST_COALESCE_MS sweep at conc=4 (final lever)

Same cluster (smbp+smbpt+bmbp jaccl/RDMA), same code (70830e8e ack-gate lifted), same bench (long_context_summary, 2 runs × 4 slots × 384 tokens, gemma-4-26b-a4b-it-4bit on smbp). Only the runner-side EXO_BURST_COALESCE_MS changes.

Config Aggregate gen_tps Median per-req gen_tps Best / Median / Worst TTFT (ms) TTFT spread batched_prefill per run Prefill aggregate (tok/s)
200ms (current default) 179.28 27.12 1830 / 2596 / 3482 1652 ms 2× B=2, 4260 tok each 3228 + 3736 / 3085 + 3315
500ms 177.22 28.14 3376 / 3520 / 3668 5 ms 1× B=4, 8520 tok 3814 / 3384
750ms 172.27 27.89 3653 / 3774 / 3909 18 ms 1× B=4, 8520 tok 3829 / 3532

Findings

  • 500ms is the saturation point at this concurrency. drained=3 for both 500ms and 750ms; the 4 slots arrive within 500ms of slot 0. 750ms just adds 250ms of pure wait time without changing batching outcome.
  • B=4 prefill is 5-10% faster per token than 2×B=2 of the same total token budget (3814 vs ~3482 average across the two B=2 batches), because there's only one MoE-router pass and less Python overhead.
  • Per-request gen_tps is best at 500ms (28.14 vs 27.12 at 200ms, +3.8%) because every slot decodes from t=prefill_done; in 200ms config slot 0/1 decode for ~1400ms before slots 2/3 even start prefill.
  • Best-case TTFT is best at 200ms (1830ms vs 3376ms at 500ms) because slots 0/1 don't wait for the longer coalesce window. If a workload values "first token of fastest slot", 200ms wins.
  • TTFT fairness is best at 500ms (5ms spread vs 1652ms spread). The 1830 → 3482ms staircase at 200ms means slots 2/3 see ~2x slot-0's TTFT; at 500ms all 4 see the same TTFT within a few ms.

Diagnostic from runner log:

# 500ms config
[ 08:17:29 INFO ] burst-coalesce drained=3 budget_ms=500 elapsed_ms=510.1 deferred=False
[ 08:17:32 INFO ] batched_prefill: 4 slots, 8520 tokens (3814.3 tok/s aggregate)
[ 08:17:46 INFO ] burst-coalesce drained=3 budget_ms=500 elapsed_ms=510.1 deferred=False
[ 08:17:49 INFO ] batched_prefill: 4 slots, 8520 tokens (3384.0 tok/s aggregate)

# 750ms config (same drained=3, just 250ms more wait)
[ 08:19:56 INFO ] burst-coalesce drained=3 budget_ms=750 elapsed_ms=760.1 deferred=False
[ 08:19:58 INFO ] batched_prefill: 4 slots, 8520 tokens (3828.9 tok/s aggregate)

Recommendation

Keep the default at 200ms (already committed). The default targets the most common operator profile (low p50 TTFT, conc<=2 streams hitting an instance) and hands B=4 batching to operators who explicitly opt in via EXO_BURST_COALESCE_MS=500.

Documented profiles:

  • EXO_BURST_COALESCE_MS=0 -- per-slot prefill on every request (lowest TTFT for solo, no batched_prefill)
  • EXO_BURST_COALESCE_MS=200 (default) -- 2-slot batching for conc=2 streams; conc>=4 falls back to round-robin pairs
  • EXO_BURST_COALESCE_MS=500 -- full B=4 batching for conc=4 fan-in workloads (uniform fairness, +3.8% per-request gen_tps, sacrifices best-case TTFT)
  • EXO_BURST_COALESCE_MS=750+ -- saturated; adds latency without gain on this cluster

The conc=4 work in this PR is now end-to-end:

  • 79bfaf27 Downgrade burst-coalesce/admit chatter to debug
  • 398440cf Drain all pending work items between step() iterations
  • 70830e8e Lift per-task ack-wait gate for warm-runner generation tasks

The EXO_BURST_COALESCE_MS variable already lives in runner.py (74812f22) -- no new code needed. Operators tune the trade between p50 TTFT and aggregate-fairness/throughput via that single knob.

jw-wcv added 21 commits May 7, 2026 22:32
The pipelined drafter's accept loop did mx.eval(sampled) + .item()
per position, costing K+1 host-device syncs per spec round
(~2-3ms each on Apple Silicon). On a target with ~10ms step time,
the K+1 sync overhead alone is enough to flip spec-decode from
net-win to net-loss; that is why the asymmetric drafter benched
slower than solo across every K and every prompt type.

When logits_processors is empty (the common case at temperature=0
or with no repetition penalty), all K+1 sampling decisions are
independent of one another, so we batch them into a single
sampler() call and a single mx.eval()+tolist() sync. The stateful
path is preserved verbatim for callers that need per-position
context (e.g. repetition penalty against running_prev).

Expected effect on the asymmetric q4-drafter / q4-MoE-target
bench: previous loss of ~14% on code_completion K=3 should
narrow or flip to a win.
The previous batched-sampling fast path in PipelinedModelDrafter only
fired when logits_processors was empty. In practice the runner always
prepends ban_token_ids(eos_ids) for bench/length-controlled requests
and make_logits_processors() in the default config returns an empty
processor list otherwise -- so the original gate hit the slow path
on every code path that exercises the bench harness. That is why the
prior commit measured a no-op.

This change:

  * Tags ban_token_ids with position_independent=True, since it
    discards its history arg and only masks fixed token ids.
  * Generalises the spec-loop fast-path gate from "no processors" to
    "all processors are position-independent", and applies the
    processors once to the batched (K+1, vocab) logits before the
    single batched sampler call.

For requests with stateful processors (repetition_penalty,
presence_penalty, frequency_penalty), no attribute is set on the
returned processor so position_independent stays False and the
existing per-position path runs unchanged.
…port)

Lift the architectural cap on asymmetric placement: the drafter rank no
longer joins mx.distributed.Group, so target ranks are free to run TP/PP
collectives without needing Group.split (which jaccl/ring don't
implement on Apple Silicon). The wire is now a plain TCP socket between
target rank 0 and the drafter rank, carrying the same v3 frames the
mx.distributed wire used.

Result: sharded multi-rank target + asymmetric drafter on a third node
is now placeable on jaccl/RDMA and ring/TCP backends alike. The drafter
wire is independent of the target backend and only requires
socket reachability between target rank 0 and the drafter node.

- New drafter_socket module with bind/dial/frame helpers (TCP_NODELAY,
  exponential backoff dial, length-prefixed prompt tail).
- DrafterPlacement carries drafter_socket_host + drafter_socket_port;
  placement layer resolves target rank 0's IP via find_ip_prioritised
  (Thunderbolt > Ethernet > WiFi) and allocates an ephemeral port.
- BaseInstance.parent_group_size now returns the target-only rank count
  (drafter is on TCP, never in mx.distributed).
- initialize_mlx asserts not-on-drafter-rank, builds target-only group,
  and on target rank 0 of an asymmetric placement binds the listener
  and accepts the drafter's incoming connection.
- DrafterRunner._handle_connect dials directly; serve loop runs over
  the socket with no mx.distributed dependency.
- worker/plan.py: drafter ranks dispatch ConnectToGroup + StartWarmup
  independently; target ranks retain rank-ordered collective barriers.
- Retired the EXO_DRAFTER_TRANSPORT="remote" factory path (asymmetric
  is now built directly from the runner bootstrap).
- Tests rewritten on top of socket.socketpair() so both sides of the
  wire run end-to-end without mocking mx.distributed.
The function returns the SINK end of node_id -> other_node_id
connections (i.e. other_node_id's reachable address). To resolve target
rank 0's IP for the drafter to dial, the drafter must be node_id and
target must be other_node_id; I had them swapped, which produced the
drafter's own IP and would have caused dial failure at runtime.
V1 only supported ``target_subgroup_size=1``: a single target rank plus an
asymmetric drafter on a separate node. Multi-target placements (e.g. the
3-node TB-RDMA cluster running ``gemma-4-31b-it-bf16`` tensor-sharded
across ``wc-smbp`` + ``wc-smbpt`` with the drafter on ``wc-bmbp``) raised
``NotImplementedError`` because non-root target ranks had no way to
participate in the spec loop -- they would either skip the drafter
entirely (rank-asymmetric draft_mode → TP collective desync) or attempt
to load drafter weights twice and never agree on tokens.

The fix is a fixed-size rank-0 broadcast on the target subgroup at every
drafts-update site (round 0 propose, partial-accept rebuild, full-accept
hit, full-accept miss). The broadcast piggybacks on JACCL's well-tested
``all_sum`` collective via a length-prefixed ``int32`` buffer of size
``k+1``; non-root target ranks contribute zeros. ``mx_all_gather`` is
deliberately avoided because of an observed JACCL corruption pattern on
small int buffers documented in ``mx_all_gather_tasks``.

The non-root ``PipelinedModelDrafter`` is built with ``transport=None``
and consumes the broadcast each round; both ranks then run an identical
verify forward in TP lockstep. Determinism falls out for free: TP all-
reduces logits to be byte-identical on every rank, so accept/reject and
emitted tokens match without any further coordination.

Sustained measurements on the V2-multi cluster:
- 256 tok / short prompt: nodraft 13.94 tok/s -> drafter 15.73 tok/s (+12.8%)
- 256 tok / long prompt:  nodraft 10.41 tok/s -> drafter 12.90 tok/s (+24%)
- 512 tok / long prompt:  nodraft  8.63 tok/s -> drafter 9.36-9.69 tok/s (+8.5-12%)

Three drafter requests in a row, no deadlocks, coherent output.

Adjacent fixes that landed during the bring-up:

- ``runner.py``: log compact task identifiers instead of the full
  ``TextGeneration`` Pydantic model on every entry. Repeated re-planning
  of the same task while a runner was busy was funneling the entire
  chat-template + token blob through ``loguru`` once per tick; that hit
  ``list_repr`` recursion (~300 GB peak physical) and starved the
  forward loop, causing rank 0 to never enter its TP collective and
  rank 1 to park forever in ``mlx::core::eval_impl`` cvwait.
- ``logging.py``: pass ``diagnose=False`` on every loguru handler.
  Without this, an exception in a frame that closed over a large list
  would trigger ``_better_exceptions`` to ``repr()`` every local --
  same ``list_repr`` storm pattern, just inside the crash reporter,
  turning a quick OOM into a multi-minute hang.
- ``bootstrap.py``: register ``faulthandler.SIGUSR1`` in every runner
  subprocess so future TP collective hangs can be diagnosed with full
  Python tracebacks (root not required, unlike py-spy on macOS).
- ``mx_all_gather_tasks``: trust local state and short-circuit
  (``return list(tasks), []``). The protocol is defensive -- the master
  pushes the same task list to every target rank -- and JACCL's
  ``all_gather`` of small int buffers is unreliable on Apple Silicon
  (sporadic ``[1, 1068875521]`` corruption that drove a 144 GB padded-
  buffer allocation). If the master ever diverges across ranks, the
  next TP forward will fail loudly instead of completing on stale data.
- ``utils_mlx.load_mlx_items``: surface the drafter ``ModelId`` from
  ``DrafterPlacement`` on multi-rank distributed loads so
  ``GenerationStats.drafter_model_id`` is non-null for asymmetric
  multi-target requests (previously only the single-device load path
  set it).
- ``generate.py``: telemetry now stamps ``drafter_model_id`` and ``K``
  for ``mode="pipelined"``; the asymmetric branch builds either the
  transport-owning or broadcast-consumer drafter based on
  ``asymmetric_drafter_is_root``.
Closes the correctness gaps left open by the initial V2-multi
multi-target asymmetric drafter ship:

- **Sampler determinism (cross-rank correctness fix).** The default
  ``make_sampler`` uses ``temperature=0.7`` -- ``mx.random.categorical``
  reads MLX's per-rank PRNG state, so identical post-TP-all-reduce
  logits produced divergent ``target_tokens`` across target ranks,
  divergent ``num_accepted``, divergent ``mlx_trim_prompt_cache``
  amounts, and silent KV-cache desync at the next TP forward. Rank 0
  now samples and broadcasts ``target_tokens`` (k+1 ints) on the
  target subgroup via a new ``_broadcast_target_tokens`` helper;
  non-root ranks skip the sampler entirely and consume the broadcast.
  Both ranks then compute identical accept/reject from byte-identical
  ``drafts`` and ``target_tokens``. The fix adds one small
  ``mx.distributed.all_sum`` per round (microsecond-range on
  Thunderbolt RDMA, negligible against the verify forward).

- **mx_all_gather_tasks drift detection.** The trust-local short-
  circuit kept us safe from JACCL's ``all_gather`` corruption, but
  silently masked any master-side plan divergence across target
  ranks. Now hashes the local task list (deterministic 31-bit
  polynomial hash, ``PYTHONHASHSEED``-independent) and rides an
  ``all_sum`` collective to compare hashes against root. Mismatch
  raises locally and surfaces in ``handle_generation_tasks`` rather
  than corrupting later TP forwards on stale state. The hash mixes
  per-byte and includes a separator so transposition + concatenation
  collisions don't slip through.

- **mx_broadcast_int_list range validation.** Negatives wrap silently
  on ``int32`` cast and values >= 2**31 overflow on ``all_sum``.
  Both are caller-side bugs; centralised ``_validate_broadcast_values``
  enforces ``[0, 2**31 - 1]`` on every entry. Length must be >= 1;
  ``is_root=False`` with ``group is None`` is now rejected as a
  configuration bug instead of returning the wrong list.

- **Drafter-death failure mode documented.** When the drafter rank
  crashes between rounds, root's ``transport.forward`` raises and
  ``mlx_generate``'s ``finally`` shuts the broken session down
  cleanly; non-root target ranks block on the next broadcast until
  the runner is restarted (same failure mode as any TP-rank death).
  Documented in the module docstring with a note on how a future
  termination-sentinel pass could exit non-root cleanly.

- **Dashboard surfaces draft_mode.** ``DrafterStats`` now carries the
  drafter mode through to ``ChatMessages.svelte`` which renders a
  mode-specific pill (``PIPELINED`` for asymmetric remote, ``MODEL``
  for in-process, ``NGRAM`` for suffix lookup, ``SPEC`` fallback for
  older runner builds). Distinguishes V2 multi-target spec runs from
  in-process drafting at a glance for A/B comparisons.

Tests:
- New ``test_utils_mlx_broadcast.py`` covers the single-rank
  short-circuit contract for ``mx_broadcast_int_list``, the int32
  range validation, the task-list drift hash (deterministic,
  ordering-sensitive, no transposition or concatenation collisions),
  and the ``mx_all_gather_tasks`` short-circuit.
- ``test_pipelined_drafter.py`` adds three V2-multi cases: 3+ target
  ranks (consumer + root) construct correctly through ``make_drafter``,
  and the new ``_broadcast_drafts`` / ``_broadcast_target_tokens``
  helpers reject configuration bugs (consumer with no group, wrong
  ``k_this`` length).

698 passed, 1 skipped, 2 deselected (pre-existing model-card test
failures from local custom_model_cards override). Lint + basedpyright
clean.
Spec-loop level (drafter death):
- Reserve DRAFT_ABORT_SENTINEL in the broadcast length-prefix slot;
  raise DrafterAbortedError on non-root when root signals abort.
- Wrap _pipelined_speculative_step body so any OSError on root also
  broadcasts the sentinel before re-raising. Non-root exits in
  lockstep instead of hanging on the next-round draft broadcast.
- Add RemoteTransport.is_failed sticky flag, set by the blocking
  wire helpers on socket close. open_session() rejects after a
  failure so subsequent requests can't allocate a session against a
  dead wire; the runner crashes via the spec-loop exception and the
  master's instance-deletion path rebuilds the placement.

Control-plane level (worker / drafter node disconnect):
- Master's instance-deletion loop now iterates all_node_to_runner
  (target + drafter) instead of shard_assignments.node_to_runner
  (target only). Drafter-node disconnects therefore tear the
  instance down on the same path as a target-rank disconnect, which
  routes into the existing supervisor SIGTERM/SIGKILL escalation
  chain. Total recovery is bounded by node_inactivity_timeout (5 s)
  + supervisor escalation (~25 s).

Tests:
- TestDrafterAbortRecovery: sentinel range, broadcast short-circuit,
  decode -> DrafterAbortedError, root broadcasts on OSError, non-root
  does not, recovery best-effort suppression.
- RemoteTransport.is_failed: initial state, set on mid-forward close,
  set on mid-trim close, open_session rejection after failure.
- test_asymmetric_all_node_to_runner_includes_drafter_for_disconnect_check:
  pins the contract the master fix relies on.

Docs:
- pipelined_drafter.py module docstring rewritten ("Known limitation"
  -> "Recovery") with the three-layer model and pointer to target-
  rank-death sharing the same control-plane path.
Two bugs caused the V2 multi-target spec loop to hang silently
under sustained generation, manifesting as: rank 1 returns to
``RunnerReady`` in 0.4s with no tokens emitted while rank 0 stays
stuck in ``mx.eval(sampled_batch)`` inside
``_pipelined_speculative_step_body``; the request times out at
the API client.

1. ``SequentialGenerator.step`` was silently swallowing every
   exception except ``StopIteration``/``PrefillCancelled``. On a
   multi-rank target group this means a per-rank exception (e.g.
   the master-divergence ``RuntimeError`` raised below, an MLX
   collective desync, or a model-side error) makes the offending
   rank exit the generator, return to the runner's
   ``_work_queue.get()`` idle, and leave its peer hung in the
   next TP collective forever. The previous swallow + ``DO NOT
   re-raise`` comment was written before the supervisor's
   peer-failure ``_kill_runner`` rule existed; with that rule in
   place the correct multi-rank behaviour is to escalate (raise
   -> ``RunnerFailed`` -> peer auto-tears down -> master
   rebuilds). Single-rank runners keep the swallow path so a
   malformed request doesn't crash the only generator. Always
   log with ``logger.opt(exception=True)`` so the peer's hang
   stack stops being mysterious.

2. ``mx_all_gather_tasks`` is now root-authoritative rather than
   strict-drift-detecting. The strict detector correctly fired
   on real divergence but treated libp2p delivery races (master
   pushing the same task to ranks at slightly different times)
   as request-failure conditions, so any 2k+ token request had
   a high chance of catching a benign race window and bailing.
   The new protocol broadcasts root's task IDs as canonical
   ASCII slots, every rank admits the subset of root's IDs it
   has locally, leftover local-only tasks defer to the next
   round. Rides the well-exercised ``all_sum`` primitive via
   ``mx_broadcast_int_list``; never desyncs the collective.
   Tasks beyond ``_MX_AGREE_MAX_TASKS`` (16) defer.

Together these turn "rank-1 silent fast-exit + hang" into
"both ranks complete the request" or, on a real hardware
failure, "supervisor escalates and master rebuilds the
instance" -- never indefinite hang.

Tests: 26 unit tests in ``test_utils_mlx_broadcast.py`` covering
the encode/decode codec, the consumer-side filtering, leftover
semantics, root-authoritative ordering, and the
``_MX_AGREE_MAX_TASKS`` cap. All 189 MLX unit tests pass.
The root-authoritative agreement protocol from the previous commit
caused a worse failure mode under sustained generation: when rank 0
admitted a task that rank 1 hadn't yet received via libp2p,
``_active_tasks`` diverged across ranks. Rank 0 would call
``next(gen)`` to advance the spec loop (issuing several ``all_sum``
collectives) while rank 1, with an empty ``_active_tasks``, would
loop back to ``step()`` and re-enter ``agree_on_tasks`` (issuing
its own ``all_sum``). The two collective streams interleaved on the
wire and corrupted each other's payloads, presenting as
``IndexError: list index out of range`` in the detokenizer because
the broadcast token slots arrived scrambled. The supervisor's
peer-failure rule then tore the instance down -- a real recovery,
but not a useful one because the same poison-pill task got
re-placed and crashed the rebuild on every single attempt.

The new protocol is two-phase intersection:

  Phase 1: root broadcasts its candidate task IDs as ASCII slots
           (same wire format as before; ``mx_broadcast_int_list``).
  Phase 2: every rank emits a ``[0, 1]`` indicator vector saying
           "I have this canonical ID locally"; the new
           ``mx_all_sum_int_list`` helper element-wise-sums those
           vectors. A slot whose sum equals ``group_size`` is
           agreed (every rank had it); slots below ``group_size``
           are deferred to the next round.

The agreed set is identical on every rank by construction, so the
collective-count guarantee that ``step()`` depends on
(``if len(_active_tasks) < max_concurrent_tasks: agree_on_tasks()``
firing on the same step on every rank) is preserved across all
delivery-race scenarios. Cost is one extra ``all_sum`` per
``agree_on_tasks`` call -- sub-millisecond on Apple Silicon JACCL,
runs only at admit boundaries.

Tests: rewrote ``test_utils_mlx_broadcast.py`` to exercise the
intersection contract for representative race scenarios (root-only,
peer-only, partial overlap, 3-rank, ordering, max-cap). 190 MLX
unit tests pass.
Sympton on the wild: ~300ms after decode begins on a multi-target
asymmetric placement (`gemma-4-31b-it-bf16` over JACCL with the
`gemma-4-e2b-it-4bit` drafter), the non-root target rank crashes inside
mlx_lm's SPM detokenizer with `IndexError: list index out of range`
because the emitted token id sits well outside the tokenizer vocab.

Root cause is JACCL collective cross-talk between the model's
tensor-parallel `all_sum` (which dispatches on the default Metal stream
because `auto_parallel.py` does not pass a `stream=` argument) and the
spec-decode broadcasts in `mx_broadcast_int_list` /
`mx_all_sum_int_list` / `mx_any` / `mx_barrier`, which deliberately
forced themselves onto the CPU stream. With separate streams JACCL
sees two independent dispatch queues per rank, so the rendezvous
matching is no longer FIFO across the whole group: a CPU `all_sum`
on rank 0 can pair with a Metal TP `all_sum` on rank 1 mid-forward,
silently corrupting the broadcast buffer. The hot path emits two
broadcasts per spec round interleaved with ~120 layer all_sums, which
is exactly the pathological pattern.

Fix: drop `stream=mx.default_stream(mx.Device(mx.cpu))` from every
distributed helper that runs on the same target group as the model TP.
The collectives now ride the input array's default stream (Metal),
matching the TP all-reduces and giving JACCL one in-order dispatch
queue per rank for the entire group.

Defence in depth: surface a typed `RuntimeError` (with the offending
token id) in `_pipelined_stream_generate` before the SPM detokenizer
explodes, so any future broadcast corruption is obvious instead of
appearing as an unrelated stack frame deep in `mlx_lm`.
Stream alignment alone did not stop the IndexError -- token id
1083194908 (well outside vocab) still surfaced on rank-1 within ~300ms
of decode start. Strong evidence that JACCL is conflating our int32
broadcast `all_sum` with the model's float32 TP `all_sum` on the same
target group when both are issued back-to-back on the same stream.

Switching the hot-path broadcast to `mx.distributed.send` / `recv`
makes it a fundamentally different primitive than the model TP
all-reduce, so JACCL has no opportunity to merge the two on the wire.
The cost is one wire round-trip per peer per call (vs one collective
for the all_sum), which on a 2-rank target subgroup is identical
network traffic.

Left `mx_all_sum_int_list`, `mx_any`, and `mx_barrier` on `all_sum`
because they only fire at admit/cancel boundaries (not per token),
so the interleaving frequency that broke the broadcast helper does
not apply. Restored the comment to reflect that.
The model's TP all_sum collectives stay on JACCL/RDMA -- those are the
multi-MB tensor reductions the vendor stack is optimised for. But the
tiny ~24-byte int32 broadcasts the spec-decode loop runs between
target rank 0 and its peers (drafts in / sampled tokens out) are NOT
safe on the same wire: probe diagnostics from the prior run showed the
JACCL backend conflating the int32 broadcast with the model's float32
TP all-reduce, returning logit memory reinterpreted as int32 (token id
1083948012 = float ~4.78). The resulting out-of-vocab id surfaced as
an IndexError deep in the SPM detokenizer a few hundred ms into every
generation.

Switching the broadcast primitive (all_sum -> send/recv) didn't help
-- both rode the same target group as the model TP and JACCL kept
mismatching them. The fix is to lift the int wire off mx.distributed
entirely:

  * target_peer_socket.py: bind/accept/dial + send_int32_frame /
    recv_int32_frame helpers (mirrors drafter_socket.py).
  * DrafterPlacement: target_peer_socket_port + per-peer
    target_peer_hosts_by_rank, allocated by master/placement.py.
  * MlxGroupSplit: target_peer_fanout slot, populated by
    initialize_mlx (rank 0 binds + accepts N-1 peers; peers dial in).
  * TargetPeerFanout: rank-aware dataclass holding the connected
    sockets; target_peer_broadcast_int_list rides them.
  * pipelined_drafter: _broadcast_drafts / _broadcast_target_tokens /
    _broadcast_abort take a fanout, fall through to the legacy
    mx_broadcast_int_list when fanout is None (single-rank,
    symmetric, test fakes).
  * Plumbing: builder -> SequentialGenerator -> mlx_generate ->
    make_drafter -> PipelinedModelDrafter, threading the fanout
    through every layer.

Performance: each round adds two ~24-byte TCP frames over Thunderbolt
with TCP_NODELAY, ~100us per broadcast vs. ~30ms verifier forward. <1%
per-round overhead in exchange for a wire that cannot collide with TP
collectives. Same precedent as the drafter wire (which has used a
direct TCP socket since the V3 redesign because JACCL doesn't support
Group.split on Apple Silicon).
The event-router serialises events to JSON for gossipsub fan-out;
JSON has no integer dict keys, so a Pydantic dict[int, str] field
fails strict re-validation on the receiving worker with
"Input should be a valid integer [type=int_type, input_value='1']".

Surfaced as: master placed the instance, broadcast InstanceCreated,
worker rank 0 (smbp) crashed mid-validate, ASGI loop tore down the
runner before MLX init even started, master logged
RunnerFailed and gave up.

Fix: store rank keys as strings in the wire type and stringify at
the consumer (initialize_mlx looks up str(rank) instead of rank).
Same shape change pydantic recommends for any int-keyed dict that
crosses a JSON boundary.
Drops timestamped logger.info checkpoints around every blocking call on
the rank-0 spec path (tolist materialization, OP_PREFILL ack,
spec-body prefill iters, seed materialization, OP_FORWARD round 0,
broadcast round 0) so we can read off the exact step that wedges
against rank 1's draft-broadcast recv.

Plain logger.info, not gated; the loops are post-prefill and run at
~30 ms granularity so the noise is bounded.
The runner subprocess on smbpt produces zero loguru output -- its
stdout/stderr aren't reaching tee, so the previous loguru-only
checkpoints were invisible on rank 0 (which is on smbpt this run).

Adds an unconditional file-side-channel that writes plain timestamped
lines to /tmp/spec_diag_<pid>.log alongside the loguru emit. Makes
diagnostics survive whatever's swallowing rank 0's stdout.

Adds checkpoints around every blocking call in the per-round main
loop:
  * round top
  * model(verify_input) dispatch (NOT eval'd yet -- MLX is lazy)
  * sampler() call (root only)
  * mx.eval(sampled_batch) -- the spot most likely to wedge because
    forcing eval here is what actually fires the deferred verify
    forward + TP all_sum
  * _broadcast_target_tokens (TCP)
  * _broadcast_drafts (next, TCP)

The previous rank-1 logs showed it sailing past round-0 broadcast and
then waiting in sock_recv on _broadcast_target_tokens, while rank 0
was wedged in mlx::core::eval -- consistent with the verify-forward
all_sum failing to pair up. These per-round checkpoints will show
exactly which round wedges and which step inside the round.
Diagnostics from /tmp/spec_diag_<pid>.log captured the exact deadlock:

  rank 0 (root): "about to mx.eval(sampled_batch)" -- blocked
                 forcing the verify forward + all_sum to run
  rank 1:        "about to call _broadcast_target_tokens" -- blocked
                 in TCP recv after dispatching model(verify_input)
                 but never forcing eval of ``logits``

MLX is lazy. ``logits = model(verify_input)`` builds a graph but does
not launch kernels. The TP all-reduce embedded in the final layer is
queued but never executed on rank 1, because rank 1's next step is a
pure-Python TCP recv that does not touch ``logits``. Rank 0's
``mx.eval(sampled_batch)`` then waits forever for rank 1's matching
all_sum, which never fires. Hang.

The fix mirrors the spec-body prefill loop (which worked correctly):
``mx.eval([c.state for c in prompt_cache])`` after every ``model()``
call to force collective kernels to actually run on every rank.

Apply the same pattern to the verify forward: ``mx.eval(logits)``
immediately after ``model(verify_input)``. Both ranks now block until
the TP all-reduce completes. Cost is one extra sync per round, but
rank 0 was going to block on this same sync at the sampler step
anyway -- the only thing that changes is rank 1 now correctly
participates in the collective instead of dropping out.

This is a JACCL / TP-collective correctness fix, independent of the
TCP fanout (which is for the int32 hot-path broadcasts and is working
correctly per the same diagnostics: drafts_broadcast completed in
50ms on rank 1).
The /tmp/spec_diag_<pid>.log side-channel writes ~10 lines per spec
round; on an 8k-token run that's ~16k log lines per pid. Useful for
correctness regressions, noise in steady-state.

Default off; ``EXO_SPEC_DIAG=1`` re-enables. Hooks themselves stay in
place so future TP-collective regressions can be isolated immediately
without rebuilding with new logging.
Closes the visibility gap that was making "the drafter doesn't deadlock"
indistinguishable from "the drafter is actually working" in benches.

What was missing:
  * Non-streaming chat-completions responses dropped GenerationStats on
    the floor in collect_chat_response, so callers got no drafter info
    even when the drafter ran for thousands of rounds.
  * GenerationStats had no proposed-draft count, so the classical
    accepted/proposed acceptance rate was unknowable from the API.
  * The per-request log was at DEBUG and had no drafter info, so
    operators had to grep without knowing what to grep for.

What this adds:

  pipelined_drafter:
    * Per-request metrics dict on PipelinedModelDrafter, mutated in
      lockstep with the spec body's accept loop. Tracks
      proposed_draft_tokens (sum of k_this), accepted_draft_tokens
      (sum of num_accepted), and spec_decode_rounds.
    * Drafter.metrics() returns a snapshot. Other drafter
      implementations are unchanged; mlx_generate uses getattr() so
      they degrade to {}.

  GenerationStats:
    * proposed_draft_tokens, spec_decode_rounds new fields (default 0).
    * drafter_acceptance_rate = accepted/proposed (classical metric)
      alongside the existing drafter_acceptance_fraction
      = accepted/generated (speedup-mapping metric).
    * Doc clarifies the difference between the two and which to
      use when.

  Usage / OpenAI compatibility:
    * usage.completion_tokens_details.accepted_prediction_tokens
      mirrors accepted_draft_tokens (OpenAI's Predicted Outputs term).
    * rejected_prediction_tokens = max(0, proposed - accepted) when
      the drafter surfaces a proposal count; 0 otherwise.

  ChatCompletionResponse:
    * generation_stats: GenerationStats | None top-level extension
      so non-streaming clients see the full stats object. OpenAI
      clients ignore unknown fields; exo benches and the dashboard
      read it directly.

  collect_chat_response:
    * Tracks chunk.stats across the stream and surfaces it on the
      final response. Both TokenChunk and ToolCallChunk paths.

  Per-request log (mlx_generate):
    * Bumped from DEBUG to INFO and now includes drafter mode, id,
      K, rounds, accepted/proposed counts, acceptance rate, and
      fraction-of-emitted. One bounded line per completed request,
      so dashboards / operators can answer "is the drafter helping?"
      without flipping verbose.

This is the visibility leg of the gemma-4 drafter rollout. It's a
strict superset of the prior surface (no field removals, all defaults
match the previous behavior), so existing clients keep working.
bench_compare.py:
  Drafter vs no-drafter A/B at the same prompt and lengths.
  Surfaces per-run TPS, drafter telemetry (acceptance rate,
  fraction-of-emitted, K, rounds), and the speedup ratio.
  Used to demonstrate that the drafter is helping (or, just as
  importantly, NOT helping at low acceptance rates).

bench_concurrent.py:
  Fires N parallel chat-completions requests at the same instance
  to measure overlap of in-flight spec-decode sessions. Reports
  per-request and aggregate TPS so we can tell when adding
  concurrency hurts vs helps.

Both write JSON to /tmp by default; pair with the new
generation_stats and per-request log line to get full visibility
into drafter effectiveness.
`_custom_cards_dir` resolves to `$EXO_DATA_HOME/custom_model_cards`,
where dev workstations keep operator-edited cards (e.g. trimmed
drafter lists for memory-constrained clusters). Those overrides
layer on top of the shipped TOML, so the gemma-4 card-content
gates were asserting against whatever the operator last wrote
instead of the shipped data they're supposed to protect.

Add an autouse fixture mirroring the pattern in test_model_cards.py:
point `_custom_cards_dir` at a fresh tmp dir per test and clear
`_card_cache` so the next refresh sees only the shipped builtins.

This unblocks the full pytest run that previously failed with
mismatches like e2b-it-4bit vs e2b-it-bf16.
@team-wcv
Copy link
Copy Markdown
Owner Author

team-wcv commented May 9, 2026

Closing in favour of the new 4-PR drafter stack:

New PR Scope
#18 - foundations ModelCard drafter_model_ids field, gemma-4 card TOMLs, auto-download, Drafter interface, n-gram strategy.
#19 - in-process tuning K/warmup/short-skip, KV prefix cache, per-request overrides, mlx_lm native cache wiring.
#20 - asymmetric pipelined drafter V1 (mx.distributed), concurrency block, EAGLE/lookahead scaffolds, V2 sharded multi-target, V3 socket transport.
#21 - production hardening Resilience (drafter-death + worker-disconnect recovery, two-phase intersection consensus), TP fanout (TCP int broadcast, the mx.eval(logits) deadlock fix), drafter telemetry, A/B + concurrent benches.

This split was driven by the file-overlap topology of the 74 commits (heavy churn on generate.py, pipelined_drafter.py, utils_mlx.py), so the new PRs are stacked rather than independent. PR-INFRA orthogonal cluster fixes are split to #17.

Verification (content-preserving):

@team-wcv team-wcv closed this May 9, 2026
team-wcv pushed a commit that referenced this pull request May 9, 2026
The previous claim that lifting the 14s slot-1 TTFT required upstream
``position_ids`` for tree attention conflated two separate problems:

  1. **Tree attention** (EAGLE / Medusa / lookahead) wants K candidate
     continuations verified in a *single* forward whose siblings need
     different RoPE positions in the same step. ``mlx_lm`` derives every
     position's RoPE id from ``KVCache.offset`` (a single ``int``), so
     this collapses to linear verify on Apple Silicon. Genuinely blocked
     on ``ml-explore/mlx-lm#846``.

  2. **Concurrent target requests** (this PR) want N requests' generators
     to make progress per ``step``. Each ``mlx_generate`` call already
     allocates its own KV cache, so two generators are independent in
     everything but model weights (read-only during forward). Round-
     robin scheduling at the generator level needs zero changes to
     ``mlx_lm`` because every ``next(gen)`` is a *single-position*
     forward against that generator's own cache.

The previous ``SequentialGenerator._active`` was a singular slot, so
slot 1's TTFT equalled slot 0's full completion time -- 14s on the
PR #15 K=3 single-host n-gram concurrency leg.

This change:

* Replaces ``_active: tuple | None`` with ``_active_tasks: OrderedDict``
  capped by ``max_concurrent_tasks``. ``step`` admits up to the cap from
  the queue, then round-robins one ``next(gen)`` per active task per
  tick. ``OrderedDict`` makes the round-robin order explicit so all
  ranks see the same admit/iterate sequence (collective ops in
  ``agree_on_tasks`` need this).

* Sets ``max_concurrent_tasks = EXO_MAX_CONCURRENT_REQUESTS`` (default
  8) for everything except the asymmetric pipelined+remote path, which
  stays at 1. The asymmetric cap is real: ``RemoteTransport``'s wire
  protocol is per-session and concurrent target requests would
  interleave ``OP_PREFILL`` / ``OP_FORWARD`` frames on the same socket
  and corrupt the drafter rank's per-request KV state. Lifting *that*
  cap requires extending the wire protocol with a request-id field --
  separate change.

* Strengthens the K=8-cancel resilience contract: a single in-flight
  task's runtime exception now evicts only that task. The previous
  contract was "the runner must survive"; the new contract is "every
  *other* in-flight task must keep advancing too".

Tests added in ``test_sequential_generator_errors``:

* ``test_round_robin_advances_all_active_tasks_per_tick`` -- the core
  contract: with cap=2, both tasks advance one yield per ``step``
  (the singular-slot version would have been 0 for slot 1).
* ``test_round_robin_respects_max_concurrent_tasks`` -- cap=1
  (asymmetric default) admits only one task per tick.
* ``test_round_robin_per_task_error_does_not_kill_other_active_tasks``
  -- a faulty generator evicts only itself; siblings keep advancing.

The pre-existing ``test_runner_survives_sequential_failure_and_serves_
next_task`` was tightened to express the post-refactor contract: a
queue of failing tasks may surface all their failures on tick 1 (the
old ``_active = None`` reset only admitted one per tick).

All 642 tests pass; ``basedpyright`` 0 errors; ``ruff`` clean.
team-wcv pushed a commit that referenced this pull request May 9, 2026
The PR #15 round-robin landed in 456bbb3 cut slot-1 TTFT 5.2x by
interleaving decode ticks across in-flight requests, but the residual
11s outliers in the mixed-prompt long_context_summary bench are
6K-token prefills that still run sequentially per-tick: every active
slot pays its full prefill cost on the same GPU before its first
decode token. Batched prefill folds K eligible slots' prefills into a
single ``PromptProcessingBatch.prompt`` call, amortising weight loads
across the batch, then hands per-slot pre-filled caches back to
``mlx_generate`` via the new ``precomputed_target_cache`` seam.

Eligibility (V1, narrow on purpose):
* single-rank target (multi-rank pipeline-parallel prefill keeps its
  driver loop -- folding it in needs a follow-up that touches
  ``pipeline_parallel_prefill``'s collective semantics);
* no vision (``task_params.images``);
* no remote prefill endpoint (the disaggregated-prefill path already
  runs on a sibling instance via ``InstanceLink``);
* no in-process model drafter (its drafter prefill must stay aligned
  to the target's offset; batching only the target would desync them).

The asymmetric pipelined drafter still qualifies because
``self.draft_model is None`` on the target rank -- its drafter
prefill goes over the wire per-session and is independent of target
prefill batching. ``EXO_BATCH_PREFILL=0`` is the env-var escape hatch
for shared-prefix workloads where the per-slot prefix-cache hit rate
exceeds the batched-forward speedup.

Implementation:
* ``batched_prefill`` in ``generate.py`` wraps ``PromptProcessingBatch``
  (the upstream mlx-lm helper that handles right-padding +
  ``prepare(lengths, right_padding)`` + ``finalize()``), slices each
  prompt to ``prompt[:-1]`` so the post-prefill cache offset lands
  at ``len(prompt) - 1`` (matching the existing exact-prefix-hit
  invariant), and re-extracts per-sequence caches in place. Cache
  layers without ``merge``/``extract`` (e.g. ``DeepseekV4Cache``)
  raise the typed ``BatchedPrefillUnsupportedError`` so the caller
  falls back to per-slot prefill instead of crashing the runner.
* ``mlx_generate`` accepts ``precomputed_target_cache``: when set,
  the prefix-cache lookup is bypassed (V1 trade-off -- we don't
  pollute the shared cache with per-request entries), the local
  ``prefill`` slice is empty (one-token decode seed), and the
  exact-prefix-hit shape is reproduced so the rest of the function
  is untouched.
* ``SequentialGenerator`` factors out admission into
  ``_admit_queued_tasks`` which collects up-to-slack candidates,
  filters via ``_batch_eligible_for_prefill``, and either runs one
  ``batched_prefill`` (>=2 eligible) or per-slot ``_start_one``.
  Per-slot path is preserved verbatim for backwards compatibility
  with monkeypatched ``_build_generator`` test seams (only forwards
  ``precomputed_target_cache`` when set).
* Untyped failures in the batched forward are charged to every
  batched task via ``_send_error`` + ``_pending_failed`` so a single
  malformed request can't take down the runner -- same liveness
  contract as the per-slot path.

Tests:
* ``test_batched_prefill.py``: bit-exact correctness vs sequential
  prefill on B=2, decode continuity after batched prefill, empty-
  input zero return, length-mismatch / short-prompt ``ValueError``,
  and ``BatchedPrefillUnsupportedError`` on a cache type without
  ``merge``.
* ``test_sequential_generator_batch_prefill.py``: routing matrix
  -- two eligible tasks take the batched path, single-eligible
  falls through to per-slot, mixed eligibility splits cleanly,
  ``EXO_BATCH_PREFILL=0`` disables, unsupported cache falls back,
  and each individual disqualifier (group, vision, remote prefill,
  in-process drafter) routes to per-slot. Asymmetric drafter target
  is verified to qualify.

Bench (still pending): B=1 vs B=2 prefill on 6K-token
``long_context_summary`` -- expected ~1.5-1.8x aggregate prefill tps
on a single GPU. The 6K outliers should drop from ~11s to ~6s on
B=2, eliminating the worst-case slot-1 TTFT in the mixed bench.

Cluster validation gated on TB /30 RDMA repair (operator-driven sudo
required, see ``bb rdma repair all``).

X-Orchestraitor-Task: gemma4-drafter-tuning
X-Orchestraitor-Plan: PR #15 batched-prefill leg
X-Agent-Platform: cursor-claude-opus
team-wcv pushed a commit that referenced this pull request May 9, 2026
The PR #15 batched-prefill work landed in 890c7ae, but in production
it never fired: two concurrent client requests reliably miss the same
``SequentialGenerator._admit_queued_tasks`` window because the runner
calls ``step()`` immediately after submitting the *starting* task and
only polls ``_work_queue`` for the second task **after** that first
``step()`` has already run prefill on slot 0. By the time the
second admit cycle sees task #2, slot 0 is mid-decode and there's
nothing left to batch with.

Cluster bench on 3-node TB-RDMA Big Brain (smbp + smbpt + bmbp,
gemma-4-26b-a4b-it-4bit single-rank on smbp, drafter disabled,
concurrency=2, long_context_summary, 3 runs):

  EXO_BATCH_PREFILL=0  agg=272.23 t/s  slot0_ttft=1119ms  slot1_ttft=2143ms
  EXO_BATCH_PREFILL=1  agg=269.03 t/s  slot0_ttft=1114ms  slot1_ttft=2135ms

Identical numbers across both runs is the smoking gun: batched_prefill
never triggered, the per-slot fallback ran in both cases, and slot 1
paid the full prefill of slot 0 on top of its own. Confirmed by the
absence of the ``batched_prefill: N slots`` info log on the runner.

This commit:

* Adds ``Runner._coalesce_burst_generation_tasks`` -- pulls
  ``TextGeneration`` / ``ImageGeneration`` / ``ImageEdits`` items
  from ``_work_queue`` and submits them via the existing
  ``submit_generation`` path before the first ``step()``.
* Calls it from ``handle_generation_tasks`` immediately after the
  starting-task submit so the upcoming admit cycle sees the full
  burst together.
* Blocks on the queue for up to ``EXO_BURST_COALESCE_MS`` (default
  20ms) so libp2p-delivery jitter doesn't lose task #2 to the
  before-step gap. Tuned at 20ms because TB-routed libp2p delivery
  on concurrent client requests typically straggles 5-15ms; values
  >50ms start adding user-visible TTFT to solo requests.
  ``EXO_BURST_COALESCE_MS=0`` disables (per-slot every time).
* Routes any non-task item picked up during the drain
  (``PrefillTask`` / ``_TaskStreamClosed`` / ``Shutdown``) into a
  new ``_burst_deferred_item`` slot consumed by the main loop
  before its next ``_work_queue.get_nowait()``. Re-queueing at the
  tail would race with the listener thread and silently re-order
  ``Shutdown`` past burst tasks -- this preserves FIFO without
  needing an appendleft-capable queue.

The 307 worker unit tests all pass; basedpyright and ruff clean.
Bench validation pending TB-RDMA bringup retry (next commit).

X-Orchestraitor-Task: gemma4-drafter-tuning
X-Orchestraitor-Plan: PR #15 batched-prefill bench-fix
X-Agent-Platform: cursor-claude-opus
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants