Add native MTP for Qwen3.6 MLX models#2110
Open
ffrappo wants to merge 1 commit into
Open
Conversation
9c3bf31 to
42c86bd
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
This PR adds native multi-token prediction for Qwen3.6/Qwen3.5-style MLX checkpoints that include in-checkpoint MTP layers. The MTP heads draft candidate tokens, and the target model verifies those candidates before any token is emitted.
What ships:
temperature/top_p/top_k/min_pdistribution.Behavior
Native MTP is enabled by default for supported single-node model cards. It dispatches only when the card declares
native_mtp, the local checkpoint has recoverable MTP weights, and the model is placed on one node. Multi-node placements continue to use the normal path in this PR. Operators can disable the feature withEXO_NATIVE_MTP_ENABLED=0or the macOS setting.Correctness
Native MTP keeps the target model in charge of emission:
temperature/top_p/top_k/min_p.Current scope:
native_mtp.default_k/native_mtp.max_k.logits_processorssuch as repetition/presence/frequency penalties are not routed through native MTP yet.Performance
Broad prompt set,
max_prompt_tokens=32,max_tokens=64:Summary: 27B reaches +97.2% / 1.97x at K=2. 35B-A3B reaches +15.8% / 1.16x at K=1 in the broad sweep, and higher K is not automatically better on that path.
Implementation Notes
EXO_NATIVE_MTP_ASYNC_DRAFT_EVAL=0disables this path.Validation
uv run basedpyrightuv run ruff checknix fmtuv run pytest src -q