Skip to content

Drafter abstraction + Gemma 4 MTP + Qwen 3.5/3.6 DFlash + multi-device coupled-drafter speculative decoding#2079

Open
team-wcv wants to merge 2 commits into
exo-explore:mainfrom
team-wcv:feature/gemma4-mtp-coupled-drafter
Open

Drafter abstraction + Gemma 4 MTP + Qwen 3.5/3.6 DFlash + multi-device coupled-drafter speculative decoding#2079
team-wcv wants to merge 2 commits into
exo-explore:mainfrom
team-wcv:feature/gemma4-mtp-coupled-drafter

Conversation

@team-wcv
Copy link
Copy Markdown
Contributor

@team-wcv team-wcv commented May 10, 2026

Lands the full speculative-decoding stack on top of the drafter_model_id ModelCard foundation: drafter abstraction, in-process tuning, asymmetric pipelined drafter, production hardening, Gemma 4 multi-token-prediction (MTP), Qwen 3.5 / 3.6 DFlash coupled drafters, and multi-device tensor-parallel coupled-drafter dispatch.

This PR is self-contained: the foundation patch from the now-closed #2065 (bdf1a12d) is included here as the first commit (f383ef0a, same patch rebased onto the latest main). Two-commit structure preserved for review clarity (foundation + bundle).

Per @rltakashige on #2065: "We're also looking into MTP, and we'll handle the issues you've mentioned at the same time as well as use something like this." — feel free to cherry-pick whatever pieces are useful for upstream's own MTP work.

Headlines

All numbers from MacBook Pros with Apple M5 Max. wc-smbp and wc-smbpt carry 128 GB unified memory each; wc-bmbp carries 48 GB. Generation tokens/s are the median of 10 runs per A/B side, T=0, across 5 scenarios (short_repetitive, code_completion, creative_prose, factual_qa, long_context_summary).

Single-device DFlash A/B (Qwen 3.5 / 3.6)

Target Quant Arch Host Target gen_tps DFlash gen_tps Speedup Accept
Qwen 3.5 4B 8bit dense wc-smbp 97.24 404.38 4.16x 93.2%
Qwen 3.6 27B 8bit dense wc-smbpt 14.98 49.13 3.28x 92.6%
Qwen 3.6 35B-A3B 8bit MoE wc-smbpt 87.70 377.49 4.30x 92.6%

Multi-device DFlash A/B (Qwen 3.5 122B-A10B, 2-node TP + JACCL/RDMA)

Target Placement Target gen_tps DFlash gen_tps Speedup Accept
Qwen 3.5 122B-A10B 8bit MoE TP2 across wc-smbp + wc-smbpt over RDMA 52.61 159.00 3.02x 93.75%

DFlash holds above 3.02x at every scale tested, including the 122B-A10B MoE running across two M5 Max MacBook Pros with tensor parallelism over a Thunderbolt-bridge RDMA edge.

Single-device MTP A/B (Gemma 4)

Median of 2 runs per scenario on wc-smbp (M5 Max, 128 GB). The all-scenario median is dragged down by a single long_context_summary outlier where the assistant head goes out-of-distribution; the high-accept scenarios consistently win double-digit speedups.

Target Quant Arch Best scenario speedup All-scenario median
Gemma 4 31B 4bit dense +13.2% on code_completion +5.4%
Gemma 4 26B-A4B 4bit MoE +22.1% on code_completion -1.6%

Full scenario tables + the routing-heuristic discussion live in bench/results/mtp/REPORT.md. The DFlash gains are much larger than MTP because the DFlash drafter is architecturally tied to the gated-delta-net targets (every layer's hidden state + the recurrent SSM state are captured + replayed) whereas Gemma 4's MTP assistant only sees the final hidden state.

What's in this PR

Drafter abstraction (src/exo/worker/engines/mlx/generator/drafter.py)

  • Drafter Protocol with stream, metrics, and DraftMode resolution (model / pipelined / none).
  • EXO_DRAFT_MODE env override; K + warmup + KV-cache + n-gram tuning knobs.
  • Coupled-aware resolve_draft_mode keeps pipelined mode gated on a standard sibling drafter so coupled-only runs never deadlock.
  • GenerationStats.drafter_kind: Literal[\"standard\",\"mtp\",\"dflash\",\"ngram\",\"none\"] so OpenAI-compatible CompletionTokensDetails and the dashboard can surface which speculative path actually dispatched.

Coupled drafters (generator/coupled_drafter.py, vendor/)

  • CoupledModelDrafter shim around mlx-vlm's _mtp_rounds / _dflash_rounds with a CoupledDrafterKind = Literal[\"mtp\", \"dflash\"] dispatch table. Both kinds are wired and ship in DISPATCHABLE_COUPLED_DRAFTER_KINDS = frozenset({\"mtp\", \"dflash\"}).
  • Gemma 4 MTP: vendored attach_mtp_hooks attaches forward_with_capture + rollback_speculative_cache to the wrapped Gemma 4 text model. Target adapter captures the final hidden state per draft round and replays KV cache rewinds on rejection.
  • Qwen 3.5 / 3.6 DFlash: vendored qwen3_5_dflash_hooks for the hybrid attention + gated-delta-net architecture. The DFlash drafter consumes captured hidden states at specific layer indices plus the 11-tuple GdnState (last_recurrent_state + window kv buffers + delta caches); rollback replays both KV pages and SSM states. Works for both dense (Qwen 3.6 27B) and MoE (Qwen 3.6 35B-A3B, 256 experts × 8 active) targets because the existing qwen3_5.DecoderLayer.SparseMoeBlock already handles routing — no MoE-specific vendor work needed.

Multi-device tensor-parallel coupled-drafter dispatch

The previous loader hard-coded if group is None and the generator hard-coded draft_mode = \"none\" whenever group is not None, so the coupled drafter never ran on TP placements — exactly the regime 122B-class targets live in. Lifted via:

  • _try_load_collocated_drafter is now called from both the single-device and the symmetric multi-rank branches in utils_mlx.py. The multi-device call passes allow_standard_drafter_fallback=False because the generator still can't dispatch standard drafters through group, so a loaded standard drafter would only waste memory.
  • mlx_generate only forces draft_mode = \"none\" for multi-device when coupled_drafter_eligible is false.
  • builder.py selects SequentialGenerator (speculative-capable) when coupled_drafter_dispatchable is true, even with group is not None.

Correctness: each TP rank's per-rank __call__ reduces its output to the full hidden state (via the in-layer ShardedToAllLinear / ShardedMoE all-sums), so the replicated drafter consumes an identical hidden state and produces identical draft tokens + bonus samples under the shared mx.random.seed(seed) set at the top of each generation step. Inline docstring + report explain the three lockstep guarantees. Validated end-to-end on Qwen 3.5 122B-A10B over JACCL/RDMA across wc-smbp + wc-smbpt.

Loader & download fixes

  • Coupled drafter loader recognises the ModelCard.coupled_drafter declaration, downloads via the existing peer-aware shard downloader, and gates on mlx-vlm >= 0.5.0 being installed.
  • Single-file safetensors.index.json bootstrap: DFlash drafters that ship with just model.safetensors no longer trip the shard downloader (the index is synthesized).

Model cards

  • All 8 Gemma 4 31B / 26B-A4B cards (4/6/8bit + bf16) declare the MTP coupled_drafter.
  • New cards: mlx-community/Qwen3.5-4B-MLX-8bit, mlx-community/Qwen3.5-122B-A10B-8bit.
  • Updated cards: mlx-community/Qwen3.6-27B-8bit, mlx-community/Qwen3.6-35B-A3B-8bit (point at z-lab DFlash drafters).

In-process drafter tuning, asymmetric pipeline, production hardening

The full speculative-decoding stack also lands the prior team-wcv iterations:

  • K, warmup, KV-cache, n-gram strategy tuning knobs.
  • Asymmetric pipelined DrafterRunner over mx.distributed + socket transports + concurrency for uneven-memory clusters (wc-bmbp 48 GB peers paired with the 128 GB hosts is the canonical case).
  • Resilience, TP fanout, telemetry, bench harness.

Tests

  • Drafter abstraction + tuning: covered.
  • Gemma 4 MTP: test_gemma4_mtp_hooks.py + dispatch tests.
  • Qwen 3.5 DFlash: test_qwen3_5_dflash_hooks.py (captured-hidden-state shapes, GdnState rollback semantics, 1-D bonus-sample fix).
  • DFlash dispatch: test_coupled_drafter_dflash_dispatch.py.
  • Multi-device dispatch: test_coupled_drafter_multi_device.py — 8 tests pinning loader fallback policy, dispatch eligibility, and the critical conditionals in generate.py + builder.py.

Full suite (team-wcv branch): 1056 passing, basedpyright strict 0 errors project-wide, ruff check clean, ruff format clean.

Bench results

Per-request raw JSON + scenario tables checked in under bench/results/{mtp,dflash}/. Two A/B reports document methodology and headline numbers:

  • bench/results/mtp/REPORT.md — Gemma 4 26B-A4B + 31B at 4bit (per-scenario tables, headline-vs-peak discussion, routing-heuristic recommendation).
  • bench/results/dflash/REPORT.md — Qwen 3.5 4B + Qwen 3.6 27B + Qwen 3.6 35B-A3B + Qwen 3.5 122B-A10B TP2, with the multi-device correctness note.

…rate

Adds the surface-level support for speculative decoding via mlx_lm's
stream_generate(draft_model=...) on the single-device generation path:

- `ModelCard.drafter_model_id: ModelId | None`: declarative pointer to a
  drafter model that runners may load alongside the target. The drafter
  must share a tokenizer with the target; this is the caller's
  responsibility to enforce.
- `mlx_generate(draft_model=...)`: forwarded to `stream_generate` when
  `group is None` (single-device). Distributed-mode draft is dropped
  explicitly, since mlx_lm's speculative decoding does not yet plumb
  through tensor-parallel groups.
- Eight Gemma 4 model cards (gemma-4-26b-a4b-it and gemma-4-31b-it,
  4bit/6bit/8bit/bf16) declare gemma-4-e2b-it (matching quant) as their
  drafter. The Gemma 4 family shares a tokenizer across e2b/e4b/26b/31b,
  so e2b is a valid drafter.

Drafter loading at builder/runner bootstrap is intentionally not in this
patch — keeping the diff focused on the model-card schema and the
single-device generate plumbing. Wiring drafter download and
load_drafter() into MlxBuilder is straightforward follow-up work.

Tests:
- test_model_cards_drafter.py: 4 tests covering default-None,
  Gemma 4 31b/26b drafter pointers, and round-trip of an explicit value.
@team-wcv team-wcv force-pushed the feature/gemma4-mtp-coupled-drafter branch from 0ae6d25 to a76781e Compare May 10, 2026 22:06
@team-wcv team-wcv changed the title Drafter abstraction + Gemma 4 MTP coupled-drafter speculative decoding (extends #2065) Drafter abstraction + Gemma 4 MTP + Qwen 3.5/3.6 DFlash + multi-device coupled-drafter speculative decoding (extends #2065) May 10, 2026
@team-wcv team-wcv changed the title Drafter abstraction + Gemma 4 MTP + Qwen 3.5/3.6 DFlash + multi-device coupled-drafter speculative decoding (extends #2065) Drafter abstraction + Gemma 4 MTP + Qwen 3.5/3.6 DFlash + multi-device coupled-drafter speculative decoding May 10, 2026
@team-wcv team-wcv force-pushed the feature/gemma4-mtp-coupled-drafter branch 2 times, most recently from 0a18ae3 to 028ef03 Compare May 10, 2026 22:29
…e coupled drafter

Lands the full speculative-decoding stack on top of the
``drafter_model_id`` ModelCard foundation:

1. Drafter abstraction (``Drafter`` Protocol with ``stream`` /
   ``metrics`` / ``DraftMode``) and the ``CoupledModelDrafter``
   shim around mlx-vlm's ``_mtp_rounds`` / ``_dflash_rounds``.
   ``GenerationStats.drafter_kind`` ∈ {standard, mtp, dflash, ngram,
   none} so OpenAI ``CompletionTokensDetails`` + the dashboard
   surface which speculative path actually dispatched.

2. In-process drafter tuning: K, warmup, KV cache, n-gram strategy.

3. Asymmetric pipelined drafter for uneven-memory clusters --
   ``DrafterRunner`` + mx.distributed / socket transports + concurrency.

4. Production hardening: resilience, TP fanout, telemetry, bench.

5. Gemma 4 MTP coupled drafter (Phase 1-3). New
   ``ModelCard.coupled_drafter`` field; ``mlx-vlm>=0.5.0`` loader
   + per-kind target-side hook attachment
   (``attach_mtp_hooks`` for Gemma 4). 31B and 26B-A4B at all four
   quants declare the coupled MTP drafter.

   Headline: Gemma 4 31B 4bit + MTP drafter at T=0 jumps from
   13.8 t/s to 24.7 t/s with byte-identical output (single M3 Ultra).

6. Qwen 3.5 / 3.6 DFlash coupled drafter. Vendored
   ``forward_with_capture`` + ``rollback_speculative_cache`` for the
   hybrid attention / gated-delta-net architecture. The drafter
   consumes captured hidden states + an 11-tuple ``GdnState`` and
   replays them on rejection.

   Headlines (median over 10 runs per A/B side, T=0):
     Qwen 3.5 4B  8bit (dense, wc-smbp)        97.24 -> 404.38 t/s  4.16x
     Qwen 3.6 27B 8bit (dense, wc-smbpt)       14.98 ->  49.13 t/s  3.28x
     Qwen 3.6 35B-A3B 8bit (MoE, wc-smbpt)     87.70 -> 377.49 t/s  4.30x
     Qwen 3.5 122B-A10B 8bit (MoE, TP2 RDMA)   52.61 -> 159.00 t/s  3.02x

7. Multi-device coupled drafter dispatch (tensor-parallel). The
   previous loader hard-coded ``if group is None`` and the
   generator hard-coded ``draft_mode = "none"`` whenever
   ``group is not None``, so the coupled drafter never ran on TP
   placements -- exactly the regime 122B-class targets live in.
   Lifted via:

   * ``_try_load_collocated_drafter`` is now called from both the
     single-device and the symmetric multi-rank branches. The
     multi-device call passes ``allow_standard_drafter_fallback=
     False`` because the generator still can't dispatch standard
     drafters through ``group``, so a loaded standard drafter
     would only waste memory.
   * ``mlx_generate`` only forces ``draft_mode = "none"`` for
     multi-device when ``coupled_drafter_eligible`` is false.
   * ``builder.py`` selects ``SequentialGenerator``
     (speculative-capable) when ``coupled_drafter_dispatchable``
     is true, even with ``group is not None``.

   Correctness: each TP rank's per-rank ``__call__`` reduces its
   output to the full hidden state (via the in-layer
   ``ShardedToAllLinear`` / ``ShardedMoE`` all-sums), so the
   replicated drafter consumes an identical hidden state and
   produces identical draft tokens / bonus samples under the
   shared ``mx.random.seed(seed)`` set at the top of each
   generation step. 122B-A10B + JACCL/RDMA across two MacBook
   Pros validates the path end-to-end.

8. Single-file ``safetensors.index.json`` bootstrap. DFlash
   drafters that ship with just ``model.safetensors`` no longer
   trip the shard downloader.

9. Bench results + reports. ``bench/results/{mtp,dflash}/REPORT.md``
   document the A/B methodology and headline numbers. Raw
   per-request gen_tps + acceptance JSON committed for
   reproducibility.

Tests: 1056 passing, basedpyright 0 errors project-wide,
ruff clean.
@team-wcv team-wcv force-pushed the feature/gemma4-mtp-coupled-drafter branch from 028ef03 to 5dae97d Compare May 10, 2026 23:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants