diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 49c5858667..aa89d8a24f 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -6,6 +6,7 @@ Changelog **New Features** +- Add the **D-PACE** loss objective for DFlash speculative-decoding training (`arXiv:2605.18810 `_). Set ``dflash_loss_objective: dpace`` to replace the static exponential position decay with dynamic, confidence-derived per-position weights that adapt to whichever block positions currently limit acceptance. Smoothing is controlled by ``dflash_dpace_alpha`` (default 0.5). Training-only and detached from the gradient (no architecture or inference change). - Add the ``day0-release`` agent skill (``.agents/skills/day0-release/``), a deterministic end-to-end driver that chains the PTQ → evaluation → comparison skills (the evaluation stage deploys the checkpoint itself) with an enforced gate after each stage and returns a publish decision (ACCEPT / REGRESSION / ANOMALOUS / INFEASIBLE). Ships three GPU-free, unit-tested gate scripts (``gate_ptq.py``, ``gate_run.py``, ``gate_compare.py``) that validate checkpoint coverage, evaluation-run completeness, and baseline-vs-candidate accuracy threshold. v1 reports and stops on regression; the recipe-search loop is deferred. - Add **streaming** speculative-decoding training (EAGLE3 / DFlash): the draft trains on base-model hidden states produced on the fly by a co-located ``vllm serve`` (no disk dump), moved trainer-side over NIXL RDMA, scaling to multi-node (dedicated serve replicas + DDP trainers). New launcher examples for NVFP4 Kimi-K2.5 / K2.6 on GB200/aarch64 under ``tools/launcher/examples/moonshotai/``. diff --git a/examples/speculative_decoding/doc/dflash.md b/examples/speculative_decoding/doc/dflash.md index 44db5d39e7..b5a5ae998d 100644 --- a/examples/speculative_decoding/doc/dflash.md +++ b/examples/speculative_decoding/doc/dflash.md @@ -162,6 +162,8 @@ See [`modelopt_recipes/general/speculative_decoding/dflash.yaml`](../../../model | `dflash.dflash_block_size` | 8 | Block size for parallel prediction | | `dflash.dflash_num_anchors` | 512 | Random anchor positions per sample (see below) | | `dflash.dflash_loss_decay_factor` | 4.0 | Exponential decay gamma (0 disables, see below) | +| `dflash.dflash_loss_objective` | `decay` | Position weighting: `decay` (static) or `dpace` (dynamic, see below) | +| `dflash.dflash_dpace_alpha` | 0.5 | D-PACE smoothing factor in (0, 1]; only used when objective is `dpace` | | `dflash.dflash_self_logit_distillation` | true | Use target model logits as soft labels (vs hard CE) | | `dflash.dflash_architecture_config.num_hidden_layers` | 5 | Draft decoder layers | | `dflash.dflash_architecture_config.mask_token_id` | auto | Token ID for masked positions | @@ -239,6 +241,34 @@ Note: this is different from EAGLE3's `eagle_loss_decay_factor` which multiplies `alpha^step` across TTT steps. DFlash decay operates within a single block, weighting early positions higher because they gate acceptance of all later positions. +### D-PACE (Dynamic Position-Aware Cross-Entropy) + +Set `dflash.dflash_loss_objective: dpace` to replace the static decay with **D-PACE** +([arXiv:2605.18810](https://arxiv.org/abs/2605.18810)), which derives per-position weights +from a differentiable surrogate of expected accepted block length. Where static decay uses +a fixed schedule, D-PACE adapts to the draft's own per-position confidence and shifts +training signal toward whichever positions currently limit acceptance as the drafter improves. + +For each block, let `q_i = exp(-CE_i)` be the draft confidence on the target token at +predicted position `i`. D-PACE smooths it (Eq.7) and weights each position by the suffix-sum +of prefix products (Eq.8): + +```text +q~_i = (1 - alpha) * q_i + alpha +w_j = sum_{m >= j} prod_{i <= m} q~_i # detached; multiplies the per-token CE +``` + +The weight factors into the prefix-acceptance probability (`prod_{i<=j} q~_i`) times the +remaining accepted-length value, so it directly targets expected accepted length. The +weights are detached from the gradient — D-PACE only reshapes credit assignment and adds +~2.3% training overhead with no change to the draft architecture or inference. + +- `dflash_dpace_alpha` is the asymmetric smoothing floor (`q~_i >= alpha`) that keeps later + weights from vanishing. Stable in `[0.3, 0.7]`; `alpha=0` is rejected (cumulative product + collapses), and `alpha → 1` flattens toward uniform weighting. Default `0.5`. +- D-PACE is mutually exclusive with `dflash_loss_decay_factor`; when objective is `dpace`, + the decay factor is ignored. + ### Checkpoint Resume DFlash supports checkpoint resume transparently. Rotary embeddings are lazily diff --git a/modelopt/torch/speculative/config.py b/modelopt/torch/speculative/config.py index 7649b2d035..56eab51fc7 100644 --- a/modelopt/torch/speculative/config.py +++ b/modelopt/torch/speculative/config.py @@ -103,7 +103,23 @@ class DFlashConfig(ModeloptBaseConfig): dflash_loss_decay_factor: float = ModeloptField( default=0.0, description="Gamma for exponential loss decay weighting (paper Eq.4). " - "Suggested: 7 for block_size=16, 5 for 10, 4 for 8. 0 disables.", + "Suggested: 7 for block_size=16, 5 for 10, 4 for 8. 0 disables. " + "Only used when dflash_loss_objective='decay'.", + ) + + dflash_loss_objective: str = ModeloptField( + default="decay", + description="Block-position loss weighting objective. 'decay' uses the static " + "exponential decay of dflash_loss_decay_factor (DFlash, arXiv:2602.06036 Eq.4). " + "'dpace' uses dynamic, confidence-derived per-position weights " + "(D-PACE, arXiv:2605.18810 Eq.8).", + ) + + dflash_dpace_alpha: float = ModeloptField( + default=0.5, + description="D-PACE asymmetric smoothing factor alpha in (0, 1] (paper Eq.7). Used only " + "when dflash_loss_objective='dpace'. Stable in [0.3, 0.7]; alpha=0 is degenerate " + "(cumulative product vanishes) and alpha->1 removes the adaptive signal.", ) dflash_num_anchors: int = ModeloptField( diff --git a/modelopt/torch/speculative/dflash/dflash_model.py b/modelopt/torch/speculative/dflash/dflash_model.py index a99e93c816..f9a55f437e 100644 --- a/modelopt/torch/speculative/dflash/dflash_model.py +++ b/modelopt/torch/speculative/dflash/dflash_model.py @@ -15,8 +15,12 @@ """DFlash model to support block-wise parallel speculative decoding.""" +import logging + from modelopt.torch.opt.dynamic import DynamicModule +logger = logging.getLogger(__name__) + class DFlashModel(DynamicModule): """Base DFlash Model.""" @@ -31,6 +35,24 @@ def modify(self, config): self.dflash_block_size = config.dflash_block_size self.dflash_freeze_base_model = config.dflash_freeze_base_model self.dflash_loss_decay_factor = config.dflash_loss_decay_factor + self.dflash_loss_objective = config.dflash_loss_objective + self.dflash_dpace_alpha = config.dflash_dpace_alpha + if self.dflash_loss_objective not in ("decay", "dpace"): + raise ValueError( + f"dflash_loss_objective must be 'decay' or 'dpace', got " + f"{self.dflash_loss_objective!r}" + ) + if self.dflash_loss_objective == "dpace" and not 0.0 < self.dflash_dpace_alpha <= 1.0: + raise ValueError( + f"dflash_dpace_alpha must be in (0, 1] for the D-PACE objective, got " + f"{self.dflash_dpace_alpha}" + ) + if self.dflash_loss_objective == "dpace" and self.dflash_loss_decay_factor > 0: + logger.warning( + "dflash_loss_decay_factor=%s is ignored when dflash_loss_objective='dpace'; " + "D-PACE derives per-position weights dynamically from draft confidence.", + self.dflash_loss_decay_factor, + ) self.dflash_self_logit_distillation = config.dflash_self_logit_distillation self.dflash_num_anchors = config.dflash_num_anchors self.dflash_report_acc = config.dflash_report_acc diff --git a/modelopt/torch/speculative/plugins/hf_dflash.py b/modelopt/torch/speculative/plugins/hf_dflash.py index 1760cb2072..8cb77b04a4 100644 --- a/modelopt/torch/speculative/plugins/hf_dflash.py +++ b/modelopt/torch/speculative/plugins/hf_dflash.py @@ -74,6 +74,36 @@ __all__ = ["HFDFlashModel"] +def _dpace_position_weights(confidences: torch.Tensor, alpha: float) -> torch.Tensor: + """Compute detached D-PACE per-position weights from draft confidences. + + Implements D-PACE (arXiv:2605.18810) Eq.7-8: each draft confidence ``q_i`` is + smoothed toward 1 with ``q~_i = (1 - alpha) * q_i + alpha`` (Eq.7), then the + per-position weight is the suffix-sum of the prefix products of the smoothed + confidences, ``w_j = sum_{m >= j} prod_{i <= m} q~_i`` (Eq.8). This factors into + the prefix acceptance probability times the remaining accepted-length value, so + the loss tracks each position's contribution to expected accepted block length. + + Args: + confidences: ``[..., L]`` draft confidence ``q_i = exp(-CE)`` per position. + alpha: smoothing factor in (0, 1]; raises if outside that range. + + Returns: + Detached weights with the same shape and dtype as ``confidences``. + """ + if not 0.0 < alpha <= 1.0: + raise ValueError(f"dflash_dpace_alpha must be in (0, 1], got {alpha}") + + with torch.no_grad(): + smoothed = (1.0 - alpha) * confidences.float() + alpha + prefix_products = torch.cumprod(smoothed, dim=-1) + # Suffix sum over positions: reverse -> cumsum -> reverse. + weights = torch.flip( + torch.cumsum(torch.flip(prefix_products, dims=[-1]), dim=-1), dims=[-1] + ) + return weights.to(dtype=confidences.dtype) + + @DFlashDMRegistry.register({PreTrainedModel: "hf.PreTrainedModel"}) class HFDFlashModel(DFlashModel): """DFlash Model for HuggingFace transformers.""" @@ -349,14 +379,38 @@ def _compute_loss( binary_eval_mask = weight_mask.view(-1) - # Optional loss decay - if self.dflash_loss_decay_factor > 0: + flat_logits = logits.view(-1, logits.size(-1)) + flat_targets = target_ids.view(-1) + + # Non-KD loss is per-token cross-entropy; compute it once (grad enabled) so the + # D-PACE confidences below can reuse it instead of a second CE pass. The KD path + # (base_logits is not None) optimizes KL, so its confidences need a dedicated + # no_grad CE pass. + loss_per_token = None + if base_logits is None: + loss_per_token = F.cross_entropy(flat_logits, flat_targets, reduction="none") + + # Block-position loss weighting: dynamic D-PACE weights or static exponential decay. + if self.dflash_loss_objective == "dpace" and block_size > 1: + # Draft confidence q_i = exp(-CE) on the target-selected token, over the + # predicted positions (slot 0 is the given anchor, already masked above). + # Weights are detached (paper Eq.9), so this adds the documented ~2.3% + # training overhead without altering the cross-entropy gradient. + with torch.no_grad(): + conf_ce = ( + loss_per_token.detach() + if loss_per_token is not None + else F.cross_entropy(flat_logits, flat_targets, reduction="none") + ).view(bsz, n_blocks, block_size) + confidences = torch.exp(-conf_ce[..., 1:].float()) + dpace = torch.ones_like(weight_mask) + dpace[..., 1:] = _dpace_position_weights(confidences, self.dflash_dpace_alpha) + weight_mask = weight_mask * dpace + elif self.dflash_loss_decay_factor > 0: k = torch.arange(block_size, device=device).view(1, 1, -1) decay = torch.exp(-(k - 1).clamp(min=0).float() / self.dflash_loss_decay_factor) weight_mask = weight_mask * decay - flat_logits = logits.view(-1, logits.size(-1)) - flat_targets = target_ids.view(-1) flat_weights = weight_mask.view(-1) valid_count = flat_weights.sum() + 1e-6 @@ -375,7 +429,6 @@ def _compute_loss( kd_loss = -(target_soft * draft_logsoft).sum(dim=-1) loss = (kd_loss * flat_weights).sum() / valid_count else: - loss_per_token = F.cross_entropy(flat_logits, flat_targets, reduction="none") loss = (loss_per_token * flat_weights).sum() / valid_count with torch.no_grad(): diff --git a/tests/unit/torch/speculative/plugins/test_hf_dflash.py b/tests/unit/torch/speculative/plugins/test_hf_dflash.py index e35ac698e7..dbba0456e1 100644 --- a/tests/unit/torch/speculative/plugins/test_hf_dflash.py +++ b/tests/unit/torch/speculative/plugins/test_hf_dflash.py @@ -24,6 +24,7 @@ from types import SimpleNamespace from unittest.mock import MagicMock +import pytest import torch from _test_utils.torch.transformers_models import ( get_tiny_llama, @@ -38,6 +39,7 @@ DFlashAttention, DFlashModule, HFDFlashModel, + _dpace_position_weights, build_target_layer_ids, ) from modelopt.torch.speculative.utils import AcceptanceRateValidation @@ -116,6 +118,68 @@ def test_convert_sets_mask_token_id(self): assert model.mask_token_id == 0 +class TestDPaceWeights: + """Test the D-PACE position-weighting objective (arXiv:2605.18810).""" + + def test_weights_match_paper_formula(self): + """w_j = sum_{m>=j} prod_{i<=m} q~_i with q~_i = (1-a)q_i + a (Eq.7-8).""" + alpha = 0.5 + conf = torch.tensor([[0.9, 0.6, 0.3, 0.8]]) + weights = _dpace_position_weights(conf, alpha) + + smoothed = (1.0 - alpha) * conf + alpha + prefix = torch.cumprod(smoothed, dim=-1) + expected = torch.flip(torch.cumsum(torch.flip(prefix, [-1]), dim=-1), [-1]) + assert torch.allclose(weights, expected, atol=1e-6) + + def test_weights_are_detached(self): + """Weights must carry no gradient (paper Eq.9 detaches them).""" + conf = torch.rand(2, 3, 5, requires_grad=True) + weights = _dpace_position_weights(conf, 0.5) + assert not weights.requires_grad + + def test_weights_monotonic_nonincreasing(self): + """Suffix-sum of positive prefix products is non-increasing along the block.""" + conf = torch.rand(4, 8).clamp(0.05, 0.99) + weights = _dpace_position_weights(conf, 0.5) + assert torch.all(weights[:, :-1] >= weights[:, 1:] - 1e-6) + + def test_smoothing_keeps_later_weights_nonzero(self): + """With alpha>0, q~_i >= alpha so cumulative products cannot vanish.""" + conf = torch.zeros(1, 6) # worst case: zero confidence everywhere + weights = _dpace_position_weights(conf, alpha=0.5) + assert torch.all(weights > 0) + + def test_invalid_alpha_raises(self): + with pytest.raises(ValueError, match="dflash_dpace_alpha"): + _dpace_position_weights(torch.rand(1, 4), alpha=1.5) + + def test_convert_with_dpace_objective(self): + """Convert with the dpace objective wires the attributes onto the model.""" + model = get_tiny_llama(num_hidden_layers=4) + config = _get_dflash_config() + config["dflash_loss_objective"] = "dpace" + config["dflash_dpace_alpha"] = 0.3 + mtsp.convert(model, [("dflash", config)]) + assert model.dflash_loss_objective == "dpace" + assert model.dflash_dpace_alpha == 0.3 + + def test_convert_rejects_bad_objective(self): + model = get_tiny_llama(num_hidden_layers=4) + config = _get_dflash_config() + config["dflash_loss_objective"] = "nope" + with pytest.raises(ValueError, match="dflash_loss_objective"): + mtsp.convert(model, [("dflash", config)]) + + def test_convert_rejects_degenerate_alpha(self): + model = get_tiny_llama(num_hidden_layers=4) + config = _get_dflash_config() + config["dflash_loss_objective"] = "dpace" + config["dflash_dpace_alpha"] = 0.0 + with pytest.raises(ValueError, match="dflash_dpace_alpha"): + mtsp.convert(model, [("dflash", config)]) + + class TestDFlashSaveRestore: """Test DFlash model save and restore."""