Add Gemma 4 drafter support: ModelCard.drafter_model_id + mlx_generate plumbing#2065
Add Gemma 4 drafter support: ModelCard.drafter_model_id + mlx_generate plumbing#2065team-wcv wants to merge 1 commit into
Conversation
…rate Adds the surface-level support for speculative decoding via mlx_lm's stream_generate(draft_model=...) on the single-device generation path: - `ModelCard.drafter_model_id: ModelId | None`: declarative pointer to a drafter model that runners may load alongside the target. The drafter must share a tokenizer with the target; this is the caller's responsibility to enforce. - `mlx_generate(draft_model=...)`: forwarded to `stream_generate` when `group is None` (single-device). Distributed-mode draft is dropped explicitly, since mlx_lm's speculative decoding does not yet plumb through tensor-parallel groups. - Eight Gemma 4 model cards (gemma-4-26b-a4b-it and gemma-4-31b-it, 4bit/6bit/8bit/bf16) declare gemma-4-e2b-it (matching quant) as their drafter. The Gemma 4 family shares a tokenizer across e2b/e4b/26b/31b, so e2b is a valid drafter. Drafter loading at builder/runner bootstrap is intentionally not in this patch — keeping the diff focused on the model-card schema and the single-device generate plumbing. Wiring drafter download and load_drafter() into MlxBuilder is straightforward follow-up work. Tests: - test_model_cards_drafter.py: 4 tests covering default-None, Gemma 4 31b/26b drafter pointers, and round-trip of an explicit value.
|
Thanks for this contribution! We're also looking into MTP, and we'll handle the issues you've mentioned at the same time as well as use something like this. |
|
Closing as superseded by #2079 (now self-contained: includes this PR's commit as its first commit, plus the full drafter abstraction + Gemma 4 MTP + Qwen 3.5/3.6 DFlash + multi-device tensor-parallel coupled-drafter stack). The first commit of #2079 ( If you'd prefer the smaller 11-file foundation diff for an initial sanity pass before the larger one, the commit's still right there as the first commit of #2079 and easy to cherry-pick onto |
|
Superseded by #2079. |
Motivation
Google's Gemma 4 May 2026 update shipped MTP drafters — small drafter weights that pair with the larger instruct models for speculative decoding via mlx_lm's `stream_generate(draft_model=...)` API. mlx_lm has supported this for several releases; exo currently doesn't expose it.
This PR adds the surface-level support to declare drafters in model cards and plumb them through the single-device generation path. It's intentionally scoped small so it can land quickly; drafter loading at builder/runner bootstrap and distributed-mode speculative decoding are both straightforward follow-up work.
Related: #1685 (David's MTP exploration). David's `david/speculative-mtp` branch has an alternative direction (custom MTP module + Qwen3.5 MoE kernels) that's much larger in scope; this PR is the minimal mlx_lm-native path.
Changes
`shared/models/model_cards.py` — new optional field
```python
class ModelCard(FrozenModel):
...
drafter_model_id: ModelId | None = None
```
When set, runners may load the named drafter model alongside the target and pass it as `draft_model` to mlx_lm's `stream_generate`. The drafter MUST share a tokenizer with the target; that contract is the responsibility of the caller (i.e. whoever sets the field on a card).
`worker/engines/mlx/generator/generate.py` — pipe `draft_model` to `stream_generate`
Model cards — 8 Gemma 4 large/instruct variants
Each of `gemma-4-26b-a4b-it` and `gemma-4-31b-it` (4bit / 6bit / 8bit / bf16) now declares the matching-quant `gemma-4-e2b-it` as its drafter. The Gemma 4 family shares one tokenizer across e2b / e4b / 26b / 31b, so e2b is a valid drafter for both 31b and 26b.
Why It Works
mlx_lm's `stream_generate(model, tokenizer, prompt, draft_model=...)` already implements the verify-and-accept loop for speculative decoding when `draft_model` is supplied. exo's `mlx_generate` was already calling `stream_generate` for single-device decode; the only missing piece was the parameter pass-through.
Single-device behavior is the primary win for users running exo as a local serving stack (one Apple Silicon device hosting both target and drafter). Distributed (tensor-parallel) speculative decoding is a much larger lift and not in this PR.
What is intentionally NOT in this PR
Test Plan
Automated Testing
```
src/exo/shared/tests/test_model_cards_drafter.py ....
=== 4 passed in 0.16s ===
```
`uv run basedpyright` and `uv run ruff check` both clean.
Manual Testing
This PR is schema + plumbing only. Manual validation requires the follow-up loader PR. We've tested an internal version of the full stack (loader + this plumbing) on a single M5 Max 128GB running `gemma-4-31b-it-4bit` with the `gemma-4-e2b-it-4bit` drafter and observed the expected ~2x decode-tps speedup on chat-completion prompts.
Notes for reviewers