Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
72 commits
Select commit Hold shift + click to select a range
c53193a
Migrate ModelCard.drafter_model_id -> drafter_model_ids list
jw-wcv May 7, 2026
4fe0f75
Drafter tuning: K, warmup, short-skip (items 1+3+8)
jw-wcv May 7, 2026
3cc39ab
Drafter telemetry on GenerationStats + dashboard (item 4)
jw-wcv May 7, 2026
ea131f7
Adaptive draft depth based on rolling acceptance (item 7)
jw-wcv May 7, 2026
ec2990d
Drafter KV prefix caching (item 6)
jw-wcv May 7, 2026
fb858b0
Per-request drafter overrides (item 9)
jw-wcv May 7, 2026
b79cf88
Warn when drafter-aware model gets multi-node placement (item 10)
jw-wcv May 7, 2026
7d89477
Fix drafter cache build during warmup (no prefix cache wired)
jw-wcv May 7, 2026
702dcad
Match drafter cache offset to target via stream_generate prefill
jw-wcv May 7, 2026
e96fa64
Force plain KVCache for drafter; mlx_lm spec_step crawls on RotatingK…
jw-wcv May 7, 2026
165d690
Force plain KVCache for target too when drafter active (RotatingKVCac…
jw-wcv May 7, 2026
af04a03
Spec-decoding path bypasses exo prefill+prefix cache; mlx_lm prefills…
jw-wcv May 7, 2026
e2cd6f6
Add diagnostic logs around spec decode
jw-wcv May 7, 2026
f3a009c
Use mlx_lm native caches on spec-decoding path
jw-wcv May 7, 2026
4454d0f
Wire drafter KV prefix cache + manual drafter prefill on spec path
jw-wcv May 7, 2026
2f52b82
Add Drafter abstraction + n-gram drafting strategy
jw-wcv May 7, 2026
b62c96e
Add DrafterTransport interface + PipelinedModelDrafter (Layer A)
jw-wcv May 7, 2026
5f61ad8
Implement RemoteTransport + drafter_serve_loop over mx.distributed
jw-wcv May 7, 2026
24a945c
Add asymmetric drafter placement (model card opt-in + N+1 rank topology)
jw-wcv May 7, 2026
db3c86f
Split mx.distributed group into target subgroup + drafter rank
jw-wcv May 7, 2026
93a4403
Add DrafterRunner + plan helpers for asymmetric drafter rank
jw-wcv May 7, 2026
8add455
Wire pipelined+remote drafter at target rank (asymmetric N+1 path)
jw-wcv May 7, 2026
bf5a864
Add drafter benchmark harness for A/B comparing draft modes
jw-wcv May 7, 2026
effd749
Fix drafter_bench SSE streaming to capture generation_stats
jw-wcv May 7, 2026
de220fa
Skip bound_shard for drafter rank in RunnerSupervisor
jw-wcv May 7, 2026
d54eaf8
persist node ids in .cache
Evanev7 Feb 24, 2026
2925f3e
Reconcile worker instance backoff from state
jw-wcv May 7, 2026
6fcc9c3
Tune cluster liveness polling cadence
jw-wcv May 7, 2026
4853127
fix: make darwin mdns discovery reliable
AlexCheema May 3, 2026
83fcd4c
Resolve drafter rank in worker._start_runner_task via all_node_to_runner
jw-wcv May 7, 2026
2fdf194
Auto-upgrade single-node target to MlxJaccl when drafter is asymmetric
jw-wcv May 7, 2026
69f8d0d
Avoid Group.split for V1 N=1 asymmetric drafter placement
jw-wcv May 7, 2026
adb87b3
Pass None for V1 N=1 target_subgroup instead of Python stub
jw-wcv May 7, 2026
5e17371
Diagnostic: log drafter wire frames
jw-wcv May 7, 2026
189b102
Fix drafter-rank warmup race that stalls target in RunnerLoaded
jw-wcv May 7, 2026
49611f5
Short-circuit mx_all_gather_tasks when group is None to avoid drafter…
jw-wcv May 7, 2026
4a92978
Force-flush mx.distributed.send results in RemoteTransport / drafter_…
jw-wcv May 7, 2026
e37262c
Extend bench tool + scaffold EAGLE/lookahead drafters
jw-wcv May 7, 2026
4891984
Fix K=8-cancel regression: clamp K, runner survives, peer cascade
jw-wcv May 7, 2026
5cb3eab
Document EAGLE/lookahead Apple Silicon ceiling + ship EAGLE3 converter
jw-wcv May 7, 2026
456bbb3
Round-robin SequentialGenerator: lift the singular-slot ceiling
jw-wcv May 7, 2026
3fec7ad
Lift asymmetric drafter concurrency cap via session-aware wire protocol
jw-wcv May 8, 2026
98c48d4
Auto-place prefill-only siblings on prefill_eligible_nodes
jw-wcv May 8, 2026
890c7ae
Batch K prefills into one forward to cut long-prompt TTFT outliers
jw-wcv May 8, 2026
7afb054
Coalesce burst-arriving generation tasks before first step()
jw-wcv May 8, 2026
4ac0fe9
Add diagnostic logs for burst-coalesce + admit eligibility
jw-wcv May 8, 2026
74812f2
Raise EXO_BURST_COALESCE_MS default to 200ms
jw-wcv May 8, 2026
5835160
Drop docs/loom-design.md from this branch
jw-wcv May 8, 2026
79bfaf2
Downgrade burst-coalesce/admit chatter to debug for solo-runner case
jw-wcv May 8, 2026
398440c
Drain all pending work items between step() iterations
jw-wcv May 8, 2026
70830e8
Lift per-task ack-wait gate for warm-runner generation tasks
jw-wcv May 8, 2026
8bee4dc
Batch verify-loop sampling to one host-device sync per round
jw-wcv May 8, 2026
bb2b0d3
Mark ban_token_ids position-independent so spec-decode fast path fires
jw-wcv May 8, 2026
80ab341
Decouple asymmetric drafter wire from mx.distributed (v3 socket trans…
jw-wcv May 8, 2026
3f35d8e
Fix find_ip_prioritised arg order in drafter placement
jw-wcv May 8, 2026
2c0a25f
Implement multi-target asymmetric drafter for V2 sharded targets
jw-wcv May 8, 2026
873baf2
Harden V2-multi spec loop: determinism, drift detection, validation
jw-wcv May 8, 2026
84b7888
Recover from drafter death and worker disconnects
jw-wcv May 8, 2026
6d77ac7
Fix multi-target spec loop hang on master/rank divergence
jw-wcv May 9, 2026
702a9ac
Switch task agreement to two-phase intersection consensus
jw-wcv May 9, 2026
097887f
Align distributed-helper collectives to default stream
jw-wcv May 9, 2026
8f0d7d3
Switch hot-path broadcast from all_sum to send/recv
jw-wcv May 9, 2026
e98247d
Add EXO_PROBE_BROADCAST debug flag for spec-decode wire diagnostics
jw-wcv May 9, 2026
a452e9b
Move spec-decode int broadcasts to TCP fanout
jw-wcv May 9, 2026
7e6cc36
Stringify target_peer_hosts_by_rank keys for JSON round-trip
jw-wcv May 9, 2026
bbe0e16
Add spec-diag logging to localize asymmetric drafter rank-0 hang
jw-wcv May 9, 2026
5ae282b
Add side-channel + per-round spec-diag checkpoints
jw-wcv May 9, 2026
13cb061
Force eval of verify logits to launch TP all-reduce on every target rank
jw-wcv May 9, 2026
927f6db
Gate spec-decode diagnostics behind EXO_SPEC_DIAG env var
jw-wcv May 9, 2026
67810ea
Surface drafter telemetry through the API and per-request logs
jw-wcv May 9, 2026
14b18d6
Add drafter A/B and concurrent bench scripts
jw-wcv May 9, 2026
059e6fd
Isolate test_model_cards_drafter from operator-local custom cards
jw-wcv May 9, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 164 additions & 0 deletions bench/bench_compare.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
"""Drafter vs no-drafter A/B bench for the asymmetric cluster.

For each length in --lengths, runs the same prompt twice: once with
``use_drafter=True``, once with ``use_drafter=False``. Reports per-run
TPS, drafter telemetry, and the speedup ratio.

Sleeps briefly between runs so the model isn't warm-cache for one and
cold for the other; first run of each length pair is the
"throw-away" warmup, subsequent are timed (when --warmup is set).
"""

from __future__ import annotations

import argparse
import json
import time
import urllib.request
from typing import Final

API_URL: Final[str] = "http://192.168.1.224:52415/v1/chat/completions"
MODEL: Final[str] = "mlx-community/gemma-4-31b-it-bf16"

PROMPT: Final[str] = (
"Write a detailed, comprehensive technical reference on distributed "
"speculative decoding for large language models. Cover the following "
"topics in depth, with examples, equations, and pseudocode where "
"relevant: (1) architectural foundations of speculative decoding, "
"(2) the role of drafter vs target models and how acceptance/rejection "
"is computed, (3) multi-token prediction (MTP) heads vs separate drafter "
"models, (4) tensor-parallel verification and KV cache rollback semantics, "
"(5) asymmetric placement on heterogeneous clusters, (6) wire-protocol "
"design for drafter/target IPC, (7) failure modes (drafter death, target "
"rank crashes, partitions) and recovery strategies, (8) tuning K (draft "
"depth) for different workloads, (9) integration with continuous batching "
"and paged attention, (10) practical performance results from real "
"deployments. Use markdown headings and detailed prose. Begin now."
)


def run_once(max_tokens: int, use_drafter: bool, timeout: int) -> dict[str, object]:
body: dict[str, object] = {
"model": MODEL,
"messages": [{"role": "user", "content": PROMPT}],
"max_tokens": max_tokens,
"temperature": 0.0,
"stream": False,
"use_drafter": use_drafter,
}
payload = json.dumps(body).encode("utf-8")
request = urllib.request.Request(
API_URL,
data=payload,
headers={"Content-Type": "application/json"},
method="POST",
)
started = time.monotonic()
with urllib.request.urlopen(request, timeout=timeout) as resp: # noqa: S310 - lan
raw = resp.read().decode("utf-8")
wall = time.monotonic() - started
parsed = json.loads(raw)
usage = parsed.get("usage") or {}
completion = int(usage.get("completion_tokens", 0))
stats = parsed.get("generation_stats") or {}
return {
"max_tokens": max_tokens,
"use_drafter": use_drafter,
"wall_s": round(wall, 2),
"completion_tokens": completion,
"tps_total": round(completion / wall, 2) if wall > 0 else 0.0,
"drafter_model_id": stats.get("drafter_model_id"),
"draft_mode": stats.get("draft_mode"),
"num_draft_tokens": stats.get("num_draft_tokens"),
"accepted_draft_tokens": stats.get("accepted_draft_tokens"),
"proposed_draft_tokens": stats.get("proposed_draft_tokens"),
"spec_decode_rounds": stats.get("spec_decode_rounds"),
"acceptance_rate": (
round(stats["accepted_draft_tokens"] / stats["proposed_draft_tokens"], 3)
if stats.get("proposed_draft_tokens")
else None
),
"fraction_from_drafter": (
round(stats["accepted_draft_tokens"] / completion, 3)
if stats.get("accepted_draft_tokens") and completion
else None
),
"finish_reason": parsed.get("choices", [{}])[0].get("finish_reason"),
}


def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--lengths", type=int, nargs="+", default=[256, 1024, 2048])
parser.add_argument("--timeout", type=int, default=900)
parser.add_argument("--out", type=str, default="/tmp/bench_compare.json")
parser.add_argument(
"--sleep-between",
type=float,
default=2.0,
help="Seconds to sleep between runs to let the master settle.",
)
args = parser.parse_args()

results: list[dict[str, object]] = []
summary: list[dict[str, object]] = []
for length in args.lengths:
print(f"\n=== max_tokens={length} ===", flush=True)
# Run no-drafter first; the drafter run inherits a warm prompt cache
# via prefix-cache-hit, but max_tokens drives the bulk of the
# measured time so this is fine for steady-state TPS comparison.
for use_drafter in (False, True):
try:
r = run_once(length, use_drafter, args.timeout)
except Exception as exc: # noqa: BLE001 - report bench failure
r = {
"max_tokens": length,
"use_drafter": use_drafter,
"error": f"{type(exc).__name__}: {exc}",
}
results.append(r)
print(json.dumps(r, indent=2), flush=True)
time.sleep(args.sleep_between)

# Speedup summary for this length pair
no_draft = next(
(
r
for r in results
if r.get("max_tokens") == length and not r.get("use_drafter")
),
None,
)
draft = next(
(
r
for r in results
if r.get("max_tokens") == length and r.get("use_drafter")
),
None,
)
if no_draft and draft and "error" not in no_draft and "error" not in draft:
tps_no = float(no_draft.get("tps_total", 0.0) or 0)
tps_yes = float(draft.get("tps_total", 0.0) or 0)
speedup = round(tps_yes / tps_no, 3) if tps_no > 0 else None
row = {
"max_tokens": length,
"tps_no_drafter": tps_no,
"tps_drafter": tps_yes,
"speedup_x": speedup,
"acceptance_rate": draft.get("acceptance_rate"),
"fraction_from_drafter": draft.get("fraction_from_drafter"),
}
print(f"\n>>> speedup at {length}: {json.dumps(row)}", flush=True)
summary.append(row)

out = {"summary": summary, "raw": results}
with open(args.out, "w", encoding="utf-8") as fh:
json.dump(out, fh, indent=2)
print("\n=== overall summary ===", flush=True)
print(json.dumps(summary, indent=2), flush=True)
print(f"Saved: {args.out}", flush=True)


if __name__ == "__main__":
main()
150 changes: 150 additions & 0 deletions bench/bench_concurrent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
"""Concurrent overlapping spec-decode bench for the asymmetric cluster.

Fires N parallel chat-completions requests (each with the drafter
enabled) at the master and measures:
- Per-request wall time, completion tokens, individual TPS
- Aggregate cluster TPS (sum of per-request tokens / max wall)
- Time-to-first-token spread

The point: validate that EXO_MAX_CONCURRENT_REQUESTS > 1 actually
overlaps spec-decode sessions correctly. Single-rank-target placements
trivially share KV; multi-rank tensor-parallel placements with an
asymmetric drafter are the interesting case here.
"""

from __future__ import annotations

import argparse
import json
import threading
import time
import urllib.request
from typing import Final

API_URL: Final[str] = "http://192.168.1.224:52415/v1/chat/completions"
MODEL: Final[str] = "mlx-community/gemma-4-31b-it-bf16"

PROMPTS: Final[list[str]] = [
"Explain the architecture of distributed speculative decoding in "
"one paragraph, then list six common failure modes with mitigations.",
"Write a 400-word technical brief on tensor-parallel KV cache "
"rollback semantics, including pseudocode for accept/reject.",
"Outline the difference between MTP heads and external drafter "
"models, and discuss when each is preferable for low-latency serving.",
"Describe how an n-gram drafter integrates with a transformer "
"target model, with attention to stateful processors and RNG.",
"Summarize the trade-offs of pipelined vs synchronous spec-decode, "
"including their interaction with continuous batching.",
"Walk through the wire protocol of a drafter-target IPC channel "
"designed for sub-millisecond round-trip on local sockets.",
"Explain how acceptance probability is computed for vanilla "
"speculative decoding and when greedy acceptance is sound.",
"Discuss the engineering trade-offs between more drafter heads "
"and more drafter depth for a fixed quality bar.",
]


def run_one(
idx: int,
prompt: str,
max_tokens: int,
timeout: int,
results: list[dict[str, object]],
started_at: float,
) -> None:
body = {
"model": MODEL,
"messages": [{"role": "user", "content": prompt}],
"max_tokens": max_tokens,
"temperature": 0.0,
"stream": False,
"use_drafter": True,
}
payload = json.dumps(body).encode("utf-8")
req = urllib.request.Request(
API_URL,
data=payload,
headers={"Content-Type": "application/json"},
method="POST",
)
relative_start = time.monotonic() - started_at
t0 = time.monotonic()
try:
with urllib.request.urlopen(req, timeout=timeout) as resp: # noqa: S310 - lan
raw = resp.read().decode("utf-8")
wall = time.monotonic() - t0
parsed = json.loads(raw)
usage = parsed.get("usage", {})
completion = int(usage.get("completion_tokens", 0))
result: dict[str, object] = {
"idx": idx,
"relative_start_s": round(relative_start, 2),
"wall_s": round(wall, 2),
"completion_tokens": completion,
"tps": round(completion / wall, 2) if wall > 0 else 0.0,
"finish_reason": parsed.get("choices", [{}])[0].get("finish_reason"),
"first_64": (
parsed.get("choices", [{}])[0]
.get("message", {})
.get("content", "")[:64]
),
}
except Exception as exc: # noqa: BLE001 - report bench failure
wall = time.monotonic() - t0
result = {
"idx": idx,
"relative_start_s": round(relative_start, 2),
"wall_s": round(wall, 2),
"error": f"{type(exc).__name__}: {exc}",
}
results.append(result)
print(json.dumps(result), flush=True)


def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--concurrency", type=int, default=4)
parser.add_argument("--max-tokens", type=int, default=512)
parser.add_argument("--timeout", type=int, default=600)
parser.add_argument("--out", type=str, default="/tmp/bench_concurrent.json")
args = parser.parse_args()

n = args.concurrency
prompts = (PROMPTS * ((n + len(PROMPTS) - 1) // len(PROMPTS)))[:n]

results: list[dict[str, object]] = []
started_at = time.monotonic()
threads: list[threading.Thread] = []
for i, p in enumerate(prompts):
thread = threading.Thread(
target=run_one,
args=(i, p, args.max_tokens, args.timeout, results, started_at),
)
threads.append(thread)
thread.start()
for thread in threads:
thread.join()
total_wall = time.monotonic() - started_at

completed = [r for r in results if "error" not in r]
total_tokens = sum(int(r.get("completion_tokens", 0)) for r in completed)
aggregate_tps = round(total_tokens / total_wall, 2) if total_wall > 0 else 0.0
summary = {
"concurrency": n,
"max_tokens": args.max_tokens,
"total_wall_s": round(total_wall, 2),
"total_tokens_completed": total_tokens,
"aggregate_tps": aggregate_tps,
"successful": len(completed),
"failed": n - len(completed),
"individual": sorted(results, key=lambda r: int(r["idx"])),
}
print(json.dumps(summary, indent=2), flush=True)

with open(args.out, "w", encoding="utf-8") as fh:
json.dump(summary, fh, indent=2)
print(f"Saved: {args.out}", flush=True)


if __name__ == "__main__":
main()
Loading