Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ Changelog

**New Features**

- Add the **D-PACE** loss objective for DFlash speculative-decoding training (`arXiv:2605.18810 <https://arxiv.org/abs/2605.18810>`_). Set ``dflash_loss_objective: dpace`` to replace the static exponential position decay with dynamic, confidence-derived per-position weights that adapt to whichever block positions currently limit acceptance. Smoothing is controlled by ``dflash_dpace_alpha`` (default 0.5). Training-only and detached from the gradient (no architecture or inference change).
- Add the ``day0-release`` agent skill (``.agents/skills/day0-release/``), a deterministic end-to-end driver that chains the PTQ → evaluation → comparison skills (the evaluation stage deploys the checkpoint itself) with an enforced gate after each stage and returns a publish decision (ACCEPT / REGRESSION / ANOMALOUS / INFEASIBLE). Ships three GPU-free, unit-tested gate scripts (``gate_ptq.py``, ``gate_run.py``, ``gate_compare.py``) that validate checkpoint coverage, evaluation-run completeness, and baseline-vs-candidate accuracy threshold. v1 reports and stops on regression; the recipe-search loop is deferred.
- Add **streaming** speculative-decoding training (EAGLE3 / DFlash): the draft trains on base-model hidden states produced on the fly by a co-located ``vllm serve`` (no disk dump), moved trainer-side over NIXL RDMA, scaling to multi-node (dedicated serve replicas + DDP trainers). New launcher examples for NVFP4 Kimi-K2.5 / K2.6 on GB200/aarch64 under ``tools/launcher/examples/moonshotai/``.

