diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 721619dd2e..efda549a9f 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -9,6 +9,7 @@ Changelog - Add the ``day0-release`` agent skill (``.agents/skills/day0-release/``), a deterministic end-to-end driver that chains the PTQ → evaluation → comparison skills (the evaluation stage deploys the checkpoint itself) with an enforced gate after each stage and returns a publish decision (ACCEPT / REGRESSION / ANOMALOUS / INFEASIBLE). Ships three GPU-free, unit-tested gate scripts (``gate_ptq.py``, ``gate_run.py``, ``gate_compare.py``) that validate checkpoint coverage, evaluation-run completeness, and baseline-vs-candidate accuracy threshold. v1 reports and stops on regression; the recipe-search loop is deferred. - Add **streaming** speculative-decoding training (EAGLE3 / DFlash): the draft trains on base-model hidden states produced on the fly by a co-located ``vllm serve`` (no disk dump), moved trainer-side over NIXL RDMA, scaling to multi-node (dedicated serve replicas + DDP trainers). New launcher examples for NVFP4 Kimi-K2.5 / K2.6 on GB200/aarch64 under ``tools/launcher/examples/moonshotai/``. - Add a fused Triton fast path for ``local_hessian`` NVFP4 weight-scale search (the Hessian-weighted FP8-E4M3 scale sweep). For each NVFP4 block it minimizes ``dwᵀ H dw`` over the 126 candidate scales using the per-cin-block local Hessian on tensor cores, replacing the per-weight Python reference sweep — roughly **34x** faster on a single 8192x4096 weight and bit-exact with the reference for fp32/fp16 weights. Used automatically during ``local_hessian`` calibration for both dense and fused-MoE expert weights; falls back to the reference sweep on CPU, when Triton is unavailable, or via ``MODELOPT_NVFP4_TRITON_SWEEP=0``. +- Add tied-weight PTQ and HF-checkpoint export support for block-diffusion encoder-decoder LLMs (e.g. DiffusionGemma) whose encoder/decoder stacks share parameters via HF ``_tied_weights_keys``. ``_export_quantized_weight`` and ``_export_fused_experts`` now alias bit-identical packed ``weight`` / ``weight_scale`` / ``weight_scale_2`` buffers across modules sharing a source weight ``data_ptr()`` so the downstream ``postprocess_state_dict`` dedup catches them (~42% storage reduction on ``nvfp4_experts_only`` for tied 26B MoE checkpoints). New ``sync_tied_input_amax`` helper max-merges per-side ``input_quantizer.amax`` across tied modules before export so single-backbone consumers that load one ``input_scale`` per parameter don't clip either side. Opt-in ``--canonical_tied_naming`` flag (default off) reorders the state_dict so canonical-side keys per HF's ``_tied_weights_keys`` declaration win the data_ptr dedup. New DiffusionGemma model-specific recipe under ``modelopt_recipes/huggingface/diffusiongemma/ptq/`` (``nvfp4_experts_only.yaml`` + its ``disabled_quantizers.yaml`` unit) adds the ``*self_conditioning*`` exclude on top of the standard default, leaving the shared ``default_disabled_quantizers`` unit clean for non-diffusion models — pattern matches the existing ``phi4mm`` / ``nemotron_vl`` model-specific recipes. ``hf_ptq.py`` also unwraps ``ModelOutput`` dataclasses from ``.generate()`` so the preview decode works on diffusion models. Non-tied models see no behavioral change. 0.45 (2026-07-02) ^^^^^^^^^^^^^^^^^ diff --git a/examples/llm_ptq/example_utils.py b/examples/llm_ptq/example_utils.py index d36754a8d4..1d3c196d42 100755 --- a/examples/llm_ptq/example_utils.py +++ b/examples/llm_ptq/example_utils.py @@ -806,7 +806,13 @@ def is_model_on_gpu(model) -> bool: def is_enc_dec(model_type) -> bool: - """Return if the model is a encoder-decoder model.""" + """Return whether the model_type uses encoder-decoder-style preview decode. + + Controls whether ``hf_ptq.py`` slices off the prompt prefix from + ``.generate()`` output. ``diffusion_gemma`` is structurally encoder-decoder + but returns prompt+canvas concatenated, so it stays OFF this list (AR-style + decode applies). + """ return model_type in ["t5", "bart", "whisper"] diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index afb725988c..14215853ea 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -774,6 +774,7 @@ def export_quantized( full_model, export_dir=export_path, extra_state_dict=mtp_state_dict, + canonical_tied_naming=args.canonical_tied_naming, ) if args.qformat == "w4a16_nvfp4": @@ -941,6 +942,11 @@ def input_decode(input_ids): raise ValueError("The processor or tokenizer must be set") def output_decode(generated_ids, input_shape): + # Some `.generate()` returns a ModelOutput dataclass (e.g. DiffusionGemma); + # unwrap to the token tensor so downstream slicing works uniformly. + if hasattr(generated_ids, "sequences"): + generated_ids = generated_ids.sequences + if is_enc_dec(model_type): if processor is not None and isinstance(processor, WhisperProcessor): return processor.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] @@ -1252,6 +1258,19 @@ def parse_args() -> argparse.Namespace: default=512, ) parser.add_argument("--export_path", default="exported_model") + parser.add_argument( + "--canonical_tied_naming", + type=lambda s: s.lower() in ("1", "true", "yes"), + default=False, + help=( + "If True, reorder the exported state_dict so tied-weight aliases " + "dedup to the canonical side declared in the model's HF " + "_tied_weights_keys (e.g. decoder-side for DiffusionGemma4). Off " + "by default to avoid renaming exported keys for models whose " + "downstream consumers expect the legacy (registration-order) " + "winner." + ), + ) parser.add_argument( "--dataset", help=( diff --git a/modelopt/torch/export/model_utils.py b/modelopt/torch/export/model_utils.py index 3bd72d9de9..9c49cae0cf 100755 --- a/modelopt/torch/export/model_utils.py +++ b/modelopt/torch/export/model_utils.py @@ -33,6 +33,9 @@ "Qwen3Next": "qwen3next", "QWen": "qwen", "RecurrentGemma": "recurrentgemma", + # DiffusionGemma must come before "Gemma" — get_model_type substring-matches + # in order, and "gemma" is a substring of "diffusiongemma". + "DiffusionGemma": "diffusion_gemma", "Gemma3": "gemma3", "Gemma2": "gemma2", "Gemma": "gemma", diff --git a/modelopt/torch/export/moe_utils.py b/modelopt/torch/export/moe_utils.py index e325e5346f..73f7391c32 100644 --- a/modelopt/torch/export/moe_utils.py +++ b/modelopt/torch/export/moe_utils.py @@ -23,7 +23,12 @@ import torch.nn as nn -def _export_fused_experts(module: nn.Module, dtype: torch.dtype) -> None: +def _export_fused_experts( + module: nn.Module, + dtype: torch.dtype, + _moe_tied_cache: dict[tuple[int, int], nn.Module] | None = None, + _tied_cache: dict[int, nn.Module] | None = None, +) -> None: """Split fused MoE expert weights and export per-expert quantization scales. Works with any module wrapped by ``_QuantFusedExperts`` — i.e. any HF @@ -42,6 +47,20 @@ def _export_fused_experts(module: nn.Module, dtype: torch.dtype) -> None: {E}.gate_proj.weight, {E}.gate_proj.weight_scale, ... {E}.up_proj.weight, {E}.up_proj.weight_scale, ... {E}.down_proj.weight, {E}.down_proj.weight_scale, ... + + Tied-experts dedup is opt-in via ``_moe_tied_cache``: when multiple + fused-expert modules share their 3-D source params via HF + ``_tied_weights_keys``, the unpacking creates fresh per-expert tensors + that break the tie. With ``_moe_tied_cache`` provided (tuple-keyed by + ``(gate_up_proj.data_ptr(), down_proj.data_ptr())``), the alias step + at the end re-points the per-expert ``weight`` / ``weight_scale`` / + ``weight_scale_2`` / ``input_scale`` buffers at a previously-processed + module sharing the same source memory. ``_tied_cache`` (int-keyed) is + threaded through to the per-projection ``_export_quantized_weight`` + calls so wrapper-level dedup uses the same scope as standalone Linears. + Both caches are owned by the caller (typically + ``_export_transformers_checkpoint``) and scoped to one export + invocation; when ``None`` the corresponding alias step is skipped. """ from modelopt.torch.export.unified_export_hf import _export_quantized_weight from modelopt.torch.quantization.plugins.huggingface import _get_fused_expert_intermediate_dim @@ -49,6 +68,54 @@ def _export_fused_experts(module: nn.Module, dtype: torch.dtype) -> None: n = module.num_experts expert_dim = _get_fused_expert_intermediate_dim(module) + # Capture source tensor identities BEFORE unpacking (the source + # attrs are deleted at the end of this function). + _source_key = (module.gate_up_proj.data_ptr(), module.down_proj.data_ptr()) + + # Tied-experts fast path: if this exact (gate_up, down) source-tensor pair + # has been processed before, build the per-expert subtree by aliasing the + # prior module's already-packed buffers directly — no unpacking, no per- + # expert packing, no transient buffers thrown away. Functionally equivalent + # to the bottom alias step below; that step exists for the cache-miss path + # only (where we register and become the prior for any later tied module). + if _moe_tied_cache is not None: + _prior = _moe_tied_cache.get(_source_key) + if _prior is not None and _prior is not module: + for _idx in range(n): + _prior_expert = getattr(_prior, str(_idx), None) + if _prior_expert is None: + continue + _cur_expert = nn.Module() + for _proj_name in ("gate_proj", "up_proj", "down_proj"): + _prior_proj = getattr(_prior_expert, _proj_name, None) + if _prior_proj is None: + continue + _cur_proj = nn.Module() + if hasattr(_prior_proj, "weight"): + # Alias the Parameter — same data_ptr as the prior side's + # packed bytes; downstream postprocess_state_dict dedup + # collapses the duplicate at write time. + _cur_proj.weight = _prior_proj.weight + for _attr in ("weight_scale", "weight_scale_2", "input_scale"): + if hasattr(_prior_proj, _attr): + _cur_proj.register_buffer(_attr, getattr(_prior_proj, _attr)) + _cur_expert.add_module(_proj_name, _cur_proj) + module.add_module(str(_idx), _cur_expert) + # Source-tensor cleanup mirrors the normal path's end-of-function + # step so the source 3-D Parameters and quantizer ModuleLists don't + # land in the exported state_dict. + for attr in ( + "gate_up_proj", + "down_proj", + "gate_up_proj_weight_quantizers", + "gate_up_proj_input_quantizer", + "down_proj_weight_quantizers", + "down_proj_input_quantizer", + ): + if hasattr(module, attr): + delattr(module, attr) + return + # 1. Shared input quantizers — one per projection type, shared across all experts. gate_up_input_q = module.gate_up_proj_input_quantizer down_input_q = module.down_proj_input_quantizer @@ -154,7 +221,7 @@ def _export_fused_experts(module: nn.Module, dtype: torch.dtype) -> None: wrapper.weight_quantizer = w_quantizer wrapper.input_quantizer = i_quantizer - _export_quantized_weight(wrapper, dtype) + _export_quantized_weight(wrapper, dtype, _tied_cache=_tied_cache) proj = nn.Module() proj.weight = wrapper.weight @@ -178,6 +245,13 @@ def _export_fused_experts(module: nn.Module, dtype: torch.dtype) -> None: if hasattr(module, attr): delattr(module, attr) + # 5. Register this module in the dedup cache so any later tied module + # (same source data_ptr pair) takes the fast path at the top of this + # function. Reached only on cache miss; cache-hit modules early-exited + # above before any unpack work. + if _moe_tied_cache is not None: + _moe_tied_cache[_source_key] = module + def save_expert_token_count_table(model: nn.Module, output_dir: str | Path | None = None): """Collect expert_token_count from all quantized MoE layers and save as an HTML table. diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index 892b8c42ca..56cf7cd44e 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -514,12 +514,27 @@ def llm_dummy_forward(): def _export_quantized_weight( - sub_module: nn.Module, dtype: torch.dtype, weight_name: str = "weight" + sub_module: nn.Module, + dtype: torch.dtype, + weight_name: str = "weight", + _tied_cache: dict[int, nn.Module] | None = None, ): """For the given weight attr of the sub_module, export the quantization info of it. The export includes converting weight tensor to correct quantized values and quantized dtype, and registering scaling factors. + + Tied-weight dedup is opt-in via ``_tied_cache``: the setattr below replaces + ``.weight`` with a fresh ``nn.Parameter`` wrapping packed bytes, breaking + any HF-level tie. When the caller passes a ``_tied_cache`` dict (keyed by + the pre-pack ``weight.data_ptr()``), the alias step at the end re-points + ``weight`` / ``weight_scale`` / ``weight_scale_2`` at a previously-processed + module sharing the same source memory so the downstream data_ptr dedup can + collapse them. The cache is owned by the caller (typically + ``_export_transformers_checkpoint``) and scoped to one export invocation; + when ``_tied_cache`` is ``None`` (the default) the alias step is skipped + entirely. Uses memory identity only — no ``_tied_weights_keys`` lookup, + no-op for non-tied modules. """ quantization_format = get_quantization_format(sub_module) if quantization_format == QUANTIZATION_NONE: @@ -528,6 +543,13 @@ def _export_quantized_weight( block_size = get_weight_block_size(sub_module, weight_name) quantizer_attrs = quantizer_attr_names(weight_name) weight: nn.Parameter = getattr(sub_module, weight_name) + + # Capture source identity BEFORE any tensor-creating operation below. + # For HF-tied weights this matches across all modules sharing the + # underlying Parameter; the cache lookup at the end of this function + # uses it to detect ties whose Python identity is about to be broken + # by the setattr on `weight_name` further down. + _tied_source_data_ptr = weight.data_ptr() weight_quantizer: TensorQuantizer | SequentialQuantizer = getattr( sub_module, quantizer_attrs.weight_quantizer ) @@ -703,14 +725,197 @@ def _export_quantized_weight( if weight_scale is not None: sub_module.register_buffer(quantizer_attrs.weight_scale, weight_scale) + # Tied-weight dedup: if a previously-processed module shared the same + # source weight memory, alias the packed weight + scale buffers so the + # downstream data_ptr dedup in postprocess_state_dict can collapse them. + # input_scale is safe to alias because sync_tied_input_amax (earlier in + # this export) already max-merged the per-side amaxes. Gated on the + # caller-owned _tied_cache so the dedup state is scoped to one export. + if _tied_cache is not None: + _prior = _tied_cache.get(_tied_source_data_ptr) + if _prior is not None and _prior is not sub_module: + if hasattr(_prior, weight_name): + setattr(sub_module, weight_name, getattr(_prior, weight_name)) + for _attr in ( + quantizer_attrs.weight_scale, + quantizer_attrs.weight_scale_2, + quantizer_attrs.input_scale, + ): + if not hasattr(_prior, _attr): + continue + if _attr in sub_module._buffers: + del sub_module._buffers[_attr] + elif hasattr(sub_module, _attr): + delattr(sub_module, _attr) + sub_module.register_buffer(_attr, getattr(_prior, _attr)) + else: + _tied_cache[_tied_source_data_ptr] = sub_module + torch.cuda.empty_cache() +def _collect_canonical_tied_patterns( + model: nn.Module, +) -> tuple[list[re.Pattern], list[str]]: + """Walk the model and collect canonical-side tied-weight matchers. + + Patterns are submodule-prefixed regexes from each module's + ``_tied_weights_keys`` dict-style declaration (the prefix matters + for nested models where the dict lives on an inner submodule). + Side substrings are dot-separated tokens that appear only on the + canonical side of those declarations — needed because modelopt's + per-expert unpacking creates post-export keys (e.g. + ``…experts.Y.gate_proj.input_scale``) that HF's regexes never knew + about. List-style (legacy) declarations are skipped. + """ + patterns: list[re.Pattern] = [] + alias_token_set: set[str] = set() + canonical_token_set: set[str] = set() + + def _tokens(s: str) -> set[str]: + """Identifiers in a regex string, with regex specials as separators.""" + return {tok for tok in re.split(r"[^A-Za-z0-9_]+", s) if tok} + + for name, submodule in model.named_modules(): + tied = getattr(submodule, "_tied_weights_keys", None) + if not isinstance(tied, dict) or not tied: + continue + prefix = f"{name}." if name else "" + for alias_pat, canonical_pat in tied.items(): + patterns.append(re.compile(prefix + canonical_pat)) + alias_token_set.update(_tokens(prefix + alias_pat)) + canonical_token_set.update(_tokens(prefix + canonical_pat)) + + # Tokens unique to the canonical side become substring matchers. + side_substrings = sorted(canonical_token_set - alias_token_set) + return patterns, side_substrings + + +def _reorder_canonical_first(state_dict: dict, model: nn.Module) -> dict: + r"""Reorder ``state_dict`` so canonical-side tied keys iterate first. + + Lets the downstream first-wins data_ptr dedup keep canonical names. + Uses both regex patterns and substring matchers from + :func:`_collect_canonical_tied_patterns`. No-op when the model + declares no dict-style ``_tied_weights_keys``. + """ + canonical_patterns, side_substrings = _collect_canonical_tied_patterns(model) + if not canonical_patterns and not side_substrings: + return state_dict + + def _has_side_substring(key: str) -> bool: + # Require the token to appear as a proper dot-separated path + # component, not just as a substring of an unrelated identifier. + for tok in side_substrings: + if ( + f".{tok}." in key + or key.startswith(f"{tok}.") + or key.endswith(f".{tok}") + or key == tok + ): + return True + return False + + head: dict = {} + tail: dict = {} + for k, v in state_dict.items(): + if any(p.search(k) for p in canonical_patterns) or _has_side_substring(k): + head[k] = v + else: + tail[k] = v + head.update(tail) + return head + + +def sync_tied_input_amax(model: nn.Module) -> int: + """Max-merge input_quantizer amaxes across modules sharing a weight ``data_ptr``. + + Mutates ``model`` in place: overwrites the ``.amax`` buffer on every + affected ``input_quantizer`` with the per-group maximum. Intended to + run as part of an export pipeline that already replaces weights with + packed bytes downstream — i.e. the model is not expected to be reused + after this helper runs. + + Closes the loop on ``input_scale`` for HF-tied modules whose forward + paths see different activation distributions (encoder vs decoder in + YOCO-style models). Must run BEFORE per-module export so the merged + amax flows into ``input_scale`` derivation. Handles both dense + Linears (keyed by ``weight.data_ptr()``) and fused MoE (keyed by + ``(gate_up_proj, down_proj)`` data_ptr tuple). Returns the number of + tied groups merged. + """ + from collections import defaultdict + + by_dp: dict = defaultdict(list) + for _, m in model.named_modules(): + # Fused MoE: 3-D source tensors with shared input quantizers + if ( + hasattr(m, "gate_up_proj_input_quantizer") + and hasattr(m, "gate_up_proj") + and hasattr(m, "down_proj") + and m.gate_up_proj.dim() == 3 + ): + key = ("moe", m.gate_up_proj.data_ptr(), m.down_proj.data_ptr()) + by_dp[key].append(m) + # Dense quantized Linear with an input_quantizer + elif ( + hasattr(m, "input_quantizer") + and hasattr(m, "weight") + and isinstance(m.weight, torch.nn.Parameter) + ): + by_dp[("dense", m.weight.data_ptr())].append(m) + + def _merge(quantizers: list) -> bool: + """Max-merge amaxes across the quantizer list. Returns True on merge.""" + valid = [ + q + for q in quantizers + if q is not None + and getattr(q, "is_enabled", False) + and getattr(q, "_amax", None) is not None + and not q._amax.is_meta + ] + if len(valid) < 2: + return False + # Require scalar (per-tensor) amax — matches preprocess_linear_fusion. + if any(q._amax.numel() != 1 for q in valid): + warnings.warn( + "sync_tied_input_amax: non-scalar input_quantizer amax encountered " + "in a tied group; skipping. Only per-tensor input quantizers are " + "supported for tied-modules merging." + ) + return False + merged = torch.max(torch.stack([q.amax for q in valid])) + for q in valid: + q.amax = merged.clone() + return True + + synced = 0 + for key, modules in by_dp.items(): + if len(modules) < 2: + continue + if key[0] == "moe": + for q_name in ("gate_up_proj_input_quantizer", "down_proj_input_quantizer"): + if _merge([getattr(m, q_name, None) for m in modules]): + synced += 1 + elif _merge([m.input_quantizer for m in modules]): + synced += 1 + return synced + + def _process_quantized_modules( model: nn.Module, dtype: torch.dtype, is_modelopt_qlora: bool = False, ) -> None: + # Per-call tied-weight dedup caches. Created fresh on every invocation + # so cache state is scoped to one export and cannot leak into a later + # call (a process-global cache would carry stale entries whose data_ptr + # keys can be recycled by PyTorch's allocator across exports — silent + # false-positive aliasing). int keys hold dense Linear / per-expert + # wrapper dedup; tuple keys hold MoE fused-experts module dedup. + _tied_cache: dict[int, nn.Module] = {} + _moe_tied_cache: dict[tuple[int, int], nn.Module] = {} """Process all quantized modules in model, export weights in-place. This function iterates through all modules in the model and exports quantized weights @@ -752,7 +957,12 @@ def _process_quantized_modules( # which get_quantization_format's singular-weight_quantizer check misses. Handle # it explicitly before the format gate so fused-experts get split + quantized. with fsdp2_aware_weight_update(model, sub_module, reshard=False): - _export_fused_experts(sub_module, dtype) + _export_fused_experts( + sub_module, + dtype, + _moe_tied_cache=_moe_tied_cache, + _tied_cache=_tied_cache, + ) elif get_quantization_format(sub_module) != QUANTIZATION_NONE: # Skip QuantMoELinear - it's handled separately in _reconstruct_fused_moe_linear if type(sub_module).__name__ == "QuantMoELinear": @@ -760,7 +970,7 @@ def _process_quantized_modules( if is_quantlinear(sub_module): try: with fsdp2_aware_weight_update(model, sub_module, reshard=False): - _export_quantized_weight(sub_module, dtype) + _export_quantized_weight(sub_module, dtype, _tied_cache=_tied_cache) except AssertionError as e: raise AssertionError( f"Failed to export module '{name}' (type={type(sub_module).__name__}): {e}" @@ -788,7 +998,7 @@ def _process_quantized_modules( else: try: with fsdp2_aware_weight_update(model, sub_module, reshard=False): - _export_quantized_weight(sub_module, dtype) + _export_quantized_weight(sub_module, dtype, _tied_cache=_tied_cache) except AssertionError as e: raise AssertionError( f"Failed to export embedding '{name}' (type={type(sub_module).__name__}): {e}" @@ -811,11 +1021,17 @@ def _process_quantized_modules( # Export the quantized weights with fsdp2_aware_weight_update(model, sub_module, reshard=False): for weight_name in ["gate_up_proj", "down_proj"]: - _export_quantized_weight(sub_module, dtype, weight_name) + _export_quantized_weight( + sub_module, dtype, weight_name, _tied_cache=_tied_cache + ) def _export_transformers_checkpoint( - model: nn.Module, dtype: torch.dtype | None = None, is_modelopt_qlora: bool = False, **kwargs + model: nn.Module, + dtype: torch.dtype | None = None, + is_modelopt_qlora: bool = False, + canonical_tied_naming: bool = False, + **kwargs, ) -> tuple[dict[str, Any], dict[str, Any]]: """Exports the torch model to the packed checkpoint with original HF naming. @@ -940,6 +1156,15 @@ def _export_transformers_checkpoint( f"Taking element-wise max of amaxes for serving-engine fusion." ) + # Merge per-side input_quantizer amaxes BEFORE _process_quantized_modules, + # so the merged value flows into input_scale derivation downstream. + synced_input = sync_tied_input_amax(model) + if synced_input: + print( + f"sync_tied_input_amax: max-merged input_quantizer amaxes across " + f"{synced_input} tied module group(s)" + ) + # Process all quantized modules and export weights _process_quantized_modules(model, dtype, is_modelopt_qlora) @@ -957,6 +1182,16 @@ def _export_transformers_checkpoint( # We define kv cache scale as amax / 448 for both FP8 and NVFP4 KV cache quantization. kv_cache_max_bound = 448 kv_cache_format = quant_config["quantization"]["kv_cache_quant_algo"] + + # Optionally reorder so canonical-side tied keys (per HF's + # _tied_weights_keys) iterate first into postprocess_state_dict's + # first-wins data_ptr dedup. Off by default to avoid renaming exported + # keys for models whose downstream consumers expect the legacy + # (registration-order) winner; opt in for models where matching HF's + # own naming convention matters (e.g. DiffusionGemma4 → decoder names). + if canonical_tied_naming: + quantized_state_dict = _reorder_canonical_first(quantized_state_dict, model) + quantized_state_dict = postprocess_state_dict( quantized_state_dict, kv_cache_max_bound, kv_cache_format, is_modelopt_qlora ) @@ -1294,6 +1529,7 @@ def export_hf_checkpoint( components: list[str] | None = None, extra_state_dict: dict[str, torch.Tensor] | None = None, max_shard_size: int | str = "10GB", + canonical_tied_naming: bool = False, **kwargs, ): """Export quantized HuggingFace model checkpoint (transformers or diffusers). @@ -1313,6 +1549,11 @@ def export_hf_checkpoint( to export. If None, all quantized components are exported. extra_state_dict: Extra state dictionary to add to the exported model. max_shard_size: Maximum size of each safetensors shard file. Defaults to "10GB". + canonical_tied_naming: If True, reorder the state_dict so tied-weight + aliases dedup to the canonical side declared in the model's HF + ``_tied_weights_keys`` (e.g. decoder-side for DiffusionGemma4). + Off by default to avoid renaming exported keys for models whose + downstream consumers expect the legacy (registration-order) winner. **kwargs: Runtime-specific post-processing options forwarded to :func:`_postprocess_safetensors` for diffusion model exports. See its docstring for supported keys. @@ -1335,7 +1576,9 @@ def export_hf_checkpoint( return try: - post_state_dict, hf_quant_config = _export_transformers_checkpoint(model, dtype) + post_state_dict, hf_quant_config = _export_transformers_checkpoint( + model, dtype, canonical_tied_naming=canonical_tied_naming, **kwargs + ) # Only treat the export as quantized when at least one quant_algo field is set. # get_quant_config always returns a dict (even for sparsity-only or unmodified models), diff --git a/modelopt/torch/utils/dataset_utils.py b/modelopt/torch/utils/dataset_utils.py index 8b3fd0b067..318273d9b7 100644 --- a/modelopt/torch/utils/dataset_utils.py +++ b/modelopt/torch/utils/dataset_utils.py @@ -1231,7 +1231,17 @@ def create_forward_loop( def model_type_is_enc_dec(model): - enc_dec_model_list = ["t5", "bart", "whisper"] + # Substring match against `model.__class__.__name__.lower()` — entries are + # the lowercased class-name form (no underscores). Calibration then uses + # `model.generate` to run the full denoising loop. + # + # Note: this list intentionally diverges from ``is_enc_dec`` in + # ``examples/llm_ptq/example_utils.py`` (which keys by ``model_type`` + # string and is used for preview-decode slicing). DiffusionGemma is + # included here so calibration uses ``.generate()`` end-to-end, but + # deliberately excluded there so the preview decode treats its + # prompt+canvas output as AR-style. + enc_dec_model_list = ["t5", "bart", "whisper", "diffusiongemma"] return any(model_name in model.__class__.__name__.lower() for model_name in enc_dec_model_list) diff --git a/modelopt_recipes/huggingface/diffusiongemma/ptq/README.md b/modelopt_recipes/huggingface/diffusiongemma/ptq/README.md new file mode 100644 index 0000000000..78b1a428bc --- /dev/null +++ b/modelopt_recipes/huggingface/diffusiongemma/ptq/README.md @@ -0,0 +1,14 @@ +# DiffusionGemma PTQ recipes + +DiffusionGemma is a block-diffusion encoder-decoder text LLM with a Gemma4 MoE +backbone shared between an encoder pass and a 48-step iterative decoder. +Quantization targets the MoE experts; the self-conditioning network is +text-only and is not exercised by standard PTQ calibration data, so its +``TensorQuantizer`` observers never see input and ``_export_quantized_weight`` +crashes on the missing ``_amax``. These recipes apply the model-specific +``*self_conditioning*`` exclude on top of the standard default exclusions. + +| File | What's model-specific | +|------|-----------------------| +| `disabled_quantizers.yaml` | Reusable unit (`QuantizerCfgListConfig`). Merges the standard `default_disabled_quantizers` exclusions with the DiffusionGemma-specific `*self_conditioning*` exclude. Imported by the recipe below as the single `disabled_quantizers` slot so it doesn't pull in two disabled-quantizer sets. | +| `nvfp4_experts_only.yaml` | Dynamic W4A4 NVFP4 quantization on MoE experts only (attention and dense MLP stay bf16). Identical numerics to the general `nvfp4_experts_only` preset; what makes it model-specific is that it imports `disabled_quantizers.yaml` from this folder to skip the self-conditioning network. | diff --git a/modelopt_recipes/huggingface/diffusiongemma/ptq/disabled_quantizers.yaml b/modelopt_recipes/huggingface/diffusiongemma/ptq/disabled_quantizers.yaml new file mode 100644 index 0000000000..9879c607fd --- /dev/null +++ b/modelopt_recipes/huggingface/diffusiongemma/ptq/disabled_quantizers.yaml @@ -0,0 +1,30 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# QuantizerCfgList snippet of disabled quantizers for DiffusionGemma. Splices +# in the standard ``default_disabled_quantizers`` exclusions and appends the +# DiffusionGemma-specific self-conditioning network so that text-only PTQ +# calibration data doesn't trip on its uncalibrated TensorQuantizers (export +# would otherwise crash with "AttributeError: 'TensorQuantizer' object has no +# attribute '_amax'"). Recipes that import this should NOT also import +# ``default_disabled_quantizers``. + +# modelopt-schema: modelopt.torch.quantization.config.QuantizerCfgListConfig +imports: + default_disabled_quantizers: configs/ptq/units/default_disabled_quantizers +--- + - $import: default_disabled_quantizers + - quantizer_name: '*self_conditioning*' + enable: false diff --git a/modelopt_recipes/huggingface/diffusiongemma/ptq/nvfp4_experts_only.yaml b/modelopt_recipes/huggingface/diffusiongemma/ptq/nvfp4_experts_only.yaml new file mode 100644 index 0000000000..a9e1072504 --- /dev/null +++ b/modelopt_recipes/huggingface/diffusiongemma/ptq/nvfp4_experts_only.yaml @@ -0,0 +1,42 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DiffusionGemma-specific PTQ recipe for the ``nvfp4_experts_only`` qformat. +# Equivalent to the general ``nvfp4_experts_only`` preset +# (``configs/ptq/presets/model/nvfp4_experts_only``) with the +# self-conditioning network additionally disabled, applied via the local +# ``disabled_quantizers`` unit that splices in the standard default plus +# ``*self_conditioning*``. + +imports: + base_disable_all: configs/ptq/units/base_disable_all + experts_nvfp4: configs/ptq/units/experts_nvfp4 + disabled_quantizers: huggingface/diffusiongemma/ptq/disabled_quantizers + +metadata: + recipe_type: ptq + description: >- + DiffusionGemma PTQ recipe (nvfp4_experts_only): same numerics as the + general nvfp4_experts_only preset (dynamic W4A4 NVFP4 on MoE experts only; + attention and dense MLP stay bf16), with the self-conditioning network + additionally disabled so text-only calibration doesn't trip on its + uncalibrated quantizers. + +quantize: + algorithm: max + quant_cfg: + - $import: base_disable_all + - $import: experts_nvfp4 + - $import: disabled_quantizers diff --git a/tests/_test_utils/torch/quantization/tied_modules.py b/tests/_test_utils/torch/quantization/tied_modules.py new file mode 100644 index 0000000000..8ea76d2d45 --- /dev/null +++ b/tests/_test_utils/torch/quantization/tied_modules.py @@ -0,0 +1,115 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Factories for tied-weight test scenarios. + +These build small synthetic modules whose ``.weight`` :class:`nn.Parameter` is +shared between two sibling modules — mimicking HuggingFace's +``_tied_weights_keys`` machinery — for unit-testing the export-time dedup, +canonical-side naming, and per-side ``input_quantizer.amax`` merge logic in +the HF export path. + +Every factory returns CPU-resident, float32-default modules; no GPU required. +Each factory asserts its own post-conditions before returning, so a broken +tie surfaces as a clear factory-side error rather than as a downstream test +failure with an ambiguous cause. +""" + +import re + +import torch.nn as nn + + +def make_tied_linear_pair( + in_features: int = 16, + out_features: int = 32, + bias: bool = False, +) -> tuple[nn.Linear, nn.Linear]: + """Two :class:`nn.Linear` modules whose ``.weight`` Parameter is shared. + + Mimics what HuggingFace's :meth:`PreTrainedModel.tie_weights` does after + ``__init__``: one extra ``setattr`` so that both modules' ``.weight`` + attributes resolve to the same :class:`nn.Parameter` and therefore the + same underlying storage. The modules are otherwise independent — separate + biases (if requested), separate forward/training state, separate + quantizer slots when ``mtq.quantize`` inserts them later. + """ + enc = nn.Linear(in_features, out_features, bias=bias) + dec = nn.Linear(in_features, out_features, bias=bias) + dec.weight = enc.weight # mimics HF tie_weights() + + # Post-conditions — fail loudly if the tie was somehow lost. + assert enc.weight is dec.weight, "Linear weights not tied (object identity)" + assert enc.weight.data_ptr() == dec.weight.data_ptr(), ( + "Linear weights tied at object level but storage diverged" + ) + return enc, dec + + +def tie_fused_experts_3d_params(enc: nn.Module, dec: nn.Module) -> None: + """Tie ``gate_up_proj`` and ``down_proj`` between two fused-experts modules. + + Mutates ``dec`` in place. After calling, ``dec.gate_up_proj`` IS + ``enc.gate_up_proj`` (same :class:`nn.Parameter`) and likewise for + ``down_proj``. Used by MoE-dedup tests together with the + ``_SyntheticFusedExperts`` fixture defined in + ``tests/unit/torch/quantization/plugins/test_fused_experts.py``. + """ + dec.gate_up_proj = enc.gate_up_proj + dec.down_proj = enc.down_proj + + assert enc.gate_up_proj is dec.gate_up_proj, "gate_up_proj not tied" + assert enc.down_proj is dec.down_proj, "down_proj not tied" + assert enc.gate_up_proj.data_ptr() == dec.gate_up_proj.data_ptr() + assert enc.down_proj.data_ptr() == dec.down_proj.data_ptr() + + +def wrap_in_parent_with_tied_keys( + enc: nn.Module, + dec: nn.Module, + *, + decoder_canonical: bool = True, + weight_attr: str = "weight", +) -> nn.Module: + """Wrap two tied modules in a parent that declares HF ``_tied_weights_keys``. + + Returns a parent :class:`nn.Module` with: + + - ``parent.encoder = enc`` — registered as a submodule (alias side). + - ``parent.decoder = dec`` — registered as a submodule (canonical side + when ``decoder_canonical=True``, the default and DiffusionGemma-like case). + - ``parent._tied_weights_keys``: dict-style ``{alias_regex: canonical}`` + when ``decoder_canonical=True``, list-style (legacy, no canonical/alias + distinction) when ``decoder_canonical=False``. + + Used by tests for :func:`_collect_canonical_tied_patterns` and + :func:`_reorder_canonical_first`. The legacy list-style branch exercises + the "no patterns extracted" negative case. + """ + parent = nn.Module() + parent.encoder = enc + parent.decoder = dec + + if decoder_canonical: + # Dict-style: regex pattern → canonical path. Mimics HF's per-class + # ``_tied_weights_keys`` declaration for an encoder/decoder model. + parent._tied_weights_keys = { + rf"^encoder\.{re.escape(weight_attr)}$": f"decoder.{weight_attr}", + } + else: + # Legacy list-style: just a list of tied paths, no canonical info. + parent._tied_weights_keys = [f"encoder.{weight_attr}"] + + return parent diff --git a/tests/unit/torch/export/test_unified_export_hf.py b/tests/unit/torch/export/test_unified_export_hf.py new file mode 100644 index 0000000000..bc856883b7 --- /dev/null +++ b/tests/unit/torch/export/test_unified_export_hf.py @@ -0,0 +1,182 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for tied-weight helpers in unified_export_hf.""" + +from collections import OrderedDict + +import torch +from _test_utils.torch.quantization.tied_modules import ( + make_tied_linear_pair, + wrap_in_parent_with_tied_keys, +) + +import modelopt.torch.quantization as mtq +from modelopt.torch.export.unified_export_hf import ( + _collect_canonical_tied_patterns, + _export_quantized_weight, + _reorder_canonical_first, + sync_tied_input_amax, +) + + +def test_collect_canonical_tied_patterns_dict_style(): + """Dict-style _tied_weights_keys yields regex patterns + canonical-side substrings.""" + enc, dec = make_tied_linear_pair() + parent = wrap_in_parent_with_tied_keys(enc, dec, decoder_canonical=True) + + patterns, side_substrings = _collect_canonical_tied_patterns(parent) + + assert len(patterns) >= 1 + # "decoder" is in the canonical RHS but not the alias LHS — must auto-derive. + # "encoder" is alias-only and must NOT be returned as canonical (would invert dedup). + assert "decoder" in side_substrings + assert "encoder" not in side_substrings + + +def test_collect_canonical_tied_patterns_list_style_yields_no_canonical_info(): + """Legacy list-style _tied_weights_keys carries no canonical/alias info — returns empty.""" + enc, dec = make_tied_linear_pair() + parent = wrap_in_parent_with_tied_keys(enc, dec, decoder_canonical=False) + + patterns, side_substrings = _collect_canonical_tied_patterns(parent) + + assert patterns == [] + assert side_substrings == [] + + +def test_reorder_canonical_first_puts_decoder_keys_before_encoder_keys(): + """_reorder_canonical_first moves canonical-side state_dict keys ahead of alias-side keys.""" + enc, dec = make_tied_linear_pair() + parent = wrap_in_parent_with_tied_keys(enc, dec, decoder_canonical=True) + + sd = OrderedDict( + [ + ("encoder.weight", torch.zeros(1)), + ("unrelated.foo", torch.zeros(1)), + ("decoder.weight", torch.zeros(1)), + ] + ) + + reordered = _reorder_canonical_first(sd, parent) + keys = list(reordered.keys()) + + assert keys.index("decoder.weight") < keys.index("encoder.weight") + assert set(reordered) == set(sd) # no drops or additions + + +def _quantize_and_get_input_quantizers(parent): + """Insert FP8 quantizers via no-op forward_loop and return both input_quantizers.""" + mtq.quantize(parent, mtq.FP8_DEFAULT_CFG, forward_loop=lambda m: None) + return parent.encoder.input_quantizer, parent.decoder.input_quantizer + + +def test_sync_tied_input_amax_max_merges_tied_module_amaxes_in_place(): + """Tied Linears with divergent input_quantizer.amax get both sides overwritten with the max.""" + enc, dec = make_tied_linear_pair() + parent = wrap_in_parent_with_tied_keys(enc, dec, decoder_canonical=True) + enc_q, dec_q = _quantize_and_get_input_quantizers(parent) + + enc_q.amax = torch.tensor(2.0) + dec_q.amax = torch.tensor(5.0) + + sync_tied_input_amax(parent) + + expected = torch.tensor(5.0) + assert torch.allclose(enc_q.amax, expected) + assert torch.allclose(dec_q.amax, expected) + + +def test_sync_tied_input_amax_no_op_for_untied_modules(): + """Untied Linears keep their per-side amaxes — the helper is a no-op when there's no tie.""" + parent = torch.nn.Module() + parent.encoder = torch.nn.Linear(16, 32, bias=False) + parent.decoder = torch.nn.Linear(16, 32, bias=False) + enc_q, dec_q = _quantize_and_get_input_quantizers(parent) + + enc_q.amax = torch.tensor(2.0) + dec_q.amax = torch.tensor(5.0) + + sync_tied_input_amax(parent) + + assert torch.allclose(enc_q.amax, torch.tensor(2.0)) + assert torch.allclose(dec_q.amax, torch.tensor(5.0)) + + +def _calibrate_through_both_children(parent): + """Insert NVFP4 quantizers and run a one-shot forward through both children for calibration.""" + + def forward_loop(m): + x = torch.randn(2, 16) + m.encoder(x) + m.decoder(x) + + mtq.quantize(parent, mtq.NVFP4_DEFAULT_CFG, forward_loop=forward_loop) + + +def test_export_quantized_weight_aliases_packed_weight_for_tied_linears(): + """Tied Linears share data_ptr for packed .weight and scale buffers after export.""" + enc, dec = make_tied_linear_pair() + parent = wrap_in_parent_with_tied_keys(enc, dec) + _calibrate_through_both_children(parent) + + # Per-call dedup cache (the production pattern: caller owns the cache, scoped + # to one export invocation). Threaded through both sides of the tied pair so + # the alias step at the end of _export_quantized_weight catches the dedup. + tied_cache: dict = {} + _export_quantized_weight(enc, torch.float16, "weight", _tied_cache=tied_cache) + _export_quantized_weight(dec, torch.float16, "weight", _tied_cache=tied_cache) + + assert enc.weight.data_ptr() == dec.weight.data_ptr() + for scale_attr in ("weight_scale", "weight_scale_2"): + if hasattr(enc, scale_attr) and hasattr(dec, scale_attr): + assert getattr(enc, scale_attr).data_ptr() == getattr(dec, scale_attr).data_ptr() + + +def test_export_quantized_weight_no_alias_for_untied_linears(): + """Untied Linears keep independent data_ptrs after export — no false-positive aliasing.""" + parent = torch.nn.Module() + parent.encoder = torch.nn.Linear(16, 32, bias=False) + parent.decoder = torch.nn.Linear(16, 32, bias=False) + assert parent.encoder.weight.data_ptr() != parent.decoder.weight.data_ptr() + _calibrate_through_both_children(parent) + + # Same fresh cache shape as the positive case — confirms that even with + # dedup enabled, untied modules with distinct source data_ptrs do not get + # falsely aliased. + tied_cache: dict = {} + _export_quantized_weight(parent.encoder, torch.float16, "weight", _tied_cache=tied_cache) + _export_quantized_weight(parent.decoder, torch.float16, "weight", _tied_cache=tied_cache) + + assert parent.encoder.weight.data_ptr() != parent.decoder.weight.data_ptr() + + +def test_export_quantized_weight_skips_alias_when_one_tied_side_is_unquantized(): + """Unquantized side early-returns; its .weight stays at the original shared Parameter.""" + enc, dec = make_tied_linear_pair() + parent = wrap_in_parent_with_tied_keys(enc, dec) + original_shared_data_ptr = enc.weight.data_ptr() + + _calibrate_through_both_children(parent) + # is_enabled is a read-only property; .disable() is the canonical bypass. + dec.weight_quantizer.disable() + + tied_cache: dict = {} + _export_quantized_weight(enc, torch.float16, "weight", _tied_cache=tied_cache) + _export_quantized_weight(dec, torch.float16, "weight", _tied_cache=tied_cache) + + assert enc.weight.data_ptr() != original_shared_data_ptr # encoder got fresh packed + assert dec.weight.data_ptr() == original_shared_data_ptr # decoder untouched + assert enc.weight.data_ptr() != dec.weight.data_ptr() diff --git a/tests/unit/torch/quantization/plugins/test_fused_experts.py b/tests/unit/torch/quantization/plugins/test_fused_experts.py index ce23f7a51d..550c27c46f 100644 --- a/tests/unit/torch/quantization/plugins/test_fused_experts.py +++ b/tests/unit/torch/quantization/plugins/test_fused_experts.py @@ -22,6 +22,8 @@ pytest.importorskip("transformers") +from _test_utils.torch.quantization.tied_modules import tie_fused_experts_3d_params + import modelopt.torch.quantization as mtq from modelopt.torch.export.moe_utils import _export_fused_experts from modelopt.torch.export.quant_utils import get_quant_config @@ -365,7 +367,7 @@ def test_uncalibrated_expert_gate_up_share_amax(self, monkeypatch): # FP4 quantization step. Patching here avoids needing CUDA / FP4. seen = {} # (expert_idx, proj_name) -> amax tensor - def _spy_export(wrapper, dtype): + def _spy_export(wrapper, dtype, **_kwargs): # Identify which expert/projection this wrapper belongs to by # matching the weight tensor against the fused parameters. w = wrapper.weight.data @@ -463,7 +465,7 @@ def test_per_block_amax_reshape_for_fused_export(self, monkeypatch): seen = {} - def _spy_export(wrapper, dtype): + def _spy_export(wrapper, dtype, **_kwargs): w = wrapper.weight.data wq = wrapper.weight_quantizer amax = wq._amax.detach().clone() if hasattr(wq, "_amax") else None @@ -514,6 +516,132 @@ def _spy_export(wrapper, dtype): QuantModuleRegistry.unregister(expert_type) +# --------------------------------------------------------------------------- +# Tests for tied-experts dedup in _export_fused_experts +# --------------------------------------------------------------------------- +def _build_two_moe_blocks(tie: bool) -> nn.Module: + """Build a parent with two _SyntheticSparseMoeBlock children, optionally with tied 3-D params.""" + parent = nn.Module() + parent.encoder = _SyntheticSparseMoeBlock() + parent.decoder = _SyntheticSparseMoeBlock() + if tie: + tie_fused_experts_3d_params(parent.encoder.experts, parent.decoder.experts) + return parent + + +def _moe_fp8_quant_cfg(): + """Custom inline FP8 cfg targeting the MoE-specific quantizer names.""" + return { + "quant_cfg": [ + {"quantizer_name": "*", "enable": False}, + { + "quantizer_name": "*gate_up_proj_input_quantizer", + "cfg": {"num_bits": 8, "axis": None}, + }, + {"quantizer_name": "*down_proj_input_quantizer", "cfg": {"num_bits": 8, "axis": None}}, + {"quantizer_name": "*gate_up_proj_weight_quantizer", "cfg": {"num_bits": 8, "axis": 0}}, + {"quantizer_name": "*down_proj_weight_quantizer", "cfg": {"num_bits": 8, "axis": 0}}, + ], + "algorithm": "max", + } + + +def _calibrate_two_moe_blocks(parent): + """Fire one calibration batch through both encoder.experts and decoder.experts.""" + + def forward_loop(m): + torch.manual_seed(0) + x = torch.randn(1, 4, HIDDEN_DIM) + m.encoder(x) + m.decoder(x) + + mtq.quantize(parent, _moe_fp8_quant_cfg(), forward_loop=forward_loop) + + +class TestExportFusedExpertsTiedDedup: + @staticmethod + def _cleanup_registry(mod_type): + if QuantModuleRegistry.get(mod_type) is not None: + QuantModuleRegistry.unregister(mod_type) + + def test_per_expert_buffers_share_data_ptr_for_tied_fused_experts(self): + """Two tied FusedExperts modules: every per-expert .weight + scale buffer shares data_ptr.""" + parent = _build_two_moe_blocks(tie=True) + expert_type = type(parent.encoder.experts) + self._cleanup_registry(expert_type) + try: + _calibrate_two_moe_blocks(parent) + + # Per-call dedup caches threaded through both export calls; int keys + # for per-expert wrapper dedup, tuple keys for module-level dedup. + tied_cache: dict = {} + moe_tied_cache: dict = {} + _export_fused_experts( + parent.encoder.experts, + torch.float16, + _moe_tied_cache=moe_tied_cache, + _tied_cache=tied_cache, + ) + _export_fused_experts( + parent.decoder.experts, + torch.float16, + _moe_tied_cache=moe_tied_cache, + _tied_cache=tied_cache, + ) + + for idx in range(NUM_EXPERTS): + enc_expert = getattr(parent.encoder.experts, str(idx)) + dec_expert = getattr(parent.decoder.experts, str(idx)) + for proj_name in ("gate_proj", "up_proj", "down_proj"): + enc_proj = getattr(enc_expert, proj_name) + dec_proj = getattr(dec_expert, proj_name) + assert enc_proj.weight.data_ptr() == dec_proj.weight.data_ptr() + for scale_attr in ("weight_scale", "weight_scale_2"): + if hasattr(enc_proj, scale_attr) and hasattr(dec_proj, scale_attr): + assert ( + getattr(enc_proj, scale_attr).data_ptr() + == getattr(dec_proj, scale_attr).data_ptr() + ) + finally: + self._cleanup_registry(expert_type) + + def test_per_expert_buffers_have_independent_data_ptrs_for_untied_fused_experts(self): + """Two untied FusedExperts modules: per-expert buffers stay independent (no false-positive alias).""" + parent = _build_two_moe_blocks(tie=False) + expert_type = type(parent.encoder.experts) + self._cleanup_registry(expert_type) + try: + _calibrate_two_moe_blocks(parent) + + # Same fresh caches as the positive case — confirms that even with + # dedup enabled, untied modules with distinct source data_ptrs do + # not get falsely aliased. + tied_cache: dict = {} + moe_tied_cache: dict = {} + _export_fused_experts( + parent.encoder.experts, + torch.float16, + _moe_tied_cache=moe_tied_cache, + _tied_cache=tied_cache, + ) + _export_fused_experts( + parent.decoder.experts, + torch.float16, + _moe_tied_cache=moe_tied_cache, + _tied_cache=tied_cache, + ) + + for idx in range(NUM_EXPERTS): + enc_expert = getattr(parent.encoder.experts, str(idx)) + dec_expert = getattr(parent.decoder.experts, str(idx)) + for proj_name in ("gate_proj", "up_proj", "down_proj"): + enc_proj = getattr(enc_expert, proj_name) + dec_proj = getattr(dec_expert, proj_name) + assert enc_proj.weight.data_ptr() != dec_proj.weight.data_ptr() + finally: + self._cleanup_registry(expert_type) + + # --------------------------------------------------------------------------- # Tests for force_eager_experts_impl_on_the_fly # ---------------------------------------------------------------------------