diff --git a/app/EXO/EXO/ExoProcessController.swift b/app/EXO/EXO/ExoProcessController.swift
index ff756379b1..c2672e4c8c 100644
--- a/app/EXO/EXO/ExoProcessController.swift
+++ b/app/EXO/EXO/ExoProcessController.swift
@@ -8,6 +8,7 @@ private let hfEndpointKey = "EXOHFEndpoint"
private let enableImageModelsKey = "EXOEnableImageModels"
private let offlineModeKey = "EXOOfflineMode"
private let fastSynchEnabledKey = "EXOFastSynchEnabled"
+private let nativeMTPEnabledKey = "EXONativeMTPEnabled"
private let onboardingCompletedKey = "EXOOnboardingCompleted"
private let defaultModelsDirKey = "EXODefaultModelsDir"
private let additionalModelsDirsKey = "EXOAdditionalModelsDirs"
@@ -109,6 +110,17 @@ final class ExoProcessController: ObservableObject {
UserDefaults.standard.set(fastSynchEnabled, forKey: fastSynchEnabledKey)
}
}
+ @Published var nativeMTPEnabled: Bool = {
+ if UserDefaults.standard.object(forKey: nativeMTPEnabledKey) == nil {
+ return true
+ }
+ return UserDefaults.standard.bool(forKey: nativeMTPEnabledKey)
+ }()
+ {
+ didSet {
+ UserDefaults.standard.set(nativeMTPEnabled, forKey: nativeMTPEnabledKey)
+ }
+ }
@Published var defaultModelsDir: String = {
return UserDefaults.standard.string(forKey: defaultModelsDirKey) ?? ""
}()
@@ -366,6 +378,7 @@ final class ExoProcessController: ObservableObject {
environment["EXO_OFFLINE"] = "true"
}
environment["EXO_FAST_SYNCH"] = fastSynchEnabled ? "true" : "false"
+ environment["EXO_NATIVE_MTP_ENABLED"] = nativeMTPEnabled ? "1" : "0"
var paths: [String] = []
if let existing = environment["PATH"], !existing.isEmpty {
diff --git a/app/EXO/EXO/Views/SettingsView.swift b/app/EXO/EXO/Views/SettingsView.swift
index bb1f1bb798..0786cf8052 100644
--- a/app/EXO/EXO/Views/SettingsView.swift
+++ b/app/EXO/EXO/Views/SettingsView.swift
@@ -16,6 +16,7 @@ struct SettingsView: View {
@State private var pendingEnableImageModels = false
@State private var pendingOfflineMode = false
@State private var pendingFastSynchEnabled = false
+ @State private var pendingNativeMTPEnabled = true
@State private var pendingDefaultModelsDir: String = ""
@State private var pendingAdditionalModelsDirs: String = ""
@State private var pendingReadOnlyModelsDirs: String = ""
@@ -54,6 +55,7 @@ struct SettingsView: View {
pendingEnableImageModels = controller.enableImageModels
pendingOfflineMode = controller.offlineMode
pendingFastSynchEnabled = controller.fastSynchEnabled
+ pendingNativeMTPEnabled = controller.nativeMTPEnabled
pendingDefaultModelsDir = controller.defaultModelsDir
pendingAdditionalModelsDirs = controller.additionalModelsDirs
pendingReadOnlyModelsDirs = controller.readOnlyModelsDirs
@@ -131,6 +133,15 @@ struct SettingsView: View {
.foregroundColor(.secondary)
}
+ Section {
+ Toggle("Use Native MTP", isOn: $pendingNativeMTPEnabled)
+ Text(
+ "Enable native multi-token-prediction speculative decoding for supported model checkpoints. On by default."
+ )
+ .font(.caption)
+ .foregroundColor(.secondary)
+ }
+
Section {
HStack {
Spacer()
@@ -607,6 +618,7 @@ struct SettingsView: View {
private var hasModelChanges: Bool {
pendingEnableImageModels != controller.enableImageModels
+ || pendingNativeMTPEnabled != controller.nativeMTPEnabled
}
private var hasAdvancedChanges: Bool {
@@ -630,6 +642,7 @@ struct SettingsView: View {
private func applyModelSettings() {
controller.enableImageModels = pendingEnableImageModels
+ controller.nativeMTPEnabled = pendingNativeMTPEnabled
restartIfRunning()
}
diff --git a/dashboard/src/lib/components/ModelPickerGroup.svelte b/dashboard/src/lib/components/ModelPickerGroup.svelte
index c09600b076..87917aaf52 100644
--- a/dashboard/src/lib/components/ModelPickerGroup.svelte
+++ b/dashboard/src/lib/components/ModelPickerGroup.svelte
@@ -9,6 +9,7 @@
capabilities?: string[];
family?: string;
is_custom?: boolean;
+ native_mtp?: { default_k: number; max_k: number } | null;
}
interface ModelGroup {
@@ -127,6 +128,11 @@
}
}
+ // Whether any variant in this group ships native MTP (informational badge).
+ const hasNativeMtp = $derived(
+ group.variants.some((v) => v.native_mtp != null),
+ );
+
// Check if this group's model is currently selected (for single-variant groups)
const isMainSelected = $derived(
!group.hasMultipleVariants &&
@@ -309,6 +315,15 @@
{/if}
{/each}
+
+ {#if hasNativeMtp}
+
+ MTP
+
+ {/if}
@@ -523,6 +538,16 @@
{variant.quantization || "default"}
+
+ {#if variant.native_mtp}
+
+ MTP
+
+ {/if}
+
ModelLis
capabilities=card.capabilities,
reasoning_dialect=card.reasoning_dialect,
context_length=card.context_length,
+ native_mtp=(
+ NativeMTPModelInfo(
+ default_k=card.native_mtp.default_k,
+ max_k=card.native_mtp.max_k,
+ )
+ if card.native_mtp is not None
+ else None
+ ),
)
for card in cards
]
diff --git a/src/exo/api/tests/test_generation_stats.py b/src/exo/api/tests/test_generation_stats.py
new file mode 100644
index 0000000000..9cc467f1d5
--- /dev/null
+++ b/src/exo/api/tests/test_generation_stats.py
@@ -0,0 +1,78 @@
+"""Regression tests for speculative-decode generation telemetry."""
+
+from __future__ import annotations
+
+from exo.api.types.api import GenerationStats
+from exo.shared.types.memory import Memory
+
+
+def _stats(
+ *,
+ generation_tokens: int = 100,
+ accepted: int = 25,
+ proposed: int = 0,
+ drafter_model_id: str | None = None,
+ draft_mode: str | None = None,
+ drafter_kind: str | None = None,
+ num_draft_tokens: int | None = None,
+) -> GenerationStats:
+ return GenerationStats(
+ prompt_tps=0.0,
+ prompt_tokens=10,
+ generation_tps=10.0,
+ generation_tokens=generation_tokens,
+ peak_memory_usage=Memory.from_bytes(0),
+ accepted_draft_tokens=accepted,
+ proposed_draft_tokens=proposed,
+ drafter_model_id=drafter_model_id,
+ num_draft_tokens=num_draft_tokens,
+ draft_mode=draft_mode, # pyright: ignore[reportArgumentType]
+ drafter_kind=drafter_kind, # pyright: ignore[reportArgumentType]
+ )
+
+
+def test_acceptance_fraction_is_none_for_explicit_none_mode() -> None:
+ stats = _stats(draft_mode="none", accepted=0)
+ assert stats.drafter_acceptance_fraction is None
+
+
+def test_acceptance_fraction_reports_for_model_runs() -> None:
+ stats = _stats(
+ draft_mode="model",
+ drafter_model_id="some-org/drafter-7b",
+ accepted=40,
+ )
+ assert stats.drafter_acceptance_fraction == 0.40
+
+
+def test_acceptance_metrics_report_for_native_mtp_runs_without_drafter_id() -> None:
+ stats = _stats(
+ draft_mode="model",
+ drafter_model_id=None,
+ drafter_kind="native_mtp",
+ num_draft_tokens=2,
+ accepted=25,
+ proposed=40,
+ )
+ assert stats.drafter_acceptance_fraction == 0.25
+ assert stats.drafter_acceptance_rate == 0.625
+
+
+def test_acceptance_fraction_legacy_payload_without_draft_mode() -> None:
+ legacy_with_drafter = _stats(
+ draft_mode=None,
+ drafter_model_id="legacy-org/drafter",
+ accepted=10,
+ )
+ legacy_without_drafter = _stats(
+ draft_mode=None,
+ drafter_model_id=None,
+ accepted=0,
+ )
+ assert legacy_with_drafter.drafter_acceptance_fraction is not None
+ assert legacy_without_drafter.drafter_acceptance_fraction is None
+
+
+def test_acceptance_fraction_zero_generation_tokens_returns_none() -> None:
+ stats = _stats(generation_tokens=0, accepted=0, draft_mode="model")
+ assert stats.drafter_acceptance_fraction is None
diff --git a/src/exo/api/tests/test_model_list_native_mtp.py b/src/exo/api/tests/test_model_list_native_mtp.py
new file mode 100644
index 0000000000..9d11d2c5af
--- /dev/null
+++ b/src/exo/api/tests/test_model_list_native_mtp.py
@@ -0,0 +1,63 @@
+import pytest
+
+from exo.api.main import API
+from exo.shared.models import model_cards
+from exo.shared.models.model_cards import (
+ ModelCard,
+ ModelTask,
+ NativeMTPConfig,
+)
+from exo.shared.types.backends import Backend
+from exo.shared.types.common import ModelId
+from exo.shared.types.memory import Memory
+from exo.shared.types.state import State
+
+
+def _native_mtp_card(model_id: str, *, default_k: int, max_k: int) -> ModelCard:
+ return ModelCard(
+ model_id=ModelId(model_id),
+ storage_size=Memory.from_mb(1),
+ n_layers=1,
+ hidden_size=1,
+ supports_tensor=True,
+ tasks=[ModelTask.TextGeneration],
+ backends=[Backend.MlxMetal],
+ native_mtp=NativeMTPConfig(
+ num_layers=1,
+ default_k=default_k,
+ max_k=max_k,
+ ),
+ )
+
+
+@pytest.mark.asyncio
+async def test_models_response_includes_native_mtp_for_native_mtp_cards(
+ monkeypatch: pytest.MonkeyPatch,
+) -> None:
+ cards = [
+ _native_mtp_card("Jundot/Qwen3.6-27B-oQ8-mtp", default_k=2, max_k=3),
+ _native_mtp_card(
+ "alvarolizama/Qwen3.6-35B-A3B-oQ8-mtp",
+ default_k=1,
+ max_k=3,
+ ),
+ ]
+
+ async def _fake_list_all() -> list[ModelCard]:
+ return cards
+
+ api = API.__new__(API)
+ api.state = State()
+ monkeypatch.setattr(model_cards.card_cache, "list_all", _fake_list_all)
+
+ response = await api.get_models()
+
+ by_id = {item.id: item for item in response.data}
+ qwen27_native_mtp = by_id["Jundot/Qwen3.6-27B-oQ8-mtp"].native_mtp
+ qwen35_native_mtp = by_id["alvarolizama/Qwen3.6-35B-A3B-oQ8-mtp"].native_mtp
+ assert qwen27_native_mtp is not None
+ assert qwen27_native_mtp.default_k == 2
+ assert qwen27_native_mtp.max_k == 3
+ assert qwen35_native_mtp is not None
+ assert qwen35_native_mtp.default_k == 1
+ assert qwen35_native_mtp.max_k == 3
diff --git a/src/exo/api/types/__init__.py b/src/exo/api/types/__init__.py
index 9cb2f834fa..7aea349ea3 100644
--- a/src/exo/api/types/__init__.py
+++ b/src/exo/api/types/__init__.py
@@ -40,6 +40,7 @@
from .api import LogprobsContentItem as LogprobsContentItem
from .api import ModelList as ModelList
from .api import ModelListModel as ModelListModel
+from .api import NativeMTPModelInfo as NativeMTPModelInfo
from .api import NodePowerStats as NodePowerStats
from .api import PlaceInstanceParams as PlaceInstanceParams
from .api import PlacementPreview as PlacementPreview
diff --git a/src/exo/api/types/api.py b/src/exo/api/types/api.py
index 8cfa10dd1a..c994d3dca0 100644
--- a/src/exo/api/types/api.py
+++ b/src/exo/api/types/api.py
@@ -29,6 +29,16 @@ class ErrorResponse(BaseModel):
error: ErrorInfo
+class NativeMTPModelInfo(BaseModel):
+ """Native Multi-Token Prediction metadata surfaced for cards whose
+ on-disk checkpoint ships MTP weights handled by exo's in-process
+ draft+verify loop. ``default_k`` is the K used when a request doesn't
+ override; ``max_k`` is the card-declared upper bound."""
+
+ default_k: int
+ max_k: int
+
+
class ModelListModel(BaseModel):
id: str
object: str = "model"
@@ -49,6 +59,9 @@ class ModelListModel(BaseModel):
base_model: str = Field(default="")
capabilities: list[str] = Field(default_factory=list)
reasoning_dialect: ReasoningDialect = "none"
+ # Native MTP metadata for checkpoints whose card declares exo's
+ # in-process MTP draft+verify path. ``None`` for all other models.
+ native_mtp: NativeMTPModelInfo | None = None
class ModelList(BaseModel):
@@ -167,6 +180,39 @@ class GenerationStats(BaseModel):
generation_tokens: int
peak_memory_usage: Memory
prefix_cache_hit: Literal["none", "partial", "exact"] = "none"
+ drafter_model_id: str | None = None
+ accepted_draft_tokens: int = 0
+ proposed_draft_tokens: int = 0
+ spec_decode_rounds: int = 0
+ num_draft_tokens: int | None = None
+ draft_mode: (
+ Literal["model", "pipelined", "ngram", "eagle", "lookahead", "none"] | None
+ ) = None
+ drafter_kind: Literal["standard", "mtp", "dflash", "native_mtp"] | None = None
+
+ @property
+ def drafter_acceptance_fraction(self) -> float | None:
+ """Fraction of generated tokens that came from accepted drafts."""
+ if self.generation_tokens == 0:
+ return None
+ if self.draft_mode is None:
+ if self.drafter_model_id is None and self.drafter_kind is None:
+ return None
+ elif self.draft_mode == "none":
+ return None
+ return self.accepted_draft_tokens / self.generation_tokens
+
+ @property
+ def drafter_acceptance_rate(self) -> float | None:
+ """Classical speculative acceptance rate: accepted / proposed."""
+ if self.proposed_draft_tokens == 0:
+ return None
+ if self.draft_mode is None:
+ if self.drafter_model_id is None and self.drafter_kind is None:
+ return None
+ elif self.draft_mode == "none":
+ return None
+ return self.accepted_draft_tokens / self.proposed_draft_tokens
class ImageGenerationStats(BaseModel):
diff --git a/src/exo/shared/models/model_cards.py b/src/exo/shared/models/model_cards.py
index 8911b9323c..a9c020cb0e 100644
--- a/src/exo/shared/models/model_cards.py
+++ b/src/exo/shared/models/model_cards.py
@@ -154,6 +154,47 @@ class SamplingDefaults(SamplingValues):
non_thinking: SamplingValues | None = None
+class NativeMTPConfig(FrozenModel):
+ """Declares that a target checkpoint ships native Multi-Token Prediction
+ weights handled by exo's vendored MTP-aware loader.
+
+ Set this field on cards whose on-disk checkpoint exposes recoverable
+ MTP tensors -- typically MTPLX-format Qwen3.5/3.6 artefacts that ship a
+ separate ``mtp.safetensors`` sidecar (``mlx_lm_extra_tensors.mtp_file``
+ in ``config.json``), or the original HuggingFace layout where the MTP
+ tensors are embedded in the main shards under the ``mtp.*`` prefix.
+
+ The loader at :mod:`exo.worker.engines.mlx.vendor.qwen3_5_mtp_loader`
+ consumes this declaration through ``utils_mlx.load_mlx_items``: when
+ the card has ``native_mtp`` set, the placement is single-node, and
+ :func:`exo.worker.engines.mlx.mtp_probe.probe_mtp_weights` confirms
+ the tensors are present on disk, the loader dispatches via
+ :func:`load_mtp_model` instead of stock ``mlx_lm.utils.load_model``.
+
+ Native MTP is structurally single-node only: the verify forward through
+ a TP-sharded target would amortise K+1 tokens over K+1 compute units,
+ eating the MTP speedup. The guard
+ :func:`exo.worker.engines.mlx.utils_mlx.is_native_mtp_runnable`
+ enforces this -- the loader silently falls back to the stock load path
+ on multi-rank instances.
+ """
+
+ # Number of MTP transformer layers in this checkpoint. Matches
+ # ``text_config.mtp_num_hidden_layers`` in ``config.json``.
+ # Qwen3.6-27B MTPLX ships with 1.
+ num_layers: PositiveInt
+ # Default K for draft+verify when a request doesn't override.
+ default_k: PositiveInt = 3
+ # Maximum K the runner is allowed to use. Above this, verify cost
+ # dominates.
+ max_k: PositiveInt = 3
+ # Hugging Face filename of the separate MTP safetensors when shipped
+ # in MTPLX format. ``None`` when MTP tensors are embedded in main
+ # shards (original HuggingFace layout) -- the loader probes both
+ # paths regardless of this hint.
+ mtp_file: str | None = None
+
+
class ModelCard(FrozenModel):
model_id: ModelId
storage_size: Memory
@@ -175,6 +216,13 @@ class ModelCard(FrozenModel):
is_custom: bool = False
vision: VisionCardConfig | None = None
sampling_defaults: SamplingDefaults = Field(default_factory=SamplingDefaults)
+ # Optional declaration that the target checkpoint ships native
+ # Multi-Token Prediction (MTP) weights handled directly by exo's
+ # vendored Qwen3.5/3.6 MTP-aware loader. See :class:`NativeMTPConfig`
+ # for the field-by-field semantics and the loader / placement gates
+ # that consume it. Single-node only; ``None`` (the default) preserves
+ # legacy behaviour and is purely additive.
+ native_mtp: NativeMTPConfig | None = None
@model_validator(mode="after")
def _autodetect_vision(self) -> "ModelCard":
diff --git a/src/exo/worker/engines/mlx/builder.py b/src/exo/worker/engines/mlx/builder.py
index af7c75bb9d..0284fc63e8 100644
--- a/src/exo/worker/engines/mlx/builder.py
+++ b/src/exo/worker/engines/mlx/builder.py
@@ -21,6 +21,8 @@
from exo.worker.runner.llm_inference.tool_parsers import make_mlx_parser
from .cache import KVPrefixCache
+from .generator.native_mtp_drafter import is_native_mtp_dispatchable
+from .native_mtp_config import native_mtp_enabled_from_env
from .types import Model
from .utils_mlx import (
initialize_mlx,
@@ -38,6 +40,11 @@ class MlxBuilder(Builder):
tokenizer: TokenizerWrapper | None = None
group: mx.distributed.Group | None = None
vision_processor: VisionProcessor | None = None
+ # Native-MTP K bounds captured from the model card at load time, used
+ # by ``build`` to configure the generator. ``None`` unless the card
+ # declares ``native_mtp``.
+ native_mtp_default_k: int | None = None
+ native_mtp_max_k: int | None = None
def connect(self, bound_instance: BoundInstance) -> None:
self.group = initialize_mlx(bound_instance)
@@ -48,6 +55,10 @@ def load(self, bound_instance: BoundInstance) -> Generator[ModelLoadingResponse]
self.tokenizer,
self.vision_processor,
) = yield from load_mlx_items(bound_instance, self.group)
+ native_mtp = bound_instance.bound_shard.model_card.native_mtp
+ if native_mtp is not None:
+ self.native_mtp_default_k = native_mtp.default_k
+ self.native_mtp_max_k = native_mtp.max_k
def close(self) -> None:
with contextlib.suppress(NameError, AttributeError):
@@ -83,8 +94,19 @@ def build(
kv_prefix_cache = KVPrefixCache(self.group)
device_rank = 0 if self.group is None else self.group.rank()
- if os.environ.get("EXO_NO_BATCH"):
- logger.info("using SequentialGenerator (batching disabled)")
+ # Native MTP runs the single-request draft+verify loop inside
+ # ``mlx_generate`` (SequentialGenerator), so force sequential mode
+ # when the target loaded as a vendored MTP-aware model and native
+ # MTP is enabled. ``is_native_mtp_dispatchable`` is an isinstance
+ # check, so it is only ever True for single-node MTP checkpoints.
+ native_mtp_dispatchable = (
+ self.group is None
+ and native_mtp_enabled_from_env()
+ and is_native_mtp_dispatchable(self.inference_model)
+ )
+ if native_mtp_dispatchable or os.environ.get("EXO_NO_BATCH"):
+ reason = "native MTP" if native_mtp_dispatchable else "batching disabled"
+ logger.info(f"using SequentialGenerator ({reason})")
return SequentialGenerator(
model=self.inference_model,
tokenizer=self.tokenizer,
@@ -96,6 +118,12 @@ def build(
cancel_receiver=self.cancel_receiver,
event_sender=self.event_sender,
vision_processor=vision_processor,
+ native_mtp_default_k=(
+ self.native_mtp_default_k if native_mtp_dispatchable else None
+ ),
+ native_mtp_max_k=(
+ self.native_mtp_max_k if native_mtp_dispatchable else None
+ ),
)
else:
logger.info("using BatchGenerator")
diff --git a/src/exo/worker/engines/mlx/generator/generate.py b/src/exo/worker/engines/mlx/generator/generate.py
index 2e3d051251..ff54a1c7fb 100644
--- a/src/exo/worker/engines/mlx/generator/generate.py
+++ b/src/exo/worker/engines/mlx/generator/generate.py
@@ -6,6 +6,9 @@
from typing import Callable, Generator, cast, get_args
import mlx.core as mx
+from mlx_lm.generate import (
+ GenerationResponse as MlxGenerationResponse,
+)
from mlx_lm.generate import (
maybe_quantize_kv_cache,
stream_generate,
@@ -55,7 +58,15 @@
KV_GROUP_SIZE,
MAX_TOKENS,
)
+from exo.worker.engines.mlx.generator.native_mtp_drafter import (
+ NativeMTPDrafter,
+ is_native_mtp_dispatchable,
+)
from exo.worker.engines.mlx.generator.remote_prefill import remote_prefill
+from exo.worker.engines.mlx.native_mtp_config import (
+ native_mtp_enabled_from_env,
+ resolve_native_mtp_num_draft_tokens,
+)
from exo.worker.engines.mlx.types import KVCacheType, Model
from exo.worker.engines.mlx.utils_mlx import (
apply_chat_template,
@@ -540,6 +551,8 @@ def mlx_generate(
distributed_prompt_progress_callback: Callable[[], None] | None = None,
on_generation_token: Callable[[], None] | None = None,
vision_processor: VisionProcessor | None = None,
+ native_mtp_default_k: int | None = None,
+ native_mtp_max_k: int | None = None,
) -> Generator[GenerationResponse]:
# Ensure that generation stats only contains peak memory for this generation
mx.reset_peak_memory()
@@ -717,8 +730,59 @@ def mlx_generate(
logger.info("Starting decode")
mx_barrier(group)
- for completion_tokens, out in enumerate(
- stream_generate(
+ # Native MTP dispatch: when the target loaded as a vendored MTP-aware
+ # model (single-node, env-enabled), draft+verify through its own MTP
+ # head. ``NativeMTPDrafter.stream`` is a drop-in for ``stream_generate``
+ # -- it yields the same ``GenerationResponse`` objects, primes its own
+ # MTP cache from the full prompt, and consumes the KV cache that
+ # ``prefill`` already aligned to ``prompt_tokens[:-2]`` (the 2-token
+ # ``last_token`` tail is the decode seed for both paths).
+ native_mtp_active = (
+ group is None
+ and native_mtp_enabled_from_env()
+ and is_native_mtp_dispatchable(model)
+ )
+ native_mtp_drafter: NativeMTPDrafter | None = None
+ native_k: int | None = None
+ if native_mtp_active:
+ native_k, native_k_clamped = resolve_native_mtp_num_draft_tokens(
+ request_num_draft_tokens=None,
+ configured_num_draft_tokens=None,
+ card_default_k=native_mtp_default_k
+ if native_mtp_default_k is not None
+ else 1,
+ card_max_k=native_mtp_max_k if native_mtp_max_k is not None else 1,
+ )
+ if native_k_clamped:
+ logger.info(f"native MTP K clamped to {native_k}")
+ full_context_tokens = [
+ int(t) for t in cast(list[int], all_prompt_tokens.tolist())
+ ]
+ logger.info(f"native MTP dispatch: K={native_k}, single-node")
+ native_mtp_drafter = NativeMTPDrafter(k=native_k)
+ token_stream: Generator[MlxGenerationResponse] = native_mtp_drafter.stream(
+ model=model,
+ tokenizer=tokenizer,
+ prompt=last_token,
+ context_tokens=full_context_tokens,
+ prompt_cache=caches,
+ max_tokens=max_tokens,
+ sampler=sampler,
+ logits_processors=logits_processors,
+ prefill_step_size=1,
+ # Lossless speculative sampling: thread the raw sampling params so
+ # the drafter can reconstruct the request's distribution (the
+ # ``sampler`` callable alone is not enough -- it only yields a
+ # token, not the per-step distribution the accept test needs).
+ # Defaults mirror the ``make_sampler`` call above so MTP-on and
+ # MTP-off sample from the same distribution.
+ temperature=task.temperature if task.temperature is not None else 0.7,
+ top_p=task.top_p if task.top_p is not None else 1.0,
+ top_k=task.top_k if task.top_k is not None else 0,
+ min_p=task.min_p if task.min_p is not None else 0.05,
+ )
+ else:
+ token_stream = stream_generate(
model=model,
tokenizer=tokenizer,
prompt=last_token,
@@ -729,9 +793,9 @@ def mlx_generate(
prefill_step_size=1,
kv_group_size=KV_GROUP_SIZE,
kv_bits=KV_BITS,
- ),
- start=1,
- ):
+ )
+
+ for completion_tokens, out in enumerate(token_stream, start=1):
generated_text_parts.append(out.text)
accumulated_text += out.text
@@ -758,12 +822,27 @@ def mlx_generate(
stats: GenerationStats | None = None
if is_done:
+ native_mtp_metrics = (
+ native_mtp_drafter.metrics()
+ if native_mtp_drafter is not None and native_k is not None
+ else {}
+ )
stats = GenerationStats(
prompt_tps=float(prefill_tps or out.prompt_tps),
generation_tps=float(out.generation_tps),
prompt_tokens=int(prefill_tokens + out.prompt_tokens),
generation_tokens=int(out.generation_tokens),
peak_memory_usage=Memory.from_gb(out.peak_memory),
+ accepted_draft_tokens=native_mtp_metrics.get(
+ "accepted_draft_tokens", 0
+ ),
+ proposed_draft_tokens=native_mtp_metrics.get(
+ "proposed_draft_tokens", 0
+ ),
+ spec_decode_rounds=native_mtp_metrics.get("spec_decode_rounds", 0),
+ num_draft_tokens=native_k if native_mtp_drafter is not None else None,
+ draft_mode="model" if native_mtp_drafter is not None else "none",
+ drafter_kind="native_mtp" if native_mtp_drafter is not None else None,
)
if not stop_matched and out.finish_reason not in get_args(FinishReason):
logger.warning(
diff --git a/src/exo/worker/engines/mlx/generator/native_mtp_drafter.py b/src/exo/worker/engines/mlx/generator/native_mtp_drafter.py
new file mode 100644
index 0000000000..1b784ca029
--- /dev/null
+++ b/src/exo/worker/engines/mlx/generator/native_mtp_drafter.py
@@ -0,0 +1,1890 @@
+"""Native Multi-Token Prediction drafter for Qwen3.5/3.6.
+
+Native MTP drafts come from the target model's own MTP head: a single
+transformer layer that consumes the target's post-norm hidden state +
+the previously-sampled token and predicts the next token via the shared
+``lm_head``. No sibling drafter LM is loaded and no separate model KV
+cache exists -- the MTP layer has its own KV cache that's populated from
+prompt prefill and extended on every accepted draft.
+
+**Cache rollback (Qwen3.6 GatedDeltaNet)**
+
+``mlx_trim_prompt_cache`` decrements KV offsets but does NOT roll back
+the GatedDeltaNet linear-attention recurrent state on Qwen3.5/3.6
+targets, so a bare trim after the batched verify leaves the GDN cache
+on the SPECULATIVE trajectory (the rejected drafts have been folded
+into the recurrent state and can't be undone via offset
+arithmetic). The ``trim_after_batched_probe`` measurement showed
+54.2% logit mismatch (Logit L2 distances 300-1000) without the fix.
+
+The loop snapshots GDN state before the batched verify; on partial accept it
+trims the main cache fully, restores GDN from the snapshot, and re-forwards
+``[current_token, *drafts[:n_accepted]]`` so the cache is advanced through
+retained tokens only. On full accept the cache is already on-trajectory, so
+rollback is skipped.
+
+The MTP cache is full-attention (``vendor/qwen3_5_mtp.py``
+``fa_layer_idx``), so its trim path uses plain
+:func:`mlx_trim_prompt_cache` -- no rollback needed.
+
+The loop:
+
+1. **Prime** the MTP KV cache by walking prompt positions through the
+ main model on a fresh cache to capture post-norm hiddens, then
+ feeding those into ``model.mtp_update_cache`` so the MTP attention
+ sees prompt context before the first draft fires.
+2. **Draft** K tokens by chaining ``model.mtp_forward`` -- each step
+ consumes the previous step's post-norm hidden so the MTP layer
+ builds a self-referential chain. K is bounded by
+ ``card.native_mtp.max_k`` upstream (default 3).
+3. **Verify** all K drafts in a single ``(K+1)``-token target forward.
+ Walk the target's preferred tokens; accept while they match drafts,
+ emit one bonus from the target's choice at the first mismatch (or
+ after a full-accept). Trim both caches by ``K - num_accepted``.
+
+Native MTP is single-node-only by design (the loader gate
+:func:`exo.worker.engines.mlx.utils_mlx._native_mtp_loader_eligible`
+only fires on single-node placements): a TP-sharded verify forward
+amortises ``K+1`` tokens over ``K+1`` compute units, eating the MTP
+speedup.
+"""
+
+# pyright: reportUnknownMemberType=false, reportUnknownArgumentType=false, reportUnknownVariableType=false, reportAttributeAccessIssue=false, reportCallIssue=false, reportArgumentType=false, reportAny=false
+# This module composes over untyped mlx-lm tokenizer / vendor MTP
+# attributes; the type-checker can't see through nn.Module subclassing
+# in the way mlx-lm uses it. Mirrors the same pragma block as
+# ``vendor/qwen3_5_mtp.py``.
+
+from __future__ import annotations
+
+import os
+import time
+from collections.abc import Callable, Generator, Iterable, Sequence
+from typing import Any, Literal, cast, final
+
+import mlx.core as mx
+import mlx.nn as nn
+from mlx.nn.layers.distributed import sum_gradients
+from mlx_lm.generate import GenerationResponse
+from mlx_lm.models.base import create_attention_mask, create_ssm_mask
+from mlx_lm.models.cache import trim_prompt_cache as mlx_trim_prompt_cache
+from mlx_lm.sample_utils import apply_min_p, apply_top_k, apply_top_p
+from mlx_lm.tokenizer_utils import TokenizerWrapper
+
+from exo.worker.engines.mlx.spec_cache import (
+ restore_cache,
+ rollback_after_verify,
+ snapshot_untrimmable_cache_lazy,
+)
+from exo.worker.engines.mlx.types import KVCacheType, Model
+from exo.worker.runner.bootstrap import logger
+
+_MoeVerifierPolicy = Literal[
+ "safe",
+ "hybrid_full_accept",
+ "batched",
+ "row_moe",
+ "route_locked",
+]
+_DENSE_GDN_STATE_HISTORY_MODEL_TYPES = frozenset({"qwen3_5", "qwen3_5_text"})
+_MOE_GDN_STATE_HISTORY_MODEL_TYPES = frozenset({"qwen3_5_moe", "qwen3_5_moe_text"})
+
+
+def is_native_mtp_dispatchable(model: object) -> bool:
+ """Return ``True`` when ``model`` is a vendored native-MTP Model instance.
+
+ The loader (:mod:`exo.worker.engines.mlx.vendor.qwen3_5_mtp_loader`)
+ swaps in :class:`vendor.qwen3_5_mtp.Model` (or its
+ ``_PatchedMtpModel`` subclass) when the card declares ``native_mtp``
+ and the on-disk weights are recoverable. We ``isinstance``-check
+ against that concrete class rather than probe for the
+ ``mtp_forward`` / ``make_mtp_cache`` / ``mtp_update_cache`` methods
+ structurally because ``unittest.mock.MagicMock()`` auto-creates any
+ attribute on access -- a structural check would falsely engage the
+ NativeMTPDrafter for any mocked target model and break unrelated
+ routing tests. The import is local to keep this module's import
+ surface small for workers that never load Qwen3.5/3.6.
+
+ This is the dispatch-side gate. The builder-side gate
+ (:func:`exo.worker.engines.mlx.utils_mlx.is_native_mtp_runnable`)
+ consults the card + placement before the model is loaded.
+ """
+ try:
+ from exo.worker.engines.mlx.vendor.qwen3_5_mtp import Model as _VendorMtpModel
+ except ImportError:
+ return False
+ return isinstance(model, _VendorMtpModel)
+
+
+def _eos_ids_from_tokenizer(tokenizer: TokenizerWrapper) -> list[int]:
+ """Tokenizer-agnostic EOS lookup mirroring drafter.py / coupled_drafter.py."""
+ eos = getattr(tokenizer, "eos_token_ids", None)
+ if eos is None:
+ eos = getattr(tokenizer, "eos_token_id", None)
+ if eos is None:
+ return []
+ if isinstance(eos, int):
+ return [eos]
+ if isinstance(eos, Iterable) and not isinstance(eos, (str, bytes)):
+ return [int(x) for x in eos]
+ try:
+ return [int(eos)]
+ except (TypeError, ValueError):
+ return []
+
+
+def _token_array_to_int(token: mx.array) -> int:
+ """Read a single-token ``mx.array`` as a Python ``int``."""
+ return int(token.reshape(()).item())
+
+
+@final
+class _SpecSampler:
+ """Reconstructs the request's sampling distribution for lossless MTP.
+
+ The drafter only receives a ``sampler`` (``logits -> token``), which is
+ not enough for probability-ratio (Leviathan/Chen) acceptance: that needs
+ the *distribution* the request samples from. We rebuild it by replicating
+ mlx-lm's ``make_sampler`` chain (``top_p`` / ``min_p`` / ``top_k`` masking,
+ then ``/temperature`` and a softmax) so the draft distribution ``q`` and
+ the target distribution ``p`` are computed with identical transforms.
+
+ Mirrors ``mtp_generate_step`` in mlx-lm PR #990 (the PR this drafter
+ derives from): ``q``/``p`` are the temperature-adjusted, filtered
+ distributions; draft tokens are sampled from ``q``; a draft token ``x`` is
+ accepted with probability ``min(1, p(x)/q(x))``; on rejection the
+ replacement is drawn from the normalized residual ``(p - q)+``; after a
+ full accept the bonus is drawn from ``p``. With ``temperature == 0`` this
+ reduces to "accept iff ``x == argmax(p)``", recovering exact greedy.
+
+ Note: ``logits_processors`` (repetition/presence/frequency penalties,
+ logit bias) are intentionally NOT applied here. They are stateful over the
+ emitted-token history and cannot be replayed consistently across the K
+ chained MTP draft steps and the batched K+1 verify, so applying them
+ inconsistently would BREAK the lossless guarantee rather than help it.
+ Penalties stay unhandled on the native-MTP path (as before); only the
+ request's temperature/top_p/top_k/min_p now govern the distribution.
+ """
+
+ def __init__(
+ self,
+ *,
+ temperature: float,
+ top_p: float,
+ top_k: int,
+ min_p: float,
+ min_tokens_to_keep: int,
+ ) -> None:
+ self._temperature: float = temperature
+ self._is_greedy: bool = temperature <= 0.0
+ self._top_p: float = top_p
+ self._top_k: int = top_k
+ self._min_p: float = min_p
+ self._min_tokens_to_keep: int = max(1, min_tokens_to_keep)
+
+ @property
+ def is_greedy(self) -> bool:
+ return self._is_greedy
+
+ def _filter(self, logprobs: mx.array) -> mx.array:
+ """Apply the request's masking filters (top_p / min_p / top_k)."""
+ masked = logprobs
+ if 0.0 < self._top_p < 1.0:
+ masked = apply_top_p(masked, self._top_p)
+ if self._min_p != 0.0:
+ masked = apply_min_p(masked, self._min_p, self._min_tokens_to_keep)
+ if self._top_k > 0:
+ masked = apply_top_k(masked, self._top_k)
+ return masked
+
+ def distribution(self, last_logits: mx.array) -> mx.array:
+ """Return the request's sampling distribution ``p`` for ``(1, V)`` logits.
+
+ ``last_logits`` is the final-position logit row, shape ``(1, V)``. The
+ returned array is normalized probabilities over the vocab with masked
+ entries at exactly ``0`` (temperature-adjusted, filtered).
+ """
+ logprobs = last_logits - mx.logsumexp(last_logits, axis=-1, keepdims=True)
+ if self._is_greedy:
+ # Degenerate distribution at the argmax; the ratio test below
+ # reduces to "x == argmax(p)" and the residual to argmax(p).
+ argmax = mx.argmax(logprobs, axis=-1)
+ return mx.zeros_like(logprobs).at[0, argmax].add(1.0)
+ masked = self._filter(logprobs)
+ scaled = masked / self._temperature
+ accept_logprobs = scaled - mx.logsumexp(scaled, axis=-1, keepdims=True)
+ return mx.exp(accept_logprobs)
+
+ def sample_from_distribution(self, distribution: mx.array) -> mx.array:
+ """Sample one token id ``(1, 1)`` int32 from a probability row ``(1, V)``."""
+ if self._is_greedy:
+ return mx.argmax(distribution, axis=-1).reshape(1, 1).astype(mx.int32)
+ # ``mx.log(0) -> -inf`` which categorical treats as probability 0.
+ return (
+ mx.random.categorical(mx.log(distribution)).reshape(1, 1).astype(mx.int32)
+ )
+
+ def residual_token(self, target_dist: mx.array, draft_dist: mx.array) -> mx.array:
+ """Sample the replacement token from the residual ``(p - q)+`` (or ``p``)."""
+ if self._is_greedy:
+ return mx.argmax(target_dist, axis=-1).reshape(1, 1).astype(mx.int32)
+ residual = mx.maximum(target_dist - draft_dist, 0.0)
+ normalizer = residual.sum(axis=-1, keepdims=True)
+ # When the residual mass is ~0 (draft already covered all target mass),
+ # fall back to sampling from the target distribution directly.
+ distribution = mx.where(normalizer > 0, residual, target_dist)
+ return self.sample_from_distribution(distribution)
+
+
+def _accept_draft_token(
+ *,
+ is_greedy: bool,
+ draft_token: int,
+ target_argmax: int,
+ target_prob: float,
+ draft_prob: float,
+ uniform: float,
+) -> bool:
+ """Probability-ratio acceptance for one drafted token.
+
+ Greedy: accept iff the drafted token equals the target's argmax.
+ Stochastic: accept with probability ``min(1, p(x)/q(x))`` against a
+ pre-drawn ``uniform`` in ``[0, 1)`` (Leviathan/Chen). ``draft_prob`` is
+ the draft head's probability ``q(x)`` for the drafted token; a draft that
+ the target assigns zero mass (``p(x) == 0``, e.g. masked out by top_k) is
+ always rejected.
+ """
+ if is_greedy:
+ return draft_token == target_argmax
+ if draft_prob <= 0.0:
+ # The draft token was not in the draft's own support; treat as a hard
+ # reject so the residual replacement (which equals p here) is emitted.
+ return False
+ return uniform < (target_prob / draft_prob)
+
+
+def _spec_accept_walk(
+ *,
+ spec_sampler: _SpecSampler,
+ verify_logits: mx.array,
+ draft_tokens: Sequence[int],
+ draft_distributions: Sequence[mx.array],
+ accept_uniforms: Sequence[float],
+ k: int,
+) -> tuple[int, mx.array]:
+ """Probability-ratio accept walk over a batched ``(1, K+1, V)`` verify.
+
+ Returns ``(num_accepted, replacement_token_arr)`` where
+ ``replacement_token_arr`` is the ``(1, 1)`` int32 token emitted after the
+ accepted prefix -- the residual sample at the first rejected position, or
+ the target bonus sample at position ``num_accepted == K`` (full accept).
+ The caller's cache machinery already handles trimming/rollback to land the
+ cache on ``[current_token, *drafts[:num_accepted]]``.
+ """
+ # Per-position target distributions p_0..p_K.
+ target_dists = [
+ spec_sampler.distribution(verify_logits[:, i, :]) for i in range(k + 1)
+ ]
+ # Consolidated readback: single sync for all K+1 argmax predictions.
+ target_argmax_arr = mx.argmax(verify_logits, axis=-1).astype(mx.int32)
+ mx.eval(target_argmax_arr)
+ target_argmaxes: list[int] = [
+ int(t) for t in list(target_argmax_arr.squeeze(0).tolist())
+ ]
+ num_accepted = 0
+ for i in range(k):
+ draft_token = draft_tokens[i]
+ target_prob = float(target_dists[i][0, draft_token].item())
+ draft_prob = float(draft_distributions[i][0, draft_token].item())
+ if _accept_draft_token(
+ is_greedy=spec_sampler.is_greedy,
+ draft_token=draft_token,
+ target_argmax=target_argmaxes[i],
+ target_prob=target_prob,
+ draft_prob=draft_prob,
+ uniform=accept_uniforms[i],
+ ):
+ num_accepted += 1
+ continue
+ # Rejected at position i: replacement from residual (p_i - q_i)+.
+ replacement = spec_sampler.residual_token(
+ target_dists[i], draft_distributions[i]
+ )
+ return num_accepted, replacement
+ # Full accept: bonus sampled from the target distribution at position K.
+ bonus = spec_sampler.sample_from_distribution(target_dists[k])
+ return num_accepted, bonus
+
+
+def _native_mtp_model_type(model: Model) -> str:
+ """Best-effort model_type lookup for vendored Qwen3.5/3.6 models."""
+ outer_args = getattr(model, "args", None)
+ outer_model_type = getattr(outer_args, "model_type", None)
+ if isinstance(outer_model_type, str):
+ return outer_model_type
+ language_model = getattr(model, "language_model", None)
+ args = getattr(language_model, "args", None)
+ model_type = getattr(args, "model_type", None)
+ return model_type if isinstance(model_type, str) else ""
+
+
+def _requires_sequential_gdn_path(model: Model) -> bool:
+ """Return True for 35B-A3B/MoE native MTP, where batched GDN state drifts.
+
+ The dense 27B path has benchmarked well with the batched verifier. The
+ qwen3_5_moe 35B path does not: batched prefill changes the recurrent
+ GatedDeltaNet state enough that MTP drafts the wrong first token, and a
+ batched verifier can produce wrong later-position bonus tokens. Keep this
+ slow path scoped to MoE until we have a fused GDN scan that is numerically
+ equivalent to token-by-token recurrent updates.
+ """
+ return _native_mtp_model_type(model) in {"qwen3_5_moe", "qwen3_5_moe_text"}
+
+
+def _gdn_state_history_commit_default(
+ *,
+ model_type: str,
+ moe_verifier_policy: _MoeVerifierPolicy,
+) -> bool:
+ """Return whether GDN prefix-state commit is default-on for this model."""
+ if model_type in _DENSE_GDN_STATE_HISTORY_MODEL_TYPES:
+ return True
+ return (
+ model_type in _MOE_GDN_STATE_HISTORY_MODEL_TYPES
+ and moe_verifier_policy == "route_locked"
+ )
+
+
+def _gdn_state_history_kernel_supported(model: Model) -> bool:
+ """Return True when the loaded GDN layers fit the state-history kernel."""
+ language_model = getattr(model, "language_model", None)
+ inner = getattr(language_model, "model", None)
+ layers = getattr(inner, "layers", None)
+ if not isinstance(layers, list):
+ return False
+ for layer in layers:
+ if not bool(getattr(layer, "is_linear", False)):
+ continue
+ gdn = getattr(layer, "linear_attn", None)
+ head_k_dim = getattr(gdn, "head_k_dim", None)
+ head_v_dim = getattr(gdn, "head_v_dim", None)
+ if not isinstance(head_k_dim, int) or not isinstance(head_v_dim, int):
+ return False
+ return head_k_dim >= 32 and head_k_dim % 32 == 0 and head_v_dim > 0
+ return False
+
+
+def _moe_verifier_policy() -> _MoeVerifierPolicy:
+ """Return the qwen3_5_moe verifier policy for native-MTP experiments.
+
+ ``route_locked`` is the optimized production default for Qwen3.6 MoE.
+ ``safe`` is the coherent reference path: token-by-token verification.
+ ``hybrid_full_accept`` first tries a batched verifier and commits it only
+ when every draft token matches; partial/miss rounds fall back to ``safe``.
+ ``batched`` is the research fast path for falsifying whether the batched
+ verifier can be made coherent under the current cache repair rules.
+ ``row_moe`` keeps attention/GatedDeltaNet batched but evaluates sparse
+ MoE blocks one token row at a time, matching the token-by-token verifier
+ while avoiding full sequential layer execution. ``route_locked`` row-steps
+ only the router gate, then runs the fixed-index expert and shared-expert
+ work batched.
+ """
+ raw = os.environ.get("EXO_NATIVE_MTP_MOE_VERIFY", "route_locked")
+ normalized = raw.strip().lower().replace("-", "_")
+ if normalized in {
+ "safe",
+ "hybrid_full_accept",
+ "batched",
+ "row_moe",
+ "route_locked",
+ }:
+ return cast(_MoeVerifierPolicy, normalized)
+ logger.warning(
+ f"[native MTP] unknown EXO_NATIVE_MTP_MOE_VERIFY={raw!r}; using 'safe'"
+ )
+ return "safe"
+
+
+def _env_flag(name: str, *, default: bool) -> bool:
+ raw = os.environ.get(name)
+ if raw is None:
+ return default
+ return raw.strip().lower() in {"1", "true", "yes", "on"}
+
+
+def _env_int(name: str, *, default: int, minimum: int) -> int:
+ raw = os.environ.get(name)
+ if raw is None:
+ return default
+ try:
+ value = int(raw)
+ except ValueError:
+ logger.warning(f"[native MTP] invalid {name}={raw!r}; using {default}")
+ return default
+ return max(minimum, value)
+
+
+def _env_float(name: str, *, default: float) -> float:
+ raw = os.environ.get(name)
+ if raw is None:
+ return default
+ try:
+ return float(raw)
+ except ValueError:
+ logger.warning(f"[native MTP] invalid {name}={raw!r}; using {default}")
+ return default
+
+
+_ASYNC_DRAFT_EVAL_ENABLED = _env_flag("EXO_NATIVE_MTP_ASYNC_DRAFT_EVAL", default=True)
+
+
+def _cache_state_arrays(cache: Sequence[Any]) -> list[mx.array]:
+ """Return materialisable cache state arrays without forcing evaluation."""
+ state_arrays: list[mx.array] = []
+ for entry in cache:
+ state_obj: object = getattr(entry, "state", None)
+ if isinstance(state_obj, mx.array):
+ state_arrays.append(state_obj)
+ elif isinstance(state_obj, (list, tuple)):
+ for inner in cast(Sequence[object], state_obj):
+ if isinstance(inner, mx.array):
+ state_arrays.append(inner)
+ return state_arrays
+
+
+def _materialize_cache_state(cache: Sequence[Any]) -> None:
+ """Force ``mx.eval`` on every materialisable state in ``cache``.
+
+ KVCache exposes ``.state`` as ``mx.array``; some custom caches use a
+ tuple/list of arrays. We walk both shapes so the cache writes
+ complete before the next forward consumes them.
+ """
+ state_arrays = _cache_state_arrays(cache)
+ if state_arrays:
+ mx.eval(state_arrays)
+
+
+def _async_materialize_cache_state(
+ cache: Sequence[Any],
+ *arrays: mx.array,
+) -> None:
+ """Schedule draft/cache evaluation without waiting immediately.
+
+ MLX evaluation is lazy. K-chain draft work can overlap with Python-side
+ verifier graph construction as long as we schedule the draft tokens and
+ MTP cache writes before building the target verify forward. The first
+ required readback still synchronizes, but it has less work left to wait on.
+ """
+ state_arrays = _cache_state_arrays(cache)
+ eval_arrays = [*arrays, *state_arrays]
+ if eval_arrays:
+ mx.async_eval(eval_arrays)
+
+
+def _is_sparse_moe_block(mlp: object) -> bool:
+ """Return True for Qwen3.5/Next sparse MoE blocks."""
+ return (
+ callable(getattr(mlp, "gate", None))
+ and callable(getattr(mlp, "switch_mlp", None))
+ and isinstance(getattr(mlp, "top_k", None), int)
+ )
+
+
+def _row_stepped_sparse_moe(mlp: object, mlp_input: mx.array) -> mx.array:
+ """Evaluate a Qwen sparse MoE block one verify token at a time."""
+ if not callable(mlp):
+ raise TypeError(f"mlp object is not callable: {type(mlp).__name__}")
+ return mx.concatenate(
+ [mlp(mlp_input[:, i : i + 1, :]) for i in range(int(mlp_input.shape[1]))],
+ axis=1,
+ )
+
+
+def _route_locked_sparse_moe(mlp: object, mlp_input: mx.array) -> mx.array:
+ """Evaluate sparse MoE with exact row-stepped routing and batched experts."""
+ gate = getattr(mlp, "gate", None)
+ switch_mlp = getattr(mlp, "switch_mlp", None)
+ shared_expert = getattr(mlp, "shared_expert", None)
+ shared_expert_gate = getattr(mlp, "shared_expert_gate", None)
+ top_k = getattr(mlp, "top_k", None)
+ if (
+ not callable(gate)
+ or not callable(switch_mlp)
+ or not callable(shared_expert)
+ or not callable(shared_expert_gate)
+ or not isinstance(top_k, int)
+ or top_k <= 0
+ ):
+ raise TypeError(
+ f"mlp object is not a supported sparse MoE: {type(mlp).__name__}"
+ )
+
+ sharding_group = getattr(mlp, "sharding_group", None)
+ x = (
+ sum_gradients(sharding_group)(mlp_input)
+ if sharding_group is not None
+ else mlp_input
+ )
+
+ # Qwen3.6-35B's BF16 router gate takes a different MLX kernel at M>1 than
+ # at M=1. Route rows one-by-one to match token verification exactly, then
+ # keep the expensive expert matmuls batched with fixed indices.
+ gates = mx.concatenate(
+ [gate(x[:, i : i + 1, :]) for i in range(int(x.shape[1]))],
+ axis=1,
+ )
+ gates = mx.softmax(gates, axis=-1, precise=True)
+ indices = mx.argpartition(gates, kth=-top_k, axis=-1)[..., -top_k:]
+ scores = mx.take_along_axis(gates, indices, axis=-1)
+ if bool(getattr(mlp, "norm_topk_prob", False)):
+ scores = scores / scores.sum(axis=-1, keepdims=True)
+
+ expert_output = cast(mx.array, switch_mlp(x, indices))
+ output = (expert_output * scores[..., None]).sum(axis=-2)
+ shared_gate_output = cast(mx.array, shared_expert_gate(x))
+ shared_expert_output = cast(mx.array, shared_expert(x))
+ shared_output = mx.sigmoid(shared_gate_output) * shared_expert_output
+ output = output + shared_output
+
+ if sharding_group is not None:
+ output = mx.distributed.all_sum(output, group=sharding_group)
+ return output
+
+
+def _target_forward_row_moe(
+ *,
+ model: Model,
+ inputs: mx.array,
+ cache: KVCacheType,
+) -> tuple[mx.array, mx.array]:
+ """Target forward with batched non-MoE ops and row-stepped sparse MoE.
+
+ ``scripts/native_mtp_verify_parity_probe.py`` shows qwen3_5_moe's first
+ batched verifier mismatch occurs inside sparse MoE gate/evaluation, while
+ GatedDeltaNet internals and full-attention cache updates stay exact for the
+ K+1 verify window. This forward preserves those batched exact operations
+ and only splits each sparse MoE call across token rows.
+ """
+ text_model = model.language_model
+ inner = text_model.model
+ hidden_states = inner.embed_tokens(inputs)
+ cache_list = cast("list[Any]", cache)
+ fa_mask = create_attention_mask(hidden_states, cache_list[inner.fa_idx])
+ ssm_mask = create_ssm_mask(hidden_states, cache_list[inner.ssm_idx])
+
+ for layer, layer_cache in zip(inner.layers, cache_list, strict=True):
+ selected_mask = ssm_mask if bool(layer.is_linear) else fa_mask
+ normalized = layer.input_layernorm(hidden_states)
+ if bool(layer.is_linear):
+ residual = layer.linear_attn(
+ normalized,
+ mask=selected_mask,
+ cache=layer_cache,
+ )
+ else:
+ residual = layer.self_attn(
+ normalized,
+ mask=selected_mask,
+ cache=layer_cache,
+ )
+ mid = hidden_states + residual
+ mlp_input = layer.post_attention_layernorm(mid)
+ mlp_output = (
+ _row_stepped_sparse_moe(layer.mlp, mlp_input)
+ if _is_sparse_moe_block(layer.mlp)
+ else layer.mlp(mlp_input)
+ )
+ hidden_states = mid + mlp_output
+
+ post = inner.norm(hidden_states)
+ if text_model.args.tie_word_embeddings:
+ logits = inner.embed_tokens.as_linear(post)
+ else:
+ logits = text_model.lm_head(post)
+ return logits, post
+
+
+def _target_forward_route_locked_moe(
+ *,
+ model: Model,
+ inputs: mx.array,
+ cache: KVCacheType,
+) -> tuple[mx.array, mx.array]:
+ """Target forward with row-locked MoE routing and batched expert work."""
+ text_model = model.language_model
+ inner = text_model.model
+ hidden_states = inner.embed_tokens(inputs)
+ cache_list = cast("list[Any]", cache)
+ fa_mask = create_attention_mask(hidden_states, cache_list[inner.fa_idx])
+ ssm_mask = create_ssm_mask(hidden_states, cache_list[inner.ssm_idx])
+
+ for layer, layer_cache in zip(inner.layers, cache_list, strict=True):
+ selected_mask = ssm_mask if bool(layer.is_linear) else fa_mask
+ normalized = layer.input_layernorm(hidden_states)
+ if bool(layer.is_linear):
+ residual = layer.linear_attn(
+ normalized,
+ mask=selected_mask,
+ cache=layer_cache,
+ )
+ else:
+ residual = layer.self_attn(
+ normalized,
+ mask=selected_mask,
+ cache=layer_cache,
+ )
+ mid = hidden_states + residual
+ mlp_input = layer.post_attention_layernorm(mid)
+ mlp_output = (
+ _route_locked_sparse_moe(layer.mlp, mlp_input)
+ if _is_sparse_moe_block(layer.mlp)
+ else layer.mlp(mlp_input)
+ )
+ hidden_states = mid + mlp_output
+
+ post = inner.norm(hidden_states)
+ if text_model.args.tie_word_embeddings:
+ logits = inner.embed_tokens.as_linear(post)
+ else:
+ logits = text_model.lm_head(post)
+ return logits, post
+
+
+GdnPrefixHistory = dict[int, list[Any]]
+
+
+def _lm_head_from_hidden(model: Model, hidden: mx.array) -> mx.array:
+ text_model = model.language_model
+ inner = text_model.model
+ if text_model.args.tie_word_embeddings:
+ return inner.embed_tokens.as_linear(hidden)
+ return text_model.lm_head(hidden)
+
+
+def _target_post_norm_hidden(
+ *,
+ model: Model,
+ inputs: mx.array,
+ cache: KVCacheType,
+) -> mx.array:
+ """Run the target body and return post-norm hidden without ``lm_head``.
+
+ Prompt-cache repair and MTP-cache priming need exact post-norm target
+ hidden states, not vocabulary logits. The vendored Qwen inner model exposes
+ that post-norm body output directly, so these internal paths can avoid a
+ vocabulary projection whose output would be discarded.
+ """
+ text_model = getattr(model, "language_model", None)
+ inner = getattr(text_model, "model", None)
+ if callable(inner):
+ result = inner(
+ inputs,
+ cache=cast("list[Any]", cache),
+ return_hidden=True,
+ )
+ if isinstance(result, tuple):
+ post_hidden = result[0]
+ if isinstance(post_hidden, mx.array):
+ return post_hidden
+ if isinstance(result, mx.array):
+ return result
+ _logits, hidden = cast(
+ tuple[mx.array, mx.array],
+ model(inputs, cache=cache, return_hidden=True),
+ )
+ return hidden
+
+
+def _entry_is_trimmable(entry: Any) -> bool:
+ method = getattr(entry, "is_trimmable", None)
+ if not callable(method):
+ return False
+ try:
+ return bool(method())
+ except Exception:
+ return False
+
+
+def _trim_trimmable_cache_entries(cache: Sequence[Any], n_tokens: int) -> None:
+ if n_tokens <= 0:
+ return
+ for entry in cache:
+ trim = getattr(entry, "trim", None)
+ if _entry_is_trimmable(entry) and callable(trim):
+ trim(n_tokens)
+
+
+def _set_cache_entry_state_lazy(entry: Any, state: Any) -> None:
+ replace_state = getattr(entry, "replace_state", None)
+ if callable(replace_state):
+ replace_state(state)
+ return
+ current = getattr(entry, "state", None)
+ if (
+ isinstance(current, list)
+ and isinstance(state, list)
+ and len(current) == len(state)
+ ):
+ current[:] = state
+ return
+ entry.state = state
+
+
+def _make_gated_delta_state_history_kernel() -> Any:
+ source = """
+ auto n = thread_position_in_grid.z;
+ auto b_idx = n / Hv;
+ auto hv_idx = n % Hv;
+ auto hk_idx = hv_idx / (Hv / Hk);
+ constexpr int n_per_t = Dk / 32;
+
+ auto q_ = q + b_idx * T * Hk * Dk + hk_idx * Dk;
+ auto k_ = k + b_idx * T * Hk * Dk + hk_idx * Dk;
+ auto v_ = v + b_idx * T * Hv * Dv + hv_idx * Dv;
+ y += b_idx * T * Hv * Dv + hv_idx * Dv;
+
+ auto dk_idx = thread_position_in_threadgroup.x;
+ auto dv_idx = thread_position_in_grid.y;
+
+ auto i_state = state_in + (n * Dv + dv_idx) * Dk;
+ auto o_state = state_out + (n * Dv + dv_idx) * Dk;
+
+ float state[n_per_t];
+ for (int i = 0; i < n_per_t; ++i) {
+ auto s_idx = n_per_t * dk_idx + i;
+ state[i] = static_cast(i_state[s_idx]);
+ }
+
+ auto a_ = a + b_idx * T * Hv;
+ auto b_ = b + b_idx * T * Hv;
+
+ for (int t = 0; t < T; ++t) {
+ float a_val = static_cast(a_[hv_idx]);
+ float dt_val = static_cast(dt_bias[hv_idx]);
+ float x_g = a_val + dt_val;
+ float sp = (x_g > 20.0f) ? x_g : log(1.0f + exp(x_g));
+ float g_val = exp(-exp(static_cast(A_log[hv_idx])) * sp);
+ float beta_val = 1.0f / (1.0f + exp(-static_cast(b_[hv_idx])));
+
+ float kv_mem = 0.0f;
+ for (int i = 0; i < n_per_t; ++i) {
+ auto s_idx = n_per_t * dk_idx + i;
+ state[i] = state[i] * g_val;
+ kv_mem += state[i] * k_[s_idx];
+ }
+ kv_mem = simd_sum(kv_mem);
+
+ auto delta = (v_[dv_idx] - kv_mem) * beta_val;
+
+ float out = 0.0f;
+ for (int i = 0; i < n_per_t; ++i) {
+ auto s_idx = n_per_t * dk_idx + i;
+ state[i] = state[i] + k_[s_idx] * delta;
+ out += state[i] * q_[s_idx];
+ }
+ out = simd_sum(out);
+ if (thread_index_in_simdgroup == 0) {
+ y[dv_idx] = static_cast(out);
+ }
+
+ auto hist = state_history
+ + (((b_idx * T + t) * Hv + hv_idx) * Dv + dv_idx) * Dk;
+ for (int i = 0; i < n_per_t; ++i) {
+ auto s_idx = n_per_t * dk_idx + i;
+ hist[s_idx] = static_cast(state[i]);
+ }
+
+ q_ += Hk * Dk;
+ k_ += Hk * Dk;
+ v_ += Hv * Dv;
+ y += Hv * Dv;
+ a_ += Hv;
+ b_ += Hv;
+ }
+ for (int i = 0; i < n_per_t; ++i) {
+ auto s_idx = n_per_t * dk_idx + i;
+ o_state[s_idx] = static_cast(state[i]);
+ }
+ """
+ return mx.fast.metal_kernel(
+ name="gated_delta_step_state_history",
+ input_names=["q", "k", "v", "a", "b", "A_log", "dt_bias", "state_in", "T"],
+ output_names=["y", "state_out", "state_history"],
+ source=source,
+ )
+
+
+_gated_delta_state_history_kernel: Any | None = None
+
+
+def _gated_delta_update_with_state_history(
+ *,
+ q: mx.array,
+ k: mx.array,
+ v: mx.array,
+ a: mx.array,
+ b: mx.array,
+ a_log: mx.array,
+ dt_bias: mx.array,
+ state: mx.array,
+) -> tuple[mx.array, mx.array, mx.array]:
+ global _gated_delta_state_history_kernel
+ kernel = _gated_delta_state_history_kernel
+ if kernel is None:
+ kernel = _make_gated_delta_state_history_kernel()
+ _gated_delta_state_history_kernel = kernel
+ typed_kernel = cast(
+ Callable[..., tuple[mx.array, mx.array, mx.array]],
+ kernel,
+ )
+ batch_size, verify_width, num_key_heads, key_dim = k.shape
+ num_value_heads, value_dim = v.shape[2:]
+ return typed_kernel(
+ inputs=[q, k, v, a, b, a_log, dt_bias, state, verify_width],
+ template=[
+ ("InT", q.dtype),
+ ("StT", state.dtype),
+ ("Dk", key_dim),
+ ("Dv", value_dim),
+ ("Hk", num_key_heads),
+ ("Hv", num_value_heads),
+ ],
+ grid=(32, value_dim, batch_size * num_value_heads),
+ threadgroup=(32, 4, 1),
+ output_shapes=[
+ (batch_size, verify_width, num_value_heads, value_dim),
+ state.shape,
+ (batch_size, verify_width, num_value_heads, value_dim, key_dim),
+ ],
+ output_dtypes=[q.dtype, state.dtype, state.dtype],
+ )
+
+
+def _gdn_linear_attn_with_state_history(
+ gdn: Any,
+ inputs: mx.array,
+ *,
+ cache: Any,
+) -> tuple[mx.array, list[Any]]:
+ """Run one GatedDeltaNet layer and return cache states after each prefix."""
+ batch_size, verify_width, _ = inputs.shape
+ qkv = gdn.in_proj_qkv(inputs)
+ z = gdn.in_proj_z(inputs).reshape(
+ batch_size, verify_width, gdn.num_v_heads, gdn.head_v_dim
+ )
+ b = gdn.in_proj_b(inputs)
+ a = gdn.in_proj_a(inputs)
+
+ if cache is not None and cache[0] is not None:
+ conv_state = cache[0]
+ else:
+ conv_state = mx.zeros(
+ (batch_size, gdn.conv_kernel_size - 1, gdn.conv_dim),
+ dtype=inputs.dtype,
+ )
+
+ conv_input = mx.concatenate([conv_state, qkv], axis=1)
+ n_keep = gdn.conv_kernel_size - 1
+ if cache is not None:
+ cache[0] = mx.contiguous(conv_input[:, -n_keep:, :])
+ conv_prefix_states = [
+ mx.contiguous(conv_input[:, prefix_index + 1 : prefix_index + 1 + n_keep, :])
+ for prefix_index in range(verify_width)
+ ]
+
+ conv_out = nn.silu(gdn.conv1d(conv_input))
+ q, k_arr, v = [
+ t.reshape(batch_size, verify_width, h, d)
+ for t, h, d in zip(
+ mx.split(conv_out, [gdn.key_dim, 2 * gdn.key_dim], -1),
+ [gdn.num_k_heads, gdn.num_k_heads, gdn.num_v_heads],
+ [gdn.head_k_dim, gdn.head_k_dim, gdn.head_v_dim],
+ strict=True,
+ )
+ ]
+
+ state = cache[1] if cache else None
+ if state is None:
+ state = mx.zeros(
+ (batch_size, gdn.num_v_heads, gdn.head_v_dim, gdn.head_k_dim),
+ dtype=mx.float32,
+ )
+ inv_scale = k_arr.shape[-1] ** -0.5
+ q = inv_scale * q * mx.rsqrt((q * q).sum(axis=-1, keepdims=True) + 1e-6)
+ k_arr = k_arr * mx.rsqrt((k_arr * k_arr).sum(axis=-1, keepdims=True) + 1e-6)
+
+ out, state, state_history = _gated_delta_update_with_state_history(
+ q=q,
+ k=k_arr,
+ v=v,
+ a=a,
+ b=b,
+ a_log=gdn.A_log,
+ dt_bias=gdn.dt_bias,
+ state=state,
+ )
+
+ if cache is not None:
+ cache[1] = state
+ cache.advance(verify_width)
+
+ out = gdn.norm(out, z)
+ out = gdn.out_proj(out.reshape(batch_size, verify_width, -1))
+ prefix_states = [
+ [conv_prefix_states[prefix_index], state_history[:, prefix_index, ...]]
+ for prefix_index in range(verify_width)
+ ]
+ return out, prefix_states
+
+
+def _target_forward_with_gdn_state_history(
+ *,
+ model: Model,
+ inputs: mx.array,
+ cache: KVCacheType,
+ policy: _MoeVerifierPolicy,
+) -> tuple[mx.array, mx.array, GdnPrefixHistory]:
+ """Target forward with normal-width GDN state-history outputs."""
+ text_model = model.language_model
+ inner = text_model.model
+ hidden_states = inner.embed_tokens(inputs)
+ cache_list = cast("list[Any]", cache)
+ fa_mask = create_attention_mask(hidden_states, cache_list[inner.fa_idx])
+ ssm_mask = create_ssm_mask(hidden_states, cache_list[inner.ssm_idx])
+ if isinstance(ssm_mask, mx.array):
+ raise ValueError("state-history GDN verifier does not support SSM masks")
+ prefix_history: GdnPrefixHistory = {}
+
+ for layer_index, (layer, layer_cache) in enumerate(
+ zip(inner.layers, cache_list, strict=True)
+ ):
+ selected_mask = ssm_mask if bool(layer.is_linear) else fa_mask
+ normalized = layer.input_layernorm(hidden_states)
+ if bool(layer.is_linear):
+ if selected_mask is not None:
+ raise ValueError(
+ "state-history GDN verifier requires unmasked decode cache"
+ )
+ residual, layer_history = _gdn_linear_attn_with_state_history(
+ layer.linear_attn,
+ normalized,
+ cache=layer_cache,
+ )
+ prefix_history[layer_index] = layer_history
+ else:
+ residual = layer.self_attn(
+ normalized,
+ mask=selected_mask,
+ cache=layer_cache,
+ )
+ mid = hidden_states + residual
+ mlp_input = layer.post_attention_layernorm(mid)
+ if _is_sparse_moe_block(layer.mlp):
+ if policy == "route_locked":
+ mlp_output = _route_locked_sparse_moe(layer.mlp, mlp_input)
+ elif policy == "row_moe":
+ mlp_output = _row_stepped_sparse_moe(layer.mlp, mlp_input)
+ else:
+ raise ValueError(
+ f"unsupported state-history MoE verifier policy: {policy}"
+ )
+ else:
+ mlp_output = layer.mlp(mlp_input)
+ hidden_states = mid + mlp_output
+
+ post = inner.norm(hidden_states)
+ logits = _lm_head_from_hidden(model, post)
+ return logits, post, prefix_history
+
+
+def _commit_gdn_prefix_state(
+ *,
+ cache: KVCacheType,
+ prefix_history: GdnPrefixHistory,
+ accept_count: int,
+ k: int,
+) -> None:
+ """Commit verifier cache to ``[current, drafts[:accept_count]]``."""
+ cache_list = cast("list[Any]", cache)
+ _trim_trimmable_cache_entries(cache_list, k - accept_count)
+ for cache_index, states_by_prefix in prefix_history.items():
+ _set_cache_entry_state_lazy(
+ cache_list[cache_index],
+ states_by_prefix[accept_count],
+ )
+
+
+def prime_mtp_cache_from_prompt(
+ *,
+ model: Model,
+ full_prompt_tokens: Sequence[int],
+ mtp_cache: list[Any],
+) -> int:
+ """Walk prefill positions through the main model to prime the MTP cache.
+
+ Forwards ``prompt[:-1]`` through a fresh main-model cache, captures
+ post-norm hidden states for every prefill position, then feeds
+ them through ``model.mtp_update_cache`` in a single batched call.
+ After priming, the MTP attention cache holds K/V for positions
+ ``0..N-2``; the first ``mtp_forward`` in the K-chain draft loop
+ adds position ``N-1`` naturally.
+
+ Returns the number of positions primed (``N-1``), or ``0`` when
+ the prompt is shorter than 2 tokens.
+
+ Known cost: this re-walks the prompt through a fresh main-model cache
+ to prime the separate MTP cache, so the prompt is prefilled twice (once
+ by ``exo.prefill`` for the main KV cache, once here for the MTP cache).
+ Folding the two into a single capture-prefill is a future optimisation
+ at the :func:`exo.worker.engines.mlx.prefill` seam, not in this drafter.
+ """
+ n = len(full_prompt_tokens)
+ if n < 2:
+ return 0
+ prefill_arr = mx.array(list(full_prompt_tokens[:-1]), dtype=mx.int32)
+ next_arr = mx.array(list(full_prompt_tokens[1:]), dtype=mx.int32)
+ fresh_cache: list[Any] = cast("list[Any]", model.make_cache())
+ prefill_hidden = _target_post_norm_hidden(
+ model=model,
+ inputs=prefill_arr[None],
+ cache=fresh_cache,
+ )
+ mx.eval(prefill_hidden)
+ model.mtp_update_cache(prefill_hidden, next_arr[None], mtp_cache=mtp_cache)
+ _materialize_cache_state(mtp_cache)
+ del fresh_cache
+ return n - 1
+
+
+def prime_mtp_cache_from_prompt_incremental(
+ *,
+ model: Model,
+ full_prompt_tokens: Sequence[int],
+ mtp_cache: list[Any],
+) -> int:
+ """Prime MTP cache by stepping target hiddens token by token.
+
+ This is the correctness path for qwen3_5_moe. Its recurrent
+ GatedDeltaNet state is sensitive to batched prefill; the batched
+ :func:`prime_mtp_cache_from_prompt` can produce hiddens that are close
+ enough for ordinary logits but wrong enough for MTP acceptance.
+ """
+ n = len(full_prompt_tokens)
+ if n < 2:
+ return 0
+ fresh_cache: list[Any] = cast("list[Any]", model.make_cache())
+ for i in range(n - 1):
+ token_arr = mx.array([[int(full_prompt_tokens[i])]], dtype=mx.int32)
+ next_arr = mx.array([[int(full_prompt_tokens[i + 1])]], dtype=mx.int32)
+ hidden = _target_post_norm_hidden(
+ model=model,
+ inputs=token_arr,
+ cache=fresh_cache,
+ )
+ mtp_hidden = model.mtp_update_cache(
+ hidden[:, -1:, :],
+ next_arr,
+ mtp_cache=mtp_cache,
+ )
+ mx.eval(hidden, mtp_hidden)
+ _materialize_cache_state(mtp_cache)
+ del fresh_cache
+ return n - 1
+
+
+def rebuild_prompt_cache_and_prime_mtp_cache_incremental(
+ *,
+ model: Model,
+ full_prompt_tokens: Sequence[int],
+ prompt_cache: KVCacheType,
+ mtp_cache: list[Any],
+) -> tuple[int, int]:
+ """Rebuild target prompt cache while priming MTP cache in one pass."""
+ n = len(full_prompt_tokens)
+ cache_list = cast("list[Any]", prompt_cache)
+ if n < 2:
+ cache_list[:] = cast("list[Any]", model.make_cache())
+ return 0, 0
+
+ tokens_to_prefill = max(0, n - 2)
+ fresh_cache: list[Any] = cast("list[Any]", model.make_cache())
+ prompt_boundary_snapshot: Any | None = None
+ for i in range(n - 1):
+ token_arr = mx.array([[int(full_prompt_tokens[i])]], dtype=mx.int32)
+ next_arr = mx.array([[int(full_prompt_tokens[i + 1])]], dtype=mx.int32)
+ hidden = _target_post_norm_hidden(
+ model=model,
+ inputs=token_arr,
+ cache=fresh_cache,
+ )
+ mtp_hidden = model.mtp_update_cache(
+ hidden[:, -1:, :],
+ next_arr,
+ mtp_cache=mtp_cache,
+ )
+ mx.eval(hidden, mtp_hidden)
+ if i + 1 == tokens_to_prefill:
+ prompt_boundary_snapshot = snapshot_untrimmable_cache_lazy(fresh_cache)
+
+ if tokens_to_prefill == 0:
+ cache_list[:] = cast("list[Any]", model.make_cache())
+ else:
+ _trim_trimmable_cache_entries(fresh_cache, (n - 1) - tokens_to_prefill)
+ if prompt_boundary_snapshot is None:
+ raise RuntimeError("missing prompt boundary snapshot")
+ restore_cache(fresh_cache, prompt_boundary_snapshot)
+ cache_list[:] = fresh_cache
+ _materialize_cache_state(mtp_cache)
+ return tokens_to_prefill, n - 1
+
+
+def rebuild_prompt_cache_incremental(
+ *,
+ model: Model,
+ full_prompt_tokens: Sequence[int],
+ prompt_cache: KVCacheType,
+) -> int:
+ """Replace ``prompt_cache`` with token-by-token state through prompt[:-2].
+
+ Exo's standard prefill uses mlx-lm's chunked prefill and then trims the
+ decode tail. That is fast, but qwen3_5_moe's GatedDeltaNet recurrent
+ state is not equivalent to a token-by-token cache under that path. Native
+ MTP needs the exact recurrent hidden/cache state, so the 35B MoE safe path
+ rebuilds it incrementally before the two-token decode tail is consumed.
+ """
+ fresh_cache: list[Any] = cast("list[Any]", model.make_cache())
+ tokens_to_prefill = max(0, len(full_prompt_tokens) - 2)
+ for token in full_prompt_tokens[:tokens_to_prefill]:
+ token_arr = mx.array([[int(token)]], dtype=mx.int32)
+ hidden = _target_post_norm_hidden(
+ model=model,
+ inputs=token_arr,
+ cache=fresh_cache,
+ )
+ mx.eval(hidden)
+ cache_list = cast("list[Any]", prompt_cache)
+ cache_list[:] = fresh_cache
+ return tokens_to_prefill
+
+
+@final
+class NativeMTPDrafter:
+ """Drafter-protocol shim for native MTP (vendored Qwen3.5/3.6).
+
+ Constructor takes ``k`` explicitly; the builder sources it from
+ :attr:`exo.shared.models.model_cards.NativeMTPConfig.default_k`
+ (overridable per-request via ``TaskParams.num_draft_tokens``).
+
+ The drafter reports ``mode="model"`` so existing telemetry counts
+ it as drafter-active. The architecture-level distinction
+ (``native_mtp`` vs sibling-LM ``model``) is surfaced via the
+ builder log line at dispatch time; a future :data:`DraftMode`
+ extension can expose it via ``GenerationStats`` if needed.
+ """
+
+ def __init__(self, *, k: int) -> None:
+ if k < 1:
+ raise ValueError(f"k must be >= 1, got {k}")
+ self._k: int = k
+ self._metrics: dict[str, int] = {
+ "proposed_draft_tokens": 0,
+ "accepted_draft_tokens": 0,
+ "spec_decode_rounds": 0,
+ }
+
+ @property
+ def mode(self) -> str:
+ return "model"
+
+ @property
+ def num_draft_tokens(self) -> int:
+ return self._k
+
+ def metrics(self) -> dict[str, int]:
+ return dict(self._metrics)
+
+ def stream(
+ self,
+ *,
+ model: Model,
+ tokenizer: TokenizerWrapper,
+ prompt: mx.array,
+ context_tokens: Sequence[int],
+ prompt_cache: KVCacheType,
+ max_tokens: int,
+ sampler: Callable[[mx.array], mx.array],
+ logits_processors: Sequence[Callable[[mx.array, mx.array], mx.array]],
+ prefill_step_size: int = 1,
+ temperature: float = 0.0,
+ top_p: float = 0.0,
+ top_k: int = 0,
+ min_p: float = 0.0,
+ min_tokens_to_keep: int = 1,
+ ) -> Generator[GenerationResponse, None, None]:
+ # LOSSLESS / distribution-preserving native MTP. The K-chain drafts
+ # are SAMPLED from the MTP head's per-step distribution ``q`` (after
+ # the request's temperature/top_p/top_k/min_p transforms) and the
+ # batched/sequential verify accepts each draft via probability-ratio
+ # (Leviathan/Chen) rejection sampling against the target distribution
+ # ``p`` -- so the emitted marginal matches the request's sampler at any
+ # temperature. With ``temperature == 0`` the scheme collapses to exact
+ # greedy (accept iff draft == argmax(p)), token-identical to plain
+ # greedy decode. See :class:`_SpecSampler` for the reconstruction of
+ # ``p``/``q`` from the raw sampling params (the request only hands us a
+ # ``sampler`` callable, which is insufficient -- we need the
+ # distribution, so the params are threaded in from ``mlx_generate``).
+ #
+ # ``logits_processors`` (repetition/presence/frequency penalties,
+ # logit bias) remain UNHANDLED on this path: they are stateful over
+ # emitted-token history and cannot be replayed consistently across the
+ # chained draft steps and the batched verify without breaking the
+ # lossless guarantee; applying them inconsistently is worse than not at
+ # all. ``sampler`` is likewise unused -- the params drive the
+ # reconstructed distribution directly. ``prefill_step_size`` is unused
+ # (MTP-cache priming is batched).
+ del sampler, logits_processors, prefill_step_size
+ spec_sampler = _SpecSampler(
+ temperature=temperature,
+ top_p=top_p,
+ top_k=top_k,
+ min_p=min_p,
+ min_tokens_to_keep=min_tokens_to_keep,
+ )
+ if not is_native_mtp_dispatchable(model):
+ raise RuntimeError(
+ "NativeMTPDrafter.stream called on a model without "
+ "mtp_forward / make_mtp_cache / mtp_update_cache. The "
+ "dispatch-site gate (is_native_mtp_dispatchable) should "
+ "have caught this."
+ )
+
+ detokenizer = tokenizer.detokenizer
+ detokenizer.reset()
+ eos_ids = _eos_ids_from_tokenizer(tokenizer)
+
+ prompt_tail_size = int(prompt.size)
+
+ mtp_cache: list[Any] = cast("list[Any]", model.make_mtp_cache())
+ if not mtp_cache:
+ raise RuntimeError(
+ "model.make_mtp_cache() returned empty list -- model "
+ "has no MTP head. Loader/dispatch drift."
+ )
+
+ model_type = _native_mtp_model_type(model)
+ sequential_gdn_path = _requires_sequential_gdn_path(model)
+ moe_verifier_policy = _moe_verifier_policy() if sequential_gdn_path else "safe"
+ if sequential_gdn_path and moe_verifier_policy != "safe":
+ logger.info(
+ f"[native MTP] qwen3_5_moe verifier policy={moe_verifier_policy!r}"
+ )
+ gdn_state_history_default = _gdn_state_history_commit_default(
+ model_type=model_type,
+ moe_verifier_policy=moe_verifier_policy,
+ )
+ gdn_state_history_supported = (
+ _gdn_state_history_kernel_supported(model)
+ if gdn_state_history_default
+ else False
+ )
+ gdn_state_history_commit = (
+ gdn_state_history_default
+ and gdn_state_history_supported
+ # Default-on for dense qwen3_5 and for qwen3_5_moe route-locked
+ # native MTP after broad production stream parity. Set the env var
+ # to 0 as a kill switch.
+ and _env_flag("EXO_NATIVE_MTP_GDN_STATE_HISTORY_COMMIT", default=True)
+ )
+ if gdn_state_history_default and not gdn_state_history_supported:
+ logger.warning(
+ f"[native MTP] {model_type or 'unknown model'} GDN "
+ "state-history prefix commit disabled: unsupported GDN head shape"
+ )
+ if gdn_state_history_commit:
+ logger.info(
+ f"[native MTP] {model_type or 'unknown model'} GDN "
+ "state-history prefix commit enabled"
+ )
+ row_moe_adaptive = (
+ sequential_gdn_path
+ and moe_verifier_policy == "row_moe"
+ and _env_flag("EXO_NATIVE_MTP_ROW_MOE_ADAPTIVE", default=True)
+ )
+ row_moe_adaptive_min_rounds = _env_int(
+ "EXO_NATIVE_MTP_ROW_MOE_ADAPTIVE_MIN_ROUNDS",
+ default=8,
+ minimum=1,
+ )
+ row_moe_adaptive_min_acceptance = _env_float(
+ "EXO_NATIVE_MTP_ROW_MOE_ADAPTIVE_MIN_ACCEPTANCE",
+ default=0.55,
+ )
+ row_moe_adaptive_switched = False
+ primed_positions = 0
+ if context_tokens:
+ prime_start = time.perf_counter()
+ if sequential_gdn_path:
+ repaired_positions, primed_positions = (
+ rebuild_prompt_cache_and_prime_mtp_cache_incremental(
+ model=model,
+ full_prompt_tokens=context_tokens,
+ prompt_cache=prompt_cache,
+ mtp_cache=mtp_cache,
+ )
+ )
+ prime_elapsed_ms = (time.perf_counter() - prime_start) * 1000
+ logger.info(
+ "[native MTP] qwen3_5_moe safe path rebuilt target cache "
+ f"from {repaired_positions} prompt positions and primed "
+ f"MTP cache from {primed_positions} prompt positions in "
+ f"{prime_elapsed_ms:.1f}ms"
+ )
+ else:
+ primed_positions = prime_mtp_cache_from_prompt(
+ model=model,
+ full_prompt_tokens=context_tokens,
+ mtp_cache=mtp_cache,
+ )
+ prime_elapsed_ms = (time.perf_counter() - prime_start) * 1000
+ logger.info(
+ f"[native MTP] primed MTP cache from {primed_positions} "
+ f"prompt positions in {prime_elapsed_ms:.1f}ms"
+ )
+
+ # First main-model forward: decode_in covers the 2-token prefill
+ # tail. The caller has aligned ``prompt_cache`` to
+ # ``full_prompt[:-2]`` via exo.prefill + trim(2); this brings
+ # the cache to ``len-1`` via position N-2 and returns logits
+ # at position N-1 (which we sample for the first emitted token).
+ decode_in = prompt.astype(mx.int32)[None]
+ if int(decode_in.shape[1]) > 1:
+ # Qwen3.5/3.6 GatedDeltaNet can produce different last-position
+ # logits for a small batched tail than for the same tokens stepped
+ # through the recurrent state. Mirror mlx-lm's prefill/decode split:
+ # advance every tail token except the last, then sample from the last.
+ tail_prefix_hidden = _target_post_norm_hidden(
+ model=model,
+ inputs=decode_in[:, :-1],
+ cache=prompt_cache,
+ )
+ first_logits, first_hidden = cast(
+ tuple[mx.array, mx.array],
+ model(decode_in[:, -1:], cache=prompt_cache, return_hidden=True),
+ )
+ mx.eval(tail_prefix_hidden, first_logits, first_hidden)
+ else:
+ first_logits, first_hidden = cast(
+ tuple[mx.array, mx.array],
+ model(decode_in, cache=prompt_cache, return_hidden=True),
+ )
+ mx.eval(first_logits, first_hidden)
+ # Sample the first emitted token from the request's distribution
+ # (argmax when greedy) instead of a hard argmax.
+ current_token_arr = spec_sampler.sample_from_distribution(
+ spec_sampler.distribution(first_logits[:, -1, :])
+ )
+ mx.eval(current_token_arr)
+ current_token = _token_array_to_int(current_token_arr)
+ current_hidden = first_hidden[:, -1:, :]
+
+ proposed = 0
+ accepted = 0
+ rounds = 0
+ ntoks = 0
+ finish_reason: str | None = None
+ prompt_tps_local = 0.0
+
+ tic = time.perf_counter()
+ if current_token in eos_ids:
+ detokenizer.add_token(current_token)
+ detokenizer.finalize()
+ yield GenerationResponse(
+ text=detokenizer.last_segment,
+ token=current_token,
+ logprobs=mx.zeros((1,)),
+ from_draft=False,
+ prompt_tokens=prompt_tail_size,
+ prompt_tps=0.0,
+ generation_tokens=1,
+ generation_tps=0.0,
+ peak_memory=mx.get_peak_memory() / 1e9,
+ finish_reason="stop",
+ )
+ return
+
+ prompt_time = time.perf_counter() - tic
+ prompt_tps_local = prompt_tail_size / prompt_time if prompt_time > 0 else 0.0
+ tic = time.perf_counter()
+ ntoks = 1
+ detokenizer.add_token(current_token)
+ elapsed = time.perf_counter() - tic
+ yield GenerationResponse(
+ text=detokenizer.last_segment,
+ token=current_token,
+ logprobs=mx.zeros((1,)),
+ from_draft=False,
+ prompt_tokens=prompt_tail_size,
+ prompt_tps=prompt_tps_local,
+ generation_tokens=ntoks,
+ generation_tps=ntoks / elapsed if elapsed > 0 else 0.0,
+ peak_memory=mx.get_peak_memory() / 1e9,
+ finish_reason=None,
+ )
+
+ # Main loop: cache offset = N, current_{token, hidden} from
+ # most recent emission.
+ while ntoks < max_tokens:
+ # ---- Draft K via chained mtp_forward ----
+ # Each draft token is SAMPLED from its MTP-step distribution ``q``
+ # (argmax when greedy); ``q`` is retained per step so the verify
+ # step can run the probability-ratio acceptance test.
+ draft_token_arrays: list[mx.array] = []
+ draft_distributions: list[mx.array] = []
+ chain_hidden = current_hidden
+ chain_tok_arr = current_token_arr
+ for _ in range(self._k):
+ mtp_logits, mtp_hidden = cast(
+ tuple[mx.array, mx.array],
+ model.mtp_forward(
+ chain_hidden,
+ chain_tok_arr,
+ mtp_cache=mtp_cache,
+ return_hidden=True,
+ ),
+ )
+ draft_dist = spec_sampler.distribution(mtp_logits[:, -1, :])
+ next_draft_arr = spec_sampler.sample_from_distribution(draft_dist)
+ draft_token_arrays.append(next_draft_arr)
+ draft_distributions.append(draft_dist)
+ chain_hidden = mtp_hidden[:, -1:, :]
+ chain_tok_arr = next_draft_arr
+ # Consolidated eval: single GPU dispatch for all draft arrays
+ all_drafts = (
+ draft_token_arrays[0]
+ if self._k == 1
+ else mx.concatenate(draft_token_arrays, axis=1)
+ )
+ if _ASYNC_DRAFT_EVAL_ENABLED:
+ _async_materialize_cache_state(
+ mtp_cache,
+ all_drafts,
+ *draft_distributions,
+ )
+ else:
+ mx.eval(all_drafts, *draft_distributions)
+ _materialize_cache_state(mtp_cache)
+ proposed += self._k
+ rounds += 1
+
+ # Pre-draw the per-position uniforms for the ratio test up-front so
+ # the acceptance walk needs no extra GPU sync per token.
+ accept_uniforms: list[float] = (
+ [0.0] * self._k
+ if spec_sampler.is_greedy
+ else cast(
+ "list[float]",
+ mx.random.uniform(shape=(self._k,)).tolist(),
+ )
+ )
+
+ # ---- Verify: target forward on [current_token, *drafts] ----
+ # Consolidated readback: single CPU-GPU sync for all K draft tokens
+ # instead of K sequential .item() calls (~50-150us each).
+ draft_tokens: list[int] = [
+ int(t) for t in list(all_drafts.reshape(-1).tolist())
+ ]
+ num_accepted_this_round = 0
+ bonus_tok = current_token
+ next_current_hidden = current_hidden
+ next_current_token_arr = current_token_arr
+ pre_verify_snapshot: Any | None = None
+ prefix_history: GdnPrefixHistory | None = None
+ verify_hidden: mx.array | None = None
+
+ if sequential_gdn_path:
+ used_batched_moe_verify = False
+ if moe_verifier_policy != "safe":
+ verify_in = mx.concatenate(
+ [current_token_arr, *draft_token_arrays], axis=1
+ )
+ if gdn_state_history_commit:
+ (
+ verify_logits,
+ verify_hidden,
+ prefix_history,
+ ) = _target_forward_with_gdn_state_history(
+ model=model,
+ inputs=verify_in,
+ cache=prompt_cache,
+ policy=moe_verifier_policy,
+ )
+ else:
+ pre_verify_snapshot = snapshot_untrimmable_cache_lazy(
+ cast("list[Any]", prompt_cache)
+ )
+ verify_logits, verify_hidden = (
+ _target_forward_row_moe(
+ model=model,
+ inputs=verify_in,
+ cache=prompt_cache,
+ )
+ if moe_verifier_policy == "row_moe"
+ else _target_forward_route_locked_moe(
+ model=model,
+ inputs=verify_in,
+ cache=prompt_cache,
+ )
+ if moe_verifier_policy == "route_locked"
+ else cast(
+ tuple[mx.array, mx.array],
+ model(
+ verify_in,
+ cache=prompt_cache,
+ return_hidden=True,
+ ),
+ )
+ )
+ mx.eval(verify_logits, verify_hidden)
+ batched_accepted, batched_replacement_arr = _spec_accept_walk(
+ spec_sampler=spec_sampler,
+ verify_logits=verify_logits,
+ draft_tokens=draft_tokens,
+ draft_distributions=draft_distributions,
+ accept_uniforms=accept_uniforms,
+ k=self._k,
+ )
+ mx.eval(batched_replacement_arr)
+
+ can_commit_batched = (
+ moe_verifier_policy in {"batched", "row_moe", "route_locked"}
+ or batched_accepted == self._k
+ )
+ if can_commit_batched:
+ num_accepted_this_round = batched_accepted
+ bonus_tok = _token_array_to_int(batched_replacement_arr)
+ next_current_token_arr = batched_replacement_arr
+ next_current_hidden = verify_hidden[
+ :,
+ num_accepted_this_round : num_accepted_this_round + 1,
+ :,
+ ]
+ used_batched_moe_verify = True
+ else:
+ assert pre_verify_snapshot is not None
+ rollback_after_verify(
+ cast("list[Any]", prompt_cache),
+ pre_verify_snapshot,
+ verified_tokens=self._k + 1,
+ )
+
+ if not used_batched_moe_verify:
+ # qwen3_5_moe safe verifier: feed only the retained prefix
+ # token-by-token. Batched verification corrupts later-
+ # position GatedDeltaNet state/logits on 35B-A3B, producing
+ # coherent-looking but wrong bonuses such as "if items"
+ # instead of "if len". Each position runs the same
+ # probability-ratio acceptance test as the batched paths.
+ for i in range(self._k + 1):
+ verify_token_arr = (
+ current_token_arr if i == 0 else draft_token_arrays[i - 1]
+ )
+ verify_logits_i, verify_hidden_i = cast(
+ tuple[mx.array, mx.array],
+ model(
+ verify_token_arr,
+ cache=prompt_cache,
+ return_hidden=True,
+ ),
+ )
+ target_dist_i = spec_sampler.distribution(
+ verify_logits_i[:, -1, :]
+ )
+ mx.eval(target_dist_i, verify_hidden_i)
+ next_current_hidden = verify_hidden_i[:, -1:, :]
+ if i < self._k:
+ draft_token = draft_tokens[i]
+ target_argmax = int(
+ mx.argmax(verify_logits_i[:, -1, :], axis=-1).item()
+ )
+ target_prob = float(target_dist_i[0, draft_token].item())
+ draft_prob = float(
+ draft_distributions[i][0, draft_token].item()
+ )
+ if _accept_draft_token(
+ is_greedy=spec_sampler.is_greedy,
+ draft_token=draft_token,
+ target_argmax=target_argmax,
+ target_prob=target_prob,
+ draft_prob=draft_prob,
+ uniform=accept_uniforms[i],
+ ):
+ num_accepted_this_round += 1
+ continue
+ # Rejected: replacement from residual (p_i - q_i)+.
+ replacement_arr = spec_sampler.residual_token(
+ target_dist_i, draft_distributions[i]
+ )
+ else:
+ # Full accept: bonus sampled from the target dist.
+ replacement_arr = spec_sampler.sample_from_distribution(
+ target_dist_i
+ )
+ mx.eval(replacement_arr)
+ bonus_tok = _token_array_to_int(replacement_arr)
+ next_current_token_arr = replacement_arr
+ break
+ elif num_accepted_this_round < self._k:
+ if gdn_state_history_commit:
+ assert prefix_history is not None
+ _commit_gdn_prefix_state(
+ cache=prompt_cache,
+ prefix_history=prefix_history,
+ accept_count=num_accepted_this_round,
+ k=self._k,
+ )
+ else:
+ assert pre_verify_snapshot is not None
+ rollback_after_verify(
+ cast("list[Any]", prompt_cache),
+ pre_verify_snapshot,
+ verified_tokens=self._k + 1,
+ )
+ repair_in = mx.concatenate(
+ [
+ current_token_arr,
+ *draft_token_arrays[:num_accepted_this_round],
+ ],
+ axis=1,
+ )
+ _repair_logits, _repair_hidden = (
+ _target_forward_row_moe(
+ model=model,
+ inputs=repair_in,
+ cache=prompt_cache,
+ )
+ if moe_verifier_policy == "row_moe"
+ else _target_forward_route_locked_moe(
+ model=model,
+ inputs=repair_in,
+ cache=prompt_cache,
+ )
+ if moe_verifier_policy == "route_locked"
+ else cast(
+ tuple[mx.array, mx.array],
+ model(
+ repair_in,
+ cache=prompt_cache,
+ return_hidden=True,
+ ),
+ )
+ )
+ mx.eval(_repair_logits, _repair_hidden)
+ next_current_hidden = _repair_hidden[:, -1:, :]
+ accepted += num_accepted_this_round
+ if (
+ row_moe_adaptive
+ and not row_moe_adaptive_switched
+ and rounds >= row_moe_adaptive_min_rounds
+ and proposed > 0
+ and (accepted / proposed) < row_moe_adaptive_min_acceptance
+ ):
+ logger.info(
+ "[native MTP] row_moe acceptance "
+ f"{accepted}/{proposed}={accepted / proposed:.1%} after "
+ f"{rounds} rounds; switching verifier to safe for this "
+ "generation"
+ )
+ moe_verifier_policy = "safe"
+ row_moe_adaptive_switched = True
+ else:
+ verify_in = mx.concatenate(
+ [current_token_arr, *draft_token_arrays], axis=1
+ )
+ if gdn_state_history_commit:
+ (
+ verify_logits,
+ verify_hidden,
+ prefix_history,
+ ) = _target_forward_with_gdn_state_history(
+ model=model,
+ inputs=verify_in,
+ cache=prompt_cache,
+ policy=moe_verifier_policy,
+ )
+ else:
+ # Snapshot GDN state BEFORE verify advances cache by K+1 so
+ # we can roll back on partial accept; see module docstring.
+ pre_verify_snapshot = snapshot_untrimmable_cache_lazy(
+ cast("list[Any]", prompt_cache)
+ )
+ verify_logits, verify_hidden = cast(
+ tuple[mx.array, mx.array],
+ model(verify_in, cache=prompt_cache, return_hidden=True),
+ )
+ mx.eval(verify_logits, verify_hidden)
+
+ # Probability-ratio accept-walk (lossless speculative sampling).
+ num_accepted_this_round, replacement_arr = _spec_accept_walk(
+ spec_sampler=spec_sampler,
+ verify_logits=verify_logits,
+ draft_tokens=draft_tokens,
+ draft_distributions=draft_distributions,
+ accept_uniforms=accept_uniforms,
+ k=self._k,
+ )
+ mx.eval(replacement_arr)
+ accepted += num_accepted_this_round
+ bonus_tok = _token_array_to_int(replacement_arr)
+ next_current_token_arr = replacement_arr
+ next_current_hidden = verify_hidden[
+ :, num_accepted_this_round : num_accepted_this_round + 1, :
+ ]
+
+ # Emit accepted drafts.
+ for i in range(num_accepted_this_round):
+ tok = draft_tokens[i]
+ ntoks += 1
+ detokenizer.add_token(tok)
+ elapsed = time.perf_counter() - tic
+ yield GenerationResponse(
+ text=detokenizer.last_segment,
+ token=tok,
+ logprobs=mx.zeros((1,)),
+ from_draft=True,
+ prompt_tokens=prompt_tail_size,
+ prompt_tps=prompt_tps_local,
+ generation_tokens=ntoks,
+ generation_tps=ntoks / elapsed if elapsed > 0 else 0.0,
+ peak_memory=mx.get_peak_memory() / 1e9,
+ finish_reason=None,
+ )
+ if tok in eos_ids:
+ finish_reason = "stop"
+ break
+ if ntoks >= max_tokens:
+ finish_reason = "length"
+ break
+
+ # Emit bonus.
+ if finish_reason is None:
+ ntoks += 1
+ detokenizer.add_token(bonus_tok)
+ elapsed = time.perf_counter() - tic
+ yield GenerationResponse(
+ text=detokenizer.last_segment,
+ token=bonus_tok,
+ logprobs=mx.zeros((1,)),
+ from_draft=False,
+ prompt_tokens=prompt_tail_size,
+ prompt_tps=prompt_tps_local,
+ generation_tokens=ntoks,
+ generation_tps=ntoks / elapsed if elapsed > 0 else 0.0,
+ peak_memory=mx.get_peak_memory() / 1e9,
+ finish_reason=None,
+ )
+ if bonus_tok in eos_ids:
+ finish_reason = "stop"
+
+ # ---- Cache trims ----
+ # Main cache: on partial accept (n < K) we trim back to the
+ # pre-verify offset and restore GDN from the snapshot, then
+ # re-forward [current_token, *drafts[:n]] to land cache at
+ # N + n + 1 with the GDN recurrent state advanced through
+ # the retained tokens only. On full accept (n == K) the
+ # batched verify's cache state at N+K+1 is already on-
+ # trajectory (verified by batched_vs_solo_probe in the
+ # research lane), so no rollback is needed.
+ if not sequential_gdn_path and num_accepted_this_round < self._k:
+ try:
+ if gdn_state_history_commit:
+ assert prefix_history is not None
+ _commit_gdn_prefix_state(
+ cache=prompt_cache,
+ prefix_history=prefix_history,
+ accept_count=num_accepted_this_round,
+ k=self._k,
+ )
+ else:
+ assert pre_verify_snapshot is not None
+ rollback_after_verify(
+ cast("list[Any]", prompt_cache),
+ pre_verify_snapshot,
+ verified_tokens=self._k + 1,
+ )
+ repair_in = mx.concatenate(
+ [
+ current_token_arr,
+ *draft_token_arrays[:num_accepted_this_round],
+ ],
+ axis=1,
+ )
+ _repair_logits, _repair_hidden = cast(
+ tuple[mx.array, mx.array],
+ model(repair_in, cache=prompt_cache, return_hidden=True),
+ )
+ mx.eval(_repair_logits, _repair_hidden)
+ next_current_hidden = _repair_hidden[:, -1:, :]
+ except Exception as exc:
+ logger.warning(
+ f"[native MTP] cache repair after partial accept "
+ f"(n={num_accepted_this_round}/K={self._k}) raised "
+ f"{type(exc).__name__}: {exc}; continuing with stale cache"
+ )
+ # MTP cache: chain drafting added K positions; only the
+ # accepted prefix's K/V corresponds to tokens that materialised
+ # in the main stream. The MTP attention layer is full-attention
+ # (qwen3_5_mtp.py uses ``fa_layer_idx``) so the stock trim is
+ # correct here -- no rollback asymmetry to worry about.
+ mtp_retained = min(self._k, num_accepted_this_round + 1)
+ mtp_trim = self._k - mtp_retained
+ if mtp_trim > 0:
+ try:
+ mlx_trim_prompt_cache(cast(list[object], mtp_cache), mtp_trim)
+ except Exception as exc:
+ logger.warning(
+ f"[native MTP] mtp cache trim({mtp_trim}) raised "
+ f"{type(exc).__name__}: {exc}; continuing with stale cache"
+ )
+
+ if finish_reason is not None or ntoks >= max_tokens:
+ if finish_reason is None and ntoks >= max_tokens:
+ finish_reason = "length"
+ break
+
+ current_token = bonus_tok
+ current_token_arr = next_current_token_arr
+ current_hidden = next_current_hidden
+
+ detokenizer.finalize()
+ if finish_reason is None:
+ finish_reason = "length" if ntoks >= max_tokens else "stop"
+ elapsed = time.perf_counter() - tic
+ acceptance = (accepted / proposed) if proposed > 0 else 0.0
+ self._metrics["proposed_draft_tokens"] = proposed
+ self._metrics["accepted_draft_tokens"] = accepted
+ self._metrics["spec_decode_rounds"] = rounds
+ logger.info(
+ f"[native MTP] K={self._k} stream done: tokens={ntoks}, "
+ f"proposed={proposed}, accepted={accepted}, "
+ f"acceptance={acceptance:.1%}, rounds={rounds}, "
+ f"finish={finish_reason}, primed_positions={primed_positions}"
+ )
+ yield GenerationResponse(
+ text=detokenizer.last_segment,
+ token=current_token,
+ logprobs=mx.zeros((1,)),
+ from_draft=False,
+ prompt_tokens=prompt_tail_size,
+ prompt_tps=prompt_tps_local,
+ generation_tokens=ntoks,
+ generation_tps=ntoks / elapsed if elapsed > 0 else 0.0,
+ peak_memory=mx.get_peak_memory() / 1e9,
+ finish_reason=finish_reason,
+ )
+
+
+__all__ = [
+ "NativeMTPDrafter",
+ "is_native_mtp_dispatchable",
+ "prime_mtp_cache_from_prompt",
+]
diff --git a/src/exo/worker/engines/mlx/mtp_probe.py b/src/exo/worker/engines/mlx/mtp_probe.py
new file mode 100644
index 0000000000..00226a4f26
--- /dev/null
+++ b/src/exo/worker/engines/mlx/mtp_probe.py
@@ -0,0 +1,244 @@
+"""MTP (Multi-Token Prediction) weight probing for Qwen3.6 models.
+
+Detects MTP weights across all known distribution formats:
+- Original HuggingFace: embedded in main shards with ``mtp.*`` prefix (15 tensors)
+- MTPLX quantized: separate ``mtp.safetensors`` file with ``mtp.*`` prefix (29 tensors)
+- oMLX quantized: embedded in main shards with ``language_model.mtp.*`` prefix (29 tensors)
+- mlx-community: stripped during quantization (0 tensors, unrecoverable)
+
+The probe is called before model loading so the loader can:
+1. Inject separate MTP weights into the main weight dict (MTPLX format)
+2. Patch ``sanitize()`` to stop stripping MTP tensors
+3. Ensure norm weight shifting fires correctly
+"""
+
+from __future__ import annotations
+
+import json
+from dataclasses import dataclass
+from enum import Enum, auto
+from pathlib import Path
+from typing import Any, Iterable, cast
+
+from loguru import logger
+
+# pyright: reportAny=false
+# This module probes safetensors/JSON data which are untyped.
+
+
+class MtpFormat(Enum):
+ """How MTP weights are stored in this model."""
+
+ ORIGINAL_EMBEDDED = auto()
+ """Original HuggingFace format: ``mtp.*`` prefix, embedded in main shards (15 tensors, BF16)."""
+
+ MTPLX_SEPARATE_FILE = auto()
+ """MTPLX format: ``mtp.*`` prefix, separate ``mtp.safetensors`` file (29 tensors, quantized)."""
+
+ OMLX_EMBEDDED = auto()
+ """oMLX format: ``language_model.mtp.*`` prefix, embedded in main shards (29 tensors, quantized)."""
+
+ STRIPPED = auto()
+ """MTP weights were stripped during quantization (e.g. mlx-community). Unrecoverable."""
+
+
+@dataclass(frozen=True)
+class MtpProbeResult:
+ """Result of probing a model directory for MTP weights."""
+
+ model_declares_mtp: bool
+ """Whether ``config.json`` declares ``mtp_num_hidden_layers > 0``."""
+
+ mtp_tensors_found: bool
+ """Whether MTP weight tensors were found on disk."""
+
+ mtp_format: MtpFormat | None
+ """Detected storage format, or ``None`` if model doesn't declare MTP."""
+
+ mtp_count: int
+ """Number of MTP tensors found."""
+
+ mtp_path: str | None
+ """Path to MTP weights (either ``mtp.safetensors`` for separate file,
+ or description of embedded location)."""
+
+ mtp_tensor_keys: tuple[str, ...]
+ """Names of MTP tensor keys found (empty if STRIPPED or not found)."""
+
+ @property
+ def is_recoverable(self) -> bool:
+ """Whether MTP weights can be loaded and used."""
+ return (
+ self.mtp_tensors_found
+ and self.mtp_format is not None
+ and self.mtp_format != MtpFormat.STRIPPED
+ )
+
+
+def probe_mtp_weights(model_path: Path | str) -> MtpProbeResult:
+ """Probe a model directory for MTP weights in all known locations.
+
+ Checks in order:
+ 1. ``config.json`` → ``mlx_lm_extra_tensors.mtp_file`` (MTPLX separate file)
+ 2. ``model.safetensors.index.json`` → ``mtp.*`` prefix (Original HuggingFace)
+ 3. ``model.safetensors.index.json`` → ``language_model.mtp.*`` prefix (oMLX)
+ 4. ``mtp.safetensors`` file on disk (fallback, no config declaration)
+
+ Args:
+ model_path: Path to the model directory.
+
+ Returns:
+ ``MtpProbeResult`` describing what was found.
+ """
+ model_dir = Path(model_path)
+ config_path = model_dir / "config.json"
+
+ # Step 1: Check config for MTP declaration
+ model_declares_mtp = False
+ if config_path.exists():
+ try:
+ cfg = json.loads(config_path.read_text())
+ tc = cfg.get("text_config", {})
+ mtp_layers = tc.get("mtp_num_hidden_layers", 0)
+ model_declares_mtp = bool(mtp_layers and mtp_layers > 0)
+ except (json.JSONDecodeError, OSError):
+ pass
+
+ if not model_declares_mtp:
+ return MtpProbeResult(
+ model_declares_mtp=False,
+ mtp_tensors_found=False,
+ mtp_format=None,
+ mtp_count=0,
+ mtp_path=None,
+ mtp_tensor_keys=(),
+ )
+
+ # Step 2: Check for MTPLX separate file via mlx_lm_extra_tensors
+ try:
+ cfg = json.loads(config_path.read_text())
+ extra = cfg.get("mlx_lm_extra_tensors", {})
+ mtp_file_name = extra.get("mtp_file")
+ if mtp_file_name:
+ mtp_file = model_dir / mtp_file_name
+ if mtp_file.exists():
+ keys = _safetensors_keys(mtp_file)
+ mtp_keys = tuple(k for k in keys if "mtp." in k)
+ if mtp_keys:
+ return MtpProbeResult(
+ model_declares_mtp=True,
+ mtp_tensors_found=True,
+ mtp_format=MtpFormat.MTPLX_SEPARATE_FILE,
+ mtp_count=len(mtp_keys),
+ mtp_path=str(mtp_file),
+ mtp_tensor_keys=mtp_keys,
+ )
+ except (json.JSONDecodeError, OSError):
+ pass
+
+ # Step 3: Check weight map index for embedded MTP tensors
+ index_path = model_dir / "model.safetensors.index.json"
+ if index_path.exists():
+ try:
+ idx = json.loads(index_path.read_text())
+ weight_map = idx.get("weight_map", {})
+ all_keys = tuple(weight_map.keys())
+
+ # Check for oMLX prefix first (more specific)
+ omlx_keys = tuple(k for k in all_keys if "language_model.mtp." in str(k))
+ if omlx_keys:
+ return MtpProbeResult(
+ model_declares_mtp=True,
+ mtp_tensors_found=True,
+ mtp_format=MtpFormat.OMLX_EMBEDDED,
+ mtp_count=len(omlx_keys),
+ mtp_path="embedded in main shards (language_model.mtp.* prefix)",
+ mtp_tensor_keys=omlx_keys,
+ )
+
+ # Check for original prefix
+ orig_keys = tuple(k for k in all_keys if str(k).startswith("mtp."))
+ if orig_keys:
+ return MtpProbeResult(
+ model_declares_mtp=True,
+ mtp_tensors_found=True,
+ mtp_format=MtpFormat.ORIGINAL_EMBEDDED,
+ mtp_count=len(orig_keys),
+ mtp_path="embedded in main shards (mtp.* prefix)",
+ mtp_tensor_keys=orig_keys,
+ )
+ except (json.JSONDecodeError, OSError):
+ pass
+
+ # Step 4: Fallback — check for mtp.safetensors without config declaration
+ mtp_file = model_dir / "mtp.safetensors"
+ if mtp_file.exists():
+ keys = _safetensors_keys(mtp_file)
+ mtp_keys = tuple(k for k in keys if "mtp." in k)
+ if mtp_keys:
+ return MtpProbeResult(
+ model_declares_mtp=True,
+ mtp_tensors_found=True,
+ mtp_format=MtpFormat.MTPLX_SEPARATE_FILE,
+ mtp_count=len(mtp_keys),
+ mtp_path=str(mtp_file),
+ mtp_tensor_keys=mtp_keys,
+ )
+
+ # Step 5: Model declares MTP but no weights found — stripped during quantization
+ logger.warning(
+ f"Model at {model_dir} declares mtp_num_hidden_layers > 0 but no MTP "
+ "weights were found on disk. MTP weights were likely stripped during "
+ "quantization (e.g. mlx-community format). This model will produce "
+ "gibberish because the norm weight shift depends on MTP detection."
+ )
+ return MtpProbeResult(
+ model_declares_mtp=True,
+ mtp_tensors_found=False,
+ mtp_format=MtpFormat.STRIPPED,
+ mtp_count=0,
+ mtp_path=None,
+ mtp_tensor_keys=(),
+ )
+
+
+def _safetensors_keys(path: Path) -> tuple[str, ...]:
+ """Return tensor keys from a safetensors file without loading data."""
+ try:
+ from safetensors import safe_open
+
+ with safe_open(str(path), framework="numpy") as f:
+ return tuple(f.keys())
+ except Exception:
+ return ()
+
+
+def load_mtp_weights(model_path: Path | str) -> dict[str, Any] | None:
+ """Load MTP weights from a separate file (MTPLX format).
+
+ Only works for MTPLX_SEPARATE_FILE format. For embedded formats
+ (Original, oMLX), the MTP weights are already in the main shards
+ and will be loaded by ``mlx_lm``'s normal weight loading.
+
+ Args:
+ model_path: Path to the model directory.
+
+ Returns:
+ Dict of MTP tensor name → array, or ``None`` if not in separate file.
+ """
+ result = probe_mtp_weights(model_path)
+ if result.mtp_format != MtpFormat.MTPLX_SEPARATE_FILE or result.mtp_path is None:
+ return None
+
+ try:
+ from safetensors import safe_open
+
+ mtp_weights: dict[str, Any] = {}
+ with safe_open(result.mtp_path, framework="mlx") as f:
+ for key in cast(Iterable[str], f.keys()):
+ if "mtp." in key:
+ mtp_weights[key] = f.get_tensor(key)
+ return mtp_weights if mtp_weights else None
+ except Exception as e:
+ logger.warning(f"Failed to load MTP weights from {result.mtp_path}: {e}")
+ return None
diff --git a/src/exo/worker/engines/mlx/native_mtp_config.py b/src/exo/worker/engines/mlx/native_mtp_config.py
new file mode 100644
index 0000000000..645d1e88fc
--- /dev/null
+++ b/src/exo/worker/engines/mlx/native_mtp_config.py
@@ -0,0 +1,41 @@
+import os
+
+EXO_NATIVE_MTP_ENABLED_ENV = "EXO_NATIVE_MTP_ENABLED"
+
+
+def native_mtp_enabled_from_env() -> bool:
+ """Return whether native-MTP dispatch is enabled for supported cards."""
+ raw = os.environ.get(EXO_NATIVE_MTP_ENABLED_ENV)
+ if raw is None:
+ return True
+ return raw.strip().lower() not in {"0", "false", "no", "off"}
+
+
+def clamp_native_mtp_num_draft_tokens(
+ num_draft_tokens: int, *, max_k: int
+) -> tuple[int, bool]:
+ """Clamp native-MTP K to the card-declared runtime budget."""
+ bounded_max = max(1, max_k)
+ if num_draft_tokens < 1:
+ return 1, True
+ if num_draft_tokens > bounded_max:
+ return bounded_max, True
+ return num_draft_tokens, False
+
+
+def resolve_native_mtp_num_draft_tokens(
+ *,
+ request_num_draft_tokens: int | None,
+ configured_num_draft_tokens: int | None,
+ card_default_k: int,
+ card_max_k: int,
+) -> tuple[int, bool]:
+ """Resolve native-MTP K with request > startup config > card default."""
+ chosen = (
+ request_num_draft_tokens
+ if request_num_draft_tokens is not None
+ else configured_num_draft_tokens
+ )
+ if chosen is None:
+ chosen = card_default_k
+ return clamp_native_mtp_num_draft_tokens(chosen, max_k=card_max_k)
diff --git a/src/exo/worker/engines/mlx/spec_cache.py b/src/exo/worker/engines/mlx/spec_cache.py
new file mode 100644
index 0000000000..451537ba99
--- /dev/null
+++ b/src/exo/worker/engines/mlx/spec_cache.py
@@ -0,0 +1,178 @@
+"""Snapshot/rollback helpers for speculative-decode cache state.
+
+MLX's ``trim_prompt_cache`` only rolls back attention KV offsets, but the
+Qwen3.5/3.6 GatedDeltaNet ``ArraysCache`` recurrent state cannot be
+reconstructed from the offset alone. The native-MTP draft+verify loop
+therefore clones that recurrent state before each batched verify forward
+and restores it when a draft is rejected. Without this, a batched verify
+followed by an offset-only trim produces materially wrong logits.
+
+Adapted from MTPLX's ``snapshot_untrimmable_cache`` /
+``rollback_after_verify`` (Apache 2.0).
+"""
+
+# pyright: reportUnknownMemberType=false, reportUnknownArgumentType=false, reportUnknownVariableType=false, reportAttributeAccessIssue=false, reportAny=false
+
+from __future__ import annotations
+
+from dataclasses import dataclass
+from typing import Any
+
+import mlx.core as mx
+
+
+@dataclass(frozen=True)
+class CacheSnapshot:
+ """Frozen pair of (state, meta_state) tuples — one slot per cache entry.
+
+ Trimmable entries (KV caches) carry `None` in both tuples; non-trimmable
+ entries (GatedDeltaNet `ArraysCache`) carry a deep-cloned snapshot.
+ """
+
+ states: tuple[Any, ...]
+ meta_states: tuple[Any, ...]
+
+
+def _is_trimmable(entry: Any) -> bool:
+ """Mirror MTPLX's check: call `entry.is_trimmable()`, treat raise/missing as
+ non-trimmable. mlx-lm 0.31.3's `_BaseCache.is_trimmable` returns False;
+ KV-bearing subclasses override to True. `ArraysCache` (GDN recurrent
+ state) inherits the base False.
+ """
+ method = getattr(entry, "is_trimmable", None)
+ if not callable(method):
+ return False
+ try:
+ return bool(method())
+ except Exception:
+ return False
+
+
+def _clone_tree(value: Any) -> Any:
+ """Recursive deep clone of containers + mx.array leaves.
+
+ For mx.array leaves we materialize a fresh contiguous expression and
+ force evaluation so subsequent in-place writes by the cache cannot
+ mutate the snapshot through aliasing. Cost: one `mx.eval` per leaf.
+ """
+ if value is None:
+ return None
+ if isinstance(value, mx.array):
+ leaf = mx.contiguous(value)
+ mx.eval(leaf)
+ return leaf
+ if isinstance(value, tuple):
+ return tuple(_clone_tree(v) for v in value)
+ if isinstance(value, list):
+ return [_clone_tree(v) for v in value]
+ if isinstance(value, dict):
+ return {k: _clone_tree(v) for k, v in value.items()}
+ return value
+
+
+def _clone_tree_lazy(value: Any) -> Any:
+ """Recursive clone expression without forcing leaf evaluation.
+
+ MLX arrays are immutable values, and cache updates replace cache leaves
+ rather than mutating the old array storage in place. For speculative
+ snapshots this lets accept rounds discard the snapshot without paying the
+ full synchronization cost.
+ """
+ if value is None:
+ return None
+ if isinstance(value, mx.array):
+ return mx.contiguous(value)
+ if isinstance(value, tuple):
+ return tuple(_clone_tree_lazy(v) for v in value)
+ if isinstance(value, list):
+ return [_clone_tree_lazy(v) for v in value]
+ if isinstance(value, dict):
+ return {k: _clone_tree_lazy(v) for k, v in value.items()}
+ return value
+
+
+def snapshot_untrimmable_cache_lazy(cache: list[Any]) -> CacheSnapshot:
+ """Clone state + meta_state of every non-trimmable cache entry.
+
+ Trimmable entries get ``None`` (their offset trim is sufficient
+ rollback). Call this BEFORE the batched verify forward each spec-decode
+ round. The clone is lazy: MLX arrays are immutable and cache updates
+ replace leaves rather than mutating storage in place, so an accept round
+ can discard the snapshot without paying the full synchronization cost.
+ Measured logit/GDN-exact for K=1 accept and reject transactions.
+ """
+ states: list[Any] = []
+ meta_states: list[Any] = []
+ for entry in cache:
+ if _is_trimmable(entry):
+ states.append(None)
+ meta_states.append(None)
+ else:
+ states.append(_clone_tree_lazy(getattr(entry, "state", None)))
+ meta_states.append(_clone_tree_lazy(getattr(entry, "meta_state", None)))
+ return CacheSnapshot(states=tuple(states), meta_states=tuple(meta_states))
+
+
+def _restore_state_preserving_container(entry: Any, state: Any) -> None:
+ """Write `state` back into `entry` without breaking container identity.
+
+ `ArraysCache.state` is the same `list` the cache reads from internally.
+ Mutating it in place preserves identity; falling back to the property
+ setter is fine for entries that own one.
+ """
+ cloned = _clone_tree(state)
+ if hasattr(entry, "replace_state"):
+ entry.replace_state(cloned)
+ return
+ current = getattr(entry, "state", None)
+ if (
+ isinstance(current, list)
+ and isinstance(cloned, list)
+ and len(current) == len(cloned)
+ ):
+ current[:] = cloned
+ return
+ entry.state = cloned
+
+
+def restore_cache(cache: list[Any], snapshot: CacheSnapshot) -> None:
+ """Restore non-trimmable cache state from a snapshot. Trimmable entries
+ untouched. Use directly when you've already done your own trim, or call
+ `rollback_after_verify` to combine trim + restore.
+ """
+ for entry, state, meta_state in zip(
+ cache, snapshot.states, snapshot.meta_states, strict=True
+ ):
+ if state is not None:
+ _restore_state_preserving_container(entry, state)
+ if meta_state is not None:
+ entry.meta_state = _clone_tree(meta_state)
+
+
+def rollback_after_verify(
+ cache: list[Any],
+ snapshot: CacheSnapshot,
+ verified_tokens: int,
+) -> None:
+ """Undo a speculative target verify pass.
+
+ `verified_tokens` is the count of tokens to TRIM from trimmable entries
+ (matches MTPLX's API and mlx-lm's `trim(n)` "remove n from end"
+ convention). For a K=1 verify of `[primary, draft]` advancing the
+ cache by 2:
+ accept: rollback_after_verify(cache, snap, verified_tokens=0)
+ — keep the full 2-step advance, just rewrite GDN state.
+ But the GDN state at pos+2 IS what we want post-accept, so
+ callers typically do NOT call rollback on accept; they re-run
+ the snapshot at the top of the next round.
+ reject: rollback_after_verify(cache, snap, verified_tokens=2)
+ — trim KV by 2, restore GDN state, then re-forward `[primary]`
+ alone to advance back to pos+1.
+
+ Mirrors `mtplx/cache_state.py:rollback_after_verify`.
+ """
+ if verified_tokens > 0:
+ for entry in cache:
+ if _is_trimmable(entry) and hasattr(entry, "trim"):
+ entry.trim(verified_tokens)
+ restore_cache(cache, snapshot)
diff --git a/src/exo/worker/engines/mlx/utils_mlx.py b/src/exo/worker/engines/mlx/utils_mlx.py
index 730abf64e3..6b458d0e05 100644
--- a/src/exo/worker/engines/mlx/utils_mlx.py
+++ b/src/exo/worker/engines/mlx/utils_mlx.py
@@ -27,7 +27,7 @@
from mlx_lm.models.deepseek_v3 import DeepseekV3Model
from mlx_lm.tokenizer_utils import TokenizerWrapper
-from exo.shared.models.model_cards import ModelId
+from exo.shared.models.model_cards import ModelCard, ModelId
from exo.worker.engines.mlx.constants import TRUST_REMOTE_CODE
try:
@@ -160,6 +160,45 @@ def initialize_mlx(
return mlx_distributed_init(bound_instance)
+def _native_mtp_loader_eligible(target_card: ModelCard, model_path: Path) -> bool:
+ """Decide whether ``load_mlx_items`` should dispatch the native MTP loader.
+
+ Three conditions, all required:
+
+ 1. The card declares ``native_mtp`` (the operator opted in).
+ 2. The on-disk checkpoint actually exposes recoverable MTP weights
+ (probe at :mod:`exo.worker.engines.mlx.mtp_probe` returns
+ ``is_recoverable=True``). MTPLX sidecar layout, original-HF
+ embedded layout, and oMLX embedded layout all qualify.
+ 3. (Implicit at the call site) The placement is single-node --
+ ``load_mlx_items`` only calls this from the ``group is None``
+ branch. Native MTP is structurally single-node-only: the verify
+ forward through a TP-sharded target would amortise K+1 tokens
+ over K+1 compute units, eating the MTP speedup.
+
+ Returns ``False`` -- and the caller falls back to the stock loader --
+ when any condition fails. The fallback is conservative: operators who
+ declared ``native_mtp`` on a card whose weights happen to be
+ unrecoverable get a quiet downgrade instead of a hard load failure.
+ """
+ if target_card.native_mtp is None:
+ return False
+ # Local import to avoid pulling the probe (and safetensors) into the
+ # module-import path on workers that never load Qwen3.5/3.6.
+ from exo.worker.engines.mlx.mtp_probe import probe_mtp_weights
+
+ probe = probe_mtp_weights(model_path)
+ if not probe.is_recoverable:
+ logger.warning(
+ f"Card {target_card.model_id} declares native_mtp but the "
+ f"on-disk checkpoint at {model_path} has no recoverable MTP "
+ f"weights (probe verdict: format={probe.mtp_format}, "
+ f"count={probe.mtp_count}). Falling back to stock loader."
+ )
+ return False
+ return True
+
+
def load_mlx_items(
bound_instance: BoundInstance,
group: mx.distributed.Group | None,
@@ -170,9 +209,27 @@ def load_mlx_items(
if group is None:
logger.info(f"Single device used for {bound_instance.instance}")
- model_path = build_model_path(bound_instance.bound_shard.model_card.model_id)
+ target_card = bound_instance.bound_shard.model_card
+ model_path = build_model_path(target_card.model_id)
start_time = time.perf_counter()
- model, _ = load_model(model_path, lazy=True, strict=False)
+ # Native MTP path: only when (a) the card declares it AND (b) the
+ # placement is single-node (this branch) AND (c) the on-disk
+ # checkpoint actually exposes recoverable MTP weights. The vendored
+ # loader builds an MTP-aware model that the generator dispatches a
+ # native draft+verify loop against. Otherwise fall back to the
+ # stock load path unchanged.
+ if _native_mtp_loader_eligible(target_card, model_path):
+ from exo.worker.engines.mlx.vendor.qwen3_5_mtp_loader import (
+ load_mtp_model,
+ )
+
+ logger.info(
+ f"Loading {target_card.model_id} via native MTP loader "
+ f"(native_mtp=True, world_size=1, MTP weights recoverable)"
+ )
+ model, _ = load_mtp_model(model_path, lazy=True, strict=True)
+ else:
+ model, _ = load_model(model_path, lazy=True, strict=False)
# Eval layers one by one for progress reporting
try:
inner = get_inner_model(model)
diff --git a/src/exo/worker/engines/mlx/vendor/qwen3_5_mtp.py b/src/exo/worker/engines/mlx/vendor/qwen3_5_mtp.py
new file mode 100644
index 0000000000..c0bade044a
--- /dev/null
+++ b/src/exo/worker/engines/mlx/vendor/qwen3_5_mtp.py
@@ -0,0 +1,753 @@
+"""Native MTP (Multi-Token Prediction) sidecar model class for Qwen3.6/Qwen3.5.
+
+This module provides a self-contained ``Model`` class that composes over
+stock ``mlx_lm.models.qwen3_5`` and ``mlx_lm.models.qwen3_next`` with
+first-class support for Qwen3.6's MTP head. It is intended to graduate
+upstream as a new mlx-lm model class (in the same shape as the
+``gemma4_assistant`` mlx-lm PR #1276), but lives here in exo's vendor
+directory first.
+
+The design corrects seven concrete bugs in the prior upstream attempt
+(``ml-explore/mlx-lm`` PR #1226 by chuaaron) that caused the original
+work to be self-abandoned:
+
+1. **Strict weight loading.** Loader keeps ``strict=True`` for all
+ weights -- main and MTP -- so missing MTP shards raise instead of
+ silently initialising to random.
+2. **MTP sidecar file.** Qwen3.6 ships MTP weights in a separate
+ ``mtp.safetensors`` file (not embedded in the main shards). The
+ loader probes that location explicitly.
+3. **Per-weight norm-shift gate.** Stock ``Qwen3_5TextModel.sanitize``
+ shifts RMSNorm weights by +1 unconditionally when it sees MTP keys.
+ This double-shifts already-shifted norms. Our sanitize gates the
+ shift on ``value.mean() < 0.5`` per weight, matching MTPLX's
+ ``_finalize_mtp_weights``.
+4. **Post-norm hidden variant.** MTP was trained to ingest POST-norm
+ hidden states (i.e. ``self.model.norm(hidden)``), not the pre-norm
+ variant PR #1226 fed in. The ``pre_fc_norm_hidden`` RMSNorm inside
+ the MTP module assumes post-norm input.
+5. **MTP-specific quantization policies.** Cyankiwi prequantized
+ shards quantize only the attention/MLP linears inside MTP; ``fc``
+ and the norms stay in BF16. We detect the policy from the key set
+ and apply it. "all" prequantized shards also quantize ``fc``.
+6. **MTP KV cache priming.** ``make_mtp_cache`` returns the empty
+ cache shape; ``mtp_update_cache`` walks prefill tokens through the
+ MTP layer with logits suppressed, populating K/V so the draft loop
+ inherits prefill context.
+7. **Tests check numerical correctness.** The companion test module
+ includes both synthetic structural tests (no real download) and an
+ opt-in parity test against the real MTPLX artifact that asserts
+ top-1 agreement >= 60% against the main lm_head -- this catches
+ silent random-init regressions that shape-only tests would miss.
+
+The MTPLX runtime injection in ``mtplx/mtp_patch.py`` is the reference
+behaviour we mirror; that file is Apache 2.0 and we re-implement here
+in idiomatic mlx-lm class form rather than monkey-patching at runtime.
+"""
+
+from __future__ import annotations
+
+from dataclasses import dataclass
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import mlx.core as mx
+import mlx.nn as nn
+from mlx.nn.layers.distributed import shard_inplace, shard_linear
+from mlx.utils import tree_map
+from mlx_lm.models.base import BaseModelArgs, create_attention_mask, create_ssm_mask
+from mlx_lm.models.cache import ArraysCache, KVCache
+from mlx_lm.models.qwen3_5 import DecoderLayer as StockDecoderLayer
+from mlx_lm.models.qwen3_5 import TextModelArgs as StockTextModelArgs
+from mlx_lm.models.qwen3_next import Qwen3NextMLP as MLP # noqa: N814
+
+# pyright: reportAny=false, reportUnknownMemberType=false, reportUnknownArgumentType=false, reportUnknownVariableType=false, reportUnknownParameterType=false, reportMissingParameterType=false, reportPrivateUsage=false, reportIncompatibleMethodOverride=false, reportArgumentType=false, reportOptionalMemberAccess=false
+# This module composes over untyped mlx-lm and mlx.nn modules; their
+# attribute surface is dynamic and the type-checker can't see through
+# nn.Module subclassing in the way mlx-lm uses it.
+
+
+# ---------------------------------------------------------------------------
+# Exceptions
+# ---------------------------------------------------------------------------
+
+
+class MTPWeightsNotFound(RuntimeError): # noqa: N818 - public API name
+ """Raised when ``mtp_num_hidden_layers > 0`` but no MTP weights are loadable.
+
+ Carries the candidate filenames probed so the operator can diagnose
+ whether the sidecar wasn't downloaded, mounted, or named as expected.
+ """
+
+ def __init__(self, message: str, *, candidates: tuple[str, ...] = ()) -> None:
+ super().__init__(message)
+ self.candidates = candidates
+
+
+# ---------------------------------------------------------------------------
+# Model args
+# ---------------------------------------------------------------------------
+
+
+_RMSNORM_SUFFIXES: tuple[str, ...] = (
+ "input_layernorm.weight",
+ "post_attention_layernorm.weight",
+ "q_norm.weight",
+ "k_norm.weight",
+ "pre_fc_norm_hidden.weight",
+ "pre_fc_norm_embedding.weight",
+ "norm.weight",
+)
+
+
+@dataclass
+class TextModelArgs(StockTextModelArgs):
+ """Stock ``Qwen3_5`` text args extended with MTP fields.
+
+ Defaults match the cyankiwi-prequantized MTPLX contract:
+ - ``mtp_num_hidden_layers``: 0 (no MTP unless config declares one)
+ - ``mtp_hidden_variant``: 'post_norm' (THE critical PR #1226 fix)
+ - ``mtp_concat_order``: 'embedding_hidden' (embedding first)
+ """
+
+ mtp_num_hidden_layers: int = 0
+ mtp_hidden_variant: str = "post_norm"
+ mtp_concat_order: str = "embedding_hidden"
+
+
+@dataclass
+class ModelArgs(BaseModelArgs):
+ model_type: str
+ text_config: Dict[str, Any]
+
+ @classmethod
+ def from_dict(cls, params: Dict[str, Any]) -> "ModelArgs":
+ if "text_config" not in params:
+ return cls(model_type=params["model_type"], text_config=params)
+ return super().from_dict(params)
+
+
+# ---------------------------------------------------------------------------
+# MTP module
+# ---------------------------------------------------------------------------
+
+
+class MTPModule(nn.Module):
+ """The MTP head: ``[pre_fc_norm_hidden, pre_fc_norm_embedding, fc, layers, norm]``.
+
+ ``layers`` is a list of stock ``qwen3_5.DecoderLayer`` instances
+ constructed with ``layer_idx = full_attention_interval - 1`` so each
+ layer takes the full-attention (not linear-attention) branch
+ (``is_linear = False``). This matches MTPLX, which uses the
+ full-attention layer for MTP regardless of the main-model layer
+ cadence.
+ """
+
+ def __init__(self, args: TextModelArgs, n_layers: int) -> None:
+ super().__init__()
+ fa_layer_idx = args.full_attention_interval - 1
+ self.pre_fc_norm_hidden = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
+ self.pre_fc_norm_embedding = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
+ self.fc = nn.Linear(args.hidden_size * 2, args.hidden_size, bias=False)
+ self.layers = [
+ StockDecoderLayer(args, layer_idx=fa_layer_idx) for _ in range(n_layers)
+ ]
+ self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
+
+
+# ---------------------------------------------------------------------------
+# Inner text model with MTP attached
+# ---------------------------------------------------------------------------
+
+
+class Qwen3_5MTPInner(nn.Module): # noqa: N801 - mirrors mlx-lm's Qwen3_5TextModel naming
+ """Mirrors stock ``Qwen3_5TextModel`` plus an attached ``self.mtp``.
+
+ ``__call__`` accepts ``return_hidden``, ``emit_logits`` and
+ ``logits_keep`` kwargs for cooperative use with the MTP draft loop.
+ """
+
+ def __init__(self, args: TextModelArgs) -> None:
+ super().__init__()
+ self.args = args
+ self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
+ self.layers = [
+ StockDecoderLayer(args=args, layer_idx=i)
+ for i in range(args.num_hidden_layers)
+ ]
+ self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
+ self.ssm_idx = 0
+ self.fa_idx = args.full_attention_interval - 1
+ self.mtp: Optional[MTPModule] = None
+ if args.mtp_num_hidden_layers > 0:
+ self.mtp = MTPModule(args, args.mtp_num_hidden_layers)
+
+ def __call__(
+ self,
+ inputs: mx.array,
+ cache: Optional[List[Any]] = None,
+ input_embeddings: Optional[mx.array] = None,
+ return_hidden: bool = False,
+ ) -> Union[mx.array, tuple[mx.array, mx.array]]:
+ hidden_states = (
+ input_embeddings
+ if input_embeddings is not None
+ else self.embed_tokens(inputs)
+ )
+ if cache is None:
+ cache = [None] * len(self.layers)
+ fa_mask = create_attention_mask(hidden_states, cache[self.fa_idx])
+ ssm_mask = create_ssm_mask(hidden_states, cache[self.ssm_idx])
+ for layer, layer_cache in zip(self.layers, cache, strict=True):
+ mask = ssm_mask if layer.is_linear else fa_mask
+ hidden_states = layer(hidden_states, mask=mask, cache=layer_cache)
+ post = self.norm(hidden_states)
+ if return_hidden:
+ return post, hidden_states
+ return post
+
+
+# ---------------------------------------------------------------------------
+# TextModel wrapper (lm_head + cache builders + sanitize)
+# ---------------------------------------------------------------------------
+
+
+class TextModel(nn.Module):
+ """Wraps ``Qwen3_5MTPInner`` + ``lm_head``; mirrors stock ``TextModel`` API."""
+
+ def __init__(self, args: TextModelArgs) -> None:
+ super().__init__()
+ self.args = args
+ self.model_type = args.model_type
+ self.model = Qwen3_5MTPInner(args)
+ if not args.tie_word_embeddings:
+ self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
+ # Storage for an optional sidecar weight loader supplied by the
+ # outer loader (qwen3_5_mtp_loader). Sanitize uses it to pick up
+ # the separate mtp.safetensors file when no MTP keys appear in
+ # the main shards. Type: Callable[[], dict[str, mx.array]] | None.
+ self._mtp_sidecar_loader: Optional[Callable[[], Dict[str, mx.array]]] = None
+
+ def __call__(
+ self,
+ inputs: mx.array,
+ cache: Optional[List[Any]] = None,
+ input_embeddings: Optional[mx.array] = None,
+ return_hidden: bool = False,
+ ) -> Union[mx.array, tuple[mx.array, mx.array]]:
+ result = self.model(
+ inputs,
+ cache,
+ input_embeddings=input_embeddings,
+ return_hidden=return_hidden,
+ )
+ if return_hidden:
+ assert isinstance(result, tuple)
+ post, _pre = result
+ else:
+ assert isinstance(result, mx.array)
+ post = result
+ if self.args.tie_word_embeddings:
+ logits = self.model.embed_tokens.as_linear(post)
+ else:
+ logits = self.lm_head(post)
+ if return_hidden:
+ return logits, post
+ return logits
+
+ @property
+ def layers(self) -> List[StockDecoderLayer]:
+ return self.model.layers
+
+ def make_cache(self) -> List[Any]:
+ return [
+ ArraysCache(size=2) if layer.is_linear else KVCache()
+ for layer in self.layers
+ ]
+
+ def make_mtp_cache(self) -> List[Any]:
+ if self.model.mtp is None:
+ return []
+ return [KVCache() for _ in self.model.mtp.layers]
+
+ # -------------------- sanitize --------------------
+
+ def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
+ """Sanitize stock+MTP weights.
+
+ The incoming weight dict is keyed as it appears in safetensors:
+ for Qwen3.6-27B that's the ``language_model.model.*`` /
+ ``language_model.lm_head.*`` form. (The outer
+ :meth:`Model.sanitize` already normalized any other-prefix
+ variants into this shape.)
+
+ Behavior:
+ 1. The MTP submodule's parameters live at
+ ``language_model.model.mtp.*``. Any incoming key in the
+ ``language_model.mtp.*`` or bare ``mtp.*`` (or
+ ``language_model.model.mtp.*``) namespaces is normalized
+ into the canonical model-tree key.
+ 2. If ``mtp_num_hidden_layers > 0`` and no MTP keys are
+ present, the sidecar loader (set by the outer loader) is
+ invoked. If that still produces no MTP keys,
+ :class:`MTPWeightsNotFound` is raised.
+ 3. Per-weight norm-shift gate: ``v + 1.0`` only when ``v.ndim
+ == 1`` and ``v.mean() < 0.5`` and the suffix matches a
+ RMSNorm weight. This prevents double-shifting weights that
+ are already in the post-shift form.
+ 4. conv1d axis fix (matches stock).
+ 5. ``lm_head.weight`` is dropped if ``tie_word_embeddings``.
+ """
+ # ----- 1. namespace normalization -----
+ normalized: Dict[str, mx.array] = {}
+ canonical_mtp_prefix = "language_model.model.mtp."
+ for key, value in weights.items():
+ if key.startswith("language_model.model.mtp."):
+ normalized[key] = value
+ elif key.startswith("language_model.mtp."):
+ normalized[canonical_mtp_prefix + key[len("language_model.mtp.") :]] = (
+ value
+ )
+ elif key.startswith("mtp."):
+ normalized[canonical_mtp_prefix + key[len("mtp.") :]] = value
+ else:
+ normalized[key] = value
+
+ embedded_mtp_keys = [
+ k for k in normalized if k.startswith(canonical_mtp_prefix)
+ ]
+ wants_mtp = self.args.mtp_num_hidden_layers > 0
+
+ # ----- 2. sidecar fallback -----
+ if wants_mtp and not embedded_mtp_keys:
+ loader = self._mtp_sidecar_loader
+ if loader is not None:
+ extra = loader()
+ if extra:
+ for k, v in extra.items():
+ # Sidecar keys come in bare ``mtp.*`` form.
+ if k.startswith("mtp."):
+ normalized[canonical_mtp_prefix + k[len("mtp.") :]] = v
+ elif k.startswith(canonical_mtp_prefix):
+ normalized[k] = v
+ else:
+ # Unexpected; let load_weights complain.
+ normalized[k] = v
+ embedded_mtp_keys = [
+ k for k in normalized if k.startswith(canonical_mtp_prefix)
+ ]
+
+ if wants_mtp and not embedded_mtp_keys:
+ raise MTPWeightsNotFound(
+ "MTP declared in config (mtp_num_hidden_layers="
+ f"{self.args.mtp_num_hidden_layers}) but no MTP weights were "
+ "found in the main shards and no sidecar loader produced any. "
+ "Candidate sidecar files (relative to model dir): "
+ "mtp.safetensors, mtp/weights.safetensors, model-mtp.safetensors, "
+ "or config.mlx_lm_extra_tensors.mtp_file",
+ candidates=(
+ "mtp.safetensors",
+ "mtp/weights.safetensors",
+ "model-mtp.safetensors",
+ ),
+ )
+
+ # ----- 3-5. tie/conv/norm fixes -----
+ if self.args.tie_word_embeddings:
+ normalized.pop("language_model.lm_head.weight", None)
+ normalized.pop("lm_head.weight", None)
+
+ for k, v in list(normalized.items()):
+ if "conv1d.weight" in k and v.shape[-1] != 1:
+ normalized[k] = v.moveaxis(2, 1)
+ if v.ndim == 1 and any(k.endswith(sfx) for sfx in _RMSNORM_SUFFIXES):
+ mean = float(v.astype(mx.float32).mean().item())
+ if mean < 0.5:
+ normalized[k] = v + 1.0
+
+ return normalized
+
+ # -------------------- quant predicates --------------------
+
+ @property
+ def quant_predicate(self) -> Optional[Callable[[str, nn.Module], Any]]:
+ """Main-model quant predicate (delegates to MoE-style if used).
+
+ For Qwen3.6-27B (no MoE), this returns None and per-layer quant
+ comes from the ``quantization`` dict in config.json.
+ """
+ if self.args.num_experts <= 0:
+ return None
+
+ def predicate(path: str, _: nn.Module) -> Any:
+ if path.endswith(("mlp.gate", "shared_expert_gate")):
+ return {"group_size": 64, "bits": 8}
+ return True
+
+ return predicate
+
+ @property
+ def cast_predicate(self) -> Callable[[str], bool]:
+ def predicate(path: str) -> bool:
+ return not path.endswith("A_log")
+
+ return predicate
+
+
+# ---------------------------------------------------------------------------
+# Outer Model
+# ---------------------------------------------------------------------------
+
+
+def _classify_mtp_key_set(keys: tuple[str, ...]) -> str:
+ """Return one of ``'unquantized'``, ``'cyankiwi'``, ``'all'``.
+
+ The classification is exactly mirrored from MTPLX constants.py:
+ - 'cyankiwi': attention + mlp quantized; ``fc`` + norms BF16
+ - 'all': everything quantized including ``fc``
+ - 'unquantized': bf16 throughout (only ``.weight`` keys, no
+ ``.scales`` / ``.biases``)
+ """
+ has_fc_scales = any(k == "mtp.fc.scales" for k in keys)
+ has_attn_scales = any(k.endswith(".self_attn.q_proj.scales") for k in keys)
+ if has_fc_scales:
+ return "all"
+ if has_attn_scales:
+ return "cyankiwi"
+ return "unquantized"
+
+
+def _quantize_mtp_module(
+ mtp: MTPModule,
+ *,
+ policy: str,
+ bits: int,
+ group_size: int,
+ mode: str = "affine",
+ quant_overrides: Optional[Dict[str, Dict[str, Any]]] = None,
+) -> None:
+ """Apply MTP-specific quantization according to ``policy``.
+
+ - ``'all'``: quantize every quantizable module in the MTP subtree.
+ - ``'cyankiwi'``: quantize attention/MLP linears only; ``fc``,
+ ``pre_fc_norm_*``, ``norm`` stay unquantized.
+ - ``'unquantized'``: no-op.
+ """
+ if policy == "unquantized":
+ return
+ if policy == "all":
+ nn.quantize(mtp, group_size=group_size, bits=bits, mode=mode)
+ return
+ if policy != "cyankiwi":
+ raise ValueError(f"Unsupported MTP quantization policy: {policy!r}")
+
+ def predicate(path: str, module: nn.Module) -> Any:
+ if path == "fc" or path.startswith("pre_fc_norm") or path == "norm":
+ return False
+ if path.endswith("mlp.gate"):
+ return False
+ if path.startswith("layers.") and hasattr(module, "to_quantized"):
+ override = (quant_overrides or {}).get(f"language_model.mtp.{path}")
+ if override is not None:
+ return {
+ "group_size": int(override.get("group_size", group_size)),
+ "bits": int(override.get("bits", bits)),
+ "mode": str(override.get("mode", mode)),
+ }
+ return {"group_size": group_size, "bits": bits, "mode": mode}
+ return False
+
+ nn.quantize(mtp, class_predicate=predicate)
+
+
+class Model(nn.Module):
+ """Outer model class. Exposes ``language_model: TextModel`` + ``mtp_forward``."""
+
+ def __init__(self, args: ModelArgs) -> None:
+ super().__init__()
+ self.args = args
+ self.model_type = args.model_type
+ self.language_model = TextModel(TextModelArgs.from_dict(args.text_config))
+
+ # -------------------- forward --------------------
+
+ def __call__(
+ self,
+ inputs: mx.array,
+ cache: Optional[List[Any]] = None,
+ input_embeddings: Optional[mx.array] = None,
+ return_hidden: bool = False,
+ ) -> Union[mx.array, tuple[mx.array, mx.array]]:
+ return self.language_model(
+ inputs,
+ cache=cache,
+ input_embeddings=input_embeddings,
+ return_hidden=return_hidden,
+ )
+
+ # -------------------- MTP forward --------------------
+
+ def _mtp_core(
+ self,
+ hidden_states: mx.array,
+ next_token_ids: mx.array,
+ *,
+ mtp_cache: Optional[List[Any]],
+ concat_order: Optional[str],
+ mtp_hidden_variant: str,
+ emit_logits: bool,
+ ) -> tuple[Optional[mx.array], mx.array]:
+ text_model = self.language_model
+ inner = text_model.model
+ mtp = inner.mtp
+ if mtp is None:
+ raise RuntimeError("Model was constructed without an MTP head")
+
+ input_embeds = inner.embed_tokens(next_token_ids)
+ e = mtp.pre_fc_norm_embedding(input_embeds)
+ h = mtp.pre_fc_norm_hidden(hidden_states)
+ order = concat_order or text_model.args.mtp_concat_order
+ parts = [e, h] if order == "embedding_hidden" else [h, e]
+ x = mtp.fc(mx.concatenate(parts, axis=-1))
+ layer_cache = mtp_cache[0] if mtp_cache else None
+ mask = create_attention_mask(x, layer_cache)
+ x = mtp.layers[0](x, mask=mask, cache=layer_cache)
+ post_norm = mtp.norm(x)
+ if mtp_hidden_variant == "post_norm":
+ hidden = post_norm
+ elif mtp_hidden_variant == "pre_norm":
+ hidden = x
+ else:
+ raise ValueError(
+ f"Unsupported mtp_hidden_variant={mtp_hidden_variant!r}; "
+ "use 'post_norm' or 'pre_norm'"
+ )
+ if not emit_logits:
+ return None, hidden
+ if text_model.args.tie_word_embeddings:
+ logits = inner.embed_tokens.as_linear(post_norm)
+ else:
+ logits = text_model.lm_head(post_norm)
+ return logits, hidden
+
+ def mtp_forward(
+ self,
+ hidden_states: mx.array,
+ next_token_ids: mx.array,
+ *,
+ mtp_cache: Optional[List[Any]] = None,
+ concat_order: Optional[str] = None,
+ mtp_hidden_variant: str = "post_norm",
+ emit_logits: bool = True,
+ return_hidden: bool = False,
+ ) -> Union[mx.array, tuple[mx.array, mx.array]]:
+ logits, hidden = self._mtp_core(
+ hidden_states,
+ next_token_ids,
+ mtp_cache=mtp_cache,
+ concat_order=concat_order,
+ mtp_hidden_variant=mtp_hidden_variant,
+ emit_logits=emit_logits,
+ )
+ if not emit_logits:
+ # caller asked for cache-update only
+ return hidden
+ assert logits is not None
+ if return_hidden:
+ return logits, hidden
+ return logits
+
+ def mtp_update_cache(
+ self,
+ hidden_states: mx.array,
+ next_token_ids: mx.array,
+ *,
+ mtp_cache: Optional[List[Any]] = None,
+ concat_order: Optional[str] = None,
+ ) -> mx.array:
+ """Populate MTP KV cache without emitting logits.
+
+ Used during prefill to seed the MTP cache with K/V derived from
+ the prompt history, so subsequent draft calls see prefill
+ context. Returns the post-norm hidden the MTP head produced (for
+ chained ``mtp_update_cache`` calls if you want them).
+ """
+ _, hidden = self._mtp_core(
+ hidden_states,
+ next_token_ids,
+ mtp_cache=mtp_cache,
+ concat_order=concat_order,
+ mtp_hidden_variant="post_norm",
+ emit_logits=False,
+ )
+ return hidden
+
+ # -------------------- cache builders --------------------
+
+ def make_cache(self) -> List[Any]:
+ return self.language_model.make_cache()
+
+ def make_mtp_cache(self) -> List[Any]:
+ return self.language_model.make_mtp_cache()
+
+ @property
+ def layers(self) -> List[StockDecoderLayer]:
+ return self.language_model.model.layers
+
+ # -------------------- sanitize --------------------
+
+ def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
+ """Outer sanitize: normalize prefixes, then delegate to TextModel.
+
+ Stock ``qwen3_5.Model.sanitize`` returns weights with the
+ ``language_model.`` prefix intact so the outer ``load_weights``
+ can route them to ``self.language_model.*``. We keep that
+ convention. MTP keys are normalized to ``language_model.mtp.*``
+ before delegation so :meth:`TextModel.sanitize` -- which keys off
+ bare ``mtp.*`` -- can detect and process them.
+ """
+ sanitized: Dict[str, mx.array] = {}
+ for key, value in weights.items():
+ if key.startswith(("vision_tower", "model.visual")):
+ continue
+ if key.startswith("model.language_model"):
+ key = key.replace("model.language_model", "language_model.model")
+ elif key.startswith("language_model."):
+ pass
+ else:
+ key = "language_model." + key
+ sanitized[key] = value
+ # Hand off to TextModel.sanitize, which sees keys WITH the
+ # "language_model." prefix and returns them WITH the prefix
+ # preserved (except for MTP keys, which it normalizes).
+ return self.language_model.sanitize(sanitized)
+
+ # -------------------- quantize hook --------------------
+
+ @property
+ def quant_predicate(self) -> Optional[Callable[[str, nn.Module], Any]]:
+ return self.language_model.quant_predicate
+
+ @property
+ def cast_predicate(self) -> Callable[[str], bool]:
+ return self.language_model.cast_predicate
+
+ # -------------------- shard (DELIBERATELY replicates MTP) --------------------
+
+ def shard(self, group: Optional[Any] = None) -> None:
+ """Stock-style shard for main layers; MTP is REPLICATED, not sharded.
+
+ Replicating the MTP module across all nodes keeps the draft loop
+ independent of inter-node communication latency, which is the
+ whole point of speculative decoding.
+ """
+ group = group or mx.distributed.init()
+ N = group.size() # noqa: N806 - stock qwen3_5.Model.shard uses N
+ rank = group.rank()
+
+ def conv_sharding(
+ key_dim: int,
+ ) -> Callable[[str, mx.array], tuple[int, list[int]]]:
+ return lambda p, w: (0, [key_dim, 2 * key_dim])
+
+ def repeat_kv_layer_inplace(layer: nn.Module, h: int) -> None:
+ if h >= N:
+ return
+
+ def _repeat(p: mx.array) -> mx.array:
+ s = p.shape
+ p = p.reshape(h, s[0] // h, *s[1:])
+ p = mx.repeat(p, N // h, axis=0)
+ p = p.reshape(-1, *s[1:])
+ return p
+
+ layer.update(tree_map(_repeat, layer.parameters()))
+
+ for layer in self.layers:
+ if layer.is_linear:
+ kd = layer.linear_attn.key_dim
+ layer.linear_attn.sharding_group = group
+ shard_inplace(layer.linear_attn.conv1d, conv_sharding(kd), group=group)
+ layer.linear_attn.conv1d.groups //= N
+ shard_inplace(
+ layer.linear_attn.in_proj_qkv,
+ "all-to-sharded",
+ segments=[kd, 2 * kd],
+ group=group,
+ )
+ shard_inplace(
+ layer.linear_attn.in_proj_z, "all-to-sharded", group=group
+ )
+ shard_inplace(
+ layer.linear_attn.in_proj_b, "all-to-sharded", group=group
+ )
+ shard_inplace(
+ layer.linear_attn.in_proj_a, "all-to-sharded", group=group
+ )
+ layer.linear_attn.dt_bias = mx.contiguous(
+ mx.split(layer.linear_attn.dt_bias, N)[rank]
+ )
+ layer.linear_attn.A_log = mx.contiguous(
+ mx.split(layer.linear_attn.A_log, N)[rank]
+ )
+ shard_inplace(layer.linear_attn.out_proj, "sharded-to-all", group=group)
+ layer.linear_attn.num_k_heads //= N
+ layer.linear_attn.num_v_heads //= N
+ layer.linear_attn.key_dim //= N
+ layer.linear_attn.value_dim //= N
+ layer.linear_attn.conv_dim //= N
+ else:
+ layer.self_attn.o_proj = shard_linear(
+ layer.self_attn.o_proj, "sharded-to-all", group=group
+ )
+ layer.self_attn.q_proj = shard_linear(
+ layer.self_attn.q_proj, "all-to-sharded", group=group
+ )
+ repeat_kv_layer_inplace(
+ layer.self_attn.k_proj, layer.self_attn.num_key_value_heads
+ )
+ repeat_kv_layer_inplace(
+ layer.self_attn.v_proj, layer.self_attn.num_key_value_heads
+ )
+ layer.self_attn.k_proj = shard_linear(
+ layer.self_attn.k_proj, "all-to-sharded", group=group
+ )
+ layer.self_attn.v_proj = shard_linear(
+ layer.self_attn.v_proj, "all-to-sharded", group=group
+ )
+ layer.self_attn.num_attention_heads //= N
+ layer.self_attn.num_key_value_heads = max(
+ 1, layer.self_attn.num_key_value_heads // N
+ )
+
+ if isinstance(layer.mlp, MLP):
+ layer.mlp.gate_proj = shard_linear(
+ layer.mlp.gate_proj, "all-to-sharded", group=group
+ )
+ layer.mlp.down_proj = shard_linear(
+ layer.mlp.down_proj, "sharded-to-all", group=group
+ )
+ layer.mlp.up_proj = shard_linear(
+ layer.mlp.up_proj, "all-to-sharded", group=group
+ )
+
+ # MTP is intentionally NOT sharded -- it's replicated on every
+ # node so the draft loop is fully local. See class docstring.
+
+
+# ---------------------------------------------------------------------------
+# Public helpers (used by the loader and tests)
+# ---------------------------------------------------------------------------
+
+
+__all__ = [
+ "MTPWeightsNotFound",
+ "TextModelArgs",
+ "ModelArgs",
+ "MTPModule",
+ "Qwen3_5MTPInner",
+ "TextModel",
+ "Model",
+ "_classify_mtp_key_set",
+ "_quantize_mtp_module",
+ "_RMSNORM_SUFFIXES",
+]
diff --git a/src/exo/worker/engines/mlx/vendor/qwen3_5_mtp_loader.py b/src/exo/worker/engines/mlx/vendor/qwen3_5_mtp_loader.py
new file mode 100644
index 0000000000..9a7d7fbcc9
--- /dev/null
+++ b/src/exo/worker/engines/mlx/vendor/qwen3_5_mtp_loader.py
@@ -0,0 +1,303 @@
+"""Loader that swaps in the vendored Qwen3.5/3.6 MTP-aware model class.
+
+If the model declares MTP (``text_config.mtp_num_hidden_layers > 0``)
+and a usable weight source exists (either a separate ``mtp.safetensors``
+sidecar, or ``mtp.*`` keys in the main shards), the loader dispatches
+``model_type='qwen3_5'`` / ``'qwen3_5_moe'`` to
+:class:`vendor.qwen3_5_mtp.Model` and threads a sidecar-weights callable
+into the model so ``sanitize`` can pick it up.
+
+If MTP is not declared, or no weights are available, the loader falls
+through to stock ``mlx_lm.utils.load_model`` (which produces stock
+``mlx_lm.models.qwen3_5.Model``).
+
+Strict-load is the only mode supported -- random init of MTP weights
+would silently regress to the PR #1226 failure mode.
+"""
+
+from __future__ import annotations
+
+import json
+from pathlib import Path
+from typing import Any, Callable, Dict, Optional, Tuple, Type
+
+import mlx.core as mx
+import mlx.nn as nn
+from mlx_lm import utils as _mlx_lm_utils
+from mlx_lm.utils import load_model
+
+from .qwen3_5_mtp import (
+ Model as MtpModel,
+)
+from .qwen3_5_mtp import (
+ ModelArgs as MtpModelArgs,
+)
+from .qwen3_5_mtp import (
+ MTPWeightsNotFound,
+ _classify_mtp_key_set,
+ _quantize_mtp_module,
+)
+
+_get_classes: Callable[..., Tuple[Type[nn.Module], Type[Any]]] = (
+ _mlx_lm_utils._get_classes # pyright: ignore[reportAttributeAccessIssue]
+)
+_SUPPORTED_NATIVE_MTP_MODEL_TYPES = frozenset({"qwen3_5", "qwen3_5_moe"})
+
+
+# pyright: reportAny=false, reportUnknownMemberType=false, reportUnknownArgumentType=false, reportUnknownVariableType=false, reportUnknownParameterType=false, reportMissingParameterType=false, reportPrivateUsage=false, reportAttributeAccessIssue=false, reportUnnecessaryTypeIgnoreComment=false
+# Loader operates on untyped safetensors / JSON metadata and patches
+# mlx-lm classes whose surface is dynamic.
+
+
+# ---------------------------------------------------------------------------
+# Candidate sidecar locations
+# ---------------------------------------------------------------------------
+
+
+_DEFAULT_SIDECAR_CANDIDATES: tuple[str, ...] = (
+ "mtp.safetensors",
+ "mtp/weights.safetensors",
+ "model-mtp.safetensors",
+)
+
+
+def _config_sidecar_filename(config: Dict[str, Any]) -> Optional[str]:
+ extra = config.get("mlx_lm_extra_tensors") or {}
+ if isinstance(extra, dict):
+ name = extra.get("mtp_file")
+ if isinstance(name, str) and name:
+ return name
+ return None
+
+
+def _resolve_sidecar_path(model_path: Path, config: Dict[str, Any]) -> Optional[Path]:
+ """Return the first existing sidecar path or ``None``."""
+ configured = _config_sidecar_filename(config)
+ if configured is not None:
+ candidate = model_path / configured
+ if candidate.exists():
+ return candidate
+ for rel in _DEFAULT_SIDECAR_CANDIDATES:
+ candidate = model_path / rel
+ if candidate.exists():
+ return candidate
+ return None
+
+
+def _embedded_mtp_keys(model_path: Path) -> Tuple[str, ...]:
+ """Probe ``model.safetensors.index.json`` for embedded MTP keys.
+
+ Embedded layouts use either the ``mtp.*`` prefix (original HF
+ format) or ``language_model.mtp.*`` (oMLX). We treat both as
+ embedded and let the sanitizer normalize the namespace.
+ """
+ index = model_path / "model.safetensors.index.json"
+ if not index.exists():
+ return ()
+ try:
+ payload = json.loads(index.read_text(encoding="utf-8"))
+ except (json.JSONDecodeError, OSError):
+ return ()
+ weight_map = payload.get("weight_map")
+ if not isinstance(weight_map, dict):
+ return ()
+ return tuple(
+ sorted(
+ str(k)
+ for k in weight_map
+ if str(k).startswith("mtp.") or str(k).startswith("language_model.mtp.")
+ )
+ )
+
+
+def _sidecar_keys(sidecar: Path) -> Tuple[str, ...]:
+ try:
+ from safetensors import safe_open # type: ignore[import-not-found]
+ except ImportError:
+ return ()
+ try:
+ with safe_open(str(sidecar), framework="numpy") as handle: # type: ignore[no-untyped-call]
+ # safe_open is not iterable; .keys() is the supported API.
+ return tuple(sorted(str(k) for k in handle.keys())) # noqa: SIM118
+ except Exception:
+ return ()
+
+
+# ---------------------------------------------------------------------------
+# Weight finalization
+# ---------------------------------------------------------------------------
+
+
+_RMSNORM_SUFFIXES: tuple[str, ...] = (
+ "input_layernorm.weight",
+ "post_attention_layernorm.weight",
+ "q_norm.weight",
+ "k_norm.weight",
+ "pre_fc_norm_hidden.weight",
+ "pre_fc_norm_embedding.weight",
+ "norm.weight",
+)
+
+
+def _strip_mtp_prefix(key: str) -> str:
+ if key.startswith("language_model.mtp."):
+ return "mtp." + key[len("language_model.mtp.") :]
+ return key
+
+
+def _load_sidecar_weights(sidecar: Path) -> Dict[str, mx.array]:
+ """Load all MTP weights from the sidecar with the ``mtp.`` prefix kept."""
+ raw = mx.load(str(sidecar))
+ out: Dict[str, mx.array] = {}
+ for key, value in raw.items():
+ normalized = _strip_mtp_prefix(str(key))
+ if normalized.startswith("mtp."):
+ out[normalized] = value
+ return out
+
+
+# ---------------------------------------------------------------------------
+# Public loader entry point
+# ---------------------------------------------------------------------------
+
+
+def _model_declares_mtp(config: Dict[str, Any]) -> bool:
+ tcfg = config.get("text_config", config)
+ return int(tcfg.get("mtp_num_hidden_layers", 0) or 0) > 0
+
+
+def load_mtp_model(
+ model_path: Path,
+ *,
+ lazy: bool = False,
+ strict: bool = True,
+) -> Tuple[nn.Module, Dict[str, Any]]:
+ """Load a Qwen3.5/3.6 model, attaching native MTP support if available.
+
+ Returns the loaded model and the (possibly updated) config dict,
+ matching ``mlx_lm.utils.load_model``'s signature.
+
+ Raises :class:`MTPWeightsNotFound` if MTP is declared but no weight
+ source can be found.
+ """
+ if not strict:
+ raise ValueError(
+ "load_mtp_model only supports strict=True; non-strict loading "
+ "would silently allow random-initialized MTP weights, "
+ "reproducing the PR #1226 failure mode."
+ )
+
+ with open(model_path / "config.json", encoding="utf-8") as f:
+ config = json.load(f)
+
+ if config.get(
+ "model_type"
+ ) not in _SUPPORTED_NATIVE_MTP_MODEL_TYPES or not _model_declares_mtp(config):
+ # No MTP declared (or wrong model type): fall through to stock.
+ return load_model(model_path, lazy=lazy, strict=strict)
+
+ sidecar = _resolve_sidecar_path(model_path, config)
+ embedded_keys = _embedded_mtp_keys(model_path)
+ has_embedded_mtp = bool(embedded_keys)
+
+ if sidecar is None and not has_embedded_mtp:
+ candidates_msg = ", ".join(
+ (str(_config_sidecar_filename(config) or ""),) + _DEFAULT_SIDECAR_CANDIDATES
+ )
+ raise MTPWeightsNotFound(
+ "Qwen3.5/3.6 model declares MTP but no MTP weights are present in "
+ f"{model_path}. Probed: {candidates_msg} and "
+ "model.safetensors.index.json for embedded mtp.* / "
+ "language_model.mtp.* keys.",
+ candidates=(str(sidecar) if sidecar else "",) + _DEFAULT_SIDECAR_CANDIDATES,
+ )
+
+ # Decide quantization policy from the available key set (sidecar
+ # takes precedence; if absent, use embedded keys).
+ keys_for_policy: tuple[str, ...] = ()
+ if sidecar is not None:
+ keys_for_policy = _sidecar_keys(sidecar)
+ if not keys_for_policy and embedded_keys:
+ keys_for_policy = tuple(_strip_mtp_prefix(k) for k in embedded_keys)
+ policy = _classify_mtp_key_set(keys_for_policy)
+
+ # Pull bits/group_size from explicit MTP quant config if present,
+ # otherwise fall back to the main quantization block.
+ mtp_quant = config.get("mtplx_mtp_quantization") or {}
+ quant_overrides: Dict[str, Dict[str, Any]] = {}
+ if isinstance(mtp_quant, dict) and mtp_quant.get("prequantized"):
+ bits = int(mtp_quant.get("bits", 8))
+ group_size = int(mtp_quant.get("group_size", 64))
+ mode = str(mtp_quant.get("mode", "affine"))
+ else:
+ main_quant = (
+ config.get("quantization") or config.get("quantization_config") or {}
+ )
+ bits = int(main_quant.get("bits", 4))
+ group_size = int(main_quant.get("group_size", 64))
+ mode = str(main_quant.get("mode", "affine"))
+ if isinstance(main_quant, dict):
+ quant_overrides = {
+ str(key): value
+ for key, value in main_quant.items()
+ if isinstance(key, str)
+ and key.startswith("language_model.mtp.")
+ and isinstance(value, dict)
+ }
+
+ # Build a sidecar-loader callable; the TextModel.sanitize calls it
+ # when it can't find embedded MTP keys.
+ sidecar_loader: Optional[Callable[[], Dict[str, mx.array]]] = None
+ if sidecar is not None:
+ captured = sidecar # avoid late-binding
+
+ def _loader() -> Dict[str, mx.array]:
+ return _load_sidecar_weights(captured)
+
+ sidecar_loader = _loader
+
+ # Custom get_model_classes that returns our MTP-aware classes for
+ # qwen3_5 and arms the model instance with the sidecar loader and
+ # the MTP quant policy/params before sanitize runs.
+ def get_classes(config: Dict[str, Any]) -> Tuple[Type[nn.Module], Type[Any]]:
+ cfg = config
+ if cfg.get("model_type") not in _SUPPORTED_NATIVE_MTP_MODEL_TYPES:
+ return _get_classes(cfg)
+
+ original_init = MtpModel.__init__
+
+ def patched_init(self_: MtpModel, args: MtpModelArgs) -> None:
+ original_init(self_, args)
+ self_.language_model._mtp_sidecar_loader = sidecar_loader
+ # Quantize MTP module per the detected policy BEFORE
+ # load_weights runs in mlx_lm.utils.load_model. The main
+ # model's per-layer quantization dict in config["quantization"]
+ # does not cover the MTP submodule (no entries with
+ # 'mtp.*' paths), so we have to do it ourselves.
+ mtp_submodule = self_.language_model.model.mtp
+ if mtp_submodule is not None and policy != "unquantized":
+ _quantize_mtp_module(
+ mtp_submodule,
+ policy=policy,
+ bits=bits,
+ group_size=group_size,
+ mode=mode,
+ quant_overrides=quant_overrides,
+ )
+
+ # Use a tiny subclass so we don't permanently mutate MtpModel.
+ class _PatchedMtpModel(MtpModel):
+ __init__ = patched_init # type: ignore[assignment]
+
+ return _PatchedMtpModel, MtpModelArgs
+
+ model, updated_config = load_model(
+ model_path,
+ lazy=lazy,
+ strict=True,
+ get_model_classes=get_classes,
+ )
+ return model, updated_config
+
+
+__all__ = ["load_mtp_model"]
diff --git a/src/exo/worker/engines/mlx/vendor/tests/__init__.py b/src/exo/worker/engines/mlx/vendor/tests/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/src/exo/worker/engines/mlx/vendor/tests/test_qwen3_5_mtp.py b/src/exo/worker/engines/mlx/vendor/tests/test_qwen3_5_mtp.py
new file mode 100644
index 0000000000..e48c39e3f4
--- /dev/null
+++ b/src/exo/worker/engines/mlx/vendor/tests/test_qwen3_5_mtp.py
@@ -0,0 +1,472 @@
+"""Synthetic unit tests for the native Qwen3.5/3.6 MTP model class.
+
+These tests build tiny configurations and synthetic weight dicts so the
+whole suite runs in seconds without requiring any model download.
+The opt-in parity test against the real MTPLX artifact lives in a
+separate file (test_qwen3_5_mtp_parity.py).
+"""
+
+# pyright: reportAny=false, reportUnknownMemberType=false, reportUnknownArgumentType=false, reportUnknownVariableType=false, reportUnknownParameterType=false, reportMissingParameterType=false, reportPrivateUsage=false, reportArgumentType=false, reportOperatorIssue=false
+
+from __future__ import annotations
+
+from typing import Any, Dict
+
+import mlx.core as mx
+import mlx.nn as nn
+import pytest
+
+from exo.worker.engines.mlx.vendor.qwen3_5_mtp import (
+ Model,
+ ModelArgs,
+ MTPModule,
+ MTPWeightsNotFound,
+ TextModelArgs,
+ _classify_mtp_key_set,
+ _quantize_mtp_module,
+)
+
+# ---------------------------------------------------------------------------
+# Test fixtures: minimal "tiny" Qwen3.6-shaped config.
+# ---------------------------------------------------------------------------
+
+
+def _tiny_text_config(*, with_mtp: bool = True) -> Dict[str, Any]:
+ """Tiny config (32 hidden, 1 layer per group, 1 MTP layer) for fast tests.
+
+ The shapes are chosen to be self-consistent with the Qwen3.6 architecture
+ constraints (head dims, GQA, full_attention_interval) while staying
+ extremely small.
+ """
+ cfg: Dict[str, Any] = {
+ "model_type": "qwen3_5_text",
+ "hidden_size": 64,
+ "intermediate_size": 128,
+ "num_hidden_layers": 4,
+ "num_attention_heads": 4,
+ "rms_norm_eps": 1e-6,
+ "vocab_size": 256,
+ "num_key_value_heads": 2,
+ "max_position_embeddings": 256,
+ "linear_num_value_heads": 4,
+ "linear_num_key_heads": 2,
+ "linear_key_head_dim": 16,
+ "linear_value_head_dim": 16,
+ "linear_conv_kernel_dim": 4,
+ "tie_word_embeddings": False,
+ "attention_bias": False,
+ "head_dim": 16,
+ "full_attention_interval": 2,
+ "num_experts": 0,
+ "rope_parameters": {
+ "type": "default",
+ "rope_theta": 10000.0,
+ "partial_rotary_factor": 1.0,
+ },
+ }
+ if with_mtp:
+ cfg["mtp_num_hidden_layers"] = 1
+ return cfg
+
+
+def _tiny_model_args(*, with_mtp: bool = True) -> ModelArgs:
+ return ModelArgs(
+ model_type="qwen3_5",
+ text_config=_tiny_text_config(with_mtp=with_mtp),
+ )
+
+
+# ---------------------------------------------------------------------------
+# Shape tests
+# ---------------------------------------------------------------------------
+
+
+def test_forward_shapes() -> None:
+ """``Model(...)(tokens)`` returns logits of shape (B, T, V)."""
+ args = _tiny_model_args(with_mtp=True)
+ model = Model(args)
+ mx.eval(model.parameters())
+ tokens = mx.array([[1, 2, 3, 4, 5]], dtype=mx.int32)
+ logits = model(tokens)
+ assert isinstance(logits, mx.array)
+ mx.eval(logits)
+ assert logits.shape == (1, 5, 256)
+
+
+def test_forward_return_hidden_shapes() -> None:
+ """``return_hidden=True`` returns (logits, hidden) with hidden = post-norm."""
+ args = _tiny_model_args(with_mtp=True)
+ model = Model(args)
+ mx.eval(model.parameters())
+ tokens = mx.array([[1, 2, 3, 4, 5]], dtype=mx.int32)
+ result = model(tokens, return_hidden=True)
+ assert isinstance(result, tuple)
+ logits, hidden = result
+ mx.eval(logits, hidden)
+ assert logits.shape == (1, 5, 256)
+ assert hidden.shape == (1, 5, 64)
+
+
+def test_mtp_forward_shape() -> None:
+ """``mtp_forward`` returns logits of shape (B, T, V) from hidden + token."""
+ args = _tiny_model_args(with_mtp=True)
+ model = Model(args)
+ mx.eval(model.parameters())
+ hidden = mx.random.uniform(shape=(1, 1, 64))
+ next_token = mx.array([[42]], dtype=mx.int32)
+ cache = model.make_mtp_cache()
+ logits = model.mtp_forward(hidden, next_token, mtp_cache=cache)
+ assert isinstance(logits, mx.array)
+ mx.eval(logits)
+ assert logits.shape == (1, 1, 256)
+
+
+def test_make_mtp_cache_length_matches_mtp_layers() -> None:
+ args = _tiny_model_args(with_mtp=True)
+ model = Model(args)
+ cache = model.make_mtp_cache()
+ assert len(cache) == args.text_config["mtp_num_hidden_layers"]
+
+
+# ---------------------------------------------------------------------------
+# Norm-shift gate idempotence
+# ---------------------------------------------------------------------------
+
+
+def _build_synthetic_main_weights(args: TextModelArgs) -> Dict[str, mx.array]:
+ """Build a synthetic weight dict that matches the main model shape.
+
+ Norms start with mean ~0 (the unshifted form) so we can observe the
+ +1.0 shift kicking in.
+ """
+ hidden = args.hidden_size
+ vocab = args.vocab_size
+ inter = args.intermediate_size
+ head_dim = args.head_dim
+ n_heads = args.num_attention_heads
+ n_kv = args.num_key_value_heads
+ # Use float32 with mean 0 so the shift gate fires deterministically.
+ rng = mx.random.uniform
+ weights: Dict[str, mx.array] = {
+ "model.embed_tokens.weight": rng(shape=(vocab, hidden)) - 0.5,
+ "model.norm.weight": mx.zeros((hidden,), dtype=mx.float32),
+ "lm_head.weight": rng(shape=(vocab, hidden)) - 0.5,
+ }
+ for li in range(args.num_hidden_layers):
+ is_linear = (li + 1) % args.full_attention_interval != 0
+ prefix = f"model.layers.{li}"
+ weights[f"{prefix}.input_layernorm.weight"] = mx.zeros(
+ (hidden,), dtype=mx.float32
+ )
+ weights[f"{prefix}.post_attention_layernorm.weight"] = mx.zeros(
+ (hidden,), dtype=mx.float32
+ )
+ weights[f"{prefix}.mlp.gate_proj.weight"] = rng(shape=(inter, hidden)) - 0.5
+ weights[f"{prefix}.mlp.down_proj.weight"] = rng(shape=(hidden, inter)) - 0.5
+ weights[f"{prefix}.mlp.up_proj.weight"] = rng(shape=(inter, hidden)) - 0.5
+ if is_linear:
+ # GatedDeltaNet weights -- shapes mirror what stock code builds.
+ key_dim = args.linear_num_key_heads * args.linear_key_head_dim
+ value_dim = args.linear_num_value_heads * args.linear_value_head_dim
+ conv_dim = key_dim * 2 + value_dim
+ weights[f"{prefix}.linear_attn.conv1d.weight"] = (
+ rng(shape=(conv_dim, 1, args.linear_conv_kernel_dim)) - 0.5
+ )
+ weights[f"{prefix}.linear_attn.in_proj_qkv.weight"] = (
+ rng(shape=(conv_dim, hidden)) - 0.5
+ )
+ weights[f"{prefix}.linear_attn.in_proj_z.weight"] = (
+ rng(shape=(value_dim, hidden)) - 0.5
+ )
+ weights[f"{prefix}.linear_attn.in_proj_b.weight"] = (
+ rng(shape=(args.linear_num_value_heads, hidden)) - 0.5
+ )
+ weights[f"{prefix}.linear_attn.in_proj_a.weight"] = (
+ rng(shape=(args.linear_num_value_heads, hidden)) - 0.5
+ )
+ weights[f"{prefix}.linear_attn.out_proj.weight"] = (
+ rng(shape=(hidden, value_dim)) - 0.5
+ )
+ weights[f"{prefix}.linear_attn.dt_bias"] = mx.ones(
+ (args.linear_num_value_heads,)
+ )
+ weights[f"{prefix}.linear_attn.A_log"] = mx.zeros(
+ (args.linear_num_value_heads,)
+ )
+ weights[f"{prefix}.linear_attn.norm.weight"] = mx.ones(
+ (args.linear_value_head_dim,)
+ )
+ else:
+ weights[f"{prefix}.self_attn.q_proj.weight"] = (
+ rng(shape=(n_heads * head_dim * 2, hidden)) - 0.5
+ )
+ weights[f"{prefix}.self_attn.k_proj.weight"] = (
+ rng(shape=(n_kv * head_dim, hidden)) - 0.5
+ )
+ weights[f"{prefix}.self_attn.v_proj.weight"] = (
+ rng(shape=(n_kv * head_dim, hidden)) - 0.5
+ )
+ weights[f"{prefix}.self_attn.o_proj.weight"] = (
+ rng(shape=(hidden, n_heads * head_dim)) - 0.5
+ )
+ weights[f"{prefix}.self_attn.q_norm.weight"] = mx.zeros(
+ (head_dim,), dtype=mx.float32
+ )
+ weights[f"{prefix}.self_attn.k_norm.weight"] = mx.zeros(
+ (head_dim,), dtype=mx.float32
+ )
+ return weights
+
+
+def _build_synthetic_mtp_weights(args: TextModelArgs) -> Dict[str, mx.array]:
+ """Build a synthetic MTP weight dict with unshifted norms."""
+ hidden = args.hidden_size
+ inter = args.intermediate_size
+ head_dim = args.head_dim
+ n_heads = args.num_attention_heads
+ n_kv = args.num_key_value_heads
+ rng = mx.random.uniform
+ weights: Dict[str, mx.array] = {
+ "mtp.fc.weight": rng(shape=(hidden, 2 * hidden)) - 0.5,
+ "mtp.pre_fc_norm_hidden.weight": mx.zeros((hidden,), dtype=mx.float32),
+ "mtp.pre_fc_norm_embedding.weight": mx.zeros((hidden,), dtype=mx.float32),
+ "mtp.norm.weight": mx.zeros((hidden,), dtype=mx.float32),
+ "mtp.layers.0.input_layernorm.weight": mx.zeros((hidden,), dtype=mx.float32),
+ "mtp.layers.0.post_attention_layernorm.weight": mx.zeros(
+ (hidden,), dtype=mx.float32
+ ),
+ "mtp.layers.0.self_attn.q_proj.weight": rng(
+ shape=(n_heads * head_dim * 2, hidden)
+ )
+ - 0.5,
+ "mtp.layers.0.self_attn.k_proj.weight": rng(shape=(n_kv * head_dim, hidden))
+ - 0.5,
+ "mtp.layers.0.self_attn.v_proj.weight": rng(shape=(n_kv * head_dim, hidden))
+ - 0.5,
+ "mtp.layers.0.self_attn.o_proj.weight": rng(shape=(hidden, n_heads * head_dim))
+ - 0.5,
+ "mtp.layers.0.self_attn.q_norm.weight": mx.zeros((head_dim,), dtype=mx.float32),
+ "mtp.layers.0.self_attn.k_norm.weight": mx.zeros((head_dim,), dtype=mx.float32),
+ "mtp.layers.0.mlp.gate_proj.weight": rng(shape=(inter, hidden)) - 0.5,
+ "mtp.layers.0.mlp.down_proj.weight": rng(shape=(hidden, inter)) - 0.5,
+ "mtp.layers.0.mlp.up_proj.weight": rng(shape=(inter, hidden)) - 0.5,
+ }
+ return weights
+
+
+def test_norm_shift_idempotent() -> None:
+ """Sanitize shifts mean=0 norms to mean=1; a second sanitize does NOT shift again."""
+ args = _tiny_model_args(with_mtp=True)
+ model = Model(args)
+ main = _build_synthetic_main_weights(model.language_model.args)
+ mtp = _build_synthetic_mtp_weights(model.language_model.args)
+ raw: Dict[str, mx.array] = {}
+ for k, v in main.items():
+ raw["language_model." + k] = v
+ for k, v in mtp.items():
+ raw["language_model." + k] = v
+
+ sanitized = model.sanitize(raw)
+ # After first sanitize: all norm.weights should have mean ~1.
+ norm_keys = [
+ k
+ for k in sanitized
+ if k.endswith(("input_layernorm.weight", "norm.weight", "q_norm.weight"))
+ and sanitized[k].ndim == 1
+ ]
+ assert len(norm_keys) > 0
+ for k in norm_keys:
+ m = float(sanitized[k].astype(mx.float32).mean().item())
+ assert abs(m - 1.0) < 0.05, f"first sanitize on {k}: mean={m}"
+
+ # Second sanitize on the already-shifted dict should NOT shift again.
+ # The sanitized output is already keyed in the canonical
+ # ``language_model.model.X`` form, so we can feed it straight back
+ # to ``model.sanitize`` (the outer Model.sanitize tolerates the
+ # ``language_model.*`` prefix as a pass-through).
+ sanitized2 = model.sanitize(sanitized)
+ for k in norm_keys:
+ m = float(sanitized2[k].astype(mx.float32).mean().item())
+ assert abs(m - 1.0) < 0.05, (
+ f"second sanitize on {k}: mean={m} (double-shifted!)"
+ )
+
+
+# ---------------------------------------------------------------------------
+# Missing-weights diagnostics
+# ---------------------------------------------------------------------------
+
+
+def test_missing_weights_raises() -> None:
+ """If MTP is declared but no MTP weights and no loader: ``MTPWeightsNotFound``."""
+ args = _tiny_model_args(with_mtp=True)
+ model = Model(args)
+ main = _build_synthetic_main_weights(model.language_model.args)
+ raw: Dict[str, mx.array] = {"language_model." + k: v for k, v in main.items()}
+ # Crucially: do NOT install a sidecar loader, do NOT include mtp.* keys.
+ with pytest.raises(MTPWeightsNotFound) as excinfo:
+ model.sanitize(raw)
+ assert "mtp.safetensors" in excinfo.value.candidates
+
+
+def test_no_mtp_declared_does_not_raise() -> None:
+ """If config has ``mtp_num_hidden_layers=0``, sanitize is happy without MTP keys."""
+ args = _tiny_model_args(with_mtp=False)
+ model = Model(args)
+ main = _build_synthetic_main_weights(model.language_model.args)
+ raw: Dict[str, mx.array] = {"language_model." + k: v for k, v in main.items()}
+ sanitized = model.sanitize(raw)
+ assert not any(k.startswith("mtp.") for k in sanitized)
+
+
+# ---------------------------------------------------------------------------
+# Cache-share differential (proves cache seeding matters)
+# ---------------------------------------------------------------------------
+
+
+def test_cache_share_differential() -> None:
+ """Primed-cache MTP differs from fresh-cache MTP after multiple steps."""
+ args = _tiny_model_args(with_mtp=True)
+ model = Model(args)
+ mx.eval(model.parameters())
+
+ # Walk 4 hidden+token pairs through the MTP head with two cache strategies:
+ rng_h1 = mx.random.uniform(shape=(1, 1, 64), key=mx.random.key(7))
+ rng_h2 = mx.random.uniform(shape=(1, 1, 64), key=mx.random.key(11))
+ rng_h3 = mx.random.uniform(shape=(1, 1, 64), key=mx.random.key(13))
+ rng_h4 = mx.random.uniform(shape=(1, 1, 64), key=mx.random.key(17))
+ hiddens = [rng_h1, rng_h2, rng_h3, rng_h4]
+ tokens = [mx.array([[i + 5]], dtype=mx.int32) for i in range(4)]
+
+ # Primed: same cache for all steps
+ primed = model.make_mtp_cache()
+ primed_outputs = []
+ for h, t in zip(hiddens, tokens, strict=True):
+ out = model.mtp_forward(h, t, mtp_cache=primed)
+ mx.eval(out)
+ primed_outputs.append(out)
+
+ # Fresh: new cache each step
+ fresh_outputs = []
+ for h, t in zip(hiddens, tokens, strict=True):
+ fresh = model.make_mtp_cache()
+ out = model.mtp_forward(h, t, mtp_cache=fresh)
+ mx.eval(out)
+ fresh_outputs.append(out)
+
+ # The very first step must agree (both have empty cache).
+ first_diff = float(mx.max(mx.abs(primed_outputs[0] - fresh_outputs[0])).item())
+ assert first_diff < 1e-4, f"first-step primed/fresh diverge: {first_diff}"
+
+ # Later steps must diverge -- proves cache priming is load-bearing.
+ last_diff = float(mx.max(mx.abs(primed_outputs[-1] - fresh_outputs[-1])).item())
+ assert last_diff > 1e-3, (
+ f"last-step primed/fresh agree (diff={last_diff}); cache priming "
+ "isn't actually changing MTP outputs"
+ )
+
+
+# ---------------------------------------------------------------------------
+# Hidden-variant correctness (post-norm vs pre-norm)
+# ---------------------------------------------------------------------------
+
+
+def test_post_vs_pre_norm() -> None:
+ """The two hidden variants give different logits and hidden."""
+ args = _tiny_model_args(with_mtp=True)
+ model = Model(args)
+ mx.eval(model.parameters())
+ hidden = mx.random.uniform(shape=(1, 1, 64))
+ token = mx.array([[42]], dtype=mx.int32)
+ cache_post = model.make_mtp_cache()
+ cache_pre = model.make_mtp_cache()
+ out_post = model.mtp_forward(
+ hidden,
+ token,
+ mtp_cache=cache_post,
+ mtp_hidden_variant="post_norm",
+ return_hidden=True,
+ )
+ out_pre = model.mtp_forward(
+ hidden,
+ token,
+ mtp_cache=cache_pre,
+ mtp_hidden_variant="pre_norm",
+ return_hidden=True,
+ )
+ assert isinstance(out_post, tuple) and isinstance(out_pre, tuple)
+ # Logits are computed from POST-norm in both cases (the lm_head only
+ # ever sees post-norm). What changes is the *returned* hidden which
+ # downstream draft-loops use as the input for the next MTP step.
+ h_post = out_post[1]
+ h_pre = out_pre[1]
+ diff = float(mx.max(mx.abs(h_post - h_pre)).item())
+ assert diff > 1e-3, (
+ f"post-norm and pre-norm hidden variants are identical (diff={diff}); "
+ "variant switch is a no-op"
+ )
+
+
+# ---------------------------------------------------------------------------
+# Quant-policy classification
+# ---------------------------------------------------------------------------
+
+
+def test_classify_unquantized() -> None:
+ keys = (
+ "mtp.fc.weight",
+ "mtp.layers.0.self_attn.q_proj.weight",
+ "mtp.layers.0.self_attn.k_proj.weight",
+ "mtp.norm.weight",
+ )
+ assert _classify_mtp_key_set(keys) == "unquantized"
+
+
+def test_classify_cyankiwi() -> None:
+ keys = (
+ "mtp.fc.weight", # fc unquantized
+ "mtp.layers.0.self_attn.q_proj.weight",
+ "mtp.layers.0.self_attn.q_proj.scales", # presence of attn scales -> cyankiwi
+ "mtp.layers.0.self_attn.q_proj.biases",
+ "mtp.norm.weight",
+ )
+ assert _classify_mtp_key_set(keys) == "cyankiwi"
+
+
+def test_classify_all_quantized() -> None:
+ keys = (
+ "mtp.fc.weight",
+ "mtp.fc.scales", # presence of fc scales -> "all"
+ "mtp.fc.biases",
+ "mtp.layers.0.self_attn.q_proj.weight",
+ "mtp.layers.0.self_attn.q_proj.scales",
+ "mtp.layers.0.self_attn.q_proj.biases",
+ )
+ assert _classify_mtp_key_set(keys) == "all"
+
+
+def test_quantize_mtp_module_cyankiwi_leaves_fc_unquantized() -> None:
+ args = TextModelArgs.from_dict(_tiny_text_config(with_mtp=True))
+ mtp = MTPModule(args, 1)
+ _quantize_mtp_module(mtp, policy="cyankiwi", bits=8, group_size=32)
+ # fc should still be a plain Linear (NOT QuantizedLinear)
+ assert isinstance(mtp.fc, nn.Linear)
+ assert not isinstance(mtp.fc, nn.QuantizedLinear)
+ # attention q_proj should be QuantizedLinear
+ qproj = mtp.layers[0].self_attn.q_proj
+ assert isinstance(qproj, nn.QuantizedLinear)
+
+
+def test_quantize_mtp_module_all_quantizes_fc() -> None:
+ args = TextModelArgs.from_dict(_tiny_text_config(with_mtp=True))
+ mtp = MTPModule(args, 1)
+ _quantize_mtp_module(mtp, policy="all", bits=8, group_size=32)
+ assert isinstance(mtp.fc, nn.QuantizedLinear)
+
+
+def test_quantize_mtp_module_unquantized_is_noop() -> None:
+ args = TextModelArgs.from_dict(_tiny_text_config(with_mtp=True))
+ mtp = MTPModule(args, 1)
+ _quantize_mtp_module(mtp, policy="unquantized", bits=8, group_size=32)
+ assert isinstance(mtp.fc, nn.Linear) and not isinstance(mtp.fc, nn.QuantizedLinear)
+ assert not isinstance(mtp.layers[0].self_attn.q_proj, nn.QuantizedLinear)
diff --git a/src/exo/worker/engines/mlx/vendor/tests/test_qwen3_5_mtp_parity.py b/src/exo/worker/engines/mlx/vendor/tests/test_qwen3_5_mtp_parity.py
new file mode 100644
index 0000000000..640f27abd8
--- /dev/null
+++ b/src/exo/worker/engines/mlx/vendor/tests/test_qwen3_5_mtp_parity.py
@@ -0,0 +1,126 @@
+"""Opt-in numerical parity test against a real Qwen MTP artifact.
+
+Gated on ``MLX_LM_RUN_NETWORK_TESTS=1`` because it requires the model
+checkpoint on disk (~30 GB) and ~10 s of compute per run.
+
+Pass criterion: top-1 agreement between MTP's next-next-token
+prediction and the main lm_head's prediction over a fixed
+post-norm hidden context is >= 60%. This is the regression that
+prevents PR-#1226-style breakage.
+
+The probe walks the model incrementally rather than via a batched
+prefill -- the Qwen3.5 GatedDeltaNet implementation in mlx-lm is not
+strictly per-position causal during batched prefill, which would
+distort the per-position parity numbers (see
+scripts/mtp_parity_probe.py for the falsification probe that
+identified this).
+"""
+
+# pyright: reportAny=false, reportUnknownMemberType=false, reportUnknownArgumentType=false, reportUnknownVariableType=false, reportUnknownParameterType=false, reportMissingParameterType=false, reportUnnecessaryTypeIgnoreComment=false
+
+from __future__ import annotations
+
+import os
+from pathlib import Path
+
+import mlx.core as mx
+import pytest
+
+from exo.worker.engines.mlx.vendor.qwen3_5_mtp_loader import load_mtp_model
+
+_NETWORK_TESTS_FLAG = "MLX_LM_RUN_NETWORK_TESTS"
+_MODEL_PATH_ENV = "EXO_NATIVE_MTP_PARITY_MODEL_PATH"
+_MODEL_DIR_RAW = os.environ.get(_MODEL_PATH_ENV)
+_MODEL_DIR = Path(_MODEL_DIR_RAW) if _MODEL_DIR_RAW else None
+
+# 48 tokens of natural English; incremental forward over this is ~5 s on M5 Max.
+_PROMPT = (
+ "The quick brown fox jumps over the lazy dog. In a small village by the "
+ "river there lived an old clockmaker whose hands could repair any "
+ "broken timepiece. Each morning before the sun rose he would walk "
+ "through the cobbled streets carrying a leather satchel full of tiny "
+ "tools, and the children would wave from their windows as he passed."
+)
+
+_TOP1_FLOOR = 0.60
+
+
+@pytest.mark.slow
+@pytest.mark.skipif(
+ os.environ.get(_NETWORK_TESTS_FLAG) != "1",
+ reason=f"set {_NETWORK_TESTS_FLAG}=1 to run the MTP parity regression",
+)
+@pytest.mark.skipif(
+ _MODEL_DIR is None,
+ reason=f"set {_MODEL_PATH_ENV} to a local Qwen MTP model directory",
+)
+@pytest.mark.skipif(
+ _MODEL_DIR is not None and not _MODEL_DIR.exists(),
+ reason=f"configured Qwen MTP model directory does not exist: {_MODEL_DIR}",
+)
+def test_mtp_top1_agreement_against_lm_head() -> None:
+ """Top-1 agreement with a primed MTP cache must stay above the floor."""
+ try:
+ from transformers import AutoTokenizer # type: ignore[import-untyped]
+ except ImportError:
+ pytest.skip("transformers not installed")
+
+ assert _MODEL_DIR is not None
+ model, _cfg = load_mtp_model(_MODEL_DIR, lazy=False, strict=True)
+ tok = AutoTokenizer.from_pretrained(str(_MODEL_DIR))
+
+ token_ids = tok(_PROMPT, return_tensors="np")["input_ids"][0].tolist()
+ token_ids = token_ids[: min(len(token_ids), 48)]
+ assert len(token_ids) >= 16, f"too few tokens to test: {len(token_ids)}"
+ tokens = mx.array([token_ids], dtype=mx.int32)
+ seq_len = tokens.shape[1]
+
+ # Incremental main forward to get truly causal per-position post-norm
+ # hidden + lm_head logits.
+ text_model = model.language_model
+ inner = text_model.model
+ cache = text_model.make_cache()
+ post_per_pos: list[mx.array] = []
+ logits_per_pos: list[mx.array] = []
+ from mlx_lm.models.base import create_attention_mask, create_ssm_mask
+
+ for t in range(seq_len):
+ one = tokens[:, t : t + 1]
+ h = inner.embed_tokens(one)
+ fa_mask = create_attention_mask(h, cache[inner.fa_idx])
+ ssm_mask = create_ssm_mask(h, cache[inner.ssm_idx])
+ for layer, layer_cache in zip(inner.layers, cache, strict=True):
+ mask = ssm_mask if layer.is_linear else fa_mask
+ h = layer(h, mask=mask, cache=layer_cache)
+ post = inner.norm(h)
+ if text_model.args.tie_word_embeddings:
+ lg = inner.embed_tokens.as_linear(post)
+ else:
+ lg = text_model.lm_head(post)
+ mx.eval(post, lg)
+ post_per_pos.append(post)
+ logits_per_pos.append(lg)
+ post_norm_hidden = mx.concatenate(post_per_pos, axis=1)
+ main_logits = mx.concatenate(logits_per_pos, axis=1)
+ mx.eval(post_norm_hidden, main_logits)
+
+ # Primed-cache MTP walk
+ mtp_cache = model.make_mtp_cache()
+ matches = 0
+ total = 0
+ for t in range(seq_len - 1):
+ h_t = post_norm_hidden[:, t : t + 1, :]
+ next_tok = tokens[:, t + 1 : t + 2]
+ mtp_logits = model.mtp_forward(h_t, next_tok, mtp_cache=mtp_cache)
+ assert isinstance(mtp_logits, mx.array)
+ mx.eval(mtp_logits)
+ mtp_pred = int(mx.argmax(mtp_logits[0, -1]).item())
+ main_pred = int(mx.argmax(main_logits[0, t + 1]).item())
+ if mtp_pred == main_pred:
+ matches += 1
+ total += 1
+ top1 = matches / total
+ assert top1 >= _TOP1_FLOOR, (
+ f"MTP top-1 agreement {top1:.4f} < floor {_TOP1_FLOOR:.2f} "
+ f"(matches={matches}/{total}) -- this is the PR #1226 regression"
+ )
diff --git a/src/exo/worker/runner/llm_inference/batch_generator.py b/src/exo/worker/runner/llm_inference/batch_generator.py
index 098c829e9a..53865d02d1 100644
--- a/src/exo/worker/runner/llm_inference/batch_generator.py
+++ b/src/exo/worker/runner/llm_inference/batch_generator.py
@@ -99,6 +99,12 @@ class SequentialGenerator(Engine):
event_sender: MpSender[Event]
vision_processor: VisionProcessor | None = None
check_for_cancel_every: int = 50
+ # Native-MTP K bounds from the model card (``NativeMTPConfig``). Set by
+ # the builder only when the target loaded as a vendored MTP-aware model;
+ # ``None`` on every other model. ``mlx_generate`` resolves the per-request
+ # K from these.
+ native_mtp_default_k: int | None = None
+ native_mtp_max_k: int | None = None
_cancelled_tasks: set[TaskId] = field(default_factory=set, init=False)
_maybe_queue: list[TextGeneration] = field(default_factory=list, init=False)
@@ -295,6 +301,8 @@ def on_generation_token() -> None:
on_generation_token=on_generation_token,
group=self.group,
vision_processor=self.vision_processor,
+ native_mtp_default_k=self.native_mtp_default_k,
+ native_mtp_max_k=self.native_mtp_max_k,
)
def close(self) -> None:
diff --git a/src/exo/worker/tests/unittests/test_mlx/test_native_mtp_drafter.py b/src/exo/worker/tests/unittests/test_mlx/test_native_mtp_drafter.py
new file mode 100644
index 0000000000..e2c0e4e540
--- /dev/null
+++ b/src/exo/worker/tests/unittests/test_mlx/test_native_mtp_drafter.py
@@ -0,0 +1,551 @@
+"""Unit tests for :mod:`exo.worker.engines.mlx.generator.native_mtp_drafter`.
+
+Exercises:
+
+- ``is_native_mtp_dispatchable`` correctly disambiguates the vendored
+ MTP-capable :class:`vendor.qwen3_5_mtp.Model` from plain ``mx``
+ modules and from ``unittest.mock.MagicMock`` (which auto-creates
+ any attribute on access).
+- :class:`NativeMTPDrafter` constructor validation (rejects ``k<1``)
+ and trivial property surface (``mode``, ``num_draft_tokens``).
+- :func:`prime_mtp_cache_from_prompt` returns ``N-1`` positions
+ primed for prompts of size ``>=2`` and ``0`` otherwise.
+- End-to-end ``stream`` smoke against the tiny synthetic Qwen3.5/6
+ model fixture (mirrors the pattern used in
+ ``vendor/tests/test_qwen3_5_mtp.py``): runs K=1 / K=2 / K=3 and
+ asserts the drafter emits tokens, populates metrics, and respects
+ ``max_tokens``.
+
+These tests use the synthetic tiny model so the whole suite runs in
+seconds without requiring any real model download. The opt-in parity
+test against the real MTPLX artifact lives separately.
+"""
+
+# pyright: reportPrivateUsage=false, reportUnknownMemberType=false, reportAttributeAccessIssue=false, reportUnknownArgumentType=false, reportUnknownVariableType=false, reportAny=false
+
+from __future__ import annotations
+
+from collections.abc import Iterable
+from typing import Any, Callable, Dict, cast
+from unittest.mock import MagicMock
+
+import mlx.core as mx
+import pytest
+
+from exo.worker.engines.mlx.generator.native_mtp_drafter import (
+ NativeMTPDrafter,
+ _eos_ids_from_tokenizer,
+ _gdn_state_history_commit_default,
+ _moe_verifier_policy,
+ _target_post_norm_hidden,
+ is_native_mtp_dispatchable,
+ prime_mtp_cache_from_prompt,
+ prime_mtp_cache_from_prompt_incremental,
+ rebuild_prompt_cache_and_prime_mtp_cache_incremental,
+ rebuild_prompt_cache_incremental,
+)
+from exo.worker.engines.mlx.types import Model as ExoModel
+from exo.worker.engines.mlx.vendor.qwen3_5_mtp import Model, ModelArgs
+
+
+def _tiny_text_config(*, with_mtp: bool = True) -> Dict[str, Any]:
+ """Mirror the tiny config used by ``vendor/tests/test_qwen3_5_mtp.py``."""
+ cfg: Dict[str, Any] = {
+ "model_type": "qwen3_5_text",
+ "hidden_size": 64,
+ "intermediate_size": 128,
+ "num_hidden_layers": 4,
+ "num_attention_heads": 4,
+ "rms_norm_eps": 1e-6,
+ "vocab_size": 256,
+ "num_key_value_heads": 2,
+ "max_position_embeddings": 256,
+ "linear_num_value_heads": 4,
+ "linear_num_key_heads": 2,
+ "linear_key_head_dim": 16,
+ "linear_value_head_dim": 16,
+ "linear_conv_kernel_dim": 4,
+ "tie_word_embeddings": False,
+ "attention_bias": False,
+ "head_dim": 16,
+ "full_attention_interval": 2,
+ "num_experts": 0,
+ "rope_parameters": {
+ "type": "default",
+ "rope_theta": 10000.0,
+ "partial_rotary_factor": 1.0,
+ },
+ }
+ if with_mtp:
+ cfg["mtp_num_hidden_layers"] = 1
+ return cfg
+
+
+def _tiny_model(*, with_mtp: bool = True) -> ExoModel:
+ args = ModelArgs(
+ model_type="qwen3_5", text_config=_tiny_text_config(with_mtp=with_mtp)
+ )
+ model = Model(args)
+ mx.eval(model.parameters())
+ # The vendored ``Model`` subclass satisfies the exo Model Protocol's
+ # ``__call__`` shape modulo the optional ``return_hidden`` kwarg the
+ # native-MTP path uses. Cast through ``object`` to defuse the
+ # basedpyright ``reportInvalidCast`` for not-sufficiently-overlapping
+ # types; runtime stream tests below catch any real Protocol drift.
+ return cast(ExoModel, cast(object, model))
+
+
+class _FakeDetokenizer:
+ """Minimal detokenizer that records tokens added between yields.
+
+ Stream consumers walk ``last_segment`` after every ``add_token`` and
+ then ``finalize`` is called once before the closing yield. We only
+ need an empty ``last_segment`` for the drafter to construct
+ :class:`GenerationResponse` -- the test asserts emitted tokens via
+ ``GenerationResponse.token`` directly, not via the segment string.
+ """
+
+ def __init__(self) -> None:
+ self.tokens: list[int] = []
+ self.last_segment: str = ""
+ self.finalized: bool = False
+
+ def reset(self) -> None:
+ self.tokens = []
+ self.last_segment = ""
+ self.finalized = False
+
+ def add_token(self, token: int) -> None:
+ self.tokens.append(int(token))
+
+ def finalize(self) -> None:
+ self.finalized = True
+
+
+class _FakeTokenizer:
+ """Fake :class:`mlx_lm.tokenizer_utils.TokenizerWrapper` minimal shim."""
+
+ def __init__(self, eos_token_ids: Iterable[int] | None = None) -> None:
+ self.detokenizer: _FakeDetokenizer = _FakeDetokenizer()
+ self.eos_token_ids: Iterable[int] = list(eos_token_ids or [])
+
+
+def _identity_sampler(logits: mx.array) -> mx.array:
+ """Greedy sampler used by smoke tests."""
+ return mx.argmax(logits, axis=-1)
+
+
+def _empty_processors() -> list[Callable[[mx.array, mx.array], mx.array]]:
+ return []
+
+
+# --------------------------------------------------------------------------- #
+# is_native_mtp_dispatchable
+# --------------------------------------------------------------------------- #
+
+
+class TestIsNativeMtpDispatchable:
+ def test_vendored_model_is_dispatchable(self) -> None:
+ model = _tiny_model(with_mtp=True)
+ assert is_native_mtp_dispatchable(model) is True
+
+ def test_vendored_model_without_mtp_layers_is_still_class_dispatchable(
+ self,
+ ) -> None:
+ """The class-level marker doesn't depend on whether MTP layers exist.
+
+ The dispatch-side gate only checks the model class; the runtime
+ check for actual MTP availability is the loader / probe gate.
+ A card without MTP would never produce this class via the loader,
+ so the gate is conservative in the right direction.
+ """
+ model = _tiny_model(with_mtp=False)
+ assert is_native_mtp_dispatchable(model) is True
+
+ def test_magicmock_is_not_dispatchable(self) -> None:
+ """``MagicMock`` auto-creates attributes; the marker check rejects it.
+
+ Pre-fix the dispatcher used ``hasattr(model, "mtp_forward")`` which
+ returned ``True`` for any ``MagicMock`` instance because attribute
+ access auto-creates a child mock. That falsely engaged the
+ NativeMTPDrafter in unrelated routing tests.
+ """
+ fake = MagicMock()
+ assert is_native_mtp_dispatchable(fake) is False
+
+ def test_plain_object_is_not_dispatchable(self) -> None:
+ assert is_native_mtp_dispatchable(object()) is False
+
+ def test_none_is_not_dispatchable(self) -> None:
+ assert is_native_mtp_dispatchable(None) is False
+
+
+# --------------------------------------------------------------------------- #
+# NativeMTPDrafter constructor / property surface
+# --------------------------------------------------------------------------- #
+
+
+class TestNativeMTPDrafterConstructor:
+ def test_rejects_zero_k(self) -> None:
+ with pytest.raises(ValueError, match="k must be >= 1"):
+ NativeMTPDrafter(k=0)
+
+ def test_rejects_negative_k(self) -> None:
+ with pytest.raises(ValueError, match="k must be >= 1"):
+ NativeMTPDrafter(k=-1)
+
+ def test_mode_is_model(self) -> None:
+ drafter = NativeMTPDrafter(k=1)
+ assert drafter.mode == "model"
+
+ def test_num_draft_tokens_matches_k(self) -> None:
+ for k in (1, 2, 3):
+ assert NativeMTPDrafter(k=k).num_draft_tokens == k
+
+ def test_initial_metrics_are_zeroed(self) -> None:
+ drafter = NativeMTPDrafter(k=2)
+ assert drafter.metrics() == {
+ "proposed_draft_tokens": 0,
+ "accepted_draft_tokens": 0,
+ "spec_decode_rounds": 0,
+ }
+
+ def test_metrics_is_a_copy(self) -> None:
+ """``metrics()`` returns a fresh dict so callers can't mutate state."""
+ drafter = NativeMTPDrafter(k=1)
+ snap = drafter.metrics()
+ snap["proposed_draft_tokens"] = 99
+ assert drafter.metrics()["proposed_draft_tokens"] == 0
+
+
+def test_eos_ids_from_tokenizer_accepts_set_eos_ids() -> None:
+ """Real MLX TokenizerWrapper exposes Qwen EOS IDs as a set."""
+ tokenizer = _FakeTokenizer(eos_token_ids={248044, 248046})
+
+ assert sorted(_eos_ids_from_tokenizer(cast(Any, tokenizer))) == [248044, 248046]
+
+
+@pytest.mark.parametrize("model_type", ["qwen3_5", "qwen3_5_text"])
+def test_gdn_state_history_commit_defaults_on_for_dense_qwen3_5(
+ model_type: str,
+) -> None:
+ assert (
+ _gdn_state_history_commit_default(
+ model_type=model_type,
+ moe_verifier_policy="safe",
+ )
+ is True
+ )
+
+
+def test_gdn_state_history_commit_defaults_on_for_route_locked_moe_only() -> None:
+ assert (
+ _gdn_state_history_commit_default(
+ model_type="qwen3_5_moe",
+ moe_verifier_policy="route_locked",
+ )
+ is True
+ )
+ assert (
+ _gdn_state_history_commit_default(
+ model_type="qwen3_5_moe",
+ moe_verifier_policy="safe",
+ )
+ is False
+ )
+
+
+def test_moe_verifier_policy_defaults_to_route_locked(
+ monkeypatch: pytest.MonkeyPatch,
+) -> None:
+ monkeypatch.delenv("EXO_NATIVE_MTP_MOE_VERIFY", raising=False)
+
+ assert _moe_verifier_policy() == "route_locked"
+
+
+def test_moe_verifier_policy_can_fall_back_to_safe(
+ monkeypatch: pytest.MonkeyPatch,
+) -> None:
+ monkeypatch.setenv("EXO_NATIVE_MTP_MOE_VERIFY", "safe")
+
+ assert _moe_verifier_policy() == "safe"
+
+
+# --------------------------------------------------------------------------- #
+# prime_mtp_cache_from_prompt
+# --------------------------------------------------------------------------- #
+
+
+class TestPrimeMTPCacheFromPrompt:
+ def test_returns_zero_for_empty_prompt(self) -> None:
+ model = _tiny_model(with_mtp=True)
+ mtp_cache = model.make_mtp_cache()
+ assert (
+ prime_mtp_cache_from_prompt(
+ model=model, full_prompt_tokens=[], mtp_cache=mtp_cache
+ )
+ == 0
+ )
+
+ def test_returns_zero_for_single_token_prompt(self) -> None:
+ model = _tiny_model(with_mtp=True)
+ mtp_cache = model.make_mtp_cache()
+ assert (
+ prime_mtp_cache_from_prompt(
+ model=model, full_prompt_tokens=[5], mtp_cache=mtp_cache
+ )
+ == 0
+ )
+
+ def test_returns_nminus1_for_normal_prompt(self) -> None:
+ model = _tiny_model(with_mtp=True)
+ mtp_cache = model.make_mtp_cache()
+ primed = prime_mtp_cache_from_prompt(
+ model=model,
+ full_prompt_tokens=[1, 2, 3, 4, 5],
+ mtp_cache=mtp_cache,
+ )
+ assert primed == 4
+
+ def test_primes_advances_mtp_cache_offset(self) -> None:
+ """After priming N positions the MTP cache holds N entries."""
+ model = _tiny_model(with_mtp=True)
+ mtp_cache = model.make_mtp_cache()
+ prompt = [1, 2, 3, 4, 5, 6]
+ primed = prime_mtp_cache_from_prompt(
+ model=model, full_prompt_tokens=prompt, mtp_cache=mtp_cache
+ )
+ assert primed == len(prompt) - 1
+ # The MTP cache is a stock KVCache; its ``offset`` should be the
+ # number of positions written.
+ offsets = [getattr(c, "offset", -1) for c in mtp_cache]
+ assert offsets == [len(prompt) - 1]
+
+
+class TestHiddenWithoutLogits:
+ def test_matches_outer_model_return_hidden_and_cache(self) -> None:
+ model = _tiny_model(with_mtp=True)
+ token = mx.array([[7]], dtype=mx.int32)
+ next_token = mx.array([[8]], dtype=mx.int32)
+
+ outer_cache = list(cast(Any, model).make_cache())
+ _outer_logits, outer_hidden = cast(Any, model)(
+ token,
+ cache=outer_cache,
+ return_hidden=True,
+ )
+ body_cache = list(cast(Any, model).make_cache())
+ body_hidden = _target_post_norm_hidden(
+ model=model,
+ inputs=token,
+ cache=cast(Any, body_cache),
+ )
+ mx.eval(outer_hidden, body_hidden)
+ assert float(mx.max(mx.abs(outer_hidden - body_hidden)).item()) == 0.0
+
+ outer_next_logits, _outer_next_hidden = cast(Any, model)(
+ next_token,
+ cache=outer_cache,
+ return_hidden=True,
+ )
+ body_next_logits, _body_next_hidden = cast(Any, model)(
+ next_token,
+ cache=body_cache,
+ return_hidden=True,
+ )
+ mx.eval(outer_next_logits, body_next_logits)
+ assert float(mx.max(mx.abs(outer_next_logits - body_next_logits)).item()) == 0.0
+
+
+class TestFusedPromptAndMtpPriming:
+ def test_matches_separate_incremental_paths(self) -> None:
+ model = _tiny_model(with_mtp=True)
+ prompt = [1, 2, 3, 4, 5, 6]
+
+ separate_prompt_cache = list(cast(Any, model).make_cache())
+ separate_mtp_cache = list(cast(Any, model).make_mtp_cache())
+ separate_rebuilt = rebuild_prompt_cache_incremental(
+ model=model,
+ full_prompt_tokens=prompt,
+ prompt_cache=cast(Any, separate_prompt_cache),
+ )
+ separate_primed = prime_mtp_cache_from_prompt_incremental(
+ model=model,
+ full_prompt_tokens=prompt,
+ mtp_cache=separate_mtp_cache,
+ )
+
+ fused_prompt_cache = list(cast(Any, model).make_cache())
+ fused_mtp_cache = list(cast(Any, model).make_mtp_cache())
+ fused_rebuilt, fused_primed = (
+ rebuild_prompt_cache_and_prime_mtp_cache_incremental(
+ model=model,
+ full_prompt_tokens=prompt,
+ prompt_cache=cast(Any, fused_prompt_cache),
+ mtp_cache=fused_mtp_cache,
+ )
+ )
+
+ assert fused_rebuilt == separate_rebuilt == len(prompt) - 2
+ assert fused_primed == separate_primed == len(prompt) - 1
+ assert [getattr(c, "offset", -1) for c in fused_prompt_cache] == [
+ getattr(c, "offset", -1) for c in separate_prompt_cache
+ ]
+ assert [getattr(c, "offset", -1) for c in fused_mtp_cache] == [
+ getattr(c, "offset", -1) for c in separate_mtp_cache
+ ]
+
+ def next_logits_and_draft(
+ prompt_cache: list[Any], mtp_cache: list[Any]
+ ) -> tuple[mx.array, mx.array]:
+ prompt_tail = mx.array([prompt[-2:]], dtype=mx.int32)
+ _tail_logits, _tail_hidden = cast(Any, model)(
+ prompt_tail[:, :-1], cache=prompt_cache, return_hidden=True
+ )
+ first_logits, first_hidden = cast(Any, model)(
+ prompt_tail[:, -1:], cache=prompt_cache, return_hidden=True
+ )
+ current_token_arr = mx.argmax(first_logits, axis=-1).astype(mx.int32)
+ mtp_logits, _mtp_hidden = cast(Any, model).mtp_forward(
+ first_hidden[:, -1:, :],
+ current_token_arr,
+ mtp_cache=mtp_cache,
+ return_hidden=True,
+ )
+ mx.eval(first_logits, mtp_logits)
+ return first_logits, mtp_logits
+
+ separate_first, separate_mtp = next_logits_and_draft(
+ separate_prompt_cache, separate_mtp_cache
+ )
+ fused_first, fused_mtp = next_logits_and_draft(
+ fused_prompt_cache, fused_mtp_cache
+ )
+
+ assert float(mx.max(mx.abs(separate_first - fused_first)).item()) == 0.0
+ assert float(mx.max(mx.abs(separate_mtp - fused_mtp)).item()) == 0.0
+
+
+# --------------------------------------------------------------------------- #
+# End-to-end stream smoke tests
+# --------------------------------------------------------------------------- #
+
+
+def _drive_stream(
+ *,
+ drafter: NativeMTPDrafter,
+ model: ExoModel,
+ tokenizer: _FakeTokenizer,
+ prompt_full: list[int],
+ max_tokens: int,
+) -> list[Any]:
+ """Run the K-step drafter end-to-end and collect every yielded response.
+
+ Mirrors what ``mlx_generate`` does at the dispatch site: prefills
+ the prompt cache aligned to ``full_prompt[:-2]``, then hands the
+ last two tokens to ``drafter.stream``.
+ """
+ cache: list[Any] = list(cast(Any, model).make_cache())
+ # exo.prefill equivalent: feed prompt[:-2] (so cache holds N-2 positions).
+ if len(prompt_full) >= 3:
+ prefill = mx.array(prompt_full[:-2], dtype=mx.int32)[None]
+ _ = model(prefill, cache=cache)
+ mx.eval([c.state for c in cache if hasattr(c, "state")])
+ prompt_tail = mx.array(prompt_full[-2:], dtype=mx.int32)
+ responses: list[Any] = []
+ for response in drafter.stream(
+ model=model,
+ tokenizer=cast(Any, tokenizer),
+ prompt=prompt_tail,
+ context_tokens=prompt_full,
+ prompt_cache=cast(Any, cache),
+ max_tokens=max_tokens,
+ sampler=_identity_sampler,
+ logits_processors=_empty_processors(),
+ ):
+ responses.append(response)
+ if response.finish_reason is not None:
+ break
+ return responses
+
+
+@pytest.mark.parametrize("k", [1, 2, 3])
+def test_stream_emits_tokens_and_populates_metrics(k: int) -> None:
+ """K=1/2/3 streams emit tokens, terminate, and stamp metrics."""
+ model = _tiny_model(with_mtp=True)
+ tokenizer = _FakeTokenizer(eos_token_ids=[])
+ drafter = NativeMTPDrafter(k=k)
+ prompt = [1, 2, 3, 4, 5, 6, 7, 8]
+ max_tokens = 8
+
+ responses = _drive_stream(
+ drafter=drafter,
+ model=model,
+ tokenizer=tokenizer,
+ prompt_full=prompt,
+ max_tokens=max_tokens,
+ )
+
+ # At least one response is yielded (the first emitted token from the
+ # initial forward) plus a closing chunk.
+ assert len(responses) >= 2
+ final = responses[-1]
+ assert final.finish_reason in {"stop", "length"}
+ assert final.generation_tokens <= max_tokens
+ assert final.generation_tokens >= 1
+
+ metrics = drafter.metrics()
+ # spec_decode_rounds is 0 when the first emitted token alone covered
+ # max_tokens; for max_tokens=8 with random init the loop should run
+ # at least one round.
+ assert metrics["spec_decode_rounds"] >= 0
+ assert metrics["proposed_draft_tokens"] == k * metrics["spec_decode_rounds"]
+ assert (
+ metrics["accepted_draft_tokens"] >= 0
+ and metrics["accepted_draft_tokens"] <= metrics["proposed_draft_tokens"]
+ )
+
+
+def test_stream_raises_on_non_dispatchable_model() -> None:
+ """A bare MagicMock target fails the dispatch-side gate inside ``stream``."""
+ drafter = NativeMTPDrafter(k=1)
+ fake_model = MagicMock()
+ fake_model.make_mtp_cache.return_value = [MagicMock()]
+ tokenizer = _FakeTokenizer()
+ with pytest.raises(RuntimeError, match="is_native_mtp_dispatchable"):
+ # ``next`` to actually enter the generator body; the dispatchable
+ # check happens before any yield.
+ next(
+ drafter.stream(
+ model=cast(Any, fake_model),
+ tokenizer=cast(Any, tokenizer),
+ prompt=mx.array([1, 2], dtype=mx.int32),
+ context_tokens=[1, 2],
+ prompt_cache=cast(Any, []),
+ max_tokens=4,
+ sampler=_identity_sampler,
+ logits_processors=_empty_processors(),
+ )
+ )
+
+
+def test_stream_stops_at_immediate_eos() -> None:
+ """When the first emitted token IS an EOS, the stream terminates immediately."""
+ model = _tiny_model(with_mtp=True)
+ # Pin every vocab id as EOS so the first emitted token (whatever it
+ # is) terminates the stream. This is the cheapest way to exercise
+ # the immediate-EOS branch without needing to control the model's
+ # output distribution.
+ tokenizer = _FakeTokenizer(eos_token_ids=set(range(256)))
+ drafter = NativeMTPDrafter(k=1)
+ responses = _drive_stream(
+ drafter=drafter,
+ model=model,
+ tokenizer=tokenizer,
+ prompt_full=[1, 2, 3, 4],
+ max_tokens=8,
+ )
+ # The immediate-EOS branch yields a SINGLE response with finish_reason="stop".
+ assert len(responses) == 1
+ assert responses[0].finish_reason == "stop"
+ assert responses[0].generation_tokens == 1
diff --git a/src/exo/worker/tests/unittests/test_mlx/test_spec_cache.py b/src/exo/worker/tests/unittests/test_mlx/test_spec_cache.py
new file mode 100644
index 0000000000..88227234ed
--- /dev/null
+++ b/src/exo/worker/tests/unittests/test_mlx/test_spec_cache.py
@@ -0,0 +1,116 @@
+"""Unit tests for the speculative-decode snapshot/rollback helpers.
+
+These exercise :mod:`exo.worker.engines.mlx.spec_cache` directly with fake
+cache entries that mimic the mlx-lm cache API (``is_trimmable`` / ``trim`` /
+``state`` / ``meta_state``), so no real model or GatedDeltaNet is needed.
+"""
+
+# The snapshot stores cloned cache state as ``tuple[Any, ...]`` by design,
+# so reads off it are Any -- mirror the pragma used by the mlx modules.
+# pyright: reportAny=false
+
+from __future__ import annotations
+
+from typing import Any
+
+import mlx.core as mx
+
+from exo.worker.engines.mlx.spec_cache import (
+ CacheSnapshot,
+ rollback_after_verify,
+ snapshot_untrimmable_cache_lazy,
+)
+
+
+class _TrimmableEntry:
+ """Stands in for an attention KV cache: trimmable, no recurrent state."""
+
+ def __init__(self) -> None:
+ self.trim_calls: list[int] = []
+
+ def is_trimmable(self) -> bool:
+ return True
+
+ def trim(self, n: int) -> None:
+ self.trim_calls.append(n)
+
+
+class _NonTrimmableEntry:
+ """Stands in for a GatedDeltaNet ``ArraysCache``: not trimmable, carries
+ recurrent state in a list that the cache reads from in place."""
+
+ def __init__(self, state: list[Any], meta_state: Any) -> None:
+ self.state = state
+ self.meta_state = meta_state
+
+ def is_trimmable(self) -> bool:
+ return False
+
+
+def _arrays_equal(a: mx.array, b: mx.array) -> bool:
+ return bool(mx.array_equal(a, b))
+
+
+def test_snapshot_skips_trimmable_and_clones_non_trimmable() -> None:
+ trimmable = _TrimmableEntry()
+ non_trimmable = _NonTrimmableEntry(
+ state=[mx.array([1, 2, 3])], meta_state=mx.array([9])
+ )
+ snap = snapshot_untrimmable_cache_lazy([trimmable, non_trimmable])
+
+ assert isinstance(snap, CacheSnapshot)
+ # Trimmable slot carries None (offset trim is sufficient rollback).
+ assert snap.states[0] is None
+ assert snap.meta_states[0] is None
+ # Non-trimmable slot carries a clone of the recurrent state.
+ assert snap.states[1] is not None
+ assert _arrays_equal(snap.states[1][0], mx.array([1, 2, 3]))
+
+
+def test_snapshot_clone_is_isolated_from_later_mutation() -> None:
+ """A snapshot must not alias the live cache: replacing the cache's state
+ leaves (as the verify forward does) must not change the snapshot."""
+ entry = _NonTrimmableEntry(state=[mx.array([1, 2, 3])], meta_state=None)
+ snap = snapshot_untrimmable_cache_lazy([entry])
+
+ # Simulate the verify forward replacing the recurrent-state leaf.
+ entry.state = [mx.array([7, 8, 9])]
+
+ assert _arrays_equal(snap.states[0][0], mx.array([1, 2, 3]))
+
+
+def test_rollback_trims_trimmable_and_restores_non_trimmable() -> None:
+ trimmable = _TrimmableEntry()
+ original_state = [mx.array([1, 2, 3])]
+ entry = _NonTrimmableEntry(state=original_state, meta_state=mx.array([5]))
+ cache: list[Any] = [trimmable, entry]
+
+ snap = snapshot_untrimmable_cache_lazy(cache)
+
+ # Verify forward advances state past the speculative trajectory.
+ entry.state[:] = [mx.array([7, 8, 9])]
+ entry.meta_state = mx.array([6])
+
+ rollback_after_verify(cache, snap, verified_tokens=2)
+
+ # Trimmable entry trimmed by the requested count.
+ assert trimmable.trim_calls == [2]
+ # Non-trimmable recurrent state + meta restored from the snapshot.
+ assert _arrays_equal(entry.state[0], mx.array([1, 2, 3]))
+ assert _arrays_equal(entry.meta_state, mx.array([5]))
+ # Container identity preserved (cache reads from the same list object).
+ assert entry.state is original_state
+
+
+def test_rollback_with_zero_verified_tokens_skips_trim_but_restores() -> None:
+ trimmable = _TrimmableEntry()
+ entry = _NonTrimmableEntry(state=[mx.array([1, 2, 3])], meta_state=None)
+ cache: list[Any] = [trimmable, entry]
+
+ snap = snapshot_untrimmable_cache_lazy(cache)
+ entry.state[:] = [mx.array([7, 8, 9])]
+
+ rollback_after_verify(cache, snap, verified_tokens=0)
+
+ assert trimmable.trim_calls == []
+ assert _arrays_equal(entry.state[0], mx.array([1, 2, 3]))
diff --git a/src/exo/worker/tests/unittests/test_mlx/test_utils_mlx_native_mtp.py b/src/exo/worker/tests/unittests/test_mlx/test_utils_mlx_native_mtp.py
new file mode 100644
index 0000000000..c94e3f2da2
--- /dev/null
+++ b/src/exo/worker/tests/unittests/test_mlx/test_utils_mlx_native_mtp.py
@@ -0,0 +1,302 @@
+"""Tests for the native-MTP loader branch in ``utils_mlx.load_mlx_items``.
+
+The native-MTP path was added so cards declaring
+:class:`exo.shared.models.model_cards.NativeMTPConfig` get loaded
+through :func:`exo.worker.engines.mlx.vendor.qwen3_5_mtp_loader.load_mtp_model`
+instead of the stock ``mlx_lm.utils.load_model``.
+
+The tests use ``monkeypatch`` to stub the heavy MLX call sites so we can
+drive the loader without a real checkpoint or filesystem. Coverage:
+
+- Card declares ``native_mtp`` AND placement is single-node AND probe
+ says "recoverable": ``load_mtp_model`` is called, stock ``load_model``
+ is NOT.
+- Card has no ``native_mtp``: stock ``load_model`` is called.
+- Card declares ``native_mtp`` but the probe returns "stripped" (e.g.
+ mlx-community-style stripped quants): we degrade to the stock loader.
+
+We do NOT exercise the real Qwen3.5/3.6 MTP loader here -- the vendor
+module has its own parity tests. The goal is to verify routing.
+"""
+
+# pyright: reportPrivateUsage=false
+
+from __future__ import annotations
+
+from collections.abc import Generator
+from pathlib import Path
+from typing import cast
+
+import pytest
+
+from exo.shared.models.model_cards import (
+ ModelCard,
+ ModelId,
+ ModelTask,
+ NativeMTPConfig,
+)
+from exo.shared.types.backends import Backend
+from exo.shared.types.common import NodeId
+from exo.shared.types.memory import Memory
+from exo.shared.types.worker.instances import (
+ BoundInstance,
+ InstanceId,
+ MlxRingInstance,
+)
+from exo.shared.types.worker.runners import (
+ RunnerId,
+ ShardAssignments,
+)
+from exo.shared.types.worker.shards import (
+ PipelineShardMetadata,
+ ShardMetadata,
+)
+from exo.worker.engines.mlx import mtp_probe, utils_mlx
+
+
+def _target_card(*, native_mtp: NativeMTPConfig | None = None) -> ModelCard:
+ return ModelCard(
+ model_id=ModelId("Youssofal/Qwen3.6-27B-MTPLX-Optimized-Quality"),
+ storage_size=Memory.from_gb(1.0),
+ n_layers=64,
+ hidden_size=5120,
+ supports_tensor=True,
+ tasks=[ModelTask.TextGeneration],
+ backends=[Backend.MlxMetal],
+ native_mtp=native_mtp,
+ )
+
+
+def _make_single_target_bound_instance(card: ModelCard) -> BoundInstance:
+ target_node = NodeId()
+ target_runner_id = RunnerId()
+ shard = PipelineShardMetadata(
+ model_card=card,
+ device_rank=0,
+ world_size=1,
+ start_layer=0,
+ end_layer=64,
+ n_layers=64,
+ )
+ instance = MlxRingInstance(
+ instance_id=InstanceId(),
+ shard_assignments=ShardAssignments(
+ model_id=card.model_id,
+ runner_to_shard={target_runner_id: cast(ShardMetadata, shard)},
+ node_to_runner={target_node: target_runner_id},
+ ),
+ hosts_by_node={target_node: []},
+ ephemeral_port=60000,
+ )
+ return BoundInstance(
+ instance=instance,
+ bound_runner_id=target_runner_id,
+ bound_node_id=target_node,
+ )
+
+
+_LoadResult = tuple[object, object, object]
+
+
+def _consume_generator(
+ gen: Generator[object, None, _LoadResult],
+) -> _LoadResult:
+ while True:
+ try:
+ next(gen)
+ except StopIteration as stop:
+ return cast(_LoadResult, stop.value)
+
+
+def _recoverable_probe() -> mtp_probe.MtpProbeResult:
+ return mtp_probe.MtpProbeResult(
+ model_declares_mtp=True,
+ mtp_tensors_found=True,
+ mtp_format=mtp_probe.MtpFormat.MTPLX_SEPARATE_FILE,
+ mtp_count=29,
+ mtp_path="/tmp/fake/mtp.safetensors",
+ mtp_tensor_keys=(),
+ )
+
+
+def _stripped_probe() -> mtp_probe.MtpProbeResult:
+ return mtp_probe.MtpProbeResult(
+ model_declares_mtp=True,
+ mtp_tensors_found=False,
+ mtp_format=mtp_probe.MtpFormat.STRIPPED,
+ mtp_count=0,
+ mtp_path=None,
+ mtp_tensor_keys=(),
+ )
+
+
+def _patch_loader_routing(
+ monkeypatch: pytest.MonkeyPatch, *, probe_recoverable: bool
+) -> dict[str, bool]:
+ """Stub the heavy load call sites and return a flag dict recording
+ which loader fired."""
+ calls = {"stock_load_model": False, "native_load_mtp_model": False}
+
+ fake_model: object = object()
+ fake_inner: object = object()
+ fake_tokenizer: object = object()
+
+ def fake_stock_load(
+ _path: object, **_kwargs: object
+ ) -> tuple[object, dict[str, object]]:
+ calls["stock_load_model"] = True
+ return fake_model, {}
+
+ monkeypatch.setattr(utils_mlx, "load_model", fake_stock_load)
+
+ # The native loader is imported lazily inside ``load_mlx_items``; patch
+ # it on the source module so the lazy import sees our stub.
+ from exo.worker.engines.mlx.vendor import qwen3_5_mtp_loader as _mtp_loader_mod
+
+ def fake_native_load(
+ _path: object, **_kwargs: object
+ ) -> tuple[object, dict[str, object]]:
+ calls["native_load_mtp_model"] = True
+ return fake_model, {}
+
+ monkeypatch.setattr(_mtp_loader_mod, "load_mtp_model", fake_native_load)
+
+ def fake_probe(_path: object) -> mtp_probe.MtpProbeResult:
+ return _recoverable_probe() if probe_recoverable else _stripped_probe()
+
+ monkeypatch.setattr(mtp_probe, "probe_mtp_weights", fake_probe)
+
+ def fake_inner_model(_model: object) -> object:
+ return fake_inner
+
+ def fake_layers(_inner: object) -> list[object]:
+ return []
+
+ def fake_get_tokenizer(_path: object, _shard: object) -> object:
+ return fake_tokenizer
+
+ def fake_set_wired_limit(_size: object) -> None:
+ return None
+
+ def fake_build_model_path(_model_id: object) -> Path:
+ return Path("/tmp/fake-model-path")
+
+ monkeypatch.setattr(utils_mlx, "get_inner_model", fake_inner_model)
+ monkeypatch.setattr(utils_mlx, "get_layers", fake_layers)
+ monkeypatch.setattr(utils_mlx, "get_tokenizer", fake_get_tokenizer)
+ monkeypatch.setattr(utils_mlx, "set_wired_limit_for_model", fake_set_wired_limit)
+ monkeypatch.setattr(utils_mlx, "build_model_path", fake_build_model_path)
+
+ import mlx.core as mx_core
+
+ def fake_eval(*_args: object, **_kwargs: object) -> None:
+ return None
+
+ def fake_clear_cache() -> None:
+ return None
+
+ monkeypatch.setattr(mx_core, "eval", fake_eval)
+ monkeypatch.setattr(mx_core, "clear_cache", fake_clear_cache)
+
+ return calls
+
+
+class TestNativeMTPLoaderRouting:
+ """``load_mlx_items`` dispatches the native MTP loader iff the card
+ declares ``native_mtp``, the probe says recoverable, and the
+ placement is single-node.
+ """
+
+ def test_native_path_when_card_declares_and_single_node(
+ self, monkeypatch: pytest.MonkeyPatch
+ ) -> None:
+ calls = _patch_loader_routing(monkeypatch, probe_recoverable=True)
+ card = _target_card(native_mtp=NativeMTPConfig(num_layers=1))
+ bound = _make_single_target_bound_instance(card)
+ _consume_generator(
+ cast(
+ Generator[object, None, _LoadResult],
+ utils_mlx.load_mlx_items(bound, group=None),
+ )
+ )
+ assert calls["native_load_mtp_model"] is True, (
+ "card declares native_mtp and probe says recoverable: the "
+ "loader MUST dispatch through load_mtp_model"
+ )
+ assert calls["stock_load_model"] is False, (
+ "native loader handles weight loading end-to-end; stock "
+ "load_model must not be called when the native path fires"
+ )
+
+ def test_stock_path_when_card_lacks_native_mtp(
+ self, monkeypatch: pytest.MonkeyPatch
+ ) -> None:
+ calls = _patch_loader_routing(monkeypatch, probe_recoverable=True)
+ card = _target_card(native_mtp=None)
+ bound = _make_single_target_bound_instance(card)
+ _consume_generator(
+ cast(
+ Generator[object, None, _LoadResult],
+ utils_mlx.load_mlx_items(bound, group=None),
+ )
+ )
+ assert calls["stock_load_model"] is True
+ assert calls["native_load_mtp_model"] is False, (
+ "the native MTP loader must never fire for cards that don't opt in"
+ )
+
+ def test_stock_fallback_when_probe_says_stripped(
+ self, monkeypatch: pytest.MonkeyPatch
+ ) -> None:
+ """Card declares native_mtp but probe returns STRIPPED → stock path
+ (silent degrade so the request still completes)."""
+ calls = _patch_loader_routing(monkeypatch, probe_recoverable=False)
+ card = _target_card(native_mtp=NativeMTPConfig(num_layers=1))
+ bound = _make_single_target_bound_instance(card)
+ _consume_generator(
+ cast(
+ Generator[object, None, _LoadResult],
+ utils_mlx.load_mlx_items(bound, group=None),
+ )
+ )
+ assert calls["stock_load_model"] is True, (
+ "probe says STRIPPED → fall back to stock loader so the "
+ "request still completes (silent degrade, not a hard fail)"
+ )
+ assert calls["native_load_mtp_model"] is False
+
+
+class TestNativeMtpLoaderEligible:
+ """``_native_mtp_loader_eligible`` is the loader-side gate. It is only
+ called from the single-node (``group is None``) branch."""
+
+ def test_true_for_declared_card_with_recoverable_probe(
+ self, monkeypatch: pytest.MonkeyPatch
+ ) -> None:
+ def fake_probe(_path: object) -> mtp_probe.MtpProbeResult:
+ return _recoverable_probe()
+
+ monkeypatch.setattr(mtp_probe, "probe_mtp_weights", fake_probe)
+ card = _target_card(native_mtp=NativeMTPConfig(num_layers=1))
+ assert utils_mlx._native_mtp_loader_eligible(card, Path("/tmp/fake")) is True
+
+ def test_false_without_native_mtp_skips_probe(self) -> None:
+ """Cards that don't declare native_mtp short-circuit before the
+ probe (which would touch disk)."""
+ card = _target_card(native_mtp=None)
+ assert (
+ utils_mlx._native_mtp_loader_eligible(
+ card, Path("/nonexistent/never-probed")
+ )
+ is False
+ )
+
+ def test_false_when_probe_says_stripped(
+ self, monkeypatch: pytest.MonkeyPatch
+ ) -> None:
+ def fake_probe(_path: object) -> mtp_probe.MtpProbeResult:
+ return _stripped_probe()
+
+ monkeypatch.setattr(mtp_probe, "probe_mtp_weights", fake_probe)
+ card = _target_card(native_mtp=NativeMTPConfig(num_layers=1))
+ assert utils_mlx._native_mtp_loader_eligible(card, Path("/tmp/fake")) is False