diff --git a/resources/inference_model_cards/mlx-community--gemma-4-26b-a4b-it-4bit.toml b/resources/inference_model_cards/mlx-community--gemma-4-26b-a4b-it-4bit.toml index 51be323ec2..863203b743 100644 --- a/resources/inference_model_cards/mlx-community--gemma-4-26b-a4b-it-4bit.toml +++ b/resources/inference_model_cards/mlx-community--gemma-4-26b-a4b-it-4bit.toml @@ -8,6 +8,7 @@ family = "gemma" quantization = "4bit" base_model = "Gemma 4 26B A4B" capabilities = ["text", "vision"] +drafter_model_id = "mlx-community/gemma-4-e2b-it-4bit" context_length = 262144 diff --git a/resources/inference_model_cards/mlx-community--gemma-4-26b-a4b-it-6bit.toml b/resources/inference_model_cards/mlx-community--gemma-4-26b-a4b-it-6bit.toml index c984d44b7d..32a0a84d56 100644 --- a/resources/inference_model_cards/mlx-community--gemma-4-26b-a4b-it-6bit.toml +++ b/resources/inference_model_cards/mlx-community--gemma-4-26b-a4b-it-6bit.toml @@ -8,6 +8,7 @@ family = "gemma" quantization = "6bit" base_model = "Gemma 4 26B A4B" capabilities = ["text", "vision"] +drafter_model_id = "mlx-community/gemma-4-e2b-it-6bit" context_length = 262144 diff --git a/resources/inference_model_cards/mlx-community--gemma-4-26b-a4b-it-8bit.toml b/resources/inference_model_cards/mlx-community--gemma-4-26b-a4b-it-8bit.toml index fe2583668c..3201ec8283 100644 --- a/resources/inference_model_cards/mlx-community--gemma-4-26b-a4b-it-8bit.toml +++ b/resources/inference_model_cards/mlx-community--gemma-4-26b-a4b-it-8bit.toml @@ -8,6 +8,7 @@ family = "gemma" quantization = "8bit" base_model = "Gemma 4 26B A4B" capabilities = ["text", "vision"] +drafter_model_id = "mlx-community/gemma-4-e2b-it-8bit" context_length = 262144 diff --git a/resources/inference_model_cards/mlx-community--gemma-4-26b-a4b-it-bf16.toml b/resources/inference_model_cards/mlx-community--gemma-4-26b-a4b-it-bf16.toml index ea4dbbfc59..39ea210a64 100644 --- a/resources/inference_model_cards/mlx-community--gemma-4-26b-a4b-it-bf16.toml +++ b/resources/inference_model_cards/mlx-community--gemma-4-26b-a4b-it-bf16.toml @@ -8,6 +8,7 @@ family = "gemma" quantization = "bf16" base_model = "Gemma 4 26B A4B" capabilities = ["text", "vision"] +drafter_model_id = "mlx-community/gemma-4-e2b-it-bf16" context_length = 262144 diff --git a/resources/inference_model_cards/mlx-community--gemma-4-31b-it-4bit.toml b/resources/inference_model_cards/mlx-community--gemma-4-31b-it-4bit.toml index cb8e63580f..87a7584cbb 100644 --- a/resources/inference_model_cards/mlx-community--gemma-4-31b-it-4bit.toml +++ b/resources/inference_model_cards/mlx-community--gemma-4-31b-it-4bit.toml @@ -8,6 +8,7 @@ family = "gemma" quantization = "4bit" base_model = "Gemma 4 31B" capabilities = ["text", "vision"] +drafter_model_id = "mlx-community/gemma-4-e2b-it-4bit" context_length = 262144 diff --git a/resources/inference_model_cards/mlx-community--gemma-4-31b-it-6bit.toml b/resources/inference_model_cards/mlx-community--gemma-4-31b-it-6bit.toml index 845620626d..0e0314e119 100644 --- a/resources/inference_model_cards/mlx-community--gemma-4-31b-it-6bit.toml +++ b/resources/inference_model_cards/mlx-community--gemma-4-31b-it-6bit.toml @@ -8,6 +8,7 @@ family = "gemma" quantization = "6bit" base_model = "Gemma 4 31B" capabilities = ["text", "vision"] +drafter_model_id = "mlx-community/gemma-4-e2b-it-6bit" context_length = 262144 diff --git a/resources/inference_model_cards/mlx-community--gemma-4-31b-it-8bit.toml b/resources/inference_model_cards/mlx-community--gemma-4-31b-it-8bit.toml index 332a9b0053..0e33f6ff58 100644 --- a/resources/inference_model_cards/mlx-community--gemma-4-31b-it-8bit.toml +++ b/resources/inference_model_cards/mlx-community--gemma-4-31b-it-8bit.toml @@ -8,6 +8,7 @@ family = "gemma" quantization = "8bit" base_model = "Gemma 4 31B" capabilities = ["text", "vision"] +drafter_model_id = "mlx-community/gemma-4-e2b-it-8bit" context_length = 262144 diff --git a/resources/inference_model_cards/mlx-community--gemma-4-31b-it-bf16.toml b/resources/inference_model_cards/mlx-community--gemma-4-31b-it-bf16.toml index 6fc0a2dcaa..1da7e56e9d 100644 --- a/resources/inference_model_cards/mlx-community--gemma-4-31b-it-bf16.toml +++ b/resources/inference_model_cards/mlx-community--gemma-4-31b-it-bf16.toml @@ -8,6 +8,7 @@ family = "gemma" quantization = "bf16" base_model = "Gemma 4 31B" capabilities = ["text", "vision"] +drafter_model_id = "mlx-community/gemma-4-e2b-it-bf16" context_length = 262144 diff --git a/src/exo/shared/models/model_cards.py b/src/exo/shared/models/model_cards.py index 0d9acc1d02..410bdfc5bb 100644 --- a/src/exo/shared/models/model_cards.py +++ b/src/exo/shared/models/model_cards.py @@ -153,6 +153,11 @@ class ModelCard(FrozenModel): is_custom: bool = False vision: VisionCardConfig | None = None sampling_defaults: SamplingDefaults = Field(default_factory=SamplingDefaults) + # Optional speculative-decoding draft model. When set, runners will load the + # named model alongside the target and pass it as `draft_model` to mlx_lm's + # `stream_generate`, enabling MLX-side speculative decoding. The drafter MUST + # share a tokenizer with the target. + drafter_model_id: ModelId | None = None @model_validator(mode="after") def _autodetect_vision(self) -> "ModelCard": diff --git a/src/exo/shared/tests/test_model_cards_drafter.py b/src/exo/shared/tests/test_model_cards_drafter.py new file mode 100644 index 0000000000..302bcd3368 --- /dev/null +++ b/src/exo/shared/tests/test_model_cards_drafter.py @@ -0,0 +1,72 @@ +"""Tests for the optional `drafter_model_id` field on ModelCard. + +The field declares a speculative-decoding draft model that runners may load +alongside the target. Coverage: +- ModelCard accepts and serialises the field. +- Cards with no drafter declared default to `None`. +- The Gemma 4 large-instruct cards point to the e2b drafter. +""" + +import pytest + +from exo.shared.models.model_cards import ModelCard, ModelId, get_model_cards +from exo.shared.types.memory import Memory + + +@pytest.mark.asyncio +async def test_drafter_model_id_defaults_to_none() -> None: + cards = {card.model_id: card for card in await get_model_cards()} + qwen_id = ModelId("mlx-community/Qwen3-30B-A3B-4bit") + if qwen_id in cards: + assert cards[qwen_id].drafter_model_id is None + + +@pytest.mark.asyncio +async def test_gemma4_31b_cards_declare_e2b_drafter() -> None: + cards = {card.model_id: card for card in await get_model_cards()} + expectations = { + "mlx-community/gemma-4-31b-it-4bit": "mlx-community/gemma-4-e2b-it-4bit", + "mlx-community/gemma-4-31b-it-6bit": "mlx-community/gemma-4-e2b-it-6bit", + "mlx-community/gemma-4-31b-it-8bit": "mlx-community/gemma-4-e2b-it-8bit", + "mlx-community/gemma-4-31b-it-bf16": "mlx-community/gemma-4-e2b-it-bf16", + } + for target_str, expected_drafter_str in expectations.items(): + target_id = ModelId(target_str) + assert target_id in cards, f"{target_id} card missing" + card = cards[target_id] + assert card.drafter_model_id == ModelId(expected_drafter_str), ( + f"{target_id} drafter mismatch: got {card.drafter_model_id!r}" + ) + + +@pytest.mark.asyncio +async def test_gemma4_26b_cards_declare_e2b_drafter() -> None: + cards = {card.model_id: card for card in await get_model_cards()} + expectations = { + "mlx-community/gemma-4-26b-a4b-it-4bit": "mlx-community/gemma-4-e2b-it-4bit", + "mlx-community/gemma-4-26b-a4b-it-6bit": "mlx-community/gemma-4-e2b-it-6bit", + "mlx-community/gemma-4-26b-a4b-it-8bit": "mlx-community/gemma-4-e2b-it-8bit", + "mlx-community/gemma-4-26b-a4b-it-bf16": "mlx-community/gemma-4-e2b-it-bf16", + } + for target_str, expected_drafter_str in expectations.items(): + target_id = ModelId(target_str) + assert target_id in cards, f"{target_id} card missing" + card = cards[target_id] + assert card.drafter_model_id == ModelId(expected_drafter_str), ( + f"{target_id} drafter mismatch: got {card.drafter_model_id!r}" + ) + + +def test_model_card_explicit_drafter_round_trip() -> None: + card = ModelCard( + model_id=ModelId("mlx-community/test-target"), + storage_size=Memory.from_gb(1.0), + n_layers=12, + hidden_size=768, + supports_tensor=True, + tasks=["TextGeneration"], # pyright: ignore[reportArgumentType] + drafter_model_id=ModelId("mlx-community/test-drafter"), + ) + assert card.drafter_model_id == ModelId("mlx-community/test-drafter") + dump = card.model_dump(exclude_none=True) + assert dump["drafter_model_id"] == "mlx-community/test-drafter" diff --git a/src/exo/worker/engines/mlx/generator/generate.py b/src/exo/worker/engines/mlx/generator/generate.py index 2e3d051251..c7a7612693 100644 --- a/src/exo/worker/engines/mlx/generator/generate.py +++ b/src/exo/worker/engines/mlx/generator/generate.py @@ -540,6 +540,7 @@ def mlx_generate( distributed_prompt_progress_callback: Callable[[], None] | None = None, on_generation_token: Callable[[], None] | None = None, vision_processor: VisionProcessor | None = None, + draft_model: Model | None = None, ) -> Generator[GenerationResponse]: # Ensure that generation stats only contains peak memory for this generation mx.reset_peak_memory() @@ -717,6 +718,12 @@ def mlx_generate( logger.info("Starting decode") mx_barrier(group) + # Speculative decoding via mlx_lm: only enabled in the single-device path + # (group is None). Distributed speculative is not yet plumbed; passing a + # draft_model alongside a non-trivial group would be a no-op, so we drop + # it explicitly to make the caller contract clear. + effective_draft_model = draft_model if group is None else None + for completion_tokens, out in enumerate( stream_generate( model=model, @@ -729,6 +736,7 @@ def mlx_generate( prefill_step_size=1, kv_group_size=KV_GROUP_SIZE, kv_bits=KV_BITS, + draft_model=effective_draft_model, ), start=1, ):