Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ Changelog
- Add the ``day0-release`` agent skill (``.agents/skills/day0-release/``), a deterministic end-to-end driver that chains the PTQ → evaluation → comparison skills (the evaluation stage deploys the checkpoint itself) with an enforced gate after each stage and returns a publish decision (ACCEPT / REGRESSION / ANOMALOUS / INFEASIBLE). Ships three GPU-free, unit-tested gate scripts (``gate_ptq.py``, ``gate_run.py``, ``gate_compare.py``) that validate checkpoint coverage, evaluation-run completeness, and baseline-vs-candidate accuracy threshold. v1 reports and stops on regression; the recipe-search loop is deferred.
- Add **streaming** speculative-decoding training (EAGLE3 / DFlash): the draft trains on base-model hidden states produced on the fly by a co-located ``vllm serve`` (no disk dump), moved trainer-side over NIXL RDMA, scaling to multi-node (dedicated serve replicas + DDP trainers). New launcher examples for NVFP4 Kimi-K2.5 / K2.6 on GB200/aarch64 under ``tools/launcher/examples/moonshotai/``.
- Add a fused Triton fast path for ``local_hessian`` NVFP4 weight-scale search (the Hessian-weighted FP8-E4M3 scale sweep). For each NVFP4 block it minimizes ``dwᵀ H dw`` over the 126 candidate scales using the per-cin-block local Hessian on tensor cores, replacing the per-weight Python reference sweep — roughly **34x** faster on a single 8192x4096 weight and bit-exact with the reference for fp32/fp16 weights. Used automatically during ``local_hessian`` calibration for both dense and fused-MoE expert weights; falls back to the reference sweep on CPU, when Triton is unavailable, or via ``MODELOPT_NVFP4_TRITON_SWEEP=0``.
- Add **Domino** speculative-decoding training: the parallel DFlash draft backbone plus a lightweight GRU causal correction head, selected via ``dflash_architecture_config.projector_type=domino``. Trained with a base/final dual loss whose ``dflash_lambda_base_start``/``dflash_lambda_base_decay_ratio`` curriculum decays the base-loss weight 1→0. Exports in the z-lab drafter format; recipe at ``modelopt_recipes/general/speculative_decoding/domino.yaml``. Training only — the inference path is not wired up yet.

0.45 (2026-07-02)
^^^^^^^^^^^^^^^^^
Expand Down
8 changes: 8 additions & 0 deletions examples/speculative_decoding/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
ModelOptMedusaRecipe,
ModelOptSpeculativeRecipeBase,
)
from modelopt.torch.speculative.plugins.hf_domino import DominoLambdaCallback
from modelopt.torch.speculative.plugins.hf_training_args import (
TrainingArguments as SpecTrainingArgs,
)
Expand Down Expand Up @@ -295,6 +296,13 @@ def train():
and recipe.eagle.eagle_base_lora_warmup_steps > 0
):
callbacks.append(LoRAWarmupCallback(recipe.eagle.eagle_base_lora_warmup_steps))
# Domino (dflash recipe with projector_type=domino) needs the lambda_base
# curriculum schedule driven by the trainer's global step.
if (
isinstance(recipe, ModelOptDFlashRecipe)
and recipe.dflash.dflash_architecture_config.get("projector_type") == "domino"
):
callbacks.append(DominoLambdaCallback())
# Leave training_args.ignore_data_skip at its default (False). The dataset is
# map-style, so HF Trainer's resume skips consumed indices at the batch-sampler
# level (accelerate.skip_first_batches) without re-fetching them, landing at the
Expand Down
32 changes: 32 additions & 0 deletions modelopt/torch/export/plugins/hf_spec_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,3 +445,35 @@ def export(self, export_dir: Path | str, dtype: torch.dtype | None = None):
f"Exported DFlash draft model: {len(drafter_sd)} tensors, "
f"config keys: {list(drafter_config.keys())[:5]}..."
)


class DominoExporter(DFlashExporter):
"""Draft model exporter for Domino (DFlash backbone + causal correction head).

Same z-lab-compatible format as DFlash, plus the Domino head weights
(``prefix_gru.*`` / ``embed_proj.*``, already captured by the inherited
``dflash_module.`` stripping) and the extra config fields the loader needs to
rebuild the head (``projector_type``, ``emb_dim``, ``gru_hidden_dim``,
``pure_draft_prefix_len``, ``shift_label``).
"""

def _export_config(self):
"""Extend the DFlash config with the Domino head fields."""
config = super()._export_config()
draft_config = self.model.dflash_config

