From f383ef0ad9457f30df008182cc1b0145d9477096 Mon Sep 17 00:00:00 2001 From: jw-wcv <101585096+jw-wcv@users.noreply.github.com> Date: Wed, 6 May 2026 17:19:30 -0700 Subject: [PATCH 1/2] Add drafter_model_id to ModelCard; plumb draft_model through mlx_generate MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds the surface-level support for speculative decoding via mlx_lm's stream_generate(draft_model=...) on the single-device generation path: - `ModelCard.drafter_model_id: ModelId | None`: declarative pointer to a drafter model that runners may load alongside the target. The drafter must share a tokenizer with the target; this is the caller's responsibility to enforce. - `mlx_generate(draft_model=...)`: forwarded to `stream_generate` when `group is None` (single-device). Distributed-mode draft is dropped explicitly, since mlx_lm's speculative decoding does not yet plumb through tensor-parallel groups. - Eight Gemma 4 model cards (gemma-4-26b-a4b-it and gemma-4-31b-it, 4bit/6bit/8bit/bf16) declare gemma-4-e2b-it (matching quant) as their drafter. The Gemma 4 family shares a tokenizer across e2b/e4b/26b/31b, so e2b is a valid drafter. Drafter loading at builder/runner bootstrap is intentionally not in this patch — keeping the diff focused on the model-card schema and the single-device generate plumbing. Wiring drafter download and load_drafter() into MlxBuilder is straightforward follow-up work. Tests: - test_model_cards_drafter.py: 4 tests covering default-None, Gemma 4 31b/26b drafter pointers, and round-trip of an explicit value. --- ...lx-community--gemma-4-26b-a4b-it-4bit.toml | 1 + ...lx-community--gemma-4-26b-a4b-it-6bit.toml | 1 + ...lx-community--gemma-4-26b-a4b-it-8bit.toml | 1 + ...lx-community--gemma-4-26b-a4b-it-bf16.toml | 1 + .../mlx-community--gemma-4-31b-it-4bit.toml | 1 + .../mlx-community--gemma-4-31b-it-6bit.toml | 1 + .../mlx-community--gemma-4-31b-it-8bit.toml | 1 + .../mlx-community--gemma-4-31b-it-bf16.toml | 1 + src/exo/shared/models/model_cards.py | 5 ++ .../shared/tests/test_model_cards_drafter.py | 72 +++++++++++++++++++ .../worker/engines/mlx/generator/generate.py | 8 +++ 11 files changed, 93 insertions(+) create mode 100644 src/exo/shared/tests/test_model_cards_drafter.py diff --git a/resources/inference_model_cards/mlx-community--gemma-4-26b-a4b-it-4bit.toml b/resources/inference_model_cards/mlx-community--gemma-4-26b-a4b-it-4bit.toml index 51be323ec2..863203b743 100644 --- a/resources/inference_model_cards/mlx-community--gemma-4-26b-a4b-it-4bit.toml +++ b/resources/inference_model_cards/mlx-community--gemma-4-26b-a4b-it-4bit.toml @@ -8,6 +8,7 @@ family = "gemma" quantization = "4bit" base_model = "Gemma 4 26B A4B" capabilities = ["text", "vision"] +drafter_model_id = "mlx-community/gemma-4-e2b-it-4bit" context_length = 262144 diff --git a/resources/inference_model_cards/mlx-community--gemma-4-26b-a4b-it-6bit.toml b/resources/inference_model_cards/mlx-community--gemma-4-26b-a4b-it-6bit.toml index c984d44b7d..32a0a84d56 100644 --- a/resources/inference_model_cards/mlx-community--gemma-4-26b-a4b-it-6bit.toml +++ b/resources/inference_model_cards/mlx-community--gemma-4-26b-a4b-it-6bit.toml @@ -8,6 +8,7 @@ family = "gemma" quantization = "6bit" base_model = "Gemma 4 26B A4B" capabilities = ["text", "vision"] +drafter_model_id = "mlx-community/gemma-4-e2b-it-6bit" context_length = 262144 diff --git a/resources/inference_model_cards/mlx-community--gemma-4-26b-a4b-it-8bit.toml b/resources/inference_model_cards/mlx-community--gemma-4-26b-a4b-it-8bit.toml index fe2583668c..3201ec8283 100644 --- a/resources/inference_model_cards/mlx-community--gemma-4-26b-a4b-it-8bit.toml +++ b/resources/inference_model_cards/mlx-community--gemma-4-26b-a4b-it-8bit.toml @@ -8,6 +8,7 @@ family = "gemma" quantization = "8bit" base_model = "Gemma 4 26B A4B" capabilities = ["text", "vision"] +drafter_model_id = "mlx-community/gemma-4-e2b-it-8bit" context_length = 262144 diff --git a/resources/inference_model_cards/mlx-community--gemma-4-26b-a4b-it-bf16.toml b/resources/inference_model_cards/mlx-community--gemma-4-26b-a4b-it-bf16.toml index ea4dbbfc59..39ea210a64 100644 --- a/resources/inference_model_cards/mlx-community--gemma-4-26b-a4b-it-bf16.toml +++ b/resources/inference_model_cards/mlx-community--gemma-4-26b-a4b-it-bf16.toml @@ -8,6 +8,7 @@ family = "gemma" quantization = "bf16" base_model = "Gemma 4 26B A4B" capabilities = ["text", "vision"] +drafter_model_id = "mlx-community/gemma-4-e2b-it-bf16" context_length = 262144 diff --git a/resources/inference_model_cards/mlx-community--gemma-4-31b-it-4bit.toml b/resources/inference_model_cards/mlx-community--gemma-4-31b-it-4bit.toml index cb8e63580f..87a7584cbb 100644 --- a/resources/inference_model_cards/mlx-community--gemma-4-31b-it-4bit.toml +++ b/resources/inference_model_cards/mlx-community--gemma-4-31b-it-4bit.toml @@ -8,6 +8,7 @@ family = "gemma" quantization = "4bit" base_model = "Gemma 4 31B" capabilities = ["text", "vision"] +drafter_model_id = "mlx-community/gemma-4-e2b-it-4bit" context_length = 262144 diff --git a/resources/inference_model_cards/mlx-community--gemma-4-31b-it-6bit.toml b/resources/inference_model_cards/mlx-community--gemma-4-31b-it-6bit.toml index 845620626d..0e0314e119 100644 --- a/resources/inference_model_cards/mlx-community--gemma-4-31b-it-6bit.toml +++ b/resources/inference_model_cards/mlx-community--gemma-4-31b-it-6bit.toml @@ -8,6 +8,7 @@ family = "gemma" quantization = "6bit" base_model = "Gemma 4 31B" capabilities = ["text", "vision"] +drafter_model_id = "mlx-community/gemma-4-e2b-it-6bit" context_length = 262144 diff --git a/resources/inference_model_cards/mlx-community--gemma-4-31b-it-8bit.toml b/resources/inference_model_cards/mlx-community--gemma-4-31b-it-8bit.toml index 332a9b0053..0e33f6ff58 100644 --- a/resources/inference_model_cards/mlx-community--gemma-4-31b-it-8bit.toml +++ b/resources/inference_model_cards/mlx-community--gemma-4-31b-it-8bit.toml @@ -8,6 +8,7 @@ family = "gemma" quantization = "8bit" base_model = "Gemma 4 31B" capabilities = ["text", "vision"] +drafter_model_id = "mlx-community/gemma-4-e2b-it-8bit" context_length = 262144 diff --git a/resources/inference_model_cards/mlx-community--gemma-4-31b-it-bf16.toml b/resources/inference_model_cards/mlx-community--gemma-4-31b-it-bf16.toml index 6fc0a2dcaa..1da7e56e9d 100644 --- a/resources/inference_model_cards/mlx-community--gemma-4-31b-it-bf16.toml +++ b/resources/inference_model_cards/mlx-community--gemma-4-31b-it-bf16.toml @@ -8,6 +8,7 @@ family = "gemma" quantization = "bf16" base_model = "Gemma 4 31B" capabilities = ["text", "vision"] +drafter_model_id = "mlx-community/gemma-4-e2b-it-bf16" context_length = 262144 diff --git a/src/exo/shared/models/model_cards.py b/src/exo/shared/models/model_cards.py index 0d1648a7b1..e6c6a7cef9 100644 --- a/src/exo/shared/models/model_cards.py +++ b/src/exo/shared/models/model_cards.py @@ -171,6 +171,11 @@ class ModelCard(FrozenModel): is_custom: bool = False vision: VisionCardConfig | None = None sampling_defaults: SamplingDefaults = Field(default_factory=SamplingDefaults) + # Optional speculative-decoding draft model. When set, runners will load the + # named model alongside the target and pass it as `draft_model` to mlx_lm's + # `stream_generate`, enabling MLX-side speculative decoding. The drafter MUST + # share a tokenizer with the target. + drafter_model_id: ModelId | None = None @model_validator(mode="after") def _autodetect_vision(self) -> "ModelCard": diff --git a/src/exo/shared/tests/test_model_cards_drafter.py b/src/exo/shared/tests/test_model_cards_drafter.py new file mode 100644 index 0000000000..302bcd3368 --- /dev/null +++ b/src/exo/shared/tests/test_model_cards_drafter.py @@ -0,0 +1,72 @@ +"""Tests for the optional `drafter_model_id` field on ModelCard. + +The field declares a speculative-decoding draft model that runners may load +alongside the target. Coverage: +- ModelCard accepts and serialises the field. +- Cards with no drafter declared default to `None`. +- The Gemma 4 large-instruct cards point to the e2b drafter. +""" + +import pytest + +from exo.shared.models.model_cards import ModelCard, ModelId, get_model_cards +from exo.shared.types.memory import Memory + + +@pytest.mark.asyncio +async def test_drafter_model_id_defaults_to_none() -> None: + cards = {card.model_id: card for card in await get_model_cards()} + qwen_id = ModelId("mlx-community/Qwen3-30B-A3B-4bit") + if qwen_id in cards: + assert cards[qwen_id].drafter_model_id is None + + +@pytest.mark.asyncio +async def test_gemma4_31b_cards_declare_e2b_drafter() -> None: + cards = {card.model_id: card for card in await get_model_cards()} + expectations = { + "mlx-community/gemma-4-31b-it-4bit": "mlx-community/gemma-4-e2b-it-4bit", + "mlx-community/gemma-4-31b-it-6bit": "mlx-community/gemma-4-e2b-it-6bit", + "mlx-community/gemma-4-31b-it-8bit": "mlx-community/gemma-4-e2b-it-8bit", + "mlx-community/gemma-4-31b-it-bf16": "mlx-community/gemma-4-e2b-it-bf16", + } + for target_str, expected_drafter_str in expectations.items(): + target_id = ModelId(target_str) + assert target_id in cards, f"{target_id} card missing" + card = cards[target_id] + assert card.drafter_model_id == ModelId(expected_drafter_str), ( + f"{target_id} drafter mismatch: got {card.drafter_model_id!r}" + ) + + +@pytest.mark.asyncio +async def test_gemma4_26b_cards_declare_e2b_drafter() -> None: + cards = {card.model_id: card for card in await get_model_cards()} + expectations = { + "mlx-community/gemma-4-26b-a4b-it-4bit": "mlx-community/gemma-4-e2b-it-4bit", + "mlx-community/gemma-4-26b-a4b-it-6bit": "mlx-community/gemma-4-e2b-it-6bit", + "mlx-community/gemma-4-26b-a4b-it-8bit": "mlx-community/gemma-4-e2b-it-8bit", + "mlx-community/gemma-4-26b-a4b-it-bf16": "mlx-community/gemma-4-e2b-it-bf16", + } + for target_str, expected_drafter_str in expectations.items(): + target_id = ModelId(target_str) + assert target_id in cards, f"{target_id} card missing" + card = cards[target_id] + assert card.drafter_model_id == ModelId(expected_drafter_str), ( + f"{target_id} drafter mismatch: got {card.drafter_model_id!r}" + ) + + +def test_model_card_explicit_drafter_round_trip() -> None: + card = ModelCard( + model_id=ModelId("mlx-community/test-target"), + storage_size=Memory.from_gb(1.0), + n_layers=12, + hidden_size=768, + supports_tensor=True, + tasks=["TextGeneration"], # pyright: ignore[reportArgumentType] + drafter_model_id=ModelId("mlx-community/test-drafter"), + ) + assert card.drafter_model_id == ModelId("mlx-community/test-drafter") + dump = card.model_dump(exclude_none=True) + assert dump["drafter_model_id"] == "mlx-community/test-drafter" diff --git a/src/exo/worker/engines/mlx/generator/generate.py b/src/exo/worker/engines/mlx/generator/generate.py index 2e3d051251..c7a7612693 100644 --- a/src/exo/worker/engines/mlx/generator/generate.py +++ b/src/exo/worker/engines/mlx/generator/generate.py @@ -540,6 +540,7 @@ def mlx_generate( distributed_prompt_progress_callback: Callable[[], None] | None = None, on_generation_token: Callable[[], None] | None = None, vision_processor: VisionProcessor | None = None, + draft_model: Model | None = None, ) -> Generator[GenerationResponse]: # Ensure that generation stats only contains peak memory for this generation mx.reset_peak_memory() @@ -717,6 +718,12 @@ def mlx_generate( logger.info("Starting decode") mx_barrier(group) + # Speculative decoding via mlx_lm: only enabled in the single-device path + # (group is None). Distributed speculative is not yet plumbed; passing a + # draft_model alongside a non-trivial group would be a no-op, so we drop + # it explicitly to make the caller contract clear. + effective_draft_model = draft_model if group is None else None + for completion_tokens, out in enumerate( stream_generate( model=model, @@ -729,6 +736,7 @@ def mlx_generate( prefill_step_size=1, kv_group_size=KV_GROUP_SIZE, kv_bits=KV_BITS, + draft_model=effective_draft_model, ), start=1, ): From 5dae97de8539e490f59ef2bec166a23088c16423 Mon Sep 17 00:00:00 2001 From: jw-wcv <101585096+jw-wcv@users.noreply.github.com> Date: Sun, 10 May 2026 15:03:29 -0700 Subject: [PATCH 2/2] Drafter abstraction + Gemma 4 MTP + Qwen 3.5/3.6 DFlash + multi-device coupled drafter MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Lands the full speculative-decoding stack on top of the ``drafter_model_id`` ModelCard foundation: 1. Drafter abstraction (``Drafter`` Protocol with ``stream`` / ``metrics`` / ``DraftMode``) and the ``CoupledModelDrafter`` shim around mlx-vlm's ``_mtp_rounds`` / ``_dflash_rounds``. ``GenerationStats.drafter_kind`` ∈ {standard, mtp, dflash, ngram, none} so OpenAI ``CompletionTokensDetails`` + the dashboard surface which speculative path actually dispatched. 2. In-process drafter tuning: K, warmup, KV cache, n-gram strategy. 3. Asymmetric pipelined drafter for uneven-memory clusters -- ``DrafterRunner`` + mx.distributed / socket transports + concurrency. 4. Production hardening: resilience, TP fanout, telemetry, bench. 5. Gemma 4 MTP coupled drafter (Phase 1-3). New ``ModelCard.coupled_drafter`` field; ``mlx-vlm>=0.5.0`` loader + per-kind target-side hook attachment (``attach_mtp_hooks`` for Gemma 4). 31B and 26B-A4B at all four quants declare the coupled MTP drafter. Headline: Gemma 4 31B 4bit + MTP drafter at T=0 jumps from 13.8 t/s to 24.7 t/s with byte-identical output (single M3 Ultra). 6. Qwen 3.5 / 3.6 DFlash coupled drafter. Vendored ``forward_with_capture`` + ``rollback_speculative_cache`` for the hybrid attention / gated-delta-net architecture. The drafter consumes captured hidden states + an 11-tuple ``GdnState`` and replays them on rejection. Headlines (median over 10 runs per A/B side, T=0): Qwen 3.5 4B 8bit (dense, wc-smbp) 97.24 -> 404.38 t/s 4.16x Qwen 3.6 27B 8bit (dense, wc-smbpt) 14.98 -> 49.13 t/s 3.28x Qwen 3.6 35B-A3B 8bit (MoE, wc-smbpt) 87.70 -> 377.49 t/s 4.30x Qwen 3.5 122B-A10B 8bit (MoE, TP2 RDMA) 52.61 -> 159.00 t/s 3.02x 7. Multi-device coupled drafter dispatch (tensor-parallel). The previous loader hard-coded ``if group is None`` and the generator hard-coded ``draft_mode = "none"`` whenever ``group is not None``, so the coupled drafter never ran on TP placements -- exactly the regime 122B-class targets live in. Lifted via: * ``_try_load_collocated_drafter`` is now called from both the single-device and the symmetric multi-rank branches. The multi-device call passes ``allow_standard_drafter_fallback= False`` because the generator still can't dispatch standard drafters through ``group``, so a loaded standard drafter would only waste memory. * ``mlx_generate`` only forces ``draft_mode = "none"`` for multi-device when ``coupled_drafter_eligible`` is false. * ``builder.py`` selects ``SequentialGenerator`` (speculative-capable) when ``coupled_drafter_dispatchable`` is true, even with ``group is not None``. Correctness: each TP rank's per-rank ``__call__`` reduces its output to the full hidden state (via the in-layer ``ShardedToAllLinear`` / ``ShardedMoE`` all-sums), so the replicated drafter consumes an identical hidden state and produces identical draft tokens / bonus samples under the shared ``mx.random.seed(seed)`` set at the top of each generation step. 122B-A10B + JACCL/RDMA across two MacBook Pros validates the path end-to-end. 8. Single-file ``safetensors.index.json`` bootstrap. DFlash drafters that ship with just ``model.safetensors`` no longer trip the shard downloader. 9. Bench results + reports. ``bench/results/{mtp,dflash}/REPORT.md`` document the A/B methodology and headline numbers. Raw per-request gen_tps + acceptance JSON committed for reproducibility. Tests: 1056 passing, basedpyright 0 errors project-wide, ruff clean. --- .gitignore | 1 - .mlx_typings/mlx_lm/models/cache.pyi | 24 +- .mlx_typings/mlx_lm/models/gemma4_text.pyi | 62 +- Cargo.lock | 42 +- bench/eval_tool_calls.py | 5 +- bench/exo_bench.py | 5 +- bench/exo_eval.py | 5 +- bench/harness.py | 625 ++++ bench/prefill_decode_bench.py | 5 +- bench/results/dflash/REPORT.md | 490 +++ ...5-122b-a10b-mlx-8bit-tp2-jaccl-dflash.json | 213 ++ ...b-a10b-mlx-8bit-tp2-jaccl-target-only.json | 213 ++ .../dflash/qwen3.5-4b-mlx-8bit-dflash.json | 213 ++ .../qwen3.5-4b-mlx-8bit-target-only.json | 213 ++ .../dflash/qwen3.6-27b-mlx-8bit-dflash.json | 213 ++ .../qwen3.6-27b-mlx-8bit-target-only.json | 213 ++ .../qwen3.6-35b-a3b-mlx-8bit-dflash.json | 213 ++ .../qwen3.6-35b-a3b-mlx-8bit-target-only.json | 213 ++ bench/results/mtp/REPORT.md | 121 + .../src/exo_bench}/__init__.py | 0 pyproject.toml | 23 +- ...mlx-community--Qwen3.5-122B-A10B-8bit.toml | 11 + .../mlx-community--Qwen3.5-4B-MLX-8bit.toml | 41 + .../mlx-community--Qwen3.6-27B-8bit.toml | 16 +- .../mlx-community--Qwen3.6-35B-A3B-8bit.toml | 21 +- ...lx-community--gemma-4-26b-a4b-it-4bit.toml | 3 +- ...lx-community--gemma-4-26b-a4b-it-6bit.toml | 3 +- ...lx-community--gemma-4-26b-a4b-it-8bit.toml | 3 +- ...lx-community--gemma-4-26b-a4b-it-bf16.toml | 3 +- .../mlx-community--gemma-4-31b-it-4bit.toml | 3 +- .../mlx-community--gemma-4-31b-it-6bit.toml | 3 +- .../mlx-community--gemma-4-31b-it-8bit.toml | 3 +- .../mlx-community--gemma-4-31b-it-bf16.toml | 3 +- rust/exo_pyo3_bindings/Cargo.toml | 3 - rust/exo_pyo3_bindings/exo_pyo3_bindings.pyi | 44 - rust/exo_pyo3_bindings/pyproject.toml | 2 +- rust/exo_pyo3_bindings/src/lib.rs | 3 - rust/exo_pyo3_bindings/src/pidfile.rs | 87 - rust/exo_pyo3_bindings/tests/test_python.py | 13 - src/exo/api/adapters/chat_completions.py | 13 + src/exo/api/adapters/responses.py | 160 +- src/exo/api/main.py | 414 ++- src/exo/api/tests/test_agent_endpoints.py | 424 +++ ...test_chat_completion_request_validation.py | 108 + .../tests/test_chat_completions_adapter.py | 93 + src/exo/api/types/__init__.py | 2 + src/exo/api/types/api.py | 202 +- src/exo/api/types/openai_responses.py | 1 + .../api/types/tests/test_generation_stats.py | 96 + src/exo/diagnostics.py | 194 ++ src/exo/download/coordinator.py | 889 ++++- src/exo/download/download_utils.py | 158 +- src/exo/download/impl_shard_downloader.py | 16 +- src/exo/download/peer_download.py | 271 ++ src/exo/download/peer_file_server.py | 376 +++ src/exo/download/peer_shard_downloader.py | 510 +++ src/exo/download/peer_state.py | 129 + .../tests/test_download_status_not_lost.py | 30 +- .../download/tests/test_drafter_download.py | 2333 +++++++++++++ src/exo/download/tests/test_model_dirs.py | 98 + src/exo/download/tests/test_peer_download.py | 1759 ++++++++++ src/exo/download/tests/test_peer_state.py | 142 + src/exo/main.py | 391 ++- src/exo/master/main.py | 252 +- src/exo/master/placement.py | 1479 ++++++++- src/exo/master/placement_utils.py | 151 +- src/exo/master/tests/test_master.py | 237 +- src/exo/master/tests/test_placement.py | 2951 +++++++++++++++-- .../tests/test_placement_auto_prefill.py | 490 +++ .../test_placement_drafter_asymmetric.py | 1566 +++++++++ .../tests/test_placement_drafter_warning.py | 141 + src/exo/routing/event_router.py | 36 +- src/exo/routing/mdns_announcer.py | 95 + src/exo/routing/router.py | 225 +- .../routing/tests/test_node_id_migration.py | 533 +++ src/exo/shared/apply.py | 59 +- src/exo/shared/constants.py | 21 +- src/exo/shared/election.py | 73 +- src/exo/shared/logging.py | 96 +- src/exo/shared/models/model_cards.py | 227 +- .../test_apply_custom_model_cards.py | 44 - .../test_apply/test_apply_runner_deleted.py | 59 +- .../tests/test_diagnostic_snapshot_config.py | 42 + .../test_drafter_placement_wire_compat.py | 124 + src/exo/shared/tests/test_election.py | 198 ++ .../shared/tests/test_model_cards_drafter.py | 330 +- src/exo/shared/tests/test_xdg_paths.py | 22 +- src/exo/shared/topology.py | 2 +- src/exo/shared/types/commands.py | 12 + src/exo/shared/types/events.py | 56 +- src/exo/shared/types/state.py | 6 +- src/exo/shared/types/text_generation.py | 38 +- src/exo/shared/types/thunderbolt.py | 31 +- src/exo/shared/types/worker/instances.py | 216 +- src/exo/shared/types/worker/shards.py | 31 +- src/exo/utils/async_process.py | 290 -- src/exo/utils/daemon.py | 28 - .../info_gatherer/tests/test_tb_parsing.py | 62 + src/exo/utils/keyed_backoff.py | 4 + src/exo/utils/pidfile.py | 28 - src/exo/utils/ports.py | 114 +- src/exo/utils/power_sampler.py | 53 +- src/exo/utils/tests/conftest.py | 8 - src/exo/utils/tests/test_async_process.py | 515 --- src/exo/utils/tests/test_daemon.py | 168 - src/exo/utils/tests/test_keyed_backoff.py | 13 + src/exo/utils/tests/test_pidfile.py | 84 - src/exo/utils/tests/test_ports.py | 58 + src/exo/utils/tests/test_power_sampler.py | 30 - src/exo/worker/engines/image/builder.py | 4 +- .../worker/engines/mlx/asymmetric_parallel.py | 375 +++ src/exo/worker/engines/mlx/builder.py | 418 ++- src/exo/worker/engines/mlx/cache.py | 127 +- src/exo/worker/engines/mlx/constants.py | 11 +- .../engines/mlx/generator/coupled_drafter.py | 1117 +++++++ .../worker/engines/mlx/generator/drafter.py | 1433 ++++++++ .../engines/mlx/generator/drafter_socket.py | 269 ++ .../mlx/generator/drafter_transport.py | 437 +++ .../worker/engines/mlx/generator/generate.py | 1372 +++++++- .../mlx/generator/pipelined_drafter.py | 1277 +++++++ .../engines/mlx/generator/remote_drafter.py | 986 ++++++ .../mlx/generator/target_peer_socket.py | 189 ++ .../engines/mlx/tests/test_batched_prefill.py | 270 ++ src/exo/worker/engines/mlx/utils_mlx.py | 1775 +++++++++- .../engines/mlx/vendor/gemma4_mtp_hooks.py | 463 +++ .../mlx/vendor/qwen3_5_dflash_hooks.py | 815 +++++ src/exo/worker/main.py | 134 +- src/exo/worker/plan.py | 221 +- src/exo/worker/runner/bootstrap.py | 60 +- src/exo/worker/runner/drafter_runner.py | 350 ++ .../runner/llm_inference/batch_generator.py | 715 +++- .../llm_inference/model_output_parsers.py | 25 +- .../runner/llm_inference/tool_parsers.py | 89 +- src/exo/worker/runner/runner.py | 255 +- src/exo/worker/runner/supervisor.py | 251 +- .../unittests/test_drafter_task_routing.py | 233 ++ .../test_mlx/test_asymmetric_parallel.py | 120 + .../test_coupled_drafter_dflash_dispatch.py | 460 +++ .../test_mlx/test_coupled_drafter_dispatch.py | 801 +++++ .../test_mlx/test_coupled_drafter_loader.py | 397 +++ .../test_coupled_drafter_multi_device.py | 498 +++ .../test_coupled_drafter_round_loop.py | 344 ++ .../test_mlx/test_drafter_abstraction.py | 1001 ++++++ .../test_mlx/test_drafter_builder.py | 477 +++ .../unittests/test_mlx/test_drafter_loader.py | 195 ++ .../unittests/test_mlx/test_drafter_socket.py | 223 ++ .../unittests/test_mlx/test_drafter_tuning.py | 255 ++ .../unittests/test_mlx/test_eos_token_ids.py | 20 + .../test_mlx/test_gemma4_mtp_hooks.py | 331 ++ .../test_load_mlx_items_drafter_id.py | 351 ++ .../test_num_draft_tokens_consensus.py | 171 + .../test_mlx/test_pipelined_drafter.py | 1220 +++++++ .../test_mlx/test_qwen3_5_dflash_hooks.py | 397 +++ .../unittests/test_mlx/test_remote_drafter.py | 709 ++++ .../test_mlx/test_spec_diag_gating.py | 83 + .../unittests/test_mlx/test_tokenizers.py | 8 +- .../test_mlx/test_utils_mlx_bind_retry.py | 138 + .../test_mlx/test_utils_mlx_broadcast.py | 581 ++++ .../test_plan/test_runner_lifecycle.py | 143 + .../test_plan/test_task_forwarding.py | 59 + .../tests/unittests/test_plan/test_warmup.py | 54 + .../test_runner/test_adaptive_k_gate.py | 197 ++ .../test_batch_generator_errors.py | 88 + .../test_runner/test_event_ordering.py | 15 +- .../test_runner/test_parse_gpt_oss.py | 74 +- .../test_runner/test_responses_tool_compat.py | 303 ++ .../test_runner/test_runner_supervisor.py | 19 +- ...test_sequential_generator_batch_prefill.py | 356 ++ .../test_sequential_generator_errors.py | 428 +++ .../unittests/test_worker_instance_backoff.py | 36 + tests/auto_bench.sh | 55 + tests/conftest.py | 181 - tests/eval_tool_calls.sh | 55 + tests/framework.py | 199 -- tests/get_all_models_on_cluster.py | 36 + tests/headless_runner.py | 264 ++ tests/run_exo_on.sh | 53 + tests/start_distributed_test.py | 85 + tests/test_1node.py | 75 - tests/test_2node.py | 49 - tests/test_4node.py | 32 - tests/test_dashboard.py | 102 - tests/test_resilience.py | 56 - tests/test_vision_cache.py | 63 + tools/pyproject.toml | 10 - tools/src/exo_tools/client.py | 117 - tools/src/exo_tools/cluster.py | 243 -- uv.lock | 299 +- 188 files changed, 45653 insertions(+), 4028 deletions(-) create mode 100644 bench/harness.py create mode 100644 bench/results/dflash/REPORT.md create mode 100644 bench/results/dflash/qwen3.5-122b-a10b-mlx-8bit-tp2-jaccl-dflash.json create mode 100644 bench/results/dflash/qwen3.5-122b-a10b-mlx-8bit-tp2-jaccl-target-only.json create mode 100644 bench/results/dflash/qwen3.5-4b-mlx-8bit-dflash.json create mode 100644 bench/results/dflash/qwen3.5-4b-mlx-8bit-target-only.json create mode 100644 bench/results/dflash/qwen3.6-27b-mlx-8bit-dflash.json create mode 100644 bench/results/dflash/qwen3.6-27b-mlx-8bit-target-only.json create mode 100644 bench/results/dflash/qwen3.6-35b-a3b-mlx-8bit-dflash.json create mode 100644 bench/results/dflash/qwen3.6-35b-a3b-mlx-8bit-target-only.json create mode 100644 bench/results/mtp/REPORT.md rename {tools/src/exo_tools => bench/src/exo_bench}/__init__.py (100%) create mode 100644 resources/inference_model_cards/mlx-community--Qwen3.5-4B-MLX-8bit.toml delete mode 100644 rust/exo_pyo3_bindings/src/pidfile.rs create mode 100644 src/exo/api/tests/test_agent_endpoints.py create mode 100644 src/exo/api/tests/test_chat_completion_request_validation.py create mode 100644 src/exo/api/tests/test_chat_completions_adapter.py create mode 100644 src/exo/api/types/tests/test_generation_stats.py create mode 100644 src/exo/diagnostics.py create mode 100644 src/exo/download/peer_download.py create mode 100644 src/exo/download/peer_file_server.py create mode 100644 src/exo/download/peer_shard_downloader.py create mode 100644 src/exo/download/peer_state.py create mode 100644 src/exo/download/tests/test_drafter_download.py create mode 100644 src/exo/download/tests/test_peer_download.py create mode 100644 src/exo/download/tests/test_peer_state.py create mode 100644 src/exo/master/tests/test_placement_auto_prefill.py create mode 100644 src/exo/master/tests/test_placement_drafter_asymmetric.py create mode 100644 src/exo/master/tests/test_placement_drafter_warning.py create mode 100644 src/exo/routing/mdns_announcer.py create mode 100644 src/exo/routing/tests/test_node_id_migration.py delete mode 100644 src/exo/shared/tests/test_apply/test_apply_custom_model_cards.py create mode 100644 src/exo/shared/tests/test_diagnostic_snapshot_config.py create mode 100644 src/exo/shared/tests/test_drafter_placement_wire_compat.py delete mode 100644 src/exo/utils/async_process.py delete mode 100644 src/exo/utils/daemon.py delete mode 100644 src/exo/utils/pidfile.py delete mode 100644 src/exo/utils/tests/conftest.py delete mode 100644 src/exo/utils/tests/test_async_process.py delete mode 100644 src/exo/utils/tests/test_daemon.py create mode 100644 src/exo/utils/tests/test_keyed_backoff.py delete mode 100644 src/exo/utils/tests/test_pidfile.py create mode 100644 src/exo/utils/tests/test_ports.py create mode 100644 src/exo/worker/engines/mlx/asymmetric_parallel.py create mode 100644 src/exo/worker/engines/mlx/generator/coupled_drafter.py create mode 100644 src/exo/worker/engines/mlx/generator/drafter.py create mode 100644 src/exo/worker/engines/mlx/generator/drafter_socket.py create mode 100644 src/exo/worker/engines/mlx/generator/drafter_transport.py create mode 100644 src/exo/worker/engines/mlx/generator/pipelined_drafter.py create mode 100644 src/exo/worker/engines/mlx/generator/remote_drafter.py create mode 100644 src/exo/worker/engines/mlx/generator/target_peer_socket.py create mode 100644 src/exo/worker/engines/mlx/tests/test_batched_prefill.py create mode 100644 src/exo/worker/engines/mlx/vendor/gemma4_mtp_hooks.py create mode 100644 src/exo/worker/engines/mlx/vendor/qwen3_5_dflash_hooks.py create mode 100644 src/exo/worker/runner/drafter_runner.py create mode 100644 src/exo/worker/tests/unittests/test_drafter_task_routing.py create mode 100644 src/exo/worker/tests/unittests/test_mlx/test_asymmetric_parallel.py create mode 100644 src/exo/worker/tests/unittests/test_mlx/test_coupled_drafter_dflash_dispatch.py create mode 100644 src/exo/worker/tests/unittests/test_mlx/test_coupled_drafter_dispatch.py create mode 100644 src/exo/worker/tests/unittests/test_mlx/test_coupled_drafter_loader.py create mode 100644 src/exo/worker/tests/unittests/test_mlx/test_coupled_drafter_multi_device.py create mode 100644 src/exo/worker/tests/unittests/test_mlx/test_coupled_drafter_round_loop.py create mode 100644 src/exo/worker/tests/unittests/test_mlx/test_drafter_abstraction.py create mode 100644 src/exo/worker/tests/unittests/test_mlx/test_drafter_builder.py create mode 100644 src/exo/worker/tests/unittests/test_mlx/test_drafter_loader.py create mode 100644 src/exo/worker/tests/unittests/test_mlx/test_drafter_socket.py create mode 100644 src/exo/worker/tests/unittests/test_mlx/test_drafter_tuning.py create mode 100644 src/exo/worker/tests/unittests/test_mlx/test_eos_token_ids.py create mode 100644 src/exo/worker/tests/unittests/test_mlx/test_gemma4_mtp_hooks.py create mode 100644 src/exo/worker/tests/unittests/test_mlx/test_load_mlx_items_drafter_id.py create mode 100644 src/exo/worker/tests/unittests/test_mlx/test_num_draft_tokens_consensus.py create mode 100644 src/exo/worker/tests/unittests/test_mlx/test_pipelined_drafter.py create mode 100644 src/exo/worker/tests/unittests/test_mlx/test_qwen3_5_dflash_hooks.py create mode 100644 src/exo/worker/tests/unittests/test_mlx/test_remote_drafter.py create mode 100644 src/exo/worker/tests/unittests/test_mlx/test_spec_diag_gating.py create mode 100644 src/exo/worker/tests/unittests/test_mlx/test_utils_mlx_bind_retry.py create mode 100644 src/exo/worker/tests/unittests/test_mlx/test_utils_mlx_broadcast.py create mode 100644 src/exo/worker/tests/unittests/test_runner/test_adaptive_k_gate.py create mode 100644 src/exo/worker/tests/unittests/test_runner/test_batch_generator_errors.py create mode 100644 src/exo/worker/tests/unittests/test_runner/test_responses_tool_compat.py create mode 100644 src/exo/worker/tests/unittests/test_runner/test_sequential_generator_batch_prefill.py create mode 100644 src/exo/worker/tests/unittests/test_runner/test_sequential_generator_errors.py create mode 100644 src/exo/worker/tests/unittests/test_worker_instance_backoff.py create mode 100755 tests/auto_bench.sh delete mode 100644 tests/conftest.py create mode 100755 tests/eval_tool_calls.sh delete mode 100644 tests/framework.py create mode 100755 tests/get_all_models_on_cluster.py create mode 100644 tests/headless_runner.py create mode 100755 tests/run_exo_on.sh create mode 100755 tests/start_distributed_test.py delete mode 100644 tests/test_1node.py delete mode 100644 tests/test_2node.py delete mode 100644 tests/test_4node.py delete mode 100644 tests/test_dashboard.py delete mode 100644 tests/test_resilience.py create mode 100644 tests/test_vision_cache.py delete mode 100644 tools/pyproject.toml delete mode 100644 tools/src/exo_tools/client.py delete mode 100644 tools/src/exo_tools/cluster.py diff --git a/.gitignore b/.gitignore index a73d27afa2..b162de342c 100644 --- a/.gitignore +++ b/.gitignore @@ -40,4 +40,3 @@ bench/**/*.json tmp/models /build/exo /.claude/skills -/.claude diff --git a/.mlx_typings/mlx_lm/models/cache.pyi b/.mlx_typings/mlx_lm/models/cache.pyi index 8641815ee4..3a05c86bb5 100644 --- a/.mlx_typings/mlx_lm/models/cache.pyi +++ b/.mlx_typings/mlx_lm/models/cache.pyi @@ -148,18 +148,21 @@ class QuantizedKVCache(_BaseCache): ... class KVCache(_BaseCache): - step = ... + step: int + keys: mx.array | None + values: mx.array | None + _idx: int def __init__(self) -> None: ... - def update_and_fetch(self, keys, values): # -> tuple[array | Any, array | Any]: - ... + def update_and_fetch( + self, keys: mx.array, values: mx.array + ) -> tuple[mx.array, mx.array]: ... @property def state( self, ) -> tuple[mx.array | None, mx.array | None]: ... @state.setter - def state(self, v) -> None: ... - def is_trimmable(self): # -> Literal[True]: - ... + def state(self, v: tuple[mx.array | None, mx.array | None]) -> None: ... + def is_trimmable(self) -> bool: ... def trim(self, n: int) -> int: ... def to_quantized( self, group_size: int = ..., bits: int = ... @@ -169,20 +172,19 @@ class KVCache(_BaseCache): ) -> mx.array | Literal["causal"] | None: ... class RotatingKVCache(_BaseCache): - step = ... + step: int keys: mx.array | None values: mx.array | None keep: int max_size: int _idx: int - def __init__(self, max_size, keep=...) -> None: ... + def __init__(self, max_size: int, keep: int = ...) -> None: ... def _trim( self, trim_size: int, v: mx.array, append: mx.array | None = ... ) -> mx.array: ... def update_and_fetch( - self, keys, values - ): # -> tuple[array | Any, array | Any] | tuple[array | Any, array | Any | None]: - ... + self, keys: mx.array, values: mx.array + ) -> tuple[mx.array, mx.array]: ... @property def state( self, diff --git a/.mlx_typings/mlx_lm/models/gemma4_text.pyi b/.mlx_typings/mlx_lm/models/gemma4_text.pyi index 728d91c108..a7ae787d59 100644 --- a/.mlx_typings/mlx_lm/models/gemma4_text.pyi +++ b/.mlx_typings/mlx_lm/models/gemma4_text.pyi @@ -10,37 +10,37 @@ from .switch_layers import SwitchGLU @dataclass class ModelArgs(BaseModelArgs): - model_type: str - hidden_size: int - num_hidden_layers: int - intermediate_size: int - num_attention_heads: int - head_dim: int - global_head_dim: int - global_partial_rotary_factor: float - rms_norm_eps: float - vocab_size: int - vocab_size_per_layer_input: int - num_key_value_heads: int - num_global_key_value_heads: Optional[int] - num_kv_shared_layers: int - pad_token_id: int - hidden_size_per_layer_input: int - rope_traditional: bool - partial_rotary_factor: float - rope_parameters: Optional[Dict[str, Any]] - sliding_window: int - sliding_window_pattern: int - max_position_embeddings: int - attention_k_eq_v: bool - final_logit_softcapping: float - use_double_wide_mlp: bool - enable_moe_block: bool - num_experts: Optional[int] - top_k_experts: Optional[int] - moe_intermediate_size: Optional[int] - layer_types: Optional[List[str]] - tie_word_embeddings: bool + model_type: str = ... + hidden_size: int = ... + num_hidden_layers: int = ... + intermediate_size: int = ... + num_attention_heads: int = ... + head_dim: int = ... + global_head_dim: int = ... + global_partial_rotary_factor: float = ... + rms_norm_eps: float = ... + vocab_size: int = ... + vocab_size_per_layer_input: int = ... + num_key_value_heads: int = ... + num_global_key_value_heads: Optional[int] = ... + num_kv_shared_layers: int = ... + pad_token_id: int = ... + hidden_size_per_layer_input: int = ... + rope_traditional: bool = ... + partial_rotary_factor: float = ... + rope_parameters: Optional[Dict[str, Any]] = ... + sliding_window: int = ... + sliding_window_pattern: int = ... + max_position_embeddings: int = ... + attention_k_eq_v: bool = ... + final_logit_softcapping: float = ... + use_double_wide_mlp: bool = ... + enable_moe_block: bool = ... + num_experts: Optional[int] = ... + top_k_experts: Optional[int] = ... + moe_intermediate_size: Optional[int] = ... + layer_types: Optional[List[str]] = ... + tie_word_embeddings: bool = ... def __post_init__(self) -> None: ... diff --git a/Cargo.lock b/Cargo.lock index d0ab25d7d4..96819c8216 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -916,13 +916,11 @@ dependencies = [ "libp2p", "log", "networking", - "pidfile-rs", "pin-project", "pyo3", "pyo3-async-runtimes", "pyo3-log", "pyo3-stub-gen", - "thiserror 2.0.17", "tokio", "util", ] @@ -966,16 +964,6 @@ version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3a3076410a55c90011c298b04d0cfa770b00fa04e1e3c97d3f6c9de105a03844" -[[package]] -name = "flopen" -version = "0.1.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fbfb8b5fbd1f27929f216650081a07b6ceb0741f0542c8c43ff7ef8e93a35a5d" -dependencies = [ - "libc", - "nix 0.31.2", -] - [[package]] name = "fnv" version = "1.0.7" @@ -1801,9 +1789,9 @@ checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" [[package]] name = "libc" -version = "0.2.186" +version = "0.2.178" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "68ab91017fe16c622486840e4c83c9a37afeff978bd239b5293d61ece587de66" +checksum = "37c93d8daa9d8a012fd8ab92f088405fb202ea0b6ab73ee2482ae66af4f42091" [[package]] name = "libp2p" @@ -2819,18 +2807,6 @@ dependencies = [ "libc", ] -[[package]] -name = "nix" -version = "0.31.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5d6d0705320c1e6ba1d912b5e37cf18071b6c2e9b7fa8215a1e8a7651966f5d3" -dependencies = [ - "bitflags 2.10.0", - "cfg-if", - "cfg_aliases", - "libc", -] - [[package]] name = "nohash-hasher" version = "0.2.0" @@ -3084,18 +3060,6 @@ dependencies = [ "siphasher", ] -[[package]] -name = "pidfile-rs" -version = "0.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d1a8aa9a30b1b65ef48b333931b80f2324a14e00208eb2b8f5788f1180791bcc" -dependencies = [ - "flopen", - "libc", - "log", - "thiserror 1.0.69", -] - [[package]] name = "pin-project" version = "1.1.10" @@ -3704,7 +3668,7 @@ dependencies = [ "netlink-packet-utils", "netlink-proto", "netlink-sys", - "nix 0.26.4", + "nix", "thiserror 1.0.69", "tokio", ] diff --git a/bench/eval_tool_calls.py b/bench/eval_tool_calls.py index 7b219bc92a..c2839fcf96 100644 --- a/bench/eval_tool_calls.py +++ b/bench/eval_tool_calls.py @@ -15,8 +15,9 @@ from typing import Any, Literal import httpx -from exo_tools.client import ExoClient, ExoHttpError -from exo_tools.harness import ( +from harness import ( + ExoClient, + ExoHttpError, add_common_instance_args, capture_cluster_snapshot, instance_id_from_instance, diff --git a/bench/exo_bench.py b/bench/exo_bench.py index 3322402b5e..50d835a290 100644 --- a/bench/exo_bench.py +++ b/bench/exo_bench.py @@ -30,8 +30,9 @@ from statistics import mean from typing import Any -from exo_tools.client import ExoClient, ExoHttpError -from exo_tools.harness import ( +from harness import ( + ExoClient, + ExoHttpError, add_common_instance_args, capture_cluster_snapshot, find_existing_instance, diff --git a/bench/exo_eval.py b/bench/exo_eval.py index 04b14e2090..6e0c1b403a 100644 --- a/bench/exo_eval.py +++ b/bench/exo_eval.py @@ -42,8 +42,9 @@ from typing import Any import httpx -from exo_tools.client import ExoClient, ExoHttpError -from exo_tools.harness import ( +from harness import ( + ExoClient, + ExoHttpError, add_common_instance_args, capture_cluster_snapshot, find_existing_instance, diff --git a/bench/harness.py b/bench/harness.py new file mode 100644 index 0000000000..ba6d0a7745 --- /dev/null +++ b/bench/harness.py @@ -0,0 +1,625 @@ +# type: ignore +from __future__ import annotations + +import argparse +import http.client +import json +import os +import time +from collections.abc import Iterator +from typing import Any +from urllib.parse import urlencode + +from loguru import logger + +_SETTLE_INITIAL_BACKOFF_S = 1.0 +_SETTLE_MAX_BACKOFF_S = 60.0 +_SETTLE_BACKOFF_MULTIPLIER = 2.0 + + +class ExoHttpError(RuntimeError): + def __init__(self, status: int, reason: str, body_preview: str): + super().__init__(f"HTTP {status} {reason}: {body_preview}") + self.status = status + + +class ExoClient: + def __init__(self, host: str, port: int, timeout_s: float = 7200.0): + self.host = host + self.port = port + self.timeout_s = timeout_s + + def request_json( + self, + method: str, + path: str, + params: dict[str, Any] | None = None, + body: dict[str, Any] | None = None, + headers: dict[str, str] | None = None, + ) -> Any: + if not path.startswith("/"): + path = "/" + path + if params: + path = path + "?" + urlencode(params) + + conn = http.client.HTTPConnection(self.host, self.port, timeout=self.timeout_s) + try: + payload: bytes | None = None + hdrs: dict[str, str] = {"Accept": "application/json"} + + if body is not None: + payload = json.dumps(body).encode("utf-8") + hdrs["Content-Type"] = "application/json" + if headers: + hdrs.update(headers) + + conn.request(method.upper(), path, body=payload, headers=hdrs) + resp = conn.getresponse() + raw = resp.read() + text = raw.decode("utf-8", errors="replace") if raw else "" + + if resp.status >= 400: + raise ExoHttpError(resp.status, resp.reason, text[:300]) + + if not text: + return None + return json.loads(text) + finally: + conn.close() + + def post_bench_chat_completions(self, payload: dict[str, Any]) -> dict[str, Any]: + return self.request_json("POST", "/bench/chat/completions", body=payload) + + def stream_bench_chat_completions(self, payload: dict[str, Any]) -> Iterator[str]: + """POST /bench/chat/completions with stream=True, yielding raw SSE lines.""" + payload = {**payload, "stream": True} + data = json.dumps(payload).encode("utf-8") + conn = http.client.HTTPConnection(self.host, self.port, timeout=self.timeout_s) + try: + conn.request( + "POST", + "/bench/chat/completions", + body=data, + headers={ + "Content-Type": "application/json", + "Accept": "text/event-stream", + }, + ) + resp = conn.getresponse() + if resp.status >= 400: + raw = resp.read().decode("utf-8", errors="replace") + raise ExoHttpError(resp.status, resp.reason, raw[:300]) + for line in resp: + yield line.decode("utf-8", errors="replace") + finally: + conn.close() + + def get_state_path(self, path: str) -> Any: + try: + return self.request_json("GET", f"/state/{path}") + except ExoHttpError as e: + if e.status == 404: + return None + raise + + def get_instance(self, instance_id: str) -> dict[str, Any] | None: + return self.get_state_path(f"instances/{instance_id}") + + def get_runner(self, runner_id: str) -> dict[str, Any] | None: + return self.get_state_path(f"runners/{runner_id}") + + def get_node_downloads(self, node_id: str) -> list[dict[str, Any]] | None: + return self.get_state_path(f"downloads/{node_id}") + + def get_node_disk(self, node_id: str) -> dict[str, Any] | None: + return self.get_state_path(f"nodeDisk/{node_id}") + + def get_node_system(self, node_id: str) -> dict[str, Any] | None: + return self.get_state_path(f"nodeSystem/{node_id}") + + def get_node_identities(self) -> dict[str, Any] | None: + return self.get_state_path("nodeIdentities") + + def get_topology(self) -> dict[str, Any] | None: + return self.get_state_path("topology") + + +def unwrap_instance(instance: dict[str, Any]) -> dict[str, Any]: + if len(instance) != 1: + raise KeyError(f"Expected 1 key, got keys={list(instance.keys())}") + + tag = next(iter(instance)) + inner = instance[tag] + if not isinstance(inner, dict): + raise TypeError(f"payload for {tag} must be dict, got {type(inner)}") + return inner + + +def instance_id_from_instance(instance: dict[str, Any]) -> str: + inner = unwrap_instance(instance) + return str(inner["instanceId"]) + + +def nodes_used_in_instance(instance: dict[str, Any]) -> int: + inner = unwrap_instance(instance) + return len(inner["shardAssignments"]["nodeToRunner"]) + + +def runner_ids_from_instance(instance: dict[str, Any]) -> list[str]: + inner = unwrap_instance(instance) + runner_to_shard = inner["shardAssignments"]["runnerToShard"] + return list(runner_to_shard.keys()) + + +def node_ids_from_instance(instance: dict[str, Any]) -> list[str]: + inner = unwrap_instance(instance) + return list(inner["shardAssignments"]["nodeToRunner"].keys()) + + +def runner_ready(runner: dict[str, Any]) -> bool: + return "RunnerReady" in runner + + +def runner_failed(runner: dict[str, Any]) -> bool: + return "RunnerFailed" in runner + + +def get_runner_failed_message(runner: dict[str, Any]) -> str | None: + if "RunnerFailed" in runner: + return runner["RunnerFailed"].get("errorMessage") + return None + + +def wait_for_instance_ready( + client: ExoClient, instance_id: str, timeout: float = 24000.0 +) -> None: + start_time = time.time() + instance_existed = False + last_loaded: dict[str, int] = {} + while time.time() - start_time < timeout: + instance = client.get_instance(instance_id) + + if instance is None: + if instance_existed: + raise RuntimeError( + f"Instance {instance_id} was deleted (runner may have failed)" + ) + time.sleep(0.1) + continue + + instance_existed = True + rids = runner_ids_from_instance(instance) + + all_ready = True + for rid in rids: + runner = client.get_runner(rid) or {} + if runner_failed(runner): + error_msg = get_runner_failed_message(runner) or "Unknown error" + raise RuntimeError(f"Runner {rid} failed: {error_msg}") + if "RunnerLoading" in runner: + loading = runner["RunnerLoading"] + loaded = loading.get("layersLoaded", 0) + total = loading.get("totalLayers", 0) + if total > 0 and last_loaded.get(rid) != loaded: + last_loaded[rid] = loaded + logger.debug(f"Runner {rid}: loading layers {loaded}/{total}") + if not runner_ready(runner): + all_ready = False + + if all_ready: + return + + time.sleep(0.1) + + raise TimeoutError(f"Instance {instance_id} did not become ready within {timeout=}") + + +def wait_for_instance_gone( + client: ExoClient, instance_id: str, timeout: float = 3.0 +) -> None: + start_time = time.time() + while time.time() - start_time < timeout: + try: + client.request_json("GET", f"/instance/{instance_id}") + time.sleep(0.4) + except ExoHttpError as e: + if e.status == 404: + return + raise + + raise TimeoutError(f"Instance {instance_id} did not get deleted within {timeout=}") + + +def capture_cluster_snapshot(client: ExoClient) -> dict[str, Any]: + snapshot: dict[str, Any] = {} + identities = client.get_node_identities() + if identities: + snapshot["nodeIdentities"] = identities + topology = client.get_topology() + if topology: + snapshot["topology"] = topology + node_memory = client.get_state_path("nodeMemory") + if node_memory: + snapshot["nodeMemory"] = node_memory + node_system = client.get_state_path("nodeSystem") + if node_system: + snapshot["nodeSystem"] = node_system + return snapshot + + +def resolve_model_short_id( + client: ExoClient, model_arg: str, *, force_download: bool = False +) -> tuple[str, str]: + models = client.request_json("GET", "/models") or {} + data = models.get("data") or [] + + for m in data: + if (m.get("name") or "").lower() == model_arg.lower(): + short_id = str(m["name"]) + full_id = str(m.get("hugging_face_id") or m["name"]) + return short_id, full_id + + for m in data: + if m.get("hugging_face_id") == model_arg: + short_id = str(m["name"]) + full_id = str(m["hugging_face_id"]) + return short_id, full_id + + if force_download and "/" in model_arg: + logger.info(f"Model not in /models, adding from HuggingFace: {model_arg}") + result = client.request_json( + "POST", "/models/add", body={"model_id": model_arg} + ) + if result: + short_id = str(result.get("name") or model_arg.rsplit("/", 1)[-1]) + full_id = str(result.get("hugging_face_id") or model_arg) + return short_id, full_id + + raise ValueError(f"Model not found in /models: {model_arg}") + + +def placement_filter(instance_meta: str, wanted: str) -> bool: + s = (instance_meta or "").lower() + if wanted == "both": + return ("ring" in s) or ("jaccl" in s) + return wanted in s + + +def sharding_filter(sharding: str, wanted: str) -> bool: + s = (sharding or "").lower() + if wanted == "both": + return ("pipeline" in s) or ("tensor" in s) + return wanted in s + + +def fetch_and_filter_placements( + client: ExoClient, + full_model_id: str, + args: argparse.Namespace, + node_id: str | None = None, +) -> list[dict[str, Any]]: + params: dict[str, str] = {"model_id": full_model_id} + if node_id is not None: + params["node_ids"] = node_id + previews_resp = client.request_json("GET", "/instance/previews", params=params) + previews = previews_resp.get("previews") or [] + + selected: list[dict[str, Any]] = [] + for p in previews: + if p.get("error") is not None: + continue + if not placement_filter(str(p.get("instance_meta", "")), args.instance_meta): + continue + if not sharding_filter(str(p.get("sharding", "")), args.sharding): + continue + + instance = p.get("instance") + if not isinstance(instance, dict): + continue + + n = nodes_used_in_instance(instance) + # Skip tensor ring single node as it is pointless when pipeline ring + if n == 1 and ( + (args.sharding == "both" and "tensor" in p.get("sharding", "").lower()) + or ( + args.instance_meta == "both" + and "jaccl" in p.get("instance_meta", "").lower() + ) + ): + continue + + if ( + args.skip_pipeline_jaccl + and ( + args.instance_meta == "both" + and "jaccl" in p.get("instance_meta", "").lower() + ) + and ( + args.sharding == "both" and "pipeline" in p.get("sharding", "").lower() + ) + ): + continue + + if ( + args.skip_tensor_ring + and ( + args.instance_meta == "both" + and "ring" in p.get("instance_meta", "").lower() + ) + and (args.sharding == "both" and "tensor" in p.get("sharding", "").lower()) + ): + continue + + if args.min_nodes <= n <= args.max_nodes: + selected.append(p) + + return selected + + +def settle_and_fetch_placements( + client: ExoClient, + full_model_id: str, + args: argparse.Namespace, + settle_timeout: float = 0, + node_id: str | None = None, +) -> list[dict[str, Any]]: + selected = fetch_and_filter_placements(client, full_model_id, args, node_id=node_id) + + if not selected and settle_timeout > 0: + backoff = _SETTLE_INITIAL_BACKOFF_S + deadline = time.monotonic() + settle_timeout + while not selected and time.monotonic() < deadline: + remaining = deadline - time.monotonic() + logger.warning( + f"No valid placements yet (cluster may still be settling). " + f"Retrying in {backoff:.1f}s ({remaining:.0f}s remaining)..." + ) + time.sleep(min(backoff, remaining)) + backoff = min(backoff * _SETTLE_BACKOFF_MULTIPLIER, _SETTLE_MAX_BACKOFF_S) + selected = fetch_and_filter_placements( + client, full_model_id, args, node_id=node_id + ) + + return selected + + +def run_planning_phase( + client: ExoClient, + full_model_id: str, + preview: dict[str, Any], + danger_delete: bool, + timeout: float, + settle_deadline: float | None, +) -> float | None: + """Check disk space and ensure model is downloaded before benchmarking. + + Returns the wall-clock download duration in seconds if a fresh download + was needed, or None if the model was already cached on all nodes. + """ + # Get model size from /models + models = client.request_json("GET", "/models") or {} + model_bytes = 0 + for m in models.get("data", []): + if m.get("hugging_face_id") == full_model_id: + model_bytes = m.get("storage_size_megabytes", 0) * 1024 * 1024 + break + + if not model_bytes: + logger.warning( + f"Could not determine size for {full_model_id}, skipping disk check" + ) + return None + + # Get nodes from preview + inner = unwrap_instance(preview["instance"]) + node_ids = list(inner["shardAssignments"]["nodeToRunner"].keys()) + runner_to_shard = inner["shardAssignments"]["runnerToShard"] + + needs_download = False + + for node_id in node_ids: + node_downloads = client.get_node_downloads(node_id) or [] + + already_downloaded = any( + "DownloadCompleted" in p + and unwrap_instance(p["DownloadCompleted"]["shardMetadata"])["modelCard"][ + "modelId" + ] + == full_model_id + for p in node_downloads + ) + if already_downloaded: + continue + + needs_download = True + + disk_info = client.get_node_disk(node_id) or {} + backoff = _SETTLE_INITIAL_BACKOFF_S + while not disk_info and settle_deadline and time.monotonic() < settle_deadline: + remaining = settle_deadline - time.monotonic() + logger.info( + f"Waiting for disk info on {node_id} ({remaining:.0f}s remaining)..." + ) + time.sleep(min(backoff, remaining)) + backoff = min(backoff * _SETTLE_BACKOFF_MULTIPLIER, _SETTLE_MAX_BACKOFF_S) + disk_info = client.get_node_disk(node_id) or {} + + if not disk_info: + logger.warning(f"No disk info for {node_id}, skipping space check") + continue + + avail = disk_info.get("available", {}).get("inBytes", 0) + if avail >= model_bytes: + continue + + if not danger_delete: + raise RuntimeError( + f"Insufficient disk on {node_id}: need {model_bytes // (1024**3)}GB, " + f"have {avail // (1024**3)}GB. Use --danger-delete-downloads to free space." + ) + + completed = [ + ( + unwrap_instance(p["DownloadCompleted"]["shardMetadata"])["modelCard"][ + "modelId" + ], + p["DownloadCompleted"]["total"]["inBytes"], + ) + for p in node_downloads + if "DownloadCompleted" in p + and not p["DownloadCompleted"].get("readOnly", False) + ] + for del_model, size in sorted(completed, key=lambda x: x[1]): + logger.info(f"Deleting {del_model} from {node_id} ({size // (1024**2)}MB)") + client.request_json("DELETE", f"/download/{node_id}/{del_model}") + avail += size + if avail >= model_bytes: + break + + if avail < model_bytes: + raise RuntimeError(f"Could not free enough space on {node_id}") + + # Start downloads (idempotent) + download_t0 = time.perf_counter() if needs_download else None + for node_id in node_ids: + runner_id = inner["shardAssignments"]["nodeToRunner"][node_id] + shard = runner_to_shard[runner_id] + client.request_json( + "POST", + "/download/start", + body={ + "targetNodeId": node_id, + "shardMetadata": shard, + }, + ) + logger.info(f"Started download on {node_id}") + + # Wait for downloads (no timeout — poll until complete or failed) + while True: + all_done = True + for node_id in node_ids: + node_downloads = client.get_node_downloads(node_id) or [] + done = any( + "DownloadCompleted" in p + and unwrap_instance(p["DownloadCompleted"]["shardMetadata"])[ + "modelCard" + ]["modelId"] + == full_model_id + for p in node_downloads + ) + failed = [ + p["DownloadFailed"]["errorMessage"] + for p in node_downloads + if "DownloadFailed" in p + and unwrap_instance(p["DownloadFailed"]["shardMetadata"])["modelCard"][ + "modelId" + ] + == full_model_id + ] + if failed: + raise RuntimeError(f"Download failed on {node_id}: {failed[0]}") + if not done: + all_done = False + ongoing = [ + p + for p in node_downloads + if "DownloadOngoing" in p + and unwrap_instance(p["DownloadOngoing"]["shardMetadata"])[ + "modelCard" + ]["modelId"] + == full_model_id + ] + if ongoing: + prog = ongoing[0]["DownloadOngoing"]["downloadProgress"] + speed_mb = prog.get("speed", 0) / (1024 * 1024) + eta_s = prog.get("etaMs", 0) / 1000 + dl_bytes = prog.get("downloaded", {}).get("inBytes", 0) + total_bytes = prog.get("total", {}).get("inBytes", 0) + pct = (dl_bytes / total_bytes * 100) if total_bytes else 0 + logger.info( + f"Downloading on {node_id}: {pct:.1f}% @ {speed_mb:.1f} MB/s, " + f"ETA {eta_s:.0f}s " + f"({prog.get('completedFiles', 0)}/{prog.get('totalFiles', 0)} files)" + ) + if all_done: + if download_t0 is not None: + return time.perf_counter() - download_t0 + return None + time.sleep(10) + + +def find_existing_instance(client: ExoClient, model_id: str) -> str | None: + """Find an existing running instance for the given model.""" + try: + state = client.request_json("GET", "/state") + except Exception: + return None + for inst_id, inst in state.get("instances", {}).items(): + # Instance structure is nested: {"MlxJacclInstance": {"shardAssignments": {"modelId": ...}}} + for _inst_type, inner in inst.items(): + if not isinstance(inner, dict): + continue + sa = inner.get("shardAssignments", {}) + if sa.get("modelId") == model_id: + return inst_id + return None + + +def add_common_instance_args(ap: argparse.ArgumentParser) -> None: + ap.add_argument("--host", default=os.environ.get("EXO_HOST", "localhost")) + ap.add_argument( + "--port", type=int, default=int(os.environ.get("EXO_PORT", "52415")) + ) + ap.add_argument("--model", required=True, help="Model short id or huggingface id") + ap.add_argument( + "--force-download", + action="store_true", + help="If model not in /models, add it from HuggingFace via exo and download.", + ) + ap.add_argument( + "--max-nodes", + type=int, + default=4, + help="Only consider placements using <= this many nodes.", + ) + ap.add_argument( + "--min-nodes", + type=int, + default=1, + help="Only consider placements using >= this many nodes.", + ) + ap.add_argument( + "--instance-meta", choices=["ring", "jaccl", "both"], default="both" + ) + ap.add_argument( + "--sharding", choices=["pipeline", "tensor", "both"], default="both" + ) + ap.add_argument( + "--skip-pipeline-jaccl", + action="store_true", + help="Skip pipeline+jaccl placements, as it's often pointless.", + ) + ap.add_argument( + "--skip-tensor-ring", + action="store_true", + help="Skip tensor+ring placements, as it's so slow.", + ) + ap.add_argument( + "--timeout", type=float, default=7200.0, help="HTTP timeout (seconds)." + ) + ap.add_argument( + "--settle-timeout", + type=float, + default=60.0, + help="Max seconds to wait for the cluster to produce valid placements (0 = try once).", + ) + ap.add_argument( + "--danger-delete-downloads", + action="store_true", + help="Delete existing models from smallest to largest to make room for benchmark model.", + ) + ap.add_argument( + "--reuse-instance", + action="store_true", + help="Reuse an existing running instance for this model instead of creating a new one.", + ) diff --git a/bench/prefill_decode_bench.py b/bench/prefill_decode_bench.py index 588ddb0631..375b66e847 100644 --- a/bench/prefill_decode_bench.py +++ b/bench/prefill_decode_bench.py @@ -35,8 +35,9 @@ load_tokenizer_for_bench, parse_int_list, ) -from exo_tools.client import ExoClient, ExoHttpError -from exo_tools.harness import ( +from harness import ( + ExoClient, + ExoHttpError, add_common_instance_args, instance_id_from_instance, node_ids_from_instance, diff --git a/bench/results/dflash/REPORT.md b/bench/results/dflash/REPORT.md new file mode 100644 index 0000000000..f66755b1d6 --- /dev/null +++ b/bench/results/dflash/REPORT.md @@ -0,0 +1,490 @@ +# DFlash coupled-drafter benchmarks (Qwen 3.5 + Qwen 3.6) + +A/B benchmarks of z-lab's DFlash block-diffusion coupled drafters +against the corresponding MLX-quantized targets on Apple Silicon. +Numerical validations of the DFlash dispatch path +(`CoupledDrafterKind="dflash"`) on real hybrid Qwen targets +(gated-delta-net + full-attention, `full_attention_interval=4`) at +single-device, multi-device tensor-parallel, and the headline +122B-A10B MoE scaled across two nodes via JACCL over a Thunderbolt- +bridge RDMA edge. + +## Headlines across four targets + +| Target | Quant | Arch | host | Target gen_tps | DFlash gen_tps | Speedup | Accept | +|---|---|---|---|---:|---:|---:|---:| +| Qwen3.5 4B | 8bit | dense | wc-smbp | 97.24 | 404.38 | **4.16x** | 93.2% | +| Qwen3.6 27B | 8bit | dense | wc-smbpt | 14.98 | 49.13 | **3.28x** | 92.6% | +| Qwen3.6 35B-A3B | 8bit | MoE | wc-smbpt | 87.70 | 377.49 | **4.30x** | 92.6% | +| Qwen3.5 122B-A10B (TP2) | 8bit | MoE | smbp+smbpt | 52.61 | 159.00 | **3.02x** | 93.75% | + +All medians are over 10 runs per A/B side (5 scenarios × 2 runs). +The +316% Qwen 3.5 4B result was **not** a sweet spot — DFlash holds +above 3.02x at every scale tested, including the 122B-A10B MoE +running across two nodes with tensor parallelism and RDMA. + +The MoE 35B-A3B is particularly striking: it's the second-fastest +target-only generation of the three (because only ~3B params are +active per token), yet DFlash still delivers a 4.30x speedup on top +of that fast baseline. The combination yields **377 t/s steady-state +generation on a 35B-class model on a single MacBook Pro M5 Max**. + +The 122B-A10B result is the first end-to-end DFlash measurement on a +multi-node tensor-parallel placement. The coupled-drafter dispatch +now works through the `Sharding.Tensor` + `InstanceMeta.MlxJaccl` +loader path: each TP rank replicates the (small) DFlash drafter +weights and consumes the post-all-reduce hidden state in-process, +producing identical draft tokens + bonus samples in lockstep across +ranks under the shared `mx.random.seed(seed)` set at the top of each +generation step. 122B-class steady-state generation thus jumps from +**~53 t/s → ~159 t/s** without sacrificing accuracy. + +## Qwen 3.6 27B (dense) — 3.28x + +Target: `mlx-community/Qwen3.6-27B-8bit` (28 GB on disk, 64 layers, +hidden_size 5120, 48 linear-attn + 16 full-attn, +`full_attention_interval=4`, `head_dim=256`). + +Drafter: `z-lab/Qwen3.6-27B-DFlash` (3.2 GB, 6-layer +block-diffusion drafter, `block_size=16`, 60 target layers indexed). + +Per-scenario gen_tps is the mean of the 2 runs per scenario; +DFlash columns exclude one 0-token factual_qa run and one 0-token +short_repetitive run on the DFlash side from the *mean* but they're +still counted in the all-scenario median (see "Bench harness +flakiness" below). The all-scenario median row mirrors what the +harness reported live (`runs=8` for DFlash after auto-filtering +zero-token rows, `runs=10` for target-only). + +| Scenario | Target gen_tps | DFlash gen_tps | Speedup | Accept | +|------------------------|---------------:|---------------:|--------:|-------:| +| short_repetitive | 17.90 | 51.43 | 2.87x | 93.0% | +| code_completion | 16.72 | 33.45 | 2.00x | 86.5% | +| creative_prose | 14.98 | 55.16 | 3.68x | 92.2% | +| factual_qa | 12.72 | 24.60 | 1.93x | 82.0% | +| long_context_summary | 10.73 | 56.21 | 5.24x | 92.8% | +| **all-scenario median**| **14.98** | **49.13** | **3.28x** | **92.8%** | + +`long_context_summary` is the standout: DFlash recovers ~5.2x on +long-context generation, because the target spends a lot of wall +time per token at this scale and the speculation has more head room +to mask the per-token cost. + +`factual_qa` and `code_completion` were noisier this run with a few +80-87% acceptance pockets that dropped scenario throughput. With +larger N (more runs per scenario) the per-scenario speedup would +likely tighten back into the 3-4x band the other scenarios sit in. + +## Qwen 3.6 35B-A3B (MoE) — 4.30x + +Target: `mlx-community/Qwen3.6-35B-A3B-8bit` (35 GB on disk, 40 layers, +256 experts × 8 active per token, hidden_size 2048, +`moe_intermediate_size=512`, `head_dim=256`). + +Drafter: `z-lab/Qwen3.6-35B-A3B-DFlash` (905 MB, 8-layer dense +block-diffusion drafter, `block_size=16`, +`target_layer_ids=[1, 10, 19, 28, 37]`). + +| Scenario | Target gen_tps | DFlash gen_tps | Speedup | Accept | +|------------------------|---------------:|---------------:|--------:|-------:| +| short_repetitive | 89.91 | 256.96 | 2.86x | 90.4% | +| code_completion | 88.19 | 413.88 | 4.69x | 93.0% | +| creative_prose | 87.52 | 213.86 | 2.44x | 46.5%* | +| factual_qa | 86.82 | 287.39 | 3.31x | 89.8% | +| long_context_summary | 85.68 | 411.02 | 4.80x | 93.8% | +| **all-scenario median**| **87.70** | **377.49** | **4.30x** | **92.4%** | + +*creative_prose run 1 collapsed to 0% acceptance (23.57 t/s) on a +single run while run 0 stayed at 93.0% acceptance (404.15 t/s). The +mean is dragged down. Re-running with more samples per scenario +would tighten this. The median over the **9 healthy runs out of 10** +remains 388.67 t/s — i.e. the median is ~4.4x. + +short_repetitive's first DFlash run came in at 125 t/s (Metal kernel +cold compile, same pattern as the 4B bench); run 2 jumped to 388 t/s. +The cold run pulls the mean down. Excluding it, the steady-state +speedup is closer to **4.5x**. + +**Architectural note:** the MoE wires through our existing +`Qwen3_5DFlashTargetAdapter` with zero MoE-specific vendor work. +`mlx_lm.models.qwen3_5_moe` is a thin sanitize-wrapper around +`qwen3_5.Model`; MoE-vs-dense routing happens inside +`qwen3_5.DecoderLayer` via `SparseMoeBlock` vs `MLP` on `layer.mlp`, +and the vendored `_decoder_layer_forward_with_capture` already calls +`layer.mlp` polymorphically. The 4.30x speedup is the same code path, +unchanged. + +## Qwen 3.5 4B (dense) — 4.16x (previously reported) + +For completeness; full per-scenario breakdown elided here, see the +raw JSON next to this report. + +| Scenario | Target gen_tps | DFlash gen_tps | Speedup | Accept | +|------------------------|---------------:|---------------:|--------:|-------:| +| short_repetitive | 97.24 | 310.57 | 3.19x | 93.2% | +| code_completion | 97.19 | 371.43 | 3.82x | 92.0% | +| creative_prose | 97.52 | 407.37 | 4.18x | 93.2% | +| factual_qa | 95.80 | 449.87 | 4.70x | 93.4% | +| long_context_summary | 94.28 | 396.04 | 4.20x | 93.2% | +| **all-scenario median**| **97.24** | **404.38** | **4.16x** | **93.2%** | + +## Qwen 3.5 122B-A10B (MoE) — multi-node tensor parallel, DFlash A/B — 3.02x + +Target: `mlx-community/Qwen3.5-122B-A10B-8bit` (130 GB on disk, +48 layers, hidden_size 3072, 128 experts × 8 active per token, +~10B active params / 122B total, `num_key_value_heads=2`, +`full_attention_interval=4`). + +Drafter: `z-lab/Qwen3.5-122B-A10B-DFlash` (~0.5 GB, replicated on +each TP rank). Numerical validation of the multi-device coupled- +drafter dispatch path landed in commit `worker: lift single-device +gate on coupled-drafter loader + dispatch` — the loader now resolves +`coupled_drafter` for `Sharding.Tensor` placements and the generator +routes `draft_mode="model"` through the coupled adapter on every +rank. + +Placement: `Sharding.Tensor` + `InstanceMeta.MlxJaccl`, 2 nodes +(`wc-smbp` + `wc-smbpt`, both Apple M5 Max MacBook Pros, 128 GB +unified memory each). The two machines auto-discovered each other +via mDNS on the shared `192.168.1.0/24` LAN and established a direct +RDMA edge over their thunderbolt-bridge interfaces +(`rdma_en1 ⇌ rdma_en2`, ~4 ms ping). exo's JACCL backend used the +RDMA edge for tensor-parallel all-reduces during decode. + +| Scenario | Target gen_tps | DFlash gen_tps | Speedup | Accept | +|------------------------|---------------:|---------------:|--------:|-------:| +| short_repetitive | 53.98 | 138.84 | 2.57x | 90.8% | +| code_completion | 52.67 | 148.50 | 2.82x | 93.8% | +| creative_prose | 52.32 | 162.92 | 3.11x | 93.8% | +| factual_qa | 52.29 | 163.53 | 3.13x | 93.8% | +| long_context_summary | 52.22 | 158.18 | 3.03x | 93.8% | +| **all-scenario median**| **52.61** | **159.00** | **3.02x** | **93.75%** | + +The DFlash band is tight (138-168 t/s across 10 runs), and the +target-only band is even tighter (49.52-54.42 t/s). The MoE sparsity +(~10B active params per token) plus JACCL's RDMA all-reduce keep +per-token wall time consistent regardless of prompt shape. TTFT was +~750-870 ms for short prompts and 2.6 s for the 2 K-token +`long_context_summary` prompt — prefill all-reduce overhead scales +with prompt length but disappears once decode starts. + +For context against the single-node DFlash benches above: + +| Comparison row | Target gen_tps | DFlash gen_tps | Speedup | Notes | +|-----------------------------------|---------------:|---------------:|--------:|-------| +| 122B-A10B TP2 (this) | **52.61** | **159.00** | **3.02x** | 2 nodes via JACCL/RDMA | +| 35B-A3B single-node | 87.70 | 377.49 | 4.30x | 1 node, smaller MoE | +| 27B single-node | 14.98 | 49.13 | 3.28x | 1 node, dense | +| 4B single-node | 97.24 | 404.38 | 4.16x | 1 node, dense | + +**159 t/s steady-state on a 122B-class MoE running across two +consumer MacBook Pros over RDMA** is the headline. The DFlash speedup +ratio (3.02x) is slightly below the single-node range (3.28-4.30x) +because the per-round TP all-reduce now sits on a 4 ms RDMA hop +rather than within-chip GPU shared memory, which raises the +verifier's serial overhead per spec round. Acceptance stays at 93.75% +across the same five scenarios as single-node DFlash, confirming the +multi-rank coupled-drafter dispatch is numerically equivalent +(byte-identical draft tokens across ranks under the shared +`mx.random.seed(seed)`). + +### How the multi-device coupled-drafter path stays correct + +Three guarantees keep the per-rank coupled drafters in lockstep: + +1. **Identical hidden states.** TP shards within-layer matmuls but + reduces the output before the residual stream. Every rank ends up + with the same hidden state after each layer's `ShardedToAllLinear` + / `ShardedMoE` all-sum (and the captured `GdnState` shards rewind + identically per rank because each rank captured its own head + slice). + +2. **Identical drafter state.** The DFlash drafter (~0.5 GB) is + replicated on every TP rank — same weights, same per-step inputs, + same deterministic forward pass. + +3. **Identical sampling.** `mx.random.seed(task.seed or 42)` is set + once at the top of `_mlx_generate` on every rank, so the drafter + token-by-token sampling and the verifier's bonus sampling + advance the PRNG state in lockstep across ranks. Same RNG draws, + same accept/reject decisions, same KV trim / SSM rewind sequence + on every round. + +The result: target rank 0 and target rank 1 produce a byte-identical +output token stream under TP2 DFlash, exactly matching what a single- +node DFlash placement would produce if the 122B-A10B fit in 128 GB +(it doesn't — that's the whole reason for the TP2 placement). + +## Reading the numbers + +DFlash's speedup ratio holds remarkably steady across a **17.5x** target +size range (4B → 35B) and across architectures (dense → MoE): + +- 4B dense: 4.16x +- 27B dense: 3.28x +- 35B-A3B MoE: 4.30x + +The 27B dense is the lowest in the band, and the explanation is +simple: it's the **most memory-bound** of the three (largest weights +in active path per token), so target-only is already drag-limited; +DFlash speeds up the wall-clock but the absolute headroom is smaller +in tokens/sec terms. + +Acceptance lands at ~92-93% across all three targets, which is the +real story: DFlash's block-diffusion drafting strategy is robust +enough that the verifier accepts ~14-15 of every 16 drafted tokens +regardless of target scale or sparsity pattern. **Speedup ≈ accept × +block_size / serial-overhead**, and the accept rate is the dominant +term that DFlash optimizes against. + +### Compared to MTP on Gemma 4 (bench/results/mtp/REPORT.md) + +| Target | Drafter | Median speedup | Best-scenario speedup | +|-------------------|---------|---------------:|----------------------:| +| Gemma 4 26B-A4B | MTP | -1.6% | +22.1% (code) | +| Gemma 4 31B | MTP | +5.4% | +13.2% (code) | +| Qwen 3.5 4B | DFlash | +316% | +370% (factual_qa) | +| Qwen 3.6 27B | DFlash | +228% | +424% (long_context) | +| Qwen 3.6 35B-A3B | DFlash | +330% | +380% (long_context) | + +MTP appends a single drafter MLP head and proposes the next K tokens +autoregressively, so acceptance falls off quickly with prompt entropy +and worst-case scenarios actually regress (the 26B-A4B summary). DFlash +drafts the **entire block of 16 tokens in parallel** via block +diffusion, which is why acceptance stays consistently high across all +scenarios — every DFlash bench above stayed within a narrow 88-94% +acceptance band, while MTP on Gemma 4's `long_context_summary` fell +into single-digit acceptance. + +## Bench harness flakiness + +Across the 30 DFlash runs in this report (10 each for 4B / 27B / +35B-A3B), 3 runs returned `generation_tokens=0` and 1 run returned +0% acceptance: + +- 4B: 0 hiccups +- 27B: 2 hiccups (short_repetitive run 1, factual_qa run 1) — both + `error: null` in the harness but the server returned no body +- 35B-A3B: 1 hiccup (creative_prose run 1 collapsed to 0% accept) + +These are bench-harness / chat-completion-streaming hiccups, not +DFlash failures — the chat-completion request returned an empty +response or a partial one without an error code. The runs adjacent to +each hiccup on the *same scenario* completed normally at the expected +speedup. The all-scenario median treats the hiccup runs as data +points (i.e. doesn't filter them), so the reported median is a +*lower-bound* estimate of true steady-state speedup. + +For a publication-grade headline number, future benches should use +`--runs 5` (or `--runs 10`) instead of `--runs 2` to smooth out these +outliers. The current `--runs 2` was chosen for fast feedback during +implementation. + +## Setup + +- Hosts: + - 4B bench: **wc-smbp** (Apple M5 Max MacBook Pro, 128 GB unified memory) + - 27B + 35B-A3B benches: **wc-smbpt** (Apple M5 Max MacBook Pro, 128 GB + unified memory, ~83 GB free vs ~13 GB on wc-smbp during the 4B run) + - 122B-A10B TP2 bench: **wc-smbp + wc-smbpt** (both M5 Max, ~100 GB + free per node after `sudo purge`, JACCL RDMA over thunderbolt-bridge + `rdma_en1 ⇌ rdma_en2`, mDNS auto-discovery on shared 192.168.1.0/24 + LAN, ~4 ms RTT) +- Stack: MLX 0.32.0.dev, mlx_vlm 0.5.0, mlx_lm 0.31.3 +- exo branch: `team-wcv/bench/gemma4-mtp-coupled-results`, + including the dtype + first-bonus shape fixes documented inline below +- Harness: `bench/drafter_bench.py`, `--runs 2 --max-tokens 256`, + 5 scenarios (short_repetitive, code_completion, creative_prose, + factual_qa, long_context_summary) +- Modes: `EXO_DRAFT_MODE=none` (target-only) vs `EXO_DRAFT_MODE=model` + (DFlash coupled; auto-detected via `mlx_vlm.speculative.drafters. + load_drafter(..., kind=None)` → `kind="dflash"`) +- Model cards (declaring `coupled_drafter=...`): + - `mlx-community--Qwen3.5-4B-MLX-8bit.toml` + - `mlx-community--Qwen3.6-27B-8bit.toml` + - `mlx-community--Qwen3.6-35B-A3B-8bit.toml` + +## How to reproduce + +### Single-node DFlash A/B (4B / 27B / 35B-A3B) + +```bash +# 1. Download target + drafter pairs (first run only). Token required +# for z-lab/Qwen3.6-27B-DFlash (gated; click "agree" on HF first). +uv run python -c ' +from huggingface_hub import snapshot_download +for repo in [ + "mlx-community/Qwen3.5-4B-MLX-8bit", + "z-lab/Qwen3.5-4B-DFlash", + "mlx-community/Qwen3.6-27B-8bit", + "z-lab/Qwen3.6-27B-DFlash", + "mlx-community/Qwen3.6-35B-A3B-8bit", + "z-lab/Qwen3.6-35B-A3B-DFlash", +]: + snapshot_download(repo)' + +# 2. Symlink into ~/.exo/models/ — see /tmp/qwen36_dflash_bench.sh +# on either host for the exact ln -sfn invocations. + +# 3. Run the A/B harness per target +/tmp/qwen36_dflash_bench.sh "mlx-community/Qwen3.6-27B-8bit" "qwen3.6-27b-mlx-8bit" +/tmp/qwen36_dflash_bench.sh "mlx-community/Qwen3.6-35B-A3B-8bit" "qwen3.6-35b-a3b-mlx-8bit" +``` + +The bench script alternates `EXO_DRAFT_MODE=none` and +`EXO_DRAFT_MODE=model`, restarting exo between scenarios, and writes +per-request JSON to `bench/results/dflash/