-
Notifications
You must be signed in to change notification settings - Fork 443
Add support for dLLM encoder-decoder models (DiffusionGemma) [tied-weight PTQ export support ] #1707
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add support for dLLM encoder-decoder models (DiffusionGemma) [tied-weight PTQ export support ] #1707
Changes from 8 commits
47d4ab6
60d4ebb
225072b
a0d9b65
e351c0f
d684477
d0a735e
0543907
9f5e0e1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -42,13 +42,24 @@ def _export_fused_experts(module: nn.Module, dtype: torch.dtype) -> None: | |
| {E}.gate_proj.weight, {E}.gate_proj.weight_scale, ... | ||
| {E}.up_proj.weight, {E}.up_proj.weight_scale, ... | ||
| {E}.down_proj.weight, {E}.down_proj.weight_scale, ... | ||
|
|
||
| Tied-experts dedup: when multiple fused-expert modules share their 3-D | ||
| source params via HF ``_tied_weights_keys``, the unpacking creates fresh | ||
| per-expert tensors that break the tie. We cache the source ``data_ptr()`` | ||
| at entry and on a later cache hit alias the per-expert ``weight`` / | ||
| ``weight_scale`` / ``weight_scale_2`` back to the prior module so | ||
| downstream dedup catches them. ``input_scale`` is left per-side. | ||
| """ | ||
| from modelopt.torch.export.unified_export_hf import _export_quantized_weight | ||
| from modelopt.torch.quantization.plugins.huggingface import _get_fused_expert_intermediate_dim | ||
|
|
||
| n = module.num_experts | ||
| expert_dim = _get_fused_expert_intermediate_dim(module) | ||
|
|
||
| # Capture source tensor identities BEFORE unpacking (the source | ||
| # attrs are deleted at the end of this function). | ||
| _source_key = (module.gate_up_proj.data_ptr(), module.down_proj.data_ptr()) | ||
|
|
||
| # 1. Shared input quantizers — one per projection type, shared across all experts. | ||
| gate_up_input_q = module.gate_up_proj_input_quantizer | ||
| down_input_q = module.down_proj_input_quantizer | ||
|
|
@@ -178,6 +189,46 @@ def _export_fused_experts(module: nn.Module, dtype: torch.dtype) -> None: | |
| if hasattr(module, attr): | ||
| delattr(module, attr) | ||
|
|
||
| # 5. Tied-experts dedup: if this module's source params have been seen | ||
| # before, alias the bit-identical per-expert buffers (weight, | ||
| # weight_scale, weight_scale_2, input_scale) to the previously-unpacked | ||
| # module. input_scale is safe to alias because sync_tied_input_amax | ||
| # runs earlier in _export_transformers_checkpoint and max-merges the | ||
| # shared input_quantizer amaxes across tied fused-experts modules, so | ||
| # both sides now derive bit-identical input_scale values. | ||
| _cache = _export_fused_experts.__dict__.setdefault("_tied_unpacked_cache", {}) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🔴 High (companion to the
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, addressed in 9f5e0e1 commit |
||
| _prior = _cache.get(_source_key) | ||
| if _prior is not None and _prior is not module: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🟡 Medium — second tied module is fully unpacked, then discarded. On a cache hit the later tied module has already run the entire unpack/pack path above; this block then aliases all of it back to
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Makes sense, updated it, could you please review in 9f5e0e1 commit |
||
| for _idx in range(n): | ||
| _cur_expert = getattr(module, str(_idx), None) | ||
| _prior_expert = getattr(_prior, str(_idx), None) | ||
| if _cur_expert is None or _prior_expert is None: | ||
| continue | ||
| for _proj_name in ("gate_proj", "up_proj", "down_proj"): | ||
| _cur_proj = getattr(_cur_expert, _proj_name, None) | ||
| _prior_proj = getattr(_prior_expert, _proj_name, None) | ||
| if _cur_proj is None or _prior_proj is None: | ||
| continue | ||
| # Alias the weight (Parameter) so both sides reference the | ||
| # same nn.Parameter → same data_ptr() → existing dedup | ||
| # in postprocess_state_dict will drop the duplicate. | ||
| if hasattr(_prior_proj, "weight"): | ||
| _cur_proj.weight = _prior_proj.weight | ||
| # Alias the bit-identical scale buffers (including | ||
| # input_scale, made safe by sync_tied_input_amax pre-export | ||
| # merging). Re-register to ensure data_ptr() matches the | ||
| # prior side's tensor. | ||
| for _attr in ("weight_scale", "weight_scale_2", "input_scale"): | ||
| if not hasattr(_prior_proj, _attr): | ||
| continue | ||
| if _attr in _cur_proj._buffers: | ||
| del _cur_proj._buffers[_attr] | ||
| elif hasattr(_cur_proj, _attr): | ||
| delattr(_cur_proj, _attr) | ||
| _cur_proj.register_buffer(_attr, getattr(_prior_proj, _attr)) | ||
| else: | ||
| _cache[_source_key] = module | ||
|
|
||
|
|
||
| def save_expert_token_count_table(model: nn.Module, output_dir: str | Path | None = None): | ||
| """Collect expert_token_count from all quantized MoE layers and save as an HTML table. | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.