Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions app/EXO/EXO/ExoProcessController.swift
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ private let hfEndpointKey = "EXOHFEndpoint"
private let enableImageModelsKey = "EXOEnableImageModels"
private let offlineModeKey = "EXOOfflineMode"
private let fastSynchEnabledKey = "EXOFastSynchEnabled"
private let nativeMTPEnabledKey = "EXONativeMTPEnabled"
private let onboardingCompletedKey = "EXOOnboardingCompleted"
private let defaultModelsDirKey = "EXODefaultModelsDir"
private let additionalModelsDirsKey = "EXOAdditionalModelsDirs"
Expand Down Expand Up @@ -109,6 +110,17 @@ final class ExoProcessController: ObservableObject {
UserDefaults.standard.set(fastSynchEnabled, forKey: fastSynchEnabledKey)
}
}
@Published var nativeMTPEnabled: Bool = {
if UserDefaults.standard.object(forKey: nativeMTPEnabledKey) == nil {
return true
}
return UserDefaults.standard.bool(forKey: nativeMTPEnabledKey)
}()
{
didSet {
UserDefaults.standard.set(nativeMTPEnabled, forKey: nativeMTPEnabledKey)
}
}
@Published var defaultModelsDir: String = {
return UserDefaults.standard.string(forKey: defaultModelsDirKey) ?? ""
}()
Expand Down Expand Up @@ -366,6 +378,7 @@ final class ExoProcessController: ObservableObject {
environment["EXO_OFFLINE"] = "true"
}
environment["EXO_FAST_SYNCH"] = fastSynchEnabled ? "true" : "false"
environment["EXO_NATIVE_MTP_ENABLED"] = nativeMTPEnabled ? "1" : "0"