Expand Down
30 changes: 30 additions & 0 deletions examples/speculative_decoding/doc/dflash.md
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,8 @@ See [`modelopt_recipes/general/speculative_decoding/dflash.yaml`](../../../model
| `dflash.dflash_block_size` | 8 | Block size for parallel prediction |
| `dflash.dflash_num_anchors` | 512 | Random anchor positions per sample (see below) |
| `dflash.dflash_loss_decay_factor` | 4.0 | Exponential decay gamma (0 disables, see below) |
| `dflash.dflash_loss_objective` | `decay` | Position weighting: `decay` (static) or `dpace` (dynamic, see below) |
| `dflash.dflash_dpace_alpha` | 0.5 | D-PACE smoothing factor in (0, 1]; only used when objective is `dpace` |
| `dflash.dflash_self_logit_distillation` | true | Use target model logits as soft labels (vs hard CE) |
| `dflash.dflash_architecture_config.num_hidden_layers` | 5 | Draft decoder layers |
| `dflash.dflash_architecture_config.mask_token_id` | auto | Token ID for masked positions |
Expand Down Expand Up @@ -239,6 +241,34 @@ Note: this is different from EAGLE3's `eagle_loss_decay_factor` which multiplies
`alpha^step` across TTT steps. DFlash decay operates within a single block, weighting
early positions higher because they gate acceptance of all later positions.

### D-PACE (Dynamic Position-Aware Cross-Entropy)

Set `dflash.dflash_loss_objective: dpace` to replace the static decay with **D-PACE**
([arXiv:2605.18810](https://arxiv.org/abs/2605.18810)), which derives per-position weights
from a differentiable surrogate of expected accepted block length. Where static decay uses
a fixed schedule, D-PACE adapts to the draft's own per-position confidence and shifts
training signal toward whichever positions currently limit acceptance as the drafter improves.

For each block, let `q_i = exp(-CE_i)` be the draft confidence on the target token at
predicted position `i`. D-PACE smooths it (Eq.7) and weights each position by the suffix-sum
of prefix products (Eq.8):

```text
q~_i = (1 - alpha) * q_i + alpha
w_j = sum_{m >= j} prod_{i <= m} q~_i # detached; multiplies the per-token CE
```

The weight factors into the prefix-acceptance probability (`prod_{i<=j} q~_i`) times the
remaining accepted-length value, so it directly targets expected accepted length. The
weights are detached from the gradient — D-PACE only reshapes credit assignment and adds
~2.3% training overhead with no change to the draft architecture or inference.

- `dflash_dpace_alpha` is the asymmetric smoothing floor (`q~_i >= alpha`) that keeps later
weights from vanishing. Stable in `[0.3, 0.7]`; `alpha=0` is rejected (cumulative product
collapses), and `alpha → 1` flattens toward uniform weighting. Default `0.5`.
- D-PACE is mutually exclusive with `dflash_loss_decay_factor`; when objective is `dpace`,
the decay factor is ignored.

### Checkpoint Resume

DFlash supports checkpoint resume transparently. Rotary embeddings are lazily
Expand Down
18 changes: 17 additions & 1 deletion modelopt/torch/speculative/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,23 @@ class DFlashConfig(ModeloptBaseConfig):
dflash_loss_decay_factor: float = ModeloptField(
default=0.0,
description="Gamma for exponential loss decay weighting (paper Eq.4). "
"Suggested: 7 for block_size=16, 5 for 10, 4 for 8. 0 disables.",
"Suggested: 7 for block_size=16, 5 for 10, 4 for 8. 0 disables. "
"Only used when dflash_loss_objective='decay'.",
)

dflash_loss_objective: str = ModeloptField(
default="decay",
description="Block-position loss weighting objective. 'decay' uses the static "
"exponential decay of dflash_loss_decay_factor (DFlash, arXiv:2602.06036 Eq.4). "
"'dpace' uses dynamic, confidence-derived per-position weights "
"(D-PACE, arXiv:2605.18810 Eq.8).",
)

dflash_dpace_alpha: float = ModeloptField(
default=0.5,
description="D-PACE asymmetric smoothing factor alpha in (0, 1] (paper Eq.7). Used only "
"when dflash_loss_objective='dpace'. Stable in [0.3, 0.7]; alpha=0 is degenerate "
"(cumulative product vanishes) and alpha->1 removes the adaptive signal.",
)

dflash_num_anchors: int = ModeloptField(
Expand Down
12 changes: 12 additions & 0 deletions modelopt/torch/speculative/dflash/dflash_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,18 @@ def modify(self, config):
self.dflash_block_size = config.dflash_block_size
self.dflash_freeze_base_model = config.dflash_freeze_base_model
self.dflash_loss_decay_factor = config.dflash_loss_decay_factor
self.dflash_loss_objective = config.dflash_loss_objective
self.dflash_dpace_alpha = config.dflash_dpace_alpha
if self.dflash_loss_objective not in ("decay", "dpace"):
raise ValueError(
f"dflash_loss_objective must be 'decay' or 'dpace', got "
f"{self.dflash_loss_objective!r}"
)
if self.dflash_loss_objective == "dpace" and not 0.0 < self.dflash_dpace_alpha <= 1.0:
raise ValueError(
f"dflash_dpace_alpha must be in (0, 1] for the D-PACE objective, got "
f"{self.dflash_dpace_alpha}"
)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[SUGGESTION] Consider warning (or rejecting) when dflash_loss_objective == "dpace" and dflash_loss_decay_factor != 0.0 (i.e. the user has explicitly set both). The default recipe modelopt_recipes/general/speculative_decoding/dflash.yaml already sets dflash_loss_decay_factor: 4.0, so a user who only flips dflash_loss_objective: dpace won't realize their non-default decay value is silently ignored (the doc notes the mutual exclusion, but the runtime is silent). A logger.warning(...) here would surface the misconfiguration without blocking the run.

self.dflash_self_logit_distillation = config.dflash_self_logit_distillation
self.dflash_num_anchors = config.dflash_num_anchors
self.dflash_report_acc = config.dflash_report_acc
Expand Down
47 changes: 45 additions & 2 deletions modelopt/torch/speculative/plugins/hf_dflash.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,36 @@
__all__ = ["HFDFlashModel"]


def _dpace_position_weights(confidences: torch.Tensor, alpha: float) -> torch.Tensor:
"""Compute detached D-PACE per-position weights from draft confidences.

Implements D-PACE (arXiv:2605.18810) Eq.7-8: each draft confidence ``q_i`` is
smoothed toward 1 with ``q~_i = (1 - alpha) * q_i + alpha`` (Eq.7), then the
per-position weight is the suffix-sum of the prefix products of the smoothed
confidences, ``w_j = sum_{m >= j} prod_{i <= m} q~_i`` (Eq.8). This factors into
the prefix acceptance probability times the remaining accepted-length value, so
the loss tracks each position's contribution to expected accepted block length.

Args:
confidences: ``[..., L]`` draft confidence ``q_i = exp(-CE)`` per position.
alpha: smoothing factor in (0, 1]; raises if outside [0, 1].

Returns:
Detached weights with the same shape and dtype as ``confidences``.
"""
if not 0.0 <= alpha <= 1.0:

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[SUGGESTION] Docstring says alpha is in (0, 1] but the validation accepts [0, 1] (closed at 0). The user-facing path through DFlashModel.modify() correctly rejects alpha=0, but a direct caller of _dpace_position_weights with alpha=0 would silently get all-zero weights (the cumulative product collapses on the first position) instead of an error. Tighten the check to 0.0 < alpha <= 1.0 to match the docstring, or relax the docstring to [0, 1].

raise ValueError(f"dflash_dpace_alpha must be in [0, 1], got {alpha}")

with torch.no_grad():
smoothed = (1.0 - alpha) * confidences.float() + alpha
prefix_products = torch.cumprod(smoothed, dim=-1)
# Suffix sum over positions: reverse -> cumsum -> reverse.
weights = torch.flip(
torch.cumsum(torch.flip(prefix_products, dims=[-1]), dim=-1), dims=[-1]
)
return weights.to(dtype=confidences.dtype)


@DFlashDMRegistry.register({PreTrainedModel: "hf.PreTrainedModel"})
class HFDFlashModel(DFlashModel):
"""DFlash Model for HuggingFace transformers."""
Expand Down Expand Up @@ -349,8 +379,21 @@ def _compute_loss(

binary_eval_mask = weight_mask.view(-1)

# Optional loss decay
if self.dflash_loss_decay_factor > 0:
# Block-position loss weighting: dynamic D-PACE weights or static exponential decay.
if self.dflash_loss_objective == "dpace" and block_size > 1:
# Draft confidence q_i = exp(-CE) on the target-selected token, over the
# predicted positions (slot 0 is the given anchor, already masked above).
# Weights are detached (paper Eq.9), so this adds the documented ~2.3%
# training overhead without altering the cross-entropy gradient.
with torch.no_grad():
conf_ce = F.cross_entropy(
logits.view(-1, logits.size(-1)), target_ids.view(-1), reduction="none"
).view(bsz, n_blocks, block_size)
confidences = torch.exp(-conf_ce[..., 1:].float())
dpace = torch.ones_like(weight_mask)
dpace[..., 1:] = _dpace_position_weights(confidences, self.dflash_dpace_alpha)
weight_mask = weight_mask * dpace
elif self.dflash_loss_decay_factor > 0:
k = torch.arange(block_size, device=device).view(1, 1, -1)
decay = torch.exp(-(k - 1).clamp(min=0).float() / self.dflash_loss_decay_factor)
weight_mask = weight_mask * decay
Comment on lines +388 to 399

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[SUGGESTION] When base_logits is None (the non-KD path), the per-token cross-entropy is computed twice — once here under no_grad to derive confidences, and again at line 421 to compute the actual loss. Since the second computation is exactly the per-token CE you already have, you could reuse it (compute once with grad enabled, take .detach().exp() for the confidences). The PR description already acknowledges the ~2.3% overhead — eliminating this duplication would close most of that gap. The KD path correctly remains separate because its actual loss is KL, not CE.

Why it matters: small but free win on training throughput; CE is one of the more expensive ops in the inner training loop because of the vocab-size matmul.

How to apply: hoist a single loss_per_token = F.cross_entropy(...) computation, derive confidences = torch.exp(-loss_per_token.detach()).view(bsz, n_blocks, block_size)[..., 1:].float(), then later use the same loss_per_token in the loss reduction. Keep the no-grad CE only for the KD branch.

Expand Down
64 changes: 64 additions & 0 deletions tests/unit/torch/speculative/plugins/test_hf_dflash.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from types import SimpleNamespace
from unittest.mock import MagicMock

import pytest
import torch
from _test_utils.torch.transformers_models import (
get_tiny_llama,
Expand All @@ -38,6 +39,7 @@
DFlashAttention,
DFlashModule,
HFDFlashModel,
_dpace_position_weights,
build_target_layer_ids,
)
from modelopt.torch.speculative.utils import AcceptanceRateValidation
Expand Down Expand Up @@ -116,6 +118,68 @@ def test_convert_sets_mask_token_id(self):
assert model.mask_token_id == 0


class TestDPaceWeights:
"""Test the D-PACE position-weighting objective (arXiv:2605.18810)."""

def test_weights_match_paper_formula(self):
"""w_j = sum_{m>=j} prod_{i<=m} q~_i with q~_i = (1-a)q_i + a (Eq.7-8)."""
alpha = 0.5
conf = torch.tensor([[0.9, 0.6, 0.3, 0.8]])
weights = _dpace_position_weights(conf, alpha)

smoothed = (1.0 - alpha) * conf + alpha
prefix = torch.cumprod(smoothed, dim=-1)
expected = torch.flip(torch.cumsum(torch.flip(prefix, [-1]), dim=-1), [-1])
assert torch.allclose(weights, expected, atol=1e-6)

def test_weights_are_detached(self):
"""Weights must carry no gradient (paper Eq.9 detaches them)."""
conf = torch.rand(2, 3, 5, requires_grad=True)
weights = _dpace_position_weights(conf, 0.5)
assert not weights.requires_grad

def test_weights_monotonic_nonincreasing(self):
"""Suffix-sum of positive prefix products is non-increasing along the block."""
conf = torch.rand(4, 8).clamp(0.05, 0.99)
weights = _dpace_position_weights(conf, 0.5)
assert torch.all(weights[:, :-1] >= weights[:, 1:] - 1e-6)

def test_smoothing_keeps_later_weights_nonzero(self):
"""With alpha>0, q~_i >= alpha so cumulative products cannot vanish."""
conf = torch.zeros(1, 6) # worst case: zero confidence everywhere
weights = _dpace_position_weights(conf, alpha=0.5)
assert torch.all(weights > 0)

def test_invalid_alpha_raises(self):
with pytest.raises(ValueError, match="dflash_dpace_alpha"):
_dpace_position_weights(torch.rand(1, 4), alpha=1.5)

def test_convert_with_dpace_objective(self):
"""Convert with the dpace objective wires the attributes onto the model."""
model = get_tiny_llama(num_hidden_layers=4)
config = _get_dflash_config()
config["dflash_loss_objective"] = "dpace"
config["dflash_dpace_alpha"] = 0.3
mtsp.convert(model, [("dflash", config)])
assert model.dflash_loss_objective == "dpace"
assert model.dflash_dpace_alpha == 0.3

def test_convert_rejects_bad_objective(self):
model = get_tiny_llama(num_hidden_layers=4)
config = _get_dflash_config()
config["dflash_loss_objective"] = "nope"
with pytest.raises(ValueError, match="dflash_loss_objective"):
mtsp.convert(model, [("dflash", config)])

def test_convert_rejects_degenerate_alpha(self):
model = get_tiny_llama(num_hidden_layers=4)
config = _get_dflash_config()
config["dflash_loss_objective"] = "dpace"
config["dflash_dpace_alpha"] = 0.0
with pytest.raises(ValueError, match="dflash_dpace_alpha"):
mtsp.convert(model, [("dflash", config)])


class TestDFlashSaveRestore:
"""Test DFlash model save and restore."""

Expand Down
Loading