Skip to content

[Feat]:Support DPace#1724

Open
h-guo18 wants to merge 1 commit into
mainfrom
haoguo/dpace
Open

[Feat]:Support DPace#1724
h-guo18 wants to merge 1 commit into
mainfrom
haoguo/dpace

Conversation

@h-guo18

@h-guo18 h-guo18 commented Jun 15, 2026

Copy link
Copy Markdown
Contributor

What does this PR do?

Type of change: New feature

Adds the D-PACE (Dynamic Position-Aware Cross-Entropy) loss objective for DFlash speculative-decoding training (arXiv:2605.18810). It replaces the static exponential position decay with per-position CE weights derived from the draft's own confidence q_i = exp(-CE_i): smoothed q̃_i = (1-α)q_i + α (Eq.7) and weighted by the suffix-sum of prefix products w_j = Σ_{m≥j} ∏_{i≤m} q̃_i (Eq.8), which directly targets expected accepted block length and shifts signal toward whichever positions currently limit acceptance.

Selected via dflash_loss_objective: dpace (default decay keeps current behavior); smoothing via dflash_dpace_alpha (default 0.5). Weights are detached from the gradient — training-only, ~2.3% overhead, no architecture or inference change. Mutually exclusive with dflash_loss_decay_factor.

Usage

# DFlash recipe / training config
dflash:
  dflash_loss_objective: dpace   # default: decay
  dflash_dpace_alpha: 0.5        # smoothing in (0, 1]; stable in [0.3, 0.7]

Testing

CPU unit tests in tests/unit/torch/speculative/plugins/test_hf_dflash.py: weights match the paper closed form, are detached and non-increasing, the α smoothing floor keeps later weights non-zero, and convert wires/validates the new fields (rejects bad objective and degenerate α). Training validated on Qwen3-8B (curve below).

image

Before your PR is "Ready for review"

  • Is this change backward compatible?: ✅ (opt-in; default dflash_loss_objective=decay is unchanged)
  • If you copied code from any other sources or added a new PIP dependency, did you follow guidance in CONTRIBUTING.md: N/A (no new dependency)
  • Did you write any new necessary tests?: ✅
  • Did you update Changelog?: ✅
  • Did you get Claude approval on this PR?: ❌

Additional Information

Reference: D-PACE, arXiv:2605.18810. See examples/speculative_decoding/doc/dflash.md for the math and tuning notes.

Summary by CodeRabbit

Release Notes

  • New Features

    • Added D-PACE (Dynamic Position-Aware Cross-Entropy) loss objective for DFlash speculative-decoding training with configurable smoothing control.
  • Documentation

    • Updated DFlash training configuration documentation with new loss-objective parameters and D-PACE behavior details.
  • Tests

    • Added comprehensive tests for D-PACE weight computation and configuration validation.

Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
@copy-pr-bot

copy-pr-bot Bot commented Jun 15, 2026

Copy link
Copy Markdown

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@coderabbitai

coderabbitai Bot commented Jun 15, 2026

Copy link
Copy Markdown
Contributor

Review Change Stack

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 770841e1-ba88-441b-a636-46bddd103819

📥 Commits

Reviewing files that changed from the base of the PR and between 9f6e8fd and 1653417.

📒 Files selected for processing (6)
  • CHANGELOG.rst
  • examples/speculative_decoding/doc/dflash.md
  • modelopt/torch/speculative/config.py
  • modelopt/torch/speculative/dflash/dflash_model.py
  • modelopt/torch/speculative/plugins/hf_dflash.py
  • tests/unit/torch/speculative/plugins/test_hf_dflash.py

📝 Walkthrough

Walkthrough

Adds D-PACE (Dynamic Position-Aware Cross-Entropy) as a new DFlash training loss objective. Two new DFlashConfig fields (dflash_loss_objective, dflash_dpace_alpha) gate the feature. A new helper _dpace_position_weights computes detached per-position weights; _compute_loss applies them when "dpace" is selected. Input validation is added in DFlashModel.modify(). Tests, documentation, and a changelog entry are included.

Changes