var paths: [String] = []
if let existing = environment["PATH"], !existing.isEmpty {
Expand Down
13 changes: 13 additions & 0 deletions app/EXO/EXO/Views/SettingsView.swift
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ struct SettingsView: View {
@State private var pendingEnableImageModels = false
@State private var pendingOfflineMode = false
@State private var pendingFastSynchEnabled = false
@State private var pendingNativeMTPEnabled = true
@State private var pendingDefaultModelsDir: String = ""
@State private var pendingAdditionalModelsDirs: String = ""
@State private var pendingReadOnlyModelsDirs: String = ""
Expand Down Expand Up @@ -54,6 +55,7 @@ struct SettingsView: View {
pendingEnableImageModels = controller.enableImageModels
pendingOfflineMode = controller.offlineMode
pendingFastSynchEnabled = controller.fastSynchEnabled
pendingNativeMTPEnabled = controller.nativeMTPEnabled
pendingDefaultModelsDir = controller.defaultModelsDir
pendingAdditionalModelsDirs = controller.additionalModelsDirs
pendingReadOnlyModelsDirs = controller.readOnlyModelsDirs
Expand Down Expand Up @@ -131,6 +133,15 @@ struct SettingsView: View {
.foregroundColor(.secondary)
}

Section {
Toggle("Use Native MTP", isOn: $pendingNativeMTPEnabled)
Text(
"Enable native multi-token-prediction speculative decoding for supported model checkpoints. On by default."
)
.font(.caption)
.foregroundColor(.secondary)
}

Section {
HStack {
Spacer()
Expand Down Expand Up @@ -607,6 +618,7 @@ struct SettingsView: View {

private var hasModelChanges: Bool {
pendingEnableImageModels != controller.enableImageModels
|| pendingNativeMTPEnabled != controller.nativeMTPEnabled
}

private var hasAdvancedChanges: Bool {
Expand All @@ -630,6 +642,7 @@ struct SettingsView: View {

private func applyModelSettings() {
controller.enableImageModels = pendingEnableImageModels
controller.nativeMTPEnabled = pendingNativeMTPEnabled
restartIfRunning()
}

Expand Down
25 changes: 25 additions & 0 deletions dashboard/src/lib/components/ModelPickerGroup.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
capabilities?: string[];
family?: string;
is_custom?: boolean;
native_mtp?: { default_k: number; max_k: number } | null;
}

interface ModelGroup {
Expand Down Expand Up @@ -127,6 +128,11 @@
}
}

// Whether any variant in this group ships native MTP (informational badge).
const hasNativeMtp = $derived(
group.variants.some((v) => v.native_mtp != null),
);

// Check if this group's model is currently selected (for single-variant groups)
const isMainSelected = $derived(
!group.hasMultipleVariants &&
Expand Down Expand Up @@ -309,6 +315,15 @@
</svg>
{/if}
{/each}
<!-- Native MTP badge (tiny, informational) -->
{#if hasNativeMtp}
<span
class="px-1 text-[9px] leading-[1.1] font-mono rounded-sm bg-emerald-500/15 text-emerald-300 flex-shrink-0"
title="Native MTP speculative decoding built into this checkpoint"
>
MTP
</span>
{/if}
</div>
</div>

Expand Down Expand Up @@ -523,6 +538,16 @@
{variant.quantization || "default"}
</span>

<!-- Native MTP badge (tiny, informational) -->
{#if variant.native_mtp}
<span
class="px-1 text-[9px] leading-[1.1] font-mono rounded-sm bg-emerald-500/15 text-emerald-300 flex-shrink-0"
title={`Native MTP speculative decoding (default K${variant.native_mtp.default_k}, max K${variant.native_mtp.max_k})`}
>
MTP
</span>
{/if}

<!-- Size -->
<span
class="text-xs font-mono flex-1 {getSizeClassForFitStatus(
Expand Down
1 change: 1 addition & 0 deletions dashboard/src/lib/components/ModelPickerModal.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
is_custom?: boolean;
tasks?: string[];
hugging_face_id?: string;
native_mtp?: { default_k: number; max_k: number } | null;
}

interface ModelGroup {
Expand Down
1 change: 1 addition & 0 deletions dashboard/src/routes/+page.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -783,6 +783,7 @@
quantization?: string;
base_model?: string;
capabilities?: string[];
native_mtp?: { default_k: number; max_k: number } | null;
}>
>([]);
type ModelMemoryFitStatus =
Expand Down
39 changes: 39 additions & 0 deletions resources/inference_model_cards/Jundot--Qwen3.6-27B-oQ8-mtp.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
model_id = "Jundot/Qwen3.6-27B-oQ8-mtp"
n_layers = 64
hidden_size = 5120
num_key_value_heads = 4
supports_tensor = true
tasks = ["TextGeneration"]
family = "qwen"
quantization = "oQ8 MTP"
base_model = "Qwen3.6 27B"
capabilities = ["text"]
reasoning_dialect = "none"
context_length = 262144
backends = ["MlxMetal", "MlxCuda", "MlxCpu"]
[storage_size]
in_bytes = 30001641398

# Native Qwen3.6 MTP weights are embedded in the checkpoint shards.
# Measured-best fixed K on local M5 Max: K=2, capped at K=3.
[native_mtp]
num_layers = 1
default_k = 2
max_k = 3

# Source: https://huggingface.co/Qwen/Qwen3.6-27B (best-practices)
[sampling_defaults]
temperature = 0.6
top_p = 0.95
top_k = 20
min_p = 0.0
repetition_penalty = 1.0
presence_penalty = 1.5

[sampling_defaults.non_thinking]
temperature = 0.7
top_p = 0.8
top_k = 20
min_p = 0.0
repetition_penalty = 1.0
presence_penalty = 1.5
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
model_id = "alvarolizama/Qwen3.6-35B-A3B-oQ8-mtp"
n_layers = 40
hidden_size = 2048
num_key_value_heads = 2
supports_tensor = true
tasks = ["TextGeneration"]
family = "qwen"
quantization = "oQ8 MTP"
base_model = "Qwen3.6 35B A3B"
capabilities = ["text"]
reasoning_dialect = "none"
context_length = 262144
backends = ["MlxMetal", "MlxCuda", "MlxCpu"]
[storage_size]
in_bytes = 38604987852

# Native Qwen3.6 MoE MTP weights are embedded in the checkpoint shards.
# Measured-best fixed K on local M5 Max: K=1, capped at K=3.
[native_mtp]
num_layers = 1
default_k = 1
max_k = 3

# Source: https://huggingface.co/Qwen/Qwen3.6-35B-A3B (best-practices)
[sampling_defaults]
temperature = 0.6
top_p = 0.95
top_k = 20
min_p = 0.0
repetition_penalty = 1.0
presence_penalty = 1.5

[sampling_defaults.non_thinking]
temperature = 0.7
top_p = 0.8
top_k = 20
min_p = 0.0
repetition_penalty = 1.0
presence_penalty = 1.5
9 changes: 9 additions & 0 deletions src/exo/api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@
InstanceLinkResponse,
ModelList,
ModelListModel,
NativeMTPModelInfo,
PlaceInstanceParams,
PlacementPreview,
PlacementPreviewResponse,
Expand Down Expand Up @@ -1771,6 +1772,14 @@ async def get_models(self, status: str | None = Query(default=None)) -> ModelLis
capabilities=card.capabilities,
reasoning_dialect=card.reasoning_dialect,
context_length=card.context_length,
native_mtp=(
NativeMTPModelInfo(
default_k=card.native_mtp.default_k,
max_k=card.native_mtp.max_k,
)
if card.native_mtp is not None
else None
),
)
for card in cards
]
Expand Down
78 changes: 78 additions & 0 deletions src/exo/api/tests/test_generation_stats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
"""Regression tests for speculative-decode generation telemetry."""

from __future__ import annotations

from exo.api.types.api import GenerationStats
from exo.shared.types.memory import Memory


def _stats(
*,
generation_tokens: int = 100,
accepted: int = 25,
proposed: int = 0,
drafter_model_id: str | None = None,
draft_mode: str | None = None,
drafter_kind: str | None = None,
num_draft_tokens: int | None = None,
) -> GenerationStats:
return GenerationStats(
prompt_tps=0.0,
prompt_tokens=10,
generation_tps=10.0,
generation_tokens=generation_tokens,
peak_memory_usage=Memory.from_bytes(0),
accepted_draft_tokens=accepted,
proposed_draft_tokens=proposed,
drafter_model_id=drafter_model_id,
num_draft_tokens=num_draft_tokens,
draft_mode=draft_mode, # pyright: ignore[reportArgumentType]
drafter_kind=drafter_kind, # pyright: ignore[reportArgumentType]
)


def test_acceptance_fraction_is_none_for_explicit_none_mode() -> None:
stats = _stats(draft_mode="none", accepted=0)
assert stats.drafter_acceptance_fraction is None


def test_acceptance_fraction_reports_for_model_runs() -> None:
stats = _stats(
draft_mode="model",
drafter_model_id="some-org/drafter-7b",
accepted=40,
)
assert stats.drafter_acceptance_fraction == 0.40


def test_acceptance_metrics_report_for_native_mtp_runs_without_drafter_id() -> None:
stats = _stats(
draft_mode="model",
drafter_model_id=None,
drafter_kind="native_mtp",
num_draft_tokens=2,
accepted=25,
proposed=40,
)
assert stats.drafter_acceptance_fraction == 0.25
assert stats.drafter_acceptance_rate == 0.625


def test_acceptance_fraction_legacy_payload_without_draft_mode() -> None:
legacy_with_drafter = _stats(
draft_mode=None,
drafter_model_id="legacy-org/drafter",
accepted=10,
)
legacy_without_drafter = _stats(
draft_mode=None,
drafter_model_id=None,
accepted=0,
)
assert legacy_with_drafter.drafter_acceptance_fraction is not None
assert legacy_without_drafter.drafter_acceptance_fraction is None


def test_acceptance_fraction_zero_generation_tokens_returns_none() -> None:
stats = _stats(generation_tokens=0, accepted=0, draft_mode="model")
assert stats.drafter_acceptance_fraction is None
63 changes: 63 additions & 0 deletions src/exo/api/tests/test_model_list_native_mtp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import pytest

from exo.api.main import API
from exo.shared.models import model_cards
from exo.shared.models.model_cards import (
ModelCard,
ModelTask,
NativeMTPConfig,
)
from exo.shared.types.backends import Backend
from exo.shared.types.common import ModelId
from exo.shared.types.memory import Memory
from exo.shared.types.state import State


def _native_mtp_card(model_id: str, *, default_k: int, max_k: int) -> ModelCard:
return ModelCard(
model_id=ModelId(model_id),
storage_size=Memory.from_mb(1),
n_layers=1,
hidden_size=1,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
backends=[Backend.MlxMetal],
native_mtp=NativeMTPConfig(
num_layers=1,
default_k=default_k,
max_k=max_k,
),
)


@pytest.mark.asyncio
async def test_models_response_includes_native_mtp_for_native_mtp_cards(
monkeypatch: pytest.MonkeyPatch,
) -> None:
cards = [
_native_mtp_card("Jundot/Qwen3.6-27B-oQ8-mtp", default_k=2, max_k=3),
_native_mtp_card(
"alvarolizama/Qwen3.6-35B-A3B-oQ8-mtp",
default_k=1,
max_k=3,
),
]

async def _fake_list_all() -> list[ModelCard]:
return cards

api = API.__new__(API)
api.state = State()
monkeypatch.setattr(model_cards.card_cache, "list_all", _fake_list_all)

response = await api.get_models()

by_id = {item.id: item for item in response.data}
qwen27_native_mtp = by_id["Jundot/Qwen3.6-27B-oQ8-mtp"].native_mtp
qwen35_native_mtp = by_id["alvarolizama/Qwen3.6-35B-A3B-oQ8-mtp"].native_mtp
assert qwen27_native_mtp is not None
assert qwen27_native_mtp.default_k == 2
assert qwen27_native_mtp.max_k == 3
assert qwen35_native_mtp is not None
assert qwen35_native_mtp.default_k == 1
assert qwen35_native_mtp.max_k == 3
Loading