Drafter abstraction + Gemma 4 MTP + Qwen 3.5/3.6 DFlash + multi-device coupled-drafter speculative decoding#2079
Open
team-wcv wants to merge 2 commits into
Open
Conversation
…rate Adds the surface-level support for speculative decoding via mlx_lm's stream_generate(draft_model=...) on the single-device generation path: - `ModelCard.drafter_model_id: ModelId | None`: declarative pointer to a drafter model that runners may load alongside the target. The drafter must share a tokenizer with the target; this is the caller's responsibility to enforce. - `mlx_generate(draft_model=...)`: forwarded to `stream_generate` when `group is None` (single-device). Distributed-mode draft is dropped explicitly, since mlx_lm's speculative decoding does not yet plumb through tensor-parallel groups. - Eight Gemma 4 model cards (gemma-4-26b-a4b-it and gemma-4-31b-it, 4bit/6bit/8bit/bf16) declare gemma-4-e2b-it (matching quant) as their drafter. The Gemma 4 family shares a tokenizer across e2b/e4b/26b/31b, so e2b is a valid drafter. Drafter loading at builder/runner bootstrap is intentionally not in this patch — keeping the diff focused on the model-card schema and the single-device generate plumbing. Wiring drafter download and load_drafter() into MlxBuilder is straightforward follow-up work. Tests: - test_model_cards_drafter.py: 4 tests covering default-None, Gemma 4 31b/26b drafter pointers, and round-trip of an explicit value.
0ae6d25 to
a76781e
Compare
0a18ae3 to
028ef03
Compare
…e coupled drafter
Lands the full speculative-decoding stack on top of the
``drafter_model_id`` ModelCard foundation:
1. Drafter abstraction (``Drafter`` Protocol with ``stream`` /
``metrics`` / ``DraftMode``) and the ``CoupledModelDrafter``
shim around mlx-vlm's ``_mtp_rounds`` / ``_dflash_rounds``.
``GenerationStats.drafter_kind`` ∈ {standard, mtp, dflash, ngram,
none} so OpenAI ``CompletionTokensDetails`` + the dashboard
surface which speculative path actually dispatched.
2. In-process drafter tuning: K, warmup, KV cache, n-gram strategy.
3. Asymmetric pipelined drafter for uneven-memory clusters --
``DrafterRunner`` + mx.distributed / socket transports + concurrency.
4. Production hardening: resilience, TP fanout, telemetry, bench.
5. Gemma 4 MTP coupled drafter (Phase 1-3). New
``ModelCard.coupled_drafter`` field; ``mlx-vlm>=0.5.0`` loader
+ per-kind target-side hook attachment
(``attach_mtp_hooks`` for Gemma 4). 31B and 26B-A4B at all four
quants declare the coupled MTP drafter.
Headline: Gemma 4 31B 4bit + MTP drafter at T=0 jumps from
13.8 t/s to 24.7 t/s with byte-identical output (single M3 Ultra).
6. Qwen 3.5 / 3.6 DFlash coupled drafter. Vendored
``forward_with_capture`` + ``rollback_speculative_cache`` for the
hybrid attention / gated-delta-net architecture. The drafter
consumes captured hidden states + an 11-tuple ``GdnState`` and
replays them on rejection.
Headlines (median over 10 runs per A/B side, T=0):
Qwen 3.5 4B 8bit (dense, wc-smbp) 97.24 -> 404.38 t/s 4.16x
Qwen 3.6 27B 8bit (dense, wc-smbpt) 14.98 -> 49.13 t/s 3.28x
Qwen 3.6 35B-A3B 8bit (MoE, wc-smbpt) 87.70 -> 377.49 t/s 4.30x
Qwen 3.5 122B-A10B 8bit (MoE, TP2 RDMA) 52.61 -> 159.00 t/s 3.02x
7. Multi-device coupled drafter dispatch (tensor-parallel). The
previous loader hard-coded ``if group is None`` and the
generator hard-coded ``draft_mode = "none"`` whenever
``group is not None``, so the coupled drafter never ran on TP
placements -- exactly the regime 122B-class targets live in.
Lifted via:
* ``_try_load_collocated_drafter`` is now called from both the
single-device and the symmetric multi-rank branches. The
multi-device call passes ``allow_standard_drafter_fallback=
False`` because the generator still can't dispatch standard
drafters through ``group``, so a loaded standard drafter
would only waste memory.
* ``mlx_generate`` only forces ``draft_mode = "none"`` for
multi-device when ``coupled_drafter_eligible`` is false.
* ``builder.py`` selects ``SequentialGenerator``
(speculative-capable) when ``coupled_drafter_dispatchable``
is true, even with ``group is not None``.
Correctness: each TP rank's per-rank ``__call__`` reduces its
output to the full hidden state (via the in-layer
``ShardedToAllLinear`` / ``ShardedMoE`` all-sums), so the
replicated drafter consumes an identical hidden state and
produces identical draft tokens / bonus samples under the
shared ``mx.random.seed(seed)`` set at the top of each
generation step. 122B-A10B + JACCL/RDMA across two MacBook
Pros validates the path end-to-end.
8. Single-file ``safetensors.index.json`` bootstrap. DFlash
drafters that ship with just ``model.safetensors`` no longer
trip the shard downloader.
9. Bench results + reports. ``bench/results/{mtp,dflash}/REPORT.md``
document the A/B methodology and headline numbers. Raw
per-request gen_tps + acceptance JSON committed for
reproducibility.
Tests: 1056 passing, basedpyright 0 errors project-wide,
ruff clean.
028ef03 to
5dae97d
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Lands the full speculative-decoding stack on top of the
drafter_model_idModelCard foundation: drafter abstraction, in-process tuning, asymmetric pipelined drafter, production hardening, Gemma 4 multi-token-prediction (MTP), Qwen 3.5 / 3.6 DFlash coupled drafters, and multi-device tensor-parallel coupled-drafter dispatch.This PR is self-contained: the foundation patch from the now-closed #2065 (
bdf1a12d) is included here as the first commit (f383ef0a, same patch rebased onto the latestmain). Two-commit structure preserved for review clarity (foundation + bundle).Headlines
All numbers from MacBook Pros with Apple M5 Max.
wc-smbpandwc-smbptcarry 128 GB unified memory each;wc-bmbpcarries 48 GB. Generation tokens/s are the median of 10 runs per A/B side, T=0, across 5 scenarios (short_repetitive, code_completion, creative_prose, factual_qa, long_context_summary).Single-device DFlash A/B (Qwen 3.5 / 3.6)
Multi-device DFlash A/B (Qwen 3.5 122B-A10B, 2-node TP + JACCL/RDMA)
DFlash holds above 3.02x at every scale tested, including the 122B-A10B MoE running across two M5 Max MacBook Pros with tensor parallelism over a Thunderbolt-bridge RDMA edge.
Single-device MTP A/B (Gemma 4)
Median of 2 runs per scenario on
wc-smbp(M5 Max, 128 GB). The all-scenario median is dragged down by a singlelong_context_summaryoutlier where the assistant head goes out-of-distribution; the high-accept scenarios consistently win double-digit speedups.code_completioncode_completionFull scenario tables + the routing-heuristic discussion live in
bench/results/mtp/REPORT.md. The DFlash gains are much larger than MTP because the DFlash drafter is architecturally tied to the gated-delta-net targets (every layer's hidden state + the recurrent SSM state are captured + replayed) whereas Gemma 4's MTP assistant only sees the final hidden state.What's in this PR
Drafter abstraction (
src/exo/worker/engines/mlx/generator/drafter.py)DrafterProtocol withstream,metrics, andDraftModeresolution (model/pipelined/none).EXO_DRAFT_MODEenv override; K + warmup + KV-cache + n-gram tuning knobs.resolve_draft_modekeepspipelinedmode gated on a standard sibling drafter so coupled-only runs never deadlock.GenerationStats.drafter_kind: Literal[\"standard\",\"mtp\",\"dflash\",\"ngram\",\"none\"]so OpenAI-compatibleCompletionTokensDetailsand the dashboard can surface which speculative path actually dispatched.Coupled drafters (
generator/coupled_drafter.py,vendor/)CoupledModelDraftershim around mlx-vlm's_mtp_rounds/_dflash_roundswith aCoupledDrafterKind = Literal[\"mtp\", \"dflash\"]dispatch table. Both kinds are wired and ship inDISPATCHABLE_COUPLED_DRAFTER_KINDS = frozenset({\"mtp\", \"dflash\"}).attach_mtp_hooksattachesforward_with_capture+rollback_speculative_cacheto the wrapped Gemma 4 text model. Target adapter captures the final hidden state per draft round and replays KV cache rewinds on rejection.qwen3_5_dflash_hooksfor the hybrid attention + gated-delta-net architecture. The DFlash drafter consumes captured hidden states at specific layer indices plus the 11-tupleGdnState(last_recurrent_state + window kv buffers + delta caches); rollback replays both KV pages and SSM states. Works for both dense (Qwen 3.6 27B) and MoE (Qwen 3.6 35B-A3B, 256 experts × 8 active) targets because the existingqwen3_5.DecoderLayer.SparseMoeBlockalready handles routing — no MoE-specific vendor work needed.Multi-device tensor-parallel coupled-drafter dispatch
The previous loader hard-coded
if group is Noneand the generator hard-codeddraft_mode = \"none\"whenevergroup is not None, so the coupled drafter never ran on TP placements — exactly the regime 122B-class targets live in. Lifted via:_try_load_collocated_drafteris now called from both the single-device and the symmetric multi-rank branches inutils_mlx.py. The multi-device call passesallow_standard_drafter_fallback=Falsebecause the generator still can't dispatch standard drafters throughgroup, so a loaded standard drafter would only waste memory.mlx_generateonly forcesdraft_mode = \"none\"for multi-device whencoupled_drafter_eligibleis false.builder.pyselectsSequentialGenerator(speculative-capable) whencoupled_drafter_dispatchableis true, even withgroup is not None.Correctness: each TP rank's per-rank
__call__reduces its output to the full hidden state (via the in-layerShardedToAllLinear/ShardedMoEall-sums), so the replicated drafter consumes an identical hidden state and produces identical draft tokens + bonus samples under the sharedmx.random.seed(seed)set at the top of each generation step. Inline docstring + report explain the three lockstep guarantees. Validated end-to-end on Qwen 3.5 122B-A10B over JACCL/RDMA acrosswc-smbp+wc-smbpt.Loader & download fixes
ModelCard.coupled_drafterdeclaration, downloads via the existing peer-aware shard downloader, and gates onmlx-vlm >= 0.5.0being installed.safetensors.index.jsonbootstrap: DFlash drafters that ship with justmodel.safetensorsno longer trip the shard downloader (the index is synthesized).Model cards
coupled_drafter.mlx-community/Qwen3.5-4B-MLX-8bit,mlx-community/Qwen3.5-122B-A10B-8bit.mlx-community/Qwen3.6-27B-8bit,mlx-community/Qwen3.6-35B-A3B-8bit(point at z-lab DFlash drafters).In-process drafter tuning, asymmetric pipeline, production hardening
The full speculative-decoding stack also lands the prior team-wcv iterations:
DrafterRunnerovermx.distributed+ socket transports + concurrency for uneven-memory clusters (wc-bmbp48 GB peers paired with the 128 GB hosts is the canonical case).Tests
test_gemma4_mtp_hooks.py+ dispatch tests.test_qwen3_5_dflash_hooks.py(captured-hidden-state shapes, GdnState rollback semantics, 1-D bonus-sample fix).test_coupled_drafter_dflash_dispatch.py.test_coupled_drafter_multi_device.py— 8 tests pinning loader fallback policy, dispatch eligibility, and the critical conditionals ingenerate.py+builder.py.Full suite (team-wcv branch): 1056 passing,
basedpyrightstrict 0 errors project-wide,ruff checkclean,ruff formatclean.Bench results
Per-request raw JSON + scenario tables checked in under
bench/results/{mtp,dflash}/. Two A/B reports document methodology and headline numbers:bench/results/mtp/REPORT.md— Gemma 4 26B-A4B + 31B at 4bit (per-scenario tables, headline-vs-peak discussion, routing-heuristic recommendation).bench/results/dflash/REPORT.md— Qwen 3.5 4B + Qwen 3.6 27B + Qwen 3.6 35B-A3B + Qwen 3.5 122B-A10B TP2, with the multi-device correctness note.