Skip to content

Add Gemma 4 drafter support: ModelCard.drafter_model_id + mlx_generate plumbing#2065

Closed
team-wcv wants to merge 1 commit into
exo-explore:mainfrom
team-wcv:feature/gemma4-drafter-support
Closed

Add Gemma 4 drafter support: ModelCard.drafter_model_id + mlx_generate plumbing#2065
team-wcv wants to merge 1 commit into
exo-explore:mainfrom
team-wcv:feature/gemma4-drafter-support

Conversation

@team-wcv
Copy link
Copy Markdown
Contributor

@team-wcv team-wcv commented May 7, 2026

Motivation

Google's Gemma 4 May 2026 update shipped MTP drafters — small drafter weights that pair with the larger instruct models for speculative decoding via mlx_lm's `stream_generate(draft_model=...)` API. mlx_lm has supported this for several releases; exo currently doesn't expose it.

This PR adds the surface-level support to declare drafters in model cards and plumb them through the single-device generation path. It's intentionally scoped small so it can land quickly; drafter loading at builder/runner bootstrap and distributed-mode speculative decoding are both straightforward follow-up work.

Related: #1685 (David's MTP exploration). David's `david/speculative-mtp` branch has an alternative direction (custom MTP module + Qwen3.5 MoE kernels) that's much larger in scope; this PR is the minimal mlx_lm-native path.

Changes

`shared/models/model_cards.py` — new optional field

```python
class ModelCard(FrozenModel):
...
drafter_model_id: ModelId | None = None
```

When set, runners may load the named drafter model alongside the target and pass it as `draft_model` to mlx_lm's `stream_generate`. The drafter MUST share a tokenizer with the target; that contract is the responsibility of the caller (i.e. whoever sets the field on a card).

`worker/engines/mlx/generator/generate.py` — pipe `draft_model` to `stream_generate`

  • Add `draft_model: Model | None = None` parameter to `mlx_generate`.
  • Forward to `stream_generate(draft_model=...)` when `group is None` (single-device path). Distributed-mode draft is dropped explicitly; mlx_lm's speculative decoding does not yet plumb through tensor-parallel groups, so passing a drafter alongside a non-trivial group would currently be a no-op at best, and an error at worst.

Model cards — 8 Gemma 4 large/instruct variants

Each of `gemma-4-26b-a4b-it` and `gemma-4-31b-it` (4bit / 6bit / 8bit / bf16) now declares the matching-quant `gemma-4-e2b-it` as its drafter. The Gemma 4 family shares one tokenizer across e2b / e4b / 26b / 31b, so e2b is a valid drafter for both 31b and 26b.

Target Drafter
`gemma-4-31b-it-{4bit,6bit,8bit,bf16}` `gemma-4-e2b-it-{4bit,6bit,8bit,bf16}`
`gemma-4-26b-a4b-it-{4bit,6bit,8bit,bf16}` `gemma-4-e2b-it-{4bit,6bit,8bit,bf16}`

Why It Works

mlx_lm's `stream_generate(model, tokenizer, prompt, draft_model=...)` already implements the verify-and-accept loop for speculative decoding when `draft_model` is supplied. exo's `mlx_generate` was already calling `stream_generate` for single-device decode; the only missing piece was the parameter pass-through.

Single-device behavior is the primary win for users running exo as a local serving stack (one Apple Silicon device hosting both target and drafter). Distributed (tensor-parallel) speculative decoding is a much larger lift and not in this PR.

What is intentionally NOT in this PR

  • Drafter loading: `MlxBuilder` doesn't yet load the drafter at bootstrap. Wiring `load_model(drafter_path)` from `mlx_lm.utils` and pinning the loaded module on `MlxBuilder` is a few-line follow-up; left out here to keep this diff narrowly focused on the schema + generate plumbing.
  • Drafter download flow: the download manifest doesn't yet pull `drafter_model_id` weights alongside the target. Probably wants to live next to the same `download/coordinator.py` paths.
  • Distributed speculative decoding: not yet supported in mlx_lm, and even if it were, would need exo's own group-aware verify-and-accept logic. Not in this PR.

Test Plan

Automated Testing

```
src/exo/shared/tests/test_model_cards_drafter.py ....
=== 4 passed in 0.16s ===
```

`uv run basedpyright` and `uv run ruff check` both clean.

Manual Testing

This PR is schema + plumbing only. Manual validation requires the follow-up loader PR. We've tested an internal version of the full stack (loader + this plumbing) on a single M5 Max 128GB running `gemma-4-31b-it-4bit` with the `gemma-4-e2b-it-4bit` drafter and observed the expected ~2x decode-tps speedup on chat-completion prompts.

Notes for reviewers

  • Happy to follow up immediately with the `load_drafter` wiring + download manifest support if maintainers want this in one bigger PR rather than three. Splitting it out here so the schema can be reviewed independently.
  • David's MTP issue (#1685) discusses a richer custom-kernel approach that's complementary to this; the `drafter_model_id` field doesn't preclude either path.

…rate

Adds the surface-level support for speculative decoding via mlx_lm's
stream_generate(draft_model=...) on the single-device generation path:

- `ModelCard.drafter_model_id: ModelId | None`: declarative pointer to a
  drafter model that runners may load alongside the target. The drafter
  must share a tokenizer with the target; this is the caller's
  responsibility to enforce.
- `mlx_generate(draft_model=...)`: forwarded to `stream_generate` when
  `group is None` (single-device). Distributed-mode draft is dropped
  explicitly, since mlx_lm's speculative decoding does not yet plumb
  through tensor-parallel groups.
- Eight Gemma 4 model cards (gemma-4-26b-a4b-it and gemma-4-31b-it,
  4bit/6bit/8bit/bf16) declare gemma-4-e2b-it (matching quant) as their
  drafter. The Gemma 4 family shares a tokenizer across e2b/e4b/26b/31b,
  so e2b is a valid drafter.

Drafter loading at builder/runner bootstrap is intentionally not in this
patch — keeping the diff focused on the model-card schema and the
single-device generate plumbing. Wiring drafter download and
load_drafter() into MlxBuilder is straightforward follow-up work.

Tests:
- test_model_cards_drafter.py: 4 tests covering default-None,
  Gemma 4 31b/26b drafter pointers, and round-trip of an explicit value.
@rltakashige
Copy link
Copy Markdown
Collaborator

Thanks for this contribution! We're also looking into MTP, and we'll handle the issues you've mentioned at the same time as well as use something like this.

@team-wcv
Copy link
Copy Markdown
Contributor Author

Closing as superseded by #2079 (now self-contained: includes this PR's commit as its first commit, plus the full drafter abstraction + Gemma 4 MTP + Qwen 3.5/3.6 DFlash + multi-device tensor-parallel coupled-drafter stack).

The first commit of #2079 (f383ef0a) is a rebase of bdf1a12d (this PR's head) onto the latest main — same author, date, message, byte-identical patch — so closing this PR has zero impact on what's available for upstream review or merge. #2079 explicitly notes it can be merged standalone in one shot.

If you'd prefer the smaller 11-file foundation diff for an initial sanity pass before the larger one, the commit's still right there as the first commit of #2079 and easy to cherry-pick onto main separately. Happy to reopen this if that's the preferred review workflow.

@team-wcv
Copy link
Copy Markdown
Contributor Author

Superseded by #2079.

@team-wcv team-wcv closed this May 10, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants