diff --git a/app/EXO/EXO/ExoProcessController.swift b/app/EXO/EXO/ExoProcessController.swift index ff756379b1..c2672e4c8c 100644 --- a/app/EXO/EXO/ExoProcessController.swift +++ b/app/EXO/EXO/ExoProcessController.swift @@ -8,6 +8,7 @@ private let hfEndpointKey = "EXOHFEndpoint" private let enableImageModelsKey = "EXOEnableImageModels" private let offlineModeKey = "EXOOfflineMode" private let fastSynchEnabledKey = "EXOFastSynchEnabled" +private let nativeMTPEnabledKey = "EXONativeMTPEnabled" private let onboardingCompletedKey = "EXOOnboardingCompleted" private let defaultModelsDirKey = "EXODefaultModelsDir" private let additionalModelsDirsKey = "EXOAdditionalModelsDirs" @@ -109,6 +110,17 @@ final class ExoProcessController: ObservableObject { UserDefaults.standard.set(fastSynchEnabled, forKey: fastSynchEnabledKey) } } + @Published var nativeMTPEnabled: Bool = { + if UserDefaults.standard.object(forKey: nativeMTPEnabledKey) == nil { + return true + } + return UserDefaults.standard.bool(forKey: nativeMTPEnabledKey) + }() + { + didSet { + UserDefaults.standard.set(nativeMTPEnabled, forKey: nativeMTPEnabledKey) + } + } @Published var defaultModelsDir: String = { return UserDefaults.standard.string(forKey: defaultModelsDirKey) ?? "" }() @@ -366,6 +378,7 @@ final class ExoProcessController: ObservableObject { environment["EXO_OFFLINE"] = "true" } environment["EXO_FAST_SYNCH"] = fastSynchEnabled ? "true" : "false" + environment["EXO_NATIVE_MTP_ENABLED"] = nativeMTPEnabled ? "1" : "0" var paths: [String] = [] if let existing = environment["PATH"], !existing.isEmpty { diff --git a/app/EXO/EXO/Views/SettingsView.swift b/app/EXO/EXO/Views/SettingsView.swift index bb1f1bb798..0786cf8052 100644 --- a/app/EXO/EXO/Views/SettingsView.swift +++ b/app/EXO/EXO/Views/SettingsView.swift @@ -16,6 +16,7 @@ struct SettingsView: View { @State private var pendingEnableImageModels = false @State private var pendingOfflineMode = false @State private var pendingFastSynchEnabled = false + @State private var pendingNativeMTPEnabled = true @State private var pendingDefaultModelsDir: String = "" @State private var pendingAdditionalModelsDirs: String = "" @State private var pendingReadOnlyModelsDirs: String = "" @@ -54,6 +55,7 @@ struct SettingsView: View { pendingEnableImageModels = controller.enableImageModels pendingOfflineMode = controller.offlineMode pendingFastSynchEnabled = controller.fastSynchEnabled + pendingNativeMTPEnabled = controller.nativeMTPEnabled pendingDefaultModelsDir = controller.defaultModelsDir pendingAdditionalModelsDirs = controller.additionalModelsDirs pendingReadOnlyModelsDirs = controller.readOnlyModelsDirs @@ -131,6 +133,15 @@ struct SettingsView: View { .foregroundColor(.secondary) } + Section { + Toggle("Use Native MTP", isOn: $pendingNativeMTPEnabled) + Text( + "Enable native multi-token-prediction speculative decoding for supported model checkpoints. On by default." + ) + .font(.caption) + .foregroundColor(.secondary) + } + Section { HStack { Spacer() @@ -607,6 +618,7 @@ struct SettingsView: View { private var hasModelChanges: Bool { pendingEnableImageModels != controller.enableImageModels + || pendingNativeMTPEnabled != controller.nativeMTPEnabled } private var hasAdvancedChanges: Bool { @@ -630,6 +642,7 @@ struct SettingsView: View { private func applyModelSettings() { controller.enableImageModels = pendingEnableImageModels + controller.nativeMTPEnabled = pendingNativeMTPEnabled restartIfRunning() } diff --git a/dashboard/src/lib/components/ModelPickerGroup.svelte b/dashboard/src/lib/components/ModelPickerGroup.svelte index c09600b076..87917aaf52 100644 --- a/dashboard/src/lib/components/ModelPickerGroup.svelte +++ b/dashboard/src/lib/components/ModelPickerGroup.svelte @@ -9,6 +9,7 @@ capabilities?: string[]; family?: string; is_custom?: boolean; + native_mtp?: { default_k: number; max_k: number } | null; } interface ModelGroup { @@ -127,6 +128,11 @@ } } + // Whether any variant in this group ships native MTP (informational badge). + const hasNativeMtp = $derived( + group.variants.some((v) => v.native_mtp != null), + ); + // Check if this group's model is currently selected (for single-variant groups) const isMainSelected = $derived( !group.hasMultipleVariants && @@ -309,6 +315,15 @@ {/if} {/each} + + {#if hasNativeMtp} + + MTP + + {/if} @@ -523,6 +538,16 @@ {variant.quantization || "default"} + + {#if variant.native_mtp} + + MTP + + {/if} + ModelLis capabilities=card.capabilities, reasoning_dialect=card.reasoning_dialect, context_length=card.context_length, + native_mtp=( + NativeMTPModelInfo( + default_k=card.native_mtp.default_k, + max_k=card.native_mtp.max_k, + ) + if card.native_mtp is not None + else None + ), ) for card in cards ] diff --git a/src/exo/api/tests/test_generation_stats.py b/src/exo/api/tests/test_generation_stats.py new file mode 100644 index 0000000000..9cc467f1d5 --- /dev/null +++ b/src/exo/api/tests/test_generation_stats.py @@ -0,0 +1,78 @@ +"""Regression tests for speculative-decode generation telemetry.""" + +from __future__ import annotations + +from exo.api.types.api import GenerationStats +from exo.shared.types.memory import Memory + + +def _stats( + *, + generation_tokens: int = 100, + accepted: int = 25, + proposed: int = 0, + drafter_model_id: str | None = None, + draft_mode: str | None = None, + drafter_kind: str | None = None, + num_draft_tokens: int | None = None, +) -> GenerationStats: + return GenerationStats( + prompt_tps=0.0, + prompt_tokens=10, + generation_tps=10.0, + generation_tokens=generation_tokens, + peak_memory_usage=Memory.from_bytes(0), + accepted_draft_tokens=accepted, + proposed_draft_tokens=proposed, + drafter_model_id=drafter_model_id, + num_draft_tokens=num_draft_tokens, + draft_mode=draft_mode, # pyright: ignore[reportArgumentType] + drafter_kind=drafter_kind, # pyright: ignore[reportArgumentType] + ) + + +def test_acceptance_fraction_is_none_for_explicit_none_mode() -> None: + stats = _stats(draft_mode="none", accepted=0) + assert stats.drafter_acceptance_fraction is None + + +def test_acceptance_fraction_reports_for_model_runs() -> None: + stats = _stats( + draft_mode="model", + drafter_model_id="some-org/drafter-7b", + accepted=40, + ) + assert stats.drafter_acceptance_fraction == 0.40 + + +def test_acceptance_metrics_report_for_native_mtp_runs_without_drafter_id() -> None: + stats = _stats( + draft_mode="model", + drafter_model_id=None, + drafter_kind="native_mtp", + num_draft_tokens=2, + accepted=25, + proposed=40, + ) + assert stats.drafter_acceptance_fraction == 0.25 + assert stats.drafter_acceptance_rate == 0.625 + + +def test_acceptance_fraction_legacy_payload_without_draft_mode() -> None: + legacy_with_drafter = _stats( + draft_mode=None, + drafter_model_id="legacy-org/drafter", + accepted=10, + ) + legacy_without_drafter = _stats( + draft_mode=None, + drafter_model_id=None, + accepted=0, + ) + assert legacy_with_drafter.drafter_acceptance_fraction is not None + assert legacy_without_drafter.drafter_acceptance_fraction is None + + +def test_acceptance_fraction_zero_generation_tokens_returns_none() -> None: + stats = _stats(generation_tokens=0, accepted=0, draft_mode="model") + assert stats.drafter_acceptance_fraction is None diff --git a/src/exo/api/tests/test_model_list_native_mtp.py b/src/exo/api/tests/test_model_list_native_mtp.py new file mode 100644 index 0000000000..9d11d2c5af --- /dev/null +++ b/src/exo/api/tests/test_model_list_native_mtp.py @@ -0,0 +1,63 @@ +import pytest + +from exo.api.main import API +from exo.shared.models import model_cards +from exo.shared.models.model_cards import ( + ModelCard, + ModelTask, + NativeMTPConfig, +) +from exo.shared.types.backends import Backend +from exo.shared.types.common import ModelId +from exo.shared.types.memory import Memory +from exo.shared.types.state import State + + +def _native_mtp_card(model_id: str, *, default_k: int, max_k: int) -> ModelCard: + return ModelCard( + model_id=ModelId(model_id), + storage_size=Memory.from_mb(1), + n_layers=1, + hidden_size=1, + supports_tensor=True, + tasks=[ModelTask.TextGeneration], + backends=[Backend.MlxMetal], + native_mtp=NativeMTPConfig( + num_layers=1, + default_k=default_k, + max_k=max_k, + ), + ) + + +@pytest.mark.asyncio +async def test_models_response_includes_native_mtp_for_native_mtp_cards( + monkeypatch: pytest.MonkeyPatch, +) -> None: + cards = [ + _native_mtp_card("Jundot/Qwen3.6-27B-oQ8-mtp", default_k=2, max_k=3), + _native_mtp_card( + "alvarolizama/Qwen3.6-35B-A3B-oQ8-mtp", + default_k=1, + max_k=3, + ), + ] + + async def _fake_list_all() -> list[ModelCard]: + return cards + + api = API.__new__(API) + api.state = State() + monkeypatch.setattr(model_cards.card_cache, "list_all", _fake_list_all) + + response = await api.get_models() + + by_id = {item.id: item for item in response.data} + qwen27_native_mtp = by_id["Jundot/Qwen3.6-27B-oQ8-mtp"].native_mtp + qwen35_native_mtp = by_id["alvarolizama/Qwen3.6-35B-A3B-oQ8-mtp"].native_mtp + assert qwen27_native_mtp is not None + assert qwen27_native_mtp.default_k == 2 + assert qwen27_native_mtp.max_k == 3 + assert qwen35_native_mtp is not None + assert qwen35_native_mtp.default_k == 1 + assert qwen35_native_mtp.max_k == 3 diff --git a/src/exo/api/types/__init__.py b/src/exo/api/types/__init__.py index 9cb2f834fa..7aea349ea3 100644 --- a/src/exo/api/types/__init__.py +++ b/src/exo/api/types/__init__.py @@ -40,6 +40,7 @@ from .api import LogprobsContentItem as LogprobsContentItem from .api import ModelList as ModelList from .api import ModelListModel as ModelListModel +from .api import NativeMTPModelInfo as NativeMTPModelInfo from .api import NodePowerStats as NodePowerStats from .api import PlaceInstanceParams as PlaceInstanceParams from .api import PlacementPreview as PlacementPreview diff --git a/src/exo/api/types/api.py b/src/exo/api/types/api.py index 8cfa10dd1a..c994d3dca0 100644 --- a/src/exo/api/types/api.py +++ b/src/exo/api/types/api.py @@ -29,6 +29,16 @@ class ErrorResponse(BaseModel): error: ErrorInfo +class NativeMTPModelInfo(BaseModel): + """Native Multi-Token Prediction metadata surfaced for cards whose + on-disk checkpoint ships MTP weights handled by exo's in-process + draft+verify loop. ``default_k`` is the K used when a request doesn't + override; ``max_k`` is the card-declared upper bound.""" + + default_k: int + max_k: int + + class ModelListModel(BaseModel): id: str object: str = "model" @@ -49,6 +59,9 @@ class ModelListModel(BaseModel): base_model: str = Field(default="") capabilities: list[str] = Field(default_factory=list) reasoning_dialect: ReasoningDialect = "none" + # Native MTP metadata for checkpoints whose card declares exo's + # in-process MTP draft+verify path. ``None`` for all other models. + native_mtp: NativeMTPModelInfo | None = None class ModelList(BaseModel): @@ -167,6 +180,39 @@ class GenerationStats(BaseModel): generation_tokens: int peak_memory_usage: Memory prefix_cache_hit: Literal["none", "partial", "exact"] = "none" + drafter_model_id: str | None = None + accepted_draft_tokens: int = 0 + proposed_draft_tokens: int = 0 + spec_decode_rounds: int = 0 + num_draft_tokens: int | None = None + draft_mode: ( + Literal["model", "pipelined", "ngram", "eagle", "lookahead", "none"] | None + ) = None + drafter_kind: Literal["standard", "mtp", "dflash", "native_mtp"] | None = None + + @property + def drafter_acceptance_fraction(self) -> float | None: + """Fraction of generated tokens that came from accepted drafts.""" + if self.generation_tokens == 0: + return None + if self.draft_mode is None: + if self.drafter_model_id is None and self.drafter_kind is None: + return None + elif self.draft_mode == "none": + return None + return self.accepted_draft_tokens / self.generation_tokens + + @property + def drafter_acceptance_rate(self) -> float | None: + """Classical speculative acceptance rate: accepted / proposed.""" + if self.proposed_draft_tokens == 0: + return None + if self.draft_mode is None: + if self.drafter_model_id is None and self.drafter_kind is None: + return None + elif self.draft_mode == "none": + return None + return self.accepted_draft_tokens / self.proposed_draft_tokens class ImageGenerationStats(BaseModel): diff --git a/src/exo/shared/models/model_cards.py b/src/exo/shared/models/model_cards.py index 8911b9323c..a9c020cb0e 100644 --- a/src/exo/shared/models/model_cards.py +++ b/src/exo/shared/models/model_cards.py @@ -154,6 +154,47 @@ class SamplingDefaults(SamplingValues): non_thinking: SamplingValues | None = None +class NativeMTPConfig(FrozenModel): + """Declares that a target checkpoint ships native Multi-Token Prediction + weights handled by exo's vendored MTP-aware loader. + + Set this field on cards whose on-disk checkpoint exposes recoverable + MTP tensors -- typically MTPLX-format Qwen3.5/3.6 artefacts that ship a + separate ``mtp.safetensors`` sidecar (``mlx_lm_extra_tensors.mtp_file`` + in ``config.json``), or the original HuggingFace layout where the MTP + tensors are embedded in the main shards under the ``mtp.*`` prefix. + + The loader at :mod:`exo.worker.engines.mlx.vendor.qwen3_5_mtp_loader` + consumes this declaration through ``utils_mlx.load_mlx_items``: when + the card has ``native_mtp`` set, the placement is single-node, and + :func:`exo.worker.engines.mlx.mtp_probe.probe_mtp_weights` confirms + the tensors are present on disk, the loader dispatches via + :func:`load_mtp_model` instead of stock ``mlx_lm.utils.load_model``. + + Native MTP is structurally single-node only: the verify forward through + a TP-sharded target would amortise K+1 tokens over K+1 compute units, + eating the MTP speedup. The guard + :func:`exo.worker.engines.mlx.utils_mlx.is_native_mtp_runnable` + enforces this -- the loader silently falls back to the stock load path + on multi-rank instances. + """ + + # Number of MTP transformer layers in this checkpoint. Matches + # ``text_config.mtp_num_hidden_layers`` in ``config.json``. + # Qwen3.6-27B MTPLX ships with 1. + num_layers: PositiveInt + # Default K for draft+verify when a request doesn't override. + default_k: PositiveInt = 3 + # Maximum K the runner is allowed to use. Above this, verify cost + # dominates. + max_k: PositiveInt = 3 + # Hugging Face filename of the separate MTP safetensors when shipped + # in MTPLX format. ``None`` when MTP tensors are embedded in main + # shards (original HuggingFace layout) -- the loader probes both + # paths regardless of this hint. + mtp_file: str | None = None + + class ModelCard(FrozenModel): model_id: ModelId storage_size: Memory @@ -175,6 +216,13 @@ class ModelCard(FrozenModel): is_custom: bool = False vision: VisionCardConfig | None = None sampling_defaults: SamplingDefaults = Field(default_factory=SamplingDefaults) + # Optional declaration that the target checkpoint ships native + # Multi-Token Prediction (MTP) weights handled directly by exo's + # vendored Qwen3.5/3.6 MTP-aware loader. See :class:`NativeMTPConfig` + # for the field-by-field semantics and the loader / placement gates + # that consume it. Single-node only; ``None`` (the default) preserves + # legacy behaviour and is purely additive. + native_mtp: NativeMTPConfig | None = None @model_validator(mode="after") def _autodetect_vision(self) -> "ModelCard": diff --git a/src/exo/worker/engines/mlx/builder.py b/src/exo/worker/engines/mlx/builder.py index af7c75bb9d..0284fc63e8 100644 --- a/src/exo/worker/engines/mlx/builder.py +++ b/src/exo/worker/engines/mlx/builder.py @@ -21,6 +21,8 @@ from exo.worker.runner.llm_inference.tool_parsers import make_mlx_parser from .cache import KVPrefixCache +from .generator.native_mtp_drafter import is_native_mtp_dispatchable +from .native_mtp_config import native_mtp_enabled_from_env from .types import Model from .utils_mlx import ( initialize_mlx, @@ -38,6 +40,11 @@ class MlxBuilder(Builder): tokenizer: TokenizerWrapper | None = None group: mx.distributed.Group | None = None vision_processor: VisionProcessor | None = None + # Native-MTP K bounds captured from the model card at load time, used + # by ``build`` to configure the generator. ``None`` unless the card + # declares ``native_mtp``. + native_mtp_default_k: int | None = None + native_mtp_max_k: int | None = None def connect(self, bound_instance: BoundInstance) -> None: self.group = initialize_mlx(bound_instance) @@ -48,6 +55,10 @@ def load(self, bound_instance: BoundInstance) -> Generator[ModelLoadingResponse] self.tokenizer, self.vision_processor, ) = yield from load_mlx_items(bound_instance, self.group) + native_mtp = bound_instance.bound_shard.model_card.native_mtp + if native_mtp is not None: + self.native_mtp_default_k = native_mtp.default_k + self.native_mtp_max_k = native_mtp.max_k def close(self) -> None: with contextlib.suppress(NameError, AttributeError): @@ -83,8 +94,19 @@ def build( kv_prefix_cache = KVPrefixCache(self.group) device_rank = 0 if self.group is None else self.group.rank() - if os.environ.get("EXO_NO_BATCH"): - logger.info("using SequentialGenerator (batching disabled)") + # Native MTP runs the single-request draft+verify loop inside + # ``mlx_generate`` (SequentialGenerator), so force sequential mode + # when the target loaded as a vendored MTP-aware model and native + # MTP is enabled. ``is_native_mtp_dispatchable`` is an isinstance + # check, so it is only ever True for single-node MTP checkpoints. + native_mtp_dispatchable = ( + self.group is None + and native_mtp_enabled_from_env() + and is_native_mtp_dispatchable(self.inference_model) + ) + if native_mtp_dispatchable or os.environ.get("EXO_NO_BATCH"): + reason = "native MTP" if native_mtp_dispatchable else "batching disabled" + logger.info(f"using SequentialGenerator ({reason})") return SequentialGenerator( model=self.inference_model, tokenizer=self.tokenizer, @@ -96,6 +118,12 @@ def build( cancel_receiver=self.cancel_receiver, event_sender=self.event_sender, vision_processor=vision_processor, + native_mtp_default_k=( + self.native_mtp_default_k if native_mtp_dispatchable else None + ), + native_mtp_max_k=( + self.native_mtp_max_k if native_mtp_dispatchable else None + ), ) else: logger.info("using BatchGenerator") diff --git a/src/exo/worker/engines/mlx/generator/generate.py b/src/exo/worker/engines/mlx/generator/generate.py index 2e3d051251..ff54a1c7fb 100644 --- a/src/exo/worker/engines/mlx/generator/generate.py +++ b/src/exo/worker/engines/mlx/generator/generate.py @@ -6,6 +6,9 @@ from typing import Callable, Generator, cast, get_args import mlx.core as mx +from mlx_lm.generate import ( + GenerationResponse as MlxGenerationResponse, +) from mlx_lm.generate import ( maybe_quantize_kv_cache, stream_generate, @@ -55,7 +58,15 @@ KV_GROUP_SIZE, MAX_TOKENS, ) +from exo.worker.engines.mlx.generator.native_mtp_drafter import ( + NativeMTPDrafter, + is_native_mtp_dispatchable, +) from exo.worker.engines.mlx.generator.remote_prefill import remote_prefill +from exo.worker.engines.mlx.native_mtp_config import ( + native_mtp_enabled_from_env, + resolve_native_mtp_num_draft_tokens, +) from exo.worker.engines.mlx.types import KVCacheType, Model from exo.worker.engines.mlx.utils_mlx import ( apply_chat_template, @@ -540,6 +551,8 @@ def mlx_generate( distributed_prompt_progress_callback: Callable[[], None] | None = None, on_generation_token: Callable[[], None] | None = None, vision_processor: VisionProcessor | None = None, + native_mtp_default_k: int | None = None, + native_mtp_max_k: int | None = None, ) -> Generator[GenerationResponse]: # Ensure that generation stats only contains peak memory for this generation mx.reset_peak_memory() @@ -717,8 +730,59 @@ def mlx_generate( logger.info("Starting decode") mx_barrier(group) - for completion_tokens, out in enumerate( - stream_generate( + # Native MTP dispatch: when the target loaded as a vendored MTP-aware + # model (single-node, env-enabled), draft+verify through its own MTP + # head. ``NativeMTPDrafter.stream`` is a drop-in for ``stream_generate`` + # -- it yields the same ``GenerationResponse`` objects, primes its own + # MTP cache from the full prompt, and consumes the KV cache that + # ``prefill`` already aligned to ``prompt_tokens[:-2]`` (the 2-token + # ``last_token`` tail is the decode seed for both paths). + native_mtp_active = ( + group is None + and native_mtp_enabled_from_env() + and is_native_mtp_dispatchable(model) + ) + native_mtp_drafter: NativeMTPDrafter | None = None + native_k: int | None = None + if native_mtp_active: + native_k, native_k_clamped = resolve_native_mtp_num_draft_tokens( + request_num_draft_tokens=None, + configured_num_draft_tokens=None, + card_default_k=native_mtp_default_k + if native_mtp_default_k is not None + else 1, + card_max_k=native_mtp_max_k if native_mtp_max_k is not None else 1, + ) + if native_k_clamped: + logger.info(f"native MTP K clamped to {native_k}") + full_context_tokens = [ + int(t) for t in cast(list[int], all_prompt_tokens.tolist()) + ] + logger.info(f"native MTP dispatch: K={native_k}, single-node") + native_mtp_drafter = NativeMTPDrafter(k=native_k) + token_stream: Generator[MlxGenerationResponse] = native_mtp_drafter.stream( + model=model, + tokenizer=tokenizer, + prompt=last_token, + context_tokens=full_context_tokens, + prompt_cache=caches, + max_tokens=max_tokens, + sampler=sampler, + logits_processors=logits_processors, + prefill_step_size=1, + # Lossless speculative sampling: thread the raw sampling params so + # the drafter can reconstruct the request's distribution (the + # ``sampler`` callable alone is not enough -- it only yields a + # token, not the per-step distribution the accept test needs). + # Defaults mirror the ``make_sampler`` call above so MTP-on and + # MTP-off sample from the same distribution. + temperature=task.temperature if task.temperature is not None else 0.7, + top_p=task.top_p if task.top_p is not None else 1.0, + top_k=task.top_k if task.top_k is not None else 0, + min_p=task.min_p if task.min_p is not None else 0.05, + ) + else: + token_stream = stream_generate( model=model, tokenizer=tokenizer, prompt=last_token, @@ -729,9 +793,9 @@ def mlx_generate( prefill_step_size=1, kv_group_size=KV_GROUP_SIZE, kv_bits=KV_BITS, - ), - start=1, - ): + ) + + for completion_tokens, out in enumerate(token_stream, start=1): generated_text_parts.append(out.text) accumulated_text += out.text @@ -758,12 +822,27 @@ def mlx_generate( stats: GenerationStats | None = None if is_done: + native_mtp_metrics = ( + native_mtp_drafter.metrics() + if native_mtp_drafter is not None and native_k is not None + else {} + ) stats = GenerationStats( prompt_tps=float(prefill_tps or out.prompt_tps), generation_tps=float(out.generation_tps), prompt_tokens=int(prefill_tokens + out.prompt_tokens), generation_tokens=int(out.generation_tokens), peak_memory_usage=Memory.from_gb(out.peak_memory), + accepted_draft_tokens=native_mtp_metrics.get( + "accepted_draft_tokens", 0 + ), + proposed_draft_tokens=native_mtp_metrics.get( + "proposed_draft_tokens", 0 + ), + spec_decode_rounds=native_mtp_metrics.get("spec_decode_rounds", 0), + num_draft_tokens=native_k if native_mtp_drafter is not None else None, + draft_mode="model" if native_mtp_drafter is not None else "none", + drafter_kind="native_mtp" if native_mtp_drafter is not None else None, ) if not stop_matched and out.finish_reason not in get_args(FinishReason): logger.warning( diff --git a/src/exo/worker/engines/mlx/generator/native_mtp_drafter.py b/src/exo/worker/engines/mlx/generator/native_mtp_drafter.py new file mode 100644 index 0000000000..1b784ca029 --- /dev/null +++ b/src/exo/worker/engines/mlx/generator/native_mtp_drafter.py @@ -0,0 +1,1890 @@ +"""Native Multi-Token Prediction drafter for Qwen3.5/3.6. + +Native MTP drafts come from the target model's own MTP head: a single +transformer layer that consumes the target's post-norm hidden state + +the previously-sampled token and predicts the next token via the shared +``lm_head``. No sibling drafter LM is loaded and no separate model KV +cache exists -- the MTP layer has its own KV cache that's populated from +prompt prefill and extended on every accepted draft. + +**Cache rollback (Qwen3.6 GatedDeltaNet)** + +``mlx_trim_prompt_cache`` decrements KV offsets but does NOT roll back +the GatedDeltaNet linear-attention recurrent state on Qwen3.5/3.6 +targets, so a bare trim after the batched verify leaves the GDN cache +on the SPECULATIVE trajectory (the rejected drafts have been folded +into the recurrent state and can't be undone via offset +arithmetic). The ``trim_after_batched_probe`` measurement showed +54.2% logit mismatch (Logit L2 distances 300-1000) without the fix. + +The loop snapshots GDN state before the batched verify; on partial accept it +trims the main cache fully, restores GDN from the snapshot, and re-forwards +``[current_token, *drafts[:n_accepted]]`` so the cache is advanced through +retained tokens only. On full accept the cache is already on-trajectory, so +rollback is skipped. + +The MTP cache is full-attention (``vendor/qwen3_5_mtp.py`` +``fa_layer_idx``), so its trim path uses plain +:func:`mlx_trim_prompt_cache` -- no rollback needed. + +The loop: + +1. **Prime** the MTP KV cache by walking prompt positions through the + main model on a fresh cache to capture post-norm hiddens, then + feeding those into ``model.mtp_update_cache`` so the MTP attention + sees prompt context before the first draft fires. +2. **Draft** K tokens by chaining ``model.mtp_forward`` -- each step + consumes the previous step's post-norm hidden so the MTP layer + builds a self-referential chain. K is bounded by + ``card.native_mtp.max_k`` upstream (default 3). +3. **Verify** all K drafts in a single ``(K+1)``-token target forward. + Walk the target's preferred tokens; accept while they match drafts, + emit one bonus from the target's choice at the first mismatch (or + after a full-accept). Trim both caches by ``K - num_accepted``. + +Native MTP is single-node-only by design (the loader gate +:func:`exo.worker.engines.mlx.utils_mlx._native_mtp_loader_eligible` +only fires on single-node placements): a TP-sharded verify forward +amortises ``K+1`` tokens over ``K+1`` compute units, eating the MTP +speedup. +""" + +# pyright: reportUnknownMemberType=false, reportUnknownArgumentType=false, reportUnknownVariableType=false, reportAttributeAccessIssue=false, reportCallIssue=false, reportArgumentType=false, reportAny=false +# This module composes over untyped mlx-lm tokenizer / vendor MTP +# attributes; the type-checker can't see through nn.Module subclassing +# in the way mlx-lm uses it. Mirrors the same pragma block as +# ``vendor/qwen3_5_mtp.py``. + +from __future__ import annotations + +import os +import time +from collections.abc import Callable, Generator, Iterable, Sequence +from typing import Any, Literal, cast, final + +import mlx.core as mx +import mlx.nn as nn +from mlx.nn.layers.distributed import sum_gradients +from mlx_lm.generate import GenerationResponse +from mlx_lm.models.base import create_attention_mask, create_ssm_mask +from mlx_lm.models.cache import trim_prompt_cache as mlx_trim_prompt_cache +from mlx_lm.sample_utils import apply_min_p, apply_top_k, apply_top_p +from mlx_lm.tokenizer_utils import TokenizerWrapper + +from exo.worker.engines.mlx.spec_cache import ( + restore_cache, + rollback_after_verify, + snapshot_untrimmable_cache_lazy, +) +from exo.worker.engines.mlx.types import KVCacheType, Model +from exo.worker.runner.bootstrap import logger + +_MoeVerifierPolicy = Literal[ + "safe", + "hybrid_full_accept", + "batched", + "row_moe", + "route_locked", +] +_DENSE_GDN_STATE_HISTORY_MODEL_TYPES = frozenset({"qwen3_5", "qwen3_5_text"}) +_MOE_GDN_STATE_HISTORY_MODEL_TYPES = frozenset({"qwen3_5_moe", "qwen3_5_moe_text"}) + + +def is_native_mtp_dispatchable(model: object) -> bool: + """Return ``True`` when ``model`` is a vendored native-MTP Model instance. + + The loader (:mod:`exo.worker.engines.mlx.vendor.qwen3_5_mtp_loader`) + swaps in :class:`vendor.qwen3_5_mtp.Model` (or its + ``_PatchedMtpModel`` subclass) when the card declares ``native_mtp`` + and the on-disk weights are recoverable. We ``isinstance``-check + against that concrete class rather than probe for the + ``mtp_forward`` / ``make_mtp_cache`` / ``mtp_update_cache`` methods + structurally because ``unittest.mock.MagicMock()`` auto-creates any + attribute on access -- a structural check would falsely engage the + NativeMTPDrafter for any mocked target model and break unrelated + routing tests. The import is local to keep this module's import + surface small for workers that never load Qwen3.5/3.6. + + This is the dispatch-side gate. The builder-side gate + (:func:`exo.worker.engines.mlx.utils_mlx.is_native_mtp_runnable`) + consults the card + placement before the model is loaded. + """ + try: + from exo.worker.engines.mlx.vendor.qwen3_5_mtp import Model as _VendorMtpModel + except ImportError: + return False + return isinstance(model, _VendorMtpModel) + + +def _eos_ids_from_tokenizer(tokenizer: TokenizerWrapper) -> list[int]: + """Tokenizer-agnostic EOS lookup mirroring drafter.py / coupled_drafter.py.""" + eos = getattr(tokenizer, "eos_token_ids", None) + if eos is None: + eos = getattr(tokenizer, "eos_token_id", None) + if eos is None: + return [] + if isinstance(eos, int): + return [eos] + if isinstance(eos, Iterable) and not isinstance(eos, (str, bytes)): + return [int(x) for x in eos] + try: + return [int(eos)] + except (TypeError, ValueError): + return [] + + +def _token_array_to_int(token: mx.array) -> int: + """Read a single-token ``mx.array`` as a Python ``int``.""" + return int(token.reshape(()).item()) + + +@final +class _SpecSampler: + """Reconstructs the request's sampling distribution for lossless MTP. + + The drafter only receives a ``sampler`` (``logits -> token``), which is + not enough for probability-ratio (Leviathan/Chen) acceptance: that needs + the *distribution* the request samples from. We rebuild it by replicating + mlx-lm's ``make_sampler`` chain (``top_p`` / ``min_p`` / ``top_k`` masking, + then ``/temperature`` and a softmax) so the draft distribution ``q`` and + the target distribution ``p`` are computed with identical transforms. + + Mirrors ``mtp_generate_step`` in mlx-lm PR #990 (the PR this drafter + derives from): ``q``/``p`` are the temperature-adjusted, filtered + distributions; draft tokens are sampled from ``q``; a draft token ``x`` is + accepted with probability ``min(1, p(x)/q(x))``; on rejection the + replacement is drawn from the normalized residual ``(p - q)+``; after a + full accept the bonus is drawn from ``p``. With ``temperature == 0`` this + reduces to "accept iff ``x == argmax(p)``", recovering exact greedy. + + Note: ``logits_processors`` (repetition/presence/frequency penalties, + logit bias) are intentionally NOT applied here. They are stateful over the + emitted-token history and cannot be replayed consistently across the K + chained MTP draft steps and the batched K+1 verify, so applying them + inconsistently would BREAK the lossless guarantee rather than help it. + Penalties stay unhandled on the native-MTP path (as before); only the + request's temperature/top_p/top_k/min_p now govern the distribution. + """ + + def __init__( + self, + *, + temperature: float, + top_p: float, + top_k: int, + min_p: float, + min_tokens_to_keep: int, + ) -> None: + self._temperature: float = temperature + self._is_greedy: bool = temperature <= 0.0 + self._top_p: float = top_p + self._top_k: int = top_k + self._min_p: float = min_p + self._min_tokens_to_keep: int = max(1, min_tokens_to_keep) + + @property + def is_greedy(self) -> bool: + return self._is_greedy + + def _filter(self, logprobs: mx.array) -> mx.array: + """Apply the request's masking filters (top_p / min_p / top_k).""" + masked = logprobs + if 0.0 < self._top_p < 1.0: + masked = apply_top_p(masked, self._top_p) + if self._min_p != 0.0: + masked = apply_min_p(masked, self._min_p, self._min_tokens_to_keep) + if self._top_k > 0: + masked = apply_top_k(masked, self._top_k) + return masked + + def distribution(self, last_logits: mx.array) -> mx.array: + """Return the request's sampling distribution ``p`` for ``(1, V)`` logits. + + ``last_logits`` is the final-position logit row, shape ``(1, V)``. The + returned array is normalized probabilities over the vocab with masked + entries at exactly ``0`` (temperature-adjusted, filtered). + """ + logprobs = last_logits - mx.logsumexp(last_logits, axis=-1, keepdims=True) + if self._is_greedy: + # Degenerate distribution at the argmax; the ratio test below + # reduces to "x == argmax(p)" and the residual to argmax(p). + argmax = mx.argmax(logprobs, axis=-1) + return mx.zeros_like(logprobs).at[0, argmax].add(1.0) + masked = self._filter(logprobs) + scaled = masked / self._temperature + accept_logprobs = scaled - mx.logsumexp(scaled, axis=-1, keepdims=True) + return mx.exp(accept_logprobs) + + def sample_from_distribution(self, distribution: mx.array) -> mx.array: + """Sample one token id ``(1, 1)`` int32 from a probability row ``(1, V)``.""" + if self._is_greedy: + return mx.argmax(distribution, axis=-1).reshape(1, 1).astype(mx.int32) + # ``mx.log(0) -> -inf`` which categorical treats as probability 0. + return ( + mx.random.categorical(mx.log(distribution)).reshape(1, 1).astype(mx.int32) + ) + + def residual_token(self, target_dist: mx.array, draft_dist: mx.array) -> mx.array: + """Sample the replacement token from the residual ``(p - q)+`` (or ``p``).""" + if self._is_greedy: + return mx.argmax(target_dist, axis=-1).reshape(1, 1).astype(mx.int32) + residual = mx.maximum(target_dist - draft_dist, 0.0) + normalizer = residual.sum(axis=-1, keepdims=True) + # When the residual mass is ~0 (draft already covered all target mass), + # fall back to sampling from the target distribution directly. + distribution = mx.where(normalizer > 0, residual, target_dist) + return self.sample_from_distribution(distribution) + + +def _accept_draft_token( + *, + is_greedy: bool, + draft_token: int, + target_argmax: int, + target_prob: float, + draft_prob: float, + uniform: float, +) -> bool: + """Probability-ratio acceptance for one drafted token. + + Greedy: accept iff the drafted token equals the target's argmax. + Stochastic: accept with probability ``min(1, p(x)/q(x))`` against a + pre-drawn ``uniform`` in ``[0, 1)`` (Leviathan/Chen). ``draft_prob`` is + the draft head's probability ``q(x)`` for the drafted token; a draft that + the target assigns zero mass (``p(x) == 0``, e.g. masked out by top_k) is + always rejected. + """ + if is_greedy: + return draft_token == target_argmax + if draft_prob <= 0.0: + # The draft token was not in the draft's own support; treat as a hard + # reject so the residual replacement (which equals p here) is emitted. + return False + return uniform < (target_prob / draft_prob) + + +def _spec_accept_walk( + *, + spec_sampler: _SpecSampler, + verify_logits: mx.array, + draft_tokens: Sequence[int], + draft_distributions: Sequence[mx.array], + accept_uniforms: Sequence[float], + k: int, +) -> tuple[int, mx.array]: + """Probability-ratio accept walk over a batched ``(1, K+1, V)`` verify. + + Returns ``(num_accepted, replacement_token_arr)`` where + ``replacement_token_arr`` is the ``(1, 1)`` int32 token emitted after the + accepted prefix -- the residual sample at the first rejected position, or + the target bonus sample at position ``num_accepted == K`` (full accept). + The caller's cache machinery already handles trimming/rollback to land the + cache on ``[current_token, *drafts[:num_accepted]]``. + """ + # Per-position target distributions p_0..p_K. + target_dists = [ + spec_sampler.distribution(verify_logits[:, i, :]) for i in range(k + 1) + ] + # Consolidated readback: single sync for all K+1 argmax predictions. + target_argmax_arr = mx.argmax(verify_logits, axis=-1).astype(mx.int32) + mx.eval(target_argmax_arr) + target_argmaxes: list[int] = [ + int(t) for t in list(target_argmax_arr.squeeze(0).tolist()) + ] + num_accepted = 0 + for i in range(k): + draft_token = draft_tokens[i] + target_prob = float(target_dists[i][0, draft_token].item()) + draft_prob = float(draft_distributions[i][0, draft_token].item()) + if _accept_draft_token( + is_greedy=spec_sampler.is_greedy, + draft_token=draft_token, + target_argmax=target_argmaxes[i], + target_prob=target_prob, + draft_prob=draft_prob, + uniform=accept_uniforms[i], + ): + num_accepted += 1 + continue + # Rejected at position i: replacement from residual (p_i - q_i)+. + replacement = spec_sampler.residual_token( + target_dists[i], draft_distributions[i] + ) + return num_accepted, replacement + # Full accept: bonus sampled from the target distribution at position K. + bonus = spec_sampler.sample_from_distribution(target_dists[k]) + return num_accepted, bonus + + +def _native_mtp_model_type(model: Model) -> str: + """Best-effort model_type lookup for vendored Qwen3.5/3.6 models.""" + outer_args = getattr(model, "args", None) + outer_model_type = getattr(outer_args, "model_type", None) + if isinstance(outer_model_type, str): + return outer_model_type + language_model = getattr(model, "language_model", None) + args = getattr(language_model, "args", None) + model_type = getattr(args, "model_type", None) + return model_type if isinstance(model_type, str) else "" + + +def _requires_sequential_gdn_path(model: Model) -> bool: + """Return True for 35B-A3B/MoE native MTP, where batched GDN state drifts. + + The dense 27B path has benchmarked well with the batched verifier. The + qwen3_5_moe 35B path does not: batched prefill changes the recurrent + GatedDeltaNet state enough that MTP drafts the wrong first token, and a + batched verifier can produce wrong later-position bonus tokens. Keep this + slow path scoped to MoE until we have a fused GDN scan that is numerically + equivalent to token-by-token recurrent updates. + """ + return _native_mtp_model_type(model) in {"qwen3_5_moe", "qwen3_5_moe_text"} + + +def _gdn_state_history_commit_default( + *, + model_type: str, + moe_verifier_policy: _MoeVerifierPolicy, +) -> bool: + """Return whether GDN prefix-state commit is default-on for this model.""" + if model_type in _DENSE_GDN_STATE_HISTORY_MODEL_TYPES: + return True + return ( + model_type in _MOE_GDN_STATE_HISTORY_MODEL_TYPES + and moe_verifier_policy == "route_locked" + ) + + +def _gdn_state_history_kernel_supported(model: Model) -> bool: + """Return True when the loaded GDN layers fit the state-history kernel.""" + language_model = getattr(model, "language_model", None) + inner = getattr(language_model, "model", None) + layers = getattr(inner, "layers", None) + if not isinstance(layers, list): + return False + for layer in layers: + if not bool(getattr(layer, "is_linear", False)): + continue + gdn = getattr(layer, "linear_attn", None) + head_k_dim = getattr(gdn, "head_k_dim", None) + head_v_dim = getattr(gdn, "head_v_dim", None) + if not isinstance(head_k_dim, int) or not isinstance(head_v_dim, int): + return False + return head_k_dim >= 32 and head_k_dim % 32 == 0 and head_v_dim > 0 + return False + + +def _moe_verifier_policy() -> _MoeVerifierPolicy: + """Return the qwen3_5_moe verifier policy for native-MTP experiments. + + ``route_locked`` is the optimized production default for Qwen3.6 MoE. + ``safe`` is the coherent reference path: token-by-token verification. + ``hybrid_full_accept`` first tries a batched verifier and commits it only + when every draft token matches; partial/miss rounds fall back to ``safe``. + ``batched`` is the research fast path for falsifying whether the batched + verifier can be made coherent under the current cache repair rules. + ``row_moe`` keeps attention/GatedDeltaNet batched but evaluates sparse + MoE blocks one token row at a time, matching the token-by-token verifier + while avoiding full sequential layer execution. ``route_locked`` row-steps + only the router gate, then runs the fixed-index expert and shared-expert + work batched. + """ + raw = os.environ.get("EXO_NATIVE_MTP_MOE_VERIFY", "route_locked") + normalized = raw.strip().lower().replace("-", "_") + if normalized in { + "safe", + "hybrid_full_accept", + "batched", + "row_moe", + "route_locked", + }: + return cast(_MoeVerifierPolicy, normalized) + logger.warning( + f"[native MTP] unknown EXO_NATIVE_MTP_MOE_VERIFY={raw!r}; using 'safe'" + ) + return "safe" + + +def _env_flag(name: str, *, default: bool) -> bool: + raw = os.environ.get(name) + if raw is None: + return default + return raw.strip().lower() in {"1", "true", "yes", "on"} + + +def _env_int(name: str, *, default: int, minimum: int) -> int: + raw = os.environ.get(name) + if raw is None: + return default + try: + value = int(raw) + except ValueError: + logger.warning(f"[native MTP] invalid {name}={raw!r}; using {default}") + return default + return max(minimum, value) + + +def _env_float(name: str, *, default: float) -> float: + raw = os.environ.get(name) + if raw is None: + return default + try: + return float(raw) + except ValueError: + logger.warning(f"[native MTP] invalid {name}={raw!r}; using {default}") + return default + + +_ASYNC_DRAFT_EVAL_ENABLED = _env_flag("EXO_NATIVE_MTP_ASYNC_DRAFT_EVAL", default=True) + + +def _cache_state_arrays(cache: Sequence[Any]) -> list[mx.array]: + """Return materialisable cache state arrays without forcing evaluation.""" + state_arrays: list[mx.array] = [] + for entry in cache: + state_obj: object = getattr(entry, "state", None) + if isinstance(state_obj, mx.array): + state_arrays.append(state_obj) + elif isinstance(state_obj, (list, tuple)): + for inner in cast(Sequence[object], state_obj): + if isinstance(inner, mx.array): + state_arrays.append(inner) + return state_arrays + + +def _materialize_cache_state(cache: Sequence[Any]) -> None: + """Force ``mx.eval`` on every materialisable state in ``cache``. + + KVCache exposes ``.state`` as ``mx.array``; some custom caches use a + tuple/list of arrays. We walk both shapes so the cache writes + complete before the next forward consumes them. + """ + state_arrays = _cache_state_arrays(cache) + if state_arrays: + mx.eval(state_arrays) + + +def _async_materialize_cache_state( + cache: Sequence[Any], + *arrays: mx.array, +) -> None: + """Schedule draft/cache evaluation without waiting immediately. + + MLX evaluation is lazy. K-chain draft work can overlap with Python-side + verifier graph construction as long as we schedule the draft tokens and + MTP cache writes before building the target verify forward. The first + required readback still synchronizes, but it has less work left to wait on. + """ + state_arrays = _cache_state_arrays(cache) + eval_arrays = [*arrays, *state_arrays] + if eval_arrays: + mx.async_eval(eval_arrays) + + +def _is_sparse_moe_block(mlp: object) -> bool: + """Return True for Qwen3.5/Next sparse MoE blocks.""" + return ( + callable(getattr(mlp, "gate", None)) + and callable(getattr(mlp, "switch_mlp", None)) + and isinstance(getattr(mlp, "top_k", None), int) + ) + + +def _row_stepped_sparse_moe(mlp: object, mlp_input: mx.array) -> mx.array: + """Evaluate a Qwen sparse MoE block one verify token at a time.""" + if not callable(mlp): + raise TypeError(f"mlp object is not callable: {type(mlp).__name__}") + return mx.concatenate( + [mlp(mlp_input[:, i : i + 1, :]) for i in range(int(mlp_input.shape[1]))], + axis=1, + ) + + +def _route_locked_sparse_moe(mlp: object, mlp_input: mx.array) -> mx.array: + """Evaluate sparse MoE with exact row-stepped routing and batched experts.""" + gate = getattr(mlp, "gate", None) + switch_mlp = getattr(mlp, "switch_mlp", None) + shared_expert = getattr(mlp, "shared_expert", None) + shared_expert_gate = getattr(mlp, "shared_expert_gate", None) + top_k = getattr(mlp, "top_k", None) + if ( + not callable(gate) + or not callable(switch_mlp) + or not callable(shared_expert) + or not callable(shared_expert_gate) + or not isinstance(top_k, int) + or top_k <= 0 + ): + raise TypeError( + f"mlp object is not a supported sparse MoE: {type(mlp).__name__}" + ) + + sharding_group = getattr(mlp, "sharding_group", None) + x = ( + sum_gradients(sharding_group)(mlp_input) + if sharding_group is not None + else mlp_input + ) + + # Qwen3.6-35B's BF16 router gate takes a different MLX kernel at M>1 than + # at M=1. Route rows one-by-one to match token verification exactly, then + # keep the expensive expert matmuls batched with fixed indices. + gates = mx.concatenate( + [gate(x[:, i : i + 1, :]) for i in range(int(x.shape[1]))], + axis=1, + ) + gates = mx.softmax(gates, axis=-1, precise=True) + indices = mx.argpartition(gates, kth=-top_k, axis=-1)[..., -top_k:] + scores = mx.take_along_axis(gates, indices, axis=-1) + if bool(getattr(mlp, "norm_topk_prob", False)): + scores = scores / scores.sum(axis=-1, keepdims=True) + + expert_output = cast(mx.array, switch_mlp(x, indices)) + output = (expert_output * scores[..., None]).sum(axis=-2) + shared_gate_output = cast(mx.array, shared_expert_gate(x)) + shared_expert_output = cast(mx.array, shared_expert(x)) + shared_output = mx.sigmoid(shared_gate_output) * shared_expert_output + output = output + shared_output + + if sharding_group is not None: + output = mx.distributed.all_sum(output, group=sharding_group) + return output + + +def _target_forward_row_moe( + *, + model: Model, + inputs: mx.array, + cache: KVCacheType, +) -> tuple[mx.array, mx.array]: + """Target forward with batched non-MoE ops and row-stepped sparse MoE. + + ``scripts/native_mtp_verify_parity_probe.py`` shows qwen3_5_moe's first + batched verifier mismatch occurs inside sparse MoE gate/evaluation, while + GatedDeltaNet internals and full-attention cache updates stay exact for the + K+1 verify window. This forward preserves those batched exact operations + and only splits each sparse MoE call across token rows. + """ + text_model = model.language_model + inner = text_model.model + hidden_states = inner.embed_tokens(inputs) + cache_list = cast("list[Any]", cache) + fa_mask = create_attention_mask(hidden_states, cache_list[inner.fa_idx]) + ssm_mask = create_ssm_mask(hidden_states, cache_list[inner.ssm_idx]) + + for layer, layer_cache in zip(inner.layers, cache_list, strict=True): + selected_mask = ssm_mask if bool(layer.is_linear) else fa_mask + normalized = layer.input_layernorm(hidden_states) + if bool(layer.is_linear): + residual = layer.linear_attn( + normalized, + mask=selected_mask, + cache=layer_cache, + ) + else: + residual = layer.self_attn( + normalized, + mask=selected_mask, + cache=layer_cache, + ) + mid = hidden_states + residual + mlp_input = layer.post_attention_layernorm(mid) + mlp_output = ( + _row_stepped_sparse_moe(layer.mlp, mlp_input) + if _is_sparse_moe_block(layer.mlp) + else layer.mlp(mlp_input) + ) + hidden_states = mid + mlp_output + + post = inner.norm(hidden_states) + if text_model.args.tie_word_embeddings: + logits = inner.embed_tokens.as_linear(post) + else: + logits = text_model.lm_head(post) + return logits, post + + +def _target_forward_route_locked_moe( + *, + model: Model, + inputs: mx.array, + cache: KVCacheType, +) -> tuple[mx.array, mx.array]: + """Target forward with row-locked MoE routing and batched expert work.""" + text_model = model.language_model + inner = text_model.model + hidden_states = inner.embed_tokens(inputs) + cache_list = cast("list[Any]", cache) + fa_mask = create_attention_mask(hidden_states, cache_list[inner.fa_idx]) + ssm_mask = create_ssm_mask(hidden_states, cache_list[inner.ssm_idx]) + + for layer, layer_cache in zip(inner.layers, cache_list, strict=True): + selected_mask = ssm_mask if bool(layer.is_linear) else fa_mask + normalized = layer.input_layernorm(hidden_states) + if bool(layer.is_linear): + residual = layer.linear_attn( + normalized, + mask=selected_mask, + cache=layer_cache, + ) + else: + residual = layer.self_attn( + normalized, + mask=selected_mask, + cache=layer_cache, + ) + mid = hidden_states + residual + mlp_input = layer.post_attention_layernorm(mid) + mlp_output = ( + _route_locked_sparse_moe(layer.mlp, mlp_input) + if _is_sparse_moe_block(layer.mlp) + else layer.mlp(mlp_input) + ) + hidden_states = mid + mlp_output + + post = inner.norm(hidden_states) + if text_model.args.tie_word_embeddings: + logits = inner.embed_tokens.as_linear(post) + else: + logits = text_model.lm_head(post) + return logits, post + + +GdnPrefixHistory = dict[int, list[Any]] + + +def _lm_head_from_hidden(model: Model, hidden: mx.array) -> mx.array: + text_model = model.language_model + inner = text_model.model + if text_model.args.tie_word_embeddings: + return inner.embed_tokens.as_linear(hidden) + return text_model.lm_head(hidden) + + +def _target_post_norm_hidden( + *, + model: Model, + inputs: mx.array, + cache: KVCacheType, +) -> mx.array: + """Run the target body and return post-norm hidden without ``lm_head``. + + Prompt-cache repair and MTP-cache priming need exact post-norm target + hidden states, not vocabulary logits. The vendored Qwen inner model exposes + that post-norm body output directly, so these internal paths can avoid a + vocabulary projection whose output would be discarded. + """ + text_model = getattr(model, "language_model", None) + inner = getattr(text_model, "model", None) + if callable(inner): + result = inner( + inputs, + cache=cast("list[Any]", cache), + return_hidden=True, + ) + if isinstance(result, tuple): + post_hidden = result[0] + if isinstance(post_hidden, mx.array): + return post_hidden + if isinstance(result, mx.array): + return result + _logits, hidden = cast( + tuple[mx.array, mx.array], + model(inputs, cache=cache, return_hidden=True), + ) + return hidden + + +def _entry_is_trimmable(entry: Any) -> bool: + method = getattr(entry, "is_trimmable", None) + if not callable(method): + return False + try: + return bool(method()) + except Exception: + return False + + +def _trim_trimmable_cache_entries(cache: Sequence[Any], n_tokens: int) -> None: + if n_tokens <= 0: + return + for entry in cache: + trim = getattr(entry, "trim", None) + if _entry_is_trimmable(entry) and callable(trim): + trim(n_tokens) + + +def _set_cache_entry_state_lazy(entry: Any, state: Any) -> None: + replace_state = getattr(entry, "replace_state", None) + if callable(replace_state): + replace_state(state) + return + current = getattr(entry, "state", None) + if ( + isinstance(current, list) + and isinstance(state, list) + and len(current) == len(state) + ): + current[:] = state + return + entry.state = state + + +def _make_gated_delta_state_history_kernel() -> Any: + source = """ + auto n = thread_position_in_grid.z; + auto b_idx = n / Hv; + auto hv_idx = n % Hv; + auto hk_idx = hv_idx / (Hv / Hk); + constexpr int n_per_t = Dk / 32; + + auto q_ = q + b_idx * T * Hk * Dk + hk_idx * Dk; + auto k_ = k + b_idx * T * Hk * Dk + hk_idx * Dk; + auto v_ = v + b_idx * T * Hv * Dv + hv_idx * Dv; + y += b_idx * T * Hv * Dv + hv_idx * Dv; + + auto dk_idx = thread_position_in_threadgroup.x; + auto dv_idx = thread_position_in_grid.y; + + auto i_state = state_in + (n * Dv + dv_idx) * Dk; + auto o_state = state_out + (n * Dv + dv_idx) * Dk; + + float state[n_per_t]; + for (int i = 0; i < n_per_t; ++i) { + auto s_idx = n_per_t * dk_idx + i; + state[i] = static_cast(i_state[s_idx]); + } + + auto a_ = a + b_idx * T * Hv; + auto b_ = b + b_idx * T * Hv; + + for (int t = 0; t < T; ++t) { + float a_val = static_cast(a_[hv_idx]); + float dt_val = static_cast(dt_bias[hv_idx]); + float x_g = a_val + dt_val; + float sp = (x_g > 20.0f) ? x_g : log(1.0f + exp(x_g)); + float g_val = exp(-exp(static_cast(A_log[hv_idx])) * sp); + float beta_val = 1.0f / (1.0f + exp(-static_cast(b_[hv_idx]))); + + float kv_mem = 0.0f; + for (int i = 0; i < n_per_t; ++i) { + auto s_idx = n_per_t * dk_idx + i; + state[i] = state[i] * g_val; + kv_mem += state[i] * k_[s_idx]; + } + kv_mem = simd_sum(kv_mem); + + auto delta = (v_[dv_idx] - kv_mem) * beta_val; + + float out = 0.0f; + for (int i = 0; i < n_per_t; ++i) { + auto s_idx = n_per_t * dk_idx + i; + state[i] = state[i] + k_[s_idx] * delta; + out += state[i] * q_[s_idx]; + } + out = simd_sum(out); + if (thread_index_in_simdgroup == 0) { + y[dv_idx] = static_cast(out); + } + + auto hist = state_history + + (((b_idx * T + t) * Hv + hv_idx) * Dv + dv_idx) * Dk; + for (int i = 0; i < n_per_t; ++i) { + auto s_idx = n_per_t * dk_idx + i; + hist[s_idx] = static_cast(state[i]); + } + + q_ += Hk * Dk; + k_ += Hk * Dk; + v_ += Hv * Dv; + y += Hv * Dv; + a_ += Hv; + b_ += Hv; + } + for (int i = 0; i < n_per_t; ++i) { + auto s_idx = n_per_t * dk_idx + i; + o_state[s_idx] = static_cast(state[i]); + } + """ + return mx.fast.metal_kernel( + name="gated_delta_step_state_history", + input_names=["q", "k", "v", "a", "b", "A_log", "dt_bias", "state_in", "T"], + output_names=["y", "state_out", "state_history"], + source=source, + ) + + +_gated_delta_state_history_kernel: Any | None = None + + +def _gated_delta_update_with_state_history( + *, + q: mx.array, + k: mx.array, + v: mx.array, + a: mx.array, + b: mx.array, + a_log: mx.array, + dt_bias: mx.array, + state: mx.array, +) -> tuple[mx.array, mx.array, mx.array]: + global _gated_delta_state_history_kernel + kernel = _gated_delta_state_history_kernel + if kernel is None: + kernel = _make_gated_delta_state_history_kernel() + _gated_delta_state_history_kernel = kernel + typed_kernel = cast( + Callable[..., tuple[mx.array, mx.array, mx.array]], + kernel, + ) + batch_size, verify_width, num_key_heads, key_dim = k.shape + num_value_heads, value_dim = v.shape[2:] + return typed_kernel( + inputs=[q, k, v, a, b, a_log, dt_bias, state, verify_width], + template=[ + ("InT", q.dtype), + ("StT", state.dtype), + ("Dk", key_dim), + ("Dv", value_dim), + ("Hk", num_key_heads), + ("Hv", num_value_heads), + ], + grid=(32, value_dim, batch_size * num_value_heads), + threadgroup=(32, 4, 1), + output_shapes=[ + (batch_size, verify_width, num_value_heads, value_dim), + state.shape, + (batch_size, verify_width, num_value_heads, value_dim, key_dim), + ], + output_dtypes=[q.dtype, state.dtype, state.dtype], + ) + + +def _gdn_linear_attn_with_state_history( + gdn: Any, + inputs: mx.array, + *, + cache: Any, +) -> tuple[mx.array, list[Any]]: + """Run one GatedDeltaNet layer and return cache states after each prefix.""" + batch_size, verify_width, _ = inputs.shape + qkv = gdn.in_proj_qkv(inputs) + z = gdn.in_proj_z(inputs).reshape( + batch_size, verify_width, gdn.num_v_heads, gdn.head_v_dim + ) + b = gdn.in_proj_b(inputs) + a = gdn.in_proj_a(inputs) + + if cache is not None and cache[0] is not None: + conv_state = cache[0] + else: + conv_state = mx.zeros( + (batch_size, gdn.conv_kernel_size - 1, gdn.conv_dim), + dtype=inputs.dtype, + ) + + conv_input = mx.concatenate([conv_state, qkv], axis=1) + n_keep = gdn.conv_kernel_size - 1 + if cache is not None: + cache[0] = mx.contiguous(conv_input[:, -n_keep:, :]) + conv_prefix_states = [ + mx.contiguous(conv_input[:, prefix_index + 1 : prefix_index + 1 + n_keep, :]) + for prefix_index in range(verify_width) + ] + + conv_out = nn.silu(gdn.conv1d(conv_input)) + q, k_arr, v = [ + t.reshape(batch_size, verify_width, h, d) + for t, h, d in zip( + mx.split(conv_out, [gdn.key_dim, 2 * gdn.key_dim], -1), + [gdn.num_k_heads, gdn.num_k_heads, gdn.num_v_heads], + [gdn.head_k_dim, gdn.head_k_dim, gdn.head_v_dim], + strict=True, + ) + ] + + state = cache[1] if cache else None + if state is None: + state = mx.zeros( + (batch_size, gdn.num_v_heads, gdn.head_v_dim, gdn.head_k_dim), + dtype=mx.float32, + ) + inv_scale = k_arr.shape[-1] ** -0.5 + q = inv_scale * q * mx.rsqrt((q * q).sum(axis=-1, keepdims=True) + 1e-6) + k_arr = k_arr * mx.rsqrt((k_arr * k_arr).sum(axis=-1, keepdims=True) + 1e-6) + + out, state, state_history = _gated_delta_update_with_state_history( + q=q, + k=k_arr, + v=v, + a=a, + b=b, + a_log=gdn.A_log, + dt_bias=gdn.dt_bias, + state=state, + ) + + if cache is not None: + cache[1] = state + cache.advance(verify_width) + + out = gdn.norm(out, z) + out = gdn.out_proj(out.reshape(batch_size, verify_width, -1)) + prefix_states = [ + [conv_prefix_states[prefix_index], state_history[:, prefix_index, ...]] + for prefix_index in range(verify_width) + ] + return out, prefix_states + + +def _target_forward_with_gdn_state_history( + *, + model: Model, + inputs: mx.array, + cache: KVCacheType, + policy: _MoeVerifierPolicy, +) -> tuple[mx.array, mx.array, GdnPrefixHistory]: + """Target forward with normal-width GDN state-history outputs.""" + text_model = model.language_model + inner = text_model.model + hidden_states = inner.embed_tokens(inputs) + cache_list = cast("list[Any]", cache) + fa_mask = create_attention_mask(hidden_states, cache_list[inner.fa_idx]) + ssm_mask = create_ssm_mask(hidden_states, cache_list[inner.ssm_idx]) + if isinstance(ssm_mask, mx.array): + raise ValueError("state-history GDN verifier does not support SSM masks") + prefix_history: GdnPrefixHistory = {} + + for layer_index, (layer, layer_cache) in enumerate( + zip(inner.layers, cache_list, strict=True) + ): + selected_mask = ssm_mask if bool(layer.is_linear) else fa_mask + normalized = layer.input_layernorm(hidden_states) + if bool(layer.is_linear): + if selected_mask is not None: + raise ValueError( + "state-history GDN verifier requires unmasked decode cache" + ) + residual, layer_history = _gdn_linear_attn_with_state_history( + layer.linear_attn, + normalized, + cache=layer_cache, + ) + prefix_history[layer_index] = layer_history + else: + residual = layer.self_attn( + normalized, + mask=selected_mask, + cache=layer_cache, + ) + mid = hidden_states + residual + mlp_input = layer.post_attention_layernorm(mid) + if _is_sparse_moe_block(layer.mlp): + if policy == "route_locked": + mlp_output = _route_locked_sparse_moe(layer.mlp, mlp_input) + elif policy == "row_moe": + mlp_output = _row_stepped_sparse_moe(layer.mlp, mlp_input) + else: + raise ValueError( + f"unsupported state-history MoE verifier policy: {policy}" + ) + else: + mlp_output = layer.mlp(mlp_input) + hidden_states = mid + mlp_output + + post = inner.norm(hidden_states) + logits = _lm_head_from_hidden(model, post) + return logits, post, prefix_history + + +def _commit_gdn_prefix_state( + *, + cache: KVCacheType, + prefix_history: GdnPrefixHistory, + accept_count: int, + k: int, +) -> None: + """Commit verifier cache to ``[current, drafts[:accept_count]]``.""" + cache_list = cast("list[Any]", cache) + _trim_trimmable_cache_entries(cache_list, k - accept_count) + for cache_index, states_by_prefix in prefix_history.items(): + _set_cache_entry_state_lazy( + cache_list[cache_index], + states_by_prefix[accept_count], + ) + + +def prime_mtp_cache_from_prompt( + *, + model: Model, + full_prompt_tokens: Sequence[int], + mtp_cache: list[Any], +) -> int: + """Walk prefill positions through the main model to prime the MTP cache. + + Forwards ``prompt[:-1]`` through a fresh main-model cache, captures + post-norm hidden states for every prefill position, then feeds + them through ``model.mtp_update_cache`` in a single batched call. + After priming, the MTP attention cache holds K/V for positions + ``0..N-2``; the first ``mtp_forward`` in the K-chain draft loop + adds position ``N-1`` naturally. + + Returns the number of positions primed (``N-1``), or ``0`` when + the prompt is shorter than 2 tokens. + + Known cost: this re-walks the prompt through a fresh main-model cache + to prime the separate MTP cache, so the prompt is prefilled twice (once + by ``exo.prefill`` for the main KV cache, once here for the MTP cache). + Folding the two into a single capture-prefill is a future optimisation + at the :func:`exo.worker.engines.mlx.prefill` seam, not in this drafter. + """ + n = len(full_prompt_tokens) + if n < 2: + return 0 + prefill_arr = mx.array(list(full_prompt_tokens[:-1]), dtype=mx.int32) + next_arr = mx.array(list(full_prompt_tokens[1:]), dtype=mx.int32) + fresh_cache: list[Any] = cast("list[Any]", model.make_cache()) + prefill_hidden = _target_post_norm_hidden( + model=model, + inputs=prefill_arr[None], + cache=fresh_cache, + ) + mx.eval(prefill_hidden) + model.mtp_update_cache(prefill_hidden, next_arr[None], mtp_cache=mtp_cache) + _materialize_cache_state(mtp_cache) + del fresh_cache + return n - 1 + + +def prime_mtp_cache_from_prompt_incremental( + *, + model: Model, + full_prompt_tokens: Sequence[int], + mtp_cache: list[Any], +) -> int: + """Prime MTP cache by stepping target hiddens token by token. + + This is the correctness path for qwen3_5_moe. Its recurrent + GatedDeltaNet state is sensitive to batched prefill; the batched + :func:`prime_mtp_cache_from_prompt` can produce hiddens that are close + enough for ordinary logits but wrong enough for MTP acceptance. + """ + n = len(full_prompt_tokens) + if n < 2: + return 0 + fresh_cache: list[Any] = cast("list[Any]", model.make_cache()) + for i in range(n - 1): + token_arr = mx.array([[int(full_prompt_tokens[i])]], dtype=mx.int32) + next_arr = mx.array([[int(full_prompt_tokens[i + 1])]], dtype=mx.int32) + hidden = _target_post_norm_hidden( + model=model, + inputs=token_arr, + cache=fresh_cache, + ) + mtp_hidden = model.mtp_update_cache( + hidden[:, -1:, :], + next_arr, + mtp_cache=mtp_cache, + ) + mx.eval(hidden, mtp_hidden) + _materialize_cache_state(mtp_cache) + del fresh_cache + return n - 1 + + +def rebuild_prompt_cache_and_prime_mtp_cache_incremental( + *, + model: Model, + full_prompt_tokens: Sequence[int], + prompt_cache: KVCacheType, + mtp_cache: list[Any], +) -> tuple[int, int]: + """Rebuild target prompt cache while priming MTP cache in one pass.""" + n = len(full_prompt_tokens) + cache_list = cast("list[Any]", prompt_cache) + if n < 2: + cache_list[:] = cast("list[Any]", model.make_cache()) + return 0, 0 + + tokens_to_prefill = max(0, n - 2) + fresh_cache: list[Any] = cast("list[Any]", model.make_cache()) + prompt_boundary_snapshot: Any | None = None + for i in range(n - 1): + token_arr = mx.array([[int(full_prompt_tokens[i])]], dtype=mx.int32) + next_arr = mx.array([[int(full_prompt_tokens[i + 1])]], dtype=mx.int32) + hidden = _target_post_norm_hidden( + model=model, + inputs=token_arr, + cache=fresh_cache, + ) + mtp_hidden = model.mtp_update_cache( + hidden[:, -1:, :], + next_arr, + mtp_cache=mtp_cache, + ) + mx.eval(hidden, mtp_hidden) + if i + 1 == tokens_to_prefill: + prompt_boundary_snapshot = snapshot_untrimmable_cache_lazy(fresh_cache) + + if tokens_to_prefill == 0: + cache_list[:] = cast("list[Any]", model.make_cache()) + else: + _trim_trimmable_cache_entries(fresh_cache, (n - 1) - tokens_to_prefill) + if prompt_boundary_snapshot is None: + raise RuntimeError("missing prompt boundary snapshot") + restore_cache(fresh_cache, prompt_boundary_snapshot) + cache_list[:] = fresh_cache + _materialize_cache_state(mtp_cache) + return tokens_to_prefill, n - 1 + + +def rebuild_prompt_cache_incremental( + *, + model: Model, + full_prompt_tokens: Sequence[int], + prompt_cache: KVCacheType, +) -> int: + """Replace ``prompt_cache`` with token-by-token state through prompt[:-2]. + + Exo's standard prefill uses mlx-lm's chunked prefill and then trims the + decode tail. That is fast, but qwen3_5_moe's GatedDeltaNet recurrent + state is not equivalent to a token-by-token cache under that path. Native + MTP needs the exact recurrent hidden/cache state, so the 35B MoE safe path + rebuilds it incrementally before the two-token decode tail is consumed. + """ + fresh_cache: list[Any] = cast("list[Any]", model.make_cache()) + tokens_to_prefill = max(0, len(full_prompt_tokens) - 2) + for token in full_prompt_tokens[:tokens_to_prefill]: + token_arr = mx.array([[int(token)]], dtype=mx.int32) + hidden = _target_post_norm_hidden( + model=model, + inputs=token_arr, + cache=fresh_cache, + ) + mx.eval(hidden) + cache_list = cast("list[Any]", prompt_cache) + cache_list[:] = fresh_cache + return tokens_to_prefill + + +@final +class NativeMTPDrafter: + """Drafter-protocol shim for native MTP (vendored Qwen3.5/3.6). + + Constructor takes ``k`` explicitly; the builder sources it from + :attr:`exo.shared.models.model_cards.NativeMTPConfig.default_k` + (overridable per-request via ``TaskParams.num_draft_tokens``). + + The drafter reports ``mode="model"`` so existing telemetry counts + it as drafter-active. The architecture-level distinction + (``native_mtp`` vs sibling-LM ``model``) is surfaced via the + builder log line at dispatch time; a future :data:`DraftMode` + extension can expose it via ``GenerationStats`` if needed. + """ + + def __init__(self, *, k: int) -> None: + if k < 1: + raise ValueError(f"k must be >= 1, got {k}") + self._k: int = k + self._metrics: dict[str, int] = { + "proposed_draft_tokens": 0, + "accepted_draft_tokens": 0, + "spec_decode_rounds": 0, + } + + @property + def mode(self) -> str: + return "model" + + @property + def num_draft_tokens(self) -> int: + return self._k + + def metrics(self) -> dict[str, int]: + return dict(self._metrics) + + def stream( + self, + *, + model: Model, + tokenizer: TokenizerWrapper, + prompt: mx.array, + context_tokens: Sequence[int], + prompt_cache: KVCacheType, + max_tokens: int, + sampler: Callable[[mx.array], mx.array], + logits_processors: Sequence[Callable[[mx.array, mx.array], mx.array]], + prefill_step_size: int = 1, + temperature: float = 0.0, + top_p: float = 0.0, + top_k: int = 0, + min_p: float = 0.0, + min_tokens_to_keep: int = 1, + ) -> Generator[GenerationResponse, None, None]: + # LOSSLESS / distribution-preserving native MTP. The K-chain drafts + # are SAMPLED from the MTP head's per-step distribution ``q`` (after + # the request's temperature/top_p/top_k/min_p transforms) and the + # batched/sequential verify accepts each draft via probability-ratio + # (Leviathan/Chen) rejection sampling against the target distribution + # ``p`` -- so the emitted marginal matches the request's sampler at any + # temperature. With ``temperature == 0`` the scheme collapses to exact + # greedy (accept iff draft == argmax(p)), token-identical to plain + # greedy decode. See :class:`_SpecSampler` for the reconstruction of + # ``p``/``q`` from the raw sampling params (the request only hands us a + # ``sampler`` callable, which is insufficient -- we need the + # distribution, so the params are threaded in from ``mlx_generate``). + # + # ``logits_processors`` (repetition/presence/frequency penalties, + # logit bias) remain UNHANDLED on this path: they are stateful over + # emitted-token history and cannot be replayed consistently across the + # chained draft steps and the batched verify without breaking the + # lossless guarantee; applying them inconsistently is worse than not at + # all. ``sampler`` is likewise unused -- the params drive the + # reconstructed distribution directly. ``prefill_step_size`` is unused + # (MTP-cache priming is batched). + del sampler, logits_processors, prefill_step_size + spec_sampler = _SpecSampler( + temperature=temperature, + top_p=top_p, + top_k=top_k, + min_p=min_p, + min_tokens_to_keep=min_tokens_to_keep, + ) + if not is_native_mtp_dispatchable(model): + raise RuntimeError( + "NativeMTPDrafter.stream called on a model without " + "mtp_forward / make_mtp_cache / mtp_update_cache. The " + "dispatch-site gate (is_native_mtp_dispatchable) should " + "have caught this." + ) + + detokenizer = tokenizer.detokenizer + detokenizer.reset() + eos_ids = _eos_ids_from_tokenizer(tokenizer) + + prompt_tail_size = int(prompt.size) + + mtp_cache: list[Any] = cast("list[Any]", model.make_mtp_cache()) + if not mtp_cache: + raise RuntimeError( + "model.make_mtp_cache() returned empty list -- model " + "has no MTP head. Loader/dispatch drift." + ) + + model_type = _native_mtp_model_type(model) + sequential_gdn_path = _requires_sequential_gdn_path(model) + moe_verifier_policy = _moe_verifier_policy() if sequential_gdn_path else "safe" + if sequential_gdn_path and moe_verifier_policy != "safe": + logger.info( + f"[native MTP] qwen3_5_moe verifier policy={moe_verifier_policy!r}" + ) + gdn_state_history_default = _gdn_state_history_commit_default( + model_type=model_type, + moe_verifier_policy=moe_verifier_policy, + ) + gdn_state_history_supported = ( + _gdn_state_history_kernel_supported(model) + if gdn_state_history_default + else False + ) + gdn_state_history_commit = ( + gdn_state_history_default + and gdn_state_history_supported + # Default-on for dense qwen3_5 and for qwen3_5_moe route-locked + # native MTP after broad production stream parity. Set the env var + # to 0 as a kill switch. + and _env_flag("EXO_NATIVE_MTP_GDN_STATE_HISTORY_COMMIT", default=True) + ) + if gdn_state_history_default and not gdn_state_history_supported: + logger.warning( + f"[native MTP] {model_type or 'unknown model'} GDN " + "state-history prefix commit disabled: unsupported GDN head shape" + ) + if gdn_state_history_commit: + logger.info( + f"[native MTP] {model_type or 'unknown model'} GDN " + "state-history prefix commit enabled" + ) + row_moe_adaptive = ( + sequential_gdn_path + and moe_verifier_policy == "row_moe" + and _env_flag("EXO_NATIVE_MTP_ROW_MOE_ADAPTIVE", default=True) + ) + row_moe_adaptive_min_rounds = _env_int( + "EXO_NATIVE_MTP_ROW_MOE_ADAPTIVE_MIN_ROUNDS", + default=8, + minimum=1, + ) + row_moe_adaptive_min_acceptance = _env_float( + "EXO_NATIVE_MTP_ROW_MOE_ADAPTIVE_MIN_ACCEPTANCE", + default=0.55, + ) + row_moe_adaptive_switched = False + primed_positions = 0 + if context_tokens: + prime_start = time.perf_counter() + if sequential_gdn_path: + repaired_positions, primed_positions = ( + rebuild_prompt_cache_and_prime_mtp_cache_incremental( + model=model, + full_prompt_tokens=context_tokens, + prompt_cache=prompt_cache, + mtp_cache=mtp_cache, + ) + ) + prime_elapsed_ms = (time.perf_counter() - prime_start) * 1000 + logger.info( + "[native MTP] qwen3_5_moe safe path rebuilt target cache " + f"from {repaired_positions} prompt positions and primed " + f"MTP cache from {primed_positions} prompt positions in " + f"{prime_elapsed_ms:.1f}ms" + ) + else: + primed_positions = prime_mtp_cache_from_prompt( + model=model, + full_prompt_tokens=context_tokens, + mtp_cache=mtp_cache, + ) + prime_elapsed_ms = (time.perf_counter() - prime_start) * 1000 + logger.info( + f"[native MTP] primed MTP cache from {primed_positions} " + f"prompt positions in {prime_elapsed_ms:.1f}ms" + ) + + # First main-model forward: decode_in covers the 2-token prefill + # tail. The caller has aligned ``prompt_cache`` to + # ``full_prompt[:-2]`` via exo.prefill + trim(2); this brings + # the cache to ``len-1`` via position N-2 and returns logits + # at position N-1 (which we sample for the first emitted token). + decode_in = prompt.astype(mx.int32)[None] + if int(decode_in.shape[1]) > 1: + # Qwen3.5/3.6 GatedDeltaNet can produce different last-position + # logits for a small batched tail than for the same tokens stepped + # through the recurrent state. Mirror mlx-lm's prefill/decode split: + # advance every tail token except the last, then sample from the last. + tail_prefix_hidden = _target_post_norm_hidden( + model=model, + inputs=decode_in[:, :-1], + cache=prompt_cache, + ) + first_logits, first_hidden = cast( + tuple[mx.array, mx.array], + model(decode_in[:, -1:], cache=prompt_cache, return_hidden=True), + ) + mx.eval(tail_prefix_hidden, first_logits, first_hidden) + else: + first_logits, first_hidden = cast( + tuple[mx.array, mx.array], + model(decode_in, cache=prompt_cache, return_hidden=True), + ) + mx.eval(first_logits, first_hidden) + # Sample the first emitted token from the request's distribution + # (argmax when greedy) instead of a hard argmax. + current_token_arr = spec_sampler.sample_from_distribution( + spec_sampler.distribution(first_logits[:, -1, :]) + ) + mx.eval(current_token_arr) + current_token = _token_array_to_int(current_token_arr) + current_hidden = first_hidden[:, -1:, :] + + proposed = 0 + accepted = 0 + rounds = 0 + ntoks = 0 + finish_reason: str | None = None + prompt_tps_local = 0.0 + + tic = time.perf_counter() + if current_token in eos_ids: + detokenizer.add_token(current_token) + detokenizer.finalize() + yield GenerationResponse( + text=detokenizer.last_segment, + token=current_token, + logprobs=mx.zeros((1,)), + from_draft=False, + prompt_tokens=prompt_tail_size, + prompt_tps=0.0, + generation_tokens=1, + generation_tps=0.0, + peak_memory=mx.get_peak_memory() / 1e9, + finish_reason="stop", + ) + return + + prompt_time = time.perf_counter() - tic + prompt_tps_local = prompt_tail_size / prompt_time if prompt_time > 0 else 0.0 + tic = time.perf_counter() + ntoks = 1 + detokenizer.add_token(current_token) + elapsed = time.perf_counter() - tic + yield GenerationResponse( + text=detokenizer.last_segment, + token=current_token, + logprobs=mx.zeros((1,)), + from_draft=False, + prompt_tokens=prompt_tail_size, + prompt_tps=prompt_tps_local, + generation_tokens=ntoks, + generation_tps=ntoks / elapsed if elapsed > 0 else 0.0, + peak_memory=mx.get_peak_memory() / 1e9, + finish_reason=None, + ) + + # Main loop: cache offset = N, current_{token, hidden} from + # most recent emission. + while ntoks < max_tokens: + # ---- Draft K via chained mtp_forward ---- + # Each draft token is SAMPLED from its MTP-step distribution ``q`` + # (argmax when greedy); ``q`` is retained per step so the verify + # step can run the probability-ratio acceptance test. + draft_token_arrays: list[mx.array] = [] + draft_distributions: list[mx.array] = [] + chain_hidden = current_hidden + chain_tok_arr = current_token_arr + for _ in range(self._k): + mtp_logits, mtp_hidden = cast( + tuple[mx.array, mx.array], + model.mtp_forward( + chain_hidden, + chain_tok_arr, + mtp_cache=mtp_cache, + return_hidden=True, + ), + ) + draft_dist = spec_sampler.distribution(mtp_logits[:, -1, :]) + next_draft_arr = spec_sampler.sample_from_distribution(draft_dist) + draft_token_arrays.append(next_draft_arr) + draft_distributions.append(draft_dist) + chain_hidden = mtp_hidden[:, -1:, :] + chain_tok_arr = next_draft_arr + # Consolidated eval: single GPU dispatch for all draft arrays + all_drafts = ( + draft_token_arrays[0] + if self._k == 1 + else mx.concatenate(draft_token_arrays, axis=1) + ) + if _ASYNC_DRAFT_EVAL_ENABLED: + _async_materialize_cache_state( + mtp_cache, + all_drafts, + *draft_distributions, + ) + else: + mx.eval(all_drafts, *draft_distributions) + _materialize_cache_state(mtp_cache) + proposed += self._k + rounds += 1 + + # Pre-draw the per-position uniforms for the ratio test up-front so + # the acceptance walk needs no extra GPU sync per token. + accept_uniforms: list[float] = ( + [0.0] * self._k + if spec_sampler.is_greedy + else cast( + "list[float]", + mx.random.uniform(shape=(self._k,)).tolist(), + ) + ) + + # ---- Verify: target forward on [current_token, *drafts] ---- + # Consolidated readback: single CPU-GPU sync for all K draft tokens + # instead of K sequential .item() calls (~50-150us each). + draft_tokens: list[int] = [ + int(t) for t in list(all_drafts.reshape(-1).tolist()) + ] + num_accepted_this_round = 0 + bonus_tok = current_token + next_current_hidden = current_hidden + next_current_token_arr = current_token_arr + pre_verify_snapshot: Any | None = None + prefix_history: GdnPrefixHistory | None = None + verify_hidden: mx.array | None = None + + if sequential_gdn_path: + used_batched_moe_verify = False + if moe_verifier_policy != "safe": + verify_in = mx.concatenate( + [current_token_arr, *draft_token_arrays], axis=1 + ) + if gdn_state_history_commit: + ( + verify_logits, + verify_hidden, + prefix_history, + ) = _target_forward_with_gdn_state_history( + model=model, + inputs=verify_in, + cache=prompt_cache, + policy=moe_verifier_policy, + ) + else: + pre_verify_snapshot = snapshot_untrimmable_cache_lazy( + cast("list[Any]", prompt_cache) + ) + verify_logits, verify_hidden = ( + _target_forward_row_moe( + model=model, + inputs=verify_in, + cache=prompt_cache, + ) + if moe_verifier_policy == "row_moe" + else _target_forward_route_locked_moe( + model=model, + inputs=verify_in, + cache=prompt_cache, + ) + if moe_verifier_policy == "route_locked" + else cast( + tuple[mx.array, mx.array], + model( + verify_in, + cache=prompt_cache, + return_hidden=True, + ), + ) + ) + mx.eval(verify_logits, verify_hidden) + batched_accepted, batched_replacement_arr = _spec_accept_walk( + spec_sampler=spec_sampler, + verify_logits=verify_logits, + draft_tokens=draft_tokens, + draft_distributions=draft_distributions, + accept_uniforms=accept_uniforms, + k=self._k, + ) + mx.eval(batched_replacement_arr) + + can_commit_batched = ( + moe_verifier_policy in {"batched", "row_moe", "route_locked"} + or batched_accepted == self._k + ) + if can_commit_batched: + num_accepted_this_round = batched_accepted + bonus_tok = _token_array_to_int(batched_replacement_arr) + next_current_token_arr = batched_replacement_arr + next_current_hidden = verify_hidden[ + :, + num_accepted_this_round : num_accepted_this_round + 1, + :, + ] + used_batched_moe_verify = True + else: + assert pre_verify_snapshot is not None + rollback_after_verify( + cast("list[Any]", prompt_cache), + pre_verify_snapshot, + verified_tokens=self._k + 1, + ) + + if not used_batched_moe_verify: + # qwen3_5_moe safe verifier: feed only the retained prefix + # token-by-token. Batched verification corrupts later- + # position GatedDeltaNet state/logits on 35B-A3B, producing + # coherent-looking but wrong bonuses such as "if items" + # instead of "if len". Each position runs the same + # probability-ratio acceptance test as the batched paths. + for i in range(self._k + 1): + verify_token_arr = ( + current_token_arr if i == 0 else draft_token_arrays[i - 1] + ) + verify_logits_i, verify_hidden_i = cast( + tuple[mx.array, mx.array], + model( + verify_token_arr, + cache=prompt_cache, + return_hidden=True, + ), + ) + target_dist_i = spec_sampler.distribution( + verify_logits_i[:, -1, :] + ) + mx.eval(target_dist_i, verify_hidden_i) + next_current_hidden = verify_hidden_i[:, -1:, :] + if i < self._k: + draft_token = draft_tokens[i] + target_argmax = int( + mx.argmax(verify_logits_i[:, -1, :], axis=-1).item() + ) + target_prob = float(target_dist_i[0, draft_token].item()) + draft_prob = float( + draft_distributions[i][0, draft_token].item() + ) + if _accept_draft_token( + is_greedy=spec_sampler.is_greedy, + draft_token=draft_token, + target_argmax=target_argmax, + target_prob=target_prob, + draft_prob=draft_prob, + uniform=accept_uniforms[i], + ): + num_accepted_this_round += 1 + continue + # Rejected: replacement from residual (p_i - q_i)+. + replacement_arr = spec_sampler.residual_token( + target_dist_i, draft_distributions[i] + ) + else: + # Full accept: bonus sampled from the target dist. + replacement_arr = spec_sampler.sample_from_distribution( + target_dist_i + ) + mx.eval(replacement_arr) + bonus_tok = _token_array_to_int(replacement_arr) + next_current_token_arr = replacement_arr + break + elif num_accepted_this_round < self._k: + if gdn_state_history_commit: + assert prefix_history is not None + _commit_gdn_prefix_state( + cache=prompt_cache, + prefix_history=prefix_history, + accept_count=num_accepted_this_round, + k=self._k, + ) + else: + assert pre_verify_snapshot is not None + rollback_after_verify( + cast("list[Any]", prompt_cache), + pre_verify_snapshot, + verified_tokens=self._k + 1, + ) + repair_in = mx.concatenate( + [ + current_token_arr, + *draft_token_arrays[:num_accepted_this_round], + ], + axis=1, + ) + _repair_logits, _repair_hidden = ( + _target_forward_row_moe( + model=model, + inputs=repair_in, + cache=prompt_cache, + ) + if moe_verifier_policy == "row_moe" + else _target_forward_route_locked_moe( + model=model, + inputs=repair_in, + cache=prompt_cache, + ) + if moe_verifier_policy == "route_locked" + else cast( + tuple[mx.array, mx.array], + model( + repair_in, + cache=prompt_cache, + return_hidden=True, + ), + ) + ) + mx.eval(_repair_logits, _repair_hidden) + next_current_hidden = _repair_hidden[:, -1:, :] + accepted += num_accepted_this_round + if ( + row_moe_adaptive + and not row_moe_adaptive_switched + and rounds >= row_moe_adaptive_min_rounds + and proposed > 0 + and (accepted / proposed) < row_moe_adaptive_min_acceptance + ): + logger.info( + "[native MTP] row_moe acceptance " + f"{accepted}/{proposed}={accepted / proposed:.1%} after " + f"{rounds} rounds; switching verifier to safe for this " + "generation" + ) + moe_verifier_policy = "safe" + row_moe_adaptive_switched = True + else: + verify_in = mx.concatenate( + [current_token_arr, *draft_token_arrays], axis=1 + ) + if gdn_state_history_commit: + ( + verify_logits, + verify_hidden, + prefix_history, + ) = _target_forward_with_gdn_state_history( + model=model, + inputs=verify_in, + cache=prompt_cache, + policy=moe_verifier_policy, + ) + else: + # Snapshot GDN state BEFORE verify advances cache by K+1 so + # we can roll back on partial accept; see module docstring. + pre_verify_snapshot = snapshot_untrimmable_cache_lazy( + cast("list[Any]", prompt_cache) + ) + verify_logits, verify_hidden = cast( + tuple[mx.array, mx.array], + model(verify_in, cache=prompt_cache, return_hidden=True), + ) + mx.eval(verify_logits, verify_hidden) + + # Probability-ratio accept-walk (lossless speculative sampling). + num_accepted_this_round, replacement_arr = _spec_accept_walk( + spec_sampler=spec_sampler, + verify_logits=verify_logits, + draft_tokens=draft_tokens, + draft_distributions=draft_distributions, + accept_uniforms=accept_uniforms, + k=self._k, + ) + mx.eval(replacement_arr) + accepted += num_accepted_this_round + bonus_tok = _token_array_to_int(replacement_arr) + next_current_token_arr = replacement_arr + next_current_hidden = verify_hidden[ + :, num_accepted_this_round : num_accepted_this_round + 1, : + ] + + # Emit accepted drafts. + for i in range(num_accepted_this_round): + tok = draft_tokens[i] + ntoks += 1 + detokenizer.add_token(tok) + elapsed = time.perf_counter() - tic + yield GenerationResponse( + text=detokenizer.last_segment, + token=tok, + logprobs=mx.zeros((1,)), + from_draft=True, + prompt_tokens=prompt_tail_size, + prompt_tps=prompt_tps_local, + generation_tokens=ntoks, + generation_tps=ntoks / elapsed if elapsed > 0 else 0.0, + peak_memory=mx.get_peak_memory() / 1e9, + finish_reason=None, + ) + if tok in eos_ids: + finish_reason = "stop" + break + if ntoks >= max_tokens: + finish_reason = "length" + break + + # Emit bonus. + if finish_reason is None: + ntoks += 1 + detokenizer.add_token(bonus_tok) + elapsed = time.perf_counter() - tic + yield GenerationResponse( + text=detokenizer.last_segment, + token=bonus_tok, + logprobs=mx.zeros((1,)), + from_draft=False, + prompt_tokens=prompt_tail_size, + prompt_tps=prompt_tps_local, + generation_tokens=ntoks, + generation_tps=ntoks / elapsed if elapsed > 0 else 0.0, + peak_memory=mx.get_peak_memory() / 1e9, + finish_reason=None, + ) + if bonus_tok in eos_ids: + finish_reason = "stop" + + # ---- Cache trims ---- + # Main cache: on partial accept (n < K) we trim back to the + # pre-verify offset and restore GDN from the snapshot, then + # re-forward [current_token, *drafts[:n]] to land cache at + # N + n + 1 with the GDN recurrent state advanced through + # the retained tokens only. On full accept (n == K) the + # batched verify's cache state at N+K+1 is already on- + # trajectory (verified by batched_vs_solo_probe in the + # research lane), so no rollback is needed. + if not sequential_gdn_path and num_accepted_this_round < self._k: + try: + if gdn_state_history_commit: + assert prefix_history is not None + _commit_gdn_prefix_state( + cache=prompt_cache, + prefix_history=prefix_history, + accept_count=num_accepted_this_round, + k=self._k, + ) + else: + assert pre_verify_snapshot is not None + rollback_after_verify( + cast("list[Any]", prompt_cache), + pre_verify_snapshot, + verified_tokens=self._k + 1, + ) + repair_in = mx.concatenate( + [ + current_token_arr, + *draft_token_arrays[:num_accepted_this_round], + ], + axis=1, + ) + _repair_logits, _repair_hidden = cast( + tuple[mx.array, mx.array], + model(repair_in, cache=prompt_cache, return_hidden=True), + ) + mx.eval(_repair_logits, _repair_hidden) + next_current_hidden = _repair_hidden[:, -1:, :] + except Exception as exc: + logger.warning( + f"[native MTP] cache repair after partial accept " + f"(n={num_accepted_this_round}/K={self._k}) raised " + f"{type(exc).__name__}: {exc}; continuing with stale cache" + ) + # MTP cache: chain drafting added K positions; only the + # accepted prefix's K/V corresponds to tokens that materialised + # in the main stream. The MTP attention layer is full-attention + # (qwen3_5_mtp.py uses ``fa_layer_idx``) so the stock trim is + # correct here -- no rollback asymmetry to worry about. + mtp_retained = min(self._k, num_accepted_this_round + 1) + mtp_trim = self._k - mtp_retained + if mtp_trim > 0: + try: + mlx_trim_prompt_cache(cast(list[object], mtp_cache), mtp_trim) + except Exception as exc: + logger.warning( + f"[native MTP] mtp cache trim({mtp_trim}) raised " + f"{type(exc).__name__}: {exc}; continuing with stale cache" + ) + + if finish_reason is not None or ntoks >= max_tokens: + if finish_reason is None and ntoks >= max_tokens: + finish_reason = "length" + break + + current_token = bonus_tok + current_token_arr = next_current_token_arr + current_hidden = next_current_hidden + + detokenizer.finalize() + if finish_reason is None: + finish_reason = "length" if ntoks >= max_tokens else "stop" + elapsed = time.perf_counter() - tic + acceptance = (accepted / proposed) if proposed > 0 else 0.0 + self._metrics["proposed_draft_tokens"] = proposed + self._metrics["accepted_draft_tokens"] = accepted + self._metrics["spec_decode_rounds"] = rounds + logger.info( + f"[native MTP] K={self._k} stream done: tokens={ntoks}, " + f"proposed={proposed}, accepted={accepted}, " + f"acceptance={acceptance:.1%}, rounds={rounds}, " + f"finish={finish_reason}, primed_positions={primed_positions}" + ) + yield GenerationResponse( + text=detokenizer.last_segment, + token=current_token, + logprobs=mx.zeros((1,)), + from_draft=False, + prompt_tokens=prompt_tail_size, + prompt_tps=prompt_tps_local, + generation_tokens=ntoks, + generation_tps=ntoks / elapsed if elapsed > 0 else 0.0, + peak_memory=mx.get_peak_memory() / 1e9, + finish_reason=finish_reason, + ) + + +__all__ = [ + "NativeMTPDrafter", + "is_native_mtp_dispatchable", + "prime_mtp_cache_from_prompt", +] diff --git a/src/exo/worker/engines/mlx/mtp_probe.py b/src/exo/worker/engines/mlx/mtp_probe.py new file mode 100644 index 0000000000..00226a4f26 --- /dev/null +++ b/src/exo/worker/engines/mlx/mtp_probe.py @@ -0,0 +1,244 @@ +"""MTP (Multi-Token Prediction) weight probing for Qwen3.6 models. + +Detects MTP weights across all known distribution formats: +- Original HuggingFace: embedded in main shards with ``mtp.*`` prefix (15 tensors) +- MTPLX quantized: separate ``mtp.safetensors`` file with ``mtp.*`` prefix (29 tensors) +- oMLX quantized: embedded in main shards with ``language_model.mtp.*`` prefix (29 tensors) +- mlx-community: stripped during quantization (0 tensors, unrecoverable) + +The probe is called before model loading so the loader can: +1. Inject separate MTP weights into the main weight dict (MTPLX format) +2. Patch ``sanitize()`` to stop stripping MTP tensors +3. Ensure norm weight shifting fires correctly +""" + +from __future__ import annotations + +import json +from dataclasses import dataclass +from enum import Enum, auto +from pathlib import Path +from typing import Any, Iterable, cast + +from loguru import logger + +# pyright: reportAny=false +# This module probes safetensors/JSON data which are untyped. + + +class MtpFormat(Enum): + """How MTP weights are stored in this model.""" + + ORIGINAL_EMBEDDED = auto() + """Original HuggingFace format: ``mtp.*`` prefix, embedded in main shards (15 tensors, BF16).""" + + MTPLX_SEPARATE_FILE = auto() + """MTPLX format: ``mtp.*`` prefix, separate ``mtp.safetensors`` file (29 tensors, quantized).""" + + OMLX_EMBEDDED = auto() + """oMLX format: ``language_model.mtp.*`` prefix, embedded in main shards (29 tensors, quantized).""" + + STRIPPED = auto() + """MTP weights were stripped during quantization (e.g. mlx-community). Unrecoverable.""" + + +@dataclass(frozen=True) +class MtpProbeResult: + """Result of probing a model directory for MTP weights.""" + + model_declares_mtp: bool + """Whether ``config.json`` declares ``mtp_num_hidden_layers > 0``.""" + + mtp_tensors_found: bool + """Whether MTP weight tensors were found on disk.""" + + mtp_format: MtpFormat | None + """Detected storage format, or ``None`` if model doesn't declare MTP.""" + + mtp_count: int + """Number of MTP tensors found.""" + + mtp_path: str | None + """Path to MTP weights (either ``mtp.safetensors`` for separate file, + or description of embedded location).""" + + mtp_tensor_keys: tuple[str, ...] + """Names of MTP tensor keys found (empty if STRIPPED or not found).""" + + @property + def is_recoverable(self) -> bool: + """Whether MTP weights can be loaded and used.""" + return ( + self.mtp_tensors_found + and self.mtp_format is not None + and self.mtp_format != MtpFormat.STRIPPED + ) + + +def probe_mtp_weights(model_path: Path | str) -> MtpProbeResult: + """Probe a model directory for MTP weights in all known locations. + + Checks in order: + 1. ``config.json`` → ``mlx_lm_extra_tensors.mtp_file`` (MTPLX separate file) + 2. ``model.safetensors.index.json`` → ``mtp.*`` prefix (Original HuggingFace) + 3. ``model.safetensors.index.json`` → ``language_model.mtp.*`` prefix (oMLX) + 4. ``mtp.safetensors`` file on disk (fallback, no config declaration) + + Args: + model_path: Path to the model directory. + + Returns: + ``MtpProbeResult`` describing what was found. + """ + model_dir = Path(model_path) + config_path = model_dir / "config.json" + + # Step 1: Check config for MTP declaration + model_declares_mtp = False + if config_path.exists(): + try: + cfg = json.loads(config_path.read_text()) + tc = cfg.get("text_config", {}) + mtp_layers = tc.get("mtp_num_hidden_layers", 0) + model_declares_mtp = bool(mtp_layers and mtp_layers > 0) + except (json.JSONDecodeError, OSError): + pass + + if not model_declares_mtp: + return MtpProbeResult( + model_declares_mtp=False, + mtp_tensors_found=False, + mtp_format=None, + mtp_count=0, + mtp_path=None, + mtp_tensor_keys=(), + ) + + # Step 2: Check for MTPLX separate file via mlx_lm_extra_tensors + try: + cfg = json.loads(config_path.read_text()) + extra = cfg.get("mlx_lm_extra_tensors", {}) + mtp_file_name = extra.get("mtp_file") + if mtp_file_name: + mtp_file = model_dir / mtp_file_name + if mtp_file.exists(): + keys = _safetensors_keys(mtp_file) + mtp_keys = tuple(k for k in keys if "mtp." in k) + if mtp_keys: + return MtpProbeResult( + model_declares_mtp=True, + mtp_tensors_found=True, + mtp_format=MtpFormat.MTPLX_SEPARATE_FILE, + mtp_count=len(mtp_keys), + mtp_path=str(mtp_file), + mtp_tensor_keys=mtp_keys, + ) + except (json.JSONDecodeError, OSError): + pass + + # Step 3: Check weight map index for embedded MTP tensors + index_path = model_dir / "model.safetensors.index.json" + if index_path.exists(): + try: + idx = json.loads(index_path.read_text()) + weight_map = idx.get("weight_map", {}) + all_keys = tuple(weight_map.keys()) + + # Check for oMLX prefix first (more specific) + omlx_keys = tuple(k for k in all_keys if "language_model.mtp." in str(k)) + if omlx_keys: + return MtpProbeResult( + model_declares_mtp=True, + mtp_tensors_found=True, + mtp_format=MtpFormat.OMLX_EMBEDDED, + mtp_count=len(omlx_keys), + mtp_path="embedded in main shards (language_model.mtp.* prefix)", + mtp_tensor_keys=omlx_keys, + ) + + # Check for original prefix + orig_keys = tuple(k for k in all_keys if str(k).startswith("mtp.")) + if orig_keys: + return MtpProbeResult( + model_declares_mtp=True, + mtp_tensors_found=True, + mtp_format=MtpFormat.ORIGINAL_EMBEDDED, + mtp_count=len(orig_keys), + mtp_path="embedded in main shards (mtp.* prefix)", + mtp_tensor_keys=orig_keys, + ) + except (json.JSONDecodeError, OSError): + pass + + # Step 4: Fallback — check for mtp.safetensors without config declaration + mtp_file = model_dir / "mtp.safetensors" + if mtp_file.exists(): + keys = _safetensors_keys(mtp_file) + mtp_keys = tuple(k for k in keys if "mtp." in k) + if mtp_keys: + return MtpProbeResult( + model_declares_mtp=True, + mtp_tensors_found=True, + mtp_format=MtpFormat.MTPLX_SEPARATE_FILE, + mtp_count=len(mtp_keys), + mtp_path=str(mtp_file), + mtp_tensor_keys=mtp_keys, + ) + + # Step 5: Model declares MTP but no weights found — stripped during quantization + logger.warning( + f"Model at {model_dir} declares mtp_num_hidden_layers > 0 but no MTP " + "weights were found on disk. MTP weights were likely stripped during " + "quantization (e.g. mlx-community format). This model will produce " + "gibberish because the norm weight shift depends on MTP detection." + ) + return MtpProbeResult( + model_declares_mtp=True, + mtp_tensors_found=False, + mtp_format=MtpFormat.STRIPPED, + mtp_count=0, + mtp_path=None, + mtp_tensor_keys=(), + ) + + +def _safetensors_keys(path: Path) -> tuple[str, ...]: + """Return tensor keys from a safetensors file without loading data.""" + try: + from safetensors import safe_open + + with safe_open(str(path), framework="numpy") as f: + return tuple(f.keys()) + except Exception: + return () + + +def load_mtp_weights(model_path: Path | str) -> dict[str, Any] | None: + """Load MTP weights from a separate file (MTPLX format). + + Only works for MTPLX_SEPARATE_FILE format. For embedded formats + (Original, oMLX), the MTP weights are already in the main shards + and will be loaded by ``mlx_lm``'s normal weight loading. + + Args: + model_path: Path to the model directory. + + Returns: + Dict of MTP tensor name → array, or ``None`` if not in separate file. + """ + result = probe_mtp_weights(model_path) + if result.mtp_format != MtpFormat.MTPLX_SEPARATE_FILE or result.mtp_path is None: + return None + + try: + from safetensors import safe_open + + mtp_weights: dict[str, Any] = {} + with safe_open(result.mtp_path, framework="mlx") as f: + for key in cast(Iterable[str], f.keys()): + if "mtp." in key: + mtp_weights[key] = f.get_tensor(key) + return mtp_weights if mtp_weights else None + except Exception as e: + logger.warning(f"Failed to load MTP weights from {result.mtp_path}: {e}") + return None diff --git a/src/exo/worker/engines/mlx/native_mtp_config.py b/src/exo/worker/engines/mlx/native_mtp_config.py new file mode 100644 index 0000000000..645d1e88fc --- /dev/null +++ b/src/exo/worker/engines/mlx/native_mtp_config.py @@ -0,0 +1,41 @@ +import os + +EXO_NATIVE_MTP_ENABLED_ENV = "EXO_NATIVE_MTP_ENABLED" + + +def native_mtp_enabled_from_env() -> bool: + """Return whether native-MTP dispatch is enabled for supported cards.""" + raw = os.environ.get(EXO_NATIVE_MTP_ENABLED_ENV) + if raw is None: + return True + return raw.strip().lower() not in {"0", "false", "no", "off"} + + +def clamp_native_mtp_num_draft_tokens( + num_draft_tokens: int, *, max_k: int +) -> tuple[int, bool]: + """Clamp native-MTP K to the card-declared runtime budget.""" + bounded_max = max(1, max_k) + if num_draft_tokens < 1: + return 1, True + if num_draft_tokens > bounded_max: + return bounded_max, True + return num_draft_tokens, False + + +def resolve_native_mtp_num_draft_tokens( + *, + request_num_draft_tokens: int | None, + configured_num_draft_tokens: int | None, + card_default_k: int, + card_max_k: int, +) -> tuple[int, bool]: + """Resolve native-MTP K with request > startup config > card default.""" + chosen = ( + request_num_draft_tokens + if request_num_draft_tokens is not None + else configured_num_draft_tokens + ) + if chosen is None: + chosen = card_default_k + return clamp_native_mtp_num_draft_tokens(chosen, max_k=card_max_k) diff --git a/src/exo/worker/engines/mlx/spec_cache.py b/src/exo/worker/engines/mlx/spec_cache.py new file mode 100644 index 0000000000..451537ba99 --- /dev/null +++ b/src/exo/worker/engines/mlx/spec_cache.py @@ -0,0 +1,178 @@ +"""Snapshot/rollback helpers for speculative-decode cache state. + +MLX's ``trim_prompt_cache`` only rolls back attention KV offsets, but the +Qwen3.5/3.6 GatedDeltaNet ``ArraysCache`` recurrent state cannot be +reconstructed from the offset alone. The native-MTP draft+verify loop +therefore clones that recurrent state before each batched verify forward +and restores it when a draft is rejected. Without this, a batched verify +followed by an offset-only trim produces materially wrong logits. + +Adapted from MTPLX's ``snapshot_untrimmable_cache`` / +``rollback_after_verify`` (Apache 2.0). +""" + +# pyright: reportUnknownMemberType=false, reportUnknownArgumentType=false, reportUnknownVariableType=false, reportAttributeAccessIssue=false, reportAny=false + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any + +import mlx.core as mx + + +@dataclass(frozen=True) +class CacheSnapshot: + """Frozen pair of (state, meta_state) tuples — one slot per cache entry. + + Trimmable entries (KV caches) carry `None` in both tuples; non-trimmable + entries (GatedDeltaNet `ArraysCache`) carry a deep-cloned snapshot. + """ + + states: tuple[Any, ...] + meta_states: tuple[Any, ...] + + +def _is_trimmable(entry: Any) -> bool: + """Mirror MTPLX's check: call `entry.is_trimmable()`, treat raise/missing as + non-trimmable. mlx-lm 0.31.3's `_BaseCache.is_trimmable` returns False; + KV-bearing subclasses override to True. `ArraysCache` (GDN recurrent + state) inherits the base False. + """ + method = getattr(entry, "is_trimmable", None) + if not callable(method): + return False + try: + return bool(method()) + except Exception: + return False + + +def _clone_tree(value: Any) -> Any: + """Recursive deep clone of containers + mx.array leaves. + + For mx.array leaves we materialize a fresh contiguous expression and + force evaluation so subsequent in-place writes by the cache cannot + mutate the snapshot through aliasing. Cost: one `mx.eval` per leaf. + """ + if value is None: + return None + if isinstance(value, mx.array): + leaf = mx.contiguous(value) + mx.eval(leaf) + return leaf + if isinstance(value, tuple): + return tuple(_clone_tree(v) for v in value) + if isinstance(value, list): + return [_clone_tree(v) for v in value] + if isinstance(value, dict): + return {k: _clone_tree(v) for k, v in value.items()} + return value + + +def _clone_tree_lazy(value: Any) -> Any: + """Recursive clone expression without forcing leaf evaluation. + + MLX arrays are immutable values, and cache updates replace cache leaves + rather than mutating the old array storage in place. For speculative + snapshots this lets accept rounds discard the snapshot without paying the + full synchronization cost. + """ + if value is None: + return None + if isinstance(value, mx.array): + return mx.contiguous(value) + if isinstance(value, tuple): + return tuple(_clone_tree_lazy(v) for v in value) + if isinstance(value, list): + return [_clone_tree_lazy(v) for v in value] + if isinstance(value, dict): + return {k: _clone_tree_lazy(v) for k, v in value.items()} + return value + + +def snapshot_untrimmable_cache_lazy(cache: list[Any]) -> CacheSnapshot: + """Clone state + meta_state of every non-trimmable cache entry. + + Trimmable entries get ``None`` (their offset trim is sufficient + rollback). Call this BEFORE the batched verify forward each spec-decode + round. The clone is lazy: MLX arrays are immutable and cache updates + replace leaves rather than mutating storage in place, so an accept round + can discard the snapshot without paying the full synchronization cost. + Measured logit/GDN-exact for K=1 accept and reject transactions. + """ + states: list[Any] = [] + meta_states: list[Any] = [] + for entry in cache: + if _is_trimmable(entry): + states.append(None) + meta_states.append(None) + else: + states.append(_clone_tree_lazy(getattr(entry, "state", None))) + meta_states.append(_clone_tree_lazy(getattr(entry, "meta_state", None))) + return CacheSnapshot(states=tuple(states), meta_states=tuple(meta_states)) + + +def _restore_state_preserving_container(entry: Any, state: Any) -> None: + """Write `state` back into `entry` without breaking container identity. + + `ArraysCache.state` is the same `list` the cache reads from internally. + Mutating it in place preserves identity; falling back to the property + setter is fine for entries that own one. + """ + cloned = _clone_tree(state) + if hasattr(entry, "replace_state"): + entry.replace_state(cloned) + return + current = getattr(entry, "state", None) + if ( + isinstance(current, list) + and isinstance(cloned, list) + and len(current) == len(cloned) + ): + current[:] = cloned + return + entry.state = cloned + + +def restore_cache(cache: list[Any], snapshot: CacheSnapshot) -> None: + """Restore non-trimmable cache state from a snapshot. Trimmable entries + untouched. Use directly when you've already done your own trim, or call + `rollback_after_verify` to combine trim + restore. + """ + for entry, state, meta_state in zip( + cache, snapshot.states, snapshot.meta_states, strict=True + ): + if state is not None: + _restore_state_preserving_container(entry, state) + if meta_state is not None: + entry.meta_state = _clone_tree(meta_state) + + +def rollback_after_verify( + cache: list[Any], + snapshot: CacheSnapshot, + verified_tokens: int, +) -> None: + """Undo a speculative target verify pass. + + `verified_tokens` is the count of tokens to TRIM from trimmable entries + (matches MTPLX's API and mlx-lm's `trim(n)` "remove n from end" + convention). For a K=1 verify of `[primary, draft]` advancing the + cache by 2: + accept: rollback_after_verify(cache, snap, verified_tokens=0) + — keep the full 2-step advance, just rewrite GDN state. + But the GDN state at pos+2 IS what we want post-accept, so + callers typically do NOT call rollback on accept; they re-run + the snapshot at the top of the next round. + reject: rollback_after_verify(cache, snap, verified_tokens=2) + — trim KV by 2, restore GDN state, then re-forward `[primary]` + alone to advance back to pos+1. + + Mirrors `mtplx/cache_state.py:rollback_after_verify`. + """ + if verified_tokens > 0: + for entry in cache: + if _is_trimmable(entry) and hasattr(entry, "trim"): + entry.trim(verified_tokens) + restore_cache(cache, snapshot) diff --git a/src/exo/worker/engines/mlx/utils_mlx.py b/src/exo/worker/engines/mlx/utils_mlx.py index 730abf64e3..6b458d0e05 100644 --- a/src/exo/worker/engines/mlx/utils_mlx.py +++ b/src/exo/worker/engines/mlx/utils_mlx.py @@ -27,7 +27,7 @@ from mlx_lm.models.deepseek_v3 import DeepseekV3Model from mlx_lm.tokenizer_utils import TokenizerWrapper -from exo.shared.models.model_cards import ModelId +from exo.shared.models.model_cards import ModelCard, ModelId from exo.worker.engines.mlx.constants import TRUST_REMOTE_CODE try: @@ -160,6 +160,45 @@ def initialize_mlx( return mlx_distributed_init(bound_instance) +def _native_mtp_loader_eligible(target_card: ModelCard, model_path: Path) -> bool: + """Decide whether ``load_mlx_items`` should dispatch the native MTP loader. + + Three conditions, all required: + + 1. The card declares ``native_mtp`` (the operator opted in). + 2. The on-disk checkpoint actually exposes recoverable MTP weights + (probe at :mod:`exo.worker.engines.mlx.mtp_probe` returns + ``is_recoverable=True``). MTPLX sidecar layout, original-HF + embedded layout, and oMLX embedded layout all qualify. + 3. (Implicit at the call site) The placement is single-node -- + ``load_mlx_items`` only calls this from the ``group is None`` + branch. Native MTP is structurally single-node-only: the verify + forward through a TP-sharded target would amortise K+1 tokens + over K+1 compute units, eating the MTP speedup. + + Returns ``False`` -- and the caller falls back to the stock loader -- + when any condition fails. The fallback is conservative: operators who + declared ``native_mtp`` on a card whose weights happen to be + unrecoverable get a quiet downgrade instead of a hard load failure. + """ + if target_card.native_mtp is None: + return False + # Local import to avoid pulling the probe (and safetensors) into the + # module-import path on workers that never load Qwen3.5/3.6. + from exo.worker.engines.mlx.mtp_probe import probe_mtp_weights + + probe = probe_mtp_weights(model_path) + if not probe.is_recoverable: + logger.warning( + f"Card {target_card.model_id} declares native_mtp but the " + f"on-disk checkpoint at {model_path} has no recoverable MTP " + f"weights (probe verdict: format={probe.mtp_format}, " + f"count={probe.mtp_count}). Falling back to stock loader." + ) + return False + return True + + def load_mlx_items( bound_instance: BoundInstance, group: mx.distributed.Group | None, @@ -170,9 +209,27 @@ def load_mlx_items( if group is None: logger.info(f"Single device used for {bound_instance.instance}") - model_path = build_model_path(bound_instance.bound_shard.model_card.model_id) + target_card = bound_instance.bound_shard.model_card + model_path = build_model_path(target_card.model_id) start_time = time.perf_counter() - model, _ = load_model(model_path, lazy=True, strict=False) + # Native MTP path: only when (a) the card declares it AND (b) the + # placement is single-node (this branch) AND (c) the on-disk + # checkpoint actually exposes recoverable MTP weights. The vendored + # loader builds an MTP-aware model that the generator dispatches a + # native draft+verify loop against. Otherwise fall back to the + # stock load path unchanged. + if _native_mtp_loader_eligible(target_card, model_path): + from exo.worker.engines.mlx.vendor.qwen3_5_mtp_loader import ( + load_mtp_model, + ) + + logger.info( + f"Loading {target_card.model_id} via native MTP loader " + f"(native_mtp=True, world_size=1, MTP weights recoverable)" + ) + model, _ = load_mtp_model(model_path, lazy=True, strict=True) + else: + model, _ = load_model(model_path, lazy=True, strict=False) # Eval layers one by one for progress reporting try: inner = get_inner_model(model) diff --git a/src/exo/worker/engines/mlx/vendor/qwen3_5_mtp.py b/src/exo/worker/engines/mlx/vendor/qwen3_5_mtp.py new file mode 100644 index 0000000000..c0bade044a --- /dev/null +++ b/src/exo/worker/engines/mlx/vendor/qwen3_5_mtp.py @@ -0,0 +1,753 @@ +"""Native MTP (Multi-Token Prediction) sidecar model class for Qwen3.6/Qwen3.5. + +This module provides a self-contained ``Model`` class that composes over +stock ``mlx_lm.models.qwen3_5`` and ``mlx_lm.models.qwen3_next`` with +first-class support for Qwen3.6's MTP head. It is intended to graduate +upstream as a new mlx-lm model class (in the same shape as the +``gemma4_assistant`` mlx-lm PR #1276), but lives here in exo's vendor +directory first. + +The design corrects seven concrete bugs in the prior upstream attempt +(``ml-explore/mlx-lm`` PR #1226 by chuaaron) that caused the original +work to be self-abandoned: + +1. **Strict weight loading.** Loader keeps ``strict=True`` for all + weights -- main and MTP -- so missing MTP shards raise instead of + silently initialising to random. +2. **MTP sidecar file.** Qwen3.6 ships MTP weights in a separate + ``mtp.safetensors`` file (not embedded in the main shards). The + loader probes that location explicitly. +3. **Per-weight norm-shift gate.** Stock ``Qwen3_5TextModel.sanitize`` + shifts RMSNorm weights by +1 unconditionally when it sees MTP keys. + This double-shifts already-shifted norms. Our sanitize gates the + shift on ``value.mean() < 0.5`` per weight, matching MTPLX's + ``_finalize_mtp_weights``. +4. **Post-norm hidden variant.** MTP was trained to ingest POST-norm + hidden states (i.e. ``self.model.norm(hidden)``), not the pre-norm + variant PR #1226 fed in. The ``pre_fc_norm_hidden`` RMSNorm inside + the MTP module assumes post-norm input. +5. **MTP-specific quantization policies.** Cyankiwi prequantized + shards quantize only the attention/MLP linears inside MTP; ``fc`` + and the norms stay in BF16. We detect the policy from the key set + and apply it. "all" prequantized shards also quantize ``fc``. +6. **MTP KV cache priming.** ``make_mtp_cache`` returns the empty + cache shape; ``mtp_update_cache`` walks prefill tokens through the + MTP layer with logits suppressed, populating K/V so the draft loop + inherits prefill context. +7. **Tests check numerical correctness.** The companion test module + includes both synthetic structural tests (no real download) and an + opt-in parity test against the real MTPLX artifact that asserts + top-1 agreement >= 60% against the main lm_head -- this catches + silent random-init regressions that shape-only tests would miss. + +The MTPLX runtime injection in ``mtplx/mtp_patch.py`` is the reference +behaviour we mirror; that file is Apache 2.0 and we re-implement here +in idiomatic mlx-lm class form rather than monkey-patching at runtime. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Union + +import mlx.core as mx +import mlx.nn as nn +from mlx.nn.layers.distributed import shard_inplace, shard_linear +from mlx.utils import tree_map +from mlx_lm.models.base import BaseModelArgs, create_attention_mask, create_ssm_mask +from mlx_lm.models.cache import ArraysCache, KVCache +from mlx_lm.models.qwen3_5 import DecoderLayer as StockDecoderLayer +from mlx_lm.models.qwen3_5 import TextModelArgs as StockTextModelArgs +from mlx_lm.models.qwen3_next import Qwen3NextMLP as MLP # noqa: N814 + +# pyright: reportAny=false, reportUnknownMemberType=false, reportUnknownArgumentType=false, reportUnknownVariableType=false, reportUnknownParameterType=false, reportMissingParameterType=false, reportPrivateUsage=false, reportIncompatibleMethodOverride=false, reportArgumentType=false, reportOptionalMemberAccess=false +# This module composes over untyped mlx-lm and mlx.nn modules; their +# attribute surface is dynamic and the type-checker can't see through +# nn.Module subclassing in the way mlx-lm uses it. + + +# --------------------------------------------------------------------------- +# Exceptions +# --------------------------------------------------------------------------- + + +class MTPWeightsNotFound(RuntimeError): # noqa: N818 - public API name + """Raised when ``mtp_num_hidden_layers > 0`` but no MTP weights are loadable. + + Carries the candidate filenames probed so the operator can diagnose + whether the sidecar wasn't downloaded, mounted, or named as expected. + """ + + def __init__(self, message: str, *, candidates: tuple[str, ...] = ()) -> None: + super().__init__(message) + self.candidates = candidates + + +# --------------------------------------------------------------------------- +# Model args +# --------------------------------------------------------------------------- + + +_RMSNORM_SUFFIXES: tuple[str, ...] = ( + "input_layernorm.weight", + "post_attention_layernorm.weight", + "q_norm.weight", + "k_norm.weight", + "pre_fc_norm_hidden.weight", + "pre_fc_norm_embedding.weight", + "norm.weight", +) + + +@dataclass +class TextModelArgs(StockTextModelArgs): + """Stock ``Qwen3_5`` text args extended with MTP fields. + + Defaults match the cyankiwi-prequantized MTPLX contract: + - ``mtp_num_hidden_layers``: 0 (no MTP unless config declares one) + - ``mtp_hidden_variant``: 'post_norm' (THE critical PR #1226 fix) + - ``mtp_concat_order``: 'embedding_hidden' (embedding first) + """ + + mtp_num_hidden_layers: int = 0 + mtp_hidden_variant: str = "post_norm" + mtp_concat_order: str = "embedding_hidden" + + +@dataclass +class ModelArgs(BaseModelArgs): + model_type: str + text_config: Dict[str, Any] + + @classmethod + def from_dict(cls, params: Dict[str, Any]) -> "ModelArgs": + if "text_config" not in params: + return cls(model_type=params["model_type"], text_config=params) + return super().from_dict(params) + + +# --------------------------------------------------------------------------- +# MTP module +# --------------------------------------------------------------------------- + + +class MTPModule(nn.Module): + """The MTP head: ``[pre_fc_norm_hidden, pre_fc_norm_embedding, fc, layers, norm]``. + + ``layers`` is a list of stock ``qwen3_5.DecoderLayer`` instances + constructed with ``layer_idx = full_attention_interval - 1`` so each + layer takes the full-attention (not linear-attention) branch + (``is_linear = False``). This matches MTPLX, which uses the + full-attention layer for MTP regardless of the main-model layer + cadence. + """ + + def __init__(self, args: TextModelArgs, n_layers: int) -> None: + super().__init__() + fa_layer_idx = args.full_attention_interval - 1 + self.pre_fc_norm_hidden = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + self.pre_fc_norm_embedding = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + self.fc = nn.Linear(args.hidden_size * 2, args.hidden_size, bias=False) + self.layers = [ + StockDecoderLayer(args, layer_idx=fa_layer_idx) for _ in range(n_layers) + ] + self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + + +# --------------------------------------------------------------------------- +# Inner text model with MTP attached +# --------------------------------------------------------------------------- + + +class Qwen3_5MTPInner(nn.Module): # noqa: N801 - mirrors mlx-lm's Qwen3_5TextModel naming + """Mirrors stock ``Qwen3_5TextModel`` plus an attached ``self.mtp``. + + ``__call__`` accepts ``return_hidden``, ``emit_logits`` and + ``logits_keep`` kwargs for cooperative use with the MTP draft loop. + """ + + def __init__(self, args: TextModelArgs) -> None: + super().__init__() + self.args = args + self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) + self.layers = [ + StockDecoderLayer(args=args, layer_idx=i) + for i in range(args.num_hidden_layers) + ] + self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + self.ssm_idx = 0 + self.fa_idx = args.full_attention_interval - 1 + self.mtp: Optional[MTPModule] = None + if args.mtp_num_hidden_layers > 0: + self.mtp = MTPModule(args, args.mtp_num_hidden_layers) + + def __call__( + self, + inputs: mx.array, + cache: Optional[List[Any]] = None, + input_embeddings: Optional[mx.array] = None, + return_hidden: bool = False, + ) -> Union[mx.array, tuple[mx.array, mx.array]]: + hidden_states = ( + input_embeddings + if input_embeddings is not None + else self.embed_tokens(inputs) + ) + if cache is None: + cache = [None] * len(self.layers) + fa_mask = create_attention_mask(hidden_states, cache[self.fa_idx]) + ssm_mask = create_ssm_mask(hidden_states, cache[self.ssm_idx]) + for layer, layer_cache in zip(self.layers, cache, strict=True): + mask = ssm_mask if layer.is_linear else fa_mask + hidden_states = layer(hidden_states, mask=mask, cache=layer_cache) + post = self.norm(hidden_states) + if return_hidden: + return post, hidden_states + return post + + +# --------------------------------------------------------------------------- +# TextModel wrapper (lm_head + cache builders + sanitize) +# --------------------------------------------------------------------------- + + +class TextModel(nn.Module): + """Wraps ``Qwen3_5MTPInner`` + ``lm_head``; mirrors stock ``TextModel`` API.""" + + def __init__(self, args: TextModelArgs) -> None: + super().__init__() + self.args = args + self.model_type = args.model_type + self.model = Qwen3_5MTPInner(args) + if not args.tie_word_embeddings: + self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) + # Storage for an optional sidecar weight loader supplied by the + # outer loader (qwen3_5_mtp_loader). Sanitize uses it to pick up + # the separate mtp.safetensors file when no MTP keys appear in + # the main shards. Type: Callable[[], dict[str, mx.array]] | None. + self._mtp_sidecar_loader: Optional[Callable[[], Dict[str, mx.array]]] = None + + def __call__( + self, + inputs: mx.array, + cache: Optional[List[Any]] = None, + input_embeddings: Optional[mx.array] = None, + return_hidden: bool = False, + ) -> Union[mx.array, tuple[mx.array, mx.array]]: + result = self.model( + inputs, + cache, + input_embeddings=input_embeddings, + return_hidden=return_hidden, + ) + if return_hidden: + assert isinstance(result, tuple) + post, _pre = result + else: + assert isinstance(result, mx.array) + post = result + if self.args.tie_word_embeddings: + logits = self.model.embed_tokens.as_linear(post) + else: + logits = self.lm_head(post) + if return_hidden: + return logits, post + return logits + + @property + def layers(self) -> List[StockDecoderLayer]: + return self.model.layers + + def make_cache(self) -> List[Any]: + return [ + ArraysCache(size=2) if layer.is_linear else KVCache() + for layer in self.layers + ] + + def make_mtp_cache(self) -> List[Any]: + if self.model.mtp is None: + return [] + return [KVCache() for _ in self.model.mtp.layers] + + # -------------------- sanitize -------------------- + + def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]: + """Sanitize stock+MTP weights. + + The incoming weight dict is keyed as it appears in safetensors: + for Qwen3.6-27B that's the ``language_model.model.*`` / + ``language_model.lm_head.*`` form. (The outer + :meth:`Model.sanitize` already normalized any other-prefix + variants into this shape.) + + Behavior: + 1. The MTP submodule's parameters live at + ``language_model.model.mtp.*``. Any incoming key in the + ``language_model.mtp.*`` or bare ``mtp.*`` (or + ``language_model.model.mtp.*``) namespaces is normalized + into the canonical model-tree key. + 2. If ``mtp_num_hidden_layers > 0`` and no MTP keys are + present, the sidecar loader (set by the outer loader) is + invoked. If that still produces no MTP keys, + :class:`MTPWeightsNotFound` is raised. + 3. Per-weight norm-shift gate: ``v + 1.0`` only when ``v.ndim + == 1`` and ``v.mean() < 0.5`` and the suffix matches a + RMSNorm weight. This prevents double-shifting weights that + are already in the post-shift form. + 4. conv1d axis fix (matches stock). + 5. ``lm_head.weight`` is dropped if ``tie_word_embeddings``. + """ + # ----- 1. namespace normalization ----- + normalized: Dict[str, mx.array] = {} + canonical_mtp_prefix = "language_model.model.mtp." + for key, value in weights.items(): + if key.startswith("language_model.model.mtp."): + normalized[key] = value + elif key.startswith("language_model.mtp."): + normalized[canonical_mtp_prefix + key[len("language_model.mtp.") :]] = ( + value + ) + elif key.startswith("mtp."): + normalized[canonical_mtp_prefix + key[len("mtp.") :]] = value + else: + normalized[key] = value + + embedded_mtp_keys = [ + k for k in normalized if k.startswith(canonical_mtp_prefix) + ] + wants_mtp = self.args.mtp_num_hidden_layers > 0 + + # ----- 2. sidecar fallback ----- + if wants_mtp and not embedded_mtp_keys: + loader = self._mtp_sidecar_loader + if loader is not None: + extra = loader() + if extra: + for k, v in extra.items(): + # Sidecar keys come in bare ``mtp.*`` form. + if k.startswith("mtp."): + normalized[canonical_mtp_prefix + k[len("mtp.") :]] = v + elif k.startswith(canonical_mtp_prefix): + normalized[k] = v + else: + # Unexpected; let load_weights complain. + normalized[k] = v + embedded_mtp_keys = [ + k for k in normalized if k.startswith(canonical_mtp_prefix) + ] + + if wants_mtp and not embedded_mtp_keys: + raise MTPWeightsNotFound( + "MTP declared in config (mtp_num_hidden_layers=" + f"{self.args.mtp_num_hidden_layers}) but no MTP weights were " + "found in the main shards and no sidecar loader produced any. " + "Candidate sidecar files (relative to model dir): " + "mtp.safetensors, mtp/weights.safetensors, model-mtp.safetensors, " + "or config.mlx_lm_extra_tensors.mtp_file", + candidates=( + "mtp.safetensors", + "mtp/weights.safetensors", + "model-mtp.safetensors", + ), + ) + + # ----- 3-5. tie/conv/norm fixes ----- + if self.args.tie_word_embeddings: + normalized.pop("language_model.lm_head.weight", None) + normalized.pop("lm_head.weight", None) + + for k, v in list(normalized.items()): + if "conv1d.weight" in k and v.shape[-1] != 1: + normalized[k] = v.moveaxis(2, 1) + if v.ndim == 1 and any(k.endswith(sfx) for sfx in _RMSNORM_SUFFIXES): + mean = float(v.astype(mx.float32).mean().item()) + if mean < 0.5: + normalized[k] = v + 1.0 + + return normalized + + # -------------------- quant predicates -------------------- + + @property + def quant_predicate(self) -> Optional[Callable[[str, nn.Module], Any]]: + """Main-model quant predicate (delegates to MoE-style if used). + + For Qwen3.6-27B (no MoE), this returns None and per-layer quant + comes from the ``quantization`` dict in config.json. + """ + if self.args.num_experts <= 0: + return None + + def predicate(path: str, _: nn.Module) -> Any: + if path.endswith(("mlp.gate", "shared_expert_gate")): + return {"group_size": 64, "bits": 8} + return True + + return predicate + + @property + def cast_predicate(self) -> Callable[[str], bool]: + def predicate(path: str) -> bool: + return not path.endswith("A_log") + + return predicate + + +# --------------------------------------------------------------------------- +# Outer Model +# --------------------------------------------------------------------------- + + +def _classify_mtp_key_set(keys: tuple[str, ...]) -> str: + """Return one of ``'unquantized'``, ``'cyankiwi'``, ``'all'``. + + The classification is exactly mirrored from MTPLX constants.py: + - 'cyankiwi': attention + mlp quantized; ``fc`` + norms BF16 + - 'all': everything quantized including ``fc`` + - 'unquantized': bf16 throughout (only ``.weight`` keys, no + ``.scales`` / ``.biases``) + """ + has_fc_scales = any(k == "mtp.fc.scales" for k in keys) + has_attn_scales = any(k.endswith(".self_attn.q_proj.scales") for k in keys) + if has_fc_scales: + return "all" + if has_attn_scales: + return "cyankiwi" + return "unquantized" + + +def _quantize_mtp_module( + mtp: MTPModule, + *, + policy: str, + bits: int, + group_size: int, + mode: str = "affine", + quant_overrides: Optional[Dict[str, Dict[str, Any]]] = None, +) -> None: + """Apply MTP-specific quantization according to ``policy``. + + - ``'all'``: quantize every quantizable module in the MTP subtree. + - ``'cyankiwi'``: quantize attention/MLP linears only; ``fc``, + ``pre_fc_norm_*``, ``norm`` stay unquantized. + - ``'unquantized'``: no-op. + """ + if policy == "unquantized": + return + if policy == "all": + nn.quantize(mtp, group_size=group_size, bits=bits, mode=mode) + return + if policy != "cyankiwi": + raise ValueError(f"Unsupported MTP quantization policy: {policy!r}") + + def predicate(path: str, module: nn.Module) -> Any: + if path == "fc" or path.startswith("pre_fc_norm") or path == "norm": + return False + if path.endswith("mlp.gate"): + return False + if path.startswith("layers.") and hasattr(module, "to_quantized"): + override = (quant_overrides or {}).get(f"language_model.mtp.{path}") + if override is not None: + return { + "group_size": int(override.get("group_size", group_size)), + "bits": int(override.get("bits", bits)), + "mode": str(override.get("mode", mode)), + } + return {"group_size": group_size, "bits": bits, "mode": mode} + return False + + nn.quantize(mtp, class_predicate=predicate) + + +class Model(nn.Module): + """Outer model class. Exposes ``language_model: TextModel`` + ``mtp_forward``.""" + + def __init__(self, args: ModelArgs) -> None: + super().__init__() + self.args = args + self.model_type = args.model_type + self.language_model = TextModel(TextModelArgs.from_dict(args.text_config)) + + # -------------------- forward -------------------- + + def __call__( + self, + inputs: mx.array, + cache: Optional[List[Any]] = None, + input_embeddings: Optional[mx.array] = None, + return_hidden: bool = False, + ) -> Union[mx.array, tuple[mx.array, mx.array]]: + return self.language_model( + inputs, + cache=cache, + input_embeddings=input_embeddings, + return_hidden=return_hidden, + ) + + # -------------------- MTP forward -------------------- + + def _mtp_core( + self, + hidden_states: mx.array, + next_token_ids: mx.array, + *, + mtp_cache: Optional[List[Any]], + concat_order: Optional[str], + mtp_hidden_variant: str, + emit_logits: bool, + ) -> tuple[Optional[mx.array], mx.array]: + text_model = self.language_model + inner = text_model.model + mtp = inner.mtp + if mtp is None: + raise RuntimeError("Model was constructed without an MTP head") + + input_embeds = inner.embed_tokens(next_token_ids) + e = mtp.pre_fc_norm_embedding(input_embeds) + h = mtp.pre_fc_norm_hidden(hidden_states) + order = concat_order or text_model.args.mtp_concat_order + parts = [e, h] if order == "embedding_hidden" else [h, e] + x = mtp.fc(mx.concatenate(parts, axis=-1)) + layer_cache = mtp_cache[0] if mtp_cache else None + mask = create_attention_mask(x, layer_cache) + x = mtp.layers[0](x, mask=mask, cache=layer_cache) + post_norm = mtp.norm(x) + if mtp_hidden_variant == "post_norm": + hidden = post_norm + elif mtp_hidden_variant == "pre_norm": + hidden = x + else: + raise ValueError( + f"Unsupported mtp_hidden_variant={mtp_hidden_variant!r}; " + "use 'post_norm' or 'pre_norm'" + ) + if not emit_logits: + return None, hidden + if text_model.args.tie_word_embeddings: + logits = inner.embed_tokens.as_linear(post_norm) + else: + logits = text_model.lm_head(post_norm) + return logits, hidden + + def mtp_forward( + self, + hidden_states: mx.array, + next_token_ids: mx.array, + *, + mtp_cache: Optional[List[Any]] = None, + concat_order: Optional[str] = None, + mtp_hidden_variant: str = "post_norm", + emit_logits: bool = True, + return_hidden: bool = False, + ) -> Union[mx.array, tuple[mx.array, mx.array]]: + logits, hidden = self._mtp_core( + hidden_states, + next_token_ids, + mtp_cache=mtp_cache, + concat_order=concat_order, + mtp_hidden_variant=mtp_hidden_variant, + emit_logits=emit_logits, + ) + if not emit_logits: + # caller asked for cache-update only + return hidden + assert logits is not None + if return_hidden: + return logits, hidden + return logits + + def mtp_update_cache( + self, + hidden_states: mx.array, + next_token_ids: mx.array, + *, + mtp_cache: Optional[List[Any]] = None, + concat_order: Optional[str] = None, + ) -> mx.array: + """Populate MTP KV cache without emitting logits. + + Used during prefill to seed the MTP cache with K/V derived from + the prompt history, so subsequent draft calls see prefill + context. Returns the post-norm hidden the MTP head produced (for + chained ``mtp_update_cache`` calls if you want them). + """ + _, hidden = self._mtp_core( + hidden_states, + next_token_ids, + mtp_cache=mtp_cache, + concat_order=concat_order, + mtp_hidden_variant="post_norm", + emit_logits=False, + ) + return hidden + + # -------------------- cache builders -------------------- + + def make_cache(self) -> List[Any]: + return self.language_model.make_cache() + + def make_mtp_cache(self) -> List[Any]: + return self.language_model.make_mtp_cache() + + @property + def layers(self) -> List[StockDecoderLayer]: + return self.language_model.model.layers + + # -------------------- sanitize -------------------- + + def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]: + """Outer sanitize: normalize prefixes, then delegate to TextModel. + + Stock ``qwen3_5.Model.sanitize`` returns weights with the + ``language_model.`` prefix intact so the outer ``load_weights`` + can route them to ``self.language_model.*``. We keep that + convention. MTP keys are normalized to ``language_model.mtp.*`` + before delegation so :meth:`TextModel.sanitize` -- which keys off + bare ``mtp.*`` -- can detect and process them. + """ + sanitized: Dict[str, mx.array] = {} + for key, value in weights.items(): + if key.startswith(("vision_tower", "model.visual")): + continue + if key.startswith("model.language_model"): + key = key.replace("model.language_model", "language_model.model") + elif key.startswith("language_model."): + pass + else: + key = "language_model." + key + sanitized[key] = value + # Hand off to TextModel.sanitize, which sees keys WITH the + # "language_model." prefix and returns them WITH the prefix + # preserved (except for MTP keys, which it normalizes). + return self.language_model.sanitize(sanitized) + + # -------------------- quantize hook -------------------- + + @property + def quant_predicate(self) -> Optional[Callable[[str, nn.Module], Any]]: + return self.language_model.quant_predicate + + @property + def cast_predicate(self) -> Callable[[str], bool]: + return self.language_model.cast_predicate + + # -------------------- shard (DELIBERATELY replicates MTP) -------------------- + + def shard(self, group: Optional[Any] = None) -> None: + """Stock-style shard for main layers; MTP is REPLICATED, not sharded. + + Replicating the MTP module across all nodes keeps the draft loop + independent of inter-node communication latency, which is the + whole point of speculative decoding. + """ + group = group or mx.distributed.init() + N = group.size() # noqa: N806 - stock qwen3_5.Model.shard uses N + rank = group.rank() + + def conv_sharding( + key_dim: int, + ) -> Callable[[str, mx.array], tuple[int, list[int]]]: + return lambda p, w: (0, [key_dim, 2 * key_dim]) + + def repeat_kv_layer_inplace(layer: nn.Module, h: int) -> None: + if h >= N: + return + + def _repeat(p: mx.array) -> mx.array: + s = p.shape + p = p.reshape(h, s[0] // h, *s[1:]) + p = mx.repeat(p, N // h, axis=0) + p = p.reshape(-1, *s[1:]) + return p + + layer.update(tree_map(_repeat, layer.parameters())) + + for layer in self.layers: + if layer.is_linear: + kd = layer.linear_attn.key_dim + layer.linear_attn.sharding_group = group + shard_inplace(layer.linear_attn.conv1d, conv_sharding(kd), group=group) + layer.linear_attn.conv1d.groups //= N + shard_inplace( + layer.linear_attn.in_proj_qkv, + "all-to-sharded", + segments=[kd, 2 * kd], + group=group, + ) + shard_inplace( + layer.linear_attn.in_proj_z, "all-to-sharded", group=group + ) + shard_inplace( + layer.linear_attn.in_proj_b, "all-to-sharded", group=group + ) + shard_inplace( + layer.linear_attn.in_proj_a, "all-to-sharded", group=group + ) + layer.linear_attn.dt_bias = mx.contiguous( + mx.split(layer.linear_attn.dt_bias, N)[rank] + ) + layer.linear_attn.A_log = mx.contiguous( + mx.split(layer.linear_attn.A_log, N)[rank] + ) + shard_inplace(layer.linear_attn.out_proj, "sharded-to-all", group=group) + layer.linear_attn.num_k_heads //= N + layer.linear_attn.num_v_heads //= N + layer.linear_attn.key_dim //= N + layer.linear_attn.value_dim //= N + layer.linear_attn.conv_dim //= N + else: + layer.self_attn.o_proj = shard_linear( + layer.self_attn.o_proj, "sharded-to-all", group=group + ) + layer.self_attn.q_proj = shard_linear( + layer.self_attn.q_proj, "all-to-sharded", group=group + ) + repeat_kv_layer_inplace( + layer.self_attn.k_proj, layer.self_attn.num_key_value_heads + ) + repeat_kv_layer_inplace( + layer.self_attn.v_proj, layer.self_attn.num_key_value_heads + ) + layer.self_attn.k_proj = shard_linear( + layer.self_attn.k_proj, "all-to-sharded", group=group + ) + layer.self_attn.v_proj = shard_linear( + layer.self_attn.v_proj, "all-to-sharded", group=group + ) + layer.self_attn.num_attention_heads //= N + layer.self_attn.num_key_value_heads = max( + 1, layer.self_attn.num_key_value_heads // N + ) + + if isinstance(layer.mlp, MLP): + layer.mlp.gate_proj = shard_linear( + layer.mlp.gate_proj, "all-to-sharded", group=group + ) + layer.mlp.down_proj = shard_linear( + layer.mlp.down_proj, "sharded-to-all", group=group + ) + layer.mlp.up_proj = shard_linear( + layer.mlp.up_proj, "all-to-sharded", group=group + ) + + # MTP is intentionally NOT sharded -- it's replicated on every + # node so the draft loop is fully local. See class docstring. + + +# --------------------------------------------------------------------------- +# Public helpers (used by the loader and tests) +# --------------------------------------------------------------------------- + + +__all__ = [ + "MTPWeightsNotFound", + "TextModelArgs", + "ModelArgs", + "MTPModule", + "Qwen3_5MTPInner", + "TextModel", + "Model", + "_classify_mtp_key_set", + "_quantize_mtp_module", + "_RMSNORM_SUFFIXES", +] diff --git a/src/exo/worker/engines/mlx/vendor/qwen3_5_mtp_loader.py b/src/exo/worker/engines/mlx/vendor/qwen3_5_mtp_loader.py new file mode 100644 index 0000000000..9a7d7fbcc9 --- /dev/null +++ b/src/exo/worker/engines/mlx/vendor/qwen3_5_mtp_loader.py @@ -0,0 +1,303 @@ +"""Loader that swaps in the vendored Qwen3.5/3.6 MTP-aware model class. + +If the model declares MTP (``text_config.mtp_num_hidden_layers > 0``) +and a usable weight source exists (either a separate ``mtp.safetensors`` +sidecar, or ``mtp.*`` keys in the main shards), the loader dispatches +``model_type='qwen3_5'`` / ``'qwen3_5_moe'`` to +:class:`vendor.qwen3_5_mtp.Model` and threads a sidecar-weights callable +into the model so ``sanitize`` can pick it up. + +If MTP is not declared, or no weights are available, the loader falls +through to stock ``mlx_lm.utils.load_model`` (which produces stock +``mlx_lm.models.qwen3_5.Model``). + +Strict-load is the only mode supported -- random init of MTP weights +would silently regress to the PR #1226 failure mode. +""" + +from __future__ import annotations + +import json +from pathlib import Path +from typing import Any, Callable, Dict, Optional, Tuple, Type + +import mlx.core as mx +import mlx.nn as nn +from mlx_lm import utils as _mlx_lm_utils +from mlx_lm.utils import load_model + +from .qwen3_5_mtp import ( + Model as MtpModel, +) +from .qwen3_5_mtp import ( + ModelArgs as MtpModelArgs, +) +from .qwen3_5_mtp import ( + MTPWeightsNotFound, + _classify_mtp_key_set, + _quantize_mtp_module, +) + +_get_classes: Callable[..., Tuple[Type[nn.Module], Type[Any]]] = ( + _mlx_lm_utils._get_classes # pyright: ignore[reportAttributeAccessIssue] +) +_SUPPORTED_NATIVE_MTP_MODEL_TYPES = frozenset({"qwen3_5", "qwen3_5_moe"}) + + +# pyright: reportAny=false, reportUnknownMemberType=false, reportUnknownArgumentType=false, reportUnknownVariableType=false, reportUnknownParameterType=false, reportMissingParameterType=false, reportPrivateUsage=false, reportAttributeAccessIssue=false, reportUnnecessaryTypeIgnoreComment=false +# Loader operates on untyped safetensors / JSON metadata and patches +# mlx-lm classes whose surface is dynamic. + + +# --------------------------------------------------------------------------- +# Candidate sidecar locations +# --------------------------------------------------------------------------- + + +_DEFAULT_SIDECAR_CANDIDATES: tuple[str, ...] = ( + "mtp.safetensors", + "mtp/weights.safetensors", + "model-mtp.safetensors", +) + + +def _config_sidecar_filename(config: Dict[str, Any]) -> Optional[str]: + extra = config.get("mlx_lm_extra_tensors") or {} + if isinstance(extra, dict): + name = extra.get("mtp_file") + if isinstance(name, str) and name: + return name + return None + + +def _resolve_sidecar_path(model_path: Path, config: Dict[str, Any]) -> Optional[Path]: + """Return the first existing sidecar path or ``None``.""" + configured = _config_sidecar_filename(config) + if configured is not None: + candidate = model_path / configured + if candidate.exists(): + return candidate + for rel in _DEFAULT_SIDECAR_CANDIDATES: + candidate = model_path / rel + if candidate.exists(): + return candidate + return None + + +def _embedded_mtp_keys(model_path: Path) -> Tuple[str, ...]: + """Probe ``model.safetensors.index.json`` for embedded MTP keys. + + Embedded layouts use either the ``mtp.*`` prefix (original HF + format) or ``language_model.mtp.*`` (oMLX). We treat both as + embedded and let the sanitizer normalize the namespace. + """ + index = model_path / "model.safetensors.index.json" + if not index.exists(): + return () + try: + payload = json.loads(index.read_text(encoding="utf-8")) + except (json.JSONDecodeError, OSError): + return () + weight_map = payload.get("weight_map") + if not isinstance(weight_map, dict): + return () + return tuple( + sorted( + str(k) + for k in weight_map + if str(k).startswith("mtp.") or str(k).startswith("language_model.mtp.") + ) + ) + + +def _sidecar_keys(sidecar: Path) -> Tuple[str, ...]: + try: + from safetensors import safe_open # type: ignore[import-not-found] + except ImportError: + return () + try: + with safe_open(str(sidecar), framework="numpy") as handle: # type: ignore[no-untyped-call] + # safe_open is not iterable; .keys() is the supported API. + return tuple(sorted(str(k) for k in handle.keys())) # noqa: SIM118 + except Exception: + return () + + +# --------------------------------------------------------------------------- +# Weight finalization +# --------------------------------------------------------------------------- + + +_RMSNORM_SUFFIXES: tuple[str, ...] = ( + "input_layernorm.weight", + "post_attention_layernorm.weight", + "q_norm.weight", + "k_norm.weight", + "pre_fc_norm_hidden.weight", + "pre_fc_norm_embedding.weight", + "norm.weight", +) + + +def _strip_mtp_prefix(key: str) -> str: + if key.startswith("language_model.mtp."): + return "mtp." + key[len("language_model.mtp.") :] + return key + + +def _load_sidecar_weights(sidecar: Path) -> Dict[str, mx.array]: + """Load all MTP weights from the sidecar with the ``mtp.`` prefix kept.""" + raw = mx.load(str(sidecar)) + out: Dict[str, mx.array] = {} + for key, value in raw.items(): + normalized = _strip_mtp_prefix(str(key)) + if normalized.startswith("mtp."): + out[normalized] = value + return out + + +# --------------------------------------------------------------------------- +# Public loader entry point +# --------------------------------------------------------------------------- + + +def _model_declares_mtp(config: Dict[str, Any]) -> bool: + tcfg = config.get("text_config", config) + return int(tcfg.get("mtp_num_hidden_layers", 0) or 0) > 0 + + +def load_mtp_model( + model_path: Path, + *, + lazy: bool = False, + strict: bool = True, +) -> Tuple[nn.Module, Dict[str, Any]]: + """Load a Qwen3.5/3.6 model, attaching native MTP support if available. + + Returns the loaded model and the (possibly updated) config dict, + matching ``mlx_lm.utils.load_model``'s signature. + + Raises :class:`MTPWeightsNotFound` if MTP is declared but no weight + source can be found. + """ + if not strict: + raise ValueError( + "load_mtp_model only supports strict=True; non-strict loading " + "would silently allow random-initialized MTP weights, " + "reproducing the PR #1226 failure mode." + ) + + with open(model_path / "config.json", encoding="utf-8") as f: + config = json.load(f) + + if config.get( + "model_type" + ) not in _SUPPORTED_NATIVE_MTP_MODEL_TYPES or not _model_declares_mtp(config): + # No MTP declared (or wrong model type): fall through to stock. + return load_model(model_path, lazy=lazy, strict=strict) + + sidecar = _resolve_sidecar_path(model_path, config) + embedded_keys = _embedded_mtp_keys(model_path) + has_embedded_mtp = bool(embedded_keys) + + if sidecar is None and not has_embedded_mtp: + candidates_msg = ", ".join( + (str(_config_sidecar_filename(config) or ""),) + _DEFAULT_SIDECAR_CANDIDATES + ) + raise MTPWeightsNotFound( + "Qwen3.5/3.6 model declares MTP but no MTP weights are present in " + f"{model_path}. Probed: {candidates_msg} and " + "model.safetensors.index.json for embedded mtp.* / " + "language_model.mtp.* keys.", + candidates=(str(sidecar) if sidecar else "",) + _DEFAULT_SIDECAR_CANDIDATES, + ) + + # Decide quantization policy from the available key set (sidecar + # takes precedence; if absent, use embedded keys). + keys_for_policy: tuple[str, ...] = () + if sidecar is not None: + keys_for_policy = _sidecar_keys(sidecar) + if not keys_for_policy and embedded_keys: + keys_for_policy = tuple(_strip_mtp_prefix(k) for k in embedded_keys) + policy = _classify_mtp_key_set(keys_for_policy) + + # Pull bits/group_size from explicit MTP quant config if present, + # otherwise fall back to the main quantization block. + mtp_quant = config.get("mtplx_mtp_quantization") or {} + quant_overrides: Dict[str, Dict[str, Any]] = {} + if isinstance(mtp_quant, dict) and mtp_quant.get("prequantized"): + bits = int(mtp_quant.get("bits", 8)) + group_size = int(mtp_quant.get("group_size", 64)) + mode = str(mtp_quant.get("mode", "affine")) + else: + main_quant = ( + config.get("quantization") or config.get("quantization_config") or {} + ) + bits = int(main_quant.get("bits", 4)) + group_size = int(main_quant.get("group_size", 64)) + mode = str(main_quant.get("mode", "affine")) + if isinstance(main_quant, dict): + quant_overrides = { + str(key): value + for key, value in main_quant.items() + if isinstance(key, str) + and key.startswith("language_model.mtp.") + and isinstance(value, dict) + } + + # Build a sidecar-loader callable; the TextModel.sanitize calls it + # when it can't find embedded MTP keys. + sidecar_loader: Optional[Callable[[], Dict[str, mx.array]]] = None + if sidecar is not None: + captured = sidecar # avoid late-binding + + def _loader() -> Dict[str, mx.array]: + return _load_sidecar_weights(captured) + + sidecar_loader = _loader + + # Custom get_model_classes that returns our MTP-aware classes for + # qwen3_5 and arms the model instance with the sidecar loader and + # the MTP quant policy/params before sanitize runs. + def get_classes(config: Dict[str, Any]) -> Tuple[Type[nn.Module], Type[Any]]: + cfg = config + if cfg.get("model_type") not in _SUPPORTED_NATIVE_MTP_MODEL_TYPES: + return _get_classes(cfg) + + original_init = MtpModel.__init__ + + def patched_init(self_: MtpModel, args: MtpModelArgs) -> None: + original_init(self_, args) + self_.language_model._mtp_sidecar_loader = sidecar_loader + # Quantize MTP module per the detected policy BEFORE + # load_weights runs in mlx_lm.utils.load_model. The main + # model's per-layer quantization dict in config["quantization"] + # does not cover the MTP submodule (no entries with + # 'mtp.*' paths), so we have to do it ourselves. + mtp_submodule = self_.language_model.model.mtp + if mtp_submodule is not None and policy != "unquantized": + _quantize_mtp_module( + mtp_submodule, + policy=policy, + bits=bits, + group_size=group_size, + mode=mode, + quant_overrides=quant_overrides, + ) + + # Use a tiny subclass so we don't permanently mutate MtpModel. + class _PatchedMtpModel(MtpModel): + __init__ = patched_init # type: ignore[assignment] + + return _PatchedMtpModel, MtpModelArgs + + model, updated_config = load_model( + model_path, + lazy=lazy, + strict=True, + get_model_classes=get_classes, + ) + return model, updated_config + + +__all__ = ["load_mtp_model"] diff --git a/src/exo/worker/engines/mlx/vendor/tests/__init__.py b/src/exo/worker/engines/mlx/vendor/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/exo/worker/engines/mlx/vendor/tests/test_qwen3_5_mtp.py b/src/exo/worker/engines/mlx/vendor/tests/test_qwen3_5_mtp.py new file mode 100644 index 0000000000..e48c39e3f4 --- /dev/null +++ b/src/exo/worker/engines/mlx/vendor/tests/test_qwen3_5_mtp.py @@ -0,0 +1,472 @@ +"""Synthetic unit tests for the native Qwen3.5/3.6 MTP model class. + +These tests build tiny configurations and synthetic weight dicts so the +whole suite runs in seconds without requiring any model download. +The opt-in parity test against the real MTPLX artifact lives in a +separate file (test_qwen3_5_mtp_parity.py). +""" + +# pyright: reportAny=false, reportUnknownMemberType=false, reportUnknownArgumentType=false, reportUnknownVariableType=false, reportUnknownParameterType=false, reportMissingParameterType=false, reportPrivateUsage=false, reportArgumentType=false, reportOperatorIssue=false + +from __future__ import annotations + +from typing import Any, Dict + +import mlx.core as mx +import mlx.nn as nn +import pytest + +from exo.worker.engines.mlx.vendor.qwen3_5_mtp import ( + Model, + ModelArgs, + MTPModule, + MTPWeightsNotFound, + TextModelArgs, + _classify_mtp_key_set, + _quantize_mtp_module, +) + +# --------------------------------------------------------------------------- +# Test fixtures: minimal "tiny" Qwen3.6-shaped config. +# --------------------------------------------------------------------------- + + +def _tiny_text_config(*, with_mtp: bool = True) -> Dict[str, Any]: + """Tiny config (32 hidden, 1 layer per group, 1 MTP layer) for fast tests. + + The shapes are chosen to be self-consistent with the Qwen3.6 architecture + constraints (head dims, GQA, full_attention_interval) while staying + extremely small. + """ + cfg: Dict[str, Any] = { + "model_type": "qwen3_5_text", + "hidden_size": 64, + "intermediate_size": 128, + "num_hidden_layers": 4, + "num_attention_heads": 4, + "rms_norm_eps": 1e-6, + "vocab_size": 256, + "num_key_value_heads": 2, + "max_position_embeddings": 256, + "linear_num_value_heads": 4, + "linear_num_key_heads": 2, + "linear_key_head_dim": 16, + "linear_value_head_dim": 16, + "linear_conv_kernel_dim": 4, + "tie_word_embeddings": False, + "attention_bias": False, + "head_dim": 16, + "full_attention_interval": 2, + "num_experts": 0, + "rope_parameters": { + "type": "default", + "rope_theta": 10000.0, + "partial_rotary_factor": 1.0, + }, + } + if with_mtp: + cfg["mtp_num_hidden_layers"] = 1 + return cfg + + +def _tiny_model_args(*, with_mtp: bool = True) -> ModelArgs: + return ModelArgs( + model_type="qwen3_5", + text_config=_tiny_text_config(with_mtp=with_mtp), + ) + + +# --------------------------------------------------------------------------- +# Shape tests +# --------------------------------------------------------------------------- + + +def test_forward_shapes() -> None: + """``Model(...)(tokens)`` returns logits of shape (B, T, V).""" + args = _tiny_model_args(with_mtp=True) + model = Model(args) + mx.eval(model.parameters()) + tokens = mx.array([[1, 2, 3, 4, 5]], dtype=mx.int32) + logits = model(tokens) + assert isinstance(logits, mx.array) + mx.eval(logits) + assert logits.shape == (1, 5, 256) + + +def test_forward_return_hidden_shapes() -> None: + """``return_hidden=True`` returns (logits, hidden) with hidden = post-norm.""" + args = _tiny_model_args(with_mtp=True) + model = Model(args) + mx.eval(model.parameters()) + tokens = mx.array([[1, 2, 3, 4, 5]], dtype=mx.int32) + result = model(tokens, return_hidden=True) + assert isinstance(result, tuple) + logits, hidden = result + mx.eval(logits, hidden) + assert logits.shape == (1, 5, 256) + assert hidden.shape == (1, 5, 64) + + +def test_mtp_forward_shape() -> None: + """``mtp_forward`` returns logits of shape (B, T, V) from hidden + token.""" + args = _tiny_model_args(with_mtp=True) + model = Model(args) + mx.eval(model.parameters()) + hidden = mx.random.uniform(shape=(1, 1, 64)) + next_token = mx.array([[42]], dtype=mx.int32) + cache = model.make_mtp_cache() + logits = model.mtp_forward(hidden, next_token, mtp_cache=cache) + assert isinstance(logits, mx.array) + mx.eval(logits) + assert logits.shape == (1, 1, 256) + + +def test_make_mtp_cache_length_matches_mtp_layers() -> None: + args = _tiny_model_args(with_mtp=True) + model = Model(args) + cache = model.make_mtp_cache() + assert len(cache) == args.text_config["mtp_num_hidden_layers"] + + +# --------------------------------------------------------------------------- +# Norm-shift gate idempotence +# --------------------------------------------------------------------------- + + +def _build_synthetic_main_weights(args: TextModelArgs) -> Dict[str, mx.array]: + """Build a synthetic weight dict that matches the main model shape. + + Norms start with mean ~0 (the unshifted form) so we can observe the + +1.0 shift kicking in. + """ + hidden = args.hidden_size + vocab = args.vocab_size + inter = args.intermediate_size + head_dim = args.head_dim + n_heads = args.num_attention_heads + n_kv = args.num_key_value_heads + # Use float32 with mean 0 so the shift gate fires deterministically. + rng = mx.random.uniform + weights: Dict[str, mx.array] = { + "model.embed_tokens.weight": rng(shape=(vocab, hidden)) - 0.5, + "model.norm.weight": mx.zeros((hidden,), dtype=mx.float32), + "lm_head.weight": rng(shape=(vocab, hidden)) - 0.5, + } + for li in range(args.num_hidden_layers): + is_linear = (li + 1) % args.full_attention_interval != 0 + prefix = f"model.layers.{li}" + weights[f"{prefix}.input_layernorm.weight"] = mx.zeros( + (hidden,), dtype=mx.float32 + ) + weights[f"{prefix}.post_attention_layernorm.weight"] = mx.zeros( + (hidden,), dtype=mx.float32 + ) + weights[f"{prefix}.mlp.gate_proj.weight"] = rng(shape=(inter, hidden)) - 0.5 + weights[f"{prefix}.mlp.down_proj.weight"] = rng(shape=(hidden, inter)) - 0.5 + weights[f"{prefix}.mlp.up_proj.weight"] = rng(shape=(inter, hidden)) - 0.5 + if is_linear: + # GatedDeltaNet weights -- shapes mirror what stock code builds. + key_dim = args.linear_num_key_heads * args.linear_key_head_dim + value_dim = args.linear_num_value_heads * args.linear_value_head_dim + conv_dim = key_dim * 2 + value_dim + weights[f"{prefix}.linear_attn.conv1d.weight"] = ( + rng(shape=(conv_dim, 1, args.linear_conv_kernel_dim)) - 0.5 + ) + weights[f"{prefix}.linear_attn.in_proj_qkv.weight"] = ( + rng(shape=(conv_dim, hidden)) - 0.5 + ) + weights[f"{prefix}.linear_attn.in_proj_z.weight"] = ( + rng(shape=(value_dim, hidden)) - 0.5 + ) + weights[f"{prefix}.linear_attn.in_proj_b.weight"] = ( + rng(shape=(args.linear_num_value_heads, hidden)) - 0.5 + ) + weights[f"{prefix}.linear_attn.in_proj_a.weight"] = ( + rng(shape=(args.linear_num_value_heads, hidden)) - 0.5 + ) + weights[f"{prefix}.linear_attn.out_proj.weight"] = ( + rng(shape=(hidden, value_dim)) - 0.5 + ) + weights[f"{prefix}.linear_attn.dt_bias"] = mx.ones( + (args.linear_num_value_heads,) + ) + weights[f"{prefix}.linear_attn.A_log"] = mx.zeros( + (args.linear_num_value_heads,) + ) + weights[f"{prefix}.linear_attn.norm.weight"] = mx.ones( + (args.linear_value_head_dim,) + ) + else: + weights[f"{prefix}.self_attn.q_proj.weight"] = ( + rng(shape=(n_heads * head_dim * 2, hidden)) - 0.5 + ) + weights[f"{prefix}.self_attn.k_proj.weight"] = ( + rng(shape=(n_kv * head_dim, hidden)) - 0.5 + ) + weights[f"{prefix}.self_attn.v_proj.weight"] = ( + rng(shape=(n_kv * head_dim, hidden)) - 0.5 + ) + weights[f"{prefix}.self_attn.o_proj.weight"] = ( + rng(shape=(hidden, n_heads * head_dim)) - 0.5 + ) + weights[f"{prefix}.self_attn.q_norm.weight"] = mx.zeros( + (head_dim,), dtype=mx.float32 + ) + weights[f"{prefix}.self_attn.k_norm.weight"] = mx.zeros( + (head_dim,), dtype=mx.float32 + ) + return weights + + +def _build_synthetic_mtp_weights(args: TextModelArgs) -> Dict[str, mx.array]: + """Build a synthetic MTP weight dict with unshifted norms.""" + hidden = args.hidden_size + inter = args.intermediate_size + head_dim = args.head_dim + n_heads = args.num_attention_heads + n_kv = args.num_key_value_heads + rng = mx.random.uniform + weights: Dict[str, mx.array] = { + "mtp.fc.weight": rng(shape=(hidden, 2 * hidden)) - 0.5, + "mtp.pre_fc_norm_hidden.weight": mx.zeros((hidden,), dtype=mx.float32), + "mtp.pre_fc_norm_embedding.weight": mx.zeros((hidden,), dtype=mx.float32), + "mtp.norm.weight": mx.zeros((hidden,), dtype=mx.float32), + "mtp.layers.0.input_layernorm.weight": mx.zeros((hidden,), dtype=mx.float32), + "mtp.layers.0.post_attention_layernorm.weight": mx.zeros( + (hidden,), dtype=mx.float32 + ), + "mtp.layers.0.self_attn.q_proj.weight": rng( + shape=(n_heads * head_dim * 2, hidden) + ) + - 0.5, + "mtp.layers.0.self_attn.k_proj.weight": rng(shape=(n_kv * head_dim, hidden)) + - 0.5, + "mtp.layers.0.self_attn.v_proj.weight": rng(shape=(n_kv * head_dim, hidden)) + - 0.5, + "mtp.layers.0.self_attn.o_proj.weight": rng(shape=(hidden, n_heads * head_dim)) + - 0.5, + "mtp.layers.0.self_attn.q_norm.weight": mx.zeros((head_dim,), dtype=mx.float32), + "mtp.layers.0.self_attn.k_norm.weight": mx.zeros((head_dim,), dtype=mx.float32), + "mtp.layers.0.mlp.gate_proj.weight": rng(shape=(inter, hidden)) - 0.5, + "mtp.layers.0.mlp.down_proj.weight": rng(shape=(hidden, inter)) - 0.5, + "mtp.layers.0.mlp.up_proj.weight": rng(shape=(inter, hidden)) - 0.5, + } + return weights + + +def test_norm_shift_idempotent() -> None: + """Sanitize shifts mean=0 norms to mean=1; a second sanitize does NOT shift again.""" + args = _tiny_model_args(with_mtp=True) + model = Model(args) + main = _build_synthetic_main_weights(model.language_model.args) + mtp = _build_synthetic_mtp_weights(model.language_model.args) + raw: Dict[str, mx.array] = {} + for k, v in main.items(): + raw["language_model." + k] = v + for k, v in mtp.items(): + raw["language_model." + k] = v + + sanitized = model.sanitize(raw) + # After first sanitize: all norm.weights should have mean ~1. + norm_keys = [ + k + for k in sanitized + if k.endswith(("input_layernorm.weight", "norm.weight", "q_norm.weight")) + and sanitized[k].ndim == 1 + ] + assert len(norm_keys) > 0 + for k in norm_keys: + m = float(sanitized[k].astype(mx.float32).mean().item()) + assert abs(m - 1.0) < 0.05, f"first sanitize on {k}: mean={m}" + + # Second sanitize on the already-shifted dict should NOT shift again. + # The sanitized output is already keyed in the canonical + # ``language_model.model.X`` form, so we can feed it straight back + # to ``model.sanitize`` (the outer Model.sanitize tolerates the + # ``language_model.*`` prefix as a pass-through). + sanitized2 = model.sanitize(sanitized) + for k in norm_keys: + m = float(sanitized2[k].astype(mx.float32).mean().item()) + assert abs(m - 1.0) < 0.05, ( + f"second sanitize on {k}: mean={m} (double-shifted!)" + ) + + +# --------------------------------------------------------------------------- +# Missing-weights diagnostics +# --------------------------------------------------------------------------- + + +def test_missing_weights_raises() -> None: + """If MTP is declared but no MTP weights and no loader: ``MTPWeightsNotFound``.""" + args = _tiny_model_args(with_mtp=True) + model = Model(args) + main = _build_synthetic_main_weights(model.language_model.args) + raw: Dict[str, mx.array] = {"language_model." + k: v for k, v in main.items()} + # Crucially: do NOT install a sidecar loader, do NOT include mtp.* keys. + with pytest.raises(MTPWeightsNotFound) as excinfo: + model.sanitize(raw) + assert "mtp.safetensors" in excinfo.value.candidates + + +def test_no_mtp_declared_does_not_raise() -> None: + """If config has ``mtp_num_hidden_layers=0``, sanitize is happy without MTP keys.""" + args = _tiny_model_args(with_mtp=False) + model = Model(args) + main = _build_synthetic_main_weights(model.language_model.args) + raw: Dict[str, mx.array] = {"language_model." + k: v for k, v in main.items()} + sanitized = model.sanitize(raw) + assert not any(k.startswith("mtp.") for k in sanitized) + + +# --------------------------------------------------------------------------- +# Cache-share differential (proves cache seeding matters) +# --------------------------------------------------------------------------- + + +def test_cache_share_differential() -> None: + """Primed-cache MTP differs from fresh-cache MTP after multiple steps.""" + args = _tiny_model_args(with_mtp=True) + model = Model(args) + mx.eval(model.parameters()) + + # Walk 4 hidden+token pairs through the MTP head with two cache strategies: + rng_h1 = mx.random.uniform(shape=(1, 1, 64), key=mx.random.key(7)) + rng_h2 = mx.random.uniform(shape=(1, 1, 64), key=mx.random.key(11)) + rng_h3 = mx.random.uniform(shape=(1, 1, 64), key=mx.random.key(13)) + rng_h4 = mx.random.uniform(shape=(1, 1, 64), key=mx.random.key(17)) + hiddens = [rng_h1, rng_h2, rng_h3, rng_h4] + tokens = [mx.array([[i + 5]], dtype=mx.int32) for i in range(4)] + + # Primed: same cache for all steps + primed = model.make_mtp_cache() + primed_outputs = [] + for h, t in zip(hiddens, tokens, strict=True): + out = model.mtp_forward(h, t, mtp_cache=primed) + mx.eval(out) + primed_outputs.append(out) + + # Fresh: new cache each step + fresh_outputs = [] + for h, t in zip(hiddens, tokens, strict=True): + fresh = model.make_mtp_cache() + out = model.mtp_forward(h, t, mtp_cache=fresh) + mx.eval(out) + fresh_outputs.append(out) + + # The very first step must agree (both have empty cache). + first_diff = float(mx.max(mx.abs(primed_outputs[0] - fresh_outputs[0])).item()) + assert first_diff < 1e-4, f"first-step primed/fresh diverge: {first_diff}" + + # Later steps must diverge -- proves cache priming is load-bearing. + last_diff = float(mx.max(mx.abs(primed_outputs[-1] - fresh_outputs[-1])).item()) + assert last_diff > 1e-3, ( + f"last-step primed/fresh agree (diff={last_diff}); cache priming " + "isn't actually changing MTP outputs" + ) + + +# --------------------------------------------------------------------------- +# Hidden-variant correctness (post-norm vs pre-norm) +# --------------------------------------------------------------------------- + + +def test_post_vs_pre_norm() -> None: + """The two hidden variants give different logits and hidden.""" + args = _tiny_model_args(with_mtp=True) + model = Model(args) + mx.eval(model.parameters()) + hidden = mx.random.uniform(shape=(1, 1, 64)) + token = mx.array([[42]], dtype=mx.int32) + cache_post = model.make_mtp_cache() + cache_pre = model.make_mtp_cache() + out_post = model.mtp_forward( + hidden, + token, + mtp_cache=cache_post, + mtp_hidden_variant="post_norm", + return_hidden=True, + ) + out_pre = model.mtp_forward( + hidden, + token, + mtp_cache=cache_pre, + mtp_hidden_variant="pre_norm", + return_hidden=True, + ) + assert isinstance(out_post, tuple) and isinstance(out_pre, tuple) + # Logits are computed from POST-norm in both cases (the lm_head only + # ever sees post-norm). What changes is the *returned* hidden which + # downstream draft-loops use as the input for the next MTP step. + h_post = out_post[1] + h_pre = out_pre[1] + diff = float(mx.max(mx.abs(h_post - h_pre)).item()) + assert diff > 1e-3, ( + f"post-norm and pre-norm hidden variants are identical (diff={diff}); " + "variant switch is a no-op" + ) + + +# --------------------------------------------------------------------------- +# Quant-policy classification +# --------------------------------------------------------------------------- + + +def test_classify_unquantized() -> None: + keys = ( + "mtp.fc.weight", + "mtp.layers.0.self_attn.q_proj.weight", + "mtp.layers.0.self_attn.k_proj.weight", + "mtp.norm.weight", + ) + assert _classify_mtp_key_set(keys) == "unquantized" + + +def test_classify_cyankiwi() -> None: + keys = ( + "mtp.fc.weight", # fc unquantized + "mtp.layers.0.self_attn.q_proj.weight", + "mtp.layers.0.self_attn.q_proj.scales", # presence of attn scales -> cyankiwi + "mtp.layers.0.self_attn.q_proj.biases", + "mtp.norm.weight", + ) + assert _classify_mtp_key_set(keys) == "cyankiwi" + + +def test_classify_all_quantized() -> None: + keys = ( + "mtp.fc.weight", + "mtp.fc.scales", # presence of fc scales -> "all" + "mtp.fc.biases", + "mtp.layers.0.self_attn.q_proj.weight", + "mtp.layers.0.self_attn.q_proj.scales", + "mtp.layers.0.self_attn.q_proj.biases", + ) + assert _classify_mtp_key_set(keys) == "all" + + +def test_quantize_mtp_module_cyankiwi_leaves_fc_unquantized() -> None: + args = TextModelArgs.from_dict(_tiny_text_config(with_mtp=True)) + mtp = MTPModule(args, 1) + _quantize_mtp_module(mtp, policy="cyankiwi", bits=8, group_size=32) + # fc should still be a plain Linear (NOT QuantizedLinear) + assert isinstance(mtp.fc, nn.Linear) + assert not isinstance(mtp.fc, nn.QuantizedLinear) + # attention q_proj should be QuantizedLinear + qproj = mtp.layers[0].self_attn.q_proj + assert isinstance(qproj, nn.QuantizedLinear) + + +def test_quantize_mtp_module_all_quantizes_fc() -> None: + args = TextModelArgs.from_dict(_tiny_text_config(with_mtp=True)) + mtp = MTPModule(args, 1) + _quantize_mtp_module(mtp, policy="all", bits=8, group_size=32) + assert isinstance(mtp.fc, nn.QuantizedLinear) + + +def test_quantize_mtp_module_unquantized_is_noop() -> None: + args = TextModelArgs.from_dict(_tiny_text_config(with_mtp=True)) + mtp = MTPModule(args, 1) + _quantize_mtp_module(mtp, policy="unquantized", bits=8, group_size=32) + assert isinstance(mtp.fc, nn.Linear) and not isinstance(mtp.fc, nn.QuantizedLinear) + assert not isinstance(mtp.layers[0].self_attn.q_proj, nn.QuantizedLinear) diff --git a/src/exo/worker/engines/mlx/vendor/tests/test_qwen3_5_mtp_parity.py b/src/exo/worker/engines/mlx/vendor/tests/test_qwen3_5_mtp_parity.py new file mode 100644 index 0000000000..640f27abd8 --- /dev/null +++ b/src/exo/worker/engines/mlx/vendor/tests/test_qwen3_5_mtp_parity.py @@ -0,0 +1,126 @@ +"""Opt-in numerical parity test against a real Qwen MTP artifact. + +Gated on ``MLX_LM_RUN_NETWORK_TESTS=1`` because it requires the model +checkpoint on disk (~30 GB) and ~10 s of compute per run. + +Pass criterion: top-1 agreement between MTP's next-next-token +prediction and the main lm_head's prediction over a fixed +post-norm hidden context is >= 60%. This is the regression that +prevents PR-#1226-style breakage. + +The probe walks the model incrementally rather than via a batched +prefill -- the Qwen3.5 GatedDeltaNet implementation in mlx-lm is not +strictly per-position causal during batched prefill, which would +distort the per-position parity numbers (see +scripts/mtp_parity_probe.py for the falsification probe that +identified this). +""" + +# pyright: reportAny=false, reportUnknownMemberType=false, reportUnknownArgumentType=false, reportUnknownVariableType=false, reportUnknownParameterType=false, reportMissingParameterType=false, reportUnnecessaryTypeIgnoreComment=false + +from __future__ import annotations + +import os +from pathlib import Path + +import mlx.core as mx +import pytest + +from exo.worker.engines.mlx.vendor.qwen3_5_mtp_loader import load_mtp_model + +_NETWORK_TESTS_FLAG = "MLX_LM_RUN_NETWORK_TESTS" +_MODEL_PATH_ENV = "EXO_NATIVE_MTP_PARITY_MODEL_PATH" +_MODEL_DIR_RAW = os.environ.get(_MODEL_PATH_ENV) +_MODEL_DIR = Path(_MODEL_DIR_RAW) if _MODEL_DIR_RAW else None + +# 48 tokens of natural English; incremental forward over this is ~5 s on M5 Max. +_PROMPT = ( + "The quick brown fox jumps over the lazy dog. In a small village by the " + "river there lived an old clockmaker whose hands could repair any " + "broken timepiece. Each morning before the sun rose he would walk " + "through the cobbled streets carrying a leather satchel full of tiny " + "tools, and the children would wave from their windows as he passed." +) + +_TOP1_FLOOR = 0.60 + + +@pytest.mark.slow +@pytest.mark.skipif( + os.environ.get(_NETWORK_TESTS_FLAG) != "1", + reason=f"set {_NETWORK_TESTS_FLAG}=1 to run the MTP parity regression", +) +@pytest.mark.skipif( + _MODEL_DIR is None, + reason=f"set {_MODEL_PATH_ENV} to a local Qwen MTP model directory", +) +@pytest.mark.skipif( + _MODEL_DIR is not None and not _MODEL_DIR.exists(), + reason=f"configured Qwen MTP model directory does not exist: {_MODEL_DIR}", +) +def test_mtp_top1_agreement_against_lm_head() -> None: + """Top-1 agreement with a primed MTP cache must stay above the floor.""" + try: + from transformers import AutoTokenizer # type: ignore[import-untyped] + except ImportError: + pytest.skip("transformers not installed") + + assert _MODEL_DIR is not None + model, _cfg = load_mtp_model(_MODEL_DIR, lazy=False, strict=True) + tok = AutoTokenizer.from_pretrained(str(_MODEL_DIR)) + + token_ids = tok(_PROMPT, return_tensors="np")["input_ids"][0].tolist() + token_ids = token_ids[: min(len(token_ids), 48)] + assert len(token_ids) >= 16, f"too few tokens to test: {len(token_ids)}" + tokens = mx.array([token_ids], dtype=mx.int32) + seq_len = tokens.shape[1] + + # Incremental main forward to get truly causal per-position post-norm + # hidden + lm_head logits. + text_model = model.language_model + inner = text_model.model + cache = text_model.make_cache() + post_per_pos: list[mx.array] = [] + logits_per_pos: list[mx.array] = [] + from mlx_lm.models.base import create_attention_mask, create_ssm_mask + + for t in range(seq_len): + one = tokens[:, t : t + 1] + h = inner.embed_tokens(one) + fa_mask = create_attention_mask(h, cache[inner.fa_idx]) + ssm_mask = create_ssm_mask(h, cache[inner.ssm_idx]) + for layer, layer_cache in zip(inner.layers, cache, strict=True): + mask = ssm_mask if layer.is_linear else fa_mask + h = layer(h, mask=mask, cache=layer_cache) + post = inner.norm(h) + if text_model.args.tie_word_embeddings: + lg = inner.embed_tokens.as_linear(post) + else: + lg = text_model.lm_head(post) + mx.eval(post, lg) + post_per_pos.append(post) + logits_per_pos.append(lg) + post_norm_hidden = mx.concatenate(post_per_pos, axis=1) + main_logits = mx.concatenate(logits_per_pos, axis=1) + mx.eval(post_norm_hidden, main_logits) + + # Primed-cache MTP walk + mtp_cache = model.make_mtp_cache() + matches = 0 + total = 0 + for t in range(seq_len - 1): + h_t = post_norm_hidden[:, t : t + 1, :] + next_tok = tokens[:, t + 1 : t + 2] + mtp_logits = model.mtp_forward(h_t, next_tok, mtp_cache=mtp_cache) + assert isinstance(mtp_logits, mx.array) + mx.eval(mtp_logits) + mtp_pred = int(mx.argmax(mtp_logits[0, -1]).item()) + main_pred = int(mx.argmax(main_logits[0, t + 1]).item()) + if mtp_pred == main_pred: + matches += 1 + total += 1 + top1 = matches / total + assert top1 >= _TOP1_FLOOR, ( + f"MTP top-1 agreement {top1:.4f} < floor {_TOP1_FLOOR:.2f} " + f"(matches={matches}/{total}) -- this is the PR #1226 regression" + ) diff --git a/src/exo/worker/runner/llm_inference/batch_generator.py b/src/exo/worker/runner/llm_inference/batch_generator.py index 098c829e9a..53865d02d1 100644 --- a/src/exo/worker/runner/llm_inference/batch_generator.py +++ b/src/exo/worker/runner/llm_inference/batch_generator.py @@ -99,6 +99,12 @@ class SequentialGenerator(Engine): event_sender: MpSender[Event] vision_processor: VisionProcessor | None = None check_for_cancel_every: int = 50 + # Native-MTP K bounds from the model card (``NativeMTPConfig``). Set by + # the builder only when the target loaded as a vendored MTP-aware model; + # ``None`` on every other model. ``mlx_generate`` resolves the per-request + # K from these. + native_mtp_default_k: int | None = None + native_mtp_max_k: int | None = None _cancelled_tasks: set[TaskId] = field(default_factory=set, init=False) _maybe_queue: list[TextGeneration] = field(default_factory=list, init=False) @@ -295,6 +301,8 @@ def on_generation_token() -> None: on_generation_token=on_generation_token, group=self.group, vision_processor=self.vision_processor, + native_mtp_default_k=self.native_mtp_default_k, + native_mtp_max_k=self.native_mtp_max_k, ) def close(self) -> None: diff --git a/src/exo/worker/tests/unittests/test_mlx/test_native_mtp_drafter.py b/src/exo/worker/tests/unittests/test_mlx/test_native_mtp_drafter.py new file mode 100644 index 0000000000..e2c0e4e540 --- /dev/null +++ b/src/exo/worker/tests/unittests/test_mlx/test_native_mtp_drafter.py @@ -0,0 +1,551 @@ +"""Unit tests for :mod:`exo.worker.engines.mlx.generator.native_mtp_drafter`. + +Exercises: + +- ``is_native_mtp_dispatchable`` correctly disambiguates the vendored + MTP-capable :class:`vendor.qwen3_5_mtp.Model` from plain ``mx`` + modules and from ``unittest.mock.MagicMock`` (which auto-creates + any attribute on access). +- :class:`NativeMTPDrafter` constructor validation (rejects ``k<1``) + and trivial property surface (``mode``, ``num_draft_tokens``). +- :func:`prime_mtp_cache_from_prompt` returns ``N-1`` positions + primed for prompts of size ``>=2`` and ``0`` otherwise. +- End-to-end ``stream`` smoke against the tiny synthetic Qwen3.5/6 + model fixture (mirrors the pattern used in + ``vendor/tests/test_qwen3_5_mtp.py``): runs K=1 / K=2 / K=3 and + asserts the drafter emits tokens, populates metrics, and respects + ``max_tokens``. + +These tests use the synthetic tiny model so the whole suite runs in +seconds without requiring any real model download. The opt-in parity +test against the real MTPLX artifact lives separately. +""" + +# pyright: reportPrivateUsage=false, reportUnknownMemberType=false, reportAttributeAccessIssue=false, reportUnknownArgumentType=false, reportUnknownVariableType=false, reportAny=false + +from __future__ import annotations + +from collections.abc import Iterable +from typing import Any, Callable, Dict, cast +from unittest.mock import MagicMock + +import mlx.core as mx +import pytest + +from exo.worker.engines.mlx.generator.native_mtp_drafter import ( + NativeMTPDrafter, + _eos_ids_from_tokenizer, + _gdn_state_history_commit_default, + _moe_verifier_policy, + _target_post_norm_hidden, + is_native_mtp_dispatchable, + prime_mtp_cache_from_prompt, + prime_mtp_cache_from_prompt_incremental, + rebuild_prompt_cache_and_prime_mtp_cache_incremental, + rebuild_prompt_cache_incremental, +) +from exo.worker.engines.mlx.types import Model as ExoModel +from exo.worker.engines.mlx.vendor.qwen3_5_mtp import Model, ModelArgs + + +def _tiny_text_config(*, with_mtp: bool = True) -> Dict[str, Any]: + """Mirror the tiny config used by ``vendor/tests/test_qwen3_5_mtp.py``.""" + cfg: Dict[str, Any] = { + "model_type": "qwen3_5_text", + "hidden_size": 64, + "intermediate_size": 128, + "num_hidden_layers": 4, + "num_attention_heads": 4, + "rms_norm_eps": 1e-6, + "vocab_size": 256, + "num_key_value_heads": 2, + "max_position_embeddings": 256, + "linear_num_value_heads": 4, + "linear_num_key_heads": 2, + "linear_key_head_dim": 16, + "linear_value_head_dim": 16, + "linear_conv_kernel_dim": 4, + "tie_word_embeddings": False, + "attention_bias": False, + "head_dim": 16, + "full_attention_interval": 2, + "num_experts": 0, + "rope_parameters": { + "type": "default", + "rope_theta": 10000.0, + "partial_rotary_factor": 1.0, + }, + } + if with_mtp: + cfg["mtp_num_hidden_layers"] = 1 + return cfg + + +def _tiny_model(*, with_mtp: bool = True) -> ExoModel: + args = ModelArgs( + model_type="qwen3_5", text_config=_tiny_text_config(with_mtp=with_mtp) + ) + model = Model(args) + mx.eval(model.parameters()) + # The vendored ``Model`` subclass satisfies the exo Model Protocol's + # ``__call__`` shape modulo the optional ``return_hidden`` kwarg the + # native-MTP path uses. Cast through ``object`` to defuse the + # basedpyright ``reportInvalidCast`` for not-sufficiently-overlapping + # types; runtime stream tests below catch any real Protocol drift. + return cast(ExoModel, cast(object, model)) + + +class _FakeDetokenizer: + """Minimal detokenizer that records tokens added between yields. + + Stream consumers walk ``last_segment`` after every ``add_token`` and + then ``finalize`` is called once before the closing yield. We only + need an empty ``last_segment`` for the drafter to construct + :class:`GenerationResponse` -- the test asserts emitted tokens via + ``GenerationResponse.token`` directly, not via the segment string. + """ + + def __init__(self) -> None: + self.tokens: list[int] = [] + self.last_segment: str = "" + self.finalized: bool = False + + def reset(self) -> None: + self.tokens = [] + self.last_segment = "" + self.finalized = False + + def add_token(self, token: int) -> None: + self.tokens.append(int(token)) + + def finalize(self) -> None: + self.finalized = True + + +class _FakeTokenizer: + """Fake :class:`mlx_lm.tokenizer_utils.TokenizerWrapper` minimal shim.""" + + def __init__(self, eos_token_ids: Iterable[int] | None = None) -> None: + self.detokenizer: _FakeDetokenizer = _FakeDetokenizer() + self.eos_token_ids: Iterable[int] = list(eos_token_ids or []) + + +def _identity_sampler(logits: mx.array) -> mx.array: + """Greedy sampler used by smoke tests.""" + return mx.argmax(logits, axis=-1) + + +def _empty_processors() -> list[Callable[[mx.array, mx.array], mx.array]]: + return [] + + +# --------------------------------------------------------------------------- # +# is_native_mtp_dispatchable +# --------------------------------------------------------------------------- # + + +class TestIsNativeMtpDispatchable: + def test_vendored_model_is_dispatchable(self) -> None: + model = _tiny_model(with_mtp=True) + assert is_native_mtp_dispatchable(model) is True + + def test_vendored_model_without_mtp_layers_is_still_class_dispatchable( + self, + ) -> None: + """The class-level marker doesn't depend on whether MTP layers exist. + + The dispatch-side gate only checks the model class; the runtime + check for actual MTP availability is the loader / probe gate. + A card without MTP would never produce this class via the loader, + so the gate is conservative in the right direction. + """ + model = _tiny_model(with_mtp=False) + assert is_native_mtp_dispatchable(model) is True + + def test_magicmock_is_not_dispatchable(self) -> None: + """``MagicMock`` auto-creates attributes; the marker check rejects it. + + Pre-fix the dispatcher used ``hasattr(model, "mtp_forward")`` which + returned ``True`` for any ``MagicMock`` instance because attribute + access auto-creates a child mock. That falsely engaged the + NativeMTPDrafter in unrelated routing tests. + """ + fake = MagicMock() + assert is_native_mtp_dispatchable(fake) is False + + def test_plain_object_is_not_dispatchable(self) -> None: + assert is_native_mtp_dispatchable(object()) is False + + def test_none_is_not_dispatchable(self) -> None: + assert is_native_mtp_dispatchable(None) is False + + +# --------------------------------------------------------------------------- # +# NativeMTPDrafter constructor / property surface +# --------------------------------------------------------------------------- # + + +class TestNativeMTPDrafterConstructor: + def test_rejects_zero_k(self) -> None: + with pytest.raises(ValueError, match="k must be >= 1"): + NativeMTPDrafter(k=0) + + def test_rejects_negative_k(self) -> None: + with pytest.raises(ValueError, match="k must be >= 1"): + NativeMTPDrafter(k=-1) + + def test_mode_is_model(self) -> None: + drafter = NativeMTPDrafter(k=1) + assert drafter.mode == "model" + + def test_num_draft_tokens_matches_k(self) -> None: + for k in (1, 2, 3): + assert NativeMTPDrafter(k=k).num_draft_tokens == k + + def test_initial_metrics_are_zeroed(self) -> None: + drafter = NativeMTPDrafter(k=2) + assert drafter.metrics() == { + "proposed_draft_tokens": 0, + "accepted_draft_tokens": 0, + "spec_decode_rounds": 0, + } + + def test_metrics_is_a_copy(self) -> None: + """``metrics()`` returns a fresh dict so callers can't mutate state.""" + drafter = NativeMTPDrafter(k=1) + snap = drafter.metrics() + snap["proposed_draft_tokens"] = 99 + assert drafter.metrics()["proposed_draft_tokens"] == 0 + + +def test_eos_ids_from_tokenizer_accepts_set_eos_ids() -> None: + """Real MLX TokenizerWrapper exposes Qwen EOS IDs as a set.""" + tokenizer = _FakeTokenizer(eos_token_ids={248044, 248046}) + + assert sorted(_eos_ids_from_tokenizer(cast(Any, tokenizer))) == [248044, 248046] + + +@pytest.mark.parametrize("model_type", ["qwen3_5", "qwen3_5_text"]) +def test_gdn_state_history_commit_defaults_on_for_dense_qwen3_5( + model_type: str, +) -> None: + assert ( + _gdn_state_history_commit_default( + model_type=model_type, + moe_verifier_policy="safe", + ) + is True + ) + + +def test_gdn_state_history_commit_defaults_on_for_route_locked_moe_only() -> None: + assert ( + _gdn_state_history_commit_default( + model_type="qwen3_5_moe", + moe_verifier_policy="route_locked", + ) + is True + ) + assert ( + _gdn_state_history_commit_default( + model_type="qwen3_5_moe", + moe_verifier_policy="safe", + ) + is False + ) + + +def test_moe_verifier_policy_defaults_to_route_locked( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.delenv("EXO_NATIVE_MTP_MOE_VERIFY", raising=False) + + assert _moe_verifier_policy() == "route_locked" + + +def test_moe_verifier_policy_can_fall_back_to_safe( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.setenv("EXO_NATIVE_MTP_MOE_VERIFY", "safe") + + assert _moe_verifier_policy() == "safe" + + +# --------------------------------------------------------------------------- # +# prime_mtp_cache_from_prompt +# --------------------------------------------------------------------------- # + + +class TestPrimeMTPCacheFromPrompt: + def test_returns_zero_for_empty_prompt(self) -> None: + model = _tiny_model(with_mtp=True) + mtp_cache = model.make_mtp_cache() + assert ( + prime_mtp_cache_from_prompt( + model=model, full_prompt_tokens=[], mtp_cache=mtp_cache + ) + == 0 + ) + + def test_returns_zero_for_single_token_prompt(self) -> None: + model = _tiny_model(with_mtp=True) + mtp_cache = model.make_mtp_cache() + assert ( + prime_mtp_cache_from_prompt( + model=model, full_prompt_tokens=[5], mtp_cache=mtp_cache + ) + == 0 + ) + + def test_returns_nminus1_for_normal_prompt(self) -> None: + model = _tiny_model(with_mtp=True) + mtp_cache = model.make_mtp_cache() + primed = prime_mtp_cache_from_prompt( + model=model, + full_prompt_tokens=[1, 2, 3, 4, 5], + mtp_cache=mtp_cache, + ) + assert primed == 4 + + def test_primes_advances_mtp_cache_offset(self) -> None: + """After priming N positions the MTP cache holds N entries.""" + model = _tiny_model(with_mtp=True) + mtp_cache = model.make_mtp_cache() + prompt = [1, 2, 3, 4, 5, 6] + primed = prime_mtp_cache_from_prompt( + model=model, full_prompt_tokens=prompt, mtp_cache=mtp_cache + ) + assert primed == len(prompt) - 1 + # The MTP cache is a stock KVCache; its ``offset`` should be the + # number of positions written. + offsets = [getattr(c, "offset", -1) for c in mtp_cache] + assert offsets == [len(prompt) - 1] + + +class TestHiddenWithoutLogits: + def test_matches_outer_model_return_hidden_and_cache(self) -> None: + model = _tiny_model(with_mtp=True) + token = mx.array([[7]], dtype=mx.int32) + next_token = mx.array([[8]], dtype=mx.int32) + + outer_cache = list(cast(Any, model).make_cache()) + _outer_logits, outer_hidden = cast(Any, model)( + token, + cache=outer_cache, + return_hidden=True, + ) + body_cache = list(cast(Any, model).make_cache()) + body_hidden = _target_post_norm_hidden( + model=model, + inputs=token, + cache=cast(Any, body_cache), + ) + mx.eval(outer_hidden, body_hidden) + assert float(mx.max(mx.abs(outer_hidden - body_hidden)).item()) == 0.0 + + outer_next_logits, _outer_next_hidden = cast(Any, model)( + next_token, + cache=outer_cache, + return_hidden=True, + ) + body_next_logits, _body_next_hidden = cast(Any, model)( + next_token, + cache=body_cache, + return_hidden=True, + ) + mx.eval(outer_next_logits, body_next_logits) + assert float(mx.max(mx.abs(outer_next_logits - body_next_logits)).item()) == 0.0 + + +class TestFusedPromptAndMtpPriming: + def test_matches_separate_incremental_paths(self) -> None: + model = _tiny_model(with_mtp=True) + prompt = [1, 2, 3, 4, 5, 6] + + separate_prompt_cache = list(cast(Any, model).make_cache()) + separate_mtp_cache = list(cast(Any, model).make_mtp_cache()) + separate_rebuilt = rebuild_prompt_cache_incremental( + model=model, + full_prompt_tokens=prompt, + prompt_cache=cast(Any, separate_prompt_cache), + ) + separate_primed = prime_mtp_cache_from_prompt_incremental( + model=model, + full_prompt_tokens=prompt, + mtp_cache=separate_mtp_cache, + ) + + fused_prompt_cache = list(cast(Any, model).make_cache()) + fused_mtp_cache = list(cast(Any, model).make_mtp_cache()) + fused_rebuilt, fused_primed = ( + rebuild_prompt_cache_and_prime_mtp_cache_incremental( + model=model, + full_prompt_tokens=prompt, + prompt_cache=cast(Any, fused_prompt_cache), + mtp_cache=fused_mtp_cache, + ) + ) + + assert fused_rebuilt == separate_rebuilt == len(prompt) - 2 + assert fused_primed == separate_primed == len(prompt) - 1 + assert [getattr(c, "offset", -1) for c in fused_prompt_cache] == [ + getattr(c, "offset", -1) for c in separate_prompt_cache + ] + assert [getattr(c, "offset", -1) for c in fused_mtp_cache] == [ + getattr(c, "offset", -1) for c in separate_mtp_cache + ] + + def next_logits_and_draft( + prompt_cache: list[Any], mtp_cache: list[Any] + ) -> tuple[mx.array, mx.array]: + prompt_tail = mx.array([prompt[-2:]], dtype=mx.int32) + _tail_logits, _tail_hidden = cast(Any, model)( + prompt_tail[:, :-1], cache=prompt_cache, return_hidden=True + ) + first_logits, first_hidden = cast(Any, model)( + prompt_tail[:, -1:], cache=prompt_cache, return_hidden=True + ) + current_token_arr = mx.argmax(first_logits, axis=-1).astype(mx.int32) + mtp_logits, _mtp_hidden = cast(Any, model).mtp_forward( + first_hidden[:, -1:, :], + current_token_arr, + mtp_cache=mtp_cache, + return_hidden=True, + ) + mx.eval(first_logits, mtp_logits) + return first_logits, mtp_logits + + separate_first, separate_mtp = next_logits_and_draft( + separate_prompt_cache, separate_mtp_cache + ) + fused_first, fused_mtp = next_logits_and_draft( + fused_prompt_cache, fused_mtp_cache + ) + + assert float(mx.max(mx.abs(separate_first - fused_first)).item()) == 0.0 + assert float(mx.max(mx.abs(separate_mtp - fused_mtp)).item()) == 0.0 + + +# --------------------------------------------------------------------------- # +# End-to-end stream smoke tests +# --------------------------------------------------------------------------- # + + +def _drive_stream( + *, + drafter: NativeMTPDrafter, + model: ExoModel, + tokenizer: _FakeTokenizer, + prompt_full: list[int], + max_tokens: int, +) -> list[Any]: + """Run the K-step drafter end-to-end and collect every yielded response. + + Mirrors what ``mlx_generate`` does at the dispatch site: prefills + the prompt cache aligned to ``full_prompt[:-2]``, then hands the + last two tokens to ``drafter.stream``. + """ + cache: list[Any] = list(cast(Any, model).make_cache()) + # exo.prefill equivalent: feed prompt[:-2] (so cache holds N-2 positions). + if len(prompt_full) >= 3: + prefill = mx.array(prompt_full[:-2], dtype=mx.int32)[None] + _ = model(prefill, cache=cache) + mx.eval([c.state for c in cache if hasattr(c, "state")]) + prompt_tail = mx.array(prompt_full[-2:], dtype=mx.int32) + responses: list[Any] = [] + for response in drafter.stream( + model=model, + tokenizer=cast(Any, tokenizer), + prompt=prompt_tail, + context_tokens=prompt_full, + prompt_cache=cast(Any, cache), + max_tokens=max_tokens, + sampler=_identity_sampler, + logits_processors=_empty_processors(), + ): + responses.append(response) + if response.finish_reason is not None: + break + return responses + + +@pytest.mark.parametrize("k", [1, 2, 3]) +def test_stream_emits_tokens_and_populates_metrics(k: int) -> None: + """K=1/2/3 streams emit tokens, terminate, and stamp metrics.""" + model = _tiny_model(with_mtp=True) + tokenizer = _FakeTokenizer(eos_token_ids=[]) + drafter = NativeMTPDrafter(k=k) + prompt = [1, 2, 3, 4, 5, 6, 7, 8] + max_tokens = 8 + + responses = _drive_stream( + drafter=drafter, + model=model, + tokenizer=tokenizer, + prompt_full=prompt, + max_tokens=max_tokens, + ) + + # At least one response is yielded (the first emitted token from the + # initial forward) plus a closing chunk. + assert len(responses) >= 2 + final = responses[-1] + assert final.finish_reason in {"stop", "length"} + assert final.generation_tokens <= max_tokens + assert final.generation_tokens >= 1 + + metrics = drafter.metrics() + # spec_decode_rounds is 0 when the first emitted token alone covered + # max_tokens; for max_tokens=8 with random init the loop should run + # at least one round. + assert metrics["spec_decode_rounds"] >= 0 + assert metrics["proposed_draft_tokens"] == k * metrics["spec_decode_rounds"] + assert ( + metrics["accepted_draft_tokens"] >= 0 + and metrics["accepted_draft_tokens"] <= metrics["proposed_draft_tokens"] + ) + + +def test_stream_raises_on_non_dispatchable_model() -> None: + """A bare MagicMock target fails the dispatch-side gate inside ``stream``.""" + drafter = NativeMTPDrafter(k=1) + fake_model = MagicMock() + fake_model.make_mtp_cache.return_value = [MagicMock()] + tokenizer = _FakeTokenizer() + with pytest.raises(RuntimeError, match="is_native_mtp_dispatchable"): + # ``next`` to actually enter the generator body; the dispatchable + # check happens before any yield. + next( + drafter.stream( + model=cast(Any, fake_model), + tokenizer=cast(Any, tokenizer), + prompt=mx.array([1, 2], dtype=mx.int32), + context_tokens=[1, 2], + prompt_cache=cast(Any, []), + max_tokens=4, + sampler=_identity_sampler, + logits_processors=_empty_processors(), + ) + ) + + +def test_stream_stops_at_immediate_eos() -> None: + """When the first emitted token IS an EOS, the stream terminates immediately.""" + model = _tiny_model(with_mtp=True) + # Pin every vocab id as EOS so the first emitted token (whatever it + # is) terminates the stream. This is the cheapest way to exercise + # the immediate-EOS branch without needing to control the model's + # output distribution. + tokenizer = _FakeTokenizer(eos_token_ids=set(range(256))) + drafter = NativeMTPDrafter(k=1) + responses = _drive_stream( + drafter=drafter, + model=model, + tokenizer=tokenizer, + prompt_full=[1, 2, 3, 4], + max_tokens=8, + ) + # The immediate-EOS branch yields a SINGLE response with finish_reason="stop". + assert len(responses) == 1 + assert responses[0].finish_reason == "stop" + assert responses[0].generation_tokens == 1 diff --git a/src/exo/worker/tests/unittests/test_mlx/test_spec_cache.py b/src/exo/worker/tests/unittests/test_mlx/test_spec_cache.py new file mode 100644 index 0000000000..88227234ed --- /dev/null +++ b/src/exo/worker/tests/unittests/test_mlx/test_spec_cache.py @@ -0,0 +1,116 @@ +"""Unit tests for the speculative-decode snapshot/rollback helpers. + +These exercise :mod:`exo.worker.engines.mlx.spec_cache` directly with fake +cache entries that mimic the mlx-lm cache API (``is_trimmable`` / ``trim`` / +``state`` / ``meta_state``), so no real model or GatedDeltaNet is needed. +""" + +# The snapshot stores cloned cache state as ``tuple[Any, ...]`` by design, +# so reads off it are Any -- mirror the pragma used by the mlx modules. +# pyright: reportAny=false + +from __future__ import annotations + +from typing import Any + +import mlx.core as mx + +from exo.worker.engines.mlx.spec_cache import ( + CacheSnapshot, + rollback_after_verify, + snapshot_untrimmable_cache_lazy, +) + + +class _TrimmableEntry: + """Stands in for an attention KV cache: trimmable, no recurrent state.""" + + def __init__(self) -> None: + self.trim_calls: list[int] = [] + + def is_trimmable(self) -> bool: + return True + + def trim(self, n: int) -> None: + self.trim_calls.append(n) + + +class _NonTrimmableEntry: + """Stands in for a GatedDeltaNet ``ArraysCache``: not trimmable, carries + recurrent state in a list that the cache reads from in place.""" + + def __init__(self, state: list[Any], meta_state: Any) -> None: + self.state = state + self.meta_state = meta_state + + def is_trimmable(self) -> bool: + return False + + +def _arrays_equal(a: mx.array, b: mx.array) -> bool: + return bool(mx.array_equal(a, b)) + + +def test_snapshot_skips_trimmable_and_clones_non_trimmable() -> None: + trimmable = _TrimmableEntry() + non_trimmable = _NonTrimmableEntry( + state=[mx.array([1, 2, 3])], meta_state=mx.array([9]) + ) + snap = snapshot_untrimmable_cache_lazy([trimmable, non_trimmable]) + + assert isinstance(snap, CacheSnapshot) + # Trimmable slot carries None (offset trim is sufficient rollback). + assert snap.states[0] is None + assert snap.meta_states[0] is None + # Non-trimmable slot carries a clone of the recurrent state. + assert snap.states[1] is not None + assert _arrays_equal(snap.states[1][0], mx.array([1, 2, 3])) + + +def test_snapshot_clone_is_isolated_from_later_mutation() -> None: + """A snapshot must not alias the live cache: replacing the cache's state + leaves (as the verify forward does) must not change the snapshot.""" + entry = _NonTrimmableEntry(state=[mx.array([1, 2, 3])], meta_state=None) + snap = snapshot_untrimmable_cache_lazy([entry]) + + # Simulate the verify forward replacing the recurrent-state leaf. + entry.state = [mx.array([7, 8, 9])] + + assert _arrays_equal(snap.states[0][0], mx.array([1, 2, 3])) + + +def test_rollback_trims_trimmable_and_restores_non_trimmable() -> None: + trimmable = _TrimmableEntry() + original_state = [mx.array([1, 2, 3])] + entry = _NonTrimmableEntry(state=original_state, meta_state=mx.array([5])) + cache: list[Any] = [trimmable, entry] + + snap = snapshot_untrimmable_cache_lazy(cache) + + # Verify forward advances state past the speculative trajectory. + entry.state[:] = [mx.array([7, 8, 9])] + entry.meta_state = mx.array([6]) + + rollback_after_verify(cache, snap, verified_tokens=2) + + # Trimmable entry trimmed by the requested count. + assert trimmable.trim_calls == [2] + # Non-trimmable recurrent state + meta restored from the snapshot. + assert _arrays_equal(entry.state[0], mx.array([1, 2, 3])) + assert _arrays_equal(entry.meta_state, mx.array([5])) + # Container identity preserved (cache reads from the same list object). + assert entry.state is original_state + + +def test_rollback_with_zero_verified_tokens_skips_trim_but_restores() -> None: + trimmable = _TrimmableEntry() + entry = _NonTrimmableEntry(state=[mx.array([1, 2, 3])], meta_state=None) + cache: list[Any] = [trimmable, entry] + + snap = snapshot_untrimmable_cache_lazy(cache) + entry.state[:] = [mx.array([7, 8, 9])] + + rollback_after_verify(cache, snap, verified_tokens=0) + + assert trimmable.trim_calls == [] + assert _arrays_equal(entry.state[0], mx.array([1, 2, 3])) diff --git a/src/exo/worker/tests/unittests/test_mlx/test_utils_mlx_native_mtp.py b/src/exo/worker/tests/unittests/test_mlx/test_utils_mlx_native_mtp.py new file mode 100644 index 0000000000..c94e3f2da2 --- /dev/null +++ b/src/exo/worker/tests/unittests/test_mlx/test_utils_mlx_native_mtp.py @@ -0,0 +1,302 @@ +"""Tests for the native-MTP loader branch in ``utils_mlx.load_mlx_items``. + +The native-MTP path was added so cards declaring +:class:`exo.shared.models.model_cards.NativeMTPConfig` get loaded +through :func:`exo.worker.engines.mlx.vendor.qwen3_5_mtp_loader.load_mtp_model` +instead of the stock ``mlx_lm.utils.load_model``. + +The tests use ``monkeypatch`` to stub the heavy MLX call sites so we can +drive the loader without a real checkpoint or filesystem. Coverage: + +- Card declares ``native_mtp`` AND placement is single-node AND probe + says "recoverable": ``load_mtp_model`` is called, stock ``load_model`` + is NOT. +- Card has no ``native_mtp``: stock ``load_model`` is called. +- Card declares ``native_mtp`` but the probe returns "stripped" (e.g. + mlx-community-style stripped quants): we degrade to the stock loader. + +We do NOT exercise the real Qwen3.5/3.6 MTP loader here -- the vendor +module has its own parity tests. The goal is to verify routing. +""" + +# pyright: reportPrivateUsage=false + +from __future__ import annotations + +from collections.abc import Generator +from pathlib import Path +from typing import cast + +import pytest + +from exo.shared.models.model_cards import ( + ModelCard, + ModelId, + ModelTask, + NativeMTPConfig, +) +from exo.shared.types.backends import Backend +from exo.shared.types.common import NodeId +from exo.shared.types.memory import Memory +from exo.shared.types.worker.instances import ( + BoundInstance, + InstanceId, + MlxRingInstance, +) +from exo.shared.types.worker.runners import ( + RunnerId, + ShardAssignments, +) +from exo.shared.types.worker.shards import ( + PipelineShardMetadata, + ShardMetadata, +) +from exo.worker.engines.mlx import mtp_probe, utils_mlx + + +def _target_card(*, native_mtp: NativeMTPConfig | None = None) -> ModelCard: + return ModelCard( + model_id=ModelId("Youssofal/Qwen3.6-27B-MTPLX-Optimized-Quality"), + storage_size=Memory.from_gb(1.0), + n_layers=64, + hidden_size=5120, + supports_tensor=True, + tasks=[ModelTask.TextGeneration], + backends=[Backend.MlxMetal], + native_mtp=native_mtp, + ) + + +def _make_single_target_bound_instance(card: ModelCard) -> BoundInstance: + target_node = NodeId() + target_runner_id = RunnerId() + shard = PipelineShardMetadata( + model_card=card, + device_rank=0, + world_size=1, + start_layer=0, + end_layer=64, + n_layers=64, + ) + instance = MlxRingInstance( + instance_id=InstanceId(), + shard_assignments=ShardAssignments( + model_id=card.model_id, + runner_to_shard={target_runner_id: cast(ShardMetadata, shard)}, + node_to_runner={target_node: target_runner_id}, + ), + hosts_by_node={target_node: []}, + ephemeral_port=60000, + ) + return BoundInstance( + instance=instance, + bound_runner_id=target_runner_id, + bound_node_id=target_node, + ) + + +_LoadResult = tuple[object, object, object] + + +def _consume_generator( + gen: Generator[object, None, _LoadResult], +) -> _LoadResult: + while True: + try: + next(gen) + except StopIteration as stop: + return cast(_LoadResult, stop.value) + + +def _recoverable_probe() -> mtp_probe.MtpProbeResult: + return mtp_probe.MtpProbeResult( + model_declares_mtp=True, + mtp_tensors_found=True, + mtp_format=mtp_probe.MtpFormat.MTPLX_SEPARATE_FILE, + mtp_count=29, + mtp_path="/tmp/fake/mtp.safetensors", + mtp_tensor_keys=(), + ) + + +def _stripped_probe() -> mtp_probe.MtpProbeResult: + return mtp_probe.MtpProbeResult( + model_declares_mtp=True, + mtp_tensors_found=False, + mtp_format=mtp_probe.MtpFormat.STRIPPED, + mtp_count=0, + mtp_path=None, + mtp_tensor_keys=(), + ) + + +def _patch_loader_routing( + monkeypatch: pytest.MonkeyPatch, *, probe_recoverable: bool +) -> dict[str, bool]: + """Stub the heavy load call sites and return a flag dict recording + which loader fired.""" + calls = {"stock_load_model": False, "native_load_mtp_model": False} + + fake_model: object = object() + fake_inner: object = object() + fake_tokenizer: object = object() + + def fake_stock_load( + _path: object, **_kwargs: object + ) -> tuple[object, dict[str, object]]: + calls["stock_load_model"] = True + return fake_model, {} + + monkeypatch.setattr(utils_mlx, "load_model", fake_stock_load) + + # The native loader is imported lazily inside ``load_mlx_items``; patch + # it on the source module so the lazy import sees our stub. + from exo.worker.engines.mlx.vendor import qwen3_5_mtp_loader as _mtp_loader_mod + + def fake_native_load( + _path: object, **_kwargs: object + ) -> tuple[object, dict[str, object]]: + calls["native_load_mtp_model"] = True + return fake_model, {} + + monkeypatch.setattr(_mtp_loader_mod, "load_mtp_model", fake_native_load) + + def fake_probe(_path: object) -> mtp_probe.MtpProbeResult: + return _recoverable_probe() if probe_recoverable else _stripped_probe() + + monkeypatch.setattr(mtp_probe, "probe_mtp_weights", fake_probe) + + def fake_inner_model(_model: object) -> object: + return fake_inner + + def fake_layers(_inner: object) -> list[object]: + return [] + + def fake_get_tokenizer(_path: object, _shard: object) -> object: + return fake_tokenizer + + def fake_set_wired_limit(_size: object) -> None: + return None + + def fake_build_model_path(_model_id: object) -> Path: + return Path("/tmp/fake-model-path") + + monkeypatch.setattr(utils_mlx, "get_inner_model", fake_inner_model) + monkeypatch.setattr(utils_mlx, "get_layers", fake_layers) + monkeypatch.setattr(utils_mlx, "get_tokenizer", fake_get_tokenizer) + monkeypatch.setattr(utils_mlx, "set_wired_limit_for_model", fake_set_wired_limit) + monkeypatch.setattr(utils_mlx, "build_model_path", fake_build_model_path) + + import mlx.core as mx_core + + def fake_eval(*_args: object, **_kwargs: object) -> None: + return None + + def fake_clear_cache() -> None: + return None + + monkeypatch.setattr(mx_core, "eval", fake_eval) + monkeypatch.setattr(mx_core, "clear_cache", fake_clear_cache) + + return calls + + +class TestNativeMTPLoaderRouting: + """``load_mlx_items`` dispatches the native MTP loader iff the card + declares ``native_mtp``, the probe says recoverable, and the + placement is single-node. + """ + + def test_native_path_when_card_declares_and_single_node( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + calls = _patch_loader_routing(monkeypatch, probe_recoverable=True) + card = _target_card(native_mtp=NativeMTPConfig(num_layers=1)) + bound = _make_single_target_bound_instance(card) + _consume_generator( + cast( + Generator[object, None, _LoadResult], + utils_mlx.load_mlx_items(bound, group=None), + ) + ) + assert calls["native_load_mtp_model"] is True, ( + "card declares native_mtp and probe says recoverable: the " + "loader MUST dispatch through load_mtp_model" + ) + assert calls["stock_load_model"] is False, ( + "native loader handles weight loading end-to-end; stock " + "load_model must not be called when the native path fires" + ) + + def test_stock_path_when_card_lacks_native_mtp( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + calls = _patch_loader_routing(monkeypatch, probe_recoverable=True) + card = _target_card(native_mtp=None) + bound = _make_single_target_bound_instance(card) + _consume_generator( + cast( + Generator[object, None, _LoadResult], + utils_mlx.load_mlx_items(bound, group=None), + ) + ) + assert calls["stock_load_model"] is True + assert calls["native_load_mtp_model"] is False, ( + "the native MTP loader must never fire for cards that don't opt in" + ) + + def test_stock_fallback_when_probe_says_stripped( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + """Card declares native_mtp but probe returns STRIPPED → stock path + (silent degrade so the request still completes).""" + calls = _patch_loader_routing(monkeypatch, probe_recoverable=False) + card = _target_card(native_mtp=NativeMTPConfig(num_layers=1)) + bound = _make_single_target_bound_instance(card) + _consume_generator( + cast( + Generator[object, None, _LoadResult], + utils_mlx.load_mlx_items(bound, group=None), + ) + ) + assert calls["stock_load_model"] is True, ( + "probe says STRIPPED → fall back to stock loader so the " + "request still completes (silent degrade, not a hard fail)" + ) + assert calls["native_load_mtp_model"] is False + + +class TestNativeMtpLoaderEligible: + """``_native_mtp_loader_eligible`` is the loader-side gate. It is only + called from the single-node (``group is None``) branch.""" + + def test_true_for_declared_card_with_recoverable_probe( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + def fake_probe(_path: object) -> mtp_probe.MtpProbeResult: + return _recoverable_probe() + + monkeypatch.setattr(mtp_probe, "probe_mtp_weights", fake_probe) + card = _target_card(native_mtp=NativeMTPConfig(num_layers=1)) + assert utils_mlx._native_mtp_loader_eligible(card, Path("/tmp/fake")) is True + + def test_false_without_native_mtp_skips_probe(self) -> None: + """Cards that don't declare native_mtp short-circuit before the + probe (which would touch disk).""" + card = _target_card(native_mtp=None) + assert ( + utils_mlx._native_mtp_loader_eligible( + card, Path("/nonexistent/never-probed") + ) + is False + ) + + def test_false_when_probe_says_stripped( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + def fake_probe(_path: object) -> mtp_probe.MtpProbeResult: + return _stripped_probe() + + monkeypatch.setattr(mtp_probe, "probe_mtp_weights", fake_probe) + card = _target_card(native_mtp=NativeMTPConfig(num_layers=1)) + assert utils_mlx._native_mtp_loader_eligible(card, Path("/tmp/fake")) is False