Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ family = "gemma"
quantization = "4bit"
base_model = "Gemma 4 26B A4B"
capabilities = ["text", "vision"]
drafter_model_id = "mlx-community/gemma-4-e2b-it-4bit"

context_length = 262144

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ family = "gemma"
quantization = "6bit"
base_model = "Gemma 4 26B A4B"
capabilities = ["text", "vision"]
drafter_model_id = "mlx-community/gemma-4-e2b-it-6bit"

context_length = 262144

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ family = "gemma"
quantization = "8bit"
base_model = "Gemma 4 26B A4B"
capabilities = ["text", "vision"]
drafter_model_id = "mlx-community/gemma-4-e2b-it-8bit"

context_length = 262144

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ family = "gemma"
quantization = "bf16"
base_model = "Gemma 4 26B A4B"
capabilities = ["text", "vision"]
drafter_model_id = "mlx-community/gemma-4-e2b-it-bf16"

context_length = 262144

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ family = "gemma"
quantization = "4bit"
base_model = "Gemma 4 31B"
capabilities = ["text", "vision"]
drafter_model_id = "mlx-community/gemma-4-e2b-it-4bit"

context_length = 262144

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ family = "gemma"
quantization = "6bit"
base_model = "Gemma 4 31B"
capabilities = ["text", "vision"]
drafter_model_id = "mlx-community/gemma-4-e2b-it-6bit"

context_length = 262144

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ family = "gemma"
quantization = "8bit"
base_model = "Gemma 4 31B"
capabilities = ["text", "vision"]
drafter_model_id = "mlx-community/gemma-4-e2b-it-8bit"

context_length = 262144

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ family = "gemma"
quantization = "bf16"
base_model = "Gemma 4 31B"
capabilities = ["text", "vision"]
drafter_model_id = "mlx-community/gemma-4-e2b-it-bf16"

context_length = 262144

Expand Down
5 changes: 5 additions & 0 deletions src/exo/shared/models/model_cards.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,11 @@ class ModelCard(FrozenModel):
is_custom: bool = False
vision: VisionCardConfig | None = None
sampling_defaults: SamplingDefaults = Field(default_factory=SamplingDefaults)
# Optional speculative-decoding draft model. When set, runners will load the
# named model alongside the target and pass it as `draft_model` to mlx_lm's
# `stream_generate`, enabling MLX-side speculative decoding. The drafter MUST
# share a tokenizer with the target.
drafter_model_id: ModelId | None = None

@model_validator(mode="after")
def _autodetect_vision(self) -> "ModelCard":
Expand Down
72 changes: 72 additions & 0 deletions src/exo/shared/tests/test_model_cards_drafter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
"""Tests for the optional `drafter_model_id` field on ModelCard.

The field declares a speculative-decoding draft model that runners may load
alongside the target. Coverage:
- ModelCard accepts and serialises the field.
- Cards with no drafter declared default to `None`.
- The Gemma 4 large-instruct cards point to the e2b drafter.
"""

import pytest

from exo.shared.models.model_cards import ModelCard, ModelId, get_model_cards
from exo.shared.types.memory import Memory


@pytest.mark.asyncio
async def test_drafter_model_id_defaults_to_none() -> None:
cards = {card.model_id: card for card in await get_model_cards()}
qwen_id = ModelId("mlx-community/Qwen3-30B-A3B-4bit")
if qwen_id in cards:
assert cards[qwen_id].drafter_model_id is None


@pytest.mark.asyncio
async def test_gemma4_31b_cards_declare_e2b_drafter() -> None:
cards = {card.model_id: card for card in await get_model_cards()}
expectations = {
"mlx-community/gemma-4-31b-it-4bit": "mlx-community/gemma-4-e2b-it-4bit",
"mlx-community/gemma-4-31b-it-6bit": "mlx-community/gemma-4-e2b-it-6bit",
"mlx-community/gemma-4-31b-it-8bit": "mlx-community/gemma-4-e2b-it-8bit",
"mlx-community/gemma-4-31b-it-bf16": "mlx-community/gemma-4-e2b-it-bf16",
}
for target_str, expected_drafter_str in expectations.items():
target_id = ModelId(target_str)
assert target_id in cards, f"{target_id} card missing"
card = cards[target_id]
assert card.drafter_model_id == ModelId(expected_drafter_str), (
f"{target_id} drafter mismatch: got {card.drafter_model_id!r}"
)


@pytest.mark.asyncio
async def test_gemma4_26b_cards_declare_e2b_drafter() -> None:
cards = {card.model_id: card for card in await get_model_cards()}
expectations = {
"mlx-community/gemma-4-26b-a4b-it-4bit": "mlx-community/gemma-4-e2b-it-4bit",
"mlx-community/gemma-4-26b-a4b-it-6bit": "mlx-community/gemma-4-e2b-it-6bit",
"mlx-community/gemma-4-26b-a4b-it-8bit": "mlx-community/gemma-4-e2b-it-8bit",
"mlx-community/gemma-4-26b-a4b-it-bf16": "mlx-community/gemma-4-e2b-it-bf16",
}
for target_str, expected_drafter_str in expectations.items():
target_id = ModelId(target_str)
assert target_id in cards, f"{target_id} card missing"
card = cards[target_id]
assert card.drafter_model_id == ModelId(expected_drafter_str), (
f"{target_id} drafter mismatch: got {card.drafter_model_id!r}"
)


def test_model_card_explicit_drafter_round_trip() -> None:
card = ModelCard(
model_id=ModelId("mlx-community/test-target"),
storage_size=Memory.from_gb(1.0),
n_layers=12,
hidden_size=768,
supports_tensor=True,
tasks=["TextGeneration"], # pyright: ignore[reportArgumentType]
drafter_model_id=ModelId("mlx-community/test-drafter"),
)
assert card.drafter_model_id == ModelId("mlx-community/test-drafter")
dump = card.model_dump(exclude_none=True)
assert dump["drafter_model_id"] == "mlx-community/test-drafter"
8 changes: 8 additions & 0 deletions src/exo/worker/engines/mlx/generator/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,6 +540,7 @@ def mlx_generate(
distributed_prompt_progress_callback: Callable[[], None] | None = None,
on_generation_token: Callable[[], None] | None = None,
vision_processor: VisionProcessor | None = None,
draft_model: Model | None = None,
) -> Generator[GenerationResponse]:
# Ensure that generation stats only contains peak memory for this generation
mx.reset_peak_memory()
Expand Down Expand Up @@ -717,6 +718,12 @@ def mlx_generate(
logger.info("Starting decode")
mx_barrier(group)

# Speculative decoding via mlx_lm: only enabled in the single-device path
# (group is None). Distributed speculative is not yet plumbed; passing a
# draft_model alongside a non-trivial group would be a no-op, so we drop
# it explicitly to make the caller contract clear.
effective_draft_model = draft_model if group is None else None

for completion_tokens, out in enumerate(
stream_generate(
model=model,
Expand All @@ -729,6 +736,7 @@ def mlx_generate(
prefill_step_size=1,
kv_group_size=KV_GROUP_SIZE,
kv_bits=KV_BITS,
draft_model=effective_draft_model,
),
start=1,
):
Expand Down