D-PACE Loss Objective for DFlash

Layer / File(s) Summary
DFlashConfig new fields
modelopt/torch/speculative/config.py
Adds dflash_loss_objective (str, default "decay") and dflash_dpace_alpha (float, default 0.5) to DFlashConfig; updates dflash_loss_decay_factor docstring to note it is only used with the "decay" objective.
D-PACE weight helper and _compute_loss integration
modelopt/torch/speculative/plugins/hf_dflash.py, modelopt/torch/speculative/dflash/dflash_model.py
_dpace_position_weights computes detached per-position weights from draft confidences using smoothed cumprod and reverse cumsum. DFlashModel.modify() validates both new config fields. _compute_loss adds a "dpace" branch that derives confidences via exp(-CE), calls the helper under no_grad, and multiplies the resulting weights into weight_mask.
Tests, documentation, and changelog
tests/unit/torch/speculative/plugins/test_hf_dflash.py, examples/speculative_decoding/doc/dflash.md, CHANGELOG.rst
TestDPaceWeights tests formula correctness, gradient detachment, monotonicity, smoothing, error handling, and mtsp.convert wiring. Documentation adds the D-PACE subsection with formulas and parameter descriptions. Changelog records the new feature.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

🚥 Pre-merge checks | ✅ 5 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 75.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (5 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title '[Feat]:Support DPace' directly references the main feature being introduced (D-PACE), which is the core change across all modified files. The title accurately summarizes the primary objective of enabling D-PACE as a new loss objective for DFlash speculative-decoding training.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
Security Anti-Patterns ✅ Passed PR reviewed against SECURITY.md guidelines. No torch.load(weights_only=False), numpy.load(allow_pickle=True), hardcoded trust_remote_code, eval/exec on external input, # nosec comments, or new non-...

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch haoguo/dpace

Comment @coderabbitai help to get the list of available commands and usage tips.

@github-actions

Copy link
Copy Markdown
Contributor
PR Preview Action v1.8.1

QR code for preview link

🚀 View preview at
https://NVIDIA.github.io/Model-Optimizer/pr-preview/pr-1724/

Built to branch gh-pages at 2026-06-15 07:04 UTC.
Preview will be ready when the GitHub Pages deployment is complete.

@codecov

codecov Bot commented Jun 15, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 71.42857% with 6 lines in your changes missing coverage. Please review.
✅ Project coverage is 76.55%. Comparing base (9f6e8fd) to head (1653417).
⚠️ Report is 8 commits behind head on main.

Files with missing lines Patch % Lines
modelopt/torch/speculative/plugins/hf_dflash.py 60.00% 6 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1724      +/-   ##
==========================================
- Coverage   77.12%   76.55%   -0.58%     
==========================================
  Files         511      511              
  Lines       56247    56267      +20     
==========================================
- Hits        43381    43073     -308     
- Misses      12866    13194     +328     
Flag Coverage Δ
examples 41.83% <14.28%> (-0.12%) ⬇️
gpu 57.77% <42.85%> (-0.60%) ⬇️
regression 14.70% <42.85%> (+0.07%) ⬆️
unit 54.39% <61.90%> (-0.01%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Harness.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@h-guo18 h-guo18 marked this pull request as ready for review June 15, 2026 18:59
@h-guo18 h-guo18 requested a review from a team as a code owner June 15, 2026 18:59
@h-guo18 h-guo18 requested a review from yeyu-nvidia June 15, 2026 18:59
@h-guo18

h-guo18 commented Jun 15, 2026

Copy link
Copy Markdown
Contributor Author

/claude review

@ChenhanYu

Copy link
Copy Markdown
Collaborator

/claude review

1 similar comment
@h-guo18

h-guo18 commented Jun 16, 2026

Copy link
Copy Markdown
Contributor Author

/claude review

Comment on lines +388 to 399
with torch.no_grad():
conf_ce = F.cross_entropy(
logits.view(-1, logits.size(-1)), target_ids.view(-1), reduction="none"
).view(bsz, n_blocks, block_size)
confidences = torch.exp(-conf_ce[..., 1:].float())
dpace = torch.ones_like(weight_mask)
dpace[..., 1:] = _dpace_position_weights(confidences, self.dflash_dpace_alpha)
weight_mask = weight_mask * dpace
elif self.dflash_loss_decay_factor > 0:
k = torch.arange(block_size, device=device).view(1, 1, -1)
decay = torch.exp(-(k - 1).clamp(min=0).float() / self.dflash_loss_decay_factor)
weight_mask = weight_mask * decay

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[SUGGESTION] When base_logits is None (the non-KD path), the per-token cross-entropy is computed twice — once here under no_grad to derive confidences, and again at line 421 to compute the actual loss. Since the second computation is exactly the per-token CE you already have, you could reuse it (compute once with grad enabled, take .detach().exp() for the confidences). The PR description already acknowledges the ~2.3% overhead — eliminating this duplication would close most of that gap. The KD path correctly remains separate because its actual loss is KL, not CE.

Why it matters: small but free win on training throughput; CE is one of the more expensive ops in the inner training loop because of the vocab-size matmul.

How to apply: hoist a single loss_per_token = F.cross_entropy(...) computation, derive confidences = torch.exp(-loss_per_token.detach()).view(bsz, n_blocks, block_size)[..., 1:].float(), then later use the same loss_per_token in the loss reduction. Keep the no-grad CE only for the KD branch.

Returns:
Detached weights with the same shape and dtype as ``confidences``.
"""
if not 0.0 <= alpha <= 1.0:

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[SUGGESTION] Docstring says alpha is in (0, 1] but the validation accepts [0, 1] (closed at 0). The user-facing path through DFlashModel.modify() correctly rejects alpha=0, but a direct caller of _dpace_position_weights with alpha=0 would silently get all-zero weights (the cumulative product collapses on the first position) instead of an error. Tighten the check to 0.0 < alpha <= 1.0 to match the docstring, or relax the docstring to [0, 1].

raise ValueError(
f"dflash_dpace_alpha must be in (0, 1] for the D-PACE objective, got "
f"{self.dflash_dpace_alpha}"
)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[SUGGESTION] Consider warning (or rejecting) when dflash_loss_objective == "dpace" and dflash_loss_decay_factor != 0.0 (i.e. the user has explicitly set both). The default recipe modelopt_recipes/general/speculative_decoding/dflash.yaml already sets dflash_loss_decay_factor: 4.0, so a user who only flips dflash_loss_objective: dpace won't realize their non-default decay value is silently ignored (the doc notes the mutual exclusion, but the runtime is silent). A logger.warning(...) here would surface the misconfiguration without blocking the run.

@claude claude Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude review passed — no blocking issues found.

Summary (CRITICAL: 0, IMPORTANT: 0, SUGGESTION: 3)

The D-PACE implementation is correct and well-scoped:

  • Algorithm matches paper Eq.7-8: smoothing q~_i = (1-α)q_i + α, prefix-product, suffix-sum (reverse-cumsum-reverse), all under no_grad and explicitly detached.
  • Opt-in (default dflash_loss_objective='decay' preserves prior behavior) — no backward-compat concern.
  • Validation lives in DFlashModel.modify() so bad configs fail at convert-time, not deep in the training loop.
  • Tests cover formula correctness, detachment, monotonicity, smoothing floor, error paths, and mtsp.convert wiring.
  • Position-0 (anchor) is correctly excluded from D-PACE weights via [..., 1:] slicing.
  • No mode-state schema or export-path changes — purely a training-loss feature.

Inline suggestions (non-blocking):

  1. CE on the predicted positions is computed twice in the non-KD path (once for confidences, once for the loss). Reusing one computation would close most of the documented ~2.3% overhead.
  2. _dpace_position_weights accepts alpha=0 (silently zero weights) while its docstring claims (0, 1] — tighten the runtime check to match.
  3. When dflash_loss_objective='dpace' and dflash_loss_decay_factor is non-default, the decay value is silently ignored. A logger.warning would surface the misconfiguration since the default recipe already sets decay_factor: 4.0.

LGTM.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants