diff --git a/.github/actions/cache-extensions/action.yml b/.github/actions/cache-extensions/action.yml new file mode 100644 index 00000000000..3df12f8f6fc --- /dev/null +++ b/.github/actions/cache-extensions/action.yml @@ -0,0 +1,24 @@ +name: Cache extensions +description: Cache and reuse JIT-compiled extensions (e.g. CUDA extensions) across runs. + +inputs: + cache-key: + description: Environment discriminator for the cache key (e.g. GPU arch + container image). + required: true + +runs: + using: composite + steps: + - shell: bash + run: echo "TORCH_EXTENSIONS_DIR=/root/.cache/torch_extensions" >> "$GITHUB_ENV" + - id: cache + uses: actions/cache@v4 + with: + path: /root/.cache/torch_extensions + key: torch-ext-${{ inputs.cache-key }}-${{ hashFiles('modelopt/torch/kernels/quantization/**', 'modelopt/torch/quantization/extensions.py', 'modelopt/torch/utils/cpp_extension.py') + }} + # On a cache hit, backdate sources below the cached objects so ninja reuses them (touching + # the objects instead would desync ninja's .ninja_deps and force a rebuild). + - if: steps.cache.outputs.cache-hit == 'true' + shell: bash + run: find modelopt/torch/kernels/quantization -exec touch -d '2000-01-01' {} + diff --git a/.github/workflows/_example_tests_runner.yml b/.github/workflows/_example_tests_runner.yml index c7a04e17ef9..2bca58b8a81 100644 --- a/.github/workflows/_example_tests_runner.yml +++ b/.github/workflows/_example_tests_runner.yml @@ -26,7 +26,7 @@ on: description: "GitHub runner to use" required: false type: string - default: "linux-amd64-gpu-h100-latest-1" + default: "linux-amd64-gpu-rtxpro6000-latest-1" jobs: run-test: @@ -41,6 +41,9 @@ jobs: steps: - uses: actions/checkout@v6 - uses: nv-gha-runners/setup-proxy-cache@main + - uses: ./.github/actions/cache-extensions + with: + cache-key: rtxpro6000-${{ inputs.docker_image }} - name: Setup environment variables run: | echo "LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/include:/usr/lib/x86_64-linux-gnu:/usr/local/tensorrt/targets/x86_64-linux-gnu/lib" >> $GITHUB_ENV diff --git a/.github/workflows/example_tests.yml b/.github/workflows/example_tests.yml index e69dd08448b..f93bf891b16 100644 --- a/.github/workflows/example_tests.yml +++ b/.github/workflows/example_tests.yml @@ -35,7 +35,7 @@ jobs: strategy: fail-fast: false matrix: - example: [diffusers_sparsity, gpt-oss, llm_distill, llm_qat, llm_sparsity, specdec_bench] + example: [gpt-oss, llm_distill, llm_qat, llm_sparsity, specdec_bench] include: - example: speculative_decoding docker_image: "26.01" diff --git a/.github/workflows/gpu_tests.yml b/.github/workflows/gpu_tests.yml index ab23036e0f4..009089c1233 100644 --- a/.github/workflows/gpu_tests.yml +++ b/.github/workflows/gpu_tests.yml @@ -45,7 +45,7 @@ jobs: timeout: 60 container_image: nvcr.io/nvidia/nemo:26.04 - example: gpu_trtllm - timeout: 30 + timeout: 15 container_image: nvcr.io/nvidia/tensorrt-llm/release:1.3.0rc17 - example: gpu_vllm timeout: 15 @@ -66,6 +66,9 @@ jobs: run: apt-get update && apt-get install -y git - uses: actions/checkout@v6 - uses: nv-gha-runners/setup-proxy-cache@main + - uses: ./.github/actions/cache-extensions + with: + cache-key: rtxpro6000-${{ matrix.container_image }} - name: Setup environment variables run: | echo "LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/include:/usr/lib/x86_64-linux-gnu" >> $GITHUB_ENV diff --git a/.github/workflows/unit_tests.yml b/.github/workflows/unit_tests.yml index fc2ae364cba..0a67f5f0b1d 100644 --- a/.github/workflows/unit_tests.yml +++ b/.github/workflows/unit_tests.yml @@ -1,4 +1,3 @@ -# NOTE: Make sure this file is consistent with .gitlab/tests.yml name: Unit tests on: @@ -73,6 +72,10 @@ jobs: token: ${{ secrets.CODECOV_TOKEN }} flags: unit fail_ci_if_error: true + # Skip GPG/SHASUM integrity check of the Codecov CLI: its key import + # intermittently fails (codecov/codecov-action#1876), which would + # otherwise hard-fail this required job on a transient infra blip. + skip_validation: true verbose: true windows: if: needs.check-file-changes.outputs.any_changed == 'true' diff --git a/CHANGELOG.rst b/CHANGELOG.rst index da02b315f67..3448d4cee73 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -24,16 +24,19 @@ Changelog **New Features** - Extend Claude Code agent skills for PTQ, deployment, evaluation, monitoring, and baseline-vs-quantized result comparison. Adds evaluation task references for additional benchmarks, stronger PTQ checkpoint validation gates, and session-scoped workspace/job tracking. +- Add ``examples/alpamayo`` showing FP8, NVFP4, and AutoQuantize (mixed-precision) quantization of the Alpamayo (formerly Alpamayo-R1) ~10B vision-language-action model, with a joint VLM + diffusion calibration loop and both fake-quant and ``--real-quant`` packed-checkpoint export. See `examples/alpamayo/README.md `_ for details. - Add SLURM Quality of Service (QoS) support to the ModelOpt launcher. Users can set QoS via ``slurm_config.qos`` or ``SLURM_QOS`` and the value is forwarded to ``nemo_run.SlurmExecutor``. - Add composable ``$import`` system for recipe YAML configs, enabling reusable config snippets referenced via ``{$import: name}`` markers. All built-in PTQ recipes converted to use imports with shared snippets under ``modelopt_recipes/configs/`` (numeric formats, quant_cfg building blocks, presets). See :ref:`composable-imports`. - Add offline DFlash speculative decoding training. Train the draft module from pre-computed base-model hidden states dumped by ``examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py``; base-model transformer layers are deleted after conversion to save memory. Controlled by the auto-derived ``dflash_offline`` flag on ``DFlashConfig`` (derived from ``data_args.offline_data_path``). The dump scripts now share ``collect_hidden_states/common.py`` for aux-layer selection (``--aux-layers eagle|dflash|``) and optional assistant-token ``loss_mask`` for answer-only-loss training. - Add support for ``active_params`` (for MoE models) and ``memory_mb`` constraints in Minitron pruning on top of existing ``params`` constraint. You can also provide multiple constraints. See `examples/pruning/README.md `_ for more details. The underlying utility functions ``mcore_param_count``, ``mcore_memory_footprint_mb``, and ``print_mcore_model_stats`` in ``modelopt.torch.nas.plugins.megatron_model_stats`` are also available for standalone use to compute parameter counts and memory footprints (weights + KV-cache + Mamba state) for any Megatron-Core model. - Add Minitron pruning support for Megatron-Bridge Gemma3 models. - Add quantization examples for the Megatron-Bridge framework: post-training quantization (`quantize.py `_), export to a deployable HuggingFace checkpoint (`export.py `_), and Quantization Aware Distillation (extend existing `distill.py `_). -- Add end-to-end tutorial for Minitron pruning + two-phase distillation (80B @ 8K + 20B @ 32K long-context = 100B tokens) + FP8 PTQ + vLLM deployment for Nemotron-3-Nano-30B-A3B-BF16 (MoE + Mamba-Transformer hybrid) → Pruned 22B/A3.0B active params, along with data blend preparation steps (with tool-calling data) and detailed pruning / data-blend / long-context ablations. See `examples/pruning/minitron/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16/README.md `_ for details. +- Add end-to-end optimization tutorial for Minitron pruning + two-phase distillation (80B @ 8K + 20B @ 32K long-context = 100B tokens) + FP8 PTQ + vLLM deployment for Nemotron-3-Nano-30B-A3B-BF16 (MoE + Mamba-Transformer hybrid) → Pruned 22B/A3.0B active params, along with data blend preparation steps (with tool-calling data) and detailed pruning / data-blend / long-context ablations. See `examples/megatron_bridge/tutorials/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16/README.md `_ for details. - Add ``--cast_mxfp4_to_nvfp4`` flag to ``examples/llm_ptq/hf_ptq.py`` for closed-form, bit-exact MXFP4 → NVFP4 weight conversion. Supports the GPT-OSS family (``openai/gpt-oss-20b``, ``openai/gpt-oss-120b``). See `examples/llm_ptq/README.md `__ for usage. +- Add ``--cast_mxfp4_to_nvfp4`` flag to ``examples/deepseek/deepseek_v4/quantize_to_nvfp4.py`` for closed-form, bit-exact MXFP4 → NVFP4 conversion of DeepSeek V4 routed-expert weights (mirrors the GPT-OSS cast; w1/w3 share one per-tensor ``scale_2`` for the fused GEMM1). Activation ``input_scale`` still comes from ``--amax_path`` calibration. - DeepSeek PTQ (``examples/deepseek/ptq.py``) now defaults to native top-k calibration with post-hoc per-layer peer-max sync of expert ``input_quantizer.amax``; the all-experts path is preserved behind ``--calib_all_experts``. - Add NVFP4 W4A16 weight-only quantization (``w4a16_nvfp4``): FP4 weights with group_size=16, BF16 activations, no calibration forward pass required. Use ``mtq.W4A16_NVFP4_CFG`` or ``--qformat w4a16_nvfp4`` in ``hf_ptq.py``. vLLM deployment support is in progress. +- Add FP8 KV-cache cast variants for the partial-NVFP4 and weight-only general PTQ recipes: ``general/ptq/nvfp4_mlp_only-kv_fp8_cast``, ``general/ptq/nvfp4_experts_only-kv_fp8_cast``, ``general/ptq/nvfp4_omlp_only-kv_fp8_cast``, and ``general/ptq/nvfp4_weight_only-kv_fp8_cast``. These compose the same model-quant configs as their ``-kv_fp8`` siblings with the ``kv_fp8_cast`` unit (constant-amax FP8 KV cache, no KV calibration forward pass). - Add Megatron Core export/import mapping for Qwen3-VL (``Qwen3VLForConditionalGeneration``) vision-language models. The mapping handles the ``model.language_model.`` weight prefix used by Qwen3-VL. - Add active-MoE cost accounting for ``mtq.auto_quantize`` effective-bits search. Set ``constraints={"effective_bits": ..., "cost_model": "active_moe", "cost": {"active_moe_expert_ratio": ...}}`` to weight routed MoE expert costs by active experts per token while keeping shared experts fully counted. The ``hf_ptq.py`` AutoQuant path exposes this via ``--auto_quantize_cost_model active_moe`` and ``--auto_quantize_active_moe_expert_ratio``. - Add ``DATASET_COMBOS`` to ``modelopt.torch.utils.dataset_utils`` — single ``--dataset`` tokens that fan out to multiple registered datasets; per-entry ``num_samples`` is split evenly across the members. Initial combos: ``cnn_nemotron_v2_mix`` (``cnn_dailymail`` + ``nemotron-post-training-dataset-v2``, used by ``hf_ptq.py`` when no ``--dataset`` is provided) and ``nemotron-post-training-v3`` (the seven ``nvidia/Nemotron-*`` SFT datasets added in #1498, mirroring the `nemotron-post-training-v3 collection `_). Combo names are listed by ``get_supported_datasets()`` and surfaced in ``--dataset`` help. ``get_dataset_dataloader`` rejects inputs that mix a combo with one of its member datasets (e.g. ``cnn_dailymail,cnn_nemotron_v2_mix``) to avoid double-sampling, and ``get_dataset_samples`` rejects combo names so callers route through the dataloader. ``hf_ptq.py`` default ``--calib_size`` is bumped from ``512`` to ``1024`` so the total calibration sample count under the new default combo matches the previous two-dataset fallback. @@ -50,12 +53,16 @@ Changelog The legacy FSDP1 accelerate config is removed; ``llm_qat`` now documents FSDP2, DeepSpeed, and DDP backends. - The PTQ example scripts ``examples/llm_ptq/hf_ptq.py``, ``examples/llm_ptq/multinode_ptq.py`` and ``examples/megatron_bridge/quantize.py`` now derive their ``--qformat`` / ``--kv_cache_qformat`` (``--quant_cfg`` / ``--kv_cache_quant`` for Megatron-Bridge) CLI vocabularies by discovering the YAML presets under ``modelopt_recipes/configs/ptq/presets/{model,kv}/`` rather than carrying hardcoded ``QUANT_CFG_CHOICES`` / ``KV_QUANT_CFG_CHOICES`` tables. The discovery helper, alias table and ready-built ``QUANT_CFG_CHOICES`` / ``KV_QUANT_CFG_CHOICES`` mappings now live in ``modelopt.recipe.presets`` and are shared by all three scripts. Presets are loaded eagerly into a plain dict at import. Adding a new preset YAML makes it available on the CLI of all three with no script change — note this means each script now accepts every preset under those directories, not just a previously curated subset. All previously-supported short names (``int8_sq``, ``nvfp4_awq``, ``fp8_pb_wo``, ``nvfp4_mse``, ``w4a8_awq``, ``nvfp4_local_hessian``, ``fp8_pc_pt``, ``int8_wo``) keep working via a small deprecation alias table; new formats should be exposed as preset YAMLs (or, longer term, as full ``--recipe`` recipes). - Add ``configs/ptq/presets/kv/fp8_cast.yaml`` and ``configs/ptq/presets/kv/nvfp4_cast.yaml``, promoting ``fp8_cast`` / ``nvfp4_cast`` to first-class KV presets composed from the existing ``kv_fp8_cast`` / ``kv_nvfp4_cast`` unit fragments. The previous runtime ``use_constant_amax`` post-edit in ``hf_ptq.py`` is removed; ``use_constant_amax: true`` now lives in the YAML and is therefore authoritative. **Custom (out-of-tree) recipes that target a cast KV format must set ``use_constant_amax: true`` themselves on the ``[kv]_bmm_quantizer`` config** — in-tree recipes already do via the ``kv_*_cast`` units. +- Add DMD2 distillation for few-step diffusion models in ``examples/diffusers/fastgen/``: distill Qwen-Image into a 4/8-step student via Distribution Matching Distillation. See `examples/diffusers/fastgen/README.md `_ for details. **Bug Fixes** - In Megatron-Core only do EP amax sync for routed expert weights if ``sync_expert_weight_amax=True``. Previously EP amax sync would sync routed expert weights across EP ranks even when ``sync_expert_weight_amax`` was False. - Fix Megatron-Core HF importer to load fused ``TELayerNormColumnParallelLinear.layer_norm_weight`` from HF for GPT-family models (Qwen3 etc.) under ``--export-default-te-spec``. Importer now prefers per-context keys ``fused_input_layernorm`` / ``fused_pre_mlp_layernorm`` (fallback ``fused_norm`` for Nemotron-H backward compatibility); ``mcore_qwen.py`` provides the new rules. Without this fix, post-prune MMLU sat at chance. - Fix ONNX AutoCast ``keep_io_types=True`` sanity-check failure (``Unexpected type in I/O tensor ...``) when a network input/output is an empty tensor (a dimension of size 0). Such tensors were "fake-cast" (retyped in place) to the low precision type; because the value-info map aliases the ``graph.input``/``graph.output`` ``ValueInfoProto``, this silently changed the model's I/O type. AutoCast now inserts a real ``Cast`` for protected I/O tensors instead. +- Fix INT8 entropy calibration of fp16 ONNX models raising ``ValueError: Too many bins for data range`` on numpy >= 2.0. ``_collect_value`` in ``modelopt.onnx.quantization.ort_patching`` now casts the histogram range endpoints to Python float so bin edges are computed in float64, instead of inheriting the fp16 dtype of an activation tensor with a small range (which collapsed the 128-bin linspace under NEP-50 promotion). +- Fix the GPT-OSS MXFP4 → NVFP4 PTQ path in ``examples/llm_ptq/hf_ptq.py`` (used with ``--cast_mxfp4_to_nvfp4``). ``get_model`` now loads native MXFP4 checkpoints (``openai/gpt-oss-*``) dequantized to BF16 ``GptOssExperts`` via ``Mxfp4Config(dequantize=True)`` on a sequential device map. This fixes a CUDA illegal-memory access during the multi-GPU dequant load and the ``NotImplementedError`` for experts type ``Mxfp4GptOssExperts`` during unified HF export (the packed-kernel experts wrapper, used when the optional ``kernels`` package is installed, is unsupported by export); ``kernels`` is no longer required. The ``--cast_mxfp4_to_nvfp4`` step now also resolves a HF Hub ID ``--pyt_ckpt_path`` to its local snapshot directory instead of failing with ``FileNotFoundError``. +- Fix ``_QuantGptOssExperts`` / ``_QuantLlama4TextExperts`` static-block NVFP4 weight calibration raising ``ValueError: Input shape has changed`` during the calibration forward. These experts quantize their weights transposed (``_transposed_quantize``); ``iter_weights_for_calibration`` now yields the same transposed view so weight-only calibration and the forward agree on the block-quant shape (and the export ``_amax`` orientation). **Deprecations** @@ -74,7 +81,7 @@ Changelog - Add Puzzletron - a new algorithm for heterogeneous pruning of LLM and VLM models. See `examples/puzzletron/README.md `_ for more details. - Added iterator interface using CalibrationDataReader in ONNX quantization workflow. - Add N:M sparse softmax support to the Triton flash attention kernel (``modelopt.torch.kernels.common.attention.triton_fa``). See `examples/llm_sparsity/attention_sparsity/README.md `_ for usage. -- Add skip-softmax skipping to the Triton flash attention kernel (``modelopt.torch.kernels.common.attention.triton_fa``). See `examples/llm_sparsity/attention_sparsity/README.md `_ for usage. +- Add skip-softmax skipping to the Triton flash attention kernel for both language models and video diffusion models (``modelopt.torch.kernels.common.attention.triton_fa``). See `examples/llm_sparsity/attention_sparsity/README.md `_ and `examples/diffusers/sparsity/ `_ for usage. - Add Video Sparse Attention (VSA) method for video diffusion models (``modelopt.torch.sparsity.attention_sparsity``). VSA uses 3D block tiling with a two-branch architecture for attention speedup. - Enable PTQ workflow for the Step3.5-Flash MoE model with NVFP4 W4A4 + FP8 KV cache quantization. See `modelopt_recipes/models/Step3.5-Flash/nvfp4-mlp-only.yaml `_ for more details. - Add support for vLLM fakequant reload using ModelOpt state for HF models. See `examples/vllm_serve/README.md `_ for more details. diff --git a/README.md b/README.md index b3de87a9e4f..4852bfa7884 100644 --- a/README.md +++ b/README.md @@ -26,7 +26,7 @@ Model Optimizer is also integrated with [NVIDIA Megatron-Bridge](https://github. ## Latest News -- [2026/05/27] [**End-to-end Minitron workflow for Nemotron-3-Nano-30B-A3B**](./examples/pruning/minitron/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16): Pruning + two-phase distillation + FP8 quantization achieving 1.64× vLLM throughput and 2.6× memory reduction. +- [2026/05/27] [**End-to-end optimization tutorial for Nemotron-3-Nano-30B-A3B**](./examples/megatron_bridge/tutorials/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16): Pruning + distillation (with long context extension) + FP8 quantization achieving 2.6× vLLM throughput and 2.6× memory reduction. - [2026/05/13] [**Puzzletron**](./examples/puzzletron): A new algorithm for heterogeneous pruning & NAS of LLM and VLM models. - [2026/04/15] Customer story: [Domyn compresses Colosseum-355B → 260B using ModelOpt's Minitron pruning + distillation](https://www.domyn.com/blog/domyn-large-the-journey-of-a-european-sovereign-ai-model-for-regulated-industries) - [2026/03/17] Customer story: [Bielik.AI builds Bielik Minitron 7B (33% smaller, 50% faster, 90% quality retained) using ModelOpt's Minitron pruning + distillation](https://bielik.ai/en/nvidia-gtc-bielik-minitron-premiere/) diff --git a/docs/source/guides/10_recipes.rst b/docs/source/guides/10_recipes.rst index ed2473c0233..7b9180c52d6 100644 --- a/docs/source/guides/10_recipes.rst +++ b/docs/source/guides/10_recipes.rst @@ -499,14 +499,22 @@ General PTQ recipes are model-agnostic and apply to any supported architecture: - NVFP4 W4A4, FP8 KV cache with data-driven calibration * - ``general/ptq/nvfp4_default-kv_nvfp4_cast`` - NVFP4 W4A4, NVFP4 KV cache with constant amax, max calibration + * - ``general/ptq/nvfp4_mlp_only-kv_fp8_cast`` + - NVFP4 for MLP layers only, FP8 KV cache with constant amax * - ``general/ptq/nvfp4_mlp_only-kv_fp8`` - NVFP4 for MLP layers only, FP8 KV cache + * - ``general/ptq/nvfp4_experts_only-kv_fp8_cast`` + - NVFP4 for MoE expert layers only, FP8 KV cache with constant amax * - ``general/ptq/nvfp4_experts_only-kv_fp8`` - NVFP4 for MoE expert layers only, FP8 KV cache * - ``general/ptq/nvfp4_experts_only-kv_fp8_layerwise`` - NVFP4 for MoE expert layers only, FP8 KV cache, layerwise calibration + * - ``general/ptq/nvfp4_omlp_only-kv_fp8_cast`` + - NVFP4 for output projection + MLP layers, FP8 KV cache with constant amax * - ``general/ptq/nvfp4_omlp_only-kv_fp8`` - NVFP4 for output projection + MLP layers, FP8 KV cache + * - ``general/ptq/nvfp4_weight_only-kv_fp8_cast`` + - NVFP4 W4A16 weight-only, FP8 KV cache with constant amax Model-specific recipes ---------------------- @@ -668,10 +676,14 @@ The ``modelopt_recipes/`` package is organized as follows: | +-- nvfp4_default-kv_fp8_cast.yaml | +-- nvfp4_default-kv_fp8.yaml | +-- nvfp4_default-kv_nvfp4_cast.yaml + | +-- nvfp4_mlp_only-kv_fp8_cast.yaml | +-- nvfp4_mlp_only-kv_fp8.yaml + | +-- nvfp4_experts_only-kv_fp8_cast.yaml | +-- nvfp4_experts_only-kv_fp8.yaml | +-- nvfp4_experts_only-kv_fp8_layerwise.yaml + | +-- nvfp4_omlp_only-kv_fp8_cast.yaml | +-- nvfp4_omlp_only-kv_fp8.yaml + | +-- nvfp4_weight_only-kv_fp8_cast.yaml +-- huggingface/ # Model-specific recipes | +-- / # see modelopt_recipes/huggingface/README.md | +-- / diff --git a/examples/alpamayo/0417_16rows_train_set_for_calibration_25.10.parquet b/examples/alpamayo/0417_16rows_train_set_for_calibration_25.10.parquet new file mode 100644 index 00000000000..c16c7d2b98c Binary files /dev/null and b/examples/alpamayo/0417_16rows_train_set_for_calibration_25.10.parquet differ diff --git a/examples/alpamayo/README.md b/examples/alpamayo/README.md new file mode 100644 index 00000000000..a87c85d5188 --- /dev/null +++ b/examples/alpamayo/README.md @@ -0,0 +1,72 @@ +# Quantizing Alpamayo 1 + +[Alpamayo 1](https://github.com/nvlabs/alpamayo) (formerly Alpamayo-R1) is a +~10B vision-language-action model trained by NVIDIA for autonomous vehicle +research. It takes multi-camera video and egomotion history as input and +produces a Chain-of-Causation reasoning trace plus a future driving trajectory. +See the paper, [*Alpamayo-R1: Bridging Reasoning and Action Prediction for +Generalizable Autonomous Driving in the Long +Tail*](https://arxiv.org/abs/2511.00088), and the +[nvlabs/alpamayo](https://github.com/nvlabs/alpamayo) repository for details. + +This example produces FP8, NVFP4, and mixed-precision quantized checkpoints of +Alpamayo using ModelOpt. Quantization calibration runs on a small dataset of 16 +AV clips (`0417_16rows_train_set_for_calibration_25.10.parquet`). + +## Setup + +Clone Alpamayo and install it into the current environment so `alpamayo_r1` is +importable: + +```bash +git clone https://github.com/nvlabs/alpamayo # tested @ 4cda35d +pip install ./alpamayo +``` + +Follow the Alpamayo README to request access to the gated model weights and the +Physical AI AV dataset, then authenticate with `hf auth login`. + +## Usage + +`quantize.py` loads an Alpamayo checkpoint, calibrates it on the 16 clips, and +exports an HF-style quantized checkpoint. + +### FP8 / NVFP4 + +By default the script saves **fake-quantized** weights (fp16 weights plus +quantizer state) — useful for accuracy evaluation: + +```bash +python quantize.py --ckpt nvidia/Alpamayo-R1-10B --output-dir ./alpamayo-fp8 --quantize fp8 +``` + +Pass `--real-quant` to save **real-quantized** weights packed into the +low-precision storage format (NVFP4 = E2M1 nibbles + per-block FP8 scales), +which run on the hardware low-precision GEMM path: + +```bash +python quantize.py --ckpt nvidia/Alpamayo-R1-10B --output-dir ./alpamayo-nvfp4 --quantize nvfp4 --real-quant +``` + +The vision tower is always kept in high precision, and small action-projection +heads whose dimensions are not multiples of 16 are left unquantized (they break +the real-quant GEMM backends). + +### AutoQuantize (mixed precision) + +`--quantize auto` runs ModelOpt's AutoQuantize, which searches per layer between +NVFP4 and FP8 under an effective-bits budget (`--auto_quantize_bits`, default +6.5): + +```bash +python quantize.py --ckpt nvidia/Alpamayo-R1-10B --output-dir ./alpamayo-auto --quantize auto --auto_quantize_bits 6.5 +``` + +AutoQuantize chooses a per-layer format using a **gradient-based sensitivity +score**: it backpropagates a loss through the model and estimates how much each +candidate format perturbs that loss, then picks the cheapest assignment that +stays within the bit budget. Here the loss is the flow-matching objective — an +MSE between the action expert's predicted velocity field `v_pred` and the +target `v_target = x_1 - x_0` from a teacher-forced forward pass on the +calibration clips. Layers the loss is sensitive to keep more bits (FP8); the +rest go to NVFP4. diff --git a/examples/alpamayo/quantize.py b/examples/alpamayo/quantize.py new file mode 100644 index 00000000000..6a7b3d56f61 --- /dev/null +++ b/examples/alpamayo/quantize.py @@ -0,0 +1,652 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Quantize AlpamayoR1 and export as an HF-style checkpoint. + +Usage: + python quantize.py --ckpt nvidia/Alpamayo-R1-10B --output-dir ./alpamayo-r1-fp8 --quantize fp8 + python quantize.py --ckpt nvidia/Alpamayo-R1-10B --output-dir ./alpamayo-r1-nvfp4 --quantize nvfp4 --real-quant +""" + +import argparse +import collections.abc +import copy +import json +import os +from pathlib import Path +from typing import Any + +import einops +import pandas as pd +import torch +from alpamayo_r1.load_physical_aiavdataset import load_physical_aiavdataset +from alpamayo_r1.models.alpamayo_r1 import AlpamayoR1 +from alpamayo_r1.models.token_utils import to_special_token +from tqdm import tqdm +from transformers import AutoProcessor, AutoTokenizer + +import modelopt.torch.opt as mto +import modelopt.torch.quantization as mtq +from modelopt.torch.export.quant_utils import get_quant_config +from modelopt.torch.utils.dataset_utils import create_forward_loop, get_dataset_dataloader + +MIN_PIXELS = 163840 +MAX_PIXELS = 196608 +BASE_PROCESSOR_NAME = "Qwen/Qwen3-VL-2B-Instruct" + + +def create_message(frames: torch.Tensor): + """Construct the message using images and cot.""" + assert frames.ndim == 4, f"{frames.ndim=}, expected (N, C, H, W)" + + # NOTE: we expand the padding tokens to match training, so we can directly apply native processor from VLM. + num_traj_token = 48 + hist_traj_placeholder = ( + f"<|traj_history_start|>{'<|traj_history|>' * num_traj_token}<|traj_history_end|>" + ) + + return [ + { + "role": "system", + "content": [ + { + "type": "text", + "text": "You are a driving assistant that generates safe and accurate actions.", + } + ], + }, + { + "role": "user", + "content": [{"type": "image", "image": frame} for frame in frames] + + [ + { + "type": "text", + "text": f"{hist_traj_placeholder}output the chain-of-thought reasoning of the \ + driving process, then output the future trajectory.", + } + ], + }, + { + "role": "assistant", + "content": [ + { + "type": "text", + "text": "<|cot_start|>", + } + ], + }, + ] + + +def get_processor(tokenizer: AutoTokenizer) -> AutoProcessor: + """Get the processor for the Qwen3-VL-2B-Instruct model.""" + processor_kwargs = { + "min_pixels": MIN_PIXELS, + "max_pixels": MAX_PIXELS, + } + + processor = AutoProcessor.from_pretrained(BASE_PROCESSOR_NAME, **processor_kwargs) + processor.tokenizer = tokenizer + return processor + + +def to_device( + data: Any, + device: str | torch.device | None = None, + dtype: torch.dtype | None = None, +) -> Any: + """Recursively cast data into the specified device, dtype.""" + if isinstance(data, torch.Tensor): + data = data.to( + device=device, + dtype=dtype, + ) + return data + elif isinstance(data, collections.abc.Mapping): + return {key: to_device(data[key], device=device, dtype=dtype) for key in data} + elif isinstance(data, collections.abc.Sequence) and not isinstance(data, (str, bytes)): + return [to_device(elem, device=device, dtype=dtype) for elem in data] + else: + return data + + +def _teacher_forced_flow_loss_forward( + self, + data: dict[str, Any], +) -> dict[str, torch.Tensor]: + """Differentiable forward that returns the flow-matching training targets. + + Bypasses autoregressive reasoning generation and diffusion sampling. + The VLM runs in a single non-sampling forward pass (with ```` + appended to the prompt) to build the prompt KV cache; the expert then runs once + on a linearly-interpolated noisy action and returns the predicted velocity field. + + Args: + data: dict with ``tokenized_data`` (input_ids + other processor outputs), + ``ego_history_xyz``, ``ego_history_rot``, ``ego_future_xyz``, + ``ego_future_rot``. + + Returns: + dict with keys ``v_pred`` and ``v_target``, both shape + ``(b,n_diffusion_tokens, action_dim)``. Callers compute MSE between them. + """ + ego_history_xyz = data["ego_history_xyz"] + ego_history_rot = data["ego_history_rot"] + ego_future_xyz = data["ego_future_xyz"] + ego_future_rot = data["ego_future_rot"] + b, n_traj_group, _, _ = ego_history_xyz.shape + assert n_traj_group == 1, "Only one trajectory group is supported." + + tokenized_data = dict(data["tokenized_data"]) + input_ids = tokenized_data.pop("input_ids") + traj_data_vlm = { + "ego_history_xyz": ego_history_xyz, + "ego_history_rot": ego_history_rot, + } + input_ids = self.fuse_traj_tokens(input_ids, traj_data_vlm) + device = input_ids.device + + # Append so the expert attends through the full prompt. + traj_future_start_id = self.tokenizer.convert_tokens_to_ids( + to_special_token("traj_future_start") + ) + start_col = torch.full( + (input_ids.shape[0], 1), + traj_future_start_id, + dtype=input_ids.dtype, + device=device, + ) + input_ids = torch.cat([input_ids, start_col], dim=1) + if "attention_mask" in tokenized_data and tokenized_data["attention_mask"] is not None: + am = tokenized_data["attention_mask"] + tokenized_data["attention_mask"] = torch.cat( + [am, torch.ones((am.shape[0], 1), dtype=am.dtype, device=am.device)], dim=1 + ) + + vlm_outputs = self.vlm( + input_ids=input_ids, + use_cache=True, + return_dict=True, + **tokenized_data, + ) + prompt_cache = vlm_outputs.past_key_values + prefill_seq_len = prompt_cache.get_seq_length() + rope_deltas = self.vlm.model.rope_deltas + + n_diffusion_tokens = self.action_space.get_action_space_dims()[0] + offset = torch.full((b,), prefill_seq_len, device=device, dtype=torch.long) + + position_ids = torch.arange(n_diffusion_tokens, device=device) + position_ids = einops.repeat(position_ids, "l -> 3 b l", b=b).clone() + delta = rope_deltas + offset[:, None] + position_ids += delta.to(position_ids.device) + + # No padding between prompt cache and action block: full attention mask. + attention_mask = torch.zeros( + (b, 1, n_diffusion_tokens, prefill_seq_len + n_diffusion_tokens), + dtype=torch.float32, + device=device, + ) + + forward_kwargs = {} + if self.config.expert_non_causal_attention: + forward_kwargs["is_causal"] = False + + # Build flow-matching target: x_1 = GT action, x_0 ~ N(0, I). + x_1 = self.action_space.traj_to_action( + traj_history_xyz=ego_history_xyz[:, 0], + traj_history_rot=ego_history_rot[:, 0], + traj_future_xyz=ego_future_xyz[:, 0], + traj_future_rot=ego_future_rot[:, 0], + ) # (b,n_diffusion_tokens, 2) + x_1 = x_1.to(device=device, dtype=torch.float32) + + x_0 = torch.randn_like(x_1) + t = torch.rand(b, 1, 1, device=device, dtype=x_1.dtype) + x_t = (1.0 - t) * x_0 + t * x_1 + v_target = x_1 - x_0 + + # Cast to action-module dtype to match action_in_proj / expert weights. + proj_dtype = next(self.action_in_proj.parameters()).dtype + x_t_cast = x_t.to(dtype=proj_dtype) + t_cast = t.to(dtype=proj_dtype) + + future_token_embeds = self.action_in_proj(x_t_cast, t_cast) + if future_token_embeds.dim() == 2: + future_token_embeds = future_token_embeds.view(b, n_diffusion_tokens, -1) + + expert_out = self.expert( + inputs_embeds=future_token_embeds, + position_ids=position_ids, + past_key_values=prompt_cache, + attention_mask=attention_mask, + use_cache=True, + **forward_kwargs, + ) + prompt_cache.crop(prefill_seq_len) + last_hidden = expert_out.last_hidden_state[:, -n_diffusion_tokens:] + v_pred = self.action_out_proj(last_hidden).view(b, *self.action_space.get_action_space_dims()) + + return {"v_pred": v_pred.to(torch.float32), "v_target": v_target} + + +def patch_teacher_forced_flow_loss_forward() -> None: + """Attach teacher_forced_flow_loss_forward to AlpamayoR1 if missing. + + The public OSS AlpamayoR1 (github.com/nvlabs/alpamayo) does not define this + method; it exists only on the internal training fork. The body ported above + is the calibration path used by auto_quantize_model. + """ + if not hasattr(AlpamayoR1, "teacher_forced_flow_loss_forward"): + AlpamayoR1.teacher_forced_flow_loss_forward = _teacher_forced_flow_loss_forward + + +patch_teacher_forced_flow_loss_forward() + + +def make_joint_calibration_forward_loop( + *, + clip_ids: list[str], + processor, + t0_us: int, + top_p: float, + temperature: float, + max_generation_length: int, + calibration_traj_samples: int, + device: str, +): + """ + Build a calibration loop that exercises both VLM generation and diffusion. + + This avoids text-only calibration and ensures quantizers in the rollout path + (vlm/expert/diffusion-related modules) observe representative activations. + """ + + def _calibration_loop(runtime_model): + runtime_model.eval() + with torch.no_grad(): + for clip_id in tqdm(clip_ids, desc="Calibration"): + data = load_physical_aiavdataset(clip_id, t0_us=t0_us) + messages = create_message(data["image_frames"].flatten(0, 1)) + inputs = processor.apply_chat_template( + messages, + tokenize=True, + add_generation_prompt=False, + continue_final_message=True, + return_dict=True, + return_tensors="pt", + ) + model_inputs = { + "tokenized_data": inputs, + "ego_history_xyz": data["ego_history_xyz"], + "ego_history_rot": data["ego_history_rot"], + } + model_inputs = to_device(model_inputs, device) + + with torch.autocast("cuda", dtype=torch.float16): + runtime_model.sample_trajectories_from_data_with_vlm_rollout( + data=model_inputs, + top_p=top_p, + temperature=temperature, + num_traj_samples=calibration_traj_samples, + max_generation_length=max_generation_length, + ) + + return _calibration_loop + + +def read_clip_ids_from_parquet(parquet_path: str) -> list[str]: + """ + Reads clip_ids from the parquet's "key" column. + Returns clip_ids as a list of strings (unique, preserving first occurrence order). + """ + parquet_path = str(parquet_path) + df = pd.read_parquet(parquet_path) + cols_lower = {c.lower(): c for c in df.columns} + clip_ids = df[cols_lower["key"]].astype(str).tolist() + + seen = set() + uniq = [] + for cid in clip_ids: + if cid not in seen: + seen.add(cid) + uniq.append(cid) + return uniq + + +def quantize_model(model, args, tokenizer=None, calibration_forward_loop=None): + """ + Quantize a PyTorch model using ModelOpt post-training quantization (PTQ). + + This function applies quantization to reduce model precision for faster inference + while maintaining acceptable accuracy. It uses calibration data generated from + the provided tokenizer to determine optimal quantization parameters. + + Supported quantization formats: + - fp8: 8-bit floating point quantization + - nvfp4: 4-bit NVIDIA floating point quantization + Args: + model: PyTorch model to quantize. Must be in evaluation mode. + args: Command line arguments containing quant_format. + tokenizer: Hugging Face tokenizer for creating calibration data. + Required only when `calibration_forward_loop` is not provided. + calibration_forward_loop: Optional callable taking `model` and running + calibration forward passes. Use this for non-text modules whose + forward signature is not compatible with dataset_utils batches. + + Returns: + Quantized model + """ + # Create calibration forward loop. For standard text models we can build + # it from tokenizer-based data, but vision modules often need custom args. + if calibration_forward_loop is None: + if tokenizer is None: + raise ValueError("tokenizer must be provided when calibration_forward_loop is None") + calib_dataloader = get_dataset_dataloader( + tokenizer=tokenizer, + batch_size=32, + num_samples=512, + device="cuda:0", + ) + calibrate_loop = create_forward_loop(dataloader=calib_dataloader) + else: + calibrate_loop = calibration_forward_loop + + if args.quant_format == "fp8": + quant_cfg = copy.deepcopy(mtq.FP8_DEFAULT_CFG) + elif args.quant_format == "nvfp4": + quant_cfg = copy.deepcopy(mtq.NVFP4_DEFAULT_CFG) + else: + raise RuntimeError("Unsupported quantization format") + # Keep the vision tower in high precision. Pass a non-NVFP4 cfg (num_bits=8) with + # enable=False, not just enable=False: an NVFP4-typed QuantConv3d routes to a JIT + # implicit-GEMM CUDA kernel (needs CUDA_HOME) even when disabled. + quant_cfg["quant_cfg"].append( + {"quantizer_name": "*vlm.model.visual*", "enable": False, "cfg": {"num_bits": 8}} + ) + + if args.quant_format == "nvfp4" or getattr(args, "real_quant", False): + # Keep Linear layers whose in/out features aren't multiples of 16 in high precision: + # they break the real-quant GEMM backends (NVFP4 block packing, FP8 torch._scaled_mm). + # In AlpamayoR1 these are the small action-projection heads, so the impact is negligible. + for _name, _module in model.named_modules(): + if isinstance(_module, torch.nn.Linear) and ( + _module.in_features % 16 != 0 or _module.out_features % 16 != 0 + ): + quant_cfg["quant_cfg"].append({"quantizer_name": f"{_name}.*", "enable": False}) + + model = mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop) + + print("================== quantize_model summary ==================") + mtq.print_quant_summary(model) + + return model + + +def auto_quantize_model( + model, + args, + *, + clip_ids, + processor, + t0_us: int, + device: str, +): + """ + Quantize a PyTorch model using ModelOpt's AutoQuantize API. + + Searches per-layer across [NVFP4_DEFAULT_CFG, FP8_DEFAULT_CFG] under the + effective-bits budget in args.auto_quantize_bits. Calibration runs the + teacher-forced flow-matching forward (teacher_forced_flow_loss_forward) on + the calibration clips; the MSE between v_pred and v_target is the search loss. + + Args: + model: PyTorch model to quantize. Must be in eval mode. + args: Namespace with `auto_quantize_bits` (float). + clip_ids: Iterable of clip_ids for calibration. + processor: HF processor used for chat-template tokenization. + t0_us: Trajectory anchor timestamp passed to load_physical_aiavdataset. + device: Device to place calibration tensors on. + + Returns: + Quantized model (the search_state from mtq.auto_quantize is discarded). + """ + + def _one_epoch(): + for clip_id in clip_ids: + data = load_physical_aiavdataset(clip_id, t0_us=t0_us) + messages = create_message(data["image_frames"].flatten(0, 1)) + inputs = processor.apply_chat_template( + messages, + tokenize=True, + add_generation_prompt=False, + continue_final_message=True, + return_dict=True, + return_tensors="pt", + ) + model_inputs = { + "tokenized_data": inputs, + "ego_history_xyz": data["ego_history_xyz"], + "ego_history_rot": data["ego_history_rot"], + "ego_future_xyz": data["ego_future_xyz"], + "ego_future_rot": data["ego_future_rot"], + } + yield to_device(model_inputs, device) + + class _ReusableLoader: + """Re-iterable wrapper so modelopt can run calibration + scoring passes.""" + + def __iter__(self): + return _one_epoch() + + data_loader = _ReusableLoader() + + def forward_step(runtime_model, data): + with torch.autocast("cuda", dtype=torch.bfloat16): + out = runtime_model.teacher_forced_flow_loss_forward(data=data) + v_pred, v_target = out["v_pred"], out["v_target"] + print( + f"[autoquant-fwd] v_pred: finite={torch.isfinite(v_pred).all().item()} " + f"min={v_pred.min().item():.4g} max={v_pred.max().item():.4g} " + f"abs_mean={v_pred.abs().mean().item():.4g} | " + f"v_target: finite={torch.isfinite(v_target).all().item()} " + f"min={v_target.min().item():.4g} max={v_target.max().item():.4g}" + ) + return out + + def loss_func(output, batch): + loss = torch.nn.functional.mse_loss(output["v_pred"], output["v_target"]) + print(f"[autoquant-loss] loss={loss.item():.6g} finite={torch.isfinite(loss).item()}") + return loss + + # Mirror the quantize_model exclusions via disabled_layers (fnmatch against module names), + # since the AutoQuantize search also includes NVFP4: keep the vision tower unquantized, and + # exclude Linear layers whose in/out features aren't multiples of 16. + disabled_layers = ["*lm_head*", "*vlm.model.visual*"] + for _name, _module in model.named_modules(): + if isinstance(_module, torch.nn.Linear) and ( + _module.in_features % 16 != 0 or _module.out_features % 16 != 0 + ): + disabled_layers.append(_name) + + model, search_state = mtq.auto_quantize( + model, + constraints={"effective_bits": args.auto_quantize_bits}, + quantization_formats=["NVFP4_DEFAULT_CFG", "FP8_DEFAULT_CFG"], + data_loader=data_loader, + forward_step=forward_step, + loss_func=loss_func, + disabled_layers=disabled_layers, + verbose=True, + ) + + print("================== auto_quantize search_state ==================") + print(search_state) + + print("================== auto_quantize_model summary ==================") + mtq.print_quant_summary(model) + + return model + + +def main(): + ap = argparse.ArgumentParser(description="Quantize AlpamayoR1 and export as HF checkpoint") + ap.add_argument( + "--ckpt", + type=str, + default="nvidia/Alpamayo-R1-10B", + help="HF hub id or local path of the input checkpoint", + ) + ap.add_argument( + "--output-dir", + type=str, + required=True, + help="Directory to save the quantized HF checkpoint", + ) + ap.add_argument( + "--quantize", + type=str, + required=True, + choices=["fp8", "nvfp4", "auto"], + help="Quantization format", + ) + ap.add_argument( + "--auto_quantize_bits", + type=float, + default=6.5, + help="Effective-bits budget for AutoQuantize (only used when --quantize auto)", + ) + ap.add_argument( + "--parquet", + type=str, + default="0417_16rows_train_set_for_calibration_25.10.parquet", + help="Parquet file with clip_ids for calibration", + ) + ap.add_argument( + "--t0_us", + type=int, + default=5_100_000, + help="Trajectory anchor timestamp passed to load_physical_aiavdataset", + ) + ap.add_argument("--top_p", type=float, default=0.98) + ap.add_argument("--temperature", type=float, default=0.6) + ap.add_argument("--max_generation_length", type=int, default=256) + ap.add_argument("--num_traj_samples", type=int, default=6) + ap.add_argument( + "--limit", type=int, default=16, help="How many clip_ids to use for calibration" + ) + ap.add_argument( + "--real-quant", + action="store_true", + help="Export packed real-quantized weights (fp8 / NVFP4) via " + "modelopt.torch.export.export_hf_checkpoint instead of " + "saving fake-quant fp16 weights with quantizer state.", + ) + args = ap.parse_args() + + script_dir = Path(__file__).resolve().parent + parquet_path = (script_dir / args.parquet).resolve() + + clip_ids = read_clip_ids_from_parquet(str(parquet_path)) + if args.limit is not None and args.limit > 0: + clip_ids = clip_ids[: args.limit] + print(f"Loaded {len(clip_ids)} clip_ids from: {parquet_path}") + + # Patch PreTrainedModel.from_pretrained / save_pretrained so ModelOpt state is saved with the + # checkpoint (and restored when AlpamayoR1.from_pretrained later loads the quantized weights). + mto.enable_huggingface_checkpointing() + + device = "cuda" + print(f"Loading model from {args.ckpt!r} ...") + model = AlpamayoR1.from_pretrained(args.ckpt, dtype=torch.float16).to( + device=device, dtype=torch.float16 + ) + model.eval() + + processor = get_processor(model.tokenizer) + + # Quantize using existing recipe + print(f"Quantizing model ({args.quantize}) ...") + quantization_args = argparse.Namespace( + quant_format=args.quantize, + quant_algo="max", + weight_only=False, + auto_quantize_bits=args.auto_quantize_bits, + real_quant=args.real_quant, + ) + if args.quantize == "auto": + model = auto_quantize_model( + model, + quantization_args, + clip_ids=clip_ids, + processor=processor, + t0_us=args.t0_us, + device=device, + ) + else: + # Build calibration loop + calibration_forward_loop = make_joint_calibration_forward_loop( + clip_ids=clip_ids, + processor=processor, + t0_us=args.t0_us, + top_p=args.top_p, + temperature=args.temperature, + max_generation_length=args.max_generation_length, + calibration_traj_samples=args.num_traj_samples, + device=device, + ) + model = quantize_model( + model, + quantization_args, + calibration_forward_loop=calibration_forward_loop, + ) + model.eval() + + # Save as HF-style checkpoint + os.makedirs(args.output_dir, exist_ok=True) + print(f"Saving quantized checkpoint to {args.output_dir!r} ...") + + if args.real_quant: + # Real (packed) quantization. `mtq.compress` packs weights into the low-precision + # storage format and enables ModelOpt's real-quant GEMM kernels. The ModelOpt-patched + # `save_pretrained` writes the packed weights plus a `modelopt_state.pth`, which + # `AlpamayoR1.from_pretrained` replays to reload and run real-quantized. + # + # NOTE: `export_hf_checkpoint` (the vLLM/TRT-LLM deployment format) isn't used here: it + # has no `modelopt_state.pth`, so a custom model class can't reload it via from_pretrained. + mtq.compress(model) + model.eval() + with torch.inference_mode(): + model.save_pretrained(args.output_dir) + processor.save_pretrained(args.output_dir) + model.config.save_pretrained(args.output_dir) + else: + with torch.inference_mode(): + model.save_pretrained(args.output_dir) + + processor.save_pretrained(args.output_dir) + model.config.save_pretrained(args.output_dir) + + quant_cfg = get_quant_config(model) + with open(os.path.join(args.output_dir, "hf_quant_config.json"), "w") as f: + json.dump(quant_cfg, f) + + print(f"Quantized checkpoint saved to {args.output_dir}") + + +if __name__ == "__main__": + with torch.no_grad(): + main() diff --git a/examples/deepseek/README.md b/examples/deepseek/README.md index a049d820fae..201997bb0ea 100644 --- a/examples/deepseek/README.md +++ b/examples/deepseek/README.md @@ -174,3 +174,25 @@ python deepseek_v4/quantize_to_nvfp4.py \ The output includes an updated `model.safetensors.index.json`, a `config.json` with `quantization_config.moe_quant_algo = "NVFP4"`, and `hf_quant_config.json` describing the mixed NVFP4 expert layers. + +When the source routed experts are MXFP4 (as in the V4 release), add +`--cast_mxfp4_to_nvfp4` for a lossless weight conversion — recommended over the +default lossy dequant/re-quant path. See below. + +#### Lossless MXFP4 → NVFP4 weight cast (`--cast_mxfp4_to_nvfp4`) + +The routed experts in the source checkpoint are already MXFP4 (E2M1 nibbles + +a power-of-two E8M0 scale per 32-element block). Without the flag, the export +dequantizes them to BF16 and re-quantizes to NVFP4 using the calibrated +per-tensor weight amax, which re-derives the per-block scales from the data and +is therefore lossy. With `--cast_mxfp4_to_nvfp4`, the per-tensor `scale_2` is +pinned to `2^(k_max - 8)` and each per-block E4M3 scale to `2^(k_j - m)` straight +from the source E8M0 scales, so `per_block_scale * scale_2 = 2^k_j` and the NVFP4 +nibbles equal the source MXFP4 nibbles bit-for-bit (for every block whose `k_j` +lands in E4M3's representable window; the rare out-of-range block falls back to a +data-derived scale). The flag only affects routed-expert **weights** — activation +`input_scale` still comes from `${AMAX}` calibration — and the run prints a +`[cast] lossless MXFP4->NVFP4 blocks: …` summary. This mirrors the GPTOSS cast in +[`examples/llm_ptq/cast_mxfp4_to_nvfp4.py`](../llm_ptq/cast_mxfp4_to_nvfp4.py); the +V4 twist is that w1/w3 share one `scale_2` (fused GEMM1), so `k_max` is taken over +both projections. diff --git a/examples/deepseek/deepseek_v3/ptq.py b/examples/deepseek/deepseek_v3/ptq.py index 437fbdeb155..50bd87ca819 100644 --- a/examples/deepseek/deepseek_v3/ptq.py +++ b/examples/deepseek/deepseek_v3/ptq.py @@ -66,17 +66,18 @@ from modelopt.torch.utils.dataset_utils import get_dataset_dataloader from modelopt.torch.utils.distributed import ParallelState -DS_V3_PATH = Path(__file__).resolve().parent / "DeepSeek-V3/inference" -DS_V3_2_PATH = Path(__file__).resolve().parent / "DeepSeek-V3.2-Exp/inference" +# The DeepSeek-V3 / DeepSeek-V3.2-Exp inference repos are cloned into the parent +# `examples/deepseek` directory (see README), one level up from this script. +DEEPSEEK_DIR = Path(__file__).resolve().parent.parent +DS_V3_PATH = DEEPSEEK_DIR / "DeepSeek-V3/inference" +DS_V3_2_PATH = DEEPSEEK_DIR / "DeepSeek-V3.2-Exp/inference" if DS_V3_2_PATH.exists(): sys.path.append(str(DS_V3_2_PATH)) elif DS_V3_PATH.exists(): sys.path.append(str(DS_V3_PATH)) else: - raise ValueError( - f"DeepSeek-V3 or DeepSeek-V3.2-Exp not found in {Path(__file__).resolve().parent}" - ) + raise ValueError(f"DeepSeek-V3 or DeepSeek-V3.2-Exp not found in {DEEPSEEK_DIR}") import model as deekseep_model # noqa: E402 from kernel import act_quant, fp8_gemm # noqa: E402 diff --git a/examples/deepseek/deepseek_v4/quantize_to_nvfp4.py b/examples/deepseek/deepseek_v4/quantize_to_nvfp4.py index 6906df7ab45..0f4d0198676 100644 --- a/examples/deepseek/deepseek_v4/quantize_to_nvfp4.py +++ b/examples/deepseek/deepseek_v4/quantize_to_nvfp4.py @@ -63,6 +63,20 @@ for the same projection. If no calibrated expert exists for that projection, export fails. +Lossless weight cast (``--cast_mxfp4_to_nvfp4``): the source routed experts are +already MXFP4 (E2M1 nibbles + a power-of-two E8M0 scale per 32-element block). +By default this script dequantizes them to BF16 and re-quantizes to NVFP4 with +the calibrated per-tensor weight amax, which re-derives per-block scales from +the data and is therefore lossy. With ``--cast_mxfp4_to_nvfp4`` we instead pin +``scale_2 = 2^(k_max - 8)`` and the per-block E4M3 scale to ``2^(k_j - m)`` +straight from the source E8M0 scales, so ``per_block_scale * scale_2 = 2^k_j`` +and the NVFP4 nibbles equal the source MXFP4 nibbles bit-for-bit (for every +block whose ``k_j`` lands in E4M3's representable window). The flag only affects +routed-expert *weights*; activation ``input_scale`` still comes from +``--amax_path`` calibration. This mirrors the GPTOSS cast in +``examples/llm_ptq/cast_mxfp4_to_nvfp4.py`` (PR #1372); the V4 twist is that +w1/w3 share one ``scale_2`` (fused GEMM1), so ``k_max`` is taken over both. + Usage (single compute node, CPU-default; dequant+requant math is cheap relative to shard I/O): @@ -91,6 +105,17 @@ from modelopt.torch.quantization.qtensor import MXFP4QTensor, NVFP4QTensor +# Closed-form MXFP4 -> NVFP4 numerics shared with the GPT-OSS cast (PR #1372). +from modelopt.torch.quantization.utils.numeric_utils import ( + E2M1_MAX, + E4M3_KMAX, + E4M3_KMIN, + E4M3_MAX, + E8M0_BIAS, + mxfp4_to_nvfp4_global_amax, + mxfp4_to_nvfp4_per_block_amax, +) + # Routed-expert weights in regular MoE layers. MTP experts remain in source format. _EXPERT_WEIGHT_RE = re.compile(r"^layers\.\d+\.ffn\.experts\.\d+\.w[123]\.weight$") _EXPERT_PROJ_RE = re.compile(r"^(?Players\.\d+\.ffn\.experts)\.\d+\.w[123]$") @@ -233,6 +258,98 @@ def _quantize_weight_nvfp4( return q_tensor._quantized_data, weight_scale, weight_scale_2, synthesized +# --------------------------------------------------------------------------- +# Lossless MXFP4 -> NVFP4 weight cast (``--cast_mxfp4_to_nvfp4``). +# +# NVFP4 uses the same E2M1 nibble grid as MXFP4 with 16-element blocks and a +# two-level scale ``per_block_scale (E4M3) * scale_2 (fp32)``. Pinning +# ``scale_2 = 2^m`` (``m = k_max - 8``) and ``per_block_scale = 2^(k_j - m)`` +# makes ``per_block_scale * scale_2 = 2^k_j`` exactly, so each NVFP4 nibble +# equals the source MXFP4 nibble verbatim — bit-exact for every block whose +# ``k_j`` lands in E4M3's window (``k_max - k_j <= 17``). The closed-form +# per-block amax and the format constants are reused from the GPT-OSS cast +# (``cast_mxfp4_to_nvfp4``, PR #1372); the V4 twist is that w1/w3 share one +# ``scale_2`` (fused GEMM1), so ``k_max`` is taken over both projections. +# --------------------------------------------------------------------------- +_NVFP4_BLOCK = 16 # NVFP4 block size (elements) +_MXFP4_BYTES_PER_BLOCK = 16 # 32 E2M1 nibbles packed 2-per-byte + + +def _kmax_from_mxfp4_scale(mxfp4_scale: torch.Tensor, device: str = "cpu") -> int: + """Largest non-zero E8M0 exponent ``k_j = e8m0 - 127`` (0 if all-zero). + + Delegates to the GPT-OSS cast's ``k_max`` logic, which excludes the + all-zero sentinel (``e8m0 == 0`` => ``k == -127``). + """ + e8m0 = mxfp4_scale.to(device).contiguous().view(torch.uint8) + return mxfp4_to_nvfp4_global_amax(e8m0)[1]["k_max"] + + +def _build_w13_kmax_overrides(f, expert_weight_keys: list[str], device: str) -> dict[str, int]: + """Shared ``k_max`` per w1/w3 pair so the fused GEMM1 gets one ``scale_2``.""" + groups: dict[str, dict[str, str]] = defaultdict(dict) + for key in expert_weight_keys: + expert_path = key[: -len(".weight")] + base, proj = expert_path.rsplit(".", 1) + if proj in {"w1", "w3"}: + groups[base][proj] = expert_path + + overrides: dict[str, int] = {} + for paths in groups.values(): + if "w1" not in paths or "w3" not in paths: + continue + k1 = _kmax_from_mxfp4_scale(f.get_tensor(paths["w1"] + ".scale"), device) + k3 = _kmax_from_mxfp4_scale(f.get_tensor(paths["w3"] + ".scale"), device) + shared = max(k1, k3) + overrides[paths["w1"]] = shared + overrides[paths["w3"]] = shared + return overrides + + +def _quantize_weight_nvfp4_lossless( + mxfp4_weight: torch.Tensor, + mxfp4_scale: torch.Tensor, + k_max: int, + device: str, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int]: + """Closed-form bit-exact MXFP4 -> NVFP4 weight conversion. + + Pins ``scale_2 = 2^(k_max - 8)`` and the per-block E4M3 scale to + ``2^(k_j - m)`` so the NVFP4 nibbles equal the source MXFP4 nibbles for + every in-range block. ``k_max`` is shared across w1/w3 (fused GEMM1), so it + is passed in rather than derived per tensor. The closed-form per-block amax + (``6 * 2^k_j`` in range, data-derived out of range) is independent of + ``k_max``, so we reuse the GPT-OSS helper directly. Returns + ``(packed, weight_scale, weight_scale_2, n_blocks, n_lossless)``. + """ + bf16 = _dequantize_mxfp4_to_bf16(mxfp4_weight, mxfp4_scale, device) + e8m0 = mxfp4_scale.to(bf16.device).contiguous().view(torch.uint8) # (out, nblk32) + packed = mxfp4_weight.to(bf16.device).contiguous().view(torch.uint8) # (out, nblk32*16) + blocks = packed.view(*packed.shape[:-1], e8m0.shape[-1], _MXFP4_BYTES_PER_BLOCK) + per_block_amax = mxfp4_to_nvfp4_per_block_amax(blocks, e8m0) # (out, nblk16) fp32 + + m = k_max - E4M3_KMAX + weight_scale_2 = torch.tensor(2.0**m, dtype=torch.float32, device=bf16.device).reshape(()) + per_block_scale = ( + (per_block_amax / (E2M1_MAX * weight_scale_2)) + .clamp(min=2**-9, max=E4M3_MAX) + .to(torch.float8_e4m3fn) + ) + + # Lossless accounting against the (possibly shared) k_max. A block is lossy + # only if k_max - k_j > 17; all-zero blocks (e8m0 == 0) reconstruct to 0 + # regardless of scale and so are always lossless. + k = e8m0.to(torch.int32) - E8M0_BIAS + lossless = (k >= (k_max - (E4M3_KMAX - E4M3_KMIN))) | (e8m0 == 0) + n_blocks = k.numel() + n_lossless = int(lossless.sum().item()) + + q_tensor, weight_scale, _ = NVFP4QTensor.quantize( + bf16, _NVFP4_BLOCK, per_block_scale, weight_scale_2, try_tensorrt=False + ) + return q_tensor._quantized_data, weight_scale, weight_scale_2, n_blocks, n_lossless + + def _build_w13_weight_amax_overrides( f, expert_weight_keys: list[str], @@ -279,6 +396,7 @@ def convert_shard( input_fallback: dict[str, torch.Tensor], device: str, stats: dict[str, int], + cast: bool = False, ) -> tuple[list[str], list[str]]: """Rewrite one HF-style shard and return index deltas.""" out: dict[str, torch.Tensor] = {} @@ -289,9 +407,16 @@ def convert_shard( all_keys = list(f.keys()) expert_weight_keys = [k for k in all_keys if _EXPERT_WEIGHT_RE.match(k)] expert_weight_key_set = set(expert_weight_keys) - w13_weight_amax, w13_synth_paths = _build_w13_weight_amax_overrides( - f, expert_weight_keys, amax, device - ) + if cast: + # Closed-form weight cast derives scales from the source E8M0 + # exponents, not from calibrated weight amax. w1/w3 share k_max. + w13_kmax = _build_w13_kmax_overrides(f, expert_weight_keys, device) + w13_weight_amax, w13_synth_paths = {}, set() + else: + w13_kmax = {} + w13_weight_amax, w13_synth_paths = _build_w13_weight_amax_overrides( + f, expert_weight_keys, amax, device + ) scale_siblings = { k[: -len(".weight")] + ".scale" for k in expert_weight_keys @@ -335,9 +460,22 @@ def convert_shard( w = f.get_tensor(key) s = f.get_tensor(scale_key) - packed, weight_scale, weight_scale_2, weight_synth = _quantize_weight_nvfp4( - w, s, weight_amax, device=device - ) + if cast: + k_max = w13_kmax.get(expert_path) + if k_max is None: + k_max = _kmax_from_mxfp4_scale(s, device) + packed, weight_scale, weight_scale_2, n_blk, n_lossless = ( + _quantize_weight_nvfp4_lossless(w, s, k_max, device) + ) + weight_synth = False + stats["cast_blocks_total"] += n_blk + stats["cast_blocks_lossless"] += n_lossless + if n_lossless < n_blk: + stats[f"cast_oor_tensors_{block_kind}"] += 1 + else: + packed, weight_scale, weight_scale_2, weight_synth = _quantize_weight_nvfp4( + w, s, weight_amax, device=device + ) input_scale = _amax_to_nvfp4_scale_2(input_amax).to(weight_scale_2.device) out[key] = packed.cpu() @@ -607,6 +745,17 @@ def main(): action="store_true", help="replace an existing non-empty output checkpoint directory", ) + p.add_argument( + "--cast_mxfp4_to_nvfp4", + action="store_true", + help=( + "losslessly cast the source MXFP4 routed-expert weights to NVFP4 " + "(pin scale_2 = 2^(k_max-8) and per-block scale = 2^(k_j-m) from the " + "source E8M0 scales) instead of dequant + re-quant with calibrated " + "weight amax. Only affects weights; input_scale still comes from " + "--amax_path calibration." + ), + ) args = p.parse_args() _validate_paths(args.source_ckpt, args.output_ckpt) @@ -639,6 +788,7 @@ def main(): input_fallback, args.device, stats, + args.cast_mxfp4_to_nvfp4, ) shard_updates[src.name] = (added, removed) @@ -647,6 +797,12 @@ def main(): for k in sorted(stats.keys()): _log(f" {k:40s} {stats[k]}") + if args.cast_mxfp4_to_nvfp4: + tot = stats.get("cast_blocks_total", 0) + loss = stats.get("cast_blocks_lossless", 0) + pct = 100.0 * loss / tot if tot else 100.0 + _log(f"[cast] lossless MXFP4->NVFP4 blocks: {loss}/{tot} ({pct:.4f}%)") + quantized: set[str] = set() for _added, _removed in shard_updates.values(): for a in _added: diff --git a/examples/diffusers/README.md b/examples/diffusers/README.md index 56d9eb481bf..8fc32d7a324 100644 --- a/examples/diffusers/README.md +++ b/examples/diffusers/README.md @@ -314,7 +314,7 @@ By following these steps, your PEFT LoRA model should be efficiently quantized u ## Sparse Attention (Skip-Softmax) -Skip-softmax sparse attention skips KV tiles whose attention scores are negligible during the softmax computation, reducing FLOPs without retraining. An exponential model (`scale_factor = a * exp(b * target_sparsity)`) is calibrated once, then the target sparsity can be adjusted at runtime without recalibration. +Skip-softmax sparse attention skips KV tiles whose attention scores are negligible during the softmax computation, reducing FLOPs without retraining. An exponential model (`scale_factor = a * exp(b * target_sparsity)`) is calibrated once, then the target sparsity can be adjusted at runtime without recalibration. Calibrated coefficients can be exported as a Hugging Face checkpoint (embedded in each component's `config.json` under `sparse_attention_config`) consumed directly by TRT-LLM's `SkipSoftmaxAttentionConfig.resolve_for_target_sparsity` — no extra conversion needed downstream. ### Getting Started @@ -358,10 +358,11 @@ The 14B model automatically sparsifies both `transformer` and `transformer_2`. ```bash -# 5B/14B model +# 5B/14B model — calibrate and export a TRT-LLM-ready checkpoint python sparsity/wan22_skip_softmax.py \ --model-path Wan-AI/Wan2.2-T2V-A14B-Diffusers|Wan-AI/Wan2.2-TI2V-5B-Diffusers \ --calibrate --target-sparsity 0.5 --calib-size 4 \ + --export-dir /path/to/wan22-skip-softmax-ckpt \ --prompt "A sunset over mountains" --output out.mp4 ``` diff --git a/examples/diffusers/fastgen/README.md b/examples/diffusers/fastgen/README.md new file mode 100644 index 00000000000..b3c3bc780f9 --- /dev/null +++ b/examples/diffusers/fastgen/README.md @@ -0,0 +1,180 @@ +# DMD2 distillation for Qwen-Image + +Distill [`Qwen/Qwen-Image`](https://huggingface.co/Qwen/Qwen-Image) into a **few-step +generator** with DMD2 (Distribution Matching Distillation). The distilled student +produces images in as few as **1–4 sampling steps** while matching the base model's +output distribution. Built on `modelopt.torch.fastgen` and NeMo AutoModel's +[`TrainDiffusionRecipe`](https://github.com/NVIDIA-NeMo/Automodel/blob/main/nemo_automodel/recipes/diffusion/train.py). + +> [!NOTE] +> Qwen-Image is a third-party model with its own license terms. Review the +> [Qwen-Image model card](https://huggingface.co/Qwen/Qwen-Image) before downloading or +> redistributing weights or derivatives. + +## How DMD2 works + +DMD2 trains three networks together: + +| Model | Role | +|---|---| +| **Student** | the few-step generator you keep | +| **Fake-score** | a diffusion model that tracks the *student's* current output distribution | +| **Teacher** | the frozen base Qwen-Image model (the *target* distribution) | + +The distribution-matching gradient pushes the student toward the teacher and away from +the fake-score. Training alternates between two phases, controlled by `student_update_freq`: + +```text +each step: + if step % student_update_freq == 0: # student phase + update the student (distribution-matching [+ optional GAN] loss) + update the student EMA + else: # fake-score phase + update the fake-score network to track the student +``` + +The canonical config additionally enables **CFG** (classifier-free guidance on the +teacher) and a lightweight **GAN** branch (a discriminator head on a teacher feature +block, plus an R1 gradient penalty) for sharper samples. + +## Install + +From the repo root: + +```bash +pip install -e ".[all]" # ModelOpt + torch + diffusers +pip install -r examples/diffusers/fastgen/requirements.txt # nemo_automodel +``` + +`nemo_automodel[diffusion]` pulls in diffusers, accelerate, and the `TrainDiffusionRecipe` +this example subclasses. + +## Quick start — mock data (no dataset needed) + +The smoke config feeds random tensors at Qwen-Image's shapes, so it runs end-to-end with +**no dataset to prepare** — it exercises the full training loop (FSDP2 sharding, phase +alternation, checkpoint save/restore). Use it to validate your environment: + +```bash +torchrun --nproc-per-node=8 \ + examples/diffusers/fastgen/dmd2_finetune.py \ + --config examples/diffusers/fastgen/configs/dmd2_qwen_image_smoke.yaml +``` + +Scale `--fsdp.dp_size` to your GPU count. You'll see alternating `phase=student` / +`phase=fake_score` log lines and a checkpoint written at the last step. + +> The mock loop validates wiring only — it does **not** produce meaningful images. For +> that, train on real data (below). + +## Real-data training + +`configs/dmd2_qwen_image.yaml` is the canonical config: 4-step student, CFG, and the +GAN + R1 branch, trained on a preprocessed latent cache. Before launching, provide: + +- **A preprocessed Qwen-Image latent cache** — set `data.dataloader.cache_dir`. +- **A precomputed negative-prompt embedding** (required for CFG) — set + `data.dataloader.negative_prompt_embedding_path`. +- **An output directory** — set `checkpoint.checkpoint_dir`. + +The model path defaults to `Qwen/Qwen-Image`; point it at a local snapshot to avoid +re-downloading on every job. Then: + +```bash +torchrun --nproc-per-node=8 \ + examples/diffusers/fastgen/dmd2_finetune.py \ + --config examples/diffusers/fastgen/configs/dmd2_qwen_image.yaml \ + --step_scheduler.max_steps=5000 +``` + +Any `DMDConfig` field can be overridden on the CLI (e.g. `--dmd2.guidance_scale=3.5`). + +### Checkpoints & resuming + +Checkpoints land under `checkpoint.checkpoint_dir`. Alongside the student, the recipe +saves the DMD2 sidecars needed to resume exactly: the fake-score model + optimizer, the +student EMA (`ema_shadow.pt`), and the DMD iteration counter (`dmd_state.pt`). With +`restore_from: LATEST` a re-launch auto-resumes from the newest checkpoint; pin a +specific one with `--checkpoint.restore_from=epoch_0_step_500`. + +## Inference + +After training, sample from the distilled student. The pipeline loads your consolidated +student transformer plus the base Qwen-Image VAE / text encoder / tokenizer: + +```python +import torch +from inference_dmd2_qwen_image import QwenImageDMDInferencePipeline + +pipe = QwenImageDMDInferencePipeline.from_pretrained( + student_path="/path/to/checkpoint/epoch_0_step_500/model/consolidated", + base_pipeline_path="Qwen/Qwen-Image", + ema_path=None, # or ".../ema_shadow.pt" to sample the EMA weights + torch_dtype=torch.bfloat16, +).to("cuda") + +image = pipe( + prompt="a small red cube on a white table", + num_inference_steps=4, # match the student_sample_steps you trained with + height=1024, width=1024, + generator=torch.Generator("cuda").manual_seed(42), +).images[0] +image.save("sample.png") +``` + +Or run the bundled CLI for a quick check: + +```bash +python examples/diffusers/fastgen/inference_dmd2_qwen_image.py \ + --student_path /path/to/checkpoint/.../model/consolidated \ + --base_pipeline_path Qwen/Qwen-Image \ + --prompt "a small red cube on a white table" \ + --height 512 --width 512 +``` + +Set `num_inference_steps` to the number of steps the student was trained for +(`dmd2.student_sample_steps` — e.g. 4 for the canonical config, or 1 for a single-step +student). + +## Config reference + +| Section | Key | Role | +|---|---|---| +| `model` | `pretrained_model_name_or_path` | Qwen-Image HF id or local snapshot. | +| `model` | `mode` | `finetune` — loads the pretrained weights. | +| `step_scheduler` | `global_batch_size`, `local_batch_size`, `max_steps`, `ckpt_every_steps`, `log_every` | Standard AutoModel scheduling knobs. | +| `dmd2` | `recipe_path` | Built-in fastgen recipe to hydrate `DMDConfig` from (`general/distillation/dmd2_qwen_image`). | +| `dmd2` | `pipeline_plugin` | `qwen_image` — selects `QwenImageDMDPipeline` (2×2 patch packing / img_shapes). | +| `dmd2` | `student_sample_steps` | Number of student sampling steps (e.g. 4). | +| `dmd2` | `guidance_scale` | CFG strength on the teacher (`null` disables CFG; requires a negative-prompt embedding when set). | +| `dmd2` | `gan_loss_weight_gen`, `gan_r1_reg_weight`, `gan_feature_indices`, … | GAN branch (set `gan_loss_weight_gen: 0` to disable). | +| `dmd2` | `fake_score_lr`, `discriminator_lr` | Separate LRs for the fake-score / discriminator optimizers. | +| `dmd2` | `sample_t_cfg`, `ema` | Timestep sampling + student EMA settings. | +| `optim` | `learning_rate`, `optimizer.*` | Student AdamW knobs. | +| `fsdp` | `dp_size`, `tp_size`, `activation_checkpointing`, … | FSDP2 parallelism (set `dp_size` to your GPU count). | +| `data` | `dataloader._target_`, `cache_dir`, `negative_prompt_embedding_path` | Real latent cache vs. `build_mock_t2i_dataloader`. | +| `checkpoint` | `checkpoint_dir`, `model_save_format`, `restore_from` | Output dir, save format, resume behavior. | + +## Troubleshooting + +**`CUDA out of memory`.** Training holds three Qwen-Image transformers (student + teacher +- fake-score) plus optimizer state. Shard across more GPUs (raise `--fsdp.dp_size`), +enable `--fsdp.activation_checkpointing=true`, or use the mock smoke for wiring checks. + +**Loss is `NaN` on step 0.** Almost always an out-of-range timestep — confirm you haven't +overridden `dmd2.pred_type` away from `flow` (Qwen-Image is a rectified-flow model) or +changed the timestep schedule. + +**`guidance_scale is set but negative_encoder_hidden_states was not provided`.** CFG needs +a precomputed negative-prompt embedding. Set `data.dataloader.negative_prompt_embedding_path`, +or set `dmd2.guidance_scale: null` to disable CFG. + +**Dataloader yields empty batches.** Ensure your cache has at least +`local_batch_size * fsdp.dp_size` items; the distributed sampler drops incomplete batches. + +## Reference + +- Fastgen library: [`modelopt/torch/fastgen/`](../../../modelopt/torch/fastgen/) +- Built-in recipe: [`modelopt_recipes/general/distillation/dmd2_qwen_image.yaml`](../../../modelopt_recipes/general/distillation/dmd2_qwen_image.yaml) +- AutoModel recipe this example subclasses: + [`nemo_automodel/recipes/diffusion/train.py`](https://github.com/NVIDIA-NeMo/Automodel/blob/main/nemo_automodel/recipes/diffusion/train.py) diff --git a/examples/diffusers/fastgen/configs/dmd2_qwen_image.yaml b/examples/diffusers/fastgen/configs/dmd2_qwen_image.yaml new file mode 100644 index 00000000000..791b0efbaa9 --- /dev/null +++ b/examples/diffusers/fastgen/configs/dmd2_qwen_image.yaml @@ -0,0 +1,173 @@ +# Qwen-Image DMD2 — canonical real-data training config. +# +# Enables the full Qwen-Image DMD2 setup: 4-step student, CFG, and the GAN + R1 +# branch. This is the real-data training config; the mock-data wiring smoke +# (no dataset required) lives in ``dmd2_qwen_image_smoke.yaml``. +# +# Launch with torchrun, scaling ``--fsdp.dp_size`` to your GPU count: +# +# torchrun --nproc-per-node= \ +# examples/diffusers/fastgen/dmd2_finetune.py \ +# --config examples/diffusers/fastgen/configs/dmd2_qwen_image.yaml \ +# --step_scheduler.max_steps=5000 +# +# The data.* and checkpoint.* paths below are placeholders — point them at your +# own preprocessed latent cache and output directory before launching. + +seed: 42 + +wandb: + project: fastgen-dmd2-qwen-image + mode: online + name: qwen_image_dmd2 + +dist_env: + backend: nccl + timeout_minutes: 60 + +model: + pretrained_model_name_or_path: Qwen/Qwen-Image + mode: finetune + +step_scheduler: + # Must be divisible by local_batch_size * dp_size. For the canonical + # 32-node GB200 run below, dp_size=128 and local_batch_size=1, so GBS=128 + # gives one micro-batch per rank and no gradient accumulation. + global_batch_size: 128 + local_batch_size: 1 + ckpt_every_steps: 500 + # With 40K cached samples and GBS=128, one epoch is ~313 optimizer steps. + # 16 epochs gives >=5000 optimizer steps; max_steps below stops exactly at 5000. + num_epochs: 16 + log_every: 1 + # Production placeholder; CLI override expected for the long run. + max_steps: 5000 + +# ─── DMD2 block ───────────────────────────────────────────────────────────────── +# ``recipe_path`` is required by ``_resolve_dmd_config`` in dmd2_recipe.py. +# Every actual DMDConfig knob is explicitly pinned below so this YAML is the +# single source of truth for the formal run — the recipe defaults at +# ``modelopt_recipes/general/distillation/dmd2_qwen_image.yaml`` are NOT +# consulted for anything set here. Pydantic's ``model_copy(update=...)`` is +# a shallow merge, so when overriding any ``sample_t_cfg`` or ``ema`` +# sub-field we re-list every field in that block to avoid silent drops. +dmd2: + recipe_path: general/distillation/dmd2_qwen_image + pipeline_plugin: qwen_image + qwen_image_guidance: + + # ── DMD2 method core ── + pred_type: flow # Qwen-Image is rectified flow + num_train_timesteps: # continuous t ∈ [0, 1]; QwenImageDMDPipeline forwards t verbatim + guidance_scale: 4.0 # CFG strength on teacher during the student update + student_sample_steps: 4 # 4-step student (Phase 2) + student_sample_type: ode # Euler integration when unrolling the student + # Default keeps FastGen Qwen parity: train each rung from noised real latents. + # Override with ``--dmd2.backward_simulation=true`` to no-grad unroll the + # current student from the first rung before training the selected rung. + backward_simulation: false + student_update_freq: 5 # one student step per 5 fake-score / discriminator steps + fake_score_pred_type: x0 # fake_score regresses x0; teacher/student live in flow space + + # ── GAN branch (Phase 2) ── + gan_loss_weight_gen: 0.03 # generator weight in student loss + gan_use_same_t_noise: true # share (t, noise) between generator + discriminator updates + gan_r1_reg_weight: 0.1 # R1 gradient penalty on real samples + gan_r1_reg_alpha: 0.1 # R1 EMA coefficient + + # ── Optimizer LRs (student LR comes from ``optim.learning_rate`` below) ── + fake_score_lr: 2.0e-6 + discriminator_lr: 2.0e-6 + + # ── GAN discriminator placement & dim ── + gan_feature_indices: [30] # tap transformer_blocks[30] for the feature head + gan_num_blocks: 60 # Qwen-Image has 60 transformer blocks total + gan_inner_dim: 3072 # Qwen-Image hidden_size (matches FastGen reference) + + # ── 4-step student timestep schedule ── + # Exact ``torch.linspace(max_t=0.999, 0.0, 5).tolist()``, which is also the + # inference pipeline's default schedule when no t_list is passed (see + # ``inference_dmd2_qwen_image.py:259``). Training draws t uniformly from + # t_list[:-1], so the 4 trained timesteps exactly match the 4 inference + # sample points. The earlier ``[0.999, 0.75, 0.5, 0.25, 0.0]`` was + # ``linspace(1.0, 0, 5)`` with t=1 shaved to 0.999, leaving a silent ~0.3% + # train↔inference skew on each non-endpoint timestep. + sample_t_cfg: + # ``time_dist_type`` governs the *perturbation* timestep ``t`` that gets + # sampled on every loss path — VSD perturbation in compute_student_loss + # (dmd.py:417), fake-score DSM perturbation in compute_fake_score_loss + # (dmd.py:529), and GAN/discriminator perturbation in + # compute_discriminator_loss (dmd.py:605). All three call + # ``self.sample_timesteps`` → ``sample_t`` → reads ``time_dist_type``. + # + # It does NOT govern the student's *starting* timestep ``t_student``: + # under student_sample_steps > 1, ``_build_student_input`` calls + # ``sample_from_t_list`` (dmd.py:346) which samples uniformly from + # ``t_list[:-1]`` regardless of ``time_dist_type``. + # + # ``uniform`` matches FastGen's reference Qwen DMD2 config. Earlier formal + # runs used ``logitnormal`` (concentrates ``t`` toward the middle of + # [min_t, max_t]) to follow Automodel's Qwen finetune reference; we now + # default to uniform here to keep launchers parity-aligned without an + # EXTRA_ARGS override. Flip to ``logitnormal`` on the CLI for an ablation. + time_dist_type: uniform + min_t: 0.001 + max_t: 0.999 + p_mean: 0.0 + p_std: 1.0 + t_list: [0.999, 0.74925, 0.4995, 0.24975, 0.0] + + # ── Student EMA ── + ema: + decay: 0.9999 + type: constant + start_iter: 0 + fsdp2: true + mode: full_tensor + +optim: + learning_rate: 2.0e-6 + optimizer: + weight_decay: 0.01 + betas: [0.9, 0.999] + +lr_scheduler: + lr_decay_style: constant + lr_warmup_steps: 0 + min_lr: 2.0e-6 + +fsdp: + tp_size: 1 + cp_size: 1 + pp_size: 1 + dp_replicate_size: 1 + dp_size: 128 # 32 nodes × 4 GPUs/node = 128 GPUs + activation_checkpointing: true + +# Real cached latents (40K items, 1024x1024). Negative-prompt embedding for +# CFG is loaded inside the dataloader via ``negative_prompt_embedding_path``. +data: + dataloader: + _target_: nemo_automodel.components.datasets.diffusion.build_text_to_image_multiresolution_dataloader + cache_dir: /path/to/preprocessed/qwen_image_1024p + base_resolution: [1024, 1024] + batch_size: 1 + drop_last: false + shuffle: true + num_workers: 0 + negative_prompt_embedding_path: /path/to/preprocessed/qwen_image_1024p/negative_prompt_embedding.pt + +# Inference-loadable safetensors saves so checkpoints are usable without a +# secondary export pass. +checkpoint: + enabled: true + checkpoint_dir: /path/to/output/qwen_image_dmd2/checkpoints + model_save_format: safetensors + save_consolidated: true + v4_compatible: true + diffusers_compatible: true + # ``LATEST`` auto-resumes from the most recent checkpoint in checkpoint_dir + # if one exists, and starts fresh otherwise (safe on first launch). To pin a + # specific checkpoint, override on the CLI: + # --checkpoint.restore_from=epoch_0_step_500 + restore_from: LATEST diff --git a/examples/diffusers/fastgen/configs/dmd2_qwen_image_smoke.yaml b/examples/diffusers/fastgen/configs/dmd2_qwen_image_smoke.yaml new file mode 100644 index 00000000000..2d8f2ad3581 --- /dev/null +++ b/examples/diffusers/fastgen/configs/dmd2_qwen_image_smoke.yaml @@ -0,0 +1,109 @@ +# DMD2 on Qwen-Image — mock-data wiring smoke (NOT for real training). +# +# Feeds random tensors at Qwen-Image's shapes/dtypes via +# ``build_mock_t2i_dataloader``. Useful for end-to-end wiring tests (FSDP2, +# phase routing, checkpoint save/restore) but useless for image quality — real +# training uses ``dmd2_qwen_image.yaml`` (real cache + CFG + GAN). + +seed: 42 + +wandb: + project: fastgen-dmd2-qwen-image + mode: online + name: phase1_smoke + +dist_env: + backend: nccl + timeout_minutes: 60 + +model: + # Qwen-Image text-to-image checkpoint (HF id, or a local snapshot path to + # avoid hitting HF on every job). + pretrained_model_name_or_path: Qwen/Qwen-Image + mode: finetune + +step_scheduler: + global_batch_size: 8 + local_batch_size: 1 + ckpt_every_steps: 100 + num_epochs: 1 + log_every: 1 + # Hard cap the Phase 1 smoke at 100 optimizer steps. Flip to null for full runs. + max_steps: 100 + +# ─── DMD2-specific block ──────────────────────────────────────────────────────────── +dmd2: + # Built-in fastgen recipe for Qwen-Image: pred_type=flow, num_train_timesteps=null + # (Qwen normalises t internally), logit_normal time sampling, student_update_freq=5, + # fake_score_pred_type=x0, gan_loss_weight_gen=0 (Phase 1), EMA on. + recipe_path: general/distillation/dmd2_qwen_image + + # Phase 1 overrides: + # * GAN disabled — no discriminator shipped in Phase 1. + # * CFG disabled — no negative-prompt embedding precompute yet; guidance_scale=null + # short-circuits the negative-conditioning branch inside compute_student_loss. + gan_loss_weight_gen: 0.0 + guidance_scale: + + # Phase-1-only knobs (NOT on DMDConfig — consumed directly by the recipe): + # LR for the fake-score AdamW; matches the student LR below. + fake_score_lr: 1.0e-5 + + # Explicit pipeline plugin selector. Auto-detect via model_id substring works for a + # local Qwen-Image snapshot path, but spelling it out keeps the choice visible. + pipeline_plugin: qwen_image + + # Optional guidance scalar forwarded to the transformer's ``guidance`` kwarg every + # call. The shipped ``Qwen/Qwen-Image`` checkpoint has guidance_embeds=false, so + # leave this null. Set to e.g. 3.5 only if you've fine-tuned a guidance-embed + # variant. + qwen_image_guidance: + +# Student LR + optimizer. +optim: + learning_rate: 1.0e-5 + optimizer: + weight_decay: 0.01 + betas: [0.9, 0.999] + +# Constant LR for the smoke — flip to cosine/linear once the loop is validated. +lr_scheduler: + lr_decay_style: constant + lr_warmup_steps: 0 + min_lr: 1.0e-5 + +# FSDP2 config. Scale dp_size to match the GPU count of the run. +fsdp: + tp_size: 1 + cp_size: 1 + pp_size: 1 + dp_replicate_size: 1 + dp_size: 8 + activation_checkpointing: true + +# Mock data by default — the debug target is the DMD2 loop, not the data pipeline. +# Swap _target_ to build_text_to_image_multiresolution_dataloader once a real +# preprocessed cache is available. +data: + dataloader: + _target_: nemo_automodel.components.datasets.diffusion.build_mock_t2i_dataloader + # Qwen-Image VAE: 16 latent channels, 8x spatial downsample. + # 256x256 image -> 32x32 latent. Must be even (2x2 patch packing). + num_channels: 16 + spatial_h: 32 + spatial_w: 32 + # Qwen2.5-VL hidden_dim = 3584 (text_encoder/config.json: hidden_size=3584). + text_seq_len: 512 + text_embed_dim: 3584 + length: 256 + num_workers: 0 + shuffle: true + +checkpoint: + enabled: true + checkpoint_dir: /path/to/output/qwen_image_dmd2_smoke/checkpoints + model_save_format: torch_save + save_consolidated: false + diffusers_compatible: false + # Set to LATEST or epoch_0_step_100 (etc.) to resume. Null = start fresh. + restore_from: diff --git a/examples/diffusers/fastgen/dmd2_finetune.py b/examples/diffusers/fastgen/dmd2_finetune.py new file mode 100644 index 00000000000..0f7936ef1bb --- /dev/null +++ b/examples/diffusers/fastgen/dmd2_finetune.py @@ -0,0 +1,38 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Entrypoint for the DMD2 Qwen-Image AutoModel example. + +Parses the YAML config + CLI overrides with AutoModel's argument parser, then hands +control to :class:`DMD2DiffusionRecipe`. +""" + +from __future__ import annotations + +from dmd2_recipe import DMD2DiffusionRecipe +from nemo_automodel.components.config._arg_parser import parse_args_and_load_config + + +def main( + default_config_path: str = "examples/diffusers/fastgen/configs/dmd2_qwen_image_smoke.yaml", +) -> None: + cfg = parse_args_and_load_config(default_config_path) + recipe = DMD2DiffusionRecipe(cfg) + recipe.setup() + recipe.run_train_validation_loop() + + +if __name__ == "__main__": + main() diff --git a/examples/diffusers/fastgen/dmd2_recipe.py b/examples/diffusers/fastgen/dmd2_recipe.py new file mode 100644 index 00000000000..9614d4283a5 --- /dev/null +++ b/examples/diffusers/fastgen/dmd2_recipe.py @@ -0,0 +1,1317 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""DMD2 distillation recipe built on NeMo AutoModel. + +This recipe subclasses :class:`nemo_automodel.recipes.diffusion.train.TrainDiffusionRecipe` +so it inherits AutoModel's student + optimizer + dataloader + checkpoint plumbing, then +drives ``modelopt.torch.fastgen.DMDPipeline`` (or a plugin subclass) through the +three-phase DMD2 alternation (student update / fake-score update / EMA step). + +Backbone: **Qwen-Image** (``Qwen/Qwen-Image``) — 4D ``image_latents``, +:class:`QwenImageDMDPipeline` handles 2x2 patch packing / img_shapes / +unpacking. Configs: ``configs/dmd2_qwen_image.yaml`` for the canonical +real-data run (4-step + CFG + GAN); ``configs/dmd2_qwen_image_smoke.yaml`` +for the mock-data wiring smoke (no dataset required). + +Launch:: + + # Mock-data wiring smoke (no real cache required). + torchrun --nproc-per-node=8 \\ + examples/diffusers/fastgen/dmd2_finetune.py \\ + --config examples/diffusers/fastgen/configs/dmd2_qwen_image_smoke.yaml + # Real-data formal training (canonical). + torchrun --nproc-per-node=8 \\ + examples/diffusers/fastgen/dmd2_finetune.py \\ + --config examples/diffusers/fastgen/configs/dmd2_qwen_image.yaml + +See ``examples/diffusers/fastgen/README.md`` for the three-phase +alternation diagram + troubleshooting notes. +""" + +from __future__ import annotations + +import json +import logging +import os +import re +import shutil +from typing import Any + +import torch +import torch.distributed as dist + +# nemo_automodel is required to run this example (installed via requirements.txt). Wrap +# the import in a clear, actionable error, but still re-raise so it fails loudly with a +# real stack — a previous gate that fell back to ``object`` silently masked missing deps +# and surfaced as a downstream ``TypeError: takes no arguments``. +try: + from nemo_automodel._diffusers.auto_diffusion_pipeline import NeMoAutoDiffusionPipeline + from nemo_automodel.recipes.diffusion.train import TrainDiffusionRecipe, is_main_process +except ImportError as exc: + raise ImportError( + "The DMD2 fastgen example requires `nemo_automodel`. Install the example " + "dependencies with:\n" + " pip install -r examples/diffusers/fastgen/requirements.txt" + ) from exc +from torch import nn + +import modelopt.torch.fastgen as mtf +from modelopt.torch.fastgen.config import DMDConfig +from modelopt.torch.fastgen.discriminators import Discriminator_ImageDiT +from modelopt.torch.fastgen.methods.dmd import DMDPipeline +from modelopt.torch.fastgen.plugins import qwen_image as qwen_image_plugin + +# Keys under the ``dmd2:`` YAML block that shadow fields on :class:`DMDConfig`. The +# recipe deep-merges these on top of the loaded built-in recipe so users can tweak DMD2 +# hyperparameters without editing the shared +# ``modelopt_recipes/general/distillation/dmd2_qwen_image.yaml`` file. +_DMD_CONFIG_OVERRIDE_KEYS = frozenset(DMDConfig.model_fields.keys()) + + +def _deep_merge_dicts(base: dict, override: dict) -> dict: + """Recursively merge ``override`` onto ``base`` and return a new dict. + + Nested dicts (e.g. the ``sample_t_cfg`` / ``ema`` sub-configs) are merged key-by-key + rather than replaced wholesale, so a YAML block that overrides a single sub-field + keeps the recipe's other sub-fields instead of silently resetting them to + :class:`DMDConfig` defaults. + """ + merged = dict(base) + for key, value in override.items(): + existing = merged.get(key) + if isinstance(value, dict) and isinstance(existing, dict): + merged[key] = _deep_merge_dicts(existing, value) + else: + merged[key] = value + return merged + + +# Auto-detect substrings (matched case-insensitively against ``model_id``) that map to +# DMDPipeline plugin subclasses. Keep this list small — adding a new entry is only the +# right move when the model has a non-diffusers transformer signature that requires a +# pack/unpack wrapper. Models with the standard ``(hidden_states, timestep, +# encoder_hidden_states)`` signature work with the base :class:`DMDPipeline`. +_PIPELINE_PLUGIN_BY_MODEL_SUBSTR = ( + ("qwen-image", "qwen_image"), + ("qwen_image", "qwen_image"), +) + +_DMD_COMPLETE_MARKER = "dmd2_complete.marker" + + +class DMD2DiffusionRecipe(TrainDiffusionRecipe): + """DMD2 recipe that reuses ``TrainDiffusionRecipe`` for the student path. + + What the superclass gives us (reused unchanged): + + - Student transformer + AdamW optimizer + LR scheduler, loaded via + :class:`NeMoAutoDiffusionPipeline` with FSDP2 sharding. + - ``self.dataloader`` / ``self.sampler`` (swapped to AutoModel's mock dataloader + when ``data.use_mock: true`` — see :meth:`_build_dataloader`). + - ``self.step_scheduler`` (gradient accumulation + checkpoint cadence). + - ``self.checkpointer`` (DCP student weights + optimizer). + - ``self.device`` / ``self.bf16`` / ``self.clip_grad_max_norm`` / etc. + + What this recipe adds: + + - A frozen teacher loaded via a second :meth:`NeMoAutoDiffusionPipeline.from_pretrained` + call with the same ``parallel_scheme`` so it lands with the same FSDP2 sharding + as the student. + - A trainable fake-score transformer loaded the same way (weights identical to the + teacher on step 0). + - A separate AdamW optimizer for the fake-score phase. + - An :class:`mtf.DMDPipeline` driving VSD + DSM + EMA. + - Sidecar checkpoint save / restore for fake-score weights, fake-score optimizer, + EMA shadow, and DMDPipeline iteration counters. + + Classifier-free guidance, the GAN discriminator branch, and real-data training are + configurable via the ``dmd2:`` / ``data:`` YAML blocks — all enabled in the canonical + ``configs/dmd2_qwen_image.yaml`` and off in the mock-data smoke. See + ``examples/diffusers/fastgen/README.md`` for details. + """ + + # ------------------------------------------------------------------ # + # Setup # + # ------------------------------------------------------------------ # + + def setup(self) -> None: + """Build the student via ``super()``, then add teacher / fake_score / DMDPipeline. + + The extras (``_teacher``, ``_fake_score``, ``_fake_score_optimizer``, + ``_dmd_pipeline``, ``_dmd_config``) are assigned through ``self.__dict__[...]`` + to bypass :meth:`BaseRecipe.__setattr__`'s auto-tracking — otherwise they'd be + added to ``__state_tracked`` and clobber the superclass's single-model / + single-optimizer checkpoint loop. + """ + # 1. Run the parent setup. Builds self.model / self.optimizer / self.lr_scheduler / + # self.dataloader / self.step_scheduler / self.checkpointer / etc. The parent's + # trailing call to self.load_checkpoint(self.restore_from) runs BEFORE our + # extras exist, so it only restores the student — that is intentional and safe. + # + # For the mock-data smoke, ``data.dataloader._target_`` in the YAML points at + # ``nemo_automodel.components.datasets.diffusion.build_mock_dataloader`` so the + # parent wires up the mock dataloader for us — no swap needed. + super().setup() + + # 2. Load the frozen teacher. Same from_pretrained path, same parallel_scheme, but + # ``load_for_training=False`` so the transformer comes back in eval mode with + # requires_grad=False. Bypass __setattr__ to stay invisible to the parent's + # __state_tracked loop. + self.__dict__["_teacher"] = self._load_frozen_teacher() + + # 4. Load the trainable fake-score. Third from_pretrained call — weights start + # identical to the teacher (both come from the same HF checkpoint). + self.__dict__["_fake_score"] = self._load_fake_score() + + # 5. Resolve the DMDConfig: load the fastgen built-in recipe, then apply any + # inline overrides under the YAML ``dmd2:`` block. + self.__dict__["_dmd_config"] = self._resolve_dmd_config() + + # 6. Optimizer for the fake-score phase. LR defaults to student LR when + # ``dmd2.fake_score_lr`` isn't set. + self.__dict__["_fake_score_optimizer"] = self._build_fake_score_optimizer() + + # 7. Optional GAN discriminator. Built when ``gan_loss_weight_gen > 0`` so the + # DMDPipeline constructor's assert is satisfied; otherwise ``discriminator=None`` + # and that assert fires if a YAML enables GAN for an unsupported backbone. + self.__dict__["_discriminator"] = self._build_discriminator() + self.__dict__["_discriminator_optimizer"] = self._build_discriminator_optimizer() + if self._discriminator is not None: + self._attach_gan_feature_capture() + + # 8. DMDPipeline. + # + # Dispatch to a plugin subclass when the backbone needs a non-diffusers call + # signature (e.g. Qwen-Image's packed-latents path). Default = base pipeline. + pipeline_cls = self._resolve_pipeline_cls() + pipeline_kwargs = self._resolve_pipeline_kwargs(pipeline_cls) + self.__dict__["_dmd_pipeline"] = pipeline_cls( + student=self.model, + teacher=self._teacher, + fake_score=self._fake_score, + config=self._dmd_config, + discriminator=self._discriminator, + **pipeline_kwargs, + ) + + # 8. Drop the parent's flow_matching_pipeline — we replace the training loop, + # so keeping it around is pure deadweight. The attribute is not tracked by + # ``__state_tracked`` (FlowMatchingPipeline is a plain class), so ``del`` is + # safe. + if hasattr(self, "flow_matching_pipeline"): + del self.flow_matching_pipeline + + # 9. Extend the student-only restore that super().setup() already ran: also + # restore the fake_score / fake_score_optimizer / EMA / DMD state from the + # same checkpoint directory. + self._restore_dmd_extras(getattr(self, "_dmd2_resolved_restore_from", self.restore_from)) + + if is_main_process(): + logging.info("[DMD2] recipe initialized: %s", self._dmd_config_summary()) + logging.info("[DMD2] full configuration:\n%s", self._dmd_full_config_log()) + + # ------------------------------------------------------------------ # + # Training loop # + # ------------------------------------------------------------------ # + + def run_train_validation_loop(self) -> None: + """Three-phase DMD2 alternation driven by ``step_scheduler``. + + Each outer iteration picks either the student or fake-score phase based on + ``global_step % student_update_freq``. The student phase runs + ``compute_student_loss`` + ``update_ema``. The fake-score phase runs + ``compute_fake_score_loss`` and, when a discriminator is configured + (``gan_loss_weight_gen > 0``), ``compute_discriminator_loss``. + + Mirrors the gating in ``FastGen/fastgen/methods/distribution_matching/dmd2.py`` + (``_student_update_step`` / ``_fake_score_discriminator_update_step``). + """ + dmd = self._dmd_pipeline + cfg = self._dmd_config + + logging.info( + "[DMD2] Starting DMD2 training on %s (pipeline=%s)", + self.model_id, + type(self._dmd_pipeline).__name__, + ) + # Dataloader target (mock vs real cache) is non-obvious from the per-step + # logs; surface it explicitly here so §16's "mock or real dataloader + # target" bullet is checkable from the startup log. + try: + dl_target = type(self.dataloader.dataset).__name__ + logging.info("[DMD2] Dataloader dataset class: %s", dl_target) + except Exception: + pass + logging.info( + "[DMD2] Global batch size: %s; local batch size: %s; DP size: %s", + self.global_batch_size, + self.local_batch_size, + self.dp_size, + ) + logging.info( + "[DMD2] student_update_freq=%d; fake_score_pred_type=%s; guidance_scale=%s;" + " gan_loss_weight_gen=%s", + cfg.student_update_freq, + cfg.fake_score_pred_type, + cfg.guidance_scale, + cfg.gan_loss_weight_gen, + ) + + global_step = int(self.step_scheduler.step) + + for epoch in self.step_scheduler.epochs: + if self.sampler is not None and hasattr(self.sampler, "set_epoch"): + self.sampler.set_epoch(epoch) + + # On resume, the diffusion sampler's load_state_dict primes + # ``_batches_to_skip`` so the next ``__iter__`` skips already-yielded + # batches. Forward that to tqdm's ``initial=`` so the progress bar + # reads e.g. ``187/313`` instead of the misleading ``0/313`` (the + # sampler resets the counter to 0 on the next ``__iter__`` call, + # so reading it here is a one-shot for the resumed epoch only). + tqdm_initial = int(getattr(self.sampler, "_batches_to_skip", 0) or 0) + + if is_main_process(): + from tqdm import tqdm + + self.step_scheduler.dataloader = tqdm( + self.dataloader, + desc=f"Epoch {epoch + 1}/{self.num_epochs}", + initial=tqdm_initial, + ) + else: + self.step_scheduler.dataloader = self.dataloader + + epoch_student_loss = 0.0 + epoch_fake_score_loss = 0.0 + student_steps = 0 + fake_score_steps = 0 + + for batch_group in self.step_scheduler: + is_student_phase = (global_step % cfg.student_update_freq) == 0 + + if is_student_phase: + self.optimizer.zero_grad(set_to_none=True) + else: + self._fake_score_optimizer.zero_grad(set_to_none=True) + + self._set_grad_requirements(is_student_phase) + + micro_losses: list[float] = [] + micro_vsd_losses: list[float] = [] + micro_disc_losses: list[float] = [] + for micro_batch in batch_group: + ( + latents, + noise, + text_embeds, + text_mask, + neg_text_embeds, + neg_text_mask, + ) = self._prepare_micro_batch(micro_batch) + + if is_student_phase: + # ``compute_student_loss`` reads ``guidance_scale`` from the + # DMDConfig when this kwarg is None. We pass the negative + # embedding unconditionally — the function ignores it when + # CFG is disabled, and raises a clear ValueError when CFG + # is enabled but no negative was supplied. + losses = dmd.compute_student_loss( + latents, + noise, + encoder_hidden_states=text_embeds, + encoder_hidden_states_mask=text_mask, + negative_encoder_hidden_states=neg_text_embeds, + negative_encoder_hidden_states_mask=neg_text_mask, + guidance_scale=None, + ) + micro_vsd_losses.append(float(losses["vsd"].item())) + else: + losses = dmd.compute_fake_score_loss( + latents, + noise, + encoder_hidden_states=text_embeds, + encoder_hidden_states_mask=text_mask, + ) + + (losses["total"] / len(batch_group)).backward() + micro_losses.append(float(losses["total"].item())) + + # GAN: in the fake-score phase, also update the discriminator + # on the same batch (FastGen pattern: + # _fake_score_discriminator_update_step). + if ( + not is_student_phase + and self._discriminator is not None + and self._discriminator_optimizer is not None + ): + self._discriminator_optimizer.zero_grad(set_to_none=True) + disc_losses = dmd.compute_discriminator_loss( + latents, + noise, + encoder_hidden_states=text_embeds, + encoder_hidden_states_mask=text_mask, + ) + (disc_losses["total"] / len(batch_group)).backward() + # Manual gradient all-reduce across DP ranks (the + # discriminator is replicated, not FSDP-sharded). + if dist.is_initialized(): + for p in self._discriminator.parameters(): + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + torch.nn.utils.clip_grad_norm_( + self._discriminator.parameters(), + max_norm=self.clip_grad_max_norm, + ) + self._discriminator_optimizer.step() + micro_disc_losses.append(float(disc_losses["total"].item())) + + # Grad clip on whichever module is the active trainable. + if is_student_phase: + grad_norm = torch.nn.utils.clip_grad_norm_( + self.model.parameters(), max_norm=self.clip_grad_max_norm + ) + else: + grad_norm = torch.nn.utils.clip_grad_norm_( + self._fake_score.parameters(), max_norm=self.clip_grad_max_norm + ) + grad_norm = float(grad_norm) if torch.is_tensor(grad_norm) else grad_norm + + # Step. + if is_student_phase: + self.optimizer.step() + dmd.update_ema() + if self.lr_scheduler is not None: + self.lr_scheduler[0].step(1) + else: + self._fake_score_optimizer.step() + + group_loss_mean = float(sum(micro_losses) / len(micro_losses)) + if is_student_phase: + epoch_student_loss += group_loss_mean + student_steps += 1 + else: + epoch_fake_score_loss += group_loss_mean + fake_score_steps += 1 + + global_step = int(self.step_scheduler.step) + + if ( + self.log_every + and self.log_every > 0 + and is_main_process() + and (global_step % self.log_every == 0) + ): + self._log_step( + global_step=global_step, + is_student_phase=is_student_phase, + group_loss=group_loss_mean, + grad_norm=grad_norm, + vsd_loss=(sum(micro_vsd_losses) / len(micro_vsd_losses)) + if micro_vsd_losses + else None, + disc_loss=(sum(micro_disc_losses) / len(micro_disc_losses)) + if micro_disc_losses + else None, + ) + + if self.step_scheduler.is_ckpt_step: + # Use the group mean of the active phase as the reported train loss. + self.save_checkpoint(epoch, global_step, group_loss_mean) + + # End-of-epoch logging. + if is_main_process(): + avg_student = ( + (epoch_student_loss / student_steps) if student_steps else float("nan") + ) + avg_fake = ( + epoch_fake_score_loss / fake_score_steps if fake_score_steps else float("nan") + ) + logging.info( + "[DMD2] Epoch %d complete. student_avg=%.6f (%d steps) " + "fake_score_avg=%.6f (%d steps)", + epoch + 1, + avg_student, + student_steps, + avg_fake, + fake_score_steps, + ) + + if torch.cuda.is_available(): + peak = torch.cuda.max_memory_allocated() / (1024**3) + reserved = torch.cuda.max_memory_reserved() / (1024**3) + rank = dist.get_rank() if dist.is_initialized() else 0 + logging.info( + "[DMD2] PEAK_MEM rank=%d max_allocated=%.2fGiB max_reserved=%.2fGiB", + rank, + peak, + reserved, + ) + + if is_main_process(): + logging.info("[DMD2] Training complete. Final step: %s", global_step) + + # ------------------------------------------------------------------ # + # Checkpoint save / restore (sidecars next to student DCP) # + # ------------------------------------------------------------------ # + + def load_checkpoint(self, restore_from: str | None = None): + """Load only from checkpoints whose DMD2 sidecars are complete. + + ``TrainDiffusionRecipe.setup()`` calls this before the DMD2-only objects exist, + so this method only resolves the path and delegates the student restore to the + parent. The sidecars are restored later by ``_restore_dmd_extras``. + """ + resolved = self._resolve_complete_dmd_checkpoint(restore_from) + self.__dict__["_dmd2_resolved_restore_from"] = resolved + + if resolved is None: + if ( + restore_from is not None + and str(restore_from).upper() == "LATEST" + and is_main_process() + ): + logging.warning( + "[DMD2] restore_from=LATEST but no complete DMD2 checkpoint was found in %s. " + "Starting fresh.", + self.checkpointer.config.checkpoint_dir, + ) + return + + super().load_checkpoint(resolved) + + def save_checkpoint( + self, + epoch: int, + step: int, + train_loss: float, + val_loss: dict[str, float] | None = None, + best_metric_key: str = "default", + ) -> None: + """Delegate student save to ``super()``, then sidecar the DMD2 extras.""" + # Recover from a partial save from a previous run (e.g. SLURM time + # limit killed the job between super().save_checkpoint() — which writes + # the model + step_scheduler + dataloader — and our DMD2 sidecar + # writes below). The parent's save_checkpoint refuses to overwrite an + # existing directory and raises FileExistsError, so without this we'd + # need a manual cleanup every time a SLURM kill landed mid-save. + path = os.path.join(self.checkpointer.config.checkpoint_dir, f"epoch_{epoch}_step_{step}") + if is_main_process() and self.checkpointer.config.enabled and os.path.exists(path): + if not self._is_dmd_checkpoint_complete(path): + logging.warning( + "[DMD2] cleaning up incomplete checkpoint directory left by a previous run: %s", + path, + ) + shutil.rmtree(path) + if dist.is_initialized(): + dist.barrier() + + previous_complete = None + if self.checkpointer.config.enabled: + previous_complete = self._find_latest_complete_dmd_checkpoint( + self.checkpointer.config.checkpoint_dir + ) + + super().save_checkpoint(epoch, step, train_loss, val_loss, best_metric_key) + + if not self.checkpointer.config.enabled: + return + + # The parent save updates LATEST before DMD2 sidecars exist. Until the marker is + # written below, keep LATEST on the previous complete DMD2 checkpoint. + if is_main_process(): + if previous_complete is not None: + self._update_latest_symlink(previous_complete) + else: + self._remove_checkpoint_pointer("LATEST") + if dist.is_initialized(): + dist.barrier() + + self._save_dmd_extras(path) + + if dist.is_initialized(): + dist.barrier() + + if is_main_process(): + self._write_dmd_complete_marker(path) + self._update_checkpoint_symlink("DMD2_LATEST", path) + self._update_latest_symlink(path) + if dist.is_initialized(): + dist.barrier() + + def _save_dmd_extras(self, path: str) -> None: + """Write fake_score DCP + fake_score_optimizer DCP + ema_shadow.pt + dmd_state.pt.""" + # fake_score weights — DCP sharded save via the same Checkpointer the parent uses + # for the student. Each rank writes its own shard. + fs_weights_dir = os.path.join(path, "fake_score") + os.makedirs(fs_weights_dir, exist_ok=True) + self.checkpointer.save_model( + model=self._fake_score, + weights_path=fs_weights_dir, + peft_config=None, + tokenizer=None, + ) + # fake_score optimizer — also DCP sharded. ``save_optimizer`` takes the optimizer + # and its owning model in order to rebuild the parameter mapping. + fs_opt_dir = os.path.join(path, "fake_score_optimizer") + os.makedirs(fs_opt_dir, exist_ok=True) + self.checkpointer.save_optimizer( + self._fake_score_optimizer, self._fake_score, fs_opt_dir, None + ) + + # EMA shadow + DMD scalar state — rank-0 torch.save. EMA's ``state_dict`` already + # materialises full tensors via ``DTensor.full_tensor()`` under FSDP2 full_tensor + # mode, so this is a single unsharded file. + if is_main_process(): + logging.info("[DMD2] saved fake_score weights -> %s", fs_weights_dir) + logging.info("[DMD2] saved fake_score optimizer -> %s", fs_opt_dir) + if self._dmd_pipeline.ema is not None: + ema_path = os.path.join(path, "ema_shadow.pt") + torch.save(self._dmd_pipeline.ema.state_dict(), ema_path) + logging.info("[DMD2] saved ema_shadow -> %s", ema_path) + state_path = os.path.join(path, "dmd_state.pt") + torch.save({"iteration": self._dmd_pipeline._iteration}, state_path) + logging.info( + "[DMD2] saved dmd_state (iteration=%d) -> %s", + int(self._dmd_pipeline._iteration), + state_path, + ) + # Discriminator + its optimizer — replicated across ranks (no FSDP), + # so rank-0 torch.save of the canonical state_dict suffices. + if self._discriminator is not None: + disc_path = os.path.join(path, "discriminator.pt") + torch.save(self._discriminator.state_dict(), disc_path) + logging.info("[DMD2] saved discriminator -> %s", disc_path) + if self._discriminator_optimizer is not None: + disc_opt_path = os.path.join(path, "discriminator_optimizer.pt") + torch.save(self._discriminator_optimizer.state_dict(), disc_opt_path) + logging.info("[DMD2] saved discriminator optimizer -> %s", disc_opt_path) + + def _write_dmd_complete_marker(self, path: str) -> None: + marker_path = os.path.join(path, _DMD_COMPLETE_MARKER) + payload = { + "checkpoint": os.path.basename(os.path.realpath(path)), + "dmd_iteration": int(self._dmd_pipeline._iteration), + } + with open(marker_path, "w") as f: + json.dump(payload, f) + f.write("\n") + logging.info("[DMD2] marked checkpoint complete -> %s", marker_path) + + def _remove_checkpoint_pointer(self, link_name: str) -> None: + ckpt_root = self.checkpointer.config.checkpoint_dir + for path in ( + os.path.join(ckpt_root, link_name), + os.path.join(ckpt_root, f"{link_name}.txt"), + ): + if os.path.lexists(path): + os.remove(path) + + def _restore_dmd_extras(self, restore_from: str | None) -> None: + """Restore fake_score + fake_score optimizer + EMA + DMD scalar state. + + No-op when no checkpoint is being restored. ``load_checkpoint`` resolves + ``LATEST`` to the latest complete DMD2 checkpoint before this method runs. + """ + if restore_from is None: + return + + ckpt_dir = self._resolve_extras_dir(restore_from) + if ckpt_dir is None or not os.path.isdir(ckpt_dir): + return + + fs_weights_dir = os.path.join(ckpt_dir, "fake_score") + fs_opt_dir = os.path.join(ckpt_dir, "fake_score_optimizer") + ema_path = os.path.join(ckpt_dir, "ema_shadow.pt") + state_path = os.path.join(ckpt_dir, "dmd_state.pt") + + # Checkpointer.save_model writes DCP shards to ``/model/``; + # load_model expects that *inner* ``model/`` dir as ``model_path`` (see + # ``BaseRecipe.load_checkpoint`` which passes ``os.path.join(ckpt_dir, "model")``). + # The kwarg name differs between save (``weights_path``) and load (``model_path``). + fs_weights_model_dir = os.path.join(fs_weights_dir, "model") + if os.path.isdir(fs_weights_model_dir): + self.checkpointer.load_model(model=self._fake_score, model_path=fs_weights_model_dir) + if is_main_process(): + logging.info("[DMD2] restored fake_score weights <- %s", fs_weights_model_dir) + elif is_main_process(): + logging.info( + "[DMD2] WARN: fake_score weights dir missing at %s -- skipping", + fs_weights_model_dir, + ) + # load_optimizer, in contrast, appends ``optim/`` internally — pass the base dir. + if os.path.isdir(os.path.join(fs_opt_dir, "optim")): + self.checkpointer.load_optimizer( + self._fake_score_optimizer, self._fake_score, fs_opt_dir, None + ) + if is_main_process(): + logging.info("[DMD2] restored fake_score optimizer <- %s", fs_opt_dir) + elif is_main_process(): + logging.info( + "[DMD2] WARN: fake_score optimizer dir missing at %s -- skipping", + fs_opt_dir, + ) + + if os.path.isfile(ema_path) and self._dmd_pipeline.ema is not None: + ema_state = torch.load(ema_path, map_location="cpu", weights_only=False) + self._dmd_pipeline.ema.load_state_dict(ema_state) + if is_main_process(): + logging.info("[DMD2] restored ema_shadow <- %s", ema_path) + if os.path.isfile(state_path): + state = torch.load(state_path, map_location="cpu", weights_only=False) + self._dmd_pipeline._iteration = int(state.get("iteration", 0)) + if is_main_process(): + logging.info( + "[DMD2] restored dmd_state (iteration=%d) <- %s", + self._dmd_pipeline._iteration, + state_path, + ) + + # Discriminator + its optimizer. + if self._discriminator is not None: + disc_path = os.path.join(ckpt_dir, "discriminator.pt") + if os.path.isfile(disc_path): + disc_state = torch.load(disc_path, map_location="cpu", weights_only=False) + self._discriminator.load_state_dict(disc_state) + if is_main_process(): + logging.info("[DMD2] restored discriminator <- %s", disc_path) + elif is_main_process(): + logging.info("[DMD2] WARN: discriminator file missing at %s -- skipping", disc_path) + if self._discriminator_optimizer is not None: + disc_opt_path = os.path.join(ckpt_dir, "discriminator_optimizer.pt") + if os.path.isfile(disc_opt_path): + disc_opt_state = torch.load(disc_opt_path, map_location="cpu", weights_only=False) + self._discriminator_optimizer.load_state_dict(disc_opt_state) + if is_main_process(): + logging.info("[DMD2] restored discriminator optimizer <- %s", disc_opt_path) + elif is_main_process(): + logging.info( + "[DMD2] WARN: discriminator optimizer file missing at %s -- skipping", + disc_opt_path, + ) + + def _resolve_extras_dir(self, restore_from: str) -> str | None: + """Best-effort resolve of the checkpoint dir, matching BaseRecipe's convention. + + For explicit paths we pass through; for ``"LATEST"`` we look under + ``checkpointer.config.checkpoint_dir``. This keeps resolution simple and delegates + the hard cases (async symlinks, cross-node shared filesystems) to the user. + """ + if os.path.isabs(restore_from): + return restore_from + # Try the checkpoint_dir-relative form first (matches the parent's symlink + # naming — "LATEST" or an explicit ``epoch_N_step_M`` subdir). + candidate = os.path.join(self.checkpointer.config.checkpoint_dir, restore_from) + if os.path.exists(candidate): + return os.path.realpath(candidate) + return None + + def _resolve_complete_dmd_checkpoint(self, restore_from: str | None) -> str | None: + ckpt_root = self.checkpointer.config.checkpoint_dir + + if restore_from is None or str(restore_from).upper() in {"LATEST", "DMD2_LATEST"}: + return self._find_latest_complete_dmd_checkpoint(ckpt_root) + + if os.path.isabs(restore_from): + candidate = restore_from + else: + candidate = os.path.join(ckpt_root, restore_from) + candidate = os.path.realpath(candidate) + + if not os.path.isdir(candidate): + return candidate + if not self._is_dmd_checkpoint_complete(candidate): + raise RuntimeError( + f"DMD2 checkpoint is incomplete and cannot be restored: {candidate}. " + "Use a complete older checkpoint or remove the partial directory." + ) + return candidate + + def _find_latest_complete_dmd_checkpoint(self, ckpt_root: str) -> str | None: + dmd2_latest = os.path.join(ckpt_root, "DMD2_LATEST") + for pointer in (dmd2_latest, os.path.join(ckpt_root, "LATEST")): + resolved = self._resolve_checkpoint_pointer(pointer) + if resolved is not None and self._is_dmd_checkpoint_complete(resolved): + return resolved + + candidates = [] + if os.path.isdir(ckpt_root): + for name in os.listdir(ckpt_root): + path = os.path.join(ckpt_root, name) + if ( + os.path.isdir(path) + and "_step_" in name + and self._is_dmd_checkpoint_complete(path) + ): + candidates.append(os.path.realpath(path)) + if not candidates: + return None + return max(candidates, key=self._checkpoint_step) + + def _resolve_checkpoint_pointer(self, pointer: str) -> str | None: + resolved = None + if os.path.islink(pointer): + try: + resolved = os.readlink(pointer) + except OSError: + return None + elif os.path.isfile(pointer + ".txt"): + try: + with open(pointer + ".txt") as f: + resolved = f.read().strip() + except OSError: + return None + if not resolved: + return None + if not os.path.isabs(resolved): + resolved = os.path.abspath(os.path.join(os.path.dirname(pointer), resolved)) + return os.path.realpath(resolved) if os.path.isdir(resolved) else None + + def _is_dmd_checkpoint_complete(self, path: str) -> bool: + path = os.path.realpath(path) + if not os.path.isdir(path): + return False + if os.path.isfile(os.path.join(path, _DMD_COMPLETE_MARKER)): + return True + + fs_model_dir = os.path.join(path, "fake_score", "model") + fs_opt_metadata = os.path.join(path, "fake_score_optimizer", "optim", ".metadata") + dmd_state = os.path.join(path, "dmd_state.pt") + complete = ( + self._dir_has_regular_file(fs_model_dir) + and os.path.isfile(fs_opt_metadata) + and os.path.isfile(dmd_state) + ) + if not complete: + return False + + if self._cfg_gan_enabled(): + return os.path.isfile(os.path.join(path, "discriminator.pt")) and os.path.isfile( + os.path.join(path, "discriminator_optimizer.pt") + ) + return True + + def _cfg_gan_enabled(self) -> bool: + cfg = getattr(self, "cfg", None) + if cfg is None: + return False + try: + return float(cfg.get("dmd2.gan_loss_weight_gen", 0.0) or 0.0) > 0 + except (TypeError, ValueError): + return False + + @staticmethod + def _dir_has_regular_file(path: str) -> bool: + if not os.path.isdir(path): + return False + try: + with os.scandir(path) as entries: + return any(entry.is_file() for entry in entries) + except OSError: + return False + + @staticmethod + def _checkpoint_step(path: str) -> int: + match = re.search(r"_step_(\d+)$", os.path.basename(os.path.realpath(path))) + return int(match.group(1)) if match else -1 + + # ------------------------------------------------------------------ # + # Helpers — teacher / fake_score loading, DMDConfig resolution # + # ------------------------------------------------------------------ # + + def _load_frozen_teacher(self) -> nn.Module: + """Load a second copy of the pretrained transformer, frozen + FSDP2-sharded. + + The same pretrained path + ``parallel_scheme`` as the student. Setting + ``load_for_training=False`` walks the parameters once and flips + ``requires_grad=False`` after FSDP2 wrapping; we also call ``.eval()`` on the + returned module just to be defensive. + """ + parallel_scheme = self._build_parallel_scheme_snapshot() + pipe, _ = NeMoAutoDiffusionPipeline.from_pretrained( + self.model_id, + torch_dtype=self.bf16, + device=self.device, + parallel_scheme=parallel_scheme, + components_to_load=["transformer"], + load_for_training=False, + low_cpu_mem_usage=True, + ) + teacher = pipe.transformer + teacher.eval() + for p in teacher.parameters(): + p.requires_grad_(False) + return teacher + + def _build_discriminator(self) -> nn.Module | None: + """Construct the Discriminator_ImageDiT when GAN is enabled. + + Returns ``None`` when ``dmd2.gan_loss_weight_gen`` is zero so the + DMDPipeline runs without a discriminator (any run with the GAN branch disabled). + """ + gan_weight = float(self.cfg.get("dmd2.gan_loss_weight_gen", 0.0) or 0.0) + if gan_weight <= 0.0: + return None + + # GAN-specific knobs read directly from the YAML so callers don't have to + # touch the built-in DMDConfig recipe just to flip feature indices. + feature_indices = self.cfg.get( + "dmd2.gan_feature_indices", [30] + ) # middle of Qwen-Image's 60 blocks + num_blocks = int(self.cfg.get("dmd2.gan_num_blocks", 60)) + inner_dim = int(self.cfg.get("dmd2.gan_inner_dim", 3072)) + + disc = Discriminator_ImageDiT( + feature_indices={int(i) for i in feature_indices}, + num_blocks=num_blocks, + inner_dim=inner_dim, + ) + disc.to(device=self.device, dtype=self.bf16) + disc.train() + for p in disc.parameters(): + p.requires_grad_(True) + if is_main_process(): + logging.info( + "[DMD2] Built discriminator: %s | num_features=%d num_blocks=%d inner_dim=%d " + "params=%d", + type(disc).__name__, + disc.num_features, + num_blocks, + inner_dim, + sum(p.numel() for p in disc.parameters()), + ) + return disc + + def _build_discriminator_optimizer(self) -> torch.optim.Optimizer | None: + """AdamW on the discriminator. No FSDP wrap — manual grad all-reduce keeps it simple.""" + if self._discriminator is None: + return None + lr = float(self.cfg.get("dmd2.discriminator_lr", 1.0e-5) or 1.0e-5) + opt = torch.optim.AdamW( + self._discriminator.parameters(), + lr=lr, + weight_decay=0.01, + betas=(0.9, 0.999), # FastGen Qwen-Image DMD2 inherits BaseOptimizerConfig betas + ) + if is_main_process(): + logging.info("[DMD2] Built discriminator optimizer: AdamW lr=%g betas=(0.9, 0.999)", lr) + return opt + + def _attach_gan_feature_capture(self) -> None: + """Install Qwen-Image feature-capture hooks on the teacher when GAN is enabled. + + Reads the latent resolution from the dataloader so the hook can reshape + ``[B, num_image_patches, 3072]`` into ``[B, 3072, H_lat//2, W_lat//2]``. + Mock dataloader → spatial_h/spatial_w from the YAML. Real dataloader → + base_resolution / vae_scale. + """ + feature_indices = list(self.cfg.get("dmd2.gan_feature_indices", [30])) + + # Resolve h_lat / w_lat. Mock has it in the YAML; real cache uses the + # configured base_resolution divided by the VAE 8x downsample. + # Convert the dataloader subtree to a plain dict — AutoModel's ConfigNode + # doesn't expose deep dotted paths like ``data.dataloader.spatial_h``. + dl_node = self.cfg.get("data.dataloader", None) + if dl_node is not None and hasattr(dl_node, "to_dict"): + dl_dict = dl_node.to_dict() + elif dl_node is not None: + try: + dl_dict = dict(dl_node) + except (TypeError, ValueError): + dl_dict = {} + else: + dl_dict = {} + + spatial_h = dl_dict.get("spatial_h") + spatial_w = dl_dict.get("spatial_w") + base_resolution = dl_dict.get("base_resolution") + if spatial_h is not None and spatial_w is not None: + h_lat = int(spatial_h) + w_lat = int(spatial_w) + elif base_resolution is not None: + h_lat = int(base_resolution[0]) // 8 + w_lat = int(base_resolution[1]) // 8 + else: + # Fallback: hope it's 64x64 (512px image). Smoke tests pin this explicitly. + h_lat, w_lat = 64, 64 + if is_main_process(): + logging.warning( + "[DMD2] Could not infer h_lat/w_lat from data.dataloader; defaulting to 64x64." + ) + + qwen_image_plugin.attach_feature_capture( + self._teacher, + feature_indices=feature_indices, + h_lat=h_lat, + w_lat=w_lat, + ) + if is_main_process(): + logging.info( + "[DMD2] Attached GAN feature capture: indices=%s h_lat=%d w_lat=%d", + feature_indices, + h_lat, + w_lat, + ) + + def _load_fake_score(self) -> nn.Module: + """Load a third copy, trainable. Weights start identical to the teacher.""" + parallel_scheme = self._build_parallel_scheme_snapshot() + pipe, _ = NeMoAutoDiffusionPipeline.from_pretrained( + self.model_id, + torch_dtype=self.bf16, + device=self.device, + parallel_scheme=parallel_scheme, + components_to_load=["transformer"], + load_for_training=True, + low_cpu_mem_usage=True, + ) + fake_score = pipe.transformer + fake_score.train() + for p in fake_score.parameters(): + p.requires_grad_(True) + return fake_score + + def _build_parallel_scheme_snapshot(self) -> dict[str, dict[str, Any]]: + """Reconstruct the FSDP2 manager_args used for the student. + + Mirrors ``build_model_and_optimizer`` in ``nemo_automodel.recipes.diffusion.train``. + We can't capture the student's ``parallel_scheme`` directly (the parent doesn't + stash it), so we rebuild it from the same YAML knobs the parent consumed. + """ + from torch.distributed.fsdp import MixedPrecisionPolicy + + fsdp_cfg = self.cfg.get("fsdp", None) or {} + ddp_cfg = self.cfg.get("ddp", None) + + world_size = dist.get_world_size() if dist.is_initialized() else 1 + + if ddp_cfg is not None: + return { + "transformer": { + "_manager_type": "ddp", + "backend": ddp_cfg.get("backend", "nccl"), + "world_size": world_size, + "activation_checkpointing": ddp_cfg.get("activation_checkpointing", False), + } + } + + dp_size = fsdp_cfg.get("dp_size") + tp_size = fsdp_cfg.get("tp_size", 1) + cp_size = fsdp_cfg.get("cp_size", 1) + pp_size = fsdp_cfg.get("pp_size", 1) + if dp_size is None: + denom = max(1, tp_size * cp_size * pp_size) + dp_size = max(1, world_size // denom) + + return { + "transformer": { + "_manager_type": "fsdp2", + "dp_size": dp_size, + "dp_replicate_size": fsdp_cfg.get("dp_replicate_size", None), + "tp_size": tp_size, + "cp_size": cp_size, + "pp_size": pp_size, + "backend": "nccl", + "world_size": world_size, + "use_hf_tp_plan": fsdp_cfg.get("use_hf_tp_plan", False), + "activation_checkpointing": fsdp_cfg.get("activation_checkpointing", True), + "mp_policy": MixedPrecisionPolicy( + param_dtype=self.bf16, + reduce_dtype=torch.float32, + output_dtype=self.bf16, + ), + } + } + + def _resolve_pipeline_cls(self) -> type[DMDPipeline]: + """Pick the DMDPipeline subclass for the current backbone. + + Resolution order: + + 1. ``dmd2.pipeline_plugin`` in the YAML (explicit override, ``null`` for base). + 2. Substring match on ``model.pretrained_model_name_or_path`` + (e.g. ``Qwen-Image`` -> ``qwen_image`` plugin). + 3. Fall back to :class:`DMDPipeline`. + """ + explicit = self.cfg.get("dmd2.pipeline_plugin", None) + if explicit is None: + model_id_lc = (self.model_id or "").lower() + for needle, plugin_name in _PIPELINE_PLUGIN_BY_MODEL_SUBSTR: + if needle in model_id_lc: + explicit = plugin_name + break + if explicit in (None, "base", "DMDPipeline"): + return DMDPipeline + if explicit == "qwen_image": + # Imported lazily so ``base`` users don't pay the import cost. + from modelopt.torch.fastgen.plugins.qwen_image import QwenImageDMDPipeline + + return QwenImageDMDPipeline + raise ValueError( + f"Unknown dmd2.pipeline_plugin={explicit!r}. Supported: null/'base', 'qwen_image'." + ) + + def _resolve_pipeline_kwargs(self, pipeline_cls: type[DMDPipeline]) -> dict[str, Any]: + """Extra kwargs to forward to the pipeline subclass constructor (plugin-specific).""" + if pipeline_cls.__name__ == "QwenImageDMDPipeline": + # Optional ``guidance`` value passed to the transformer's guidance kwarg every + # call. Independent of DMDConfig.guidance_scale (which drives the negative- + # prompt CFG path on the teacher). Leave ``None`` to skip the embedding when + # the transformer was built with ``guidance_embeds=false`` (default for + # ``Qwen/Qwen-Image``). + return {"guidance": self.cfg.get("dmd2.qwen_image_guidance", None)} + return {} + + def _resolve_dmd_config(self) -> DMDConfig: + """Load the built-in fastgen recipe, then apply inline YAML overrides.""" + dmd_cfg_node = self.cfg.get("dmd2", None) + if dmd_cfg_node is None: + raise ValueError( + "Missing ``dmd2:`` block in the YAML config. Expected at minimum " + "``dmd2.recipe_path`` pointing at a fastgen DMDConfig recipe " + "(e.g. ``general/distillation/dmd2_qwen_image``)." + ) + dmd_dict = ( + dmd_cfg_node.to_dict() if hasattr(dmd_cfg_node, "to_dict") else dict(dmd_cfg_node) + ) + + recipe_path = dmd_dict.pop("recipe_path", None) + if recipe_path is None: + raise ValueError( + "``dmd2.recipe_path`` is required — Phase 1 relies on the built-in " + "``modelopt_recipes`` path resolver to hydrate the full DMDConfig." + ) + base_config = mtf.load_dmd_config(recipe_path) + + # Filter overrides to the subset that actually corresponds to DMDConfig fields. + # Non-matching keys (e.g. ``fake_score_lr``, ``cfg_mode``) are kept as top-level + # recipe knobs and read via ``self.cfg.get("dmd2.")``. + overrides = {k: v for k, v in dmd_dict.items() if k in _DMD_CONFIG_OVERRIDE_KEYS} + if not overrides: + return base_config + # Deep-merge so a YAML block that overrides a single ``sample_t_cfg`` / ``ema`` + # sub-field keeps the recipe's other sub-fields — a shallow ``dict.update`` would + # replace the whole sub-config and silently reset its siblings to defaults. + # Re-validate the merged dict so the nested blocks become their Pydantic config + # objects instead of raw dicts. + merged = _deep_merge_dicts(base_config.model_dump(), overrides) + return DMDConfig.model_validate(merged) + + def _build_fake_score_optimizer(self) -> torch.optim.Optimizer: + """AdamW on fake_score params. LR defaults to student LR; overridable via YAML.""" + fs_lr = self.cfg.get("dmd2.fake_score_lr", None) + if fs_lr is None: + fs_lr = self.learning_rate + optimizer_cfg = self.cfg.get("optim.optimizer", {}) or {} + optimizer_cfg = ( + optimizer_cfg.to_dict() if hasattr(optimizer_cfg, "to_dict") else dict(optimizer_cfg) + ) + weight_decay = optimizer_cfg.get("weight_decay", 0.01) + betas = tuple(optimizer_cfg.get("betas", (0.9, 0.999))) + + trainable_params = [p for p in self._fake_score.parameters() if p.requires_grad] + if not trainable_params: + raise RuntimeError("No trainable parameters found in fake_score.") + return torch.optim.AdamW(trainable_params, lr=fs_lr, weight_decay=weight_decay, betas=betas) + + # ------------------------------------------------------------------ # + # Inner helpers # + # ------------------------------------------------------------------ # + + def _set_grad_requirements(self, is_student_phase: bool) -> None: + """Toggle train/eval + requires_grad across modules for the active phase. + + Mirrors FastGen's ``_setup_grad_requirements`` (``dmd2.py`` lines 67-77), + INCLUDING the discriminator toggle that was previously omitted. + + Why the discriminator toggle matters: ``compute_student_loss`` calls + ``self.discriminator(fake_feat)`` for the ``gan_gen`` term, so the + discriminator is in the student-phase backward graph. With its params + left at ``requires_grad=True``, ``total.backward()`` allocates and + fills ``.grad`` for every discriminator parameter — gradients which + the student optimizer never consumes and which the next discriminator + ``zero_grad(set_to_none=True)`` simply wipes. Freezing the discriminator + during the student phase skips that wasted memory + backward compute + without changing any numerics (the student still receives the GAN + signal through the discriminator's input-side gradient, which doesn't + require the discriminator's own params to be in the autograd graph). + + Called every step; cheap enough that we don't bother caching the last state. + """ + if is_student_phase: + self.model.train() + for p in self.model.parameters(): + p.requires_grad_(True) + self._fake_score.eval() + for p in self._fake_score.parameters(): + p.requires_grad_(False) + if self._discriminator is not None: + self._discriminator.eval() + for p in self._discriminator.parameters(): + p.requires_grad_(False) + else: + self.model.eval() + for p in self.model.parameters(): + p.requires_grad_(False) + self._fake_score.train() + for p in self._fake_score.parameters(): + p.requires_grad_(True) + if self._discriminator is not None: + self._discriminator.train() + for p in self._discriminator.parameters(): + p.requires_grad_(True) + + def _prepare_micro_batch( + self, micro_batch: dict[str, Any] + ) -> tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor | None, + torch.Tensor | None, + torch.Tensor | None, + ]: + """Extract latents, noise, text conditioning, and optional masks from a batch. + + Accepts both 5D ``video_latents`` and 4D ``image_latents`` + (Qwen-Image / Flux / SD3). Mirrors the key dispatch in + ``nemo_automodel.components.flow_matching.pipeline.FlowMatchingPipeline.step``. + + ``negative_text_embeddings`` is optional — present when the dataloader + supplies it (mock T2I, real cache with precomputed empty-prompt + embedding) and consumed by ``compute_student_loss`` only when CFG is + enabled (``dmd2.guidance_scale is not None``). + """ + if "image_latents" in micro_batch: + latents = micro_batch["image_latents"].to(self.device, dtype=self.bf16) + elif "video_latents" in micro_batch: + latents = micro_batch["video_latents"].to(self.device, dtype=self.bf16) + else: + raise KeyError( + "Batch must contain either 'image_latents' (4D) or 'video_latents' (5D). " + f"Got keys: {sorted(micro_batch.keys())}." + ) + text_embeds = micro_batch["text_embeddings"].to(self.device, dtype=self.bf16) + if text_embeds.ndim == 2: + text_embeds = text_embeds.unsqueeze(0) + text_mask = micro_batch.get("text_embeddings_mask") + if text_mask is not None: + text_mask = text_mask.to(self.device) + if text_mask.ndim == 1: + text_mask = text_mask.unsqueeze(0) + negative_text_embeds = micro_batch.get("negative_text_embeddings") + if negative_text_embeds is not None: + negative_text_embeds = negative_text_embeds.to(self.device, dtype=self.bf16) + if negative_text_embeds.ndim == 2: + negative_text_embeds = negative_text_embeds.unsqueeze(0) + negative_text_mask = micro_batch.get("negative_text_embeddings_mask") + if negative_text_mask is not None: + negative_text_mask = negative_text_mask.to(self.device) + if negative_text_mask.ndim == 1: + negative_text_mask = negative_text_mask.unsqueeze(0) + # Fresh noise per micro-batch — DMD2 samples noise independently at each loss call. + noise = torch.randn_like(latents) + return latents, noise, text_embeds, text_mask, negative_text_embeds, negative_text_mask + + def _log_step( + self, + *, + global_step: int, + is_student_phase: bool, + group_loss: float, + grad_norm: float, + vsd_loss: float | None, + disc_loss: float | None = None, + ) -> None: + """Log a single step. Stdout always; wandb when the parent set it up.""" + phase = "student" if is_student_phase else "fake_score" + + # Stdout + suffix = f" vsd={vsd_loss:.4f}" if vsd_loss is not None else "" + if disc_loss is not None: + suffix += f" disc={disc_loss:.4f}" + logging.info( + "[STEP %d] phase=%s loss=%.4f grad_norm=%.4f%s lr=%.2e", + global_step, + phase, + group_loss, + grad_norm, + suffix, + self.optimizer.param_groups[0]["lr"], + ) + + # wandb + try: + import wandb + + if wandb.run is not None: + log_dict: dict[str, Any] = { + f"{phase}/loss": group_loss, + f"{phase}/grad_norm": grad_norm, + "global_step": global_step, + "lr_student": self.optimizer.param_groups[0]["lr"], + "lr_fake_score": self._fake_score_optimizer.param_groups[0]["lr"], + } + if vsd_loss is not None: + log_dict["student/vsd"] = vsd_loss + if disc_loss is not None: + log_dict["discriminator/loss"] = disc_loss + wandb.log(log_dict, step=global_step) + except Exception: + # wandb not installed or not initialised — silent no-op. + pass + + def _dmd_config_summary(self) -> str: + """Compact one-line summary of the active DMDConfig for startup logging.""" + cfg = self._dmd_config + t_list = cfg.sample_t_cfg.t_list if cfg.sample_t_cfg is not None else None + return ( + f"pred_type={cfg.pred_type} fake_score_pred_type={cfg.fake_score_pred_type} " + f"num_train_timesteps={cfg.num_train_timesteps} " + f"student_update_freq={cfg.student_update_freq} " + f"student_sample_steps={cfg.student_sample_steps} " + f"student_sample_type={cfg.student_sample_type} " + f"backward_simulation={cfg.backward_simulation} " + f"t_list={t_list} " + f"gan_loss_weight_gen={cfg.gan_loss_weight_gen} " + f"guidance_scale={cfg.guidance_scale} ema={'on' if cfg.ema is not None else 'off'}" + ) + + def _dmd_full_config_log(self) -> str: + """Full multi-line dump of every DMD2 parameter for startup tracing. + + Two sections: the resolved DMDConfig (every Pydantic field, including + nested ``sample_t_cfg`` and ``ema`` blocks) and the recipe-side keys + under ``dmd2:`` that aren't DMDConfig fields (e.g. ``fake_score_lr``, + ``gan_feature_indices``, ``pipeline_plugin``). Combined they cover + every knob that ends up driving the DMD2 method at runtime. + """ + cfg = self._dmd_config + dmd_node = self.cfg.get("dmd2", {}) or {} + if hasattr(dmd_node, "to_dict"): + dmd_node = dmd_node.to_dict() + else: + dmd_node = dict(dmd_node) + recipe_extras = { + k: v + for k, v in dmd_node.items() + if k not in _DMD_CONFIG_OVERRIDE_KEYS and k != "recipe_path" + } + combined = { + "DMDConfig_resolved": cfg.model_dump(), + "recipe_side_extras": recipe_extras, + } + return json.dumps(combined, indent=2, default=str) diff --git a/examples/diffusers/fastgen/export_diffusers_qwen_image.py b/examples/diffusers/fastgen/export_diffusers_qwen_image.py new file mode 100644 index 00000000000..65a17c9c3c0 --- /dev/null +++ b/examples/diffusers/fastgen/export_diffusers_qwen_image.py @@ -0,0 +1,202 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Export a DMD2 student as a full diffusers-loadable QwenImagePipeline dir. + +The §10 safetensors addendum produces a transformer-only dir: + + epoch_0_step_N/model/consolidated/ + config.json + diffusion_pytorch_model.safetensors.index.json + model-00001-of-00001.safetensors + +That dir is loadable by ``QwenImageTransformer2DModel.from_pretrained`` but NOT +by ``QwenImagePipeline.from_pretrained`` or ``DiffusionPipeline.from_pretrained`` +because it lacks ``model_index.json`` and the sibling component dirs +(``vae/``, ``text_encoder/``, ``tokenizer/``, ``scheduler/``). + +This utility assembles a full pipeline dir by: + + / + model_index.json (copied from base) + transformer/ (symlinked or copied from the consolidated student) + vae/ (symlinked from base Qwen-Image) + text_encoder/ (symlinked from base Qwen-Image) + tokenizer/ (symlinked from base Qwen-Image) + scheduler/ (symlinked from base Qwen-Image) + +After this runs, the dir loads with ``QwenImagePipeline.from_pretrained(output_dir)`` +or ``DiffusionPipeline.from_pretrained(output_dir)``. + +Symlinks are the default — the base Qwen-Image components are huge (text encoder +alone is ~12 GB) and never change between DMD2 students, so copying them per +checkpoint wastes disk. Use ``--copy`` if the output dir must be portable. + +Usage:: + + python export_diffusers_qwen_image.py \\ + --student_path /path/to/checkpoint/epoch_0_step_500/model/consolidated \\ + --base_pipeline_path Qwen/Qwen-Image \\ + --output_dir /path/to/output/qwen_image_dmd2 \\ + [--copy] + +Smoke test (``--verify``) loads the assembled dir via QwenImagePipeline and +checks the transformer config matches the student's. +""" + +from __future__ import annotations + +import argparse +import json +import logging +import os +import shutil +import sys +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from pathlib import Path + +logger = logging.getLogger(__name__) + +# Components borrowed from the base Qwen-Image checkpoint. The trained +# transformer replaces the corresponding entry. +BASE_COMPONENTS = ("vae", "text_encoder", "tokenizer", "scheduler") + + +def _link_or_copy(src: str, dst: str, copy: bool) -> None: + if os.path.lexists(dst): + if os.path.islink(dst) or os.path.isfile(dst): + os.remove(dst) + else: + shutil.rmtree(dst) + if copy: + if os.path.isdir(src): + shutil.copytree(src, dst, symlinks=True) + else: + shutil.copy2(src, dst) + else: + os.symlink(os.path.abspath(src), dst) + + +def export_diffusers( + student_path: str | Path, + base_pipeline_path: str | Path, + output_dir: str | Path, + copy: bool = False, +) -> None: + student_path = str(student_path) + base_pipeline_path = str(base_pipeline_path) + output_dir = str(output_dir) + + if not os.path.isdir(student_path): + raise FileNotFoundError(f"student_path is not a directory: {student_path}") + if not os.path.isdir(base_pipeline_path): + raise FileNotFoundError(f"base_pipeline_path is not a directory: {base_pipeline_path}") + base_index = os.path.join(base_pipeline_path, "model_index.json") + if not os.path.isfile(base_index): + raise FileNotFoundError(f"base pipeline missing model_index.json: {base_index}") + + os.makedirs(output_dir, exist_ok=True) + logger.info( + "[Diffusers-Export] Output dir: %s (mode=%s)", output_dir, "copy" if copy else "symlink" + ) + + # 1. model_index.json — copy verbatim (the class registry is the same + # whether the transformer weights are live or DMD-distilled). + dst_index = os.path.join(output_dir, "model_index.json") + with open(base_index) as f: + index = json.load(f) + with open(dst_index, "w") as f: + json.dump(index, f, indent=2) + logger.info("[Diffusers-Export] Wrote %s", dst_index) + + # 2. transformer/ — link/copy from the consolidated student. + dst_transformer = os.path.join(output_dir, "transformer") + _link_or_copy(student_path, dst_transformer, copy=copy) + logger.info("[Diffusers-Export] transformer/ <- %s", student_path) + + # 3. vae / text_encoder / tokenizer / scheduler — link/copy from base. + for comp in BASE_COMPONENTS: + src = os.path.join(base_pipeline_path, comp) + dst = os.path.join(output_dir, comp) + if not os.path.isdir(src): + raise FileNotFoundError(f"base pipeline missing component: {src}") + _link_or_copy(src, dst, copy=copy) + logger.info("[Diffusers-Export] %s/ <- %s", comp, src) + + logger.info( + "[Diffusers-Export] Done. Load via QwenImagePipeline.from_pretrained(%r)", output_dir + ) + + +def _verify(output_dir: str) -> dict: + """Load via QwenImagePipeline.from_pretrained and report a small status dict.""" + import torch + from diffusers import QwenImagePipeline + + logger.info("[Diffusers-Export-Verify] Loading %s via QwenImagePipeline", output_dir) + pipe = QwenImagePipeline.from_pretrained(output_dir, torch_dtype=torch.bfloat16) + n_transformer_params = sum(p.numel() for p in pipe.transformer.parameters()) + n_vae_params = sum(p.numel() for p in pipe.vae.parameters()) + n_text_params = sum(p.numel() for p in pipe.text_encoder.parameters()) + stats = { + "loaded_class": type(pipe).__name__, + "transformer_class": type(pipe.transformer).__name__, + "transformer_params": int(n_transformer_params), + "vae_class": type(pipe.vae).__name__, + "vae_params": int(n_vae_params), + "text_encoder_class": type(pipe.text_encoder).__name__, + "text_encoder_params": int(n_text_params), + "tokenizer_class": type(pipe.tokenizer).__name__, + "scheduler_class": type(pipe.scheduler).__name__, + } + return stats + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--student_path", + required=True, + help="Consolidated student dir (e.g. .../epoch_0_step_5/model/consolidated)", + ) + parser.add_argument( + "--base_pipeline_path", + required=True, + help="Base Qwen-Image pipeline dir (vae / text_encoder / tokenizer / scheduler source)", + ) + parser.add_argument("--output_dir", required=True, help="Where to write the full pipeline dir") + parser.add_argument( + "--copy", + action="store_true", + help="Copy components instead of symlinking. Off by default to save disk.", + ) + parser.add_argument( + "--verify", action="store_true", help="Re-load via QwenImagePipeline.from_pretrained" + ) + args = parser.parse_args() + + logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") + export_diffusers( + student_path=args.student_path, + base_pipeline_path=args.base_pipeline_path, + output_dir=args.output_dir, + copy=args.copy, + ) + if args.verify: + stats = _verify(args.output_dir) + print(json.dumps(stats, indent=2)) + sys.exit(0) diff --git a/examples/diffusers/fastgen/inference_dmd2_qwen_image.py b/examples/diffusers/fastgen/inference_dmd2_qwen_image.py new file mode 100644 index 00000000000..7748eb15ec3 --- /dev/null +++ b/examples/diffusers/fastgen/inference_dmd2_qwen_image.py @@ -0,0 +1,528 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Inference pipeline for DMD2-trained Qwen-Image students. + +Loads the consolidated safetensors transformer (from a §10 checkpoint with +``model_save_format=safetensors`` + ``save_consolidated=true``) plus the base +Qwen-Image VAE / text-encoder / tokenizer, and exposes a diffusers-style +``pipe(prompt=...).images[0]`` call that runs the DMD few-step sampler. + +Math is bit-aligned with the training-time ``_build_student_input`` in +``modelopt/torch/fastgen/methods/dmd.py``: + + Single-step (Phase 1 default): + x_T = noise * max_t # initial latent + v = student(x_T, t=max_t, text_emb) # one forward + x_0 = x_T - max_t * v # RF identity + image = vae.decode(x_0) + + Multi-step (Phase 2, ``num_inference_steps > 1``): + x = noise * t_list[0] # initial latent at t_max + for (t_cur, t_next) in zip(t_list[:-1], t_list[1:]): + v = student(x, t=t_cur, text_emb) # flow at t_cur + x_0 = x - t_cur * v # RF identity → x_0 estimate + if t_next > 0: + eps = (x - (1 - t_cur) * x_0) / t_cur # ODE: invert RF forward + x = (1 - t_next) * x_0 + t_next * eps # re-noise to t_next + else: + x = x_0 # final step + image = vae.decode(x) + +Usage:: + + from inference_dmd2_qwen_image import QwenImageDMDInferencePipeline + import torch + + pipe = QwenImageDMDInferencePipeline.from_pretrained( + student_path="/path/to/checkpoint/epoch_0_step_500/model/consolidated", + base_pipeline_path="Qwen/Qwen-Image", + ema_path=None, # or "…/epoch_0_step_5/ema_shadow.pt" + torch_dtype=torch.bfloat16, + ).to("cuda") + + image = pipe( + prompt="a small red cube on a white table", + num_inference_steps=1, + height=512, + width=512, + generator=torch.Generator("cuda").manual_seed(42), + ).images[0] + image.save("dmd2_smoke.png") +""" + +from __future__ import annotations + +import itertools +import logging +import os +from dataclasses import dataclass +from typing import TYPE_CHECKING + +import torch +from diffusers import QwenImagePipeline, QwenImageTransformer2DModel +from diffusers.utils.torch_utils import randn_tensor + +if TYPE_CHECKING: + from pathlib import Path + +logger = logging.getLogger(__name__) + + +@dataclass +class QwenImageDMDOutput: + """Container for inference outputs — mirrors diffusers' pipeline outputs.""" + + images: list + + +class QwenImageDMDInferencePipeline: + """Thin inference wrapper around a DMD2-trained Qwen-Image student. + + Wraps a stock ``diffusers.QwenImagePipeline`` whose ``transformer`` field + has been swapped for our trained student. Re-uses the pipeline's VAE, + text encoder, tokenizer, and image-processor for everything *except* the + denoising loop, which is replaced by the DMD few-step sampler. + """ + + def __init__( + self, + base_pipeline: QwenImagePipeline, + max_t: float = 0.999, + ) -> None: + self._pipe = base_pipeline + self.max_t = max_t + + # ------------------------------------------------------------------ # + # Loading # + # ------------------------------------------------------------------ # + + @classmethod + def from_pretrained( + cls, + student_path: str | Path, + base_pipeline_path: str | Path, + ema_path: str | Path | None = None, + torch_dtype: torch.dtype = torch.bfloat16, + max_t: float = 0.999, + ) -> QwenImageDMDInferencePipeline: + """Load the student + base Qwen-Image components. + + Args: + student_path: A consolidated dir from §10's safetensors save — must + contain ``config.json``, ``diffusion_pytorch_model.safetensors.index.json``, + and one or more ``*.safetensors`` shards. Loadable directly via + ``QwenImageTransformer2DModel.from_pretrained``. + base_pipeline_path: The base Qwen-Image checkpoint (e.g. + ``Qwen/Qwen-Image`` or a local snapshot). Used only for the + ``vae`` / ``text_encoder`` / ``tokenizer`` / ``image_processor``; + the transformer is replaced. + ema_path: Optional ``ema_shadow.pt`` produced by ``_save_dmd_extras``. + If provided, the EMA shadow weights are overlaid onto the + student after the safetensors load. EMA usually yields cleaner + samples than the live student. + torch_dtype: dtype for the student + VAE. ``bfloat16`` is what we + trained with. + max_t: Initial timestep for the 1-step sampler. Must match the + recipe's ``sample_t_cfg.max_t`` (default 0.999). + """ + student_path = str(student_path) + base_pipeline_path = str(base_pipeline_path) + if not os.path.isdir(student_path): + raise FileNotFoundError(f"student_path is not a directory: {student_path}") + if not os.path.isdir(base_pipeline_path): + raise FileNotFoundError(f"base_pipeline_path is not a directory: {base_pipeline_path}") + + logger.info("[DMD2-Inference] Loading trained student from %s", student_path) + student = QwenImageTransformer2DModel.from_pretrained(student_path, torch_dtype=torch_dtype) + + if ema_path is not None: + logger.info("[DMD2-Inference] Overlaying EMA shadow from %s", ema_path) + ema_state = torch.load(str(ema_path), map_location="cpu", weights_only=False) + shadow = ( + ema_state.get("shadow", ema_state) if isinstance(ema_state, dict) else ema_state + ) + if not isinstance(shadow, dict): + raise ValueError( + f"ema_shadow.pt content has unexpected type {type(shadow).__name__}; " + "expected dict[str, Tensor]." + ) + missing, unexpected = student.load_state_dict(shadow, strict=False) + if unexpected: + logger.warning( + "[DMD2-Inference] EMA overlay had %d unexpected keys (first: %s)", + len(unexpected), + unexpected[:3], + ) + if missing: + logger.warning( + "[DMD2-Inference] EMA overlay missed %d student keys (first: %s)", + len(missing), + missing[:3], + ) + + student.eval() + + logger.info( + "[DMD2-Inference] Loading base Qwen-Image pipeline from %s (transformer replaced)", + base_pipeline_path, + ) + # Passing transformer= bypasses loading the original transformer from disk; + # the rest (vae, text_encoder, tokenizer, scheduler, image_processor) loads + # normally. + pipe = QwenImagePipeline.from_pretrained( + base_pipeline_path, + transformer=student, + torch_dtype=torch_dtype, + ) + + return cls(base_pipeline=pipe, max_t=max_t) + + def to(self, device: str | torch.device) -> QwenImageDMDInferencePipeline: + self._pipe.to(device) + return self + + @property + def device(self) -> torch.device: + return self._pipe.transformer.device + + @property + def dtype(self) -> torch.dtype: + return next(self._pipe.transformer.parameters()).dtype + + # ------------------------------------------------------------------ # + # Inference # + # ------------------------------------------------------------------ # + + @torch.no_grad() + def __call__( + self, + prompt: str | list, + negative_prompt: str | list | None = None, + num_inference_steps: int = 1, + guidance_scale: float = 1.0, + height: int = 1024, + width: int = 1024, + num_images_per_prompt: int = 1, + generator: torch.Generator | None = None, + max_t: float | None = None, + t_list: list | None = None, + sample_type: str = "ode", + output_type: str = "pil", + max_sequence_length: int = 512, + ) -> QwenImageDMDOutput: + """Generate image(s) from the trained DMD2 student. + + ``num_inference_steps == 1`` runs the canonical Phase 1 single-step + sampler (``x_0 = x_T - max_t * v``). ``num_inference_steps > 1`` runs + the multi-step DMD unroll using ``t_list`` (or + ``linspace(max_t, 0, num_inference_steps + 1)`` if ``t_list`` is None). + + ``sample_type`` selects between deterministic (``"ode"``, recovers eps + from x_0 via RF identity) and stochastic (``"sde"``, fresh noise per + step). Matches FastGen's ``student_sample_type``. + + ``guidance_scale != 1.0`` activates inference-time classifier-free + guidance: the student is called twice (positive + negative prompt) per + step and the two flows are blended as ``v = v_neg + s*(v_pos - v_neg)``. + ``negative_prompt`` defaults to ``""`` (the canonical empty prompt) when + unset and CFG is engaged. + + **Note on DMD2 students trained with CFG**: a student trained with + ``dmd2.guidance_scale=4.0`` has *already* internalised CFG (its + single-pass output is the CFG-augmented teacher target). Pass + ``guidance_scale=1.0`` at inference to avoid double-CFG. Use + ``guidance_scale > 1.0`` only for students that were trained without + CFG (``dmd2.guidance_scale=null``). + """ + if sample_type not in ("ode", "sde"): + raise ValueError(f"sample_type must be 'ode' or 'sde', got {sample_type!r}") + do_cfg = guidance_scale != 1.0 + if do_cfg and negative_prompt is None: + # Default to the canonical empty unconditional prompt. + negative_prompt = "" + + max_t = float(max_t) if max_t is not None else float(self.max_t) + pipe = self._pipe + device = self.device + dtype = self.dtype + + # ---- 1. Resolve t_list ----------------------------------------------- + if num_inference_steps == 1: + schedule = [max_t, 0.0] + elif t_list is not None: + if len(t_list) != num_inference_steps + 1: + raise ValueError( + f"t_list must have num_inference_steps+1 entries " + f"(got {len(t_list)} for num_inference_steps={num_inference_steps})" + ) + schedule = [float(t) for t in t_list] + if abs(schedule[-1]) > 1e-6: + raise ValueError( + f"t_list must end at 0.0 (got {schedule[-1]}); the final step lands on x_0." + ) + else: + # Default: linear schedule from max_t to 0 (matches FastGen's + # torch.linspace(max_t, 0, sample_steps + 1) fallback). + schedule = torch.linspace(max_t, 0.0, num_inference_steps + 1).tolist() + + # ---- 2. Encode prompt(s) --------------------------------------------- + prompt_embeds, prompt_embeds_mask = pipe.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + neg_prompt_embeds = None + neg_prompt_embeds_mask = None + if do_cfg: + neg_prompt_embeds, neg_prompt_embeds_mask = pipe.encode_prompt( + prompt=negative_prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + txt_seq_lens = ( + prompt_embeds_mask.sum(dim=1).int().tolist() if prompt_embeds_mask is not None else None + ) + neg_txt_seq_lens = ( + neg_prompt_embeds_mask.sum(dim=1).int().tolist() + if neg_prompt_embeds_mask is not None + else None + ) + + # ---- 3. Build initial noisy latents at t = schedule[0] --------------- + if isinstance(prompt, str): + batch_size = 1 + else: + batch_size = len(prompt) + batch_size = batch_size * num_images_per_prompt + + num_channels_latents = pipe.transformer.config.in_channels // 4 # 64 // 4 = 16 + h_lat = 2 * (height // (pipe.vae_scale_factor * 2)) + w_lat = 2 * (width // (pipe.vae_scale_factor * 2)) + latent_shape = (batch_size, 1, num_channels_latents, h_lat, w_lat) + + # DMD initial latents: noise * schedule[0] (RF: sigma(t0) = t0). + noise = randn_tensor(latent_shape, generator=generator, device=device, dtype=dtype) + latents_5d = noise * schedule[0] + + latents_packed = pipe._pack_latents( + latents_5d, batch_size, num_channels_latents, h_lat, w_lat + ) + img_shapes = [[(1, h_lat // 2, w_lat // 2)]] * batch_size + + # ---- 4. DMD few-step unroll ----------------------------------------- + x_packed = latents_packed + for t_cur, t_next in itertools.pairwise(schedule): + timestep = torch.tensor([t_cur], device=device, dtype=dtype).expand(batch_size) + flow_packed = pipe.transformer( + hidden_states=x_packed, + encoder_hidden_states=prompt_embeds, + encoder_hidden_states_mask=prompt_embeds_mask, + timestep=timestep, + img_shapes=img_shapes, + txt_seq_lens=txt_seq_lens, + guidance=None, + return_dict=False, + )[0] + if do_cfg: + # CFG two-pass: ``v_cfg = v_neg + s*(v_pos - v_neg)``. Equivalent + # to the FastGen training-time formula + # ``teacher_pos + (s-1) * (teacher_pos - teacher_neg)`` once + # expanded. Engaged only when the caller passes + # ``guidance_scale != 1.0`` — DMD2 students trained with a + # non-null ``dmd2.guidance_scale`` already internalise CFG, so + # leave guidance_scale=1.0 for those. + neg_flow_packed = pipe.transformer( + hidden_states=x_packed, + encoder_hidden_states=neg_prompt_embeds, + encoder_hidden_states_mask=neg_prompt_embeds_mask, + timestep=timestep, + img_shapes=img_shapes, + txt_seq_lens=neg_txt_seq_lens, + guidance=None, + return_dict=False, + )[0] + flow_packed = ( + neg_flow_packed.to(torch.float64) + + float(guidance_scale) + * (flow_packed.to(torch.float64) - neg_flow_packed.to(torch.float64)) + ).to(dtype) + # RF identity: x_0 = x_t - t_cur * v (computed in fp64 for stability). + x0_packed = ( + x_packed.to(torch.float64) - float(t_cur) * flow_packed.to(torch.float64) + ).to(dtype) + + if t_next > 1e-6: + # Re-noise x_0 forward to t_next. + if sample_type == "ode": + # Deterministic: invert the RF forward to recover the implied eps, + # then re-noise. eps = (x_t - (1 - t_cur) * x_0) / t_cur + alpha_cur = 1.0 - float(t_cur) + eps_packed = ( + (x_packed.to(torch.float64) - alpha_cur * x0_packed.to(torch.float64)) + / max(float(t_cur), 1e-6) + ).to(dtype) + else: + # Stochastic: fresh Gaussian noise. + eps_packed = torch.randn( + x_packed.shape, generator=generator, device=device, dtype=dtype + ) + # RF forward: x_{t_next} = (1 - t_next) * x_0 + t_next * eps. + alpha_next = 1.0 - float(t_next) + x_packed = ( + alpha_next * x0_packed.to(torch.float64) + + float(t_next) * eps_packed.to(torch.float64) + ).to(dtype) + else: + # Last step: x_0 is the output. + x_packed = x0_packed + + # Unpack to 5D for VAE. + x0_5d = pipe._unpack_latents(x_packed, height, width, pipe.vae_scale_factor) + + # ---- 5. VAE decode --------------------------------------------------- + # Reverse the VAE-side scaling that diffusers applied at encoding time. + latents_mean = ( + torch.tensor(pipe.vae.config.latents_mean) + .view(1, pipe.vae.config.z_dim, 1, 1, 1) + .to(device=device, dtype=dtype) + ) + latents_std = 1.0 / torch.tensor(pipe.vae.config.latents_std).view( + 1, pipe.vae.config.z_dim, 1, 1, 1 + ).to(device=device, dtype=dtype) + x0_scaled = x0_5d / latents_std + latents_mean + + # vae.decode returns 5D; the trailing [:, :, 0] drops the temporal dim + # since Qwen-Image treats images as 1-frame videos. + image_5d = pipe.vae.decode(x0_scaled, return_dict=False)[0] + image_4d = image_5d[:, :, 0] # [B, C, H, W] + + images = pipe.image_processor.postprocess(image_4d, output_type=output_type) + return QwenImageDMDOutput(images=images) + + +# ---------------------------------------------------------------------------- # +# Standalone smoke-test driver. Validates end-to-end wiring against the §10 # +# safetensors checkpoint. Mock-data training means the image won't be # +# coherent — pass criterion is just "the pipeline produces a finite image # +# tensor and the file writes successfully". # +# ---------------------------------------------------------------------------- # + + +def _smoke_test( + student_path: str, + base_pipeline_path: str, + output_png: str, + ema_path: str | None = None, + prompt: str = "a small red cube on a white table", + height: int = 512, + width: int = 512, + seed: int = 42, +) -> None: + """Run a one-shot inference and dump a PNG. + + Writes a small JSON sidecar next to the PNG with shape/dtype/range stats + so the success criteria can be machine-checked. + """ + import json + + logging.basicConfig(level=logging.INFO) + pipe = QwenImageDMDInferencePipeline.from_pretrained( + student_path=student_path, + base_pipeline_path=base_pipeline_path, + ema_path=ema_path, + torch_dtype=torch.bfloat16, + ) + device = "cuda" if torch.cuda.is_available() else "cpu" + pipe = pipe.to(device) + gen = torch.Generator(device=device).manual_seed(seed) + + out = pipe( + prompt=prompt, + num_inference_steps=1, + height=height, + width=width, + generator=gen, + ) + image = out.images[0] + + # PIL image; sanity-check shape + range. + import numpy as np + + arr = np.array(image) + stats = { + "prompt": prompt, + "height": height, + "width": width, + "seed": seed, + "ema_overlay": ema_path is not None, + "image_shape": list(arr.shape), + "image_dtype": str(arr.dtype), + "image_min": int(arr.min()), + "image_max": int(arr.max()), + "image_mean": float(arr.mean()), + "image_std": float(arr.std()), + "is_finite": bool(np.isfinite(arr).all()), + "is_not_constant": bool(arr.std() > 0), + } + + os.makedirs(os.path.dirname(output_png), exist_ok=True) + image.save(output_png) + sidecar = output_png.replace(".png", "_stats.json") + with open(sidecar, "w") as f: + json.dump(stats, f, indent=2) + print(json.dumps(stats, indent=2)) + print(f"\nImage saved to: {output_png}") + print(f"Stats sidecar: {sidecar}") + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument( + "--student_path", + required=True, + help="Path to the consolidated safetensors student checkpoint " + "(e.g. .../epoch_0_step_500/model/consolidated).", + ) + parser.add_argument( + "--base_pipeline_path", + default="Qwen/Qwen-Image", + help="Base Qwen-Image pipeline (HF id or local snapshot) for the VAE / text-encoder / tokenizer.", + ) + parser.add_argument( + "--output_png", + default="./outputs/dmd2_sample.png", + ) + parser.add_argument("--ema_path", default=None) + parser.add_argument("--prompt", default="a small red cube on a white table") + parser.add_argument("--height", type=int, default=512) + parser.add_argument("--width", type=int, default=512) + parser.add_argument("--seed", type=int, default=42) + args = parser.parse_args() + + _smoke_test( + student_path=args.student_path, + base_pipeline_path=args.base_pipeline_path, + output_png=args.output_png, + ema_path=args.ema_path, + prompt=args.prompt, + height=args.height, + width=args.width, + seed=args.seed, + ) diff --git a/examples/diffusers/fastgen/requirements.txt b/examples/diffusers/fastgen/requirements.txt new file mode 100644 index 00000000000..1cba3ce7868 --- /dev/null +++ b/examples/diffusers/fastgen/requirements.txt @@ -0,0 +1,10 @@ +# Runtime requirements for the DMD2 Qwen-Image AutoModel example. +# Torch + diffusers are already pulled in via Model-Optimizer's ``[all]`` extras. +# The one thing that's NOT shipped with Model-Optimizer is nemo_automodel. + +# NeMo AutoModel (parent recipe, dataloader, FSDP2 wrapping). +# The diffusion extras install diffusers + accelerate with matching pins. +nemo_automodel[diffusion] + +# Optional but recommended for the smoke logs. +wandb diff --git a/examples/diffusers/sparsity/README.md b/examples/diffusers/sparsity/README.md index a2b071d6c0e..e102f1b6d5d 100644 --- a/examples/diffusers/sparsity/README.md +++ b/examples/diffusers/sparsity/README.md @@ -43,10 +43,11 @@ python wan22_skip_softmax.py \ --skip-softmax-threshold 0.61557 \ --prompt "A cat playing piano" --output out.mp4 -# With calibration +# Calibrate + export for TRT-LLM deployment (typical flow) python wan22_skip_softmax.py \ --model-path /path/to/Wan2.2-T2V-A14B-Diffusers \ - --calibrate --target-sparsity 0.5 \ + --calibrate --target-sparsity 0.5 --calib-size 4 \ + --export-dir /path/to/wan22-skip-softmax-ckpt \ --prompt "A cat playing piano" --output out.mp4 # Dense baseline (no sparsity, for comparison) @@ -62,6 +63,13 @@ python wan22_skip_softmax.py \ --prompt "A cat playing piano" --output out.mp4 ``` +`--export-dir` writes a Hugging Face checkpoint with the calibrated +`threshold_scale_factor` block embedded in each component's `config.json` +(under the `sparse_attention_config` key). TRT-LLM's +`SkipSoftmaxAttentionConfig.resolve_for_target_sparsity` reads the +`(a, b)` directly via `coeffs['a'] * math.exp(coeffs['b'] * sparsity)` — +no extra conversion needed downstream. + ## Threshold Modes | Mode | How threshold reaches the kernel | Use case | @@ -73,4 +81,5 @@ python wan22_skip_softmax.py \ ## Known Issues - **14B dual transformer calibration**: Transformers are calibrated sequentially — transformer_2's calibration runs while transformer_1 is already sparsified, introducing asymmetric calibration conditions. +- **Resolution-dependent fit**: The kernel's intrinsic `S(λ)` curve depends on the spatial resolution (different attention statistics at 480×832 vs 720×1280). For tightest target ≈ achieved alignment, calibrate at the deployment `(height, width, frames)`. Within a fixed spatial resolution, achieved sparsity stays roughly aligned across frame counts. - **Minimum achievable sparsity**: Even the strictest threshold may yield 30-40% sparsity on diffusion models (many tiles are inherently negligible). Targets below this floor cause extrapolation; an inference-time warning is emitted. diff --git a/examples/diffusers/sparsity/wan22_skip_softmax.py b/examples/diffusers/sparsity/wan22_skip_softmax.py index 3f4447e0bad..890cec547ff 100644 --- a/examples/diffusers/sparsity/wan22_skip_softmax.py +++ b/examples/diffusers/sparsity/wan22_skip_softmax.py @@ -59,6 +59,7 @@ from diffusers.utils import export_to_video import modelopt.torch.sparsity.attention_sparsity as mtsa +from modelopt.torch.export import export_hf_checkpoint from modelopt.torch.sparsity.attention_sparsity.sparse_attention import SparseAttentionModule DEFAULT_MODEL_PATH = os.environ.get("WAN22_MODEL_PATH", "Wan-AI/Wan2.2-TI2V-5B-Diffusers") @@ -190,8 +191,8 @@ def parse_args() -> argparse.Namespace: parser.add_argument( "--calib-frames", type=int, - default=151, - help="Number of frames for calibration", + default=None, + help="Number of frames for calibration (default: same as --num-frames).", ) parser.add_argument( "--calib-size", @@ -199,6 +200,24 @@ def parse_args() -> argparse.Namespace: default=4, help="Number of calibration prompts from OpenVid-1M dataset", ) + + # Export options + parser.add_argument( + "--export-dir", + type=str, + default=None, + help="Export sparsified model as a HuggingFace checkpoint to this directory. " + "The sparse_attention_config (calibration params, disabled layers, etc.) " + "is written into each component's config.json.", + ) + parser.add_argument( + "--initial-disabled-steps", + type=int, + default=0, + help="User-specified initial-disabled-steps value carried into the exported " + "sparse attention config (config.json). Only emitted when > 0; not interpreted " + "at sparsify/calibration time.", + ) return parser.parse_args() @@ -233,6 +252,10 @@ def build_sparse_config(args: argparse.Namespace, num_blocks: int) -> dict: if args.skip_softmax_threshold is not None: attn_cfg["skip_softmax_threshold"] = args.skip_softmax_threshold + # Opt-in initial-disabled-steps metadata — carried through to the exported config when > 0. + if args.initial_disabled_steps > 0: + attn_cfg["initial_disabled_steps"] = args.initial_disabled_steps + sparse_cfg: dict = { "*.attn1*": attn_cfg, # Self-attention only "*.attn2*": {"enable": False}, # Text cross-attention @@ -415,11 +438,12 @@ def main() -> None: if args.calibrate: print("Warning: --calibrate is ignored when --skip-softmax-threshold is set") elif args.calibrate: + calib_frames = args.calib_frames if args.calib_frames is not None else args.num_frames forward_loop = build_calibration_forward_loop( pipe, calib_size=args.calib_size, num_steps=args.calib_steps, - num_frames=args.calib_frames, + num_frames=calib_frames, height=args.height, width=args.width, seed=args.seed, @@ -445,6 +469,11 @@ def main() -> None: torch.cuda.empty_cache() print("Cleared CUDA cache after calibration") + # ---- Export (optional) ---- + if args.export_dir and not args.baseline: + print(f"Exporting sparsified checkpoint to {args.export_dir}...") + export_hf_checkpoint(pipe, export_dir=args.export_dir) + # ---- Generate (optional) ---- if args.prompt: # Enable runtime sparsity measurement before generation diff --git a/examples/llm_eval/lm_eval_tensorrt_llm.py b/examples/llm_eval/lm_eval_tensorrt_llm.py index 181fc9c79f1..f65dad53655 100644 --- a/examples/llm_eval/lm_eval_tensorrt_llm.py +++ b/examples/llm_eval/lm_eval_tensorrt_llm.py @@ -64,6 +64,10 @@ def __init__( tokenizer=self.tokenizer, max_batch_size=int(batch_size), max_seq_len=max_length, + # Loglikelihood tasks request context logits. KV cache prefix reuse would return + # logits only for the recomputed suffix on shared-prefix requests (e.g. hellaswag), + # truncating context_logits and breaking parse_logprobs. Disable it. + enable_kv_cache_reuse=False, ) self.max_length = max_length - 1 logger.info("Loaded TRT-LLM") diff --git a/examples/llm_eval/mmlu.py b/examples/llm_eval/mmlu.py index a3047fcc9e1..3d03240c408 100755 --- a/examples/llm_eval/mmlu.py +++ b/examples/llm_eval/mmlu.py @@ -183,7 +183,10 @@ def gen_prompt(train_df, subject, k=-1): def evaluate(args, subject, model: EvalModel | LLM, dev_df, test_df): cors = [] all_probs = [] - for i in range(test_df.shape[0]): + num_examples = test_df.shape[0] + if args.limit is not None: + num_examples = min(num_examples, args.limit) + for i in range(num_examples): # get prompt and make sure it fits k = args.ntrain prompt_end = format_example(test_df, i, include_answer=False) @@ -201,6 +204,12 @@ def check_valid_length(model, prompt): train_prompt = gen_prompt(dev_df, subject, k) prompt = train_prompt + prompt_end + # Skip examples that do not fit even at zero-shot, otherwise the backend rejects + # prompts longer than max_seq_len and aborts the whole evaluation. + if not check_valid_length(model, prompt): + print(f"Skipping {subject} example {i}: prompt exceeds max_seq_len even at 0-shot.") + continue + label = test_df.iloc[i, test_df.shape[1] - 1] if isinstance(model, EvalModel): pred = model.run(prompt) @@ -212,7 +221,11 @@ def check_valid_length(model, prompt): cors.append(cor) all_probs.append(probs) - acc = np.mean(cors) + if not cors: + # Every example was skipped (all prompts exceeded max_seq_len). Surface it instead of + # silently producing a nan accuracy downstream. + print(f"WARNING: all {subject} examples were skipped; reporting accuracy as nan.") + acc = np.mean(cors) if cors else float("nan") cors = np.array(cors) all_probs = np.array(all_probs) @@ -233,8 +246,12 @@ def main( auto_quantize_score_size: int = 128, auto_quantize_checkpoint: str | None = None, sparse_cfg: str | None = None, + limit: int | None = None, **kwargs, ): + if limit is not None and limit <= 0: + raise ValueError(f"limit must be a positive integer when provided, got {limit}.") + random.seed(RAND_SEED) np.random.seed(RAND_SEED) diff --git a/examples/llm_eval/run_simple_eval.sh b/examples/llm_eval/run_simple_eval.sh index 5f40b4ce8b3..763e88432f8 100644 --- a/examples/llm_eval/run_simple_eval.sh +++ b/examples/llm_eval/run_simple_eval.sh @@ -28,11 +28,23 @@ if [ ! -d "human-eval" ]; then git clone https://github.com/openai/human-eval.git fi +# Pin to a known commit for reproducibility (and so the entry-point patch below matches), forcing +# it every run so a reused checkout cannot drift to an arbitrary revision. -f discards the patch +# applied to setup.py on a previous run before re-applying it below. +git -C human-eval checkout -q -f 6d43fb980f9fee3c892a914eda09951f772ad10d + +# human-eval's console_scripts entry point lacks the ":callable" suffix, which newer pip/setuptools +# reject ("A callable suffix is required"). The target module defines main(), so point at it. +sed -i 's|human_eval\.evaluate_functional_correctness"|human_eval.evaluate_functional_correctness:main"|' human-eval/setup.py + if [ ! -d "simple-evals" ]; then git clone https://github.com/openai/simple-evals.git fi -pip install -e human-eval +# --no-build-isolation: human-eval's legacy setup.py imports pkg_resources at build time, +# which pip's isolated build env does not provide with newer setuptools. Build against the +# base environment (which has setuptools/pkg_resources) instead. +pip install -e human-eval --no-build-isolation pip install openai pushd simple-evals diff --git a/examples/llm_ptq/cast_mxfp4_to_nvfp4.py b/examples/llm_ptq/cast_mxfp4_to_nvfp4.py index 90b8521e5d7..e8e86966f3b 100644 --- a/examples/llm_ptq/cast_mxfp4_to_nvfp4.py +++ b/examples/llm_ptq/cast_mxfp4_to_nvfp4.py @@ -41,6 +41,12 @@ from safetensors import safe_open from modelopt.torch.quantization.nn.modules.tensor_quantizer import NVFP4StaticQuantizer +from modelopt.torch.quantization.utils.numeric_utils import ( + E2M1_MAX, + E8M0_BIAS, + mxfp4_to_nvfp4_global_amax, + mxfp4_to_nvfp4_per_block_amax, +) @contextmanager @@ -62,150 +68,6 @@ def read(key: str, shard: Path) -> torch.Tensor: yield read -E8M0_BIAS = 127 # E8M0 stores k_j as uint8 with bias 127 -E2M1_MAX = 6.0 -E4M3_MAX = 448.0 -E4M3_KMAX = 8 -E4M3_KMIN = -9 # E4M3 represents 2^k exactly for k in [-9, 8] -# E2M1 magnitude grid indexed by the low 3 bits of an FP4 nibble. -_E2M1_MAGNITUDE = [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0] -# Cache of the E2M1 magnitude lookup table per (device, dtype) so we don't -# rebuild it for every layer in a batched cast. -_E2M1_MAG_CACHE: "dict[tuple, torch.Tensor]" = {} - - -def _e2m1_magnitude_table(device: torch.device, dtype: torch.dtype = torch.float32) -> torch.Tensor: - """Return ``_E2M1_MAGNITUDE`` as a tensor on the requested device, cached.""" - key = (device, dtype) - cached = _E2M1_MAG_CACHE.get(key) - if cached is None: - cached = torch.tensor(_E2M1_MAGNITUDE, dtype=dtype, device=device) - _E2M1_MAG_CACHE[key] = cached - return cached - - -def compute_global_amax_for_scales(e8m0_scales: torch.Tensor) -> tuple[float, dict]: - """Closed-form per-tensor ``global_amax``: ``m = k_max - 8``, ``global_amax = 6 * 448 * 2^m``. - - Args: - e8m0_scales: uint8 tensor of E8M0 scales for one MXFP4 source layer. - - Returns: - global_amax: scalar (float) — pins NVFP4 scale_2 to 2^m. - info: diagnostic dict with k_min, k_max, m, lossless-block stats. - """ - # k_j = e8m0 - 127. MXFP4 quantize emits e8m0=0 (=> k=-127) for all-zero - # blocks; treat those as "ignore me" when computing k_max. - k = e8m0_scales.to(torch.int32) - E8M0_BIAS - nonzero_mask = e8m0_scales > 0 - if nonzero_mask.any(): - k_nonzero = k[nonzero_mask] - k_min = int(k_nonzero.min().item()) - k_max = int(k_nonzero.max().item()) - else: - k_min = k_max = 0 - - m = k_max - E4M3_KMAX - global_amax = E2M1_MAX * E4M3_MAX * float(2.0**m) - - # A block is lossless under this cast iff k_max - k_j <= 17 (its k_j - m sits - # in E4M3's [-9, 8] window). All-zero blocks are trivially lossless because - # their reconstruction is 0 regardless of the snapped scale. - n_total = e8m0_scales.numel() - in_range = (k >= (k_max - 17)) | (~nonzero_mask) - n_lossless = int(in_range.sum().item()) - pct_lossless = 100.0 * n_lossless / n_total if n_total else 100.0 - - return global_amax, { - "k_min": k_min, - "k_max": k_max, - "m": m, - "n_total_blocks": n_total, - "n_lossless_blocks": n_lossless, - "pct_lossless": pct_lossless, - "n_zero_blocks": int((~nonzero_mask).sum().item()), - } - - -def compute_per_block_amax_for_mxfp4( - blocks: torch.Tensor, e8m0_scales: torch.Tensor -) -> torch.Tensor: - """Hybrid per-NVFP4-block amax for MXFP4 -> NVFP4 cast. - - Each MXFP4 block of 32 elements has one E8M0 exponent ``k_j``. Two cases - based on whether ``k_j`` fits in NVFP4's E4M3 scale grid (with - ``m = k_max - 8`` chosen by ``compute_global_amax_for_scales``): - - - **In-range** (``k_j - m`` in ``[-9, 8]``): ``6 * 2^k_j`` (closed-form - ideal). The resulting per-block scale ``2^(k_j - m)`` is exactly - representable in E4M3 — no rounding loss — and - ``round_to_E2M1(value / 2^k_j)`` yields the original MXFP4 nibble - verbatim. Bit-exact reconstruction. - - - **Out of range** (``|k_j - m| > 8/9``): ``max_nibble * 2^k_j``, i.e. - ``max(|w_block|)`` where ``w`` is the MXFP4-dequantized block. This is - the data-derived per-block amax. The per-block scale will still get - clamped at the E4M3 boundary, but data-derived amax keeps the post-clamp - scale closer to the block's actual magnitude than the closed-form ideal - would, which reduces re-bucketing error for OOR blocks where - ``max_nibble < 6``. - - Two NVFP4 blocks of 16 share each MXFP4 block's ``k_j``, so the result is - expanded by ``repeat_interleave(2, dim=-1)``. - - Args: - blocks: uint8 tensor of packed E2M1 nibbles, shape - ``(..., num_mxfp4_blocks, 16)`` (16 bytes per 32-element MXFP4 block). - e8m0_scales: uint8 tensor of E8M0 scales, shape - ``(..., num_mxfp4_blocks)``. - - Returns: - float32 tensor of shape ``(..., 2 * num_mxfp4_blocks)``. - """ - if blocks.shape[-1] != 16 or blocks.shape[:-1] != e8m0_scales.shape: - raise ValueError( - f"shape mismatch: blocks {tuple(blocks.shape)} " - "(expected (..., num_mxfp4_blocks, 16)) " - f"vs scales {tuple(e8m0_scales.shape)}" - ) - - k = e8m0_scales.to(torch.int32) - E8M0_BIAS # (..., num_mxfp4_blocks) - pow2_k = torch.exp2(k.float()) - closed_form_ideal = E2M1_MAX * pow2_k # (..., num_mxfp4_blocks) - - # ``m = k_max - 8`` over non-zero blocks. Compute via masked ``amax`` so - # ``m`` stays a 0-d tensor and we avoid a GPU->CPU sync just to get a - # Python int. All-zero scales fall through with the -E8M0_BIAS sentinel, - # which leaves every block trivially in-range (closed_form_ideal == 0 there). - nonzero = e8m0_scales > 0 - sentinel = torch.full_like(k, -E8M0_BIAS) - k_max = torch.where(nonzero, k, sentinel).amax() - delta = k - (k_max - E4M3_KMAX) - in_range = (delta >= E4M3_KMIN) & (delta <= E4M3_KMAX) - - # Fast path: if every block fits E4M3's [-9, 8] window the per-block amax - # is just the closed-form ideal, and we can skip the per-byte nibble scan - # over the block tensor (which is 16x larger than the scales). For typical - # MXFP4 checkpoints (e.g. gpt-oss-20b) this is the only path ever taken. - if bool(in_range.all()): - return closed_form_ideal.repeat_interleave(2, dim=-1) - - # OOR fallback: data-derived per-block amax = max(|w_block|) after MXFP4 - # dequant = ``max_nibble * 2^k_j``. The MXFP4 nibble is sign-magnitude with - # sign in bit 3 and magnitude index in bits 0-2; we extract per-byte - # magnitudes, take the byte-wise max, then reduce across the 16 bytes to - # get the largest magnitude index in the 32-element block. - low = blocks & 0x07 - high = (blocks >> 4) & 0x07 - max_idx = torch.maximum(low, high).amax(dim=-1).long() - max_nibble = _e2m1_magnitude_table(blocks.device)[max_idx] - data_derived = max_nibble * pow2_k - - per_block_amax_mxfp4 = torch.where(in_range, closed_form_ideal, data_derived) - # Each MXFP4 block of 32 splits into two NVFP4 blocks of 16 sharing k_j. - return per_block_amax_mxfp4.repeat_interleave(2, dim=-1) - - def quantizer_name_from_blocks_key(blocks_key: str) -> str: """Map ``_blocks`` -> ``_weight_quantizer``. @@ -282,7 +144,7 @@ def build_amax_map(checkpoint_dir: str | Path) -> dict[str, dict]: for tensor_key, shard in sorted(scales_keys.items()): scales = read(tensor_key, shard) - global_amax, info = compute_global_amax_for_scales(scales) + global_amax, info = mxfp4_to_nvfp4_global_amax(scales) blocks_key = tensor_key[: -len("_scales")] + "_blocks" qname = quantizer_name_from_blocks_key(blocks_key) @@ -357,7 +219,7 @@ def apply_to_model( for tensor_key, shard in sorted(scales_keys.items()): scales = read(tensor_key, shard) - global_amax_value, info = compute_global_amax_for_scales(scales) + global_amax_value, info = mxfp4_to_nvfp4_global_amax(scales) n_total_layers += 1 if info["pct_lossless"] >= 100.0: n_lossless_layers += 1 @@ -410,7 +272,7 @@ def apply_to_model( ) else: blocks = read(blocks_key, blocks_shard) - per_block_amax = compute_per_block_amax_for_mxfp4(blocks, scales).to( + per_block_amax = mxfp4_to_nvfp4_per_block_amax(blocks, scales).to( dtype=torch.float32, device=device ) # Numel must match — calibration may store ``_amax`` flat (e.g. (N, 1)) diff --git a/examples/llm_ptq/example_utils.py b/examples/llm_ptq/example_utils.py index 57d9bebef43..65363a1460e 100755 --- a/examples/llm_ptq/example_utils.py +++ b/examples/llm_ptq/example_utils.py @@ -533,6 +533,25 @@ def _resolve_file(filename): module.__dict__.pop("weight", None) +def get_original_hf_quant_method(config) -> str | None: + """Return the checkpoint's original ``quantization_config.quant_method``, if any. + + Returns e.g. ``"mxfp4"`` for native MXFP4 checkpoints (OpenAI's gpt-oss family), or + ``None`` for unquantized models. Handles ``quantization_config`` stored as a dict or a + config object, and the nested ``text_config`` of multi-modal models. + """ + for cfg in (config, getattr(config, "text_config", None)): + quant_cfg = getattr(cfg, "quantization_config", None) + method = ( + quant_cfg.get("quant_method") + if isinstance(quant_cfg, dict) + else getattr(quant_cfg, "quant_method", None) + ) + if method: + return str(method) + return None + + def get_model( ckpt_path, device="cuda", @@ -636,6 +655,29 @@ def has_pack_quantized_config(config): trust_remote_code=trust_remote_code, dtype="auto", ) + elif get_original_hf_quant_method(hf_config) == "mxfp4": + # Native MXFP4 checkpoints (e.g. openai/gpt-oss-*) must be dequantized to + # plain BF16 experts (``GptOssExperts``) so ModelOpt can insert and export + # quantizers: the packed-kernel experts wrapper (``Mxfp4GptOssExperts``, + # used when the optional ``kernels`` package is present) is not supported by + # the unified HF export. Force dequantization regardless of whether + # ``kernels`` is installed. + # Local import: ``Mxfp4Config`` only exists in newer Transformers (gpt-oss support); + # importing it at module scope would break example_utils for users on older + # Transformers running unrelated (non-MXFP4) models. + from transformers import Mxfp4Config + + # Load with a *sequential* device map (not "auto"): the MXFP4->BF16 dequant + # runs inside Transformers' threaded weight loader, and an "auto"/balanced + # split across multiple GPUs trips a CUDA illegal-memory access during dequant + # materialization. Sequential keeps each shard's dequant on a single device + # (the whole model lands on one GPU when it fits there). + model_kwargs["quantization_config"] = Mxfp4Config(dequantize=True) + model = AutoModelForCausalLM.from_pretrained( + ckpt_path, + device_map="cpu" if device == "cpu" else "sequential", + **model_kwargs, + ) else: architecture = hf_config.architectures[0] diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index 3fe3f3ceb03..407d735b921 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -27,6 +27,7 @@ from cast_mxfp4_to_nvfp4 import apply_to_model as apply_cast_mxfp4_to_nvfp4 from cast_mxfp4_to_nvfp4 import force_weight_quantizers_static from example_utils import ( + _resolve_model_path, build_quant_cfg, copy_custom_model_files, create_vlm_calibration_loop, @@ -816,13 +817,15 @@ def pre_quantize( """ # Offline specdec models skip pre-quantize preview (no tokenizer or standard dataloader) if args.specdec_offline_dataset is not None: - return None, None + return None, None, None # Only run single sample for preview assert calib_dataloader is not None, "calib_dataloader is required for pre-quantize preview" - preview_input_ids = next(iter(calib_dataloader))[ - "input_features" if model_type == "whisper" else "input_ids" - ][0:1] + batch = next(iter(calib_dataloader)) + input_key = "input_features" if model_type == "whisper" else "input_ids" + preview_input_ids = batch[input_key][0:1] + # Pass attention_mask to generate(): HF cannot infer it when pad_token == eos_token. + preview_attention_mask = batch["attention_mask"][0:1] if "attention_mask" in batch else None # Generate preview before quantization if args.skip_generate: @@ -846,9 +849,13 @@ def pre_quantize( trust_remote_code=args.trust_remote_code, ) else: - generated_ids_before_ptq = full_model.generate(preview_input_ids, max_new_tokens=100) + generated_ids_before_ptq = full_model.generate( + preview_input_ids, + attention_mask=preview_attention_mask, + max_new_tokens=100, + ) - return preview_input_ids, generated_ids_before_ptq + return preview_input_ids, preview_attention_mask, generated_ids_before_ptq def post_quantize( @@ -859,6 +866,7 @@ def post_quantize( tokenizer: PreTrainedTokenizerBase | None, processor: ProcessorMixin | None, preview_input_ids, + preview_attention_mask, generated_ids_before_ptq, is_nemotron_vl_model, first_text_speech_dataset, @@ -903,7 +911,11 @@ def post_quantize( pass elif model_type != "llama4" and not is_nemotron_vl_model: # Our fake quantizer may not be fully compatible with torch.compile. - generated_ids_after_ptq = full_model.generate(preview_input_ids, max_new_tokens=100) + generated_ids_after_ptq = full_model.generate( + preview_input_ids, + attention_mask=preview_attention_mask, + max_new_tokens=100, + ) elif is_nemotron_vl_model and tokenizer is not None: generated_ids_after_ptq = run_nemotron_vl_preview( full_model, @@ -1061,7 +1073,7 @@ def _is_layerwise(obj): # Detect if this is a Nemotron VL model using architecture-based detection is_nemotron_vl_model = is_nemotron_vl(full_model) - preview_input_ids, generated_ids_before_ptq = pre_quantize( + preview_input_ids, preview_attention_mask, generated_ids_before_ptq = pre_quantize( args, full_model, model_type, tokenizer, calib_dataloader, is_nemotron_vl_model ) @@ -1155,7 +1167,12 @@ def _is_layerwise(obj): # to NVFP4StaticQuantizer with a data-derived ``_global_amax``); we just # override that scalar with the closed-form value before export. if args.cast_mxfp4_to_nvfp4: - apply_cast_mxfp4_to_nvfp4(language_model, args.pyt_ckpt_path) + # The cast reads the source MXFP4 ``*_scales``/``*_blocks`` tensors from a local + # checkpoint directory. ``--pyt_ckpt_path`` may be a HF Hub ID (e.g. + # ``openai/gpt-oss-20b``); resolve it to the local snapshot dir that load_model's + # ``from_pretrained`` already populated so the cast works with the documented command. + source_ckpt_dir = _resolve_model_path(args.pyt_ckpt_path, args.trust_remote_code) + apply_cast_mxfp4_to_nvfp4(language_model, source_ckpt_dir) post_quantize( args, @@ -1165,6 +1182,7 @@ def _is_layerwise(obj): tokenizer, processor, preview_input_ids, + preview_attention_mask, generated_ids_before_ptq, is_nemotron_vl_model, first_text_speech_dataset, diff --git a/examples/llm_ptq/run_tensorrt_llm.py b/examples/llm_ptq/run_tensorrt_llm.py index 56d25df709c..f7cc588f40a 100644 --- a/examples/llm_ptq/run_tensorrt_llm.py +++ b/examples/llm_ptq/run_tensorrt_llm.py @@ -66,7 +66,14 @@ def run(args): print("TensorRT-LLM example outputs:") - llm = LLM(args.checkpoint_dir, tokenizer=tokenizer, max_batch_size=len(input_texts)) + # generate_context_logits() below requires KV cache reuse disabled: with prefix block reuse, + # shared-prefix inputs return truncated (silently incorrect) context logits. + llm = LLM( + args.checkpoint_dir, + tokenizer=tokenizer, + max_batch_size=len(input_texts), + enable_kv_cache_reuse=False, + ) torch.cuda.cudart().cudaProfilerStart() outputs = llm.generate_text(input_texts, args.max_output_len) torch.cuda.cudart().cudaProfilerStop() diff --git a/examples/llm_ptq/scripts/huggingface_example.sh b/examples/llm_ptq/scripts/huggingface_example.sh index 541c349da08..3f51e5b73f3 100755 --- a/examples/llm_ptq/scripts/huggingface_example.sh +++ b/examples/llm_ptq/scripts/huggingface_example.sh @@ -29,6 +29,10 @@ for i in $(env | grep ^SLURM_ | cut -d"=" -f 1); do unset -v $i; done for i in $(env | grep ^PMI_ | cut -d"=" -f 1); do unset -v $i; done for i in $(env | grep ^PMIX_ | cut -d"=" -f 1); do unset -v $i; done +# Fail on errors inside pipelines (e.g. `python eval.py | tee result.txt`), otherwise a crashing +# eval is masked by tee's exit code and the script passes silently. +set -o pipefail + if [ -z "$MODEL_PATH" ]; then echo "Unsupported model argument: Expected a huggingface model path or model name" >&2 exit 1 @@ -216,7 +220,11 @@ if [[ $TASKS =~ "quant" ]] || [[ ! -d "$SAVE_PATH" ]] || [[ ! $(ls -A $SAVE_PATH RUN_ARGS+=" --trust_remote_code " fi - python run_tensorrt_llm.py --checkpoint_dir=$SAVE_PATH $RUN_ARGS + # Only run the deploy+generate smoke test when "quant" is explicitly requested. Eval tasks + # (lm_eval/mmlu/simple_eval) deploy the checkpoint themselves, so it is redundant there. + if [[ $TASKS =~ "quant" ]]; then + python run_tensorrt_llm.py --checkpoint_dir=$SAVE_PATH $RUN_ARGS + fi fi if [[ -d "${MODEL_PATH}" ]]; then @@ -285,11 +293,16 @@ if [[ $TASKS =~ "mmlu" ]]; then tar -xf /tmp/mmlu.tar -C data && mv data/data $MMLU_DATA_PATH fi + mmlu_flags="" + if [ -n "$MMLU_LIMIT" ]; then + mmlu_flags+=" --limit $MMLU_LIMIT " + fi + python mmlu.py \ --model_name causal \ --model_path $MODEL_ABS_PATH \ --checkpoint_dir $SAVE_PATH \ - --data_dir $MMLU_DATA_PATH | tee $MMLU_RESULT + --data_dir $MMLU_DATA_PATH $mmlu_flags | tee $MMLU_RESULT popd fi @@ -304,16 +317,16 @@ if [[ $TASKS =~ "livecodebench" || $TASKS =~ "simple_eval" ]]; then trtllm-serve $SAVE_PATH --host 0.0.0.0 --port $PORT >$SAVE_PATH/serve.txt 2>&1 & SERVE_PID=$! - tail -f $SAVE_PATH/serve.txt | while read line; do - if echo "$line" | grep -q "Application startup complete"; then - echo "Application startup complete." - break - fi + # Poll the log instead of `tail -f | while ... break`: under `set -o pipefail` (set above), + # breaking out of that pipeline leaves tail to die by SIGPIPE, which would abort the script. + while ! grep -q "Application startup complete" $SAVE_PATH/serve.txt 2>/dev/null; do if ! kill -0 $SERVE_PID 2>/dev/null; then echo "trtllm-serve has exited." exit 1 fi + sleep 2 done + echo "Application startup complete." pushd ../llm_eval/ diff --git a/examples/llm_ptq/scripts/parser.sh b/examples/llm_ptq/scripts/parser.sh index 3be7706c4e6..3efed91bc32 100644 --- a/examples/llm_ptq/scripts/parser.sh +++ b/examples/llm_ptq/scripts/parser.sh @@ -28,6 +28,7 @@ parse_options() { LM_EVAL_TASKS="mmlu,gsm8k" LM_EVAL_LIMIT= SIMPLE_EVAL_TASKS="mmlu" + MMLU_LIMIT= TASKS="quant" @@ -38,7 +39,7 @@ parse_options() { CAST_MXFP4_TO_NVFP4=false # Parse command-line options - ARGS=$(getopt -o "" -l "model:,quant:,recipe:,kv_cache_quant:,tp:,pp:,sparsity:,awq_block_size:,calib:,calib_batch_size:,auto_quantize_bits:,output:,batch:,tasks:,lm_eval_tasks:,lm_eval_limit:,simple_eval_tasks:,simple_eval_limit:,trust_remote_code,use_seq_device_map,gpu_max_mem_percentage:,kv_cache_free_gpu_memory_fraction:,low_memory_mode,no-verbose,calib_dataset:,calib_seq:,auto_quantize_method:,auto_quantize_score_size:,auto_quantize_checkpoint:,moe_calib_experts_ratio:,cast_mxfp4_to_nvfp4" -n "$0" -- "$@") + ARGS=$(getopt -o "" -l "model:,quant:,recipe:,kv_cache_quant:,tp:,pp:,sparsity:,awq_block_size:,calib:,calib_batch_size:,auto_quantize_bits:,output:,batch:,tasks:,lm_eval_tasks:,lm_eval_limit:,simple_eval_tasks:,simple_eval_limit:,mmlu_limit:,trust_remote_code,use_seq_device_map,gpu_max_mem_percentage:,kv_cache_free_gpu_memory_fraction:,low_memory_mode,no-verbose,calib_dataset:,calib_seq:,auto_quantize_method:,auto_quantize_score_size:,auto_quantize_checkpoint:,moe_calib_experts_ratio:,cast_mxfp4_to_nvfp4" -n "$0" -- "$@") eval set -- "$ARGS" while true; do @@ -61,6 +62,7 @@ parse_options() { --lm_eval_limit ) LM_EVAL_LIMIT="$2"; shift 2;; --simple_eval_tasks ) SIMPLE_EVAL_TASKS="$2"; shift 2;; --simple_eval_limit ) SIMPLE_EVAL_LIMIT="$2"; shift 2;; + --mmlu_limit ) MMLU_LIMIT="$2"; shift 2;; --trust_remote_code ) TRUST_REMOTE_CODE=true; shift;; --use_seq_device_map ) USE_SEQ_DEVICE_MAP=true; shift;; --gpu_max_mem_percentage ) GPU_MAX_MEM_PERCENTAGE="$2"; shift 2;; @@ -161,6 +163,7 @@ parse_options() { echo "lm_eval_limit: $LM_EVAL_LIMIT" echo "simple_eval_tasks: $SIMPLE_EVAL_TASKS" echo "simple_eval_limit: $SIMPLE_EVAL_LIMIT" + echo "mmlu_limit: $MMLU_LIMIT" echo "num_sample: $NUM_SAMPLES" echo "use_seq_device_map: $USE_SEQ_DEVICE_MAP" echo "gpu_max_mem_percentage: $GPU_MAX_MEM_PERCENTAGE" diff --git a/examples/llm_qat/ARGUMENTS.md b/examples/llm_qat/ARGUMENTS.md index 579e235c346..0a3e2b7a12d 100644 --- a/examples/llm_qat/ARGUMENTS.md +++ b/examples/llm_qat/ARGUMENTS.md @@ -50,7 +50,7 @@ | Argument | Type | Default | Description | |----------|------|---------|-------------| | `--recipe` | `str` | `None` | Path to a quantization recipe YAML file (built-in or custom). Built-in recipes can be specified by relative path, e.g. 'general/ptq/nvfp4_default-kv_fp8'. Replaces the deprecated --quant_cfg flag. | -| `--quant_cfg` | `modelopt.torch.quantization.config.QuantizeConfig` | `None` | Deprecated: pre-quantize the model with a separate quantization step instead. Specify the quantization format for PTQ/QAT by name (e.g. NVFP4_DEFAULT_CFG). | +| `--quant_cfg` | `str` | `None` | Deprecated: pre-quantize the model with a separate quantization step instead. Specify the quantization format for PTQ/QAT by name (e.g. NVFP4_DEFAULT_CFG). | | `--calib_size` | `int` | `512` | Specify the calibration size for quantization. The calibration dataset is used to setup the quantization scale parameters for PTQ/QAT. | | `--compress` | `bool` | `False` | Whether to compress the model weights after quantization for QLoRA. This is useful for reducing the model size. | | `--calib_batch_size` | `int` | `1` | Batch size for calibration data during quantization. | diff --git a/examples/megatron_bridge/README.md b/examples/megatron_bridge/README.md index 715a9aefc3d..6bb1f19be5c 100644 --- a/examples/megatron_bridge/README.md +++ b/examples/megatron_bridge/README.md @@ -16,7 +16,7 @@ This directory contains examples of using Model Optimizer with [NeMo Megatron-Br > [!TIP] -> Checkout the [Nemotron-3-Nano-30B-A3B pruning + distillation (with data blend prep) + quantization tutorial](../pruning/minitron/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16/README.md) for a complete end-to-end workflow using Megatron-Bridge! +> Checkout the [Nemotron-3-Nano-30B-A3B pruning + distillation (with data blend prep) + quantization tutorial](tutorials/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16/README.md) for a complete end-to-end workflow using Megatron-Bridge! ## Pre-Requisites @@ -47,12 +47,6 @@ docker run \ > [!WARNING] > Use `python -m pip` instead of `pip` to avoid conflicts with the system-wide installed packages in the NeMo containers. You may also refer to this [doc](https://github.com/NVIDIA-NeMo/Megatron-Bridge/blob/main/docker/common/README.md#installing-packages-inside-the-container) on how to correctly install packages in the NeMo containers without breaking existing torch installation. -Also install additional dependencies from the [requirements.txt](./requirements.txt) file. - -```bash -python -m pip install -r requirements.txt -``` - You also need to login with your HuggingFace token to download gated datasets / models. Note that the default dataset for pruning and quantization is [`nemotron-post-training-dataset-v2`](https://huggingface.co/datasets/nvidia/Nemotron-Post-Training-Dataset-v2), which is gated. @@ -307,6 +301,9 @@ torchrun --nproc_per_node 1 prune_minitron.py --help > uneven PP by setting `--num_layers_in_first_pipeline_stage` and `--num_layers_in_last_pipeline_stage`. > E.g. for Qwen3-8B with 36 layers and 8 GPUs, you can set both to 3 to get 3-5-5-5-5-5-5-3 layers per GPU. +> [!NOTE] +> If pruning a Nemotron model and you want to save the pruned model back in HF format, please downgrade to `transformers<5` via `python -m pip install "transformers<5"` before pruning. + ## Resources - 📅 [Roadmap](https://github.com/NVIDIA/Model-Optimizer/issues/146) diff --git a/examples/megatron_bridge/requirements.txt b/examples/megatron_bridge/requirements.txt deleted file mode 100644 index ec38c2f7ee7..00000000000 --- a/examples/megatron_bridge/requirements.txt +++ /dev/null @@ -1,2 +0,0 @@ -# Saving some pruned models (e.g. Nemotron-3-Nano-30B-A3B-BF16) have issues with transformers>=5.0 -transformers<5.0 diff --git a/examples/pruning/minitron/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16/ABLATIONS.md b/examples/megatron_bridge/tutorials/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16/ABLATIONS.md similarity index 100% rename from examples/pruning/minitron/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16/ABLATIONS.md rename to examples/megatron_bridge/tutorials/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16/ABLATIONS.md diff --git a/examples/megatron_bridge/tutorials/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16/README.md b/examples/megatron_bridge/tutorials/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16/README.md new file mode 100644 index 00000000000..1c7729da51a --- /dev/null +++ b/examples/megatron_bridge/tutorials/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16/README.md @@ -0,0 +1,545 @@ +# Nemotron-3-Nano-30B-A3B: Prune + Distill + Quantize + vLLM Deployment + +End-to-end optimization of [NVIDIA-Nemotron-3-Nano-30B-A3B-BF16](https://huggingface.co/nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16) demonstrating how ModelOpt techniques stack: Minitron structured pruning → Megatron-Bridge knowledge distillation to recover accuracy → evaluation benchmarking → FP8 quantization → vLLM deployment and throughput benchmarking. This document covers: + +1. **[Data Preparation](#1-data-preparation)** — tokenizing the training blend for distillation +2. **[Pruning](#2-pruning)** — Minitron structured pruning +3. **[Distillation](#3-distillation)** — recovering accuracy via Megatron-Bridge knowledge distillation +4. **[Evaluation](#4-evaluation)** — benchmarking with NeMo Evaluator across MMLU Pro, GPQA Diamond, AIME, and more +5. **[Quantization](#5-quantization)** — FP8 PTQ on the distilled checkpoint using ModelOpt's `examples/megatron_bridge/quantize.py` script +6. **[vLLM Inference Benchmarking](#6-vllm-inference-benchmarking)** — throughput comparison of BF16 vs FP8 on a single H100 + +## Results + +![Benchmark Recovery During Knowledge Distillation](figures/learning_curves.png) + +| Model | MMLU Pro | GPQA Diamond | LiveCodeBench v6 | AIME 2025 | IFBench | SciCode (Subtask) | Average | +| --- | --- | --- | --- | --- | --- | --- | --- | +| Pruned 22B/A3.0B (no distillation) | 47.1 | 33.5 | 27.4 | 15.5 | 36.9 | 12.1 | 28.8 | +| Distill @ 2.5B tokens (100 iters at 8K SeqLen) | 73.3 | 63.7 | 55.3 | 77.6 | 59.1 | 25.1 | 59.0 | +| Distill @ 20B tokens (800 iters at 8K SeqLen) | 74.8 | 66.0 | 62.3 | 79.6 | 65.4 | 26.1 | 62.4 | +| Distill @ 40B tokens (1600 iters at 8K SeqLen) | 76.4 | 67.2 | 62.3 | 79.8 | 66.0 | 26.6 | 63.1 | +| Distill @ 60B tokens (2400 iters at 8K SeqLen) | 76.1 | 68.1 | 63.6 | 78.8 | 67.3 | 27.0 | 63.5 | +| Distill @ 80B tokens (3200 iters at 8K SeqLen) | 76.5 | 69.1 | 63.9 | 80.7 | 66.5 | 29.0 | 64.3 | +| Distill @ 82.5B tokens (+100 iters at 32K SeqLen) | 76.2 | 69.8 | 64.8 | 87.0 | 68.2 | 27.0 | 65.5 | +| Distill @ 100B tokens (+800 iters at 32K SeqLen) - **BF16** | 76.6 | 69.6 | 66.1 | 87.3 | 68.9 | 28.4 | 66.2 | +| Distill @ 100B tokens + **FP8 Quantize** | 76.7 | 70.7 | 65.5 | 87.3 | 69.0 | 28.5 | 66.3 | +| Nemotron-3-Nano-30B-A3B-BF16 (official, 31.6B/A3.6B) | 78.0 | 70.3 | 67.9 | 87.1 | 69.1 | 31.8 | 67.4 | + +### vLLM Throughput (single H100, ISL=32768, OSL=1024) + +| Checkpoint | Model loading memory | Output tokens/s | Speedup vs Nemotron-3-Nano-30B-A3B-BF16 | +| --- | --- | --- | --- | +| Nemotron-3-Nano-30B-A3B-BF16 (official, 31.6B/A3.6B) | 58.9 GiB | 598 | 1.0× | +| Nemotron-3-Nano-30B-A3B-FP8 (official) | 31.4 GiB | 1,323 | 2.2× | +| Nemotron-3-Nano-Pruned-22B-A3.0B-BF16 | 41.5 GiB | 1,190 | 2.0× | +| Nemotron-3-Nano-Pruned-22B-A3.0B-FP8 | 22.8 GiB | 1,576 | 2.6× | + +Pruning alone (BF16 → Pruned-A3.0B BF16) gives a **2.0×** throughput speedup with a 30% memory reduction (58.9 → 41.5 GiB), and FP8 quantization alone (BF16 → FP8) gives a **2.2×** speedup with a 47% memory reduction. Stacking both — pruning + FP8 — compounds to a **2.6×** throughput speedup and a **2.6× memory reduction** (58.9 → 22.8 GiB) relative to the original 30B BF16 model, while preserving most of the benchmark accuracy. See [Section 6](#6-vllm-inference-benchmarking) for the benchmark command. + +Distillation uses the **30% Pretraining (Code 5, General 20, MATH 5) + 70% Post-training v1/v3 (Math 27, Coding 20, Science 13, IF 5, Tool calling 5)** blend (see [Data Blend](#data-blend) below) with an **80B @ 8K + 20B @ 32K = 100B token** schedule. Blend ablations and long-context phase ablations are in [ABLATIONS.md](ABLATIONS.md). + +> [!TIP] +> From the benchmark numbers above, the model is still learning at 100B tokens and that further training (or a higher-quality data blend) would continue to close the gap to the original 31.6B/A3.6B model. + +> [!NOTE] +> Exact numbers may vary depending on deployment and evaluation setup. All models above (including the official model) were evaluated with the same [evaluation setup](#4-evaluation) for fair comparison. These numbers may differ from those reported on the official [Nemotron-3-Nano-30B-A3B-BF16](https://huggingface.co/nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16) HuggingFace model card. + +--- + +## Steps to Reproduce + +**Environment:** Container `nvcr.io/nvidia/nemo:26.04`, ModelOpt 0.45.0. See the [Megatron-Bridge README](../../README.md) for environment setup (including ModelOpt mount path) and container usage. Pruning Nemotron models requires `transformers<5` via `python -m pip install "transformers<5"` else saving pruned model as HF checkpoint may fail. + +### 1. Data Preparation + +See [examples/dataset/MEGATRON_DATA_PREP.md](../../../dataset/MEGATRON_DATA_PREP.md) for tokenization commands for all datasets used in this blend. + +For this experiment: `TOKENIZER=nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16`, `OUTPUT_DIR=tokenized_nemotron_3`. + +#### Data Blend + +**30% Pretraining (Code 5, General 20, MATH 5) + 70% Post-training v1/v3 (Math 27, Coding 20, Science 13, IF 5, Tool calling 5)** + +| Dataset | Tokens | Weight | Notes | +| ---------------------------------------------------------- | ------ | ------ | ---------------------------------------------- | +| Nemotron-Pretraining-SFT-v1 / Code (10M samples) | 7B | 5 | Pretraining code | +| Nemotron-Pretraining-SFT-v1 / General (10M samples) | 16B | 20 | Upweighted to close MMLU gap | +| Nemotron-Pretraining-SFT-v1 / MATH (10M samples) | 13B | 5 | Pretraining math | +| Nemotron-Math-v2 / high_part00 | 13B | 10 | Hard math reasoning | +| Nemotron-SFT-Math-v3 / train | 52B | 17 | Hard math reasoning with full reasoning traces | +| Nemotron-SFT-Competitive-Programming-v2 / python_00 | 7B | 15 | Python reasoning traces | +| Nemotron-SFT-Competitive-Programming-v2 / cpp_00 | 7B | 5 | C++ reasoning traces | +| Nemotron-Post-Training-Dataset-v1 / stem (5M samples) | 22B | 8 | Broad STEM | +| Nemotron-Science-v1 / MCQ | 0.5B | 3 | GPQA MCQ format alignment | +| Nemotron-Science-v1 / RQA | 0.3B | 2 | GPQA format diversity | +| Nemotron-SFT-Instruction-Following-Chat-v2 / reasoning_on | 2B | 3 | Instruction following (thinking on) | +| Nemotron-SFT-Instruction-Following-Chat-v2 / reasoning_off | 1B | 2 | Instruction following (thinking off) | +| Nemotron-Agentic-v1 / tool_calling | 1B | 5 | Tool-use scaffolding; helps SciCode / GPQA | + +
+Data blend for distillation (click to expand) + +```bash +DATA_BLEND=" \ +5 tokenized_nemotron_3/nvidia--Nemotron-Pretraining-SFT-v1_Nemotron-SFT-Code_train_text_max10000000 \ +20 tokenized_nemotron_3/nvidia--Nemotron-Pretraining-SFT-v1_Nemotron-SFT-General_train_text_max10000000 \ +5 tokenized_nemotron_3/nvidia--Nemotron-Pretraining-SFT-v1_Nemotron-SFT-MATH_train_text_max10000000 \ +10 tokenized_nemotron_3/nvidia--Nemotron-Math-v2_default_high_part00_messages \ +17 tokenized_nemotron_3/nvidia--Nemotron-SFT-Math-v3_default_train_messages \ +15 tokenized_nemotron_3/competitive_programming_python_00_messages \ +5 tokenized_nemotron_3/competitive_programming_cpp_00_messages \ +8 tokenized_nemotron_3/nvidia--Nemotron-Post-Training-Dataset-v1_default_stem_messages_max5000000 \ +3 tokenized_nemotron_3/MCQ_messages \ +2 tokenized_nemotron_3/RQA_messages \ +3 tokenized_nemotron_3/reasoning_on_messages \ +2 tokenized_nemotron_3/reasoning_off_messages \ +5 tokenized_nemotron_3/nvidia--Nemotron-Agentic-v1_tool_calling_messages \ +" +``` + +
+ +#### General Guidelines + +The optimal blend is 30% pretraining and 70% post-training data. Exact proportions may vary depending on the benchmarks you care about. The blend above was designed to maximize recovery on popular General Knowledge, Reasoning, Instruction Following, and Tool Calling benchmarks. The key design decisions were: + +- **30% pretraining data** closes the MMLU gap that arises from training exclusively on reasoning-heavy post-training data. The General split (20%) is upweighted specifically to recover general knowledge recall. +- **Math (27%)** is the largest post-training category because AIME and MMLU Pro respond strongly to more math reasoning tokens. We use a mix of `Nemotron-Math-v2` and `Nemotron-SFT-Math-v3` for higher quality math reasoning signal with full reasoning traces. +- **Science (13%)** uses `Nemotron-Post-Training-Dataset-v1 / stem` as the primary source for volume and GPQA stability, with small allocations to `Nemotron-Science-v1` MCQ/RQA subsets for format alignment with GPQA's multiple-choice structure. +- **Instruction following (5%)** saturates quickly so a small allocation is sufficient. +- **Tool calling (5%)** uses `Nemotron-Agentic-v1 / tool_calling`. Our evals run with `--enable-auto-tool-choice`, so the student needs explicit exposure to function-call schemas; this helps SciCode (heavy Python tool use) and GPQA Diamond (which can benefit from calculator tools). + +This blend intentionally omits capabilities not targeted in this experiment (e.g. multilingual, SWE). Depending on what benchmarks matter for your use case, you can substitute or add datasets from the [Nemotron Post-Training v3 collection](https://huggingface.co/collections/nvidia/nemotron-post-training-v3), for example: + +| Capability | Relevant datasets | +| --- | --- | +| Multilingual | `Nemotron-SFT-Multilingual-v1` | +| Software engineering (SWE) | `Nemotron-SFT-SWE-v2` | +| Safety / alignment | `Nemotron-SFT-Safety-v1` | + +When adding new datasets, reduce weights of lower-priority categories proportionally to keep the total at 100%. + +--- + +### 2. Pruning + +Here we prune the [NVIDIA-Nemotron-3-Nano-30B-A3B-BF16](https://huggingface.co/nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16) HuggingFace checkpoint from 31.6B/A3.6B to 3.0B active parameters. The output is a pruned HuggingFace checkpoint that feeds into the distillation step. + +Run on **1 node with 8x H100** (~1 hour) + +
+Pruning command (click to expand) + +```bash +torchrun --nproc_per_node 8 /opt/Model-Optimizer/examples/megatron_bridge/prune_minitron.py \ + --pp_size 8 \ + --hf_model_name_or_path nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16 \ + --trust_remote_code \ + --prune_target_params 28e9 \ + --prune_target_active_params 3e9 \ + --hparams_to_skip num_attention_heads \ + --seq_length 8192 \ + --output_hf_path /path/to/Nemotron-3-Nano-30B-A3B-Pruned-A3.0B \ + --top_k 20 \ + --max_depth_pruning 0.15 \ + --max_width_pruning 0.30 \ + --prune_score_func mmlu_10pct_bs32 \ + --num_layers_in_first_pipeline_stage 5 \ + --num_layers_in_last_pipeline_stage 5 +``` + +Non-default arguments: + +- `--hparams_to_skip num_attention_heads` (default: none) — attention heads pruning is harder to recover, hence skipped +- `--seq_length 8192` (default: 4096) — dataset has longer sequences +- `--prune_target_active_params 3e9` — MoE-specific; the **primary** pruning constraint — targets active params rather than total params, which is what matters for MoE inference cost +- `--prune_target_params 28e9` — upper bound on total params only; the actual pruned model total can range anywhere from ~20B to 28B depending on which architecture wins — see pruning logs below for the top 20 candidates. You may also skip this argument all together for simplicity. +- `--top_k 20` (default: 10) — larger candidate pool for better architecture search +- `--max_depth_pruning 0.15` (default: 0.20) — tighter constraint since candidates with 42–46 layers universally fail for this model +- `--max_width_pruning 0.30` (default: 0.40) — tighter constraint to prevent head_dim≤48 and hidden=2048 dead zones +- `--prune_score_func mmlu_10pct_bs32` (default: `mmlu_10pct_bs1`) — batch_size=32 for ~3–4× faster candidate scoring +- `--num_layers_in_first_pipeline_stage 5 --num_layers_in_last_pipeline_stage 5` — Uneven pipeline parallelism since 52 layers is not divisible by 8 GPUs + +**NOTE**: The tighter search space constraints here (`--max_depth_pruning`, `--max_width_pruning`) are specific to Nemotron hybrid models (Mamba + Attention + MoE). Unlike standard transformers which expose only layers/hidden/attention/FFN dimensions, these models add Mamba-specific dimensions (`mamba_num_heads`, `mamba_head_dim`) and MoE dimensions (`num_moe_experts`, `moe_ffn_hidden_size`, `moe_shared_expert_intermediate_size`), making the combined search space much larger. The default 40%/20% bounds cast too wide a net and waste compute on dead-zone architectures. + +See [ABLATIONS.md](ABLATIONS.md#pruning) for the full architecture search analysis across various candidates. +
+ +
+Pruning logs (top 20 candidates, best subnet, layer patterns) (click to expand) + +```text +╭──────────────────────────────────────────────────── Original Model Stats ─────────────────────────────────────────────────────╮ +│ Total Parameters 31.58B │ +│ Active Parameters 3.58B │ +│ Memory (BF16, seq_length=8192, batch_size=1) weights: 60230.1 MB, kv_cache: 48.0 MB, mamba_state: 23.8 MB, Total: 60301.9 MB │ +╰───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯ + + Search Space + (≤30% width / ≤15% depth pruning) +┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ +┃ Hyperparameter ┃ Choices ┃ +┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ +│ num_layers │ [46, 48, 50, 52] │ +│ hidden_size │ [2048, 2304, 2560, 2688] │ +│ mamba_num_heads │ [48, 56, 64] │ +│ mamba_head_dim │ [48, 56, 64] │ +│ num_moe_experts │ [96, 104, 112, 120, 128] │ +│ moe_ffn_hidden_size │ [1536, 1792, 1856] │ +│ moe_shared_expert_intermediate_size │ [2816, 3072, 3328, 3584, 3712] │ +├─────────────────────────────────────┼────────────────────────────────┤ +│ Search space size │ 10800 │ +└─────────────────────────────────────┴────────────────────────────────┘ + +Top 20 Candidates with Scores +┏━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━┳━━━━━━━━┓ +┃ # ┃ export_config ┃ active_params ┃ params ┃ score ┃ +┡━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━╇━━━━━━━━┩ +│ 1 │ {'num_layers': 46, 'hidden_size': 2560, 'mamba_num_heads': 56, 'mamba_head_dim': 64, 'num_moe_experts': 120, │ 3.00B │ 27.06B │ 0.3399 │ +│ │ 'moe_ffn_hidden_size': 1792, 'moe_shared_expert_intermediate_size': 3072} │ │ │ │ +│ 2 │ {'num_layers': 48, 'hidden_size': 2560, 'mamba_num_heads': 56, 'mamba_head_dim': 56, 'num_moe_experts': 112, │ 3.00B │ 25.37B │ 0.4650 │ +│ │ 'moe_ffn_hidden_size': 1792, 'moe_shared_expert_intermediate_size': 3072} │ │ │ │ +│ 3 │ {'num_layers': 46, 'hidden_size': 2560, 'mamba_num_heads': 64, 'mamba_head_dim': 56, 'num_moe_experts': 112, │ 3.00B │ 25.37B │ 0.2343 │ +│ │ 'moe_ffn_hidden_size': 1792, 'moe_shared_expert_intermediate_size': 3072} │ │ │ │ +│ 4 │ {'num_layers': 52, 'hidden_size': 2688, 'mamba_num_heads': 56, 'mamba_head_dim': 48, 'num_moe_experts': 96, │ 3.00B │ 20.09B │ 0.2552 │ +│ │ 'moe_ffn_hidden_size': 1536, 'moe_shared_expert_intermediate_size': 3072} │ │ │ │ +│ 5 │ {'num_layers': 52, 'hidden_size': 2688, 'mamba_num_heads': 48, 'mamba_head_dim': 56, 'num_moe_experts': 104, │ 3.00B │ 21.61B │ 0.2601 │ +│ │ 'moe_ffn_hidden_size': 1536, 'moe_shared_expert_intermediate_size': 3072} │ │ │ │ +│ 6 │ {'num_layers': 52, 'hidden_size': 2560, 'mamba_num_heads': 48, 'mamba_head_dim': 64, 'num_moe_experts': 96, │ 3.00B │ 19.28B │ 0.3762 │ +│ │ 'moe_ffn_hidden_size': 1536, 'moe_shared_expert_intermediate_size': 3712} │ │ │ │ +│ 7 │ {'num_layers': 52, 'hidden_size': 2304, 'mamba_num_heads': 64, 'mamba_head_dim': 64, 'num_moe_experts': 104, │ 3.00B │ 22.28B │ 0.4783 │ +│ │ 'moe_ffn_hidden_size': 1856, 'moe_shared_expert_intermediate_size': 3072} │ │ │ │ +│ 8 │ {'num_layers': 52, 'hidden_size': 2560, 'mamba_num_heads': 48, 'mamba_head_dim': 48, 'num_moe_experts': 96, │ 3.00B │ 21.99B │ 0.2420 │ +│ │ 'moe_ffn_hidden_size': 1792, 'moe_shared_expert_intermediate_size': 3328} │ │ │ │ +│ 9 │ {'num_layers': 50, 'hidden_size': 2560, 'mamba_num_heads': 48, 'mamba_head_dim': 48, 'num_moe_experts': 112, │ 3.00B │ 25.37B │ 0.2399 │ +│ │ 'moe_ffn_hidden_size': 1792, 'moe_shared_expert_intermediate_size': 3712} │ │ │ │ +│ 10 │ {'num_layers': 50, 'hidden_size': 2560, 'mamba_num_heads': 48, 'mamba_head_dim': 48, 'num_moe_experts': 112, │ 3.00B │ 26.17B │ 0.2601 │ +│ │ 'moe_ffn_hidden_size': 1856, 'moe_shared_expert_intermediate_size': 3328} │ │ │ │ +│ 11 │ {'num_layers': 46, 'hidden_size': 2560, 'mamba_num_heads': 56, 'mamba_head_dim': 64, 'num_moe_experts': 112, │ 3.00B │ 25.37B │ 0.2503 │ +│ │ 'moe_ffn_hidden_size': 1792, 'moe_shared_expert_intermediate_size': 3072} │ │ │ │ +│ 12 │ {'num_layers': 48, 'hidden_size': 2560, 'mamba_num_heads': 56, 'mamba_head_dim': 56, 'num_moe_experts': 104, │ 3.00B │ 23.68B │ 0.4329 │ +│ │ 'moe_ffn_hidden_size': 1792, 'moe_shared_expert_intermediate_size': 3072} │ │ │ │ +│ 13 │ {'num_layers': 46, 'hidden_size': 2688, 'mamba_num_heads': 64, 'mamba_head_dim': 64, 'num_moe_experts': 128, │ 3.00B │ 26.17B │ 0.2587 │ +│ │ 'moe_ffn_hidden_size': 1536, 'moe_shared_expert_intermediate_size': 2816} │ │ │ │ +│ 14 │ {'num_layers': 46, 'hidden_size': 2560, 'mamba_num_heads': 64, 'mamba_head_dim': 56, 'num_moe_experts': 104, │ 3.00B │ 23.68B │ 0.2336 │ +│ │ 'moe_ffn_hidden_size': 1792, 'moe_shared_expert_intermediate_size': 3072} │ │ │ │ +│ 15 │ {'num_layers': 52, 'hidden_size': 2688, 'mamba_num_heads': 48, 'mamba_head_dim': 56, 'num_moe_experts': 96, │ 3.00B │ 20.09B │ 0.2559 │ +│ │ 'moe_ffn_hidden_size': 1536, 'moe_shared_expert_intermediate_size': 3072} │ │ │ │ +│ 16 │ {'num_layers': 52, 'hidden_size': 2304, 'mamba_num_heads': 64, 'mamba_head_dim': 64, 'num_moe_experts': 96, │ 3.00B │ 20.70B │ 0.4608 │ +│ │ 'moe_ffn_hidden_size': 1856, 'moe_shared_expert_intermediate_size': 3072} │ │ │ │ +│ 17 │ {'num_layers': 50, 'hidden_size': 2560, 'mamba_num_heads': 48, 'mamba_head_dim': 48, 'num_moe_experts': 104, │ 3.00B │ 23.68B │ 0.2455 │ +│ │ 'moe_ffn_hidden_size': 1792, 'moe_shared_expert_intermediate_size': 3712} │ │ │ │ +│ 18 │ {'num_layers': 50, 'hidden_size': 2560, 'mamba_num_heads': 48, 'mamba_head_dim': 48, 'num_moe_experts': 104, │ 3.00B │ 24.42B │ 0.2503 │ +│ │ 'moe_ffn_hidden_size': 1856, 'moe_shared_expert_intermediate_size': 3328} │ │ │ │ +│ 19 │ {'num_layers': 48, 'hidden_size': 2560, 'mamba_num_heads': 48, 'mamba_head_dim': 48, 'num_moe_experts': 120, │ 3.00B │ 27.92B │ 0.2587 │ +│ │ 'moe_ffn_hidden_size': 1856, 'moe_shared_expert_intermediate_size': 3712} │ │ │ │ +│ 20 │ {'num_layers': 46, 'hidden_size': 2560, 'mamba_num_heads': 56, 'mamba_head_dim': 64, 'num_moe_experts': 104, │ 3.00B │ 23.68B │ 0.2469 │ +│ │ 'moe_ffn_hidden_size': 1792, 'moe_shared_expert_intermediate_size': 3072} │ │ │ │ +└────┴───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┴───────────────┴────────┴────────┘ + +╭──────────────────────────────────────────────────────────────────────── Best Subnet ─────────────────────────────────────────────────────────────────────────╮ +│ export_config {'num_layers': 52, 'hidden_size': 2304, 'mamba_num_heads': 64, 'mamba_head_dim': 64, 'num_moe_experts': 104, 'moe_ffn_hidden_size': 1856, │ +│ 'moe_shared_expert_intermediate_size': 3072} │ +│ active_params 3.00B │ +│ params 22.28B │ +│ score 0.4783 │ +╰──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯ + +Original hybrid_layer_pattern: MEMEM*EMEMEM*EMEMEM*EMEMEM*EMEMEM*EMEMEMEM*EMEMEMEME +Pruned hybrid_layer_pattern: MEMEM*EMEMEM*EMEMEM*EMEMEM*EMEMEM*EMEMEMEM*EMEMEMEME + +╭───────────────────────────────────────────────────── Pruned Model Stats ──────────────────────────────────────────────────────╮ +│ Total Parameters 22.28B │ +│ Active Parameters 3.00B │ +│ Memory (BF16, seq_length=8192, batch_size=1) weights: 42489.7 MB, kv_cache: 48.0 MB, mamba_state: 23.8 MB, Total: 42561.6 MB │ +╰───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯ +``` + +
+ +> [!TIP] +> Candidate selection above relies on the pruning score alone — it does not run a short KD trial per candidate to pick the winner. The main post-pruning distillation in [Section 3](#3-distillation) is still performed on the selected candidate. If you want a stronger pick, take a few top candidates' `export_config` from the logs above (where the score is similar to the best subnet), export them separately, run KD for ~2B tokens on each, and pick the best on your target metrics. See [ABLATIONS.md — 1st vs 2nd best candidate](ABLATIONS.md#distillation-results-1st-best-vs-2nd-best-pruning-candidate) for a concrete comparison. + +--- + +### 3. Distillation + +Distillation is run in two phases: an 80B-token phase at 8K sequence length, followed by a 20B-token long-context phase at 32K sequence length. The two phases are launched as separate runs with an intermediate Megatron→HF checkpoint conversion, because the long-context phase changes `seq_length`, `gbs`, and `cp_size` — Megatron's checkpoint resume bookkeeping (sample counter is in absolute samples, iteration counter is in iter-units tied to `gbs`) does not handle a mid-run `gbs` change cleanly. + +Minimum hardware: **4 nodes × 8x H100 (32 GPUs)** for the 8K phase — required by `TP=4 × EP=8`. The 32K phase additionally requires context parallel to fit the longer sequence, doubling the minimum to **8 nodes × 8x H100 (64 GPUs)**. On **96 nodes × 8x H100 (768 GPUs total)**, it takes ~900 H100 GPU-hours per 10B tokens (400 iters), i.e. ~70 min wall-clock per 10B tokens on 96 nodes. Full schedule (80B @ 8K + 20B @ 32K = 100B tokens, 4k total steps) takes ~9k H100 GPU-hours (~12 hours wall-clock). + +#### 3a. Phase 1 — 80B tokens @ 8K seq length + +
+Phase 1 distillation command (click to expand) + +> NOTE: We use `python -u` for slurm multi-node run here. + +```bash +python -u /opt/Model-Optimizer/examples/megatron_bridge/distill.py \ + --teacher_hf_path nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16 \ + --student_hf_path /path/to/Nemotron-3-Nano-30B-A3B-Pruned-A3.0B \ + --trust_remote_code \ + --tp_size 4 \ + --ep_size 8 \ + --data_paths "${DATA_BLEND}" \ + --data_path_to_cache /path/to/cache \ + --seq_length 8192 \ + --mbs 1 \ + --gbs 3072 \ + --train_iters 3200 \ + --lr 1e-4 \ + --min_lr 1e-5 \ + --lr_warmup_iters 25 \ + --eval_interval 200 \ + --eval_iters 8 \ + --log_interval 10 \ + --output_dir /path/to/distill_output_phase1_8k + +# Optional: Weights & Biases logging +# --wandb_project \ +# --wandb_entity \ +# --wandb_exp_name +``` + +Non-default arguments: + +- `--seq_length 8192` (default: 4096) +- `--gbs 3072` (default: 768) — matches the original Nemotron-3-Nano-30B training GBS from the paper, kept to preserve the training distribution +- `--train_iters 3200` — 80B tokens at GBS 3072 × seq_length 8192 +- `--lr 1e-4 --min_lr 1e-5 --lr_warmup_iters 25` — cosine fully decays over 3200 iters; the model is approaching saturation at 8K by this point (see [ABLATIONS.md — 8K trajectory](ABLATIONS.md#effect-of-data-blend-tool_calling)). +- `--eval_interval 200` (default: 100) — less frequent eval to save compute +- `--eval_iters 8` (default: 32) — since GBS is 4× larger than default + +All other arguments use defaults. +
+ +#### 3b. Convert Phase 1 final checkpoint to HuggingFace format + +Phase 2 starts as a separate run from a fresh HuggingFace student checkpoint, so the final Phase 1 Megatron checkpoint must be exported to HF first using the Megatron-Bridge conversion script (see [Megatron-Bridge README](../../README.md) for full details). You can also use this same script to convert any intermediate Phase 1 checkpoint to HF format for evaluation along the way. + +
+Checkpoint conversion command (click to expand) + +> NOTE: Below command only works for non-quantized checkpoints. For quantized checkpoints, we use the `export.py` script in Section 5 to directly export the quantized checkpoint to Unified HF format for deployment. + +```bash +python /opt/Megatron-Bridge/examples/conversion/convert_checkpoints.py export \ + --hf-model /path/to/Nemotron-3-Nano-30B-A3B-Pruned-A3.0B \ + --megatron-path /path/to/distill_output_phase1_8k/checkpoints/iter_0003200 \ + --hf-path /path/to/distill_output_phase1_8k/checkpoints/hf_iter_0003200 +``` + +
+ +#### 3c. Phase 2 — 20B tokens @ 32K seq length + +Phase 2 is a **fresh run** with the Phase 1 final checkpoint as the new student. It uses a different `--seed` so the data blend reshuffles (otherwise the model would see overlapping prefix of the same samples it already saw at 8K). The LR is bumped back up modestly to capture the rapid long-context adaptation observed in [ABLATIONS.md — Effect of long context training](ABLATIONS.md#effect-of-long-context-training). + +
+Phase 2 distillation command (click to expand) + +```bash +python -u /opt/Model-Optimizer/examples/megatron_bridge/distill.py \ + --teacher_hf_path nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16 \ + --student_hf_path /path/to/distill_output_phase1_8k/checkpoints/hf_iter_0003200 \ + --trust_remote_code \ + --tp_size 4 \ + --cp_size 2 \ + --ep_size 8 \ + --seed 5678 \ + --data_paths "${DATA_BLEND}" \ + --data_path_to_cache /path/to/cache \ + --seq_length 32768 \ + --mbs 1 \ + --gbs 768 \ + --train_iters 800 \ + --lr 2e-5 \ + --min_lr 1e-5 \ + --lr_warmup_iters 10 \ + --recompute_granularity selective \ + --recompute_modules moe \ + --eval_interval 200 \ + --eval_iters 8 \ + --log_interval 10 \ + --output_dir /path/to/distill_output_phase2_32k +``` + +Changed arguments from Phase 1: + +- `--student_hf_path` — points at the HF export of the Phase 1 final checkpoint +- `--seq_length 32768` — long-context phase +- `--gbs 768` — `seq_length × gbs` product unchanged, so each iter still processes the same number of tokens +- `--cp_size 2` — context parallel is needed to fit the longer sequence; doubles the minimum-hardware footprint to 8 nodes +- `--train_iters 800` — 20B tokens at GBS 768 × seq_length 32768 +- `--lr 2e-5 --min_lr 1e-5 --lr_warmup_iters 10` — modest LR bump for the long-context adaptation (Phase 1 ended at fully-decayed LR 1e-5); the 10-iter warmup re-populates Adam moment estimates which restart from zero in a fresh run +- `--recompute_granularity selective --recompute_modules moe` — selective MoE recompute further reduces activation memory at 32K. You may skip this if you have more memory. +- `--seed 5678` — different from the Phase 1 seed (default 1234) so the data blend reshuffles +- `--output_dir /path/to/distill_output_phase2_32k` — must be a **fresh directory** different from Phase 1's, so distill.py's resume mechanism (which auto-loads from `/checkpoints` if it exists) does not pull in stale state + +
+ +For multi-node Slurm runs, see the [Megatron-Bridge README](../../README.md#slurm-usage) for details. + +> [!NOTE] +> This is pure SFT-style distillation — no RL or online reward signal is used. Adding an RL-based post-training step after distillation is a natural next step that could further improve some of these benchmarks. + +--- + +### 4. Evaluation + +The eval config in [nemo_evaluator.yaml](nemo_evaluator.yaml) is for Slurm-based evaluation — it submits a vLLM serving job (with tool calling enabled via `--enable-auto-tool-choice --tool-call-parser qwen3_coder`) and runs evals against it. For local model execution and evaluation, refer to the [NeMo Evaluator documentation](https://docs.nvidia.com/nemo/evaluator/latest/) or this [blog](https://huggingface.co/blog/nvidia/nemotron-3-nano-evaluation-recipe). + +
+Evaluation launch steps (click to expand) + +Before running, update the following fields in the `nemo_evaluator.yaml` file or overwrite them in the command line with `-o
+ +**Tasks and exact metric names reported in the results table:** + +| Benchmark | Library | num_repeats | Metric name | +| --- | --- | --- | --- | +| MMLU Pro | NeMo Evaluator | 1 | `mmlu-pro_pass_at_1_symbolic_correct` | +| GPQA Diamond | NeMo Evaluator | 8 | `gpqa_pass_at_1_avg-of-8_symbolic_correct` | +| LiveCodeBench v6 | NeMo Evaluator | 4 | `livecodebench_pass_at_1_avg-of-4_accuracy` | +| AIME 2025 | NeMo Evaluator | 32 | `aime25_pass_at_1_avg-of-32_symbolic_correct` | +| IFBench | NeMo Evaluator | 8 | `ifbench_pass_at_1_avg-of-8_average_score` | +| SciCode (Subtask) | NeMo Evaluator | 8 | `scicode_pass_at_1_avg-of-8_subtask_accuracy` | + +For more details on NeMo Evaluator, see the [GitHub repo](https://github.com/NVIDIA-NeMo/evaluator) and [documentation](https://docs.nvidia.com/nemo/evaluator/latest/). + +--- + +### 5. Quantization + +ModelOpt allows stacking multiple optimization techniques. Here we stack FP8 quantization on top of the pruned and distilled model to get an even more optimized model. See [examples/megatron_bridge/README.md](../../README.md) for the full Megatron-Bridge PTQ documentation. + +Similar to the official [Nemotron-3-Nano-30B-A3B-FP8](https://huggingface.co/nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-FP8) model, if you want to quantize the pruned 22B/A3.0B model to FP8, the Mamba, MoE, and MLP layers are quantized to FP8, while the attention layers and the Conv1d components within the Mamba layers are kept in BF16 to avoid accuracy degradation. + +This is done with the `MAMBA_MOE_FP8_CONSERVATIVE_CFG` config defined in [`modelopt/torch/quantization/config.py`](../../../../modelopt/torch/quantization/config.py), which you select by passing `--quant_cfg MAMBA_MOE_FP8_CONSERVATIVE_CFG` below. For a faster model at the cost of a larger accuracy drop, you can use `MAMBA_MOE_FP8_AGGRESSIVE_CFG` instead. + +> [!NOTE] +> You can also quantize to NVFP4 using `--quant_cfg MAMBA_MOE_NVFP4_CONSERVATIVE_CFG` or `MAMBA_MOE_NVFP4_AGGRESSIVE_CFG` (faster, more accuracy drop), which may require further distillation (QAD) to recover accuracy and a Blackwell GPU for deployment. + +Quantization is a two-step flow: `quantize.py` calibrates and saves a Megatron checkpoint, then `export.py` converts it to a deployable HuggingFace checkpoint (the unified HF exporter loads at TP=1, so pipeline parallelism is used to shard across GPUs). Both steps take a few minutes on 8x H100. + +**Step 1 — calibrate and save the quantized Megatron checkpoint:** + +
+FP8 PTQ command (click to expand) + +```bash +torchrun --nproc_per_node 8 /opt/Model-Optimizer/examples/megatron_bridge/quantize.py \ + --hf_model_name_or_path /path/to/distill_output_phase2_32k/checkpoints/hf_iter_0000800 \ + --trust_remote_code \ + --tp_size 8 \ + --quant_cfg MAMBA_MOE_FP8_CONSERVATIVE_CFG \ + --calib_batch_size 32 \ + --seq_length 512 \ + --export_megatron_path /path/to/distill_output_phase2_32k/checkpoints/iter_0000800_fp8_megatron \ + --skip_generate +``` + +
+ +**Step 2 — export the Megatron checkpoint to a deployable HuggingFace checkpoint:** + +
+Export command (click to expand) + +```bash +torchrun --nproc_per_node 1 /opt/Model-Optimizer/examples/megatron_bridge/export.py \ + --hf_model_name_or_path /path/to/distill_output_phase2_32k/checkpoints/hf_iter_0000800 \ + --megatron_path /path/to/distill_output_phase2_32k/checkpoints/iter_0000800_fp8_megatron \ + --trust_remote_code \ + --pp_size 1 \ + --export_unified_hf_path /path/to/distill_output_phase2_32k/checkpoints/hf_iter_0000800_fp8 +``` + +
+ +The exported HuggingFace checkpoint is directly deployable with [vLLM](https://github.com/vllm-project/vllm), [TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM) and [SGLang](https://github.com/sgl-project/sglang). + +> [!TIP] +> Run text generation on sample prompts to sanity-check the quantized checkpoint generates reasonable output: +> +> ```bash +> python /opt/Model-Optimizer/examples/megatron_bridge/generate_vllm.py \ +> --model /path/to/distill_output_phase2_32k/checkpoints/hf_iter_0000800_fp8 \ +> --trust_remote_code +> ``` + +> [!TIP] +> You can run the evaluation using the same `nemo_evaluator.yaml` file for the quantized checkpoint also — just apply the FP8 deployment tweaks documented at the top of the yaml. + +See FP8 vs BF16 results in the [Results](#results) section above. + +--- + +### 6. vLLM Inference Benchmarking + +Benchmark throughput using [vLLM](https://github.com/vllm-project/vllm) on a single H100 GPU. + +
+vLLM benchmark commands (ISL=32768, OSL=1024) (click to expand) + +```bash +# BF16 (original or pruned) +vllm bench throughput \ + --model \ + --random-input-len 32768 \ + --random-output-len 1024 \ + --trust-remote-code \ + --mamba_ssm_cache_dtype float32 \ + --load-format safetensors + +# FP8 (Hopper GPU) +VLLM_USE_FLASHINFER_MOE_FP8=1 VLLM_FLASHINFER_MOE_BACKEND=throughput \ +vllm bench throughput \ + --model \ + --random-input-len 32768 \ + --random-output-len 1024 \ + --trust-remote-code \ + --mamba_ssm_cache_dtype float32 \ + --kv-cache-dtype fp8 \ + --load-format safetensors +``` + +
+ +See the [vLLM Throughput table in Results](#vllm-throughput-single-h100-isl32768-osl1024) for measured numbers. + +> [!TIP] +> To deploy the model with vLLM, you can refer to the [vLLM Quickstart documentation](https://docs.vllm.ai/en/stable/getting_started/quickstart/). diff --git a/examples/pruning/minitron/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16/figures/learning_curves.png b/examples/megatron_bridge/tutorials/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16/figures/learning_curves.png similarity index 100% rename from examples/pruning/minitron/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16/figures/learning_curves.png rename to examples/megatron_bridge/tutorials/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16/figures/learning_curves.png diff --git a/examples/pruning/minitron/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16/nemo_evaluator.yaml b/examples/megatron_bridge/tutorials/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16/nemo_evaluator.yaml similarity index 76% rename from examples/pruning/minitron/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16/nemo_evaluator.yaml rename to examples/megatron_bridge/tutorials/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16/nemo_evaluator.yaml index cacbc078807..b2f63f17e86 100644 --- a/examples/pruning/minitron/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16/nemo_evaluator.yaml +++ b/examples/megatron_bridge/tutorials/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16/nemo_evaluator.yaml @@ -3,11 +3,15 @@ # Before running, update the following fields in the yaml: # - `execution.hostname` — your Slurm login node hostname # - `execution.account` — your Slurm account -# - `deployment.checkpoint_path` — Hugging Face checkpoint path (original, pruned or quantized) -# - `evaluation.nemo_evaluator_config.config.params.extra.tokenizer` — same path as `checkpoint_path` +# - `deployment.checkpoint_path` — Hugging Face checkpoint path (original, pruned, or quantized) +# +# This config is set up for a BF16 checkpoint. For an FP8 checkpoint, also apply the deployment changes below: +# - FP8: add `--kv-cache-dtype fp8` to `deployment.extra_args`, and set in `deployment.env_vars`: +# VLLM_USE_FLASHINFER_MOE_FP8: "1" +# VLLM_FLASHINFER_MOE_BACKEND: throughput # # Usage: -# pip install "nemo-evaluator-launcher[all]==0.1.90" +# pip install "nemo-evaluator-launcher[all]==0.1.82" # # # Set required environment variables: # export HF_TOKEN= @@ -31,9 +35,9 @@ defaults: execution: type: slurm - hostname: + hostname: ??? username: ${oc.env:USER} - account: + account: ??? partition: batch num_nodes: 1 ntasks_per_node: 1 @@ -54,17 +58,22 @@ execution: # Note: Only tp=1 works for Nano (Mamba-based hybrid architecture) deployment: # Update this to your Hugging Face checkpoint path (original, pruned or quantized) - checkpoint_path: + checkpoint_path: ??? served_model_name: Nemotron-3-Nano-30B-A3B port: 8000 tensor_parallel_size: 1 pipeline_parallel_size: 1 data_parallel_size: 8 gpu_memory_utilization: 0.90 - extra_args: "--max-model-len 262144 --enable-log-requests --no-enable-prefix-caching --trust-remote-code --mamba_ssm_cache_dtype float32 --enable-auto-tool-choice\ - \ --tool-call-parser qwen3_coder --reasoning-parser-plugin /checkpoint/nano_v3_reasoning_parser.py --reasoning-parser nano_v3" - env_vars: - VLLM_FLASHINFER_MOE_BACKEND: throughput + # extra_args is for the BF16 checkpoint. For an FP8 checkpoint, append `--kv-cache-dtype fp8`. + extra_args: "--max-model-len 262144 --max-num-seqs 8 --enable-log-requests --no-enable-prefix-caching --trust-remote-code --mamba_ssm_cache_dtype float32\ + \ --enable-auto-tool-choice --tool-call-parser qwen3_coder --reasoning-parser-plugin /checkpoint/nano_v3_reasoning_parser.py --reasoning-parser nano_v3" + # env_vars is for the BF16 checkpoint (no MoE backend flags needed). For an FP8 + # checkpoint, replace `{}` with the block below (see notes at top of file): + # FP8: + # VLLM_USE_FLASHINFER_MOE_FP8: "1" + # VLLM_FLASHINFER_MOE_BACKEND: throughput + env_vars: {} endpoints: chat: /v1/chat/completions completions: /v1/completions @@ -99,8 +108,8 @@ evaluation: max_retries: 10 extra: tokenizer_backend: huggingface - # Update tokenizer path to match checkpoint_path above - tokenizer: + # Auto-derived from deployment.checkpoint_path below + tokenizer: ${deployment.checkpoint_path} env_vars: HF_TOKEN: HF_TOKEN HF_HOME: HF_HOME @@ -118,7 +127,6 @@ evaluation: nemo_evaluator_config: config: params: - # limit_samples: 8 extra: num_repeats: 1 args: "++prompt_config=eval/aai/mcq-10choices-boxed" @@ -130,7 +138,6 @@ evaluation: nemo_evaluator_config: config: params: - # limit_samples: 8 extra: num_repeats: 8 args: "++prompt_config=eval/aai/mcq-4choices" @@ -142,9 +149,8 @@ evaluation: nemo_evaluator_config: config: params: - # limit_samples: 8 extra: - num_repeats: 8 + num_repeats: 4 dataset_split: test_v6_2408_2505 # 4. AIME 2025 @@ -154,9 +160,8 @@ evaluation: nemo_evaluator_config: config: params: - # limit_samples: 8 extra: - num_repeats: 64 + num_repeats: 32 # 5. IFBench - name: ns_ifbench @@ -165,7 +170,6 @@ evaluation: nemo_evaluator_config: config: params: - # limit_samples: 8 extra: num_repeats: 8 @@ -176,6 +180,5 @@ evaluation: nemo_evaluator_config: config: params: - # limit_samples: 8 extra: num_repeats: 8 diff --git a/examples/megatron_bridge/tutorials/README.md b/examples/megatron_bridge/tutorials/README.md new file mode 100644 index 00000000000..9a27c6914d9 --- /dev/null +++ b/examples/megatron_bridge/tutorials/README.md @@ -0,0 +1,10 @@ +# Megatron-Bridge Tutorials + +End-to-end tutorials that combine ModelOpt optimization techniques on [NVIDIA Megatron-Bridge](https://github.com/NVIDIA-NeMo/Megatron-Bridge) models. +Each one walks through a complete workflow using the scripts in [examples/megatron_bridge](../README.md) (`prune_minitron.py`, `distill.py`, `quantize.py`, `export.py`). + +## Available tutorials + +| Tutorial | What it covers | +| --- | --- | +| [NVIDIA-Nemotron-3-Nano-30B-A3B-BF16](NVIDIA-Nemotron-3-Nano-30B-A3B-BF16/README.md) | End-to-end optimization of the Nemotron-3-Nano-30B-A3B-BF16 (MoE + Mamba-Transformer hybrid) model: Minitron structured pruning (31.6B/A3.6B → 22B/A3.0B) → two-phase knowledge distillation (100B tokens, 8K then 32K seq length) → quantization → vLLM deployment. Includes data-blend preparation, evaluation setup, and detailed pruning / data-blend / long-context ablations. | diff --git a/examples/pruning/README.md b/examples/pruning/README.md index 0bf8d904ecf..025c61a6712 100644 --- a/examples/pruning/README.md +++ b/examples/pruning/README.md @@ -294,7 +294,7 @@ After pruning, distillation is required to recover model accuracy. Below are rec End-to-end distillation results with Megatron-Bridge after Minitron and Puzzletron pruning: -- **[Minitron — Nemotron-3-Nano-30B-A3B-BF16](minitron/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16/README.md)** ⭐ *recommended — newer and most comprehensive*: End-to-end tutorial of structured pruning for Nemotron-3-Nano-30B-A3B-BF16 (31.6B/A3.6B) to 22B/A3.0B active parameters followed by two-phase knowledge distillation (80B tokens @ 8K seq length + 20B tokens @ 32K seq length = 100B tokens total), quantization, and vLLM deployment. Covers MoE + Mamba-Transformer hybrid, tool-calling data, and a long-context fine-tuning phase. Achieves near-parity with the official 30B model across popular pretraining and reasoning benchmarks while delivering up to 1.64× throughput speedup and 2.6× memory reduction when combined with FP8 quantization. +- **[Minitron — Nemotron-3-Nano-30B-A3B-BF16](../megatron_bridge/tutorials/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16/README.md)** ⭐ *recommended — newer and most comprehensive*: End-to-end tutorial of structured pruning for Nemotron-3-Nano-30B-A3B-BF16 (31.6B/A3.6B) to 22B/A3.0B active parameters followed by two-phase knowledge distillation (80B tokens @ 8K seq length + 20B tokens @ 32K seq length = 100B tokens total), quantization, and vLLM deployment. Covers MoE + Mamba-Transformer hybrid, tool-calling data, and a long-context fine-tuning phase. Achieves near-parity with the official 30B model across popular pretraining and reasoning benchmarks while delivering up to 2.6× throughput speedup and 2.6× memory reduction when combined with FP8 quantization. - **[Minitron — Nemotron-Nano-9B-v2](minitron/NVIDIA-Nemotron-Nano-9B-v2/README.md)**: Earlier end-to-end tutorial covering structured pruning of the dense Mamba-Transformer Nemotron-Nano-9B-v2 to 7B followed by knowledge distillation up to 80B tokens, quantization, and vLLM deployment. Simpler architecture, single-phase 8K seq length distillation, no tool-calling or long-context phase. - **[Puzzletron — Qwen3-8B and Llama-3.1-8B-Instruct](puzzletron/Llama-3.1-8B-Instruct.md)**: MIP-based compression followed by short distillation runs on WikiText-103. Shows MMLU recovery and illustrates the importance of using larger datasets to avoid overfitting. diff --git a/examples/pruning/minitron/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16/README.md b/examples/pruning/minitron/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16/README.md index 59737f26206..338c7973f0d 100644 --- a/examples/pruning/minitron/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16/README.md +++ b/examples/pruning/minitron/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16/README.md @@ -1,490 +1,4 @@ -# Nemotron-3-Nano-30B-A3B: Prune + Distill + Quantize + vLLM Deployment +# This tutorial has moved -End-to-end optimization of [NVIDIA-Nemotron-3-Nano-30B-A3B-BF16](https://huggingface.co/nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16) demonstrating how ModelOpt techniques stack: Minitron structured pruning → Megatron-Bridge knowledge distillation to recover accuracy → evaluation benchmarking → FP8 quantization → vLLM deployment and throughput benchmarking. This document covers: - -1. **[Data Preparation](#1-data-preparation)** — tokenizing the training blend for distillation -2. **[Pruning](#2-pruning)** — Minitron structured pruning -3. **[Distillation](#3-distillation)** — recovering accuracy via Megatron-Bridge knowledge distillation -4. **[Evaluation](#4-evaluation)** — benchmarking with NeMo Evaluator across MMLU Pro, GPQA Diamond, AIME, and more -5. **[Quantization](#5-quantization)** — FP8 PTQ on the distilled checkpoint using ModelOpt's `examples/llm_ptq/hf_ptq.py` script -6. **[vLLM Inference Benchmarking](#6-vllm-inference-benchmarking)** — throughput comparison of BF16 vs FP8 on a single H100 - -## Results - -![Benchmark Recovery During Knowledge Distillation](figures/learning_curves.png) - -| Model | MMLU Pro | GPQA Diamond | LiveCodeBench v6 | AIME 2025 | IFBench | SciCode (Subtask) | Average | -| --- | --- | --- | --- | --- | --- | --- | --- | -| Pruned 22B/A3.0B (no distillation) | 47.1 | 33.5 | 27.4 | 15.5 | 36.9 | 12.1 | 28.8 | -| Distill @ 2.5B tokens (100 iters at 8K Seq Length) | 73.3 | 63.7 | 55.3 | 77.6 | 59.1 | 25.1 | 59.0 | -| Distill @ 20B tokens (800 iters at 8K Seq Length) | 74.8 | 66.0 | 62.3 | 79.6 | 65.4 | 26.1 | 62.4 | -| Distill @ 40B tokens (1600 iters at 8K Seq Length) | 76.4 | 67.2 | 62.3 | 79.8 | 66.0 | 26.6 | 63.1 | -| Distill @ 60B tokens (2400 iters at 8K Seq Length) | 76.1 | 68.1 | 63.6 | 78.8 | 67.3 | 27.0 | 63.5 | -| Distill @ 80B tokens (3200 iters at 8K Seq Length) | 76.5 | 69.1 | 63.9 | 80.7 | 66.5 | 29.0 | 64.3 | -| Distill @ 82.5B tokens (+100 iters at 32K Seq Length) | 76.2 | 69.8 | 64.8 | 87.0 | 68.2 | 27.0 | 65.5 | -| Distill @ 100B tokens (+800 iters at 32K Seq Length) - **BF16** | 76.6 | 69.6 | 66.1 | 87.3 | 68.9 | 28.4 | 66.2 | -| Distill @ 100B tokens + **FP8 Quantize** | 76.3 | 69.8 | 65.5 | 86.0 | 69.7 | 27.9 | 65.9 | -| NVIDIA-Nemotron-3-Nano-30B-A3B-BF16 (official, 31.6B/A3.6B) | 78.0 | 70.3 | 67.9 | 87.1 | 69.1 | 31.8 | 67.4 | - -### vLLM Throughput (single H100, ISL=32768, OSL=1024) - -| Checkpoint | Model loading memory | Output tokens/s | Speedup vs Nemotron-3-Nano-30B-A3B BF16 | -| --- | --- | --- | --- | -| Nemotron-3-Nano-30B-A3B-BF16 (official, 31.6B/A3.6B) | 58.9 GiB | 1,006 | 1.00× | -| Nemotron-3-Nano-30B-A3B-FP8 (official) | 31.4 GiB | 1,404 | 1.40× | -| Nemotron-3-Nano-30B-A3B-Pruned-A3.0B (22B/A3.0B) | 41.5 GiB | 1,301 | 1.29× | -| Nemotron-3-Nano-30B-A3B-Pruned-A3.0B-FP8 | 22.8 GiB | 1,653 | 1.64× | - -Pruning alone (BF16 → Pruned-A3.0B BF16) gives a **1.29×** throughput speedup with a 30% memory reduction (58.9 → 41.5 GiB), and FP8 quantization alone (BF16 → FP8) gives a **1.40×** speedup with a 47% memory reduction. Stacking both — pruning + FP8 — compounds to a **1.64×** throughput speedup and a **2.6× memory reduction** (58.9 → 22.8 GiB) relative to the original 30B BF16 model, while preserving most of the benchmark accuracy. The NemotronH hybrid architecture (Mamba + Attention + MoE) moderates the FP8 gain relative to pure-transformer models, since Attention and Conv1d layers are not quantized. See [Section 6](#6-vllm-inference-benchmarking) for the benchmark command. - -Distillation uses the **30% Pretraining (Code 5, General 20, MATH 5) + 70% Post-training v1/v3 (Math 27, Coding 20, Science 13, IF 5, Tool calling 5)** blend (see [Data Blend](#data-blend) below) with an **80B @ 8K + 20B @ 32K = 100B token** schedule. Blend ablations and long-context phase ablations are in [ABLATIONS.md](ABLATIONS.md). - -> [!TIP] -> From the benchmark numbers above, the model is still learning at 100B tokens and that further training (or a higher-quality data blend) would continue to close the gap to the original 31.6B/A3.6B model. - -> [!NOTE] -> Exact numbers may vary depending on deployment and evaluation setup. All models above (including the official model) were evaluated once with the same [evaluation setup](#4-evaluation) for fair comparison. These numbers may differ from those reported on the official [Nemotron-3-Nano-30B-A3B-BF16](https://huggingface.co/nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16) HuggingFace model card. - ---- - -## Steps to Reproduce - -**Environment:** Container `nvcr.io/nvidia/nemo:26.04`, ModelOpt 0.45.0. See the [Megatron-Bridge README](../../../megatron_bridge/README.md) for environment setup (including ModelOpt mount path) and container usage. - -### 1. Data Preparation - -See [examples/dataset/MEGATRON_DATA_PREP.md](../../../dataset/MEGATRON_DATA_PREP.md) for tokenization commands for all datasets used in this blend. - -For this experiment: `TOKENIZER=nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16`, `OUTPUT_DIR=tokenized_nemotron_3`. - -#### Data Blend - -**30% Pretraining (Code 5, General 20, MATH 5) + 70% Post-training v1/v3 (Math 27, Coding 20, Science 13, IF 5, Tool calling 5)** - -| Dataset | Tokens | Weight | Notes | -| ---------------------------------------------------------- | ------ | ------ | ---------------------------------------------- | -| Nemotron-Pretraining-SFT-v1 / Code (10M samples) | 7B | 5 | Pretraining code | -| Nemotron-Pretraining-SFT-v1 / General (10M samples) | 16B | 20 | Upweighted to close MMLU gap | -| Nemotron-Pretraining-SFT-v1 / MATH (10M samples) | 13B | 5 | Pretraining math | -| Nemotron-Math-v2 / high_part00 | 13B | 10 | Hard math reasoning | -| Nemotron-SFT-Math-v3 / train | 52B | 17 | Hard math reasoning with full reasoning traces | -| Nemotron-SFT-Competitive-Programming-v2 / python_00 | 7B | 15 | Python reasoning traces | -| Nemotron-SFT-Competitive-Programming-v2 / cpp_00 | 7B | 5 | C++ reasoning traces | -| Nemotron-Post-Training-Dataset-v1 / stem (5M samples) | 22B | 8 | Broad STEM | -| Nemotron-Science-v1 / MCQ | 0.5B | 3 | GPQA MCQ format alignment | -| Nemotron-Science-v1 / RQA | 0.3B | 2 | GPQA format diversity | -| Nemotron-SFT-Instruction-Following-Chat-v2 / reasoning_on | 2B | 3 | Instruction following (thinking on) | -| Nemotron-SFT-Instruction-Following-Chat-v2 / reasoning_off | 1B | 2 | Instruction following (thinking off) | -| Nemotron-Agentic-v1 / tool_calling | 1B | 5 | Tool-use scaffolding; helps SciCode / GPQA | - -
-Data blend for distillation (click to expand) - -```bash -DATA_BLEND=" \ -5 tokenized_nemotron_3/nvidia--Nemotron-Pretraining-SFT-v1_Nemotron-SFT-Code_train_text_max10000000 \ -20 tokenized_nemotron_3/nvidia--Nemotron-Pretraining-SFT-v1_Nemotron-SFT-General_train_text_max10000000 \ -5 tokenized_nemotron_3/nvidia--Nemotron-Pretraining-SFT-v1_Nemotron-SFT-MATH_train_text_max10000000 \ -10 tokenized_nemotron_3/nvidia--Nemotron-Math-v2_default_high_part00_messages \ -17 tokenized_nemotron_3/nvidia--Nemotron-SFT-Math-v3_default_train_messages \ -15 tokenized_nemotron_3/competitive_programming_python_00_messages \ -5 tokenized_nemotron_3/competitive_programming_cpp_00_messages \ -8 tokenized_nemotron_3/nvidia--Nemotron-Post-Training-Dataset-v1_default_stem_messages_max5000000 \ -3 tokenized_nemotron_3/MCQ_messages \ -2 tokenized_nemotron_3/RQA_messages \ -3 tokenized_nemotron_3/reasoning_on_messages \ -2 tokenized_nemotron_3/reasoning_off_messages \ -5 tokenized_nemotron_3/nvidia--Nemotron-Agentic-v1_tool_calling_messages \ -" -``` - -
- -#### General Guidelines - -The optimal blend is 30% pretraining and 70% post-training data. Exact proportions may vary depending on the benchmarks you care about. The blend above was designed to maximize recovery on popular General Knowledge, Reasoning, Instruction Following, and Tool Calling benchmarks. The key design decisions were: - -- **30% pretraining data** closes the MMLU gap that arises from training exclusively on reasoning-heavy post-training data. The General split (20%) is upweighted specifically to recover general knowledge recall. -- **Math (27%)** is the largest post-training category because AIME and MMLU Pro respond strongly to more math reasoning tokens. We use a mix of `Nemotron-Math-v2` and `Nemotron-SFT-Math-v3` for higher quality math reasoning signal with full reasoning traces. -- **Science (13%)** uses `Nemotron-Post-Training-Dataset-v1 / stem` as the primary source for volume and GPQA stability, with small allocations to `Nemotron-Science-v1` MCQ/RQA subsets for format alignment with GPQA's multiple-choice structure. -- **Instruction following (5%)** saturates quickly so a small allocation is sufficient. -- **Tool calling (5%)** uses `Nemotron-Agentic-v1 / tool_calling`. Our evals run with `--enable-auto-tool-choice`, so the student needs explicit exposure to function-call schemas; this helps SciCode (heavy Python tool use) and GPQA Diamond (which can benefit from calculator tools). - -This blend intentionally omits capabilities not targeted in this experiment (e.g. multilingual, SWE). Depending on what benchmarks matter for your use case, you can substitute or add datasets from the [Nemotron Post-Training v3 collection](https://huggingface.co/collections/nvidia/nemotron-post-training-v3), for example: - -| Capability | Relevant datasets | -| --- | --- | -| Multilingual | `Nemotron-SFT-Multilingual-v1` | -| Software engineering (SWE) | `Nemotron-SFT-SWE-v2` | -| Safety / alignment | `Nemotron-SFT-Safety-v1` | - -When adding new datasets, reduce weights of lower-priority categories proportionally to keep the total at 100%. - ---- - -### 2. Pruning - -Here we prune the [NVIDIA-Nemotron-3-Nano-30B-A3B-BF16](https://huggingface.co/nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16) HuggingFace checkpoint from 31.6B/A3.6B to 3.0B active parameters. The output is a pruned HuggingFace checkpoint that feeds into the distillation step. - -Run on **1 node with 8x H100** (~1 hour) - -
-Pruning command (click to expand) - -```bash -torchrun --nproc_per_node 8 /opt/Model-Optimizer/examples/megatron_bridge/prune_minitron.py \ - --pp_size 8 \ - --hf_model_name_or_path nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16 \ - --trust_remote_code \ - --prune_target_params 28e9 \ - --prune_target_active_params 3e9 \ - --hparams_to_skip num_attention_heads \ - --seq_length 8192 \ - --output_hf_path /path/to/Nemotron-3-Nano-30B-A3B-Pruned-A3.0B \ - --top_k 20 \ - --max_depth_pruning 0.15 \ - --max_width_pruning 0.30 \ - --prune_score_func mmlu_10pct_bs32 \ - --num_layers_in_first_pipeline_stage 5 \ - --num_layers_in_last_pipeline_stage 5 -``` - -Non-default arguments: - -- `--hparams_to_skip num_attention_heads` (default: none) — attention heads pruning is harder to recover, hence skipped -- `--seq_length 8192` (default: 4096) — dataset has longer sequences -- `--prune_target_active_params 3e9` — MoE-specific; the **primary** pruning constraint — targets active params rather than total params, which is what matters for MoE inference cost -- `--prune_target_params 28e9` — upper bound on total params only; the actual pruned model total can range anywhere from ~20B to 28B depending on which architecture wins — see pruning logs below for the top 20 candidates. You may also skip this argument all together for simplicity. -- `--top_k 20` (default: 10) — larger candidate pool for better architecture search -- `--max_depth_pruning 0.15` (default: 0.20) — tighter constraint since candidates with 42–46 layers universally fail for this model -- `--max_width_pruning 0.30` (default: 0.40) — tighter constraint to prevent head_dim≤48 and hidden=2048 dead zones -- `--prune_score_func mmlu_10pct_bs32` (default: `mmlu_10pct_bs1`) — batch_size=32 for ~3–4× faster candidate scoring -- `--num_layers_in_first_pipeline_stage 5 --num_layers_in_last_pipeline_stage 5` — Uneven pipeline parallelism since 52 layers is not divisible by 8 GPUs - -**NOTE**: The tighter search space constraints here (`--max_depth_pruning`, `--max_width_pruning`) are specific to Nemotron hybrid models (Mamba + Attention + MoE). Unlike standard transformers which expose only layers/hidden/attention/FFN dimensions, these models add Mamba-specific dimensions (`mamba_num_heads`, `mamba_head_dim`) and MoE dimensions (`num_moe_experts`, `moe_ffn_hidden_size`, `moe_shared_expert_intermediate_size`), making the combined search space much larger. The default 40%/20% bounds cast too wide a net and waste compute on dead-zone architectures. - -See [ABLATIONS.md](ABLATIONS.md#pruning) for the full architecture search analysis across various candidates. -
- -
-Pruning logs (top 20 candidates, best subnet, layer patterns) (click to expand) - -```text -╭──────────────────────────────────────────────────── Original Model Stats ─────────────────────────────────────────────────────╮ -│ Total Parameters 31.58B │ -│ Active Parameters 3.58B │ -│ Memory (BF16, seq_length=8192, batch_size=1) weights: 60230.1 MB, kv_cache: 48.0 MB, mamba_state: 23.8 MB, Total: 60301.9 MB │ -╰───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯ - - Search Space - (≤30% width / ≤15% depth pruning) -┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ -┃ Hyperparameter ┃ Choices ┃ -┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ -│ num_layers │ [46, 48, 50, 52] │ -│ hidden_size │ [2048, 2304, 2560, 2688] │ -│ mamba_num_heads │ [48, 56, 64] │ -│ mamba_head_dim │ [48, 56, 64] │ -│ num_moe_experts │ [96, 104, 112, 120, 128] │ -│ moe_ffn_hidden_size │ [1536, 1792, 1856] │ -│ moe_shared_expert_intermediate_size │ [2816, 3072, 3328, 3584, 3712] │ -├─────────────────────────────────────┼────────────────────────────────┤ -│ Search space size │ 10800 │ -└─────────────────────────────────────┴────────────────────────────────┘ - -Top 20 Candidates with Scores -┏━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━┳━━━━━━━━┓ -┃ # ┃ export_config ┃ active_params ┃ params ┃ score ┃ -┡━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━╇━━━━━━━━┩ -│ 1 │ {'num_layers': 46, 'hidden_size': 2560, 'mamba_num_heads': 56, 'mamba_head_dim': 64, 'num_moe_experts': 120, │ 3.00B │ 27.06B │ 0.3399 │ -│ │ 'moe_ffn_hidden_size': 1792, 'moe_shared_expert_intermediate_size': 3072} │ │ │ │ -│ 2 │ {'num_layers': 48, 'hidden_size': 2560, 'mamba_num_heads': 56, 'mamba_head_dim': 56, 'num_moe_experts': 112, │ 3.00B │ 25.37B │ 0.4650 │ -│ │ 'moe_ffn_hidden_size': 1792, 'moe_shared_expert_intermediate_size': 3072} │ │ │ │ -│ 3 │ {'num_layers': 46, 'hidden_size': 2560, 'mamba_num_heads': 64, 'mamba_head_dim': 56, 'num_moe_experts': 112, │ 3.00B │ 25.37B │ 0.2343 │ -│ │ 'moe_ffn_hidden_size': 1792, 'moe_shared_expert_intermediate_size': 3072} │ │ │ │ -│ 4 │ {'num_layers': 52, 'hidden_size': 2688, 'mamba_num_heads': 56, 'mamba_head_dim': 48, 'num_moe_experts': 96, │ 3.00B │ 20.09B │ 0.2552 │ -│ │ 'moe_ffn_hidden_size': 1536, 'moe_shared_expert_intermediate_size': 3072} │ │ │ │ -│ 5 │ {'num_layers': 52, 'hidden_size': 2688, 'mamba_num_heads': 48, 'mamba_head_dim': 56, 'num_moe_experts': 104, │ 3.00B │ 21.61B │ 0.2601 │ -│ │ 'moe_ffn_hidden_size': 1536, 'moe_shared_expert_intermediate_size': 3072} │ │ │ │ -│ 6 │ {'num_layers': 52, 'hidden_size': 2560, 'mamba_num_heads': 48, 'mamba_head_dim': 64, 'num_moe_experts': 96, │ 3.00B │ 19.28B │ 0.3762 │ -│ │ 'moe_ffn_hidden_size': 1536, 'moe_shared_expert_intermediate_size': 3712} │ │ │ │ -│ 7 │ {'num_layers': 52, 'hidden_size': 2304, 'mamba_num_heads': 64, 'mamba_head_dim': 64, 'num_moe_experts': 104, │ 3.00B │ 22.28B │ 0.4783 │ -│ │ 'moe_ffn_hidden_size': 1856, 'moe_shared_expert_intermediate_size': 3072} │ │ │ │ -│ 8 │ {'num_layers': 52, 'hidden_size': 2560, 'mamba_num_heads': 48, 'mamba_head_dim': 48, 'num_moe_experts': 96, │ 3.00B │ 21.99B │ 0.2420 │ -│ │ 'moe_ffn_hidden_size': 1792, 'moe_shared_expert_intermediate_size': 3328} │ │ │ │ -│ 9 │ {'num_layers': 50, 'hidden_size': 2560, 'mamba_num_heads': 48, 'mamba_head_dim': 48, 'num_moe_experts': 112, │ 3.00B │ 25.37B │ 0.2399 │ -│ │ 'moe_ffn_hidden_size': 1792, 'moe_shared_expert_intermediate_size': 3712} │ │ │ │ -│ 10 │ {'num_layers': 50, 'hidden_size': 2560, 'mamba_num_heads': 48, 'mamba_head_dim': 48, 'num_moe_experts': 112, │ 3.00B │ 26.17B │ 0.2601 │ -│ │ 'moe_ffn_hidden_size': 1856, 'moe_shared_expert_intermediate_size': 3328} │ │ │ │ -│ 11 │ {'num_layers': 46, 'hidden_size': 2560, 'mamba_num_heads': 56, 'mamba_head_dim': 64, 'num_moe_experts': 112, │ 3.00B │ 25.37B │ 0.2503 │ -│ │ 'moe_ffn_hidden_size': 1792, 'moe_shared_expert_intermediate_size': 3072} │ │ │ │ -│ 12 │ {'num_layers': 48, 'hidden_size': 2560, 'mamba_num_heads': 56, 'mamba_head_dim': 56, 'num_moe_experts': 104, │ 3.00B │ 23.68B │ 0.4329 │ -│ │ 'moe_ffn_hidden_size': 1792, 'moe_shared_expert_intermediate_size': 3072} │ │ │ │ -│ 13 │ {'num_layers': 46, 'hidden_size': 2688, 'mamba_num_heads': 64, 'mamba_head_dim': 64, 'num_moe_experts': 128, │ 3.00B │ 26.17B │ 0.2587 │ -│ │ 'moe_ffn_hidden_size': 1536, 'moe_shared_expert_intermediate_size': 2816} │ │ │ │ -│ 14 │ {'num_layers': 46, 'hidden_size': 2560, 'mamba_num_heads': 64, 'mamba_head_dim': 56, 'num_moe_experts': 104, │ 3.00B │ 23.68B │ 0.2336 │ -│ │ 'moe_ffn_hidden_size': 1792, 'moe_shared_expert_intermediate_size': 3072} │ │ │ │ -│ 15 │ {'num_layers': 52, 'hidden_size': 2688, 'mamba_num_heads': 48, 'mamba_head_dim': 56, 'num_moe_experts': 96, │ 3.00B │ 20.09B │ 0.2559 │ -│ │ 'moe_ffn_hidden_size': 1536, 'moe_shared_expert_intermediate_size': 3072} │ │ │ │ -│ 16 │ {'num_layers': 52, 'hidden_size': 2304, 'mamba_num_heads': 64, 'mamba_head_dim': 64, 'num_moe_experts': 96, │ 3.00B │ 20.70B │ 0.4608 │ -│ │ 'moe_ffn_hidden_size': 1856, 'moe_shared_expert_intermediate_size': 3072} │ │ │ │ -│ 17 │ {'num_layers': 50, 'hidden_size': 2560, 'mamba_num_heads': 48, 'mamba_head_dim': 48, 'num_moe_experts': 104, │ 3.00B │ 23.68B │ 0.2455 │ -│ │ 'moe_ffn_hidden_size': 1792, 'moe_shared_expert_intermediate_size': 3712} │ │ │ │ -│ 18 │ {'num_layers': 50, 'hidden_size': 2560, 'mamba_num_heads': 48, 'mamba_head_dim': 48, 'num_moe_experts': 104, │ 3.00B │ 24.42B │ 0.2503 │ -│ │ 'moe_ffn_hidden_size': 1856, 'moe_shared_expert_intermediate_size': 3328} │ │ │ │ -│ 19 │ {'num_layers': 48, 'hidden_size': 2560, 'mamba_num_heads': 48, 'mamba_head_dim': 48, 'num_moe_experts': 120, │ 3.00B │ 27.92B │ 0.2587 │ -│ │ 'moe_ffn_hidden_size': 1856, 'moe_shared_expert_intermediate_size': 3712} │ │ │ │ -│ 20 │ {'num_layers': 46, 'hidden_size': 2560, 'mamba_num_heads': 56, 'mamba_head_dim': 64, 'num_moe_experts': 104, │ 3.00B │ 23.68B │ 0.2469 │ -│ │ 'moe_ffn_hidden_size': 1792, 'moe_shared_expert_intermediate_size': 3072} │ │ │ │ -└────┴───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┴───────────────┴────────┴────────┘ - -╭──────────────────────────────────────────────────────────────────────── Best Subnet ─────────────────────────────────────────────────────────────────────────╮ -│ export_config {'num_layers': 52, 'hidden_size': 2304, 'mamba_num_heads': 64, 'mamba_head_dim': 64, 'num_moe_experts': 104, 'moe_ffn_hidden_size': 1856, │ -│ 'moe_shared_expert_intermediate_size': 3072} │ -│ active_params 3.00B │ -│ params 22.28B │ -│ score 0.4783 │ -╰──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯ - -Original hybrid_layer_pattern: MEMEM*EMEMEM*EMEMEM*EMEMEM*EMEMEM*EMEMEMEM*EMEMEMEME -Pruned hybrid_layer_pattern: MEMEM*EMEMEM*EMEMEM*EMEMEM*EMEMEM*EMEMEMEM*EMEMEMEME - -╭───────────────────────────────────────────────────── Pruned Model Stats ──────────────────────────────────────────────────────╮ -│ Total Parameters 22.28B │ -│ Active Parameters 3.00B │ -│ Memory (BF16, seq_length=8192, batch_size=1) weights: 42489.7 MB, kv_cache: 48.0 MB, mamba_state: 23.8 MB, Total: 42561.6 MB │ -╰───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯ -``` - -
- -> [!TIP] -> Candidate selection above relies on the pruning score alone — it does not run a short KD trial per candidate to pick the winner. The main post-pruning distillation in [Section 3](#3-distillation) is still performed on the selected candidate. If you want a stronger pick, take a few top candidates' `export_config` from the logs above (where the score is similar to the best subnet), export them separately, run KD for ~2B tokens on each, and pick the best on your target metrics. See [ABLATIONS.md — 1st vs 2nd best candidate](ABLATIONS.md#distillation-results-1st-best-vs-2nd-best-pruning-candidate) for a concrete comparison. - ---- - -### 3. Distillation - -Distillation is run in two phases: an 80B-token phase at 8K sequence length, followed by a 20B-token long-context phase at 32K sequence length. The two phases are launched as separate runs with an intermediate Megatron→HF checkpoint conversion, because the long-context phase changes `seq_length`, `gbs`, and `cp_size` — Megatron's checkpoint resume bookkeeping (sample counter is in absolute samples, iteration counter is in iter-units tied to `gbs`) does not handle a mid-run `gbs` change cleanly. - -Minimum hardware: **4 nodes × 8x H100 (32 GPUs)** for the 8K phase — required by `TP=4 × EP=8`. The 32K phase additionally requires context parallel to fit the longer sequence, doubling the minimum to **8 nodes × 8x H100 (64 GPUs)**. On **96 nodes × 8x H100 (768 GPUs total)**, it takes ~900 H100 GPU-hours per 10B tokens (400 iters), i.e. ~70 min wall-clock per 10B tokens on 96 nodes. Full schedule (80B @ 8K + 20B @ 32K = 100B tokens, 4k total steps) takes ~9k H100 GPU-hours (~12 hours wall-clock). - -#### 3a. Phase 1 — 80B tokens @ 8K seq length - -
-Phase 1 distillation command (click to expand) - -```bash -python -u /opt/Model-Optimizer/examples/megatron_bridge/distill.py \ - --teacher_hf_path nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16 \ - --student_hf_path /path/to/Nemotron-3-Nano-30B-A3B-Pruned-A3.0B \ - --trust_remote_code \ - --tp_size 4 \ - --ep_size 8 \ - --data_paths "${DATA_BLEND}" \ - --data_path_to_cache /path/to/cache \ - --seq_length 8192 \ - --mbs 1 \ - --gbs 3072 \ - --train_iters 3200 \ - --lr 1e-4 \ - --min_lr 1e-5 \ - --lr_warmup_iters 25 \ - --eval_interval 200 \ - --eval_iters 8 \ - --log_interval 10 \ - --output_dir /path/to/distill_output_phase1_8k - -# Optional: Weights & Biases logging -# --wandb_project \ -# --wandb_entity \ -# --wandb_exp_name -``` - -Non-default arguments: - -- `--seq_length 8192` (default: 4096) -- `--gbs 3072` (default: 768) — matches the original Nemotron-3-Nano-30B training GBS from the paper, kept to preserve the training distribution -- `--train_iters 3200` — 80B tokens at GBS 3072 × seq_length 8192 -- `--lr 1e-4 --min_lr 1e-5 --lr_warmup_iters 25` — cosine fully decays over 3200 iters; the model is approaching saturation at 8K by this point (see [ABLATIONS.md — 8K trajectory](ABLATIONS.md#effect-of-data-blend-tool_calling)). -- `--eval_interval 200` (default: 100) — less frequent eval to save compute -- `--eval_iters 8` (default: 32) — since GBS is 4× larger than default - -All other arguments use defaults. -
- -#### 3b. Convert Phase 1 final checkpoint to HuggingFace format - -Phase 2 starts as a separate run from a fresh HuggingFace student checkpoint, so the final Phase 1 Megatron checkpoint must be exported to HF first using the Megatron-Bridge conversion script (see [Megatron-Bridge README](../../../megatron_bridge/README.md) for full details). You can also use this same script to convert any intermediate Phase 1 checkpoint to HF format for evaluation along the way. - -
-Checkpoint conversion command (click to expand) - -```bash -python /opt/Megatron-Bridge/examples/conversion/convert_checkpoints.py export \ - --hf-model /path/to/Nemotron-3-Nano-30B-A3B-Pruned-A3.0B \ - --megatron-path /path/to/distill_output_phase1_8k/checkpoints/iter_0003200 \ - --hf-path /path/to/distill_output_phase1_8k/checkpoints/hf_iter_0003200 -``` - -
- -#### 3c. Phase 2 — 20B tokens @ 32K seq length - -Phase 2 is a **fresh run** with the Phase 1 final checkpoint as the new student. It uses a different `--seed` so the data blend reshuffles (otherwise the model would see overlapping prefix of the same samples it already saw at 8K). The LR is bumped back up modestly to capture the rapid long-context adaptation observed in [ABLATIONS.md — Effect of long context training](ABLATIONS.md#effect-of-long-context-training). - -
-Phase 2 distillation command (click to expand) - -```bash -python -u /opt/Model-Optimizer/examples/megatron_bridge/distill.py \ - --teacher_hf_path nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16 \ - --student_hf_path /path/to/distill_output_phase1_8k/checkpoints/hf_iter_0003200 \ - --trust_remote_code \ - --tp_size 4 \ - --cp_size 2 \ - --ep_size 8 \ - --seed 5678 \ - --data_paths "${DATA_BLEND}" \ - --data_path_to_cache /path/to/cache \ - --seq_length 32768 \ - --mbs 1 \ - --gbs 768 \ - --train_iters 800 \ - --lr 2e-5 \ - --min_lr 1e-5 \ - --lr_warmup_iters 10 \ - --recompute_granularity selective \ - --recompute_modules moe \ - --eval_interval 200 \ - --eval_iters 8 \ - --log_interval 10 \ - --output_dir /path/to/distill_output_phase2_32k -``` - -Changed arguments from Phase 1: - -- `--student_hf_path` — points at the HF export of the Phase 1 final checkpoint -- `--seq_length 32768` — long-context phase -- `--gbs 768` — `seq_length × gbs` product unchanged, so each iter still processes the same number of tokens -- `--cp_size 2` — context parallel is needed to fit the longer sequence; doubles the minimum-hardware footprint to 8 nodes -- `--train_iters 800` — 20B tokens at GBS 768 × seq_length 32768 -- `--lr 2e-5 --min_lr 1e-5 --lr_warmup_iters 10` — modest LR bump for the long-context adaptation (Phase 1 ended at fully-decayed LR 1e-5); the 10-iter warmup re-populates Adam moment estimates which restart from zero in a fresh run -- `--recompute_granularity selective --recompute_modules moe` — selective MoE recompute further reduces activation memory at 32K. You may skip this if you have more memory. -- `--seed 5678` — different from the Phase 1 seed (default 1234) so the data blend reshuffles -- `--output_dir /path/to/distill_output_phase2_32k` — must be a **fresh directory** different from Phase 1's, so distill.py's resume mechanism (which auto-loads from `/checkpoints` if it exists) does not pull in stale state - -
- -For multi-node Slurm runs, see the [Megatron-Bridge README](../../../megatron_bridge/README.md#slurm-usage) for details. - -> [!NOTE] -> This is pure SFT-style distillation — no RL or online reward signal is used. Adding an RL-based post-training step after distillation is a natural next step that could further improve some of these benchmarks. - ---- - -### 4. Evaluation - -The eval config in [nemo_evaluator.yaml](nemo_evaluator.yaml) is for Slurm-based evaluation — it submits a vLLM serving job (with tool calling enabled via `--enable-auto-tool-choice --tool-call-parser qwen3_coder`) and runs evals against it. For local model execution and evaluation, refer to the [NeMo Evaluator documentation](https://docs.nvidia.com/nemo/evaluator/latest/) or this [blog](https://huggingface.co/blog/nvidia/nemotron-3-nano-evaluation-recipe). - -Before running, update the following fields in the yaml or overwrite them in the command line with `-o