# Present because HFDominoModel.modify validates them at convert time.
emb_dim = draft_config.emb_dim
gru_hidden_dim = draft_config.gru_hidden_dim
# Mirror the reference checkpoint: emb_dim also appears at the top level.
config["emb_dim"] = emb_dim
config["dflash_config"].update(
{
"projector_type": getattr(draft_config, "projector_type", "domino"),
"shift_label": getattr(draft_config, "shift_label", True),
"pure_draft_prefix_len": getattr(draft_config, "pure_draft_prefix_len", 1),
"gru_hidden_dim": gru_hidden_dim,
"emb_dim": emb_dim,
}
)
return config
21 changes: 21 additions & 0 deletions modelopt/torch/speculative/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,27 @@ class DFlashConfig(ModeloptBaseConfig):
),
)

dflash_lambda_base_start: float = ModeloptField(
default=1.0,
ge=0.0,
le=1.0,
description=(
"Domino only: initial weight of the base (backbone-only) loss in the "
"loss = (1 - lambda)*final + lambda*base mixture; linearly decayed to 0. "
"Ignored unless dflash_architecture_config.projector_type == 'domino'."
),
)

dflash_lambda_base_decay_ratio: float = ModeloptField(
default=1.0,
gt=0.0,
le=1.0,
description=(
"Domino only: fraction of total training steps over which lambda_base "
"decays from dflash_lambda_base_start to 0."
),
)
Comment thread
coderabbitai[bot] marked this conversation as resolved.


class MedusaConfig(ModeloptBaseConfig):
"""Medusa config."""
Expand Down
35 changes: 26 additions & 9 deletions modelopt/torch/speculative/dflash/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,26 +24,43 @@
from ..config import DFlashConfig

DFlashDMRegistry = _DMRegistryCls(prefix="DFlash") # global instance for the registry
# Domino reuses the dflash mode/config/recipe but converts the base model to a
# DFlash module augmented with a causal correction head. It is selected via
# ``dflash_architecture_config.projector_type == "domino"`` and lives in its own
# registry so its wrapper (HFDominoModel) does not overwrite HFDFlashModel.
DominoDMRegistry = _DMRegistryCls(prefix="Domino")


def convert_to_dflash_model(model: nn.Module, config: DFlashConfig) -> ConvertReturnType:
"""Convert the model to a DFlash model as per `config`."""
"""Convert the model to a DFlash (or Domino) model as per `config`."""
model = model.init_modellike() if isinstance(model, ModelLikeModule) else model

original_cls = type(model)
if original_cls not in DFlashDMRegistry:
for cls in DFlashDMRegistry._registry:
if issubclass(original_cls, cls):
DFlashDMRegistry.register({original_cls: "base_model_class"})(DFlashDMRegistry[cls])
break

# merge custom config with default config (lazy import to avoid circular)
from .default_config import default_dflash_config

custom_config = config.dflash_architecture_config
config.dflash_architecture_config = {**default_dflash_config, **custom_config}

dflash_model = DFlashDMRegistry.convert(model)
# Route to the Domino registry when the architecture asks for the Domino head.
projector_type = config.dflash_architecture_config.get("projector_type")
if projector_type == "domino":
registry = DominoDMRegistry
elif projector_type in (None, "dflash"):
registry = DFlashDMRegistry
else:
raise ValueError(
f"Unsupported dflash_architecture_config.projector_type: {projector_type!r}. "
"Expected 'dflash' (default) or 'domino'."
)

original_cls = type(model)
if original_cls not in registry:
for cls in registry._registry:
if issubclass(original_cls, cls):
registry.register({original_cls: "base_model_class"})(registry[cls])
break

dflash_model = registry.convert(model)
dflash_model.modify(config)

metadata = {}
Expand Down
1 change: 1 addition & 0 deletions modelopt/torch/speculative/plugins/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,6 @@

with import_plugin("transformers"):
from .hf_dflash import *
from .hf_domino import *
from .hf_eagle import *
from .hf_medusa import *
8 changes: 7 additions & 1 deletion modelopt/torch/speculative/plugins/hf_dflash.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,9 @@ def modify(self, config):

self._find_base_model_parts()

self.dflash_module = DFlashModule(self.dflash_config)
# Factory hook: subclasses (e.g. Domino) override to build an augmented
# draft module while reusing all of DFlash's modify() setup.
self.dflash_module = self._build_draft_module(self.dflash_config)
# Match base model dtype/device. Skip if base is on meta (during from_pretrained
# restore — the model will be moved to the correct device after weight loading).
if self.dflash_offline:
Expand All @@ -216,6 +218,10 @@ def modify(self, config):
self.is_quantized = False
self._num_anchors = self.dflash_num_anchors

def _build_draft_module(self, dflash_config):
"""Build the draft module. Subclasses override to use an augmented module."""
return DFlashModule(dflash_config)

def get_exporter(self):
"""Get the exporter for the DFlash draft model."""
from modelopt.torch.export.plugins.hf_spec_export import DFlashExporter
Expand Down
Loading
Loading