From 5583159ecb62a68b922d18a098640683d7d392a6 Mon Sep 17 00:00:00 2001 From: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> Date: Mon, 8 Jun 2026 21:26:12 +0530 Subject: [PATCH 01/24] Fix CodeCov upload issues (#1648) Disable codecov binary validation which seems to be constantly failing ``` gpg: Signature made Tue Apr 21 19:28:03 2026 UTC gpg: using RSA key 27034E7FDB850E0BBC2C62FF806BB28AED779869 gpg: Can't check signature: No public key ==> Could not verify signature. Please contact Codecov if problem continues Exiting... ``` ## Summary by CodeRabbit * **Chores** * Updated CI workflow notes and removed an outdated header comment. * Added explanatory comments to the Linux job and adjusted the code coverage upload step to use a relaxed validation mode (no other upload settings changed). --------- Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> --- .github/workflows/unit_tests.yml | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/.github/workflows/unit_tests.yml b/.github/workflows/unit_tests.yml index fc2ae364cba..0a67f5f0b1d 100644 --- a/.github/workflows/unit_tests.yml +++ b/.github/workflows/unit_tests.yml @@ -1,4 +1,3 @@ -# NOTE: Make sure this file is consistent with .gitlab/tests.yml name: Unit tests on: @@ -73,6 +72,10 @@ jobs: token: ${{ secrets.CODECOV_TOKEN }} flags: unit fail_ci_if_error: true + # Skip GPG/SHASUM integrity check of the Codecov CLI: its key import + # intermittently fails (codecov/codecov-action#1876), which would + # otherwise hard-fail this required job on a transient infra blip. + skip_validation: true verbose: true windows: if: needs.check-file-changes.outputs.any_changed == 'true' From 1faaf7a9ff56a7457251939486dabe0d40cdc597 Mon Sep 17 00:00:00 2001 From: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> Date: Mon, 8 Jun 2026 23:41:51 +0530 Subject: [PATCH 02/24] fix(llm_eval): repair test_qwen3_eval_fp8 end-to-end (#1650) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What does this PR do? Type of change: Bug fix `tests/examples/llm_eval/test_llm_eval.py::test_qwen3_eval_fp8` was silently passing while its evals crashed, then began failing as a timeout. This repairs the whole pipeline: - **lm_eval `IndexError` (root cause):** TRT-LLM KV-cache prefix reuse returns truncated `context_logits` for shared-prefix requests (e.g. hellaswag's one-context / many-endings), which breaks `parse_logprobs`. Add an `enable_kv_cache_reuse` flag to `modelopt.deploy.llm.LLM` (default `True`, unchanged) and disable it for the eval deployment so full-length context logits are returned. - **Silent CI green:** `python eval.py | tee result.txt` returns `tee`'s exit code, so a crashing eval was masked. Add `set -o pipefail` to `huggingface_example.sh` so failures fail the test. - **Long-prompt overflows:** with the tiny test model's toy tokenizer, gsm8k/MMLU prompts exceed `max_seq_len`. Bump test `max_position_embeddings` to 8192, skip MMLU prompts that don't fit even at zero-shot, and add an MMLU sample limit (`--mmlu_limit`). - **human-eval build failures:** install with `--no-build-isolation` (`pkg_resources` is absent in pip's isolated build env), patch its malformed `console_scripts` entry point, and pin the clone. - **Cleanups:** gate the post-quant `run_tensorrt_llm.py` smoke test behind the `quant` task (eval tasks deploy on their own; ~45s saved for eval-only runs); replace the SIGPIPE-prone serve-readiness `tail -f | while` with a poll loop (required under `pipefail`). ### Usage N/A — example/test fix. ### Testing All four eval tasks verified end-to-end in the CI container (TRT-LLM 1.3.0rc17, RTX 6000 Ada): lm_eval (hellaswag + gsm8k), MMLU, and simple_eval (humaneval) all complete with exit 0 and no `IndexError`/overflow. Cold full run ≈ 340s on this GPU. CI test on 2-gpu: https://github.com/NVIDIA/Model-Optimizer/actions/runs/27154417497/job/80153551154 ### Before your PR is "*Ready for review*" - Is this change backward compatible?: ✅ (new `enable_kv_cache_reuse` defaults to current behavior; new script flags are optional) - If you copied code from any other sources or added a new PIP dependency, did you follow guidance in `CONTRIBUTING.md`: N/A (no new dependencies) - Did you write any new necessary tests?: N/A (fixes and strengthens an existing test) - Did you update Changelog?: N/A (bug fix to examples/tests) - Did you get Claude approval on this PR?: ❌ (pending) ### Additional Information The full test runs ~340s on an RTX 6000 Ada; CI runners are historically slower, while `@pytest.mark.timeout` is set to 600 — worth watching the first CI run and bumping if it's close. 🤖 Generated with [Claude Code](https://claude.com/claude-code) ## Summary by CodeRabbit * **New Features** * Added an option to limit MMLU evaluation length. * **Bug Fixes** * Disabled KV-cache prefix reuse for evaluations needing per-token context logits to prevent truncated/incorrect logprobs. * Skip examples whose prompts remain too long; warn and report accuracy as NaN if all examples are skipped. * **Chores / Scripts** * Improved example scripts for reproducible installs, patched entry point handling, pipeline failure detection, conditional test invocation, polling-based log wait, and a new CLI flag for MMLU limits. * **Tests** * Increased timeout and prompt headroom; capped MMLU smoke tests for speed. --------- Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> Co-authored-by: Claude Opus 4.8 Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> --- examples/llm_eval/lm_eval_tensorrt_llm.py | 4 +++ examples/llm_eval/mmlu.py | 21 +++++++++++++-- examples/llm_eval/run_simple_eval.sh | 14 +++++++++- examples/llm_ptq/run_tensorrt_llm.py | 9 ++++++- .../llm_ptq/scripts/huggingface_example.sh | 27 ++++++++++++++----- examples/llm_ptq/scripts/parser.sh | 5 +++- modelopt/deploy/llm/generate.py | 12 +++++++++ tests/examples/llm_eval/test_llm_eval.py | 13 ++++++--- 8 files changed, 89 insertions(+), 16 deletions(-) diff --git a/examples/llm_eval/lm_eval_tensorrt_llm.py b/examples/llm_eval/lm_eval_tensorrt_llm.py index 181fc9c79f1..f65dad53655 100644 --- a/examples/llm_eval/lm_eval_tensorrt_llm.py +++ b/examples/llm_eval/lm_eval_tensorrt_llm.py @@ -64,6 +64,10 @@ def __init__( tokenizer=self.tokenizer, max_batch_size=int(batch_size), max_seq_len=max_length, + # Loglikelihood tasks request context logits. KV cache prefix reuse would return + # logits only for the recomputed suffix on shared-prefix requests (e.g. hellaswag), + # truncating context_logits and breaking parse_logprobs. Disable it. + enable_kv_cache_reuse=False, ) self.max_length = max_length - 1 logger.info("Loaded TRT-LLM") diff --git a/examples/llm_eval/mmlu.py b/examples/llm_eval/mmlu.py index a3047fcc9e1..3d03240c408 100755 --- a/examples/llm_eval/mmlu.py +++ b/examples/llm_eval/mmlu.py @@ -183,7 +183,10 @@ def gen_prompt(train_df, subject, k=-1): def evaluate(args, subject, model: EvalModel | LLM, dev_df, test_df): cors = [] all_probs = [] - for i in range(test_df.shape[0]): + num_examples = test_df.shape[0] + if args.limit is not None: + num_examples = min(num_examples, args.limit) + for i in range(num_examples): # get prompt and make sure it fits k = args.ntrain prompt_end = format_example(test_df, i, include_answer=False) @@ -201,6 +204,12 @@ def check_valid_length(model, prompt): train_prompt = gen_prompt(dev_df, subject, k) prompt = train_prompt + prompt_end + # Skip examples that do not fit even at zero-shot, otherwise the backend rejects + # prompts longer than max_seq_len and aborts the whole evaluation. + if not check_valid_length(model, prompt): + print(f"Skipping {subject} example {i}: prompt exceeds max_seq_len even at 0-shot.") + continue + label = test_df.iloc[i, test_df.shape[1] - 1] if isinstance(model, EvalModel): pred = model.run(prompt) @@ -212,7 +221,11 @@ def check_valid_length(model, prompt): cors.append(cor) all_probs.append(probs) - acc = np.mean(cors) + if not cors: + # Every example was skipped (all prompts exceeded max_seq_len). Surface it instead of + # silently producing a nan accuracy downstream. + print(f"WARNING: all {subject} examples were skipped; reporting accuracy as nan.") + acc = np.mean(cors) if cors else float("nan") cors = np.array(cors) all_probs = np.array(all_probs) @@ -233,8 +246,12 @@ def main( auto_quantize_score_size: int = 128, auto_quantize_checkpoint: str | None = None, sparse_cfg: str | None = None, + limit: int | None = None, **kwargs, ): + if limit is not None and limit <= 0: + raise ValueError(f"limit must be a positive integer when provided, got {limit}.") + random.seed(RAND_SEED) np.random.seed(RAND_SEED) diff --git a/examples/llm_eval/run_simple_eval.sh b/examples/llm_eval/run_simple_eval.sh index 5f40b4ce8b3..763e88432f8 100644 --- a/examples/llm_eval/run_simple_eval.sh +++ b/examples/llm_eval/run_simple_eval.sh @@ -28,11 +28,23 @@ if [ ! -d "human-eval" ]; then git clone https://github.com/openai/human-eval.git fi +# Pin to a known commit for reproducibility (and so the entry-point patch below matches), forcing +# it every run so a reused checkout cannot drift to an arbitrary revision. -f discards the patch +# applied to setup.py on a previous run before re-applying it below. +git -C human-eval checkout -q -f 6d43fb980f9fee3c892a914eda09951f772ad10d + +# human-eval's console_scripts entry point lacks the ":callable" suffix, which newer pip/setuptools +# reject ("A callable suffix is required"). The target module defines main(), so point at it. +sed -i 's|human_eval\.evaluate_functional_correctness"|human_eval.evaluate_functional_correctness:main"|' human-eval/setup.py + if [ ! -d "simple-evals" ]; then git clone https://github.com/openai/simple-evals.git fi -pip install -e human-eval +# --no-build-isolation: human-eval's legacy setup.py imports pkg_resources at build time, +# which pip's isolated build env does not provide with newer setuptools. Build against the +# base environment (which has setuptools/pkg_resources) instead. +pip install -e human-eval --no-build-isolation pip install openai pushd simple-evals diff --git a/examples/llm_ptq/run_tensorrt_llm.py b/examples/llm_ptq/run_tensorrt_llm.py index 56d25df709c..f7cc588f40a 100644 --- a/examples/llm_ptq/run_tensorrt_llm.py +++ b/examples/llm_ptq/run_tensorrt_llm.py @@ -66,7 +66,14 @@ def run(args): print("TensorRT-LLM example outputs:") - llm = LLM(args.checkpoint_dir, tokenizer=tokenizer, max_batch_size=len(input_texts)) + # generate_context_logits() below requires KV cache reuse disabled: with prefix block reuse, + # shared-prefix inputs return truncated (silently incorrect) context logits. + llm = LLM( + args.checkpoint_dir, + tokenizer=tokenizer, + max_batch_size=len(input_texts), + enable_kv_cache_reuse=False, + ) torch.cuda.cudart().cudaProfilerStart() outputs = llm.generate_text(input_texts, args.max_output_len) torch.cuda.cudart().cudaProfilerStop() diff --git a/examples/llm_ptq/scripts/huggingface_example.sh b/examples/llm_ptq/scripts/huggingface_example.sh index 541c349da08..3f51e5b73f3 100755 --- a/examples/llm_ptq/scripts/huggingface_example.sh +++ b/examples/llm_ptq/scripts/huggingface_example.sh @@ -29,6 +29,10 @@ for i in $(env | grep ^SLURM_ | cut -d"=" -f 1); do unset -v $i; done for i in $(env | grep ^PMI_ | cut -d"=" -f 1); do unset -v $i; done for i in $(env | grep ^PMIX_ | cut -d"=" -f 1); do unset -v $i; done +# Fail on errors inside pipelines (e.g. `python eval.py | tee result.txt`), otherwise a crashing +# eval is masked by tee's exit code and the script passes silently. +set -o pipefail + if [ -z "$MODEL_PATH" ]; then echo "Unsupported model argument: Expected a huggingface model path or model name" >&2 exit 1 @@ -216,7 +220,11 @@ if [[ $TASKS =~ "quant" ]] || [[ ! -d "$SAVE_PATH" ]] || [[ ! $(ls -A $SAVE_PATH RUN_ARGS+=" --trust_remote_code " fi - python run_tensorrt_llm.py --checkpoint_dir=$SAVE_PATH $RUN_ARGS + # Only run the deploy+generate smoke test when "quant" is explicitly requested. Eval tasks + # (lm_eval/mmlu/simple_eval) deploy the checkpoint themselves, so it is redundant there. + if [[ $TASKS =~ "quant" ]]; then + python run_tensorrt_llm.py --checkpoint_dir=$SAVE_PATH $RUN_ARGS + fi fi if [[ -d "${MODEL_PATH}" ]]; then @@ -285,11 +293,16 @@ if [[ $TASKS =~ "mmlu" ]]; then tar -xf /tmp/mmlu.tar -C data && mv data/data $MMLU_DATA_PATH fi + mmlu_flags="" + if [ -n "$MMLU_LIMIT" ]; then + mmlu_flags+=" --limit $MMLU_LIMIT " + fi + python mmlu.py \ --model_name causal \ --model_path $MODEL_ABS_PATH \ --checkpoint_dir $SAVE_PATH \ - --data_dir $MMLU_DATA_PATH | tee $MMLU_RESULT + --data_dir $MMLU_DATA_PATH $mmlu_flags | tee $MMLU_RESULT popd fi @@ -304,16 +317,16 @@ if [[ $TASKS =~ "livecodebench" || $TASKS =~ "simple_eval" ]]; then trtllm-serve $SAVE_PATH --host 0.0.0.0 --port $PORT >$SAVE_PATH/serve.txt 2>&1 & SERVE_PID=$! - tail -f $SAVE_PATH/serve.txt | while read line; do - if echo "$line" | grep -q "Application startup complete"; then - echo "Application startup complete." - break - fi + # Poll the log instead of `tail -f | while ... break`: under `set -o pipefail` (set above), + # breaking out of that pipeline leaves tail to die by SIGPIPE, which would abort the script. + while ! grep -q "Application startup complete" $SAVE_PATH/serve.txt 2>/dev/null; do if ! kill -0 $SERVE_PID 2>/dev/null; then echo "trtllm-serve has exited." exit 1 fi + sleep 2 done + echo "Application startup complete." pushd ../llm_eval/ diff --git a/examples/llm_ptq/scripts/parser.sh b/examples/llm_ptq/scripts/parser.sh index 3be7706c4e6..3efed91bc32 100644 --- a/examples/llm_ptq/scripts/parser.sh +++ b/examples/llm_ptq/scripts/parser.sh @@ -28,6 +28,7 @@ parse_options() { LM_EVAL_TASKS="mmlu,gsm8k" LM_EVAL_LIMIT= SIMPLE_EVAL_TASKS="mmlu" + MMLU_LIMIT= TASKS="quant" @@ -38,7 +39,7 @@ parse_options() { CAST_MXFP4_TO_NVFP4=false # Parse command-line options - ARGS=$(getopt -o "" -l "model:,quant:,recipe:,kv_cache_quant:,tp:,pp:,sparsity:,awq_block_size:,calib:,calib_batch_size:,auto_quantize_bits:,output:,batch:,tasks:,lm_eval_tasks:,lm_eval_limit:,simple_eval_tasks:,simple_eval_limit:,trust_remote_code,use_seq_device_map,gpu_max_mem_percentage:,kv_cache_free_gpu_memory_fraction:,low_memory_mode,no-verbose,calib_dataset:,calib_seq:,auto_quantize_method:,auto_quantize_score_size:,auto_quantize_checkpoint:,moe_calib_experts_ratio:,cast_mxfp4_to_nvfp4" -n "$0" -- "$@") + ARGS=$(getopt -o "" -l "model:,quant:,recipe:,kv_cache_quant:,tp:,pp:,sparsity:,awq_block_size:,calib:,calib_batch_size:,auto_quantize_bits:,output:,batch:,tasks:,lm_eval_tasks:,lm_eval_limit:,simple_eval_tasks:,simple_eval_limit:,mmlu_limit:,trust_remote_code,use_seq_device_map,gpu_max_mem_percentage:,kv_cache_free_gpu_memory_fraction:,low_memory_mode,no-verbose,calib_dataset:,calib_seq:,auto_quantize_method:,auto_quantize_score_size:,auto_quantize_checkpoint:,moe_calib_experts_ratio:,cast_mxfp4_to_nvfp4" -n "$0" -- "$@") eval set -- "$ARGS" while true; do @@ -61,6 +62,7 @@ parse_options() { --lm_eval_limit ) LM_EVAL_LIMIT="$2"; shift 2;; --simple_eval_tasks ) SIMPLE_EVAL_TASKS="$2"; shift 2;; --simple_eval_limit ) SIMPLE_EVAL_LIMIT="$2"; shift 2;; + --mmlu_limit ) MMLU_LIMIT="$2"; shift 2;; --trust_remote_code ) TRUST_REMOTE_CODE=true; shift;; --use_seq_device_map ) USE_SEQ_DEVICE_MAP=true; shift;; --gpu_max_mem_percentage ) GPU_MAX_MEM_PERCENTAGE="$2"; shift 2;; @@ -161,6 +163,7 @@ parse_options() { echo "lm_eval_limit: $LM_EVAL_LIMIT" echo "simple_eval_tasks: $SIMPLE_EVAL_TASKS" echo "simple_eval_limit: $SIMPLE_EVAL_LIMIT" + echo "mmlu_limit: $MMLU_LIMIT" echo "num_sample: $NUM_SAMPLES" echo "use_seq_device_map: $USE_SEQ_DEVICE_MAP" echo "gpu_max_mem_percentage: $GPU_MAX_MEM_PERCENTAGE" diff --git a/modelopt/deploy/llm/generate.py b/modelopt/deploy/llm/generate.py index 0f649199ec2..39306504137 100644 --- a/modelopt/deploy/llm/generate.py +++ b/modelopt/deploy/llm/generate.py @@ -62,6 +62,7 @@ def __init__( trust_remote_code: bool = False, max_seq_len: int = 0, max_batch_size: int = 0, + enable_kv_cache_reuse: bool = True, ): """Initializes the LLM runner class. @@ -73,6 +74,10 @@ def __init__( trust_remote_code: whether to trust the remote code (for the torch backend). max_seq_len: Max sequence length for the LLM backend. If 0, it is not specified. max_batch_size: Max batch size for the LLM backend. If 0, it is not specified. + enable_kv_cache_reuse: whether to enable KV cache block reuse. Must be disabled when + requesting context logits (e.g. lm-eval loglikelihood tasks): with prefix block + reuse, shared-prefix requests only return logits for the recomputed suffix, which + breaks per-token logprob computation. """ with open(Path(checkpoint_dir) / "config.json") as config_file: config = json.load(config_file) @@ -124,6 +129,8 @@ def _find_max_position_embeddings(cfg: dict) -> int | None: trt_kv_cache_config.max_tokens = self._max_seq_len * ( max_batch_size if max_batch_size > 0 else 8 ) + trt_kv_cache_config.enable_block_reuse = enable_kv_cache_reuse + self._enable_kv_cache_reuse = enable_kv_cache_reuse cuda_graph_config = None if max_batch_size > 0: @@ -281,6 +288,11 @@ def generate_context_logits( assert self._support_context_logits_and_stop_words, ( "Context logits are not supported with the current tensorrt_llm version." ) + assert not self._enable_kv_cache_reuse, ( + "Context logits require enable_kv_cache_reuse=False: with KV cache prefix reuse, " + "shared-prefix requests only return logits for the recomputed suffix, producing " + "truncated (and silently incorrect) context logits." + ) assert temperature >= 0.0, "Temperature must be greater than 0.0." kwargs = _sanitize_temperature_and_top_p(temperature, top_p) diff --git a/tests/examples/llm_eval/test_llm_eval.py b/tests/examples/llm_eval/test_llm_eval.py index 421dd408e3d..2c04a878ad1 100644 --- a/tests/examples/llm_eval/test_llm_eval.py +++ b/tests/examples/llm_eval/test_llm_eval.py @@ -41,11 +41,13 @@ def test_lm_eval_hf(tmp_path): @minimum_sm(89) -@pytest.mark.timeout(600) +@pytest.mark.timeout(900) def test_qwen3_eval_fp8(tmp_path): - # Bump max_position_embeddings: TRT-LLM serve rejects prompts longer than - # max_seq_len, and the default (32) is shorter than even simple MMLU prompts. - model_dir = create_tiny_qwen3_dir(tmp_path, with_tokenizer=True, max_position_embeddings=2048) + # Bump max_position_embeddings: TRT-LLM serve rejects prompts longer than max_seq_len. + # The default (32) is shorter than even simple MMLU prompts, and 2048 is shorter than + # 5-shot gsm8k prompts (~3.9k tokens). The eval LLM caps max_seq_len at max_gen_toks + 4096, + # so 8192 leaves headroom for the longest prompts we evaluate. + model_dir = create_tiny_qwen3_dir(tmp_path, with_tokenizer=True, max_position_embeddings=8192) try: run_llm_ptq_command( model=str(model_dir), @@ -56,6 +58,9 @@ def test_qwen3_eval_fp8(tmp_path): simple_eval_tasks="humaneval", lm_eval_limit=16, simple_eval_limit=16, + # MMLU has no inherent sample cap and otherwise evaluates the full ~14k-question + # test set; limit it like lm_eval/simple_eval to keep this a fast smoke test. + mmlu_limit=16, output=128, # Cap generation length: gsm8k/humaneval otherwise generate up to 1024 tokens/sample batch=8, ) From 8768bb5a35983af3e88c80c81881ec1a5cc80e50 Mon Sep 17 00:00:00 2001 From: Rohan Joshi Date: Mon, 8 Jun 2026 15:11:37 -0700 Subject: [PATCH 03/24] Add Alpamayo-1 example (#1594) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What does this PR do? Type of change: ? New example Adds example for Alpamayo-1 quantization with ModelOpt (FP8, NVFP4, AutoQuant) ### Usage ``` python quantize.py --ckpt nvidia/Alpamayo-R1-10B --output-dir ./alpamayo-r1-fp8 --quantize fp8 ``` ### Testing ### Before your PR is "*Ready for review*" Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md) and your commits are signed (`git commit -s -S`). Make sure you read and follow the [Security Best Practices](https://github.com/NVIDIA/Model-Optimizer/blob/main/SECURITY.md#security-coding-practices-for-contributors) (e.g. avoiding hardcoded `trust_remote_code=True`, `torch.load(..., weights_only=False)`, `pickle`, etc.). - Is this change backward compatible?: ✅ / ❌ / N/A - If you copied code from any other sources or added a new PIP dependency, did you follow guidance in `CONTRIBUTING.md`: ✅ / ❌ / N/A - Did you write any new necessary tests?: ✅ / ❌ / N/A - Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?: ✅ / ❌ / N/A - Did you get Claude approval on this PR?: ✅ / ❌ / N/A ### Additional Information ## Summary by CodeRabbit * **New Features** * Added Alpamayo 1 vision-language-action model quantization example supporting FP8, NVFP4, and mixed-precision optimization modes * Introduced CLI quantization tool with calibration loop and checkpoint export capabilities for both fake-quantized and real-quantized formats * **Documentation** * Added comprehensive guide documenting the Alpamayo quantization example, model details, and usage instructions --------- Signed-off-by: Rohan Joshi Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> --- CHANGELOG.rst | 1 + ...ws_train_set_for_calibration_25.10.parquet | Bin 0 -> 3949 bytes examples/alpamayo/README.md | 72 ++ examples/alpamayo/quantize.py | 652 ++++++++++++++++++ tests/examples/alpamayo/test_quantize.py | 78 +++ 5 files changed, 803 insertions(+) create mode 100644 examples/alpamayo/0417_16rows_train_set_for_calibration_25.10.parquet create mode 100644 examples/alpamayo/README.md create mode 100644 examples/alpamayo/quantize.py create mode 100644 tests/examples/alpamayo/test_quantize.py diff --git a/CHANGELOG.rst b/CHANGELOG.rst index da02b315f67..967368f3116 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -24,6 +24,7 @@ Changelog **New Features** - Extend Claude Code agent skills for PTQ, deployment, evaluation, monitoring, and baseline-vs-quantized result comparison. Adds evaluation task references for additional benchmarks, stronger PTQ checkpoint validation gates, and session-scoped workspace/job tracking. +- Add ``examples/alpamayo`` showing FP8, NVFP4, and AutoQuantize (mixed-precision) quantization of the Alpamayo (formerly Alpamayo-R1) ~10B vision-language-action model, with a joint VLM + diffusion calibration loop and both fake-quant and ``--real-quant`` packed-checkpoint export. See `examples/alpamayo/README.md `_ for details. - Add SLURM Quality of Service (QoS) support to the ModelOpt launcher. Users can set QoS via ``slurm_config.qos`` or ``SLURM_QOS`` and the value is forwarded to ``nemo_run.SlurmExecutor``. - Add composable ``$import`` system for recipe YAML configs, enabling reusable config snippets referenced via ``{$import: name}`` markers. All built-in PTQ recipes converted to use imports with shared snippets under ``modelopt_recipes/configs/`` (numeric formats, quant_cfg building blocks, presets). See :ref:`composable-imports`. - Add offline DFlash speculative decoding training. Train the draft module from pre-computed base-model hidden states dumped by ``examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py``; base-model transformer layers are deleted after conversion to save memory. Controlled by the auto-derived ``dflash_offline`` flag on ``DFlashConfig`` (derived from ``data_args.offline_data_path``). The dump scripts now share ``collect_hidden_states/common.py`` for aux-layer selection (``--aux-layers eagle|dflash|``) and optional assistant-token ``loss_mask`` for answer-only-loss training. diff --git a/examples/alpamayo/0417_16rows_train_set_for_calibration_25.10.parquet b/examples/alpamayo/0417_16rows_train_set_for_calibration_25.10.parquet new file mode 100644 index 0000000000000000000000000000000000000000..c16c7d2b98ca493a148dc528f4244f412d51462d GIT binary patch literal 3949 zcmb_f&2J+~6>n$ltUZ%JW`!*qhJE=te~A4JI*HAjAzDQfTF7I>ZFr) z+xR0+G8$=4Xv7I24l5354@gKW4oG|Bz<&S|oK}bvXAX1Xz=7pewPQKS4B8dja;mG| zd-Z<5_fchM&uPfnWcKBa?7uSG*?}%LWvMOsxb3FqQ;=wjqGnqm?6?TarY^utNLXhgI2GXdm zBgzOtx{Fl>D+ZC7&h1oP_G}aKp_qwJ_DtPEHa2`D6H7I8N-PU|sn@S$SN9l(jCiJu zGy*hj-_nrE6b+n#jfE3;U)j{Quwtr6ws@>mmyvDzF4D1S66z^Bv+v|z|H@WfLm{?} zOiS^RX2=#$V+0x4XSU@t##|99^|8$i)kZ4z3D^-C5oQ`l#^1>LNzSY#-^ep z!*f*xMRpO9JsC2oyHGCE^%ap~izrMp4HfBzVj)elAXd$=5wryLDJH(=-}(9L4>hVN zHntStqf&^15@b^`rWz(G}az0!B|WrQ)|l8laK%K_x}LsTJ~EUR1%>1 zL;^lJ>8^N-P3iS`(-&Xx!?Kk9G%-IciF7$5r4q@%rc&=FrN6DFlONy6tR=s_lu1i_ znRnAt;$|k3{ERO*SAQV=b~P>i8{U6eT}w(UnY+^OSJxikn?MPkKc(XL)imVO(n_q! zr}yHC1pMA0eOrjh?0ebVc6Lb|@Jx~8jE5wM00|VA5}p!@eEbsrGQ~LhL{Z8n3qR7O z{PM^>WId8iOWBn>OZj(>+1VXwdmX%HC0MVzVKCSaM*Alu3_REbvXpx`iH>9Y)MyX} z{70n`{_#VJpA#U+&`Df zzxzh{A1mu$skqlzUqZS_V4Pop^s_|n7m4!Awep|8$#ZzMxYzvUK1s~8ko(Q8T(O+a zeZD3Y?&X&UB=X28|I6j)r6BUy>Au?!C%tG?dcXL?=cQwi!K^eS(E%&nE0!YCV<3no zVF;_ym<-3T#A4nbz*NULUNOFsU;G$UvCDnopUpg9cH{>v^!6{yVrKi}vjImeO`^c< zdohVgZ*VqO^yv|E$DCY`jfqFb1k9h8SOn%lbinsMds;BRz>%PdOJ>fFI48W+O< ziw95Ci+ez47F+4kA-axvO??HO_WPlz!=)5lNEZ@t9hjfgHvnFUuK)$n*wC&c&Bgx< zN?Zijq2}T%P+zUGA~4=nWIP@QbTSql^&FC1a)*qJ`$JLwFrY&+Jmb|LoRQ(MKjnl@ z*l-l|BQ8~x`}jUC@qIHsi^}`zeYNyL%C9($#z(F94E z(ycYMwz7LdH=AP`?F1hM70S-GhNPt**N^%K^=Rvm_PnqjOb_Zk{gAevahr+PIvs6q zIDxA+4&CT6$eihHed2D^jrzl9)18OTsJ^jv?$w%8?rWzld)%g@*dwncPufb1-)+xv zC!LlYa$6cTM|6W@b_d2X_m19q(n{Hy0}wV}H;&dUkBmOcRg$K5A&jtxBJDQnu8 zwJXAJy=Moc*F5qzgnry@Hx9uE?mPj`o1>uH(vPUp7|@<8{56}6+V<0fz`w%XJp=pd z7;<>Q%L$0m)Z;u|<-DMaDv(denY_R$YEDRN_oNFk@ccmRV|O$3z0Gh6aTpNCNvAgm zJL=t z5v_$&s0HM)dkFJm5yP8!_O?!_GJPAzfI?m4_~p>6RgXK6+j%WF@rdVl5tnR2+-IT= zJ<@89x>b9Te;;Z&d^?wL;yvu(7*CgJyP_}l0rW54J^Bjyigu?E`^?|<`nhTa@)`Sc zzwW-P0sq>m-G&<0dsXahZVft76VJ|t$!iPsJnOcqn5uPNbLhKMu-SsMcP#o8`UU1& zCtjRSrQ6;*CoOv-&hSzFZ2QQW>}+h=;*8GucSASY9Z*FRK6V1_eEVo8@Z+1P#;=Nt zSYgWVC}+fnnz&!YAnudRhj58mpmfWcZT_8zXlxK;jFd aIQv2{'<|traj_history|>' * num_traj_token}<|traj_history_end|>" + ) + + return [ + { + "role": "system", + "content": [ + { + "type": "text", + "text": "You are a driving assistant that generates safe and accurate actions.", + } + ], + }, + { + "role": "user", + "content": [{"type": "image", "image": frame} for frame in frames] + + [ + { + "type": "text", + "text": f"{hist_traj_placeholder}output the chain-of-thought reasoning of the \ + driving process, then output the future trajectory.", + } + ], + }, + { + "role": "assistant", + "content": [ + { + "type": "text", + "text": "<|cot_start|>", + } + ], + }, + ] + + +def get_processor(tokenizer: AutoTokenizer) -> AutoProcessor: + """Get the processor for the Qwen3-VL-2B-Instruct model.""" + processor_kwargs = { + "min_pixels": MIN_PIXELS, + "max_pixels": MAX_PIXELS, + } + + processor = AutoProcessor.from_pretrained(BASE_PROCESSOR_NAME, **processor_kwargs) + processor.tokenizer = tokenizer + return processor + + +def to_device( + data: Any, + device: str | torch.device | None = None, + dtype: torch.dtype | None = None, +) -> Any: + """Recursively cast data into the specified device, dtype.""" + if isinstance(data, torch.Tensor): + data = data.to( + device=device, + dtype=dtype, + ) + return data + elif isinstance(data, collections.abc.Mapping): + return {key: to_device(data[key], device=device, dtype=dtype) for key in data} + elif isinstance(data, collections.abc.Sequence) and not isinstance(data, (str, bytes)): + return [to_device(elem, device=device, dtype=dtype) for elem in data] + else: + return data + + +def _teacher_forced_flow_loss_forward( + self, + data: dict[str, Any], +) -> dict[str, torch.Tensor]: + """Differentiable forward that returns the flow-matching training targets. + + Bypasses autoregressive reasoning generation and diffusion sampling. + The VLM runs in a single non-sampling forward pass (with ```` + appended to the prompt) to build the prompt KV cache; the expert then runs once + on a linearly-interpolated noisy action and returns the predicted velocity field. + + Args: + data: dict with ``tokenized_data`` (input_ids + other processor outputs), + ``ego_history_xyz``, ``ego_history_rot``, ``ego_future_xyz``, + ``ego_future_rot``. + + Returns: + dict with keys ``v_pred`` and ``v_target``, both shape + ``(b,n_diffusion_tokens, action_dim)``. Callers compute MSE between them. + """ + ego_history_xyz = data["ego_history_xyz"] + ego_history_rot = data["ego_history_rot"] + ego_future_xyz = data["ego_future_xyz"] + ego_future_rot = data["ego_future_rot"] + b, n_traj_group, _, _ = ego_history_xyz.shape + assert n_traj_group == 1, "Only one trajectory group is supported." + + tokenized_data = dict(data["tokenized_data"]) + input_ids = tokenized_data.pop("input_ids") + traj_data_vlm = { + "ego_history_xyz": ego_history_xyz, + "ego_history_rot": ego_history_rot, + } + input_ids = self.fuse_traj_tokens(input_ids, traj_data_vlm) + device = input_ids.device + + # Append so the expert attends through the full prompt. + traj_future_start_id = self.tokenizer.convert_tokens_to_ids( + to_special_token("traj_future_start") + ) + start_col = torch.full( + (input_ids.shape[0], 1), + traj_future_start_id, + dtype=input_ids.dtype, + device=device, + ) + input_ids = torch.cat([input_ids, start_col], dim=1) + if "attention_mask" in tokenized_data and tokenized_data["attention_mask"] is not None: + am = tokenized_data["attention_mask"] + tokenized_data["attention_mask"] = torch.cat( + [am, torch.ones((am.shape[0], 1), dtype=am.dtype, device=am.device)], dim=1 + ) + + vlm_outputs = self.vlm( + input_ids=input_ids, + use_cache=True, + return_dict=True, + **tokenized_data, + ) + prompt_cache = vlm_outputs.past_key_values + prefill_seq_len = prompt_cache.get_seq_length() + rope_deltas = self.vlm.model.rope_deltas + + n_diffusion_tokens = self.action_space.get_action_space_dims()[0] + offset = torch.full((b,), prefill_seq_len, device=device, dtype=torch.long) + + position_ids = torch.arange(n_diffusion_tokens, device=device) + position_ids = einops.repeat(position_ids, "l -> 3 b l", b=b).clone() + delta = rope_deltas + offset[:, None] + position_ids += delta.to(position_ids.device) + + # No padding between prompt cache and action block: full attention mask. + attention_mask = torch.zeros( + (b, 1, n_diffusion_tokens, prefill_seq_len + n_diffusion_tokens), + dtype=torch.float32, + device=device, + ) + + forward_kwargs = {} + if self.config.expert_non_causal_attention: + forward_kwargs["is_causal"] = False + + # Build flow-matching target: x_1 = GT action, x_0 ~ N(0, I). + x_1 = self.action_space.traj_to_action( + traj_history_xyz=ego_history_xyz[:, 0], + traj_history_rot=ego_history_rot[:, 0], + traj_future_xyz=ego_future_xyz[:, 0], + traj_future_rot=ego_future_rot[:, 0], + ) # (b,n_diffusion_tokens, 2) + x_1 = x_1.to(device=device, dtype=torch.float32) + + x_0 = torch.randn_like(x_1) + t = torch.rand(b, 1, 1, device=device, dtype=x_1.dtype) + x_t = (1.0 - t) * x_0 + t * x_1 + v_target = x_1 - x_0 + + # Cast to action-module dtype to match action_in_proj / expert weights. + proj_dtype = next(self.action_in_proj.parameters()).dtype + x_t_cast = x_t.to(dtype=proj_dtype) + t_cast = t.to(dtype=proj_dtype) + + future_token_embeds = self.action_in_proj(x_t_cast, t_cast) + if future_token_embeds.dim() == 2: + future_token_embeds = future_token_embeds.view(b, n_diffusion_tokens, -1) + + expert_out = self.expert( + inputs_embeds=future_token_embeds, + position_ids=position_ids, + past_key_values=prompt_cache, + attention_mask=attention_mask, + use_cache=True, + **forward_kwargs, + ) + prompt_cache.crop(prefill_seq_len) + last_hidden = expert_out.last_hidden_state[:, -n_diffusion_tokens:] + v_pred = self.action_out_proj(last_hidden).view(b, *self.action_space.get_action_space_dims()) + + return {"v_pred": v_pred.to(torch.float32), "v_target": v_target} + + +def patch_teacher_forced_flow_loss_forward() -> None: + """Attach teacher_forced_flow_loss_forward to AlpamayoR1 if missing. + + The public OSS AlpamayoR1 (github.com/nvlabs/alpamayo) does not define this + method; it exists only on the internal training fork. The body ported above + is the calibration path used by auto_quantize_model. + """ + if not hasattr(AlpamayoR1, "teacher_forced_flow_loss_forward"): + AlpamayoR1.teacher_forced_flow_loss_forward = _teacher_forced_flow_loss_forward + + +patch_teacher_forced_flow_loss_forward() + + +def make_joint_calibration_forward_loop( + *, + clip_ids: list[str], + processor, + t0_us: int, + top_p: float, + temperature: float, + max_generation_length: int, + calibration_traj_samples: int, + device: str, +): + """ + Build a calibration loop that exercises both VLM generation and diffusion. + + This avoids text-only calibration and ensures quantizers in the rollout path + (vlm/expert/diffusion-related modules) observe representative activations. + """ + + def _calibration_loop(runtime_model): + runtime_model.eval() + with torch.no_grad(): + for clip_id in tqdm(clip_ids, desc="Calibration"): + data = load_physical_aiavdataset(clip_id, t0_us=t0_us) + messages = create_message(data["image_frames"].flatten(0, 1)) + inputs = processor.apply_chat_template( + messages, + tokenize=True, + add_generation_prompt=False, + continue_final_message=True, + return_dict=True, + return_tensors="pt", + ) + model_inputs = { + "tokenized_data": inputs, + "ego_history_xyz": data["ego_history_xyz"], + "ego_history_rot": data["ego_history_rot"], + } + model_inputs = to_device(model_inputs, device) + + with torch.autocast("cuda", dtype=torch.float16): + runtime_model.sample_trajectories_from_data_with_vlm_rollout( + data=model_inputs, + top_p=top_p, + temperature=temperature, + num_traj_samples=calibration_traj_samples, + max_generation_length=max_generation_length, + ) + + return _calibration_loop + + +def read_clip_ids_from_parquet(parquet_path: str) -> list[str]: + """ + Reads clip_ids from the parquet's "key" column. + Returns clip_ids as a list of strings (unique, preserving first occurrence order). + """ + parquet_path = str(parquet_path) + df = pd.read_parquet(parquet_path) + cols_lower = {c.lower(): c for c in df.columns} + clip_ids = df[cols_lower["key"]].astype(str).tolist() + + seen = set() + uniq = [] + for cid in clip_ids: + if cid not in seen: + seen.add(cid) + uniq.append(cid) + return uniq + + +def quantize_model(model, args, tokenizer=None, calibration_forward_loop=None): + """ + Quantize a PyTorch model using ModelOpt post-training quantization (PTQ). + + This function applies quantization to reduce model precision for faster inference + while maintaining acceptable accuracy. It uses calibration data generated from + the provided tokenizer to determine optimal quantization parameters. + + Supported quantization formats: + - fp8: 8-bit floating point quantization + - nvfp4: 4-bit NVIDIA floating point quantization + Args: + model: PyTorch model to quantize. Must be in evaluation mode. + args: Command line arguments containing quant_format. + tokenizer: Hugging Face tokenizer for creating calibration data. + Required only when `calibration_forward_loop` is not provided. + calibration_forward_loop: Optional callable taking `model` and running + calibration forward passes. Use this for non-text modules whose + forward signature is not compatible with dataset_utils batches. + + Returns: + Quantized model + """ + # Create calibration forward loop. For standard text models we can build + # it from tokenizer-based data, but vision modules often need custom args. + if calibration_forward_loop is None: + if tokenizer is None: + raise ValueError("tokenizer must be provided when calibration_forward_loop is None") + calib_dataloader = get_dataset_dataloader( + tokenizer=tokenizer, + batch_size=32, + num_samples=512, + device="cuda:0", + ) + calibrate_loop = create_forward_loop(dataloader=calib_dataloader) + else: + calibrate_loop = calibration_forward_loop + + if args.quant_format == "fp8": + quant_cfg = copy.deepcopy(mtq.FP8_DEFAULT_CFG) + elif args.quant_format == "nvfp4": + quant_cfg = copy.deepcopy(mtq.NVFP4_DEFAULT_CFG) + else: + raise RuntimeError("Unsupported quantization format") + # Keep the vision tower in high precision. Pass a non-NVFP4 cfg (num_bits=8) with + # enable=False, not just enable=False: an NVFP4-typed QuantConv3d routes to a JIT + # implicit-GEMM CUDA kernel (needs CUDA_HOME) even when disabled. + quant_cfg["quant_cfg"].append( + {"quantizer_name": "*vlm.model.visual*", "enable": False, "cfg": {"num_bits": 8}} + ) + + if args.quant_format == "nvfp4" or getattr(args, "real_quant", False): + # Keep Linear layers whose in/out features aren't multiples of 16 in high precision: + # they break the real-quant GEMM backends (NVFP4 block packing, FP8 torch._scaled_mm). + # In AlpamayoR1 these are the small action-projection heads, so the impact is negligible. + for _name, _module in model.named_modules(): + if isinstance(_module, torch.nn.Linear) and ( + _module.in_features % 16 != 0 or _module.out_features % 16 != 0 + ): + quant_cfg["quant_cfg"].append({"quantizer_name": f"{_name}.*", "enable": False}) + + model = mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop) + + print("================== quantize_model summary ==================") + mtq.print_quant_summary(model) + + return model + + +def auto_quantize_model( + model, + args, + *, + clip_ids, + processor, + t0_us: int, + device: str, +): + """ + Quantize a PyTorch model using ModelOpt's AutoQuantize API. + + Searches per-layer across [NVFP4_DEFAULT_CFG, FP8_DEFAULT_CFG] under the + effective-bits budget in args.auto_quantize_bits. Calibration runs the + teacher-forced flow-matching forward (teacher_forced_flow_loss_forward) on + the calibration clips; the MSE between v_pred and v_target is the search loss. + + Args: + model: PyTorch model to quantize. Must be in eval mode. + args: Namespace with `auto_quantize_bits` (float). + clip_ids: Iterable of clip_ids for calibration. + processor: HF processor used for chat-template tokenization. + t0_us: Trajectory anchor timestamp passed to load_physical_aiavdataset. + device: Device to place calibration tensors on. + + Returns: + Quantized model (the search_state from mtq.auto_quantize is discarded). + """ + + def _one_epoch(): + for clip_id in clip_ids: + data = load_physical_aiavdataset(clip_id, t0_us=t0_us) + messages = create_message(data["image_frames"].flatten(0, 1)) + inputs = processor.apply_chat_template( + messages, + tokenize=True, + add_generation_prompt=False, + continue_final_message=True, + return_dict=True, + return_tensors="pt", + ) + model_inputs = { + "tokenized_data": inputs, + "ego_history_xyz": data["ego_history_xyz"], + "ego_history_rot": data["ego_history_rot"], + "ego_future_xyz": data["ego_future_xyz"], + "ego_future_rot": data["ego_future_rot"], + } + yield to_device(model_inputs, device) + + class _ReusableLoader: + """Re-iterable wrapper so modelopt can run calibration + scoring passes.""" + + def __iter__(self): + return _one_epoch() + + data_loader = _ReusableLoader() + + def forward_step(runtime_model, data): + with torch.autocast("cuda", dtype=torch.bfloat16): + out = runtime_model.teacher_forced_flow_loss_forward(data=data) + v_pred, v_target = out["v_pred"], out["v_target"] + print( + f"[autoquant-fwd] v_pred: finite={torch.isfinite(v_pred).all().item()} " + f"min={v_pred.min().item():.4g} max={v_pred.max().item():.4g} " + f"abs_mean={v_pred.abs().mean().item():.4g} | " + f"v_target: finite={torch.isfinite(v_target).all().item()} " + f"min={v_target.min().item():.4g} max={v_target.max().item():.4g}" + ) + return out + + def loss_func(output, batch): + loss = torch.nn.functional.mse_loss(output["v_pred"], output["v_target"]) + print(f"[autoquant-loss] loss={loss.item():.6g} finite={torch.isfinite(loss).item()}") + return loss + + # Mirror the quantize_model exclusions via disabled_layers (fnmatch against module names), + # since the AutoQuantize search also includes NVFP4: keep the vision tower unquantized, and + # exclude Linear layers whose in/out features aren't multiples of 16. + disabled_layers = ["*lm_head*", "*vlm.model.visual*"] + for _name, _module in model.named_modules(): + if isinstance(_module, torch.nn.Linear) and ( + _module.in_features % 16 != 0 or _module.out_features % 16 != 0 + ): + disabled_layers.append(_name) + + model, search_state = mtq.auto_quantize( + model, + constraints={"effective_bits": args.auto_quantize_bits}, + quantization_formats=["NVFP4_DEFAULT_CFG", "FP8_DEFAULT_CFG"], + data_loader=data_loader, + forward_step=forward_step, + loss_func=loss_func, + disabled_layers=disabled_layers, + verbose=True, + ) + + print("================== auto_quantize search_state ==================") + print(search_state) + + print("================== auto_quantize_model summary ==================") + mtq.print_quant_summary(model) + + return model + + +def main(): + ap = argparse.ArgumentParser(description="Quantize AlpamayoR1 and export as HF checkpoint") + ap.add_argument( + "--ckpt", + type=str, + default="nvidia/Alpamayo-R1-10B", + help="HF hub id or local path of the input checkpoint", + ) + ap.add_argument( + "--output-dir", + type=str, + required=True, + help="Directory to save the quantized HF checkpoint", + ) + ap.add_argument( + "--quantize", + type=str, + required=True, + choices=["fp8", "nvfp4", "auto"], + help="Quantization format", + ) + ap.add_argument( + "--auto_quantize_bits", + type=float, + default=6.5, + help="Effective-bits budget for AutoQuantize (only used when --quantize auto)", + ) + ap.add_argument( + "--parquet", + type=str, + default="0417_16rows_train_set_for_calibration_25.10.parquet", + help="Parquet file with clip_ids for calibration", + ) + ap.add_argument( + "--t0_us", + type=int, + default=5_100_000, + help="Trajectory anchor timestamp passed to load_physical_aiavdataset", + ) + ap.add_argument("--top_p", type=float, default=0.98) + ap.add_argument("--temperature", type=float, default=0.6) + ap.add_argument("--max_generation_length", type=int, default=256) + ap.add_argument("--num_traj_samples", type=int, default=6) + ap.add_argument( + "--limit", type=int, default=16, help="How many clip_ids to use for calibration" + ) + ap.add_argument( + "--real-quant", + action="store_true", + help="Export packed real-quantized weights (fp8 / NVFP4) via " + "modelopt.torch.export.export_hf_checkpoint instead of " + "saving fake-quant fp16 weights with quantizer state.", + ) + args = ap.parse_args() + + script_dir = Path(__file__).resolve().parent + parquet_path = (script_dir / args.parquet).resolve() + + clip_ids = read_clip_ids_from_parquet(str(parquet_path)) + if args.limit is not None and args.limit > 0: + clip_ids = clip_ids[: args.limit] + print(f"Loaded {len(clip_ids)} clip_ids from: {parquet_path}") + + # Patch PreTrainedModel.from_pretrained / save_pretrained so ModelOpt state is saved with the + # checkpoint (and restored when AlpamayoR1.from_pretrained later loads the quantized weights). + mto.enable_huggingface_checkpointing() + + device = "cuda" + print(f"Loading model from {args.ckpt!r} ...") + model = AlpamayoR1.from_pretrained(args.ckpt, dtype=torch.float16).to( + device=device, dtype=torch.float16 + ) + model.eval() + + processor = get_processor(model.tokenizer) + + # Quantize using existing recipe + print(f"Quantizing model ({args.quantize}) ...") + quantization_args = argparse.Namespace( + quant_format=args.quantize, + quant_algo="max", + weight_only=False, + auto_quantize_bits=args.auto_quantize_bits, + real_quant=args.real_quant, + ) + if args.quantize == "auto": + model = auto_quantize_model( + model, + quantization_args, + clip_ids=clip_ids, + processor=processor, + t0_us=args.t0_us, + device=device, + ) + else: + # Build calibration loop + calibration_forward_loop = make_joint_calibration_forward_loop( + clip_ids=clip_ids, + processor=processor, + t0_us=args.t0_us, + top_p=args.top_p, + temperature=args.temperature, + max_generation_length=args.max_generation_length, + calibration_traj_samples=args.num_traj_samples, + device=device, + ) + model = quantize_model( + model, + quantization_args, + calibration_forward_loop=calibration_forward_loop, + ) + model.eval() + + # Save as HF-style checkpoint + os.makedirs(args.output_dir, exist_ok=True) + print(f"Saving quantized checkpoint to {args.output_dir!r} ...") + + if args.real_quant: + # Real (packed) quantization. `mtq.compress` packs weights into the low-precision + # storage format and enables ModelOpt's real-quant GEMM kernels. The ModelOpt-patched + # `save_pretrained` writes the packed weights plus a `modelopt_state.pth`, which + # `AlpamayoR1.from_pretrained` replays to reload and run real-quantized. + # + # NOTE: `export_hf_checkpoint` (the vLLM/TRT-LLM deployment format) isn't used here: it + # has no `modelopt_state.pth`, so a custom model class can't reload it via from_pretrained. + mtq.compress(model) + model.eval() + with torch.inference_mode(): + model.save_pretrained(args.output_dir) + processor.save_pretrained(args.output_dir) + model.config.save_pretrained(args.output_dir) + else: + with torch.inference_mode(): + model.save_pretrained(args.output_dir) + + processor.save_pretrained(args.output_dir) + model.config.save_pretrained(args.output_dir) + + quant_cfg = get_quant_config(model) + with open(os.path.join(args.output_dir, "hf_quant_config.json"), "w") as f: + json.dump(quant_cfg, f) + + print(f"Quantized checkpoint saved to {args.output_dir}") + + +if __name__ == "__main__": + with torch.no_grad(): + main() diff --git a/tests/examples/alpamayo/test_quantize.py b/tests/examples/alpamayo/test_quantize.py new file mode 100644 index 00000000000..e91fdebb0b6 --- /dev/null +++ b/tests/examples/alpamayo/test_quantize.py @@ -0,0 +1,78 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""CPU-only unit tests for pure helpers in examples/alpamayo/quantize.py.""" + +import sys +from pathlib import Path + +import pytest + +# quantize.py imports the gated ``alpamayo_r1`` package (and transformers) at module +# load and monkeypatches at import time, so guard collection on those being installed. +pytest.importorskip("alpamayo_r1") +pytest.importorskip("transformers") + +import pandas as pd +import torch + +sys.path.insert(0, str(Path(__file__).resolve().parents[3] / "examples" / "alpamayo")) + +import quantize + + +class TestReadClipIdsFromParquet: + def test_dedup_preserves_first_occurrence_order(self, tmp_path): + path = tmp_path / "clips.parquet" + pd.DataFrame({"key": ["a", "b", "a", "c", "b"]}).to_parquet(path) + + assert quantize.read_clip_ids_from_parquet(path) == ["a", "b", "c"] + + def test_missing_key_column_raises(self, tmp_path): + path = tmp_path / "clips.parquet" + pd.DataFrame({"id": ["a", "b"]}).to_parquet(path) + + with pytest.raises(KeyError): + quantize.read_clip_ids_from_parquet(path) + + +class TestCreateMessage: + def test_roles_and_structure(self): + messages = quantize.create_message(torch.zeros(3, 3, 8, 8)) + + assert [m["role"] for m in messages] == ["system", "user", "assistant"] + + def test_one_image_entry_per_frame_plus_trailing_text(self): + num_frames = 3 + messages = quantize.create_message(torch.zeros(num_frames, 3, 8, 8)) + + user_content = messages[1]["content"] + image_entries = [c for c in user_content if c["type"] == "image"] + text_entries = [c for c in user_content if c["type"] == "text"] + + assert len(image_entries) == num_frames + assert len(text_entries) == 1 + + def test_history_trajectory_tokens(self): + messages = quantize.create_message(torch.zeros(1, 3, 8, 8)) + + user_text = next(c["text"] for c in messages[1]["content"] if c["type"] == "text") + assert user_text.count("<|traj_history|>") == 48 + assert "<|traj_history_start|>" in user_text + assert "<|traj_history_end|>" in user_text + + def test_non_4d_frames_raises(self): + with pytest.raises(AssertionError): + quantize.create_message(torch.zeros(3, 8, 8)) From 77c266867cd06e6134b675cd84814b0bfa05f01c Mon Sep 17 00:00:00 2001 From: jingyu-ml <108295447+jingyu-ml@users.noreply.github.com> Date: Mon, 8 Jun 2026 16:02:41 -0700 Subject: [PATCH 04/24] Skip Softmax diffusion export (#1269) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What does this PR do? Type of change: New Feature Adds HuggingFace `config.json` export of skip-softmax sparse-attention calibration for diffusion pipelines (e.g. Wan 2.2), on top of the base skip-softmax work. - **`_export_diffusers_checkpoint`** walks every `nn.Module` component of a diffusers pipeline, calls `export_sparse_attention_config`, and writes the result into that component's `config.json` under the `sparse_attention_config` key. The sparse config lives **only** in `config.json` — there is no standalone `sparse.yaml`. - **`export_sparse_attention_config`** emits a `config_groups` schema where each algorithm's parameters are nested inside its own group; only `config_groups` and `producer` are top-level: - skip-softmax group → `algorithm: "skip_softmax"`, `targets`, `ignore` (layers kept dense — e.g. cross-attention + first/last blocks), `initial_disabled_steps` (opt-in, user-set; emitted only when `> 0`), `threshold_scale_factor` (`a * exp(b * target_sparsity)`), and `target_sparsity`. - N:M group → `algorithm: "sparse_softmax"` with `sparsity_n`/`sparsity_m`, `dense_sink_tokens`, `dense_recent_tokens` flattened into the group. - **Deploy reader** (`modelopt/torch/sparsity/attention_sparsity/plugins/sparse_attn_config.py`) reads these per-group params back, keeping the export↔load round-trip consistent. - **Example wiring**: `examples/diffusers/sparsity/wan22_skip_softmax.py` gains `--export-dir`, `--skip-softmax-threshold`, and `--initial-disabled-steps`. `--export-dir` runs `export_hf_checkpoint(pipe, export_dir=...)` after calibration. - Updated `CHANGELOG.rst`. ### Usage ```bash python examples/diffusers/sparsity/wan22_skip_softmax.py \ --model-path Wan-AI/Wan2.2-T2V-A14B-Diffusers \ --calibrate --target-sparsity 0.5 --calib-size 4 \ --initial-disabled-steps 5 \ --export-dir ./wan22_skip_softmax_ckpt ``` Resulting layout — a `config.json` per component, **no `sparse.yaml`**: ``` wan22_skip_softmax_ckpt/ ├── transformer/config.json # carries sparse_attention_config ├── transformer_2/config.json # carries sparse_attention_config ├── vae/ … text_encoder/ … tokenizer/ … scheduler/ … └── model_index.json ``` A representative `config.json` entry for a diffusion transformer: ```json "sparse_attention_config": { "config_groups": { "group_0": { "algorithm": "skip_softmax", "targets": ["WanAttention"], "ignore": ["blocks.0.attn1", "blocks.0.attn2", "…"], "initial_disabled_steps": 5, "threshold_scale_factor": { "formula": "a * exp(b * target_sparsity)", "prefill": {"a": 1443.49, "b": 4.30} }, "target_sparsity": {"prefill": 0.5} } }, "producer": {"name": "modelopt", "version": "0.45.0..."} } ``` The N:M variant adds a second group: ```json "group_1": { "algorithm": "sparse_softmax", "targets": ["WanAttention"], "sparsity_n": 2, "sparsity_m": 4, "dense_sink_tokens": 0, "dense_recent_tokens": 64 } ``` ### Testing - `tests/examples/diffusers_sparsity/test_sparsity.py`: baseline / triton-baseline / fixed-threshold runs of the Wan 2.2 example, plus a Python-API calibrate → **export** test asserting the nested `sparse_attention_config` (`threshold_scale_factor`, `target_sparsity`, `ignore`, `initial_disabled_steps`) and the absence of any `sparse.yaml`. - `tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py` and `test_sparse_attn_config.py`: unit coverage of the per-group export schema and the deploy-reader round-trip (writer nests → reader reads from groups → internal mtsa config unchanged). - Validated end-to-end on Wan 2.2 T2V-A14B: full 4-prompt / 40-step / 81-frame calibration; the exported checkpoint carries the nested schema in both `transformer` and `transformer_2` `config.json`, and runtime measurement shows ~47–49% tile sparsity at a 0.5 target. ### Before your PR is "*Ready for review*" Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md) and your commits are signed (`git commit -s -S`). Make sure you read and follow the [Security Best Practices](https://github.com/NVIDIA/Model-Optimizer/blob/main/SECURITY.md#security-coding-practices-for-contributors) (e.g. avoiding hardcoded `trust_remote_code=True`, `torch.load(..., weights_only=False)`, `pickle`, etc.). - Is this change backward compatible?: ❌ The exported `sparse_attention_config` schema was renamed and nested per-group during 0.45.x development, and the loader reads only the new layout — checkpoints exported by earlier 0.45.x builds must be re-exported. No released version is affected. - If you copied code from any other sources or added a new PIP dependency, did you follow guidance in `CONTRIBUTING.md`: ✅ - Did you write any new necessary tests?: ✅ - Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?: ✅ ### Additional Information --------- Signed-off-by: Jingyu Xin Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> Co-authored-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> Co-authored-by: Claude Opus 4.8 Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> --- .github/workflows/example_tests.yml | 2 +- CHANGELOG.rst | 2 +- examples/diffusers/README.md | 5 +- examples/diffusers/sparsity/README.md | 13 +- .../diffusers/sparsity/wan22_skip_softmax.py | 35 ++- modelopt/torch/export/unified_export_hf.py | 13 + .../calibration/calibrate.py | 3 + .../calibration/calibrator.py | 6 +- .../sparsity/attention_sparsity/config.py | 10 + .../sparsity/attention_sparsity/conversion.py | 103 +++++--- .../plugins/sparse_attn_config.py | 58 ++-- .../diffusers/sparsity/test_sparsity.py | 247 ++++++++++++++++++ .../diffusers_sparsity/test_sparsity.py | 104 -------- .../test_sparse_attention_conversion.py | 40 +-- .../test_sparse_attn_config.py | 71 +++-- 15 files changed, 499 insertions(+), 213 deletions(-) create mode 100644 tests/examples/diffusers/sparsity/test_sparsity.py delete mode 100644 tests/examples/diffusers_sparsity/test_sparsity.py diff --git a/.github/workflows/example_tests.yml b/.github/workflows/example_tests.yml index e69dd08448b..f93bf891b16 100644 --- a/.github/workflows/example_tests.yml +++ b/.github/workflows/example_tests.yml @@ -35,7 +35,7 @@ jobs: strategy: fail-fast: false matrix: - example: [diffusers_sparsity, gpt-oss, llm_distill, llm_qat, llm_sparsity, specdec_bench] + example: [gpt-oss, llm_distill, llm_qat, llm_sparsity, specdec_bench] include: - example: speculative_decoding docker_image: "26.01" diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 967368f3116..9f0d369827d 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -75,7 +75,7 @@ Changelog - Add Puzzletron - a new algorithm for heterogeneous pruning of LLM and VLM models. See `examples/puzzletron/README.md `_ for more details. - Added iterator interface using CalibrationDataReader in ONNX quantization workflow. - Add N:M sparse softmax support to the Triton flash attention kernel (``modelopt.torch.kernels.common.attention.triton_fa``). See `examples/llm_sparsity/attention_sparsity/README.md `_ for usage. -- Add skip-softmax skipping to the Triton flash attention kernel (``modelopt.torch.kernels.common.attention.triton_fa``). See `examples/llm_sparsity/attention_sparsity/README.md `_ for usage. +- Add skip-softmax skipping to the Triton flash attention kernel for both language models and video diffusion models (``modelopt.torch.kernels.common.attention.triton_fa``). See `examples/llm_sparsity/attention_sparsity/README.md `_ and `examples/diffusers/sparsity/ `_ for usage. - Add Video Sparse Attention (VSA) method for video diffusion models (``modelopt.torch.sparsity.attention_sparsity``). VSA uses 3D block tiling with a two-branch architecture for attention speedup. - Enable PTQ workflow for the Step3.5-Flash MoE model with NVFP4 W4A4 + FP8 KV cache quantization. See `modelopt_recipes/models/Step3.5-Flash/nvfp4-mlp-only.yaml `_ for more details. - Add support for vLLM fakequant reload using ModelOpt state for HF models. See `examples/vllm_serve/README.md `_ for more details. diff --git a/examples/diffusers/README.md b/examples/diffusers/README.md index 56d9eb481bf..8fc32d7a324 100644 --- a/examples/diffusers/README.md +++ b/examples/diffusers/README.md @@ -314,7 +314,7 @@ By following these steps, your PEFT LoRA model should be efficiently quantized u ## Sparse Attention (Skip-Softmax) -Skip-softmax sparse attention skips KV tiles whose attention scores are negligible during the softmax computation, reducing FLOPs without retraining. An exponential model (`scale_factor = a * exp(b * target_sparsity)`) is calibrated once, then the target sparsity can be adjusted at runtime without recalibration. +Skip-softmax sparse attention skips KV tiles whose attention scores are negligible during the softmax computation, reducing FLOPs without retraining. An exponential model (`scale_factor = a * exp(b * target_sparsity)`) is calibrated once, then the target sparsity can be adjusted at runtime without recalibration. Calibrated coefficients can be exported as a Hugging Face checkpoint (embedded in each component's `config.json` under `sparse_attention_config`) consumed directly by TRT-LLM's `SkipSoftmaxAttentionConfig.resolve_for_target_sparsity` — no extra conversion needed downstream. ### Getting Started @@ -358,10 +358,11 @@ The 14B model automatically sparsifies both `transformer` and `transformer_2`. ```bash -# 5B/14B model +# 5B/14B model — calibrate and export a TRT-LLM-ready checkpoint python sparsity/wan22_skip_softmax.py \ --model-path Wan-AI/Wan2.2-T2V-A14B-Diffusers|Wan-AI/Wan2.2-TI2V-5B-Diffusers \ --calibrate --target-sparsity 0.5 --calib-size 4 \ + --export-dir /path/to/wan22-skip-softmax-ckpt \ --prompt "A sunset over mountains" --output out.mp4 ``` diff --git a/examples/diffusers/sparsity/README.md b/examples/diffusers/sparsity/README.md index a2b071d6c0e..e102f1b6d5d 100644 --- a/examples/diffusers/sparsity/README.md +++ b/examples/diffusers/sparsity/README.md @@ -43,10 +43,11 @@ python wan22_skip_softmax.py \ --skip-softmax-threshold 0.61557 \ --prompt "A cat playing piano" --output out.mp4 -# With calibration +# Calibrate + export for TRT-LLM deployment (typical flow) python wan22_skip_softmax.py \ --model-path /path/to/Wan2.2-T2V-A14B-Diffusers \ - --calibrate --target-sparsity 0.5 \ + --calibrate --target-sparsity 0.5 --calib-size 4 \ + --export-dir /path/to/wan22-skip-softmax-ckpt \ --prompt "A cat playing piano" --output out.mp4 # Dense baseline (no sparsity, for comparison) @@ -62,6 +63,13 @@ python wan22_skip_softmax.py \ --prompt "A cat playing piano" --output out.mp4 ``` +`--export-dir` writes a Hugging Face checkpoint with the calibrated +`threshold_scale_factor` block embedded in each component's `config.json` +(under the `sparse_attention_config` key). TRT-LLM's +`SkipSoftmaxAttentionConfig.resolve_for_target_sparsity` reads the +`(a, b)` directly via `coeffs['a'] * math.exp(coeffs['b'] * sparsity)` — +no extra conversion needed downstream. + ## Threshold Modes | Mode | How threshold reaches the kernel | Use case | @@ -73,4 +81,5 @@ python wan22_skip_softmax.py \ ## Known Issues - **14B dual transformer calibration**: Transformers are calibrated sequentially — transformer_2's calibration runs while transformer_1 is already sparsified, introducing asymmetric calibration conditions. +- **Resolution-dependent fit**: The kernel's intrinsic `S(λ)` curve depends on the spatial resolution (different attention statistics at 480×832 vs 720×1280). For tightest target ≈ achieved alignment, calibrate at the deployment `(height, width, frames)`. Within a fixed spatial resolution, achieved sparsity stays roughly aligned across frame counts. - **Minimum achievable sparsity**: Even the strictest threshold may yield 30-40% sparsity on diffusion models (many tiles are inherently negligible). Targets below this floor cause extrapolation; an inference-time warning is emitted. diff --git a/examples/diffusers/sparsity/wan22_skip_softmax.py b/examples/diffusers/sparsity/wan22_skip_softmax.py index 3f4447e0bad..890cec547ff 100644 --- a/examples/diffusers/sparsity/wan22_skip_softmax.py +++ b/examples/diffusers/sparsity/wan22_skip_softmax.py @@ -59,6 +59,7 @@ from diffusers.utils import export_to_video import modelopt.torch.sparsity.attention_sparsity as mtsa +from modelopt.torch.export import export_hf_checkpoint from modelopt.torch.sparsity.attention_sparsity.sparse_attention import SparseAttentionModule DEFAULT_MODEL_PATH = os.environ.get("WAN22_MODEL_PATH", "Wan-AI/Wan2.2-TI2V-5B-Diffusers") @@ -190,8 +191,8 @@ def parse_args() -> argparse.Namespace: parser.add_argument( "--calib-frames", type=int, - default=151, - help="Number of frames for calibration", + default=None, + help="Number of frames for calibration (default: same as --num-frames).", ) parser.add_argument( "--calib-size", @@ -199,6 +200,24 @@ def parse_args() -> argparse.Namespace: default=4, help="Number of calibration prompts from OpenVid-1M dataset", ) + + # Export options + parser.add_argument( + "--export-dir", + type=str, + default=None, + help="Export sparsified model as a HuggingFace checkpoint to this directory. " + "The sparse_attention_config (calibration params, disabled layers, etc.) " + "is written into each component's config.json.", + ) + parser.add_argument( + "--initial-disabled-steps", + type=int, + default=0, + help="User-specified initial-disabled-steps value carried into the exported " + "sparse attention config (config.json). Only emitted when > 0; not interpreted " + "at sparsify/calibration time.", + ) return parser.parse_args() @@ -233,6 +252,10 @@ def build_sparse_config(args: argparse.Namespace, num_blocks: int) -> dict: if args.skip_softmax_threshold is not None: attn_cfg["skip_softmax_threshold"] = args.skip_softmax_threshold + # Opt-in initial-disabled-steps metadata — carried through to the exported config when > 0. + if args.initial_disabled_steps > 0: + attn_cfg["initial_disabled_steps"] = args.initial_disabled_steps + sparse_cfg: dict = { "*.attn1*": attn_cfg, # Self-attention only "*.attn2*": {"enable": False}, # Text cross-attention @@ -415,11 +438,12 @@ def main() -> None: if args.calibrate: print("Warning: --calibrate is ignored when --skip-softmax-threshold is set") elif args.calibrate: + calib_frames = args.calib_frames if args.calib_frames is not None else args.num_frames forward_loop = build_calibration_forward_loop( pipe, calib_size=args.calib_size, num_steps=args.calib_steps, - num_frames=args.calib_frames, + num_frames=calib_frames, height=args.height, width=args.width, seed=args.seed, @@ -445,6 +469,11 @@ def main() -> None: torch.cuda.empty_cache() print("Cleared CUDA cache after calibration") + # ---- Export (optional) ---- + if args.export_dir and not args.baseline: + print(f"Exporting sparsified checkpoint to {args.export_dir}...") + export_hf_checkpoint(pipe, export_dir=args.export_dir) + # ---- Generate (optional) ---- if args.prompt: # Enable runtime sparsity measurement before generation diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index 00e4a7008a9..ef5757aa0cb 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -1128,6 +1128,19 @@ def _export_diffusers_checkpoint( else: _save_component_state_dict_safetensors(component, component_export_dir) + # Step 9: Update config.json with sparse attention info (both quantized and non-quantized) + if export_sparse_attention_config is not None: + sparse_attn_config = export_sparse_attention_config(component) + if sparse_attn_config is not None: + config_path = component_export_dir / "config.json" + if config_path.exists(): + with open(config_path) as file: + config_data = json.load(file) + config_data["sparse_attention_config"] = sparse_attn_config + with open(config_path, "w") as file: + json.dump(config_data, file, indent=4) + print(f" Added sparse_attention_config to {config_path.name}") + print(f" Saved to: {component_export_dir}") # Step 4: Export non-nn.Module components (tokenizers, schedulers, feature extractors, etc.) diff --git a/modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py b/modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py index 51df5bb4d4a..186331de948 100644 --- a/modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py +++ b/modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py @@ -346,6 +346,9 @@ def calibrate_sparse_attention( "a": result["a"], "b": result["b"], } + if result.get("fit_logspace"): + params["log_a"] = result["log_a"] + params["fit_logspace"] = True if "min_observed_sparsity" in result: params["min_observed_sparsity"] = result["min_observed_sparsity"] if "max_observed_sparsity" in result: diff --git a/modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py b/modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py index d3ed3303256..aded26fefdc 100644 --- a/modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py +++ b/modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py @@ -275,7 +275,7 @@ def exponential(sparsity, a, b): avg_s = np.mean([p["sparsity"] for p in points]) print(f" {threshold:<12.4f} {avg_sf:<12.2f} {avg_s:<12.2%} {len(points):<8}") - return { + result = { "phase": phase, "a": float(a), "b": float(b), @@ -283,9 +283,13 @@ def exponential(sparsity, a, b): "num_data_points": int(np.sum(valid_mask)), "total_samples": len(all_data_points), "calibration_type": "exponential", + "fit_logspace": self.fit_logspace, "min_observed_sparsity": min_observed_sparsity, "max_observed_sparsity": max_observed_sparsity, } + if self.fit_logspace: + result["log_a"] = float(log_a) + return result def _enable_calibration_mode(self, modules: list[nn.Module]): """Enable calibration mode on sparse attention modules.""" diff --git a/modelopt/torch/sparsity/attention_sparsity/config.py b/modelopt/torch/sparsity/attention_sparsity/config.py index 32a49f02e34..0186d2fb8ec 100644 --- a/modelopt/torch/sparsity/attention_sparsity/config.py +++ b/modelopt/torch/sparsity/attention_sparsity/config.py @@ -155,6 +155,16 @@ class SparseAttentionAttributeConfig(ModeloptBaseConfig): ), ) + initial_disabled_steps: int = ModeloptField( + default=0, + title="Initial disabled steps.", + description=( + "User-specified number of initial disabled steps recorded in the exported " + "sparse attention config. Passed straight through; not interpreted by " + "sparsify or calibration; only written to the checkpoint when > 0." + ), + ) + @field_validator("method") @classmethod def validate_method(cls, v): diff --git a/modelopt/torch/sparsity/attention_sparsity/conversion.py b/modelopt/torch/sparsity/attention_sparsity/conversion.py index 173eaa646b2..8c41b895a8a 100644 --- a/modelopt/torch/sparsity/attention_sparsity/conversion.py +++ b/modelopt/torch/sparsity/attention_sparsity/conversion.py @@ -392,18 +392,24 @@ def export_sparse_attention_config(model: nn.Module) -> dict[str, Any] | None: { "config_groups": { - "group_0": {"sparse_algo": "softmax_skip", "targets": ["LlamaAttention"]} - }, - "threshold_scale_factor": { - "formula": "a * exp(b * target_sparsity)", - "prefill": {"a": 7.93, "b": 8.61}, - "decode": {"a": 0.12, "b": 9.85}, - }, - "sparse_softmax": { - "sparsity_n": 2, - "sparsity_m": 4, - "dense_sink_tokens": 0, - "dense_recent_tokens": 64, + "group_0": { + "algorithm": "skip_softmax", + "targets": ["LlamaAttention"], + "threshold_scale_factor": { + "formula": "a * exp(b * target_sparsity)", + "prefill": {"a": 7.93, "b": 8.61}, + "decode": {"a": 0.12, "b": 9.85}, + }, + "target_sparsity": {"prefill": 0.5}, + }, + "group_1": { + "algorithm": "sparse_softmax", + "targets": ["LlamaAttention"], + "sparsity_n": 2, + "sparsity_m": 4, + "dense_sink_tokens": 0, + "dense_recent_tokens": 64, + }, }, "producer": {"name": "modelopt", "version": "0.37.0"}, } @@ -413,15 +419,24 @@ def export_sparse_attention_config(model: nn.Module) -> dict[str, Any] | None: target_sparse_ratio = None sparse_softmax_config = None target_classes: set[str] = set() + disabled_layer_names: list[str] = [] + initial_disabled_steps = 0 - for module in get_sparse_attention_modules(model): + for name, module in get_named_sparse_attention_modules(model): # Get the original wrapped module's class name if hasattr(module, "get_original_cls_by_level"): original_cls = module.get_original_cls_by_level(level=0) if original_cls is not None: target_classes.add(original_cls.__name__) - # Get calibration params from first module that has them + # Record layers kept dense (e.g. cross-attention, first/last blocks) so the + # deployment side sparsifies the same subset that was calibrated, rather than + # every instance of the target class. + if not module.is_enabled: + disabled_layer_names.append(get_unwrapped_name(name, model)) + continue + + # Get calibration params from first enabled module that has them if calibration_params is None: calibration_params = getattr(module._sparse_method_instance, "calibration_params", None) if target_sparse_ratio is None: @@ -430,36 +445,30 @@ def export_sparse_attention_config(model: nn.Module) -> dict[str, Any] | None: ) if sparse_softmax_config is None: sparse_softmax_config = _get_sparse_softmax_export_config(module) + # A single run-wide value: take it from the first enabled module that sets it + # (same harvesting pattern as calibration_params / target_sparse_ratio above). + if not initial_disabled_steps: + initial_disabled_steps = module._method_config.get("initial_disabled_steps", 0) if calibration_params is None and sparse_softmax_config is None: return None targets = sorted(target_classes) if target_classes else ["Attention"] - config_groups = {} + config_groups: dict[str, Any] = {} group_idx = 0 + + # Each algorithm's parameters live inside its own config group. if calibration_params is not None: - config_groups[f"group_{group_idx}"] = { - "sparse_algo": "softmax_skip", + skip_group: dict[str, Any] = { + "algorithm": "skip_softmax", "targets": targets, } - group_idx += 1 - if sparse_softmax_config is not None: - config_groups[f"group_{group_idx}"] = { - "sparse_algo": "sparse_softmax", - "targets": targets, - } - - # Build the export config - export_config: dict[str, Any] = { - "config_groups": config_groups, - "producer": { - "name": "modelopt", - "version": mo_version, - }, - } - - if calibration_params is not None: - # Build threshold_scale_factor with model parameters + if disabled_layer_names: + skip_group["ignore"] = disabled_layer_names + if initial_disabled_steps: + skip_group["initial_disabled_steps"] = initial_disabled_steps + # threshold_scale_factor (a * exp(b * target_sparsity)) and target_sparsity are + # skip-softmax-specific, so they live in this group. threshold_scale_factor: dict[str, Any] = { "formula": "a * exp(b * target_sparsity)", } @@ -469,14 +478,28 @@ def export_sparse_attention_config(model: nn.Module) -> dict[str, Any] | None: "a": calibration_params[phase]["a"], "b": calibration_params[phase]["b"], } - export_config["threshold_scale_factor"] = threshold_scale_factor + skip_group["threshold_scale_factor"] = threshold_scale_factor + if target_sparse_ratio is not None: + skip_group["target_sparsity"] = target_sparse_ratio + config_groups[f"group_{group_idx}"] = skip_group + group_idx += 1 - if calibration_params is not None and target_sparse_ratio is not None: - export_config["target_sparse_ratio"] = target_sparse_ratio if sparse_softmax_config is not None: - export_config["sparse_softmax"] = sparse_softmax_config + # N:M sparse-softmax params live in this group. + sparse_group: dict[str, Any] = { + "algorithm": "sparse_softmax", + "targets": targets, + } + sparse_group.update(sparse_softmax_config) + config_groups[f"group_{group_idx}"] = sparse_group - return export_config + return { + "config_groups": config_groups, + "producer": { + "name": "modelopt", + "version": mo_version, + }, + } def disable_sparse_attention(model: nn.Module, wildcard_or_filter_func: str | Callable): diff --git a/modelopt/torch/sparsity/attention_sparsity/plugins/sparse_attn_config.py b/modelopt/torch/sparsity/attention_sparsity/plugins/sparse_attn_config.py index 273f2015c92..f7bb5c8a5d4 100644 --- a/modelopt/torch/sparsity/attention_sparsity/plugins/sparse_attn_config.py +++ b/modelopt/torch/sparsity/attention_sparsity/plugins/sparse_attn_config.py @@ -28,9 +28,9 @@ import modelopt.torch.sparsity.attention_sparsity as mtsa -# Maps ``sparse_algo`` values without calibration metadata into mtsa presets. +# Maps ``algorithm`` values without calibration metadata into mtsa presets. ALGO_TO_PRESET = { - "softmax_skip": "SKIP_SOFTMAX_TRITON_DEFAULT", + "skip_softmax": "SKIP_SOFTMAX_TRITON_DEFAULT", } DEFAULT_TARGET_SPARSE_RATIO = {"prefill": 0.5, "decode": 0.5} @@ -83,7 +83,7 @@ def _sparse_softmax_params(sparse_meta: dict, config_groups: dict) -> dict: return params for group in config_groups.values(): - if not isinstance(group, dict) or group.get("sparse_algo") not in SPARSE_SOFTMAX_ALGOS: + if not isinstance(group, dict) or group.get("algorithm") not in SPARSE_SOFTMAX_ALGOS: continue params.update({key: group[key] for key in SPARSE_SOFTMAX_DEFAULTS if key in group}) return params @@ -95,22 +95,26 @@ def _add_sparse_softmax_params(layer_cfg: dict, sparse_meta: dict, config_groups layer_cfg.update(_sparse_softmax_params(sparse_meta, config_groups)) -def _build_calibrated_softmax_skip_config(sparse_meta: dict) -> dict: - """Build a vLLM Triton sparse config from exported calibration metadata.""" - return { - "sparse_cfg": { - "*attn*": { - "method": "triton_skip_softmax", - "threshold_scale_factor": sparse_meta["threshold_scale_factor"], - "target_sparse_ratio": _normalize_target_sparse_ratio( - sparse_meta.get("target_sparse_ratio") - ), - "backend": "triton", - "enable": True, - }, - "default": {"enable": False}, - }, +def _build_calibrated_softmax_skip_config(skip_group: dict) -> dict: + """Build a vLLM Triton sparse config from a skip_softmax config group. + + Layers recorded under ``ignore`` (kept dense at calibration time, e.g. + cross-attention) are disabled first so they remain dense on load. Order + matters: :func:`match_sparse_config` returns the first matching pattern, so + the ``ignore`` entries must precede the catch-all ``*attn*`` rule. + """ + sparse_cfg: dict = {} + for name in skip_group.get("ignore", []): + sparse_cfg[f"*{name}*"] = {"enable": False} + sparse_cfg["*attn*"] = { + "method": "triton_skip_softmax", + "threshold_scale_factor": skip_group["threshold_scale_factor"], + "target_sparse_ratio": _normalize_target_sparse_ratio(skip_group.get("target_sparsity")), + "backend": "triton", + "enable": True, } + sparse_cfg["default"] = {"enable": False} + return {"sparse_cfg": sparse_cfg} def _build_sparse_softmax_config(sparse_meta: dict, config_groups: dict) -> dict: @@ -150,7 +154,7 @@ def load_from_checkpoint_metadata(hf_config) -> tuple[dict, str] | None: Reads ``sparse_attention_config`` written by ModelOpt's HF export (``unified_export_hf.export_sparse_attention_config``). Calibrated - ``softmax_skip`` metadata is converted into a dynamic Triton config; + ``skip_softmax`` metadata is converted into a dynamic Triton config; uncalibrated algorithms fall back to mtsa presets via :data:`ALGO_TO_PRESET`. Args: @@ -170,11 +174,19 @@ def load_from_checkpoint_metadata(hf_config) -> tuple[dict, str] | None: config_groups = sparse_meta.get("config_groups", {}) if not isinstance(config_groups, dict): return None - algos = {grp.get("sparse_algo") for grp in config_groups.values() if isinstance(grp, dict)} - if "softmax_skip" in algos and _has_calibrated_threshold_scale_factor( - sparse_meta.get("threshold_scale_factor") + algos = {grp.get("algorithm") for grp in config_groups.values() if isinstance(grp, dict)} + skip_group = next( + ( + grp + for grp in config_groups.values() + if isinstance(grp, dict) and grp.get("algorithm") == "skip_softmax" + ), + None, + ) + if skip_group is not None and _has_calibrated_threshold_scale_factor( + skip_group.get("threshold_scale_factor") ): - cfg = _build_calibrated_softmax_skip_config(sparse_meta) + cfg = _build_calibrated_softmax_skip_config(skip_group) if _has_sparse_softmax_algo(algos): layer_cfg = cfg["sparse_cfg"]["*attn*"] _add_sparse_softmax_params(layer_cfg, sparse_meta, config_groups) diff --git a/tests/examples/diffusers/sparsity/test_sparsity.py b/tests/examples/diffusers/sparsity/test_sparsity.py new file mode 100644 index 00000000000..5a49093f80b --- /dev/null +++ b/tests/examples/diffusers/sparsity/test_sparsity.py @@ -0,0 +1,247 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for skip-softmax sparse attention on Wan 2.2 (examples/diffusers/sparsity/). + +Uses a tiny Wan 2.2 model (dual transformer, 2 layers, hidden_dim=24) created +from scratch. Tests run the wan22_skip_softmax.py example script in baseline, +triton-baseline, fixed-threshold, and export modes. Also includes a Python API +test for calibration params + export (calibration can't succeed on tiny models +via the Triton kernel, so params are injected directly). +""" + +import json +import math + +import pytest +import torch +from _test_utils.examples.run_command import run_example_command +from _test_utils.torch.diffusers_models import create_tiny_wan22_pipeline_dir + +EXAMPLE_PATH = "diffusers/sparsity" + +# Tiny inference settings — fast but exercises all code paths +_TINY_ARGS = [ + "--num-frames", + "5", + "--height", + "16", + "--width", + "16", + "--num-steps", + "2", + "--guidance-scale", + "1.0", + "--skip-first-last", + "0", + "--negative-prompt", + "", +] + + +@pytest.fixture(scope="session") +def tiny_wan22_path(tmp_path_factory): + """Create a tiny Wan 2.2 pipeline saved to disk (session-scoped).""" + return str(create_tiny_wan22_pipeline_dir(tmp_path_factory.mktemp("tiny_wan22"))) + + +def test_wan22_baseline(tiny_wan22_path, tmp_path): + """Dense baseline — no sparsity, default diffusers attention backend.""" + cmd = [ + "python", + "wan22_skip_softmax.py", + "--model-path", + tiny_wan22_path, + "--baseline", + "--prompt", + "test", + "--output", + str(tmp_path / "baseline.mp4"), + *_TINY_ARGS, + ] + run_example_command(cmd, EXAMPLE_PATH) + + +def test_wan22_triton_baseline(tiny_wan22_path, tmp_path): + """Triton kernel without skip-softmax (threshold=0, apples-to-apples).""" + cmd = [ + "python", + "wan22_skip_softmax.py", + "--model-path", + tiny_wan22_path, + "--triton-baseline", + "--prompt", + "test", + "--output", + str(tmp_path / "triton_baseline.mp4"), + *_TINY_ARGS, + ] + run_example_command(cmd, EXAMPLE_PATH) + + +def test_wan22_skip_softmax_threshold(tiny_wan22_path, tmp_path): + """Skip-softmax with a fixed lambda threshold — no calibration needed.""" + cmd = [ + "python", + "wan22_skip_softmax.py", + "--model-path", + tiny_wan22_path, + "--skip-softmax-threshold", + "0.03125", + "--report-avg-sparsity", + "--prompt", + "test", + "--output", + str(tmp_path / "skip_softmax_threshold.mp4"), + *_TINY_ARGS, + ] + run_example_command(cmd, EXAMPLE_PATH) + + +def test_wan22_export_sparse_checkpoint(tiny_wan22_path, tmp_path): + """Export a fixed-threshold checkpoint and verify the structure. + + A fixed ``--skip-softmax-threshold`` run is uncalibrated, so + ``export_sparse_attention_config`` returns ``None`` and no + ``sparse_attention_config`` is written (the populated-config path is covered + by ``test_wan22_calibrated_export``). Here we just verify the checkpoint is + written and that no standalone ``sparse.yaml`` is produced. + """ + export_dir = tmp_path / "sparse_export" + cmd = [ + "python", + "wan22_skip_softmax.py", + "--model-path", + tiny_wan22_path, + "--skip-softmax-threshold", + "0.03125", + "--export-dir", + str(export_dir), + *_TINY_ARGS, + ] + run_example_command(cmd, EXAMPLE_PATH) + + assert export_dir.exists() + for component in ["transformer", "transformer_2"]: + component_dir = export_dir / component + assert component_dir.exists(), f"Missing component dir: {component}" + config_path = component_dir / "config.json" + assert config_path.exists(), f"Missing config.json for {component}" + with open(config_path) as f: + config_data = json.load(f) + # Fixed (uncalibrated) threshold has nothing to export. + assert "sparse_attention_config" not in config_data, ( + f"Unexpected sparse_attention_config in {component}/config.json for a " + "fixed-threshold (uncalibrated) export" + ) + weight_files = list(component_dir.glob("*.safetensors")) + list(component_dir.glob("*.bin")) + assert len(weight_files) > 0, f"No weight files for {component}" + + # Sparse config (when present) lives only in config.json — never a standalone yaml. + assert not (export_dir / "sparse.yaml").exists(), "Unexpected top-level sparse.yaml" + + +def test_wan22_calibrated_export(tmp_path): + """Inject calibration params via the Python API and verify the exported config. + + Calibration can't succeed on tiny models via the Triton kernel (not enough + data points in the 10%-90% sparsity range), so this test sparsifies, injects + calibration params directly, and exports. Verifies the calibrated config schema + (top-level ``threshold_scale_factor`` of the form ``a * exp(b * target_sparsity)``) + and that the dense (cross-attention) layers are recorded under ``ignore``. + """ + from diffusers import AutoencoderKLWan, WanPipeline + + import modelopt.torch.sparsity.attention_sparsity as mtsa + from modelopt.torch.export import export_hf_checkpoint + from modelopt.torch.sparsity.attention_sparsity.sparse_attention import SparseAttentionModule + + pipe_dir = create_tiny_wan22_pipeline_dir(tmp_path / "model") + vae = AutoencoderKLWan.from_pretrained(pipe_dir, subfolder="vae", torch_dtype=torch.float32) + pipe = WanPipeline.from_pretrained(pipe_dir, vae=vae, torch_dtype=torch.bfloat16) + pipe.to("cuda") + + # Sparsify self-attention only; cross-attention (attn2) stays dense. + sparse_cfg = { + "*.attn1*": { + "method": "triton_skip_softmax", + "skip_softmax_threshold": 0.1, + "backend": "triton", + "is_causal": False, + "initial_disabled_steps": 5, + "enable": True, + }, + "*.attn2*": {"enable": False}, + "default": {"enable": False}, + } + config = {"sparse_cfg": sparse_cfg} + for transformer in [pipe.transformer, pipe.transformer_2]: + mtsa.sparsify(transformer, config) + + # Inject calibration params (simulating a successful log-space calibration). + test_log_a = math.log(1.5) + test_b = 3.0 + calibration_params = { + "prefill": { + "a": math.exp(test_log_a), + "b": test_b, + "log_a": test_log_a, + "fit_logspace": True, + "min_observed_sparsity": 0.15, + "max_observed_sparsity": 0.85, + }, + } + target_sparse_ratio = {"prefill": 0.5} + for transformer in [pipe.transformer, pipe.transformer_2]: + for module in transformer.modules(): + if isinstance(module, SparseAttentionModule) and module.is_enabled: + module._sparse_method_instance.calibration_params = calibration_params + module._sparse_method_instance.target_sparse_ratio = target_sparse_ratio + + export_dir = tmp_path / "calibrated_export" + export_hf_checkpoint(pipe, export_dir=export_dir) + + for component in ["transformer", "transformer_2"]: + config_path = export_dir / component / "config.json" + assert config_path.exists(), f"Missing config.json for {component}" + with open(config_path) as f: + config_data = json.load(f) + assert "sparse_attention_config" in config_data, ( + f"No sparse_attention_config in {component}/config.json" + ) + + sa_config = config_data["sparse_attention_config"] + group_0 = sa_config["config_groups"]["group_0"] + assert group_0["algorithm"] == "skip_softmax" + assert group_0["targets"] + + # Dense (uncalibrated) layers must be recorded so deployment skips them too. + assert "ignore" in group_0 + assert any(".attn2" in name for name in group_0["ignore"]) + + # Opt-in initial_disabled_steps metadata is carried through (exported only when > 0). + assert group_0["initial_disabled_steps"] == 5 + + # threshold_scale_factor lives inside the skip_softmax group. + tsf = group_0["threshold_scale_factor"] + assert tsf["formula"] == "a * exp(b * target_sparsity)" + assert tsf["prefill"]["a"] == pytest.approx(math.exp(test_log_a)) + assert tsf["prefill"]["b"] == pytest.approx(test_b) + + # Calibrated mode — no raw_threshold. + assert "raw_threshold" not in group_0 + + # Sparse config lives only in config.json — no standalone sparse.yaml. + assert not (export_dir / "sparse.yaml").exists(), "Unexpected top-level sparse.yaml" diff --git a/tests/examples/diffusers_sparsity/test_sparsity.py b/tests/examples/diffusers_sparsity/test_sparsity.py deleted file mode 100644 index 8c3698b473e..00000000000 --- a/tests/examples/diffusers_sparsity/test_sparsity.py +++ /dev/null @@ -1,104 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for skip-softmax sparse attention on Wan 2.2 (examples/diffusers/sparsity/). - -Uses a tiny Wan 2.2 model (dual transformer, 2 layers, hidden_dim=24) created -from scratch. Tests run the wan22_skip_softmax.py example script in baseline, -triton-baseline, and fixed-threshold modes. -""" - -import pytest -from _test_utils.examples.run_command import run_example_command -from _test_utils.torch.diffusers_models import create_tiny_wan22_pipeline_dir - -EXAMPLE_PATH = "diffusers/sparsity" - -# Tiny inference settings — fast but exercises all code paths -_TINY_ARGS = [ - "--num-frames", - "5", - "--height", - "16", - "--width", - "16", - "--num-steps", - "2", - "--guidance-scale", - "1.0", - "--skip-first-last", - "0", - "--negative-prompt", - "", -] - - -@pytest.fixture(scope="session") -def tiny_wan22_path(tmp_path_factory): - """Create a tiny Wan 2.2 pipeline saved to disk (session-scoped).""" - return str(create_tiny_wan22_pipeline_dir(tmp_path_factory.mktemp("tiny_wan22"))) - - -def test_wan22_baseline(tiny_wan22_path, tmp_path): - """Dense baseline — no sparsity, default diffusers attention backend.""" - cmd = [ - "python", - "wan22_skip_softmax.py", - "--model-path", - tiny_wan22_path, - "--baseline", - "--prompt", - "test", - "--output", - str(tmp_path / "baseline.mp4"), - *_TINY_ARGS, - ] - run_example_command(cmd, EXAMPLE_PATH) - - -def test_wan22_triton_baseline(tiny_wan22_path, tmp_path): - """Triton kernel without skip-softmax (threshold=0, apples-to-apples).""" - cmd = [ - "python", - "wan22_skip_softmax.py", - "--model-path", - tiny_wan22_path, - "--triton-baseline", - "--prompt", - "test", - "--output", - str(tmp_path / "triton_baseline.mp4"), - *_TINY_ARGS, - ] - run_example_command(cmd, EXAMPLE_PATH) - - -def test_wan22_skip_softmax_threshold(tiny_wan22_path, tmp_path): - """Skip-softmax with a fixed lambda threshold — no calibration needed.""" - cmd = [ - "python", - "wan22_skip_softmax.py", - "--model-path", - tiny_wan22_path, - "--skip-softmax-threshold", - "0.03125", - "--report-avg-sparsity", - "--prompt", - "test", - "--output", - str(tmp_path / "skip_softmax_threshold.mp4"), - *_TINY_ARGS, - ] - run_example_command(cmd, EXAMPLE_PATH) diff --git a/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py b/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py index 36d6bca60d9..6cfdd2775b9 100644 --- a/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py +++ b/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py @@ -380,10 +380,11 @@ def test_exports_when_calibration_present(self): out = export_sparse_attention_config(model) assert out is not None assert "config_groups" in out - tsf = out["threshold_scale_factor"] + group_0 = out["config_groups"]["group_0"] + tsf = group_0["threshold_scale_factor"] assert tsf["prefill"] == {"a": 3.14, "b": 7.5} assert tsf["decode"] == {"a": 0.5, "b": 9.0} - assert out["target_sparse_ratio"] == {"prefill": 0.4, "decode": 0.6} + assert group_0["target_sparsity"] == {"prefill": 0.4, "decode": 0.6} assert out["producer"]["name"] == "modelopt" def test_exports_sparse_softmax_metadata(self): @@ -413,14 +414,13 @@ def test_exports_sparse_softmax_metadata(self): out = export_sparse_attention_config(model) assert out is not None - assert out["config_groups"]["group_0"]["sparse_algo"] == "sparse_softmax" - assert out["sparse_softmax"] == { - "sparsity_n": 2, - "sparsity_m": 4, - "dense_sink_tokens": 4, - "dense_recent_tokens": 128, - } - assert "threshold_scale_factor" not in out + group_0 = out["config_groups"]["group_0"] + assert group_0["algorithm"] == "sparse_softmax" + assert group_0["sparsity_n"] == 2 + assert group_0["sparsity_m"] == 4 + assert group_0["dense_sink_tokens"] == 4 + assert group_0["dense_recent_tokens"] == 128 + assert "threshold_scale_factor" not in group_0 def test_exports_calibrated_skip_softmax_with_sparse_softmax_overlay(self): """Combined config exports both calibrated skip-softmax and N:M metadata.""" @@ -443,13 +443,13 @@ def test_exports_calibrated_skip_softmax_with_sparse_softmax_overlay(self): out = export_sparse_attention_config(model) assert out is not None - assert out["config_groups"]["group_0"]["sparse_algo"] == "softmax_skip" - assert out["config_groups"]["group_1"]["sparse_algo"] == "sparse_softmax" - assert out["threshold_scale_factor"]["prefill"] == {"a": 3.14, "b": 7.5} - assert out["target_sparse_ratio"] == {"prefill": 0.4, "decode": 0.6} - assert out["sparse_softmax"] == { - "sparsity_n": 2, - "sparsity_m": 4, - "dense_sink_tokens": 0, - "dense_recent_tokens": 64, - } + group_0 = out["config_groups"]["group_0"] + group_1 = out["config_groups"]["group_1"] + assert group_0["algorithm"] == "skip_softmax" + assert group_1["algorithm"] == "sparse_softmax" + assert group_0["threshold_scale_factor"]["prefill"] == {"a": 3.14, "b": 7.5} + assert group_0["target_sparsity"] == {"prefill": 0.4, "decode": 0.6} + assert group_1["sparsity_n"] == 2 + assert group_1["sparsity_m"] == 4 + assert group_1["dense_sink_tokens"] == 0 + assert group_1["dense_recent_tokens"] == 64 diff --git a/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attn_config.py b/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attn_config.py index 93a01730131..05644a4a20b 100644 --- a/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attn_config.py +++ b/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attn_config.py @@ -82,14 +82,14 @@ def test_returns_none_when_attribute_missing(self): def test_returns_none_for_unknown_algo(self): """Test that an unrecognized sparse_algo returns None.""" - meta = {"config_groups": {"group_0": {"sparse_algo": "future_algo_v9000"}}} + meta = {"config_groups": {"group_0": {"algorithm": "future_algo_v9000"}}} hf_config = types.SimpleNamespace(sparse_attention_config=meta) assert load_from_checkpoint_metadata(hf_config) is None def test_maps_uncalibrated_softmax_skip_to_preset(self): """Test that uncalibrated softmax_skip uses the static Triton preset.""" meta = { - "config_groups": {"group_0": {"sparse_algo": "softmax_skip"}}, + "config_groups": {"group_0": {"algorithm": "skip_softmax"}}, "producer": {"name": "modelopt", "version": "0.37.0"}, } hf_config = types.SimpleNamespace(sparse_attention_config=meta) @@ -107,9 +107,13 @@ def test_maps_calibrated_softmax_skip_to_dynamic_config(self): "decode": {"a": 5.0, "b": 7.0}, } meta = { - "config_groups": {"group_0": {"sparse_algo": "softmax_skip"}}, - "threshold_scale_factor": threshold_scale_factor, - "target_sparse_ratio": {"prefill": 0.4, "decode": 0.6}, + "config_groups": { + "group_0": { + "algorithm": "skip_softmax", + "threshold_scale_factor": threshold_scale_factor, + "target_sparsity": {"prefill": 0.4, "decode": 0.6}, + } + }, "producer": {"name": "modelopt", "version": "0.45.0"}, } hf_config = types.SimpleNamespace(sparse_attention_config=meta) @@ -131,12 +135,14 @@ def test_maps_calibrated_softmax_skip_to_dynamic_config(self): def test_maps_sparse_softmax_to_dynamic_config(self): """Test that checkpoint N:M sparse-softmax metadata restores layer params.""" meta = { - "config_groups": {"group_0": {"sparse_algo": "sparse_softmax"}}, - "sparse_softmax": { - "sparsity_n": 2, - "sparsity_m": 4, - "dense_sink_tokens": 4, - "dense_recent_tokens": 128, + "config_groups": { + "group_0": { + "algorithm": "sparse_softmax", + "sparsity_n": 2, + "sparsity_m": 4, + "dense_sink_tokens": 4, + "dense_recent_tokens": 128, + } }, "producer": {"name": "modelopt", "version": "0.45.0"}, } @@ -166,12 +172,17 @@ def test_maps_calibrated_softmax_skip_with_sparse_softmax_overlay(self): } meta = { "config_groups": { - "group_0": {"sparse_algo": "softmax_skip"}, - "group_1": {"sparse_algo": "sparse_softmax"}, + "group_0": { + "algorithm": "skip_softmax", + "threshold_scale_factor": threshold_scale_factor, + "target_sparsity": {"prefill": 0.4, "decode": 0.6}, + }, + "group_1": { + "algorithm": "sparse_softmax", + "sparsity_n": 2, + "sparsity_m": 4, + }, }, - "threshold_scale_factor": threshold_scale_factor, - "target_sparse_ratio": {"prefill": 0.4, "decode": 0.6}, - "sparse_softmax": {"sparsity_n": 2, "sparsity_m": 4}, "producer": {"name": "modelopt", "version": "0.45.0"}, } hf_config = types.SimpleNamespace(sparse_attention_config=meta) @@ -203,3 +214,31 @@ def test_handles_empty_config_groups(self): """Test that an empty config_groups returns None.""" hf_config = types.SimpleNamespace(sparse_attention_config={"config_groups": {}}) assert load_from_checkpoint_metadata(hf_config) is None + + def test_calibrated_softmax_skip_honors_ignore(self): + """Layers recorded under ``ignore`` stay dense on load; others are sparsified.""" + meta = { + "config_groups": { + "group_0": { + "algorithm": "skip_softmax", + "ignore": ["blocks.0.attn1", "blocks.0.attn2"], + "threshold_scale_factor": { + "formula": "a * exp(b * target_sparsity)", + "prefill": {"a": 2.0, "b": 3.0}, + }, + "target_sparsity": {"prefill": 0.5}, + } + }, + "producer": {"name": "modelopt", "version": "0.45.0"}, + } + hf_config = types.SimpleNamespace(sparse_attention_config=meta) + result = load_from_checkpoint_metadata(hf_config) + assert result is not None + cfg, _ = result + # ``ignore``'d layers stay dense. + assert match_sparse_config("transformer.blocks.0.attn1", cfg) == {"enable": False} + assert match_sparse_config("transformer.blocks.0.attn2", cfg) == {"enable": False} + # A non-ignored self-attention layer is sparsified. + sparsified = match_sparse_config("transformer.blocks.2.attn1", cfg) + assert sparsified["method"] == "triton_skip_softmax" + assert sparsified["enable"] is True From df8e97313f6e910eee05908eec90c8e7e769a92b Mon Sep 17 00:00:00 2001 From: jingyu-ml <108295447+jingyu-ml@users.noreply.github.com> Date: Mon, 8 Jun 2026 16:03:35 -0700 Subject: [PATCH 05/24] Add DMD2 distillation for Qwen-Image (fastgen) (#1326) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What does this PR do? **Type of change:** New example + new `modelopt.torch.fastgen` library module. Adds **DMD2 (Distribution Matching Distillation) for Qwen-Image** — distilling the base model into a few-step (1–4) generator. Includes the framework-agnostic `modelopt.torch.fastgen` loss library (DMD pipeline, EMA, optional GAN discriminator) and a NeMo AutoModel–based training example with a mock-data smoke config, a real-data config, and inference / export scripts. **Noted**: the example script will be migrated to AutoModel repo ### Usage ```bash # Mock-data wiring smoke — runs end-to-end with no dataset to prepare torchrun --nproc-per-node=8 \ examples/diffusers/fastgen/dmd2_finetune.py \ --config examples/diffusers/fastgen/configs/dmd2_qwen_image_smoke.yaml ``` See `examples/diffusers/fastgen/README.md` for real-data training and inference. ### Testing Unit tests under `tests/unit/torch/fastgen/`; `pre-commit` / code-quality clean. ### Before your PR is "*Ready for review*" - Backward compatible?: ✅ (new, additive module) - Followed `CONTRIBUTING.md` for any copied code / new deps: ✅ - New tests added?: ✅ - Updated Changelog?: N/A ## Summary by CodeRabbit * **New Features** * Adds a FastGen-based distillation framework (DMD2) with student/fake-score training, EMA support, GAN discriminator branch, inference pipeline, and export utilities. * Qwen-Image integration with latent packing and feature-capture for plugin-enabled pipelines. * **Documentation** * New README, example configs, and runnable example scripts for Qwen-Image distillation and inference. * **Tests** * Comprehensive unit tests covering math parity, gradient routing, plugins, hooks, EMA, and recipe setup. --------- Signed-off-by: Jingyu Xin Co-authored-by: Claude Opus 4.8 (1M context) Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> --- CHANGELOG.rst | 1 + examples/diffusers/fastgen/README.md | 180 +++ .../fastgen/configs/dmd2_qwen_image.yaml | 173 +++ .../configs/dmd2_qwen_image_smoke.yaml | 109 ++ examples/diffusers/fastgen/dmd2_finetune.py | 38 + examples/diffusers/fastgen/dmd2_recipe.py | 1317 +++++++++++++++++ .../fastgen/export_diffusers_qwen_image.py | 202 +++ .../fastgen/inference_dmd2_qwen_image.py | 528 +++++++ examples/diffusers/fastgen/requirements.txt | 10 + modelopt/torch/fastgen/__init__.py | 68 + modelopt/torch/fastgen/config.py | 315 ++++ modelopt/torch/fastgen/discriminators.py | 136 ++ modelopt/torch/fastgen/ema.py | 258 ++++ modelopt/torch/fastgen/factory.py | 106 ++ modelopt/torch/fastgen/flow_matching.py | 263 ++++ modelopt/torch/fastgen/loader.py | 149 ++ modelopt/torch/fastgen/losses.py | 178 +++ modelopt/torch/fastgen/methods/__init__.py | 18 + modelopt/torch/fastgen/methods/dmd.py | 812 ++++++++++ modelopt/torch/fastgen/pipeline.py | 99 ++ modelopt/torch/fastgen/plugins/__init__.py | 27 + modelopt/torch/fastgen/plugins/qwen_image.py | 381 +++++ modelopt/torch/fastgen/utils.py | 54 + .../general/distillation/dmd2_qwen_image.yaml | 80 + tests/_test_utils/torch/diffusers_models.py | 127 ++ tests/examples/diffusers/conftest.py | 17 + tests/unit/torch/fastgen/conftest.py | 114 ++ .../fastgen/test_dmd_gradient_routing.py | 162 ++ tests/unit/torch/fastgen/test_dmd_math.py | 419 ++++++ .../torch/fastgen/test_dmd_pipeline_step.py | 164 ++ .../torch/fastgen/test_hook_requirements.py | 111 ++ .../fastgen/test_pred_type_conversion.py | 218 +++ .../torch/fastgen/test_qwen_image_plugin.py | 234 +++ 33 files changed, 7068 insertions(+) create mode 100644 examples/diffusers/fastgen/README.md create mode 100644 examples/diffusers/fastgen/configs/dmd2_qwen_image.yaml create mode 100644 examples/diffusers/fastgen/configs/dmd2_qwen_image_smoke.yaml create mode 100644 examples/diffusers/fastgen/dmd2_finetune.py create mode 100644 examples/diffusers/fastgen/dmd2_recipe.py create mode 100644 examples/diffusers/fastgen/export_diffusers_qwen_image.py create mode 100644 examples/diffusers/fastgen/inference_dmd2_qwen_image.py create mode 100644 examples/diffusers/fastgen/requirements.txt create mode 100644 modelopt/torch/fastgen/__init__.py create mode 100644 modelopt/torch/fastgen/config.py create mode 100644 modelopt/torch/fastgen/discriminators.py create mode 100644 modelopt/torch/fastgen/ema.py create mode 100644 modelopt/torch/fastgen/factory.py create mode 100644 modelopt/torch/fastgen/flow_matching.py create mode 100644 modelopt/torch/fastgen/loader.py create mode 100644 modelopt/torch/fastgen/losses.py create mode 100644 modelopt/torch/fastgen/methods/__init__.py create mode 100644 modelopt/torch/fastgen/methods/dmd.py create mode 100644 modelopt/torch/fastgen/pipeline.py create mode 100644 modelopt/torch/fastgen/plugins/__init__.py create mode 100644 modelopt/torch/fastgen/plugins/qwen_image.py create mode 100644 modelopt/torch/fastgen/utils.py create mode 100644 modelopt_recipes/general/distillation/dmd2_qwen_image.yaml create mode 100644 tests/unit/torch/fastgen/conftest.py create mode 100644 tests/unit/torch/fastgen/test_dmd_gradient_routing.py create mode 100644 tests/unit/torch/fastgen/test_dmd_math.py create mode 100644 tests/unit/torch/fastgen/test_dmd_pipeline_step.py create mode 100644 tests/unit/torch/fastgen/test_hook_requirements.py create mode 100644 tests/unit/torch/fastgen/test_pred_type_conversion.py create mode 100644 tests/unit/torch/fastgen/test_qwen_image_plugin.py diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 9f0d369827d..91fca51a59c 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -51,6 +51,7 @@ Changelog The legacy FSDP1 accelerate config is removed; ``llm_qat`` now documents FSDP2, DeepSpeed, and DDP backends. - The PTQ example scripts ``examples/llm_ptq/hf_ptq.py``, ``examples/llm_ptq/multinode_ptq.py`` and ``examples/megatron_bridge/quantize.py`` now derive their ``--qformat`` / ``--kv_cache_qformat`` (``--quant_cfg`` / ``--kv_cache_quant`` for Megatron-Bridge) CLI vocabularies by discovering the YAML presets under ``modelopt_recipes/configs/ptq/presets/{model,kv}/`` rather than carrying hardcoded ``QUANT_CFG_CHOICES`` / ``KV_QUANT_CFG_CHOICES`` tables. The discovery helper, alias table and ready-built ``QUANT_CFG_CHOICES`` / ``KV_QUANT_CFG_CHOICES`` mappings now live in ``modelopt.recipe.presets`` and are shared by all three scripts. Presets are loaded eagerly into a plain dict at import. Adding a new preset YAML makes it available on the CLI of all three with no script change — note this means each script now accepts every preset under those directories, not just a previously curated subset. All previously-supported short names (``int8_sq``, ``nvfp4_awq``, ``fp8_pb_wo``, ``nvfp4_mse``, ``w4a8_awq``, ``nvfp4_local_hessian``, ``fp8_pc_pt``, ``int8_wo``) keep working via a small deprecation alias table; new formats should be exposed as preset YAMLs (or, longer term, as full ``--recipe`` recipes). - Add ``configs/ptq/presets/kv/fp8_cast.yaml`` and ``configs/ptq/presets/kv/nvfp4_cast.yaml``, promoting ``fp8_cast`` / ``nvfp4_cast`` to first-class KV presets composed from the existing ``kv_fp8_cast`` / ``kv_nvfp4_cast`` unit fragments. The previous runtime ``use_constant_amax`` post-edit in ``hf_ptq.py`` is removed; ``use_constant_amax: true`` now lives in the YAML and is therefore authoritative. **Custom (out-of-tree) recipes that target a cast KV format must set ``use_constant_amax: true`` themselves on the ``[kv]_bmm_quantizer`` config** — in-tree recipes already do via the ``kv_*_cast`` units. +- Add DMD2 distillation for few-step diffusion models in ``examples/diffusers/fastgen/``: distill Qwen-Image into a 4/8-step student via Distribution Matching Distillation. See `examples/diffusers/fastgen/README.md `_ for details. **Bug Fixes** diff --git a/examples/diffusers/fastgen/README.md b/examples/diffusers/fastgen/README.md new file mode 100644 index 00000000000..b3c3bc780f9 --- /dev/null +++ b/examples/diffusers/fastgen/README.md @@ -0,0 +1,180 @@ +# DMD2 distillation for Qwen-Image + +Distill [`Qwen/Qwen-Image`](https://huggingface.co/Qwen/Qwen-Image) into a **few-step +generator** with DMD2 (Distribution Matching Distillation). The distilled student +produces images in as few as **1–4 sampling steps** while matching the base model's +output distribution. Built on `modelopt.torch.fastgen` and NeMo AutoModel's +[`TrainDiffusionRecipe`](https://github.com/NVIDIA-NeMo/Automodel/blob/main/nemo_automodel/recipes/diffusion/train.py). + +> [!NOTE] +> Qwen-Image is a third-party model with its own license terms. Review the +> [Qwen-Image model card](https://huggingface.co/Qwen/Qwen-Image) before downloading or +> redistributing weights or derivatives. + +## How DMD2 works + +DMD2 trains three networks together: + +| Model | Role | +|---|---| +| **Student** | the few-step generator you keep | +| **Fake-score** | a diffusion model that tracks the *student's* current output distribution | +| **Teacher** | the frozen base Qwen-Image model (the *target* distribution) | + +The distribution-matching gradient pushes the student toward the teacher and away from +the fake-score. Training alternates between two phases, controlled by `student_update_freq`: + +```text +each step: + if step % student_update_freq == 0: # student phase + update the student (distribution-matching [+ optional GAN] loss) + update the student EMA + else: # fake-score phase + update the fake-score network to track the student +``` + +The canonical config additionally enables **CFG** (classifier-free guidance on the +teacher) and a lightweight **GAN** branch (a discriminator head on a teacher feature +block, plus an R1 gradient penalty) for sharper samples. + +## Install + +From the repo root: + +```bash +pip install -e ".[all]" # ModelOpt + torch + diffusers +pip install -r examples/diffusers/fastgen/requirements.txt # nemo_automodel +``` + +`nemo_automodel[diffusion]` pulls in diffusers, accelerate, and the `TrainDiffusionRecipe` +this example subclasses. + +## Quick start — mock data (no dataset needed) + +The smoke config feeds random tensors at Qwen-Image's shapes, so it runs end-to-end with +**no dataset to prepare** — it exercises the full training loop (FSDP2 sharding, phase +alternation, checkpoint save/restore). Use it to validate your environment: + +```bash +torchrun --nproc-per-node=8 \ + examples/diffusers/fastgen/dmd2_finetune.py \ + --config examples/diffusers/fastgen/configs/dmd2_qwen_image_smoke.yaml +``` + +Scale `--fsdp.dp_size` to your GPU count. You'll see alternating `phase=student` / +`phase=fake_score` log lines and a checkpoint written at the last step. + +> The mock loop validates wiring only — it does **not** produce meaningful images. For +> that, train on real data (below). + +## Real-data training + +`configs/dmd2_qwen_image.yaml` is the canonical config: 4-step student, CFG, and the +GAN + R1 branch, trained on a preprocessed latent cache. Before launching, provide: + +- **A preprocessed Qwen-Image latent cache** — set `data.dataloader.cache_dir`. +- **A precomputed negative-prompt embedding** (required for CFG) — set + `data.dataloader.negative_prompt_embedding_path`. +- **An output directory** — set `checkpoint.checkpoint_dir`. + +The model path defaults to `Qwen/Qwen-Image`; point it at a local snapshot to avoid +re-downloading on every job. Then: + +```bash +torchrun --nproc-per-node=8 \ + examples/diffusers/fastgen/dmd2_finetune.py \ + --config examples/diffusers/fastgen/configs/dmd2_qwen_image.yaml \ + --step_scheduler.max_steps=5000 +``` + +Any `DMDConfig` field can be overridden on the CLI (e.g. `--dmd2.guidance_scale=3.5`). + +### Checkpoints & resuming + +Checkpoints land under `checkpoint.checkpoint_dir`. Alongside the student, the recipe +saves the DMD2 sidecars needed to resume exactly: the fake-score model + optimizer, the +student EMA (`ema_shadow.pt`), and the DMD iteration counter (`dmd_state.pt`). With +`restore_from: LATEST` a re-launch auto-resumes from the newest checkpoint; pin a +specific one with `--checkpoint.restore_from=epoch_0_step_500`. + +## Inference + +After training, sample from the distilled student. The pipeline loads your consolidated +student transformer plus the base Qwen-Image VAE / text encoder / tokenizer: + +```python +import torch +from inference_dmd2_qwen_image import QwenImageDMDInferencePipeline + +pipe = QwenImageDMDInferencePipeline.from_pretrained( + student_path="/path/to/checkpoint/epoch_0_step_500/model/consolidated", + base_pipeline_path="Qwen/Qwen-Image", + ema_path=None, # or ".../ema_shadow.pt" to sample the EMA weights + torch_dtype=torch.bfloat16, +).to("cuda") + +image = pipe( + prompt="a small red cube on a white table", + num_inference_steps=4, # match the student_sample_steps you trained with + height=1024, width=1024, + generator=torch.Generator("cuda").manual_seed(42), +).images[0] +image.save("sample.png") +``` + +Or run the bundled CLI for a quick check: + +```bash +python examples/diffusers/fastgen/inference_dmd2_qwen_image.py \ + --student_path /path/to/checkpoint/.../model/consolidated \ + --base_pipeline_path Qwen/Qwen-Image \ + --prompt "a small red cube on a white table" \ + --height 512 --width 512 +``` + +Set `num_inference_steps` to the number of steps the student was trained for +(`dmd2.student_sample_steps` — e.g. 4 for the canonical config, or 1 for a single-step +student). + +## Config reference + +| Section | Key | Role | +|---|---|---| +| `model` | `pretrained_model_name_or_path` | Qwen-Image HF id or local snapshot. | +| `model` | `mode` | `finetune` — loads the pretrained weights. | +| `step_scheduler` | `global_batch_size`, `local_batch_size`, `max_steps`, `ckpt_every_steps`, `log_every` | Standard AutoModel scheduling knobs. | +| `dmd2` | `recipe_path` | Built-in fastgen recipe to hydrate `DMDConfig` from (`general/distillation/dmd2_qwen_image`). | +| `dmd2` | `pipeline_plugin` | `qwen_image` — selects `QwenImageDMDPipeline` (2×2 patch packing / img_shapes). | +| `dmd2` | `student_sample_steps` | Number of student sampling steps (e.g. 4). | +| `dmd2` | `guidance_scale` | CFG strength on the teacher (`null` disables CFG; requires a negative-prompt embedding when set). | +| `dmd2` | `gan_loss_weight_gen`, `gan_r1_reg_weight`, `gan_feature_indices`, … | GAN branch (set `gan_loss_weight_gen: 0` to disable). | +| `dmd2` | `fake_score_lr`, `discriminator_lr` | Separate LRs for the fake-score / discriminator optimizers. | +| `dmd2` | `sample_t_cfg`, `ema` | Timestep sampling + student EMA settings. | +| `optim` | `learning_rate`, `optimizer.*` | Student AdamW knobs. | +| `fsdp` | `dp_size`, `tp_size`, `activation_checkpointing`, … | FSDP2 parallelism (set `dp_size` to your GPU count). | +| `data` | `dataloader._target_`, `cache_dir`, `negative_prompt_embedding_path` | Real latent cache vs. `build_mock_t2i_dataloader`. | +| `checkpoint` | `checkpoint_dir`, `model_save_format`, `restore_from` | Output dir, save format, resume behavior. | + +## Troubleshooting + +**`CUDA out of memory`.** Training holds three Qwen-Image transformers (student + teacher +- fake-score) plus optimizer state. Shard across more GPUs (raise `--fsdp.dp_size`), +enable `--fsdp.activation_checkpointing=true`, or use the mock smoke for wiring checks. + +**Loss is `NaN` on step 0.** Almost always an out-of-range timestep — confirm you haven't +overridden `dmd2.pred_type` away from `flow` (Qwen-Image is a rectified-flow model) or +changed the timestep schedule. + +**`guidance_scale is set but negative_encoder_hidden_states was not provided`.** CFG needs +a precomputed negative-prompt embedding. Set `data.dataloader.negative_prompt_embedding_path`, +or set `dmd2.guidance_scale: null` to disable CFG. + +**Dataloader yields empty batches.** Ensure your cache has at least +`local_batch_size * fsdp.dp_size` items; the distributed sampler drops incomplete batches. + +## Reference + +- Fastgen library: [`modelopt/torch/fastgen/`](../../../modelopt/torch/fastgen/) +- Built-in recipe: [`modelopt_recipes/general/distillation/dmd2_qwen_image.yaml`](../../../modelopt_recipes/general/distillation/dmd2_qwen_image.yaml) +- AutoModel recipe this example subclasses: + [`nemo_automodel/recipes/diffusion/train.py`](https://github.com/NVIDIA-NeMo/Automodel/blob/main/nemo_automodel/recipes/diffusion/train.py) diff --git a/examples/diffusers/fastgen/configs/dmd2_qwen_image.yaml b/examples/diffusers/fastgen/configs/dmd2_qwen_image.yaml new file mode 100644 index 00000000000..791b0efbaa9 --- /dev/null +++ b/examples/diffusers/fastgen/configs/dmd2_qwen_image.yaml @@ -0,0 +1,173 @@ +# Qwen-Image DMD2 — canonical real-data training config. +# +# Enables the full Qwen-Image DMD2 setup: 4-step student, CFG, and the GAN + R1 +# branch. This is the real-data training config; the mock-data wiring smoke +# (no dataset required) lives in ``dmd2_qwen_image_smoke.yaml``. +# +# Launch with torchrun, scaling ``--fsdp.dp_size`` to your GPU count: +# +# torchrun --nproc-per-node= \ +# examples/diffusers/fastgen/dmd2_finetune.py \ +# --config examples/diffusers/fastgen/configs/dmd2_qwen_image.yaml \ +# --step_scheduler.max_steps=5000 +# +# The data.* and checkpoint.* paths below are placeholders — point them at your +# own preprocessed latent cache and output directory before launching. + +seed: 42 + +wandb: + project: fastgen-dmd2-qwen-image + mode: online + name: qwen_image_dmd2 + +dist_env: + backend: nccl + timeout_minutes: 60 + +model: + pretrained_model_name_or_path: Qwen/Qwen-Image + mode: finetune + +step_scheduler: + # Must be divisible by local_batch_size * dp_size. For the canonical + # 32-node GB200 run below, dp_size=128 and local_batch_size=1, so GBS=128 + # gives one micro-batch per rank and no gradient accumulation. + global_batch_size: 128 + local_batch_size: 1 + ckpt_every_steps: 500 + # With 40K cached samples and GBS=128, one epoch is ~313 optimizer steps. + # 16 epochs gives >=5000 optimizer steps; max_steps below stops exactly at 5000. + num_epochs: 16 + log_every: 1 + # Production placeholder; CLI override expected for the long run. + max_steps: 5000 + +# ─── DMD2 block ───────────────────────────────────────────────────────────────── +# ``recipe_path`` is required by ``_resolve_dmd_config`` in dmd2_recipe.py. +# Every actual DMDConfig knob is explicitly pinned below so this YAML is the +# single source of truth for the formal run — the recipe defaults at +# ``modelopt_recipes/general/distillation/dmd2_qwen_image.yaml`` are NOT +# consulted for anything set here. Pydantic's ``model_copy(update=...)`` is +# a shallow merge, so when overriding any ``sample_t_cfg`` or ``ema`` +# sub-field we re-list every field in that block to avoid silent drops. +dmd2: + recipe_path: general/distillation/dmd2_qwen_image + pipeline_plugin: qwen_image + qwen_image_guidance: + + # ── DMD2 method core ── + pred_type: flow # Qwen-Image is rectified flow + num_train_timesteps: # continuous t ∈ [0, 1]; QwenImageDMDPipeline forwards t verbatim + guidance_scale: 4.0 # CFG strength on teacher during the student update + student_sample_steps: 4 # 4-step student (Phase 2) + student_sample_type: ode # Euler integration when unrolling the student + # Default keeps FastGen Qwen parity: train each rung from noised real latents. + # Override with ``--dmd2.backward_simulation=true`` to no-grad unroll the + # current student from the first rung before training the selected rung. + backward_simulation: false + student_update_freq: 5 # one student step per 5 fake-score / discriminator steps + fake_score_pred_type: x0 # fake_score regresses x0; teacher/student live in flow space + + # ── GAN branch (Phase 2) ── + gan_loss_weight_gen: 0.03 # generator weight in student loss + gan_use_same_t_noise: true # share (t, noise) between generator + discriminator updates + gan_r1_reg_weight: 0.1 # R1 gradient penalty on real samples + gan_r1_reg_alpha: 0.1 # R1 EMA coefficient + + # ── Optimizer LRs (student LR comes from ``optim.learning_rate`` below) ── + fake_score_lr: 2.0e-6 + discriminator_lr: 2.0e-6 + + # ── GAN discriminator placement & dim ── + gan_feature_indices: [30] # tap transformer_blocks[30] for the feature head + gan_num_blocks: 60 # Qwen-Image has 60 transformer blocks total + gan_inner_dim: 3072 # Qwen-Image hidden_size (matches FastGen reference) + + # ── 4-step student timestep schedule ── + # Exact ``torch.linspace(max_t=0.999, 0.0, 5).tolist()``, which is also the + # inference pipeline's default schedule when no t_list is passed (see + # ``inference_dmd2_qwen_image.py:259``). Training draws t uniformly from + # t_list[:-1], so the 4 trained timesteps exactly match the 4 inference + # sample points. The earlier ``[0.999, 0.75, 0.5, 0.25, 0.0]`` was + # ``linspace(1.0, 0, 5)`` with t=1 shaved to 0.999, leaving a silent ~0.3% + # train↔inference skew on each non-endpoint timestep. + sample_t_cfg: + # ``time_dist_type`` governs the *perturbation* timestep ``t`` that gets + # sampled on every loss path — VSD perturbation in compute_student_loss + # (dmd.py:417), fake-score DSM perturbation in compute_fake_score_loss + # (dmd.py:529), and GAN/discriminator perturbation in + # compute_discriminator_loss (dmd.py:605). All three call + # ``self.sample_timesteps`` → ``sample_t`` → reads ``time_dist_type``. + # + # It does NOT govern the student's *starting* timestep ``t_student``: + # under student_sample_steps > 1, ``_build_student_input`` calls + # ``sample_from_t_list`` (dmd.py:346) which samples uniformly from + # ``t_list[:-1]`` regardless of ``time_dist_type``. + # + # ``uniform`` matches FastGen's reference Qwen DMD2 config. Earlier formal + # runs used ``logitnormal`` (concentrates ``t`` toward the middle of + # [min_t, max_t]) to follow Automodel's Qwen finetune reference; we now + # default to uniform here to keep launchers parity-aligned without an + # EXTRA_ARGS override. Flip to ``logitnormal`` on the CLI for an ablation. + time_dist_type: uniform + min_t: 0.001 + max_t: 0.999 + p_mean: 0.0 + p_std: 1.0 + t_list: [0.999, 0.74925, 0.4995, 0.24975, 0.0] + + # ── Student EMA ── + ema: + decay: 0.9999 + type: constant + start_iter: 0 + fsdp2: true + mode: full_tensor + +optim: + learning_rate: 2.0e-6 + optimizer: + weight_decay: 0.01 + betas: [0.9, 0.999] + +lr_scheduler: + lr_decay_style: constant + lr_warmup_steps: 0 + min_lr: 2.0e-6 + +fsdp: + tp_size: 1 + cp_size: 1 + pp_size: 1 + dp_replicate_size: 1 + dp_size: 128 # 32 nodes × 4 GPUs/node = 128 GPUs + activation_checkpointing: true + +# Real cached latents (40K items, 1024x1024). Negative-prompt embedding for +# CFG is loaded inside the dataloader via ``negative_prompt_embedding_path``. +data: + dataloader: + _target_: nemo_automodel.components.datasets.diffusion.build_text_to_image_multiresolution_dataloader + cache_dir: /path/to/preprocessed/qwen_image_1024p + base_resolution: [1024, 1024] + batch_size: 1 + drop_last: false + shuffle: true + num_workers: 0 + negative_prompt_embedding_path: /path/to/preprocessed/qwen_image_1024p/negative_prompt_embedding.pt + +# Inference-loadable safetensors saves so checkpoints are usable without a +# secondary export pass. +checkpoint: + enabled: true + checkpoint_dir: /path/to/output/qwen_image_dmd2/checkpoints + model_save_format: safetensors + save_consolidated: true + v4_compatible: true + diffusers_compatible: true + # ``LATEST`` auto-resumes from the most recent checkpoint in checkpoint_dir + # if one exists, and starts fresh otherwise (safe on first launch). To pin a + # specific checkpoint, override on the CLI: + # --checkpoint.restore_from=epoch_0_step_500 + restore_from: LATEST diff --git a/examples/diffusers/fastgen/configs/dmd2_qwen_image_smoke.yaml b/examples/diffusers/fastgen/configs/dmd2_qwen_image_smoke.yaml new file mode 100644 index 00000000000..2d8f2ad3581 --- /dev/null +++ b/examples/diffusers/fastgen/configs/dmd2_qwen_image_smoke.yaml @@ -0,0 +1,109 @@ +# DMD2 on Qwen-Image — mock-data wiring smoke (NOT for real training). +# +# Feeds random tensors at Qwen-Image's shapes/dtypes via +# ``build_mock_t2i_dataloader``. Useful for end-to-end wiring tests (FSDP2, +# phase routing, checkpoint save/restore) but useless for image quality — real +# training uses ``dmd2_qwen_image.yaml`` (real cache + CFG + GAN). + +seed: 42 + +wandb: + project: fastgen-dmd2-qwen-image + mode: online + name: phase1_smoke + +dist_env: + backend: nccl + timeout_minutes: 60 + +model: + # Qwen-Image text-to-image checkpoint (HF id, or a local snapshot path to + # avoid hitting HF on every job). + pretrained_model_name_or_path: Qwen/Qwen-Image + mode: finetune + +step_scheduler: + global_batch_size: 8 + local_batch_size: 1 + ckpt_every_steps: 100 + num_epochs: 1 + log_every: 1 + # Hard cap the Phase 1 smoke at 100 optimizer steps. Flip to null for full runs. + max_steps: 100 + +# ─── DMD2-specific block ──────────────────────────────────────────────────────────── +dmd2: + # Built-in fastgen recipe for Qwen-Image: pred_type=flow, num_train_timesteps=null + # (Qwen normalises t internally), logit_normal time sampling, student_update_freq=5, + # fake_score_pred_type=x0, gan_loss_weight_gen=0 (Phase 1), EMA on. + recipe_path: general/distillation/dmd2_qwen_image + + # Phase 1 overrides: + # * GAN disabled — no discriminator shipped in Phase 1. + # * CFG disabled — no negative-prompt embedding precompute yet; guidance_scale=null + # short-circuits the negative-conditioning branch inside compute_student_loss. + gan_loss_weight_gen: 0.0 + guidance_scale: + + # Phase-1-only knobs (NOT on DMDConfig — consumed directly by the recipe): + # LR for the fake-score AdamW; matches the student LR below. + fake_score_lr: 1.0e-5 + + # Explicit pipeline plugin selector. Auto-detect via model_id substring works for a + # local Qwen-Image snapshot path, but spelling it out keeps the choice visible. + pipeline_plugin: qwen_image + + # Optional guidance scalar forwarded to the transformer's ``guidance`` kwarg every + # call. The shipped ``Qwen/Qwen-Image`` checkpoint has guidance_embeds=false, so + # leave this null. Set to e.g. 3.5 only if you've fine-tuned a guidance-embed + # variant. + qwen_image_guidance: + +# Student LR + optimizer. +optim: + learning_rate: 1.0e-5 + optimizer: + weight_decay: 0.01 + betas: [0.9, 0.999] + +# Constant LR for the smoke — flip to cosine/linear once the loop is validated. +lr_scheduler: + lr_decay_style: constant + lr_warmup_steps: 0 + min_lr: 1.0e-5 + +# FSDP2 config. Scale dp_size to match the GPU count of the run. +fsdp: + tp_size: 1 + cp_size: 1 + pp_size: 1 + dp_replicate_size: 1 + dp_size: 8 + activation_checkpointing: true + +# Mock data by default — the debug target is the DMD2 loop, not the data pipeline. +# Swap _target_ to build_text_to_image_multiresolution_dataloader once a real +# preprocessed cache is available. +data: + dataloader: + _target_: nemo_automodel.components.datasets.diffusion.build_mock_t2i_dataloader + # Qwen-Image VAE: 16 latent channels, 8x spatial downsample. + # 256x256 image -> 32x32 latent. Must be even (2x2 patch packing). + num_channels: 16 + spatial_h: 32 + spatial_w: 32 + # Qwen2.5-VL hidden_dim = 3584 (text_encoder/config.json: hidden_size=3584). + text_seq_len: 512 + text_embed_dim: 3584 + length: 256 + num_workers: 0 + shuffle: true + +checkpoint: + enabled: true + checkpoint_dir: /path/to/output/qwen_image_dmd2_smoke/checkpoints + model_save_format: torch_save + save_consolidated: false + diffusers_compatible: false + # Set to LATEST or epoch_0_step_100 (etc.) to resume. Null = start fresh. + restore_from: diff --git a/examples/diffusers/fastgen/dmd2_finetune.py b/examples/diffusers/fastgen/dmd2_finetune.py new file mode 100644 index 00000000000..0f7936ef1bb --- /dev/null +++ b/examples/diffusers/fastgen/dmd2_finetune.py @@ -0,0 +1,38 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Entrypoint for the DMD2 Qwen-Image AutoModel example. + +Parses the YAML config + CLI overrides with AutoModel's argument parser, then hands +control to :class:`DMD2DiffusionRecipe`. +""" + +from __future__ import annotations + +from dmd2_recipe import DMD2DiffusionRecipe +from nemo_automodel.components.config._arg_parser import parse_args_and_load_config + + +def main( + default_config_path: str = "examples/diffusers/fastgen/configs/dmd2_qwen_image_smoke.yaml", +) -> None: + cfg = parse_args_and_load_config(default_config_path) + recipe = DMD2DiffusionRecipe(cfg) + recipe.setup() + recipe.run_train_validation_loop() + + +if __name__ == "__main__": + main() diff --git a/examples/diffusers/fastgen/dmd2_recipe.py b/examples/diffusers/fastgen/dmd2_recipe.py new file mode 100644 index 00000000000..9614d4283a5 --- /dev/null +++ b/examples/diffusers/fastgen/dmd2_recipe.py @@ -0,0 +1,1317 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""DMD2 distillation recipe built on NeMo AutoModel. + +This recipe subclasses :class:`nemo_automodel.recipes.diffusion.train.TrainDiffusionRecipe` +so it inherits AutoModel's student + optimizer + dataloader + checkpoint plumbing, then +drives ``modelopt.torch.fastgen.DMDPipeline`` (or a plugin subclass) through the +three-phase DMD2 alternation (student update / fake-score update / EMA step). + +Backbone: **Qwen-Image** (``Qwen/Qwen-Image``) — 4D ``image_latents``, +:class:`QwenImageDMDPipeline` handles 2x2 patch packing / img_shapes / +unpacking. Configs: ``configs/dmd2_qwen_image.yaml`` for the canonical +real-data run (4-step + CFG + GAN); ``configs/dmd2_qwen_image_smoke.yaml`` +for the mock-data wiring smoke (no dataset required). + +Launch:: + + # Mock-data wiring smoke (no real cache required). + torchrun --nproc-per-node=8 \\ + examples/diffusers/fastgen/dmd2_finetune.py \\ + --config examples/diffusers/fastgen/configs/dmd2_qwen_image_smoke.yaml + # Real-data formal training (canonical). + torchrun --nproc-per-node=8 \\ + examples/diffusers/fastgen/dmd2_finetune.py \\ + --config examples/diffusers/fastgen/configs/dmd2_qwen_image.yaml + +See ``examples/diffusers/fastgen/README.md`` for the three-phase +alternation diagram + troubleshooting notes. +""" + +from __future__ import annotations + +import json +import logging +import os +import re +import shutil +from typing import Any + +import torch +import torch.distributed as dist + +# nemo_automodel is required to run this example (installed via requirements.txt). Wrap +# the import in a clear, actionable error, but still re-raise so it fails loudly with a +# real stack — a previous gate that fell back to ``object`` silently masked missing deps +# and surfaced as a downstream ``TypeError: takes no arguments``. +try: + from nemo_automodel._diffusers.auto_diffusion_pipeline import NeMoAutoDiffusionPipeline + from nemo_automodel.recipes.diffusion.train import TrainDiffusionRecipe, is_main_process +except ImportError as exc: + raise ImportError( + "The DMD2 fastgen example requires `nemo_automodel`. Install the example " + "dependencies with:\n" + " pip install -r examples/diffusers/fastgen/requirements.txt" + ) from exc +from torch import nn + +import modelopt.torch.fastgen as mtf +from modelopt.torch.fastgen.config import DMDConfig +from modelopt.torch.fastgen.discriminators import Discriminator_ImageDiT +from modelopt.torch.fastgen.methods.dmd import DMDPipeline +from modelopt.torch.fastgen.plugins import qwen_image as qwen_image_plugin + +# Keys under the ``dmd2:`` YAML block that shadow fields on :class:`DMDConfig`. The +# recipe deep-merges these on top of the loaded built-in recipe so users can tweak DMD2 +# hyperparameters without editing the shared +# ``modelopt_recipes/general/distillation/dmd2_qwen_image.yaml`` file. +_DMD_CONFIG_OVERRIDE_KEYS = frozenset(DMDConfig.model_fields.keys()) + + +def _deep_merge_dicts(base: dict, override: dict) -> dict: + """Recursively merge ``override`` onto ``base`` and return a new dict. + + Nested dicts (e.g. the ``sample_t_cfg`` / ``ema`` sub-configs) are merged key-by-key + rather than replaced wholesale, so a YAML block that overrides a single sub-field + keeps the recipe's other sub-fields instead of silently resetting them to + :class:`DMDConfig` defaults. + """ + merged = dict(base) + for key, value in override.items(): + existing = merged.get(key) + if isinstance(value, dict) and isinstance(existing, dict): + merged[key] = _deep_merge_dicts(existing, value) + else: + merged[key] = value + return merged + + +# Auto-detect substrings (matched case-insensitively against ``model_id``) that map to +# DMDPipeline plugin subclasses. Keep this list small — adding a new entry is only the +# right move when the model has a non-diffusers transformer signature that requires a +# pack/unpack wrapper. Models with the standard ``(hidden_states, timestep, +# encoder_hidden_states)`` signature work with the base :class:`DMDPipeline`. +_PIPELINE_PLUGIN_BY_MODEL_SUBSTR = ( + ("qwen-image", "qwen_image"), + ("qwen_image", "qwen_image"), +) + +_DMD_COMPLETE_MARKER = "dmd2_complete.marker" + + +class DMD2DiffusionRecipe(TrainDiffusionRecipe): + """DMD2 recipe that reuses ``TrainDiffusionRecipe`` for the student path. + + What the superclass gives us (reused unchanged): + + - Student transformer + AdamW optimizer + LR scheduler, loaded via + :class:`NeMoAutoDiffusionPipeline` with FSDP2 sharding. + - ``self.dataloader`` / ``self.sampler`` (swapped to AutoModel's mock dataloader + when ``data.use_mock: true`` — see :meth:`_build_dataloader`). + - ``self.step_scheduler`` (gradient accumulation + checkpoint cadence). + - ``self.checkpointer`` (DCP student weights + optimizer). + - ``self.device`` / ``self.bf16`` / ``self.clip_grad_max_norm`` / etc. + + What this recipe adds: + + - A frozen teacher loaded via a second :meth:`NeMoAutoDiffusionPipeline.from_pretrained` + call with the same ``parallel_scheme`` so it lands with the same FSDP2 sharding + as the student. + - A trainable fake-score transformer loaded the same way (weights identical to the + teacher on step 0). + - A separate AdamW optimizer for the fake-score phase. + - An :class:`mtf.DMDPipeline` driving VSD + DSM + EMA. + - Sidecar checkpoint save / restore for fake-score weights, fake-score optimizer, + EMA shadow, and DMDPipeline iteration counters. + + Classifier-free guidance, the GAN discriminator branch, and real-data training are + configurable via the ``dmd2:`` / ``data:`` YAML blocks — all enabled in the canonical + ``configs/dmd2_qwen_image.yaml`` and off in the mock-data smoke. See + ``examples/diffusers/fastgen/README.md`` for details. + """ + + # ------------------------------------------------------------------ # + # Setup # + # ------------------------------------------------------------------ # + + def setup(self) -> None: + """Build the student via ``super()``, then add teacher / fake_score / DMDPipeline. + + The extras (``_teacher``, ``_fake_score``, ``_fake_score_optimizer``, + ``_dmd_pipeline``, ``_dmd_config``) are assigned through ``self.__dict__[...]`` + to bypass :meth:`BaseRecipe.__setattr__`'s auto-tracking — otherwise they'd be + added to ``__state_tracked`` and clobber the superclass's single-model / + single-optimizer checkpoint loop. + """ + # 1. Run the parent setup. Builds self.model / self.optimizer / self.lr_scheduler / + # self.dataloader / self.step_scheduler / self.checkpointer / etc. The parent's + # trailing call to self.load_checkpoint(self.restore_from) runs BEFORE our + # extras exist, so it only restores the student — that is intentional and safe. + # + # For the mock-data smoke, ``data.dataloader._target_`` in the YAML points at + # ``nemo_automodel.components.datasets.diffusion.build_mock_dataloader`` so the + # parent wires up the mock dataloader for us — no swap needed. + super().setup() + + # 2. Load the frozen teacher. Same from_pretrained path, same parallel_scheme, but + # ``load_for_training=False`` so the transformer comes back in eval mode with + # requires_grad=False. Bypass __setattr__ to stay invisible to the parent's + # __state_tracked loop. + self.__dict__["_teacher"] = self._load_frozen_teacher() + + # 4. Load the trainable fake-score. Third from_pretrained call — weights start + # identical to the teacher (both come from the same HF checkpoint). + self.__dict__["_fake_score"] = self._load_fake_score() + + # 5. Resolve the DMDConfig: load the fastgen built-in recipe, then apply any + # inline overrides under the YAML ``dmd2:`` block. + self.__dict__["_dmd_config"] = self._resolve_dmd_config() + + # 6. Optimizer for the fake-score phase. LR defaults to student LR when + # ``dmd2.fake_score_lr`` isn't set. + self.__dict__["_fake_score_optimizer"] = self._build_fake_score_optimizer() + + # 7. Optional GAN discriminator. Built when ``gan_loss_weight_gen > 0`` so the + # DMDPipeline constructor's assert is satisfied; otherwise ``discriminator=None`` + # and that assert fires if a YAML enables GAN for an unsupported backbone. + self.__dict__["_discriminator"] = self._build_discriminator() + self.__dict__["_discriminator_optimizer"] = self._build_discriminator_optimizer() + if self._discriminator is not None: + self._attach_gan_feature_capture() + + # 8. DMDPipeline. + # + # Dispatch to a plugin subclass when the backbone needs a non-diffusers call + # signature (e.g. Qwen-Image's packed-latents path). Default = base pipeline. + pipeline_cls = self._resolve_pipeline_cls() + pipeline_kwargs = self._resolve_pipeline_kwargs(pipeline_cls) + self.__dict__["_dmd_pipeline"] = pipeline_cls( + student=self.model, + teacher=self._teacher, + fake_score=self._fake_score, + config=self._dmd_config, + discriminator=self._discriminator, + **pipeline_kwargs, + ) + + # 8. Drop the parent's flow_matching_pipeline — we replace the training loop, + # so keeping it around is pure deadweight. The attribute is not tracked by + # ``__state_tracked`` (FlowMatchingPipeline is a plain class), so ``del`` is + # safe. + if hasattr(self, "flow_matching_pipeline"): + del self.flow_matching_pipeline + + # 9. Extend the student-only restore that super().setup() already ran: also + # restore the fake_score / fake_score_optimizer / EMA / DMD state from the + # same checkpoint directory. + self._restore_dmd_extras(getattr(self, "_dmd2_resolved_restore_from", self.restore_from)) + + if is_main_process(): + logging.info("[DMD2] recipe initialized: %s", self._dmd_config_summary()) + logging.info("[DMD2] full configuration:\n%s", self._dmd_full_config_log()) + + # ------------------------------------------------------------------ # + # Training loop # + # ------------------------------------------------------------------ # + + def run_train_validation_loop(self) -> None: + """Three-phase DMD2 alternation driven by ``step_scheduler``. + + Each outer iteration picks either the student or fake-score phase based on + ``global_step % student_update_freq``. The student phase runs + ``compute_student_loss`` + ``update_ema``. The fake-score phase runs + ``compute_fake_score_loss`` and, when a discriminator is configured + (``gan_loss_weight_gen > 0``), ``compute_discriminator_loss``. + + Mirrors the gating in ``FastGen/fastgen/methods/distribution_matching/dmd2.py`` + (``_student_update_step`` / ``_fake_score_discriminator_update_step``). + """ + dmd = self._dmd_pipeline + cfg = self._dmd_config + + logging.info( + "[DMD2] Starting DMD2 training on %s (pipeline=%s)", + self.model_id, + type(self._dmd_pipeline).__name__, + ) + # Dataloader target (mock vs real cache) is non-obvious from the per-step + # logs; surface it explicitly here so §16's "mock or real dataloader + # target" bullet is checkable from the startup log. + try: + dl_target = type(self.dataloader.dataset).__name__ + logging.info("[DMD2] Dataloader dataset class: %s", dl_target) + except Exception: + pass + logging.info( + "[DMD2] Global batch size: %s; local batch size: %s; DP size: %s", + self.global_batch_size, + self.local_batch_size, + self.dp_size, + ) + logging.info( + "[DMD2] student_update_freq=%d; fake_score_pred_type=%s; guidance_scale=%s;" + " gan_loss_weight_gen=%s", + cfg.student_update_freq, + cfg.fake_score_pred_type, + cfg.guidance_scale, + cfg.gan_loss_weight_gen, + ) + + global_step = int(self.step_scheduler.step) + + for epoch in self.step_scheduler.epochs: + if self.sampler is not None and hasattr(self.sampler, "set_epoch"): + self.sampler.set_epoch(epoch) + + # On resume, the diffusion sampler's load_state_dict primes + # ``_batches_to_skip`` so the next ``__iter__`` skips already-yielded + # batches. Forward that to tqdm's ``initial=`` so the progress bar + # reads e.g. ``187/313`` instead of the misleading ``0/313`` (the + # sampler resets the counter to 0 on the next ``__iter__`` call, + # so reading it here is a one-shot for the resumed epoch only). + tqdm_initial = int(getattr(self.sampler, "_batches_to_skip", 0) or 0) + + if is_main_process(): + from tqdm import tqdm + + self.step_scheduler.dataloader = tqdm( + self.dataloader, + desc=f"Epoch {epoch + 1}/{self.num_epochs}", + initial=tqdm_initial, + ) + else: + self.step_scheduler.dataloader = self.dataloader + + epoch_student_loss = 0.0 + epoch_fake_score_loss = 0.0 + student_steps = 0 + fake_score_steps = 0 + + for batch_group in self.step_scheduler: + is_student_phase = (global_step % cfg.student_update_freq) == 0 + + if is_student_phase: + self.optimizer.zero_grad(set_to_none=True) + else: + self._fake_score_optimizer.zero_grad(set_to_none=True) + + self._set_grad_requirements(is_student_phase) + + micro_losses: list[float] = [] + micro_vsd_losses: list[float] = [] + micro_disc_losses: list[float] = [] + for micro_batch in batch_group: + ( + latents, + noise, + text_embeds, + text_mask, + neg_text_embeds, + neg_text_mask, + ) = self._prepare_micro_batch(micro_batch) + + if is_student_phase: + # ``compute_student_loss`` reads ``guidance_scale`` from the + # DMDConfig when this kwarg is None. We pass the negative + # embedding unconditionally — the function ignores it when + # CFG is disabled, and raises a clear ValueError when CFG + # is enabled but no negative was supplied. + losses = dmd.compute_student_loss( + latents, + noise, + encoder_hidden_states=text_embeds, + encoder_hidden_states_mask=text_mask, + negative_encoder_hidden_states=neg_text_embeds, + negative_encoder_hidden_states_mask=neg_text_mask, + guidance_scale=None, + ) + micro_vsd_losses.append(float(losses["vsd"].item())) + else: + losses = dmd.compute_fake_score_loss( + latents, + noise, + encoder_hidden_states=text_embeds, + encoder_hidden_states_mask=text_mask, + ) + + (losses["total"] / len(batch_group)).backward() + micro_losses.append(float(losses["total"].item())) + + # GAN: in the fake-score phase, also update the discriminator + # on the same batch (FastGen pattern: + # _fake_score_discriminator_update_step). + if ( + not is_student_phase + and self._discriminator is not None + and self._discriminator_optimizer is not None + ): + self._discriminator_optimizer.zero_grad(set_to_none=True) + disc_losses = dmd.compute_discriminator_loss( + latents, + noise, + encoder_hidden_states=text_embeds, + encoder_hidden_states_mask=text_mask, + ) + (disc_losses["total"] / len(batch_group)).backward() + # Manual gradient all-reduce across DP ranks (the + # discriminator is replicated, not FSDP-sharded). + if dist.is_initialized(): + for p in self._discriminator.parameters(): + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + torch.nn.utils.clip_grad_norm_( + self._discriminator.parameters(), + max_norm=self.clip_grad_max_norm, + ) + self._discriminator_optimizer.step() + micro_disc_losses.append(float(disc_losses["total"].item())) + + # Grad clip on whichever module is the active trainable. + if is_student_phase: + grad_norm = torch.nn.utils.clip_grad_norm_( + self.model.parameters(), max_norm=self.clip_grad_max_norm + ) + else: + grad_norm = torch.nn.utils.clip_grad_norm_( + self._fake_score.parameters(), max_norm=self.clip_grad_max_norm + ) + grad_norm = float(grad_norm) if torch.is_tensor(grad_norm) else grad_norm + + # Step. + if is_student_phase: + self.optimizer.step() + dmd.update_ema() + if self.lr_scheduler is not None: + self.lr_scheduler[0].step(1) + else: + self._fake_score_optimizer.step() + + group_loss_mean = float(sum(micro_losses) / len(micro_losses)) + if is_student_phase: + epoch_student_loss += group_loss_mean + student_steps += 1 + else: + epoch_fake_score_loss += group_loss_mean + fake_score_steps += 1 + + global_step = int(self.step_scheduler.step) + + if ( + self.log_every + and self.log_every > 0 + and is_main_process() + and (global_step % self.log_every == 0) + ): + self._log_step( + global_step=global_step, + is_student_phase=is_student_phase, + group_loss=group_loss_mean, + grad_norm=grad_norm, + vsd_loss=(sum(micro_vsd_losses) / len(micro_vsd_losses)) + if micro_vsd_losses + else None, + disc_loss=(sum(micro_disc_losses) / len(micro_disc_losses)) + if micro_disc_losses + else None, + ) + + if self.step_scheduler.is_ckpt_step: + # Use the group mean of the active phase as the reported train loss. + self.save_checkpoint(epoch, global_step, group_loss_mean) + + # End-of-epoch logging. + if is_main_process(): + avg_student = ( + (epoch_student_loss / student_steps) if student_steps else float("nan") + ) + avg_fake = ( + epoch_fake_score_loss / fake_score_steps if fake_score_steps else float("nan") + ) + logging.info( + "[DMD2] Epoch %d complete. student_avg=%.6f (%d steps) " + "fake_score_avg=%.6f (%d steps)", + epoch + 1, + avg_student, + student_steps, + avg_fake, + fake_score_steps, + ) + + if torch.cuda.is_available(): + peak = torch.cuda.max_memory_allocated() / (1024**3) + reserved = torch.cuda.max_memory_reserved() / (1024**3) + rank = dist.get_rank() if dist.is_initialized() else 0 + logging.info( + "[DMD2] PEAK_MEM rank=%d max_allocated=%.2fGiB max_reserved=%.2fGiB", + rank, + peak, + reserved, + ) + + if is_main_process(): + logging.info("[DMD2] Training complete. Final step: %s", global_step) + + # ------------------------------------------------------------------ # + # Checkpoint save / restore (sidecars next to student DCP) # + # ------------------------------------------------------------------ # + + def load_checkpoint(self, restore_from: str | None = None): + """Load only from checkpoints whose DMD2 sidecars are complete. + + ``TrainDiffusionRecipe.setup()`` calls this before the DMD2-only objects exist, + so this method only resolves the path and delegates the student restore to the + parent. The sidecars are restored later by ``_restore_dmd_extras``. + """ + resolved = self._resolve_complete_dmd_checkpoint(restore_from) + self.__dict__["_dmd2_resolved_restore_from"] = resolved + + if resolved is None: + if ( + restore_from is not None + and str(restore_from).upper() == "LATEST" + and is_main_process() + ): + logging.warning( + "[DMD2] restore_from=LATEST but no complete DMD2 checkpoint was found in %s. " + "Starting fresh.", + self.checkpointer.config.checkpoint_dir, + ) + return + + super().load_checkpoint(resolved) + + def save_checkpoint( + self, + epoch: int, + step: int, + train_loss: float, + val_loss: dict[str, float] | None = None, + best_metric_key: str = "default", + ) -> None: + """Delegate student save to ``super()``, then sidecar the DMD2 extras.""" + # Recover from a partial save from a previous run (e.g. SLURM time + # limit killed the job between super().save_checkpoint() — which writes + # the model + step_scheduler + dataloader — and our DMD2 sidecar + # writes below). The parent's save_checkpoint refuses to overwrite an + # existing directory and raises FileExistsError, so without this we'd + # need a manual cleanup every time a SLURM kill landed mid-save. + path = os.path.join(self.checkpointer.config.checkpoint_dir, f"epoch_{epoch}_step_{step}") + if is_main_process() and self.checkpointer.config.enabled and os.path.exists(path): + if not self._is_dmd_checkpoint_complete(path): + logging.warning( + "[DMD2] cleaning up incomplete checkpoint directory left by a previous run: %s", + path, + ) + shutil.rmtree(path) + if dist.is_initialized(): + dist.barrier() + + previous_complete = None + if self.checkpointer.config.enabled: + previous_complete = self._find_latest_complete_dmd_checkpoint( + self.checkpointer.config.checkpoint_dir + ) + + super().save_checkpoint(epoch, step, train_loss, val_loss, best_metric_key) + + if not self.checkpointer.config.enabled: + return + + # The parent save updates LATEST before DMD2 sidecars exist. Until the marker is + # written below, keep LATEST on the previous complete DMD2 checkpoint. + if is_main_process(): + if previous_complete is not None: + self._update_latest_symlink(previous_complete) + else: + self._remove_checkpoint_pointer("LATEST") + if dist.is_initialized(): + dist.barrier() + + self._save_dmd_extras(path) + + if dist.is_initialized(): + dist.barrier() + + if is_main_process(): + self._write_dmd_complete_marker(path) + self._update_checkpoint_symlink("DMD2_LATEST", path) + self._update_latest_symlink(path) + if dist.is_initialized(): + dist.barrier() + + def _save_dmd_extras(self, path: str) -> None: + """Write fake_score DCP + fake_score_optimizer DCP + ema_shadow.pt + dmd_state.pt.""" + # fake_score weights — DCP sharded save via the same Checkpointer the parent uses + # for the student. Each rank writes its own shard. + fs_weights_dir = os.path.join(path, "fake_score") + os.makedirs(fs_weights_dir, exist_ok=True) + self.checkpointer.save_model( + model=self._fake_score, + weights_path=fs_weights_dir, + peft_config=None, + tokenizer=None, + ) + # fake_score optimizer — also DCP sharded. ``save_optimizer`` takes the optimizer + # and its owning model in order to rebuild the parameter mapping. + fs_opt_dir = os.path.join(path, "fake_score_optimizer") + os.makedirs(fs_opt_dir, exist_ok=True) + self.checkpointer.save_optimizer( + self._fake_score_optimizer, self._fake_score, fs_opt_dir, None + ) + + # EMA shadow + DMD scalar state — rank-0 torch.save. EMA's ``state_dict`` already + # materialises full tensors via ``DTensor.full_tensor()`` under FSDP2 full_tensor + # mode, so this is a single unsharded file. + if is_main_process(): + logging.info("[DMD2] saved fake_score weights -> %s", fs_weights_dir) + logging.info("[DMD2] saved fake_score optimizer -> %s", fs_opt_dir) + if self._dmd_pipeline.ema is not None: + ema_path = os.path.join(path, "ema_shadow.pt") + torch.save(self._dmd_pipeline.ema.state_dict(), ema_path) + logging.info("[DMD2] saved ema_shadow -> %s", ema_path) + state_path = os.path.join(path, "dmd_state.pt") + torch.save({"iteration": self._dmd_pipeline._iteration}, state_path) + logging.info( + "[DMD2] saved dmd_state (iteration=%d) -> %s", + int(self._dmd_pipeline._iteration), + state_path, + ) + # Discriminator + its optimizer — replicated across ranks (no FSDP), + # so rank-0 torch.save of the canonical state_dict suffices. + if self._discriminator is not None: + disc_path = os.path.join(path, "discriminator.pt") + torch.save(self._discriminator.state_dict(), disc_path) + logging.info("[DMD2] saved discriminator -> %s", disc_path) + if self._discriminator_optimizer is not None: + disc_opt_path = os.path.join(path, "discriminator_optimizer.pt") + torch.save(self._discriminator_optimizer.state_dict(), disc_opt_path) + logging.info("[DMD2] saved discriminator optimizer -> %s", disc_opt_path) + + def _write_dmd_complete_marker(self, path: str) -> None: + marker_path = os.path.join(path, _DMD_COMPLETE_MARKER) + payload = { + "checkpoint": os.path.basename(os.path.realpath(path)), + "dmd_iteration": int(self._dmd_pipeline._iteration), + } + with open(marker_path, "w") as f: + json.dump(payload, f) + f.write("\n") + logging.info("[DMD2] marked checkpoint complete -> %s", marker_path) + + def _remove_checkpoint_pointer(self, link_name: str) -> None: + ckpt_root = self.checkpointer.config.checkpoint_dir + for path in ( + os.path.join(ckpt_root, link_name), + os.path.join(ckpt_root, f"{link_name}.txt"), + ): + if os.path.lexists(path): + os.remove(path) + + def _restore_dmd_extras(self, restore_from: str | None) -> None: + """Restore fake_score + fake_score optimizer + EMA + DMD scalar state. + + No-op when no checkpoint is being restored. ``load_checkpoint`` resolves + ``LATEST`` to the latest complete DMD2 checkpoint before this method runs. + """ + if restore_from is None: + return + + ckpt_dir = self._resolve_extras_dir(restore_from) + if ckpt_dir is None or not os.path.isdir(ckpt_dir): + return + + fs_weights_dir = os.path.join(ckpt_dir, "fake_score") + fs_opt_dir = os.path.join(ckpt_dir, "fake_score_optimizer") + ema_path = os.path.join(ckpt_dir, "ema_shadow.pt") + state_path = os.path.join(ckpt_dir, "dmd_state.pt") + + # Checkpointer.save_model writes DCP shards to ``/model/``; + # load_model expects that *inner* ``model/`` dir as ``model_path`` (see + # ``BaseRecipe.load_checkpoint`` which passes ``os.path.join(ckpt_dir, "model")``). + # The kwarg name differs between save (``weights_path``) and load (``model_path``). + fs_weights_model_dir = os.path.join(fs_weights_dir, "model") + if os.path.isdir(fs_weights_model_dir): + self.checkpointer.load_model(model=self._fake_score, model_path=fs_weights_model_dir) + if is_main_process(): + logging.info("[DMD2] restored fake_score weights <- %s", fs_weights_model_dir) + elif is_main_process(): + logging.info( + "[DMD2] WARN: fake_score weights dir missing at %s -- skipping", + fs_weights_model_dir, + ) + # load_optimizer, in contrast, appends ``optim/`` internally — pass the base dir. + if os.path.isdir(os.path.join(fs_opt_dir, "optim")): + self.checkpointer.load_optimizer( + self._fake_score_optimizer, self._fake_score, fs_opt_dir, None + ) + if is_main_process(): + logging.info("[DMD2] restored fake_score optimizer <- %s", fs_opt_dir) + elif is_main_process(): + logging.info( + "[DMD2] WARN: fake_score optimizer dir missing at %s -- skipping", + fs_opt_dir, + ) + + if os.path.isfile(ema_path) and self._dmd_pipeline.ema is not None: + ema_state = torch.load(ema_path, map_location="cpu", weights_only=False) + self._dmd_pipeline.ema.load_state_dict(ema_state) + if is_main_process(): + logging.info("[DMD2] restored ema_shadow <- %s", ema_path) + if os.path.isfile(state_path): + state = torch.load(state_path, map_location="cpu", weights_only=False) + self._dmd_pipeline._iteration = int(state.get("iteration", 0)) + if is_main_process(): + logging.info( + "[DMD2] restored dmd_state (iteration=%d) <- %s", + self._dmd_pipeline._iteration, + state_path, + ) + + # Discriminator + its optimizer. + if self._discriminator is not None: + disc_path = os.path.join(ckpt_dir, "discriminator.pt") + if os.path.isfile(disc_path): + disc_state = torch.load(disc_path, map_location="cpu", weights_only=False) + self._discriminator.load_state_dict(disc_state) + if is_main_process(): + logging.info("[DMD2] restored discriminator <- %s", disc_path) + elif is_main_process(): + logging.info("[DMD2] WARN: discriminator file missing at %s -- skipping", disc_path) + if self._discriminator_optimizer is not None: + disc_opt_path = os.path.join(ckpt_dir, "discriminator_optimizer.pt") + if os.path.isfile(disc_opt_path): + disc_opt_state = torch.load(disc_opt_path, map_location="cpu", weights_only=False) + self._discriminator_optimizer.load_state_dict(disc_opt_state) + if is_main_process(): + logging.info("[DMD2] restored discriminator optimizer <- %s", disc_opt_path) + elif is_main_process(): + logging.info( + "[DMD2] WARN: discriminator optimizer file missing at %s -- skipping", + disc_opt_path, + ) + + def _resolve_extras_dir(self, restore_from: str) -> str | None: + """Best-effort resolve of the checkpoint dir, matching BaseRecipe's convention. + + For explicit paths we pass through; for ``"LATEST"`` we look under + ``checkpointer.config.checkpoint_dir``. This keeps resolution simple and delegates + the hard cases (async symlinks, cross-node shared filesystems) to the user. + """ + if os.path.isabs(restore_from): + return restore_from + # Try the checkpoint_dir-relative form first (matches the parent's symlink + # naming — "LATEST" or an explicit ``epoch_N_step_M`` subdir). + candidate = os.path.join(self.checkpointer.config.checkpoint_dir, restore_from) + if os.path.exists(candidate): + return os.path.realpath(candidate) + return None + + def _resolve_complete_dmd_checkpoint(self, restore_from: str | None) -> str | None: + ckpt_root = self.checkpointer.config.checkpoint_dir + + if restore_from is None or str(restore_from).upper() in {"LATEST", "DMD2_LATEST"}: + return self._find_latest_complete_dmd_checkpoint(ckpt_root) + + if os.path.isabs(restore_from): + candidate = restore_from + else: + candidate = os.path.join(ckpt_root, restore_from) + candidate = os.path.realpath(candidate) + + if not os.path.isdir(candidate): + return candidate + if not self._is_dmd_checkpoint_complete(candidate): + raise RuntimeError( + f"DMD2 checkpoint is incomplete and cannot be restored: {candidate}. " + "Use a complete older checkpoint or remove the partial directory." + ) + return candidate + + def _find_latest_complete_dmd_checkpoint(self, ckpt_root: str) -> str | None: + dmd2_latest = os.path.join(ckpt_root, "DMD2_LATEST") + for pointer in (dmd2_latest, os.path.join(ckpt_root, "LATEST")): + resolved = self._resolve_checkpoint_pointer(pointer) + if resolved is not None and self._is_dmd_checkpoint_complete(resolved): + return resolved + + candidates = [] + if os.path.isdir(ckpt_root): + for name in os.listdir(ckpt_root): + path = os.path.join(ckpt_root, name) + if ( + os.path.isdir(path) + and "_step_" in name + and self._is_dmd_checkpoint_complete(path) + ): + candidates.append(os.path.realpath(path)) + if not candidates: + return None + return max(candidates, key=self._checkpoint_step) + + def _resolve_checkpoint_pointer(self, pointer: str) -> str | None: + resolved = None + if os.path.islink(pointer): + try: + resolved = os.readlink(pointer) + except OSError: + return None + elif os.path.isfile(pointer + ".txt"): + try: + with open(pointer + ".txt") as f: + resolved = f.read().strip() + except OSError: + return None + if not resolved: + return None + if not os.path.isabs(resolved): + resolved = os.path.abspath(os.path.join(os.path.dirname(pointer), resolved)) + return os.path.realpath(resolved) if os.path.isdir(resolved) else None + + def _is_dmd_checkpoint_complete(self, path: str) -> bool: + path = os.path.realpath(path) + if not os.path.isdir(path): + return False + if os.path.isfile(os.path.join(path, _DMD_COMPLETE_MARKER)): + return True + + fs_model_dir = os.path.join(path, "fake_score", "model") + fs_opt_metadata = os.path.join(path, "fake_score_optimizer", "optim", ".metadata") + dmd_state = os.path.join(path, "dmd_state.pt") + complete = ( + self._dir_has_regular_file(fs_model_dir) + and os.path.isfile(fs_opt_metadata) + and os.path.isfile(dmd_state) + ) + if not complete: + return False + + if self._cfg_gan_enabled(): + return os.path.isfile(os.path.join(path, "discriminator.pt")) and os.path.isfile( + os.path.join(path, "discriminator_optimizer.pt") + ) + return True + + def _cfg_gan_enabled(self) -> bool: + cfg = getattr(self, "cfg", None) + if cfg is None: + return False + try: + return float(cfg.get("dmd2.gan_loss_weight_gen", 0.0) or 0.0) > 0 + except (TypeError, ValueError): + return False + + @staticmethod + def _dir_has_regular_file(path: str) -> bool: + if not os.path.isdir(path): + return False + try: + with os.scandir(path) as entries: + return any(entry.is_file() for entry in entries) + except OSError: + return False + + @staticmethod + def _checkpoint_step(path: str) -> int: + match = re.search(r"_step_(\d+)$", os.path.basename(os.path.realpath(path))) + return int(match.group(1)) if match else -1 + + # ------------------------------------------------------------------ # + # Helpers — teacher / fake_score loading, DMDConfig resolution # + # ------------------------------------------------------------------ # + + def _load_frozen_teacher(self) -> nn.Module: + """Load a second copy of the pretrained transformer, frozen + FSDP2-sharded. + + The same pretrained path + ``parallel_scheme`` as the student. Setting + ``load_for_training=False`` walks the parameters once and flips + ``requires_grad=False`` after FSDP2 wrapping; we also call ``.eval()`` on the + returned module just to be defensive. + """ + parallel_scheme = self._build_parallel_scheme_snapshot() + pipe, _ = NeMoAutoDiffusionPipeline.from_pretrained( + self.model_id, + torch_dtype=self.bf16, + device=self.device, + parallel_scheme=parallel_scheme, + components_to_load=["transformer"], + load_for_training=False, + low_cpu_mem_usage=True, + ) + teacher = pipe.transformer + teacher.eval() + for p in teacher.parameters(): + p.requires_grad_(False) + return teacher + + def _build_discriminator(self) -> nn.Module | None: + """Construct the Discriminator_ImageDiT when GAN is enabled. + + Returns ``None`` when ``dmd2.gan_loss_weight_gen`` is zero so the + DMDPipeline runs without a discriminator (any run with the GAN branch disabled). + """ + gan_weight = float(self.cfg.get("dmd2.gan_loss_weight_gen", 0.0) or 0.0) + if gan_weight <= 0.0: + return None + + # GAN-specific knobs read directly from the YAML so callers don't have to + # touch the built-in DMDConfig recipe just to flip feature indices. + feature_indices = self.cfg.get( + "dmd2.gan_feature_indices", [30] + ) # middle of Qwen-Image's 60 blocks + num_blocks = int(self.cfg.get("dmd2.gan_num_blocks", 60)) + inner_dim = int(self.cfg.get("dmd2.gan_inner_dim", 3072)) + + disc = Discriminator_ImageDiT( + feature_indices={int(i) for i in feature_indices}, + num_blocks=num_blocks, + inner_dim=inner_dim, + ) + disc.to(device=self.device, dtype=self.bf16) + disc.train() + for p in disc.parameters(): + p.requires_grad_(True) + if is_main_process(): + logging.info( + "[DMD2] Built discriminator: %s | num_features=%d num_blocks=%d inner_dim=%d " + "params=%d", + type(disc).__name__, + disc.num_features, + num_blocks, + inner_dim, + sum(p.numel() for p in disc.parameters()), + ) + return disc + + def _build_discriminator_optimizer(self) -> torch.optim.Optimizer | None: + """AdamW on the discriminator. No FSDP wrap — manual grad all-reduce keeps it simple.""" + if self._discriminator is None: + return None + lr = float(self.cfg.get("dmd2.discriminator_lr", 1.0e-5) or 1.0e-5) + opt = torch.optim.AdamW( + self._discriminator.parameters(), + lr=lr, + weight_decay=0.01, + betas=(0.9, 0.999), # FastGen Qwen-Image DMD2 inherits BaseOptimizerConfig betas + ) + if is_main_process(): + logging.info("[DMD2] Built discriminator optimizer: AdamW lr=%g betas=(0.9, 0.999)", lr) + return opt + + def _attach_gan_feature_capture(self) -> None: + """Install Qwen-Image feature-capture hooks on the teacher when GAN is enabled. + + Reads the latent resolution from the dataloader so the hook can reshape + ``[B, num_image_patches, 3072]`` into ``[B, 3072, H_lat//2, W_lat//2]``. + Mock dataloader → spatial_h/spatial_w from the YAML. Real dataloader → + base_resolution / vae_scale. + """ + feature_indices = list(self.cfg.get("dmd2.gan_feature_indices", [30])) + + # Resolve h_lat / w_lat. Mock has it in the YAML; real cache uses the + # configured base_resolution divided by the VAE 8x downsample. + # Convert the dataloader subtree to a plain dict — AutoModel's ConfigNode + # doesn't expose deep dotted paths like ``data.dataloader.spatial_h``. + dl_node = self.cfg.get("data.dataloader", None) + if dl_node is not None and hasattr(dl_node, "to_dict"): + dl_dict = dl_node.to_dict() + elif dl_node is not None: + try: + dl_dict = dict(dl_node) + except (TypeError, ValueError): + dl_dict = {} + else: + dl_dict = {} + + spatial_h = dl_dict.get("spatial_h") + spatial_w = dl_dict.get("spatial_w") + base_resolution = dl_dict.get("base_resolution") + if spatial_h is not None and spatial_w is not None: + h_lat = int(spatial_h) + w_lat = int(spatial_w) + elif base_resolution is not None: + h_lat = int(base_resolution[0]) // 8 + w_lat = int(base_resolution[1]) // 8 + else: + # Fallback: hope it's 64x64 (512px image). Smoke tests pin this explicitly. + h_lat, w_lat = 64, 64 + if is_main_process(): + logging.warning( + "[DMD2] Could not infer h_lat/w_lat from data.dataloader; defaulting to 64x64." + ) + + qwen_image_plugin.attach_feature_capture( + self._teacher, + feature_indices=feature_indices, + h_lat=h_lat, + w_lat=w_lat, + ) + if is_main_process(): + logging.info( + "[DMD2] Attached GAN feature capture: indices=%s h_lat=%d w_lat=%d", + feature_indices, + h_lat, + w_lat, + ) + + def _load_fake_score(self) -> nn.Module: + """Load a third copy, trainable. Weights start identical to the teacher.""" + parallel_scheme = self._build_parallel_scheme_snapshot() + pipe, _ = NeMoAutoDiffusionPipeline.from_pretrained( + self.model_id, + torch_dtype=self.bf16, + device=self.device, + parallel_scheme=parallel_scheme, + components_to_load=["transformer"], + load_for_training=True, + low_cpu_mem_usage=True, + ) + fake_score = pipe.transformer + fake_score.train() + for p in fake_score.parameters(): + p.requires_grad_(True) + return fake_score + + def _build_parallel_scheme_snapshot(self) -> dict[str, dict[str, Any]]: + """Reconstruct the FSDP2 manager_args used for the student. + + Mirrors ``build_model_and_optimizer`` in ``nemo_automodel.recipes.diffusion.train``. + We can't capture the student's ``parallel_scheme`` directly (the parent doesn't + stash it), so we rebuild it from the same YAML knobs the parent consumed. + """ + from torch.distributed.fsdp import MixedPrecisionPolicy + + fsdp_cfg = self.cfg.get("fsdp", None) or {} + ddp_cfg = self.cfg.get("ddp", None) + + world_size = dist.get_world_size() if dist.is_initialized() else 1 + + if ddp_cfg is not None: + return { + "transformer": { + "_manager_type": "ddp", + "backend": ddp_cfg.get("backend", "nccl"), + "world_size": world_size, + "activation_checkpointing": ddp_cfg.get("activation_checkpointing", False), + } + } + + dp_size = fsdp_cfg.get("dp_size") + tp_size = fsdp_cfg.get("tp_size", 1) + cp_size = fsdp_cfg.get("cp_size", 1) + pp_size = fsdp_cfg.get("pp_size", 1) + if dp_size is None: + denom = max(1, tp_size * cp_size * pp_size) + dp_size = max(1, world_size // denom) + + return { + "transformer": { + "_manager_type": "fsdp2", + "dp_size": dp_size, + "dp_replicate_size": fsdp_cfg.get("dp_replicate_size", None), + "tp_size": tp_size, + "cp_size": cp_size, + "pp_size": pp_size, + "backend": "nccl", + "world_size": world_size, + "use_hf_tp_plan": fsdp_cfg.get("use_hf_tp_plan", False), + "activation_checkpointing": fsdp_cfg.get("activation_checkpointing", True), + "mp_policy": MixedPrecisionPolicy( + param_dtype=self.bf16, + reduce_dtype=torch.float32, + output_dtype=self.bf16, + ), + } + } + + def _resolve_pipeline_cls(self) -> type[DMDPipeline]: + """Pick the DMDPipeline subclass for the current backbone. + + Resolution order: + + 1. ``dmd2.pipeline_plugin`` in the YAML (explicit override, ``null`` for base). + 2. Substring match on ``model.pretrained_model_name_or_path`` + (e.g. ``Qwen-Image`` -> ``qwen_image`` plugin). + 3. Fall back to :class:`DMDPipeline`. + """ + explicit = self.cfg.get("dmd2.pipeline_plugin", None) + if explicit is None: + model_id_lc = (self.model_id or "").lower() + for needle, plugin_name in _PIPELINE_PLUGIN_BY_MODEL_SUBSTR: + if needle in model_id_lc: + explicit = plugin_name + break + if explicit in (None, "base", "DMDPipeline"): + return DMDPipeline + if explicit == "qwen_image": + # Imported lazily so ``base`` users don't pay the import cost. + from modelopt.torch.fastgen.plugins.qwen_image import QwenImageDMDPipeline + + return QwenImageDMDPipeline + raise ValueError( + f"Unknown dmd2.pipeline_plugin={explicit!r}. Supported: null/'base', 'qwen_image'." + ) + + def _resolve_pipeline_kwargs(self, pipeline_cls: type[DMDPipeline]) -> dict[str, Any]: + """Extra kwargs to forward to the pipeline subclass constructor (plugin-specific).""" + if pipeline_cls.__name__ == "QwenImageDMDPipeline": + # Optional ``guidance`` value passed to the transformer's guidance kwarg every + # call. Independent of DMDConfig.guidance_scale (which drives the negative- + # prompt CFG path on the teacher). Leave ``None`` to skip the embedding when + # the transformer was built with ``guidance_embeds=false`` (default for + # ``Qwen/Qwen-Image``). + return {"guidance": self.cfg.get("dmd2.qwen_image_guidance", None)} + return {} + + def _resolve_dmd_config(self) -> DMDConfig: + """Load the built-in fastgen recipe, then apply inline YAML overrides.""" + dmd_cfg_node = self.cfg.get("dmd2", None) + if dmd_cfg_node is None: + raise ValueError( + "Missing ``dmd2:`` block in the YAML config. Expected at minimum " + "``dmd2.recipe_path`` pointing at a fastgen DMDConfig recipe " + "(e.g. ``general/distillation/dmd2_qwen_image``)." + ) + dmd_dict = ( + dmd_cfg_node.to_dict() if hasattr(dmd_cfg_node, "to_dict") else dict(dmd_cfg_node) + ) + + recipe_path = dmd_dict.pop("recipe_path", None) + if recipe_path is None: + raise ValueError( + "``dmd2.recipe_path`` is required — Phase 1 relies on the built-in " + "``modelopt_recipes`` path resolver to hydrate the full DMDConfig." + ) + base_config = mtf.load_dmd_config(recipe_path) + + # Filter overrides to the subset that actually corresponds to DMDConfig fields. + # Non-matching keys (e.g. ``fake_score_lr``, ``cfg_mode``) are kept as top-level + # recipe knobs and read via ``self.cfg.get("dmd2.")``. + overrides = {k: v for k, v in dmd_dict.items() if k in _DMD_CONFIG_OVERRIDE_KEYS} + if not overrides: + return base_config + # Deep-merge so a YAML block that overrides a single ``sample_t_cfg`` / ``ema`` + # sub-field keeps the recipe's other sub-fields — a shallow ``dict.update`` would + # replace the whole sub-config and silently reset its siblings to defaults. + # Re-validate the merged dict so the nested blocks become their Pydantic config + # objects instead of raw dicts. + merged = _deep_merge_dicts(base_config.model_dump(), overrides) + return DMDConfig.model_validate(merged) + + def _build_fake_score_optimizer(self) -> torch.optim.Optimizer: + """AdamW on fake_score params. LR defaults to student LR; overridable via YAML.""" + fs_lr = self.cfg.get("dmd2.fake_score_lr", None) + if fs_lr is None: + fs_lr = self.learning_rate + optimizer_cfg = self.cfg.get("optim.optimizer", {}) or {} + optimizer_cfg = ( + optimizer_cfg.to_dict() if hasattr(optimizer_cfg, "to_dict") else dict(optimizer_cfg) + ) + weight_decay = optimizer_cfg.get("weight_decay", 0.01) + betas = tuple(optimizer_cfg.get("betas", (0.9, 0.999))) + + trainable_params = [p for p in self._fake_score.parameters() if p.requires_grad] + if not trainable_params: + raise RuntimeError("No trainable parameters found in fake_score.") + return torch.optim.AdamW(trainable_params, lr=fs_lr, weight_decay=weight_decay, betas=betas) + + # ------------------------------------------------------------------ # + # Inner helpers # + # ------------------------------------------------------------------ # + + def _set_grad_requirements(self, is_student_phase: bool) -> None: + """Toggle train/eval + requires_grad across modules for the active phase. + + Mirrors FastGen's ``_setup_grad_requirements`` (``dmd2.py`` lines 67-77), + INCLUDING the discriminator toggle that was previously omitted. + + Why the discriminator toggle matters: ``compute_student_loss`` calls + ``self.discriminator(fake_feat)`` for the ``gan_gen`` term, so the + discriminator is in the student-phase backward graph. With its params + left at ``requires_grad=True``, ``total.backward()`` allocates and + fills ``.grad`` for every discriminator parameter — gradients which + the student optimizer never consumes and which the next discriminator + ``zero_grad(set_to_none=True)`` simply wipes. Freezing the discriminator + during the student phase skips that wasted memory + backward compute + without changing any numerics (the student still receives the GAN + signal through the discriminator's input-side gradient, which doesn't + require the discriminator's own params to be in the autograd graph). + + Called every step; cheap enough that we don't bother caching the last state. + """ + if is_student_phase: + self.model.train() + for p in self.model.parameters(): + p.requires_grad_(True) + self._fake_score.eval() + for p in self._fake_score.parameters(): + p.requires_grad_(False) + if self._discriminator is not None: + self._discriminator.eval() + for p in self._discriminator.parameters(): + p.requires_grad_(False) + else: + self.model.eval() + for p in self.model.parameters(): + p.requires_grad_(False) + self._fake_score.train() + for p in self._fake_score.parameters(): + p.requires_grad_(True) + if self._discriminator is not None: + self._discriminator.train() + for p in self._discriminator.parameters(): + p.requires_grad_(True) + + def _prepare_micro_batch( + self, micro_batch: dict[str, Any] + ) -> tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor | None, + torch.Tensor | None, + torch.Tensor | None, + ]: + """Extract latents, noise, text conditioning, and optional masks from a batch. + + Accepts both 5D ``video_latents`` and 4D ``image_latents`` + (Qwen-Image / Flux / SD3). Mirrors the key dispatch in + ``nemo_automodel.components.flow_matching.pipeline.FlowMatchingPipeline.step``. + + ``negative_text_embeddings`` is optional — present when the dataloader + supplies it (mock T2I, real cache with precomputed empty-prompt + embedding) and consumed by ``compute_student_loss`` only when CFG is + enabled (``dmd2.guidance_scale is not None``). + """ + if "image_latents" in micro_batch: + latents = micro_batch["image_latents"].to(self.device, dtype=self.bf16) + elif "video_latents" in micro_batch: + latents = micro_batch["video_latents"].to(self.device, dtype=self.bf16) + else: + raise KeyError( + "Batch must contain either 'image_latents' (4D) or 'video_latents' (5D). " + f"Got keys: {sorted(micro_batch.keys())}." + ) + text_embeds = micro_batch["text_embeddings"].to(self.device, dtype=self.bf16) + if text_embeds.ndim == 2: + text_embeds = text_embeds.unsqueeze(0) + text_mask = micro_batch.get("text_embeddings_mask") + if text_mask is not None: + text_mask = text_mask.to(self.device) + if text_mask.ndim == 1: + text_mask = text_mask.unsqueeze(0) + negative_text_embeds = micro_batch.get("negative_text_embeddings") + if negative_text_embeds is not None: + negative_text_embeds = negative_text_embeds.to(self.device, dtype=self.bf16) + if negative_text_embeds.ndim == 2: + negative_text_embeds = negative_text_embeds.unsqueeze(0) + negative_text_mask = micro_batch.get("negative_text_embeddings_mask") + if negative_text_mask is not None: + negative_text_mask = negative_text_mask.to(self.device) + if negative_text_mask.ndim == 1: + negative_text_mask = negative_text_mask.unsqueeze(0) + # Fresh noise per micro-batch — DMD2 samples noise independently at each loss call. + noise = torch.randn_like(latents) + return latents, noise, text_embeds, text_mask, negative_text_embeds, negative_text_mask + + def _log_step( + self, + *, + global_step: int, + is_student_phase: bool, + group_loss: float, + grad_norm: float, + vsd_loss: float | None, + disc_loss: float | None = None, + ) -> None: + """Log a single step. Stdout always; wandb when the parent set it up.""" + phase = "student" if is_student_phase else "fake_score" + + # Stdout + suffix = f" vsd={vsd_loss:.4f}" if vsd_loss is not None else "" + if disc_loss is not None: + suffix += f" disc={disc_loss:.4f}" + logging.info( + "[STEP %d] phase=%s loss=%.4f grad_norm=%.4f%s lr=%.2e", + global_step, + phase, + group_loss, + grad_norm, + suffix, + self.optimizer.param_groups[0]["lr"], + ) + + # wandb + try: + import wandb + + if wandb.run is not None: + log_dict: dict[str, Any] = { + f"{phase}/loss": group_loss, + f"{phase}/grad_norm": grad_norm, + "global_step": global_step, + "lr_student": self.optimizer.param_groups[0]["lr"], + "lr_fake_score": self._fake_score_optimizer.param_groups[0]["lr"], + } + if vsd_loss is not None: + log_dict["student/vsd"] = vsd_loss + if disc_loss is not None: + log_dict["discriminator/loss"] = disc_loss + wandb.log(log_dict, step=global_step) + except Exception: + # wandb not installed or not initialised — silent no-op. + pass + + def _dmd_config_summary(self) -> str: + """Compact one-line summary of the active DMDConfig for startup logging.""" + cfg = self._dmd_config + t_list = cfg.sample_t_cfg.t_list if cfg.sample_t_cfg is not None else None + return ( + f"pred_type={cfg.pred_type} fake_score_pred_type={cfg.fake_score_pred_type} " + f"num_train_timesteps={cfg.num_train_timesteps} " + f"student_update_freq={cfg.student_update_freq} " + f"student_sample_steps={cfg.student_sample_steps} " + f"student_sample_type={cfg.student_sample_type} " + f"backward_simulation={cfg.backward_simulation} " + f"t_list={t_list} " + f"gan_loss_weight_gen={cfg.gan_loss_weight_gen} " + f"guidance_scale={cfg.guidance_scale} ema={'on' if cfg.ema is not None else 'off'}" + ) + + def _dmd_full_config_log(self) -> str: + """Full multi-line dump of every DMD2 parameter for startup tracing. + + Two sections: the resolved DMDConfig (every Pydantic field, including + nested ``sample_t_cfg`` and ``ema`` blocks) and the recipe-side keys + under ``dmd2:`` that aren't DMDConfig fields (e.g. ``fake_score_lr``, + ``gan_feature_indices``, ``pipeline_plugin``). Combined they cover + every knob that ends up driving the DMD2 method at runtime. + """ + cfg = self._dmd_config + dmd_node = self.cfg.get("dmd2", {}) or {} + if hasattr(dmd_node, "to_dict"): + dmd_node = dmd_node.to_dict() + else: + dmd_node = dict(dmd_node) + recipe_extras = { + k: v + for k, v in dmd_node.items() + if k not in _DMD_CONFIG_OVERRIDE_KEYS and k != "recipe_path" + } + combined = { + "DMDConfig_resolved": cfg.model_dump(), + "recipe_side_extras": recipe_extras, + } + return json.dumps(combined, indent=2, default=str) diff --git a/examples/diffusers/fastgen/export_diffusers_qwen_image.py b/examples/diffusers/fastgen/export_diffusers_qwen_image.py new file mode 100644 index 00000000000..65a17c9c3c0 --- /dev/null +++ b/examples/diffusers/fastgen/export_diffusers_qwen_image.py @@ -0,0 +1,202 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Export a DMD2 student as a full diffusers-loadable QwenImagePipeline dir. + +The §10 safetensors addendum produces a transformer-only dir: + + epoch_0_step_N/model/consolidated/ + config.json + diffusion_pytorch_model.safetensors.index.json + model-00001-of-00001.safetensors + +That dir is loadable by ``QwenImageTransformer2DModel.from_pretrained`` but NOT +by ``QwenImagePipeline.from_pretrained`` or ``DiffusionPipeline.from_pretrained`` +because it lacks ``model_index.json`` and the sibling component dirs +(``vae/``, ``text_encoder/``, ``tokenizer/``, ``scheduler/``). + +This utility assembles a full pipeline dir by: + + / + model_index.json (copied from base) + transformer/ (symlinked or copied from the consolidated student) + vae/ (symlinked from base Qwen-Image) + text_encoder/ (symlinked from base Qwen-Image) + tokenizer/ (symlinked from base Qwen-Image) + scheduler/ (symlinked from base Qwen-Image) + +After this runs, the dir loads with ``QwenImagePipeline.from_pretrained(output_dir)`` +or ``DiffusionPipeline.from_pretrained(output_dir)``. + +Symlinks are the default — the base Qwen-Image components are huge (text encoder +alone is ~12 GB) and never change between DMD2 students, so copying them per +checkpoint wastes disk. Use ``--copy`` if the output dir must be portable. + +Usage:: + + python export_diffusers_qwen_image.py \\ + --student_path /path/to/checkpoint/epoch_0_step_500/model/consolidated \\ + --base_pipeline_path Qwen/Qwen-Image \\ + --output_dir /path/to/output/qwen_image_dmd2 \\ + [--copy] + +Smoke test (``--verify``) loads the assembled dir via QwenImagePipeline and +checks the transformer config matches the student's. +""" + +from __future__ import annotations + +import argparse +import json +import logging +import os +import shutil +import sys +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from pathlib import Path + +logger = logging.getLogger(__name__) + +# Components borrowed from the base Qwen-Image checkpoint. The trained +# transformer replaces the corresponding entry. +BASE_COMPONENTS = ("vae", "text_encoder", "tokenizer", "scheduler") + + +def _link_or_copy(src: str, dst: str, copy: bool) -> None: + if os.path.lexists(dst): + if os.path.islink(dst) or os.path.isfile(dst): + os.remove(dst) + else: + shutil.rmtree(dst) + if copy: + if os.path.isdir(src): + shutil.copytree(src, dst, symlinks=True) + else: + shutil.copy2(src, dst) + else: + os.symlink(os.path.abspath(src), dst) + + +def export_diffusers( + student_path: str | Path, + base_pipeline_path: str | Path, + output_dir: str | Path, + copy: bool = False, +) -> None: + student_path = str(student_path) + base_pipeline_path = str(base_pipeline_path) + output_dir = str(output_dir) + + if not os.path.isdir(student_path): + raise FileNotFoundError(f"student_path is not a directory: {student_path}") + if not os.path.isdir(base_pipeline_path): + raise FileNotFoundError(f"base_pipeline_path is not a directory: {base_pipeline_path}") + base_index = os.path.join(base_pipeline_path, "model_index.json") + if not os.path.isfile(base_index): + raise FileNotFoundError(f"base pipeline missing model_index.json: {base_index}") + + os.makedirs(output_dir, exist_ok=True) + logger.info( + "[Diffusers-Export] Output dir: %s (mode=%s)", output_dir, "copy" if copy else "symlink" + ) + + # 1. model_index.json — copy verbatim (the class registry is the same + # whether the transformer weights are live or DMD-distilled). + dst_index = os.path.join(output_dir, "model_index.json") + with open(base_index) as f: + index = json.load(f) + with open(dst_index, "w") as f: + json.dump(index, f, indent=2) + logger.info("[Diffusers-Export] Wrote %s", dst_index) + + # 2. transformer/ — link/copy from the consolidated student. + dst_transformer = os.path.join(output_dir, "transformer") + _link_or_copy(student_path, dst_transformer, copy=copy) + logger.info("[Diffusers-Export] transformer/ <- %s", student_path) + + # 3. vae / text_encoder / tokenizer / scheduler — link/copy from base. + for comp in BASE_COMPONENTS: + src = os.path.join(base_pipeline_path, comp) + dst = os.path.join(output_dir, comp) + if not os.path.isdir(src): + raise FileNotFoundError(f"base pipeline missing component: {src}") + _link_or_copy(src, dst, copy=copy) + logger.info("[Diffusers-Export] %s/ <- %s", comp, src) + + logger.info( + "[Diffusers-Export] Done. Load via QwenImagePipeline.from_pretrained(%r)", output_dir + ) + + +def _verify(output_dir: str) -> dict: + """Load via QwenImagePipeline.from_pretrained and report a small status dict.""" + import torch + from diffusers import QwenImagePipeline + + logger.info("[Diffusers-Export-Verify] Loading %s via QwenImagePipeline", output_dir) + pipe = QwenImagePipeline.from_pretrained(output_dir, torch_dtype=torch.bfloat16) + n_transformer_params = sum(p.numel() for p in pipe.transformer.parameters()) + n_vae_params = sum(p.numel() for p in pipe.vae.parameters()) + n_text_params = sum(p.numel() for p in pipe.text_encoder.parameters()) + stats = { + "loaded_class": type(pipe).__name__, + "transformer_class": type(pipe.transformer).__name__, + "transformer_params": int(n_transformer_params), + "vae_class": type(pipe.vae).__name__, + "vae_params": int(n_vae_params), + "text_encoder_class": type(pipe.text_encoder).__name__, + "text_encoder_params": int(n_text_params), + "tokenizer_class": type(pipe.tokenizer).__name__, + "scheduler_class": type(pipe.scheduler).__name__, + } + return stats + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--student_path", + required=True, + help="Consolidated student dir (e.g. .../epoch_0_step_5/model/consolidated)", + ) + parser.add_argument( + "--base_pipeline_path", + required=True, + help="Base Qwen-Image pipeline dir (vae / text_encoder / tokenizer / scheduler source)", + ) + parser.add_argument("--output_dir", required=True, help="Where to write the full pipeline dir") + parser.add_argument( + "--copy", + action="store_true", + help="Copy components instead of symlinking. Off by default to save disk.", + ) + parser.add_argument( + "--verify", action="store_true", help="Re-load via QwenImagePipeline.from_pretrained" + ) + args = parser.parse_args() + + logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") + export_diffusers( + student_path=args.student_path, + base_pipeline_path=args.base_pipeline_path, + output_dir=args.output_dir, + copy=args.copy, + ) + if args.verify: + stats = _verify(args.output_dir) + print(json.dumps(stats, indent=2)) + sys.exit(0) diff --git a/examples/diffusers/fastgen/inference_dmd2_qwen_image.py b/examples/diffusers/fastgen/inference_dmd2_qwen_image.py new file mode 100644 index 00000000000..7748eb15ec3 --- /dev/null +++ b/examples/diffusers/fastgen/inference_dmd2_qwen_image.py @@ -0,0 +1,528 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Inference pipeline for DMD2-trained Qwen-Image students. + +Loads the consolidated safetensors transformer (from a §10 checkpoint with +``model_save_format=safetensors`` + ``save_consolidated=true``) plus the base +Qwen-Image VAE / text-encoder / tokenizer, and exposes a diffusers-style +``pipe(prompt=...).images[0]`` call that runs the DMD few-step sampler. + +Math is bit-aligned with the training-time ``_build_student_input`` in +``modelopt/torch/fastgen/methods/dmd.py``: + + Single-step (Phase 1 default): + x_T = noise * max_t # initial latent + v = student(x_T, t=max_t, text_emb) # one forward + x_0 = x_T - max_t * v # RF identity + image = vae.decode(x_0) + + Multi-step (Phase 2, ``num_inference_steps > 1``): + x = noise * t_list[0] # initial latent at t_max + for (t_cur, t_next) in zip(t_list[:-1], t_list[1:]): + v = student(x, t=t_cur, text_emb) # flow at t_cur + x_0 = x - t_cur * v # RF identity → x_0 estimate + if t_next > 0: + eps = (x - (1 - t_cur) * x_0) / t_cur # ODE: invert RF forward + x = (1 - t_next) * x_0 + t_next * eps # re-noise to t_next + else: + x = x_0 # final step + image = vae.decode(x) + +Usage:: + + from inference_dmd2_qwen_image import QwenImageDMDInferencePipeline + import torch + + pipe = QwenImageDMDInferencePipeline.from_pretrained( + student_path="/path/to/checkpoint/epoch_0_step_500/model/consolidated", + base_pipeline_path="Qwen/Qwen-Image", + ema_path=None, # or "…/epoch_0_step_5/ema_shadow.pt" + torch_dtype=torch.bfloat16, + ).to("cuda") + + image = pipe( + prompt="a small red cube on a white table", + num_inference_steps=1, + height=512, + width=512, + generator=torch.Generator("cuda").manual_seed(42), + ).images[0] + image.save("dmd2_smoke.png") +""" + +from __future__ import annotations + +import itertools +import logging +import os +from dataclasses import dataclass +from typing import TYPE_CHECKING + +import torch +from diffusers import QwenImagePipeline, QwenImageTransformer2DModel +from diffusers.utils.torch_utils import randn_tensor + +if TYPE_CHECKING: + from pathlib import Path + +logger = logging.getLogger(__name__) + + +@dataclass +class QwenImageDMDOutput: + """Container for inference outputs — mirrors diffusers' pipeline outputs.""" + + images: list + + +class QwenImageDMDInferencePipeline: + """Thin inference wrapper around a DMD2-trained Qwen-Image student. + + Wraps a stock ``diffusers.QwenImagePipeline`` whose ``transformer`` field + has been swapped for our trained student. Re-uses the pipeline's VAE, + text encoder, tokenizer, and image-processor for everything *except* the + denoising loop, which is replaced by the DMD few-step sampler. + """ + + def __init__( + self, + base_pipeline: QwenImagePipeline, + max_t: float = 0.999, + ) -> None: + self._pipe = base_pipeline + self.max_t = max_t + + # ------------------------------------------------------------------ # + # Loading # + # ------------------------------------------------------------------ # + + @classmethod + def from_pretrained( + cls, + student_path: str | Path, + base_pipeline_path: str | Path, + ema_path: str | Path | None = None, + torch_dtype: torch.dtype = torch.bfloat16, + max_t: float = 0.999, + ) -> QwenImageDMDInferencePipeline: + """Load the student + base Qwen-Image components. + + Args: + student_path: A consolidated dir from §10's safetensors save — must + contain ``config.json``, ``diffusion_pytorch_model.safetensors.index.json``, + and one or more ``*.safetensors`` shards. Loadable directly via + ``QwenImageTransformer2DModel.from_pretrained``. + base_pipeline_path: The base Qwen-Image checkpoint (e.g. + ``Qwen/Qwen-Image`` or a local snapshot). Used only for the + ``vae`` / ``text_encoder`` / ``tokenizer`` / ``image_processor``; + the transformer is replaced. + ema_path: Optional ``ema_shadow.pt`` produced by ``_save_dmd_extras``. + If provided, the EMA shadow weights are overlaid onto the + student after the safetensors load. EMA usually yields cleaner + samples than the live student. + torch_dtype: dtype for the student + VAE. ``bfloat16`` is what we + trained with. + max_t: Initial timestep for the 1-step sampler. Must match the + recipe's ``sample_t_cfg.max_t`` (default 0.999). + """ + student_path = str(student_path) + base_pipeline_path = str(base_pipeline_path) + if not os.path.isdir(student_path): + raise FileNotFoundError(f"student_path is not a directory: {student_path}") + if not os.path.isdir(base_pipeline_path): + raise FileNotFoundError(f"base_pipeline_path is not a directory: {base_pipeline_path}") + + logger.info("[DMD2-Inference] Loading trained student from %s", student_path) + student = QwenImageTransformer2DModel.from_pretrained(student_path, torch_dtype=torch_dtype) + + if ema_path is not None: + logger.info("[DMD2-Inference] Overlaying EMA shadow from %s", ema_path) + ema_state = torch.load(str(ema_path), map_location="cpu", weights_only=False) + shadow = ( + ema_state.get("shadow", ema_state) if isinstance(ema_state, dict) else ema_state + ) + if not isinstance(shadow, dict): + raise ValueError( + f"ema_shadow.pt content has unexpected type {type(shadow).__name__}; " + "expected dict[str, Tensor]." + ) + missing, unexpected = student.load_state_dict(shadow, strict=False) + if unexpected: + logger.warning( + "[DMD2-Inference] EMA overlay had %d unexpected keys (first: %s)", + len(unexpected), + unexpected[:3], + ) + if missing: + logger.warning( + "[DMD2-Inference] EMA overlay missed %d student keys (first: %s)", + len(missing), + missing[:3], + ) + + student.eval() + + logger.info( + "[DMD2-Inference] Loading base Qwen-Image pipeline from %s (transformer replaced)", + base_pipeline_path, + ) + # Passing transformer= bypasses loading the original transformer from disk; + # the rest (vae, text_encoder, tokenizer, scheduler, image_processor) loads + # normally. + pipe = QwenImagePipeline.from_pretrained( + base_pipeline_path, + transformer=student, + torch_dtype=torch_dtype, + ) + + return cls(base_pipeline=pipe, max_t=max_t) + + def to(self, device: str | torch.device) -> QwenImageDMDInferencePipeline: + self._pipe.to(device) + return self + + @property + def device(self) -> torch.device: + return self._pipe.transformer.device + + @property + def dtype(self) -> torch.dtype: + return next(self._pipe.transformer.parameters()).dtype + + # ------------------------------------------------------------------ # + # Inference # + # ------------------------------------------------------------------ # + + @torch.no_grad() + def __call__( + self, + prompt: str | list, + negative_prompt: str | list | None = None, + num_inference_steps: int = 1, + guidance_scale: float = 1.0, + height: int = 1024, + width: int = 1024, + num_images_per_prompt: int = 1, + generator: torch.Generator | None = None, + max_t: float | None = None, + t_list: list | None = None, + sample_type: str = "ode", + output_type: str = "pil", + max_sequence_length: int = 512, + ) -> QwenImageDMDOutput: + """Generate image(s) from the trained DMD2 student. + + ``num_inference_steps == 1`` runs the canonical Phase 1 single-step + sampler (``x_0 = x_T - max_t * v``). ``num_inference_steps > 1`` runs + the multi-step DMD unroll using ``t_list`` (or + ``linspace(max_t, 0, num_inference_steps + 1)`` if ``t_list`` is None). + + ``sample_type`` selects between deterministic (``"ode"``, recovers eps + from x_0 via RF identity) and stochastic (``"sde"``, fresh noise per + step). Matches FastGen's ``student_sample_type``. + + ``guidance_scale != 1.0`` activates inference-time classifier-free + guidance: the student is called twice (positive + negative prompt) per + step and the two flows are blended as ``v = v_neg + s*(v_pos - v_neg)``. + ``negative_prompt`` defaults to ``""`` (the canonical empty prompt) when + unset and CFG is engaged. + + **Note on DMD2 students trained with CFG**: a student trained with + ``dmd2.guidance_scale=4.0`` has *already* internalised CFG (its + single-pass output is the CFG-augmented teacher target). Pass + ``guidance_scale=1.0`` at inference to avoid double-CFG. Use + ``guidance_scale > 1.0`` only for students that were trained without + CFG (``dmd2.guidance_scale=null``). + """ + if sample_type not in ("ode", "sde"): + raise ValueError(f"sample_type must be 'ode' or 'sde', got {sample_type!r}") + do_cfg = guidance_scale != 1.0 + if do_cfg and negative_prompt is None: + # Default to the canonical empty unconditional prompt. + negative_prompt = "" + + max_t = float(max_t) if max_t is not None else float(self.max_t) + pipe = self._pipe + device = self.device + dtype = self.dtype + + # ---- 1. Resolve t_list ----------------------------------------------- + if num_inference_steps == 1: + schedule = [max_t, 0.0] + elif t_list is not None: + if len(t_list) != num_inference_steps + 1: + raise ValueError( + f"t_list must have num_inference_steps+1 entries " + f"(got {len(t_list)} for num_inference_steps={num_inference_steps})" + ) + schedule = [float(t) for t in t_list] + if abs(schedule[-1]) > 1e-6: + raise ValueError( + f"t_list must end at 0.0 (got {schedule[-1]}); the final step lands on x_0." + ) + else: + # Default: linear schedule from max_t to 0 (matches FastGen's + # torch.linspace(max_t, 0, sample_steps + 1) fallback). + schedule = torch.linspace(max_t, 0.0, num_inference_steps + 1).tolist() + + # ---- 2. Encode prompt(s) --------------------------------------------- + prompt_embeds, prompt_embeds_mask = pipe.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + neg_prompt_embeds = None + neg_prompt_embeds_mask = None + if do_cfg: + neg_prompt_embeds, neg_prompt_embeds_mask = pipe.encode_prompt( + prompt=negative_prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + txt_seq_lens = ( + prompt_embeds_mask.sum(dim=1).int().tolist() if prompt_embeds_mask is not None else None + ) + neg_txt_seq_lens = ( + neg_prompt_embeds_mask.sum(dim=1).int().tolist() + if neg_prompt_embeds_mask is not None + else None + ) + + # ---- 3. Build initial noisy latents at t = schedule[0] --------------- + if isinstance(prompt, str): + batch_size = 1 + else: + batch_size = len(prompt) + batch_size = batch_size * num_images_per_prompt + + num_channels_latents = pipe.transformer.config.in_channels // 4 # 64 // 4 = 16 + h_lat = 2 * (height // (pipe.vae_scale_factor * 2)) + w_lat = 2 * (width // (pipe.vae_scale_factor * 2)) + latent_shape = (batch_size, 1, num_channels_latents, h_lat, w_lat) + + # DMD initial latents: noise * schedule[0] (RF: sigma(t0) = t0). + noise = randn_tensor(latent_shape, generator=generator, device=device, dtype=dtype) + latents_5d = noise * schedule[0] + + latents_packed = pipe._pack_latents( + latents_5d, batch_size, num_channels_latents, h_lat, w_lat + ) + img_shapes = [[(1, h_lat // 2, w_lat // 2)]] * batch_size + + # ---- 4. DMD few-step unroll ----------------------------------------- + x_packed = latents_packed + for t_cur, t_next in itertools.pairwise(schedule): + timestep = torch.tensor([t_cur], device=device, dtype=dtype).expand(batch_size) + flow_packed = pipe.transformer( + hidden_states=x_packed, + encoder_hidden_states=prompt_embeds, + encoder_hidden_states_mask=prompt_embeds_mask, + timestep=timestep, + img_shapes=img_shapes, + txt_seq_lens=txt_seq_lens, + guidance=None, + return_dict=False, + )[0] + if do_cfg: + # CFG two-pass: ``v_cfg = v_neg + s*(v_pos - v_neg)``. Equivalent + # to the FastGen training-time formula + # ``teacher_pos + (s-1) * (teacher_pos - teacher_neg)`` once + # expanded. Engaged only when the caller passes + # ``guidance_scale != 1.0`` — DMD2 students trained with a + # non-null ``dmd2.guidance_scale`` already internalise CFG, so + # leave guidance_scale=1.0 for those. + neg_flow_packed = pipe.transformer( + hidden_states=x_packed, + encoder_hidden_states=neg_prompt_embeds, + encoder_hidden_states_mask=neg_prompt_embeds_mask, + timestep=timestep, + img_shapes=img_shapes, + txt_seq_lens=neg_txt_seq_lens, + guidance=None, + return_dict=False, + )[0] + flow_packed = ( + neg_flow_packed.to(torch.float64) + + float(guidance_scale) + * (flow_packed.to(torch.float64) - neg_flow_packed.to(torch.float64)) + ).to(dtype) + # RF identity: x_0 = x_t - t_cur * v (computed in fp64 for stability). + x0_packed = ( + x_packed.to(torch.float64) - float(t_cur) * flow_packed.to(torch.float64) + ).to(dtype) + + if t_next > 1e-6: + # Re-noise x_0 forward to t_next. + if sample_type == "ode": + # Deterministic: invert the RF forward to recover the implied eps, + # then re-noise. eps = (x_t - (1 - t_cur) * x_0) / t_cur + alpha_cur = 1.0 - float(t_cur) + eps_packed = ( + (x_packed.to(torch.float64) - alpha_cur * x0_packed.to(torch.float64)) + / max(float(t_cur), 1e-6) + ).to(dtype) + else: + # Stochastic: fresh Gaussian noise. + eps_packed = torch.randn( + x_packed.shape, generator=generator, device=device, dtype=dtype + ) + # RF forward: x_{t_next} = (1 - t_next) * x_0 + t_next * eps. + alpha_next = 1.0 - float(t_next) + x_packed = ( + alpha_next * x0_packed.to(torch.float64) + + float(t_next) * eps_packed.to(torch.float64) + ).to(dtype) + else: + # Last step: x_0 is the output. + x_packed = x0_packed + + # Unpack to 5D for VAE. + x0_5d = pipe._unpack_latents(x_packed, height, width, pipe.vae_scale_factor) + + # ---- 5. VAE decode --------------------------------------------------- + # Reverse the VAE-side scaling that diffusers applied at encoding time. + latents_mean = ( + torch.tensor(pipe.vae.config.latents_mean) + .view(1, pipe.vae.config.z_dim, 1, 1, 1) + .to(device=device, dtype=dtype) + ) + latents_std = 1.0 / torch.tensor(pipe.vae.config.latents_std).view( + 1, pipe.vae.config.z_dim, 1, 1, 1 + ).to(device=device, dtype=dtype) + x0_scaled = x0_5d / latents_std + latents_mean + + # vae.decode returns 5D; the trailing [:, :, 0] drops the temporal dim + # since Qwen-Image treats images as 1-frame videos. + image_5d = pipe.vae.decode(x0_scaled, return_dict=False)[0] + image_4d = image_5d[:, :, 0] # [B, C, H, W] + + images = pipe.image_processor.postprocess(image_4d, output_type=output_type) + return QwenImageDMDOutput(images=images) + + +# ---------------------------------------------------------------------------- # +# Standalone smoke-test driver. Validates end-to-end wiring against the §10 # +# safetensors checkpoint. Mock-data training means the image won't be # +# coherent — pass criterion is just "the pipeline produces a finite image # +# tensor and the file writes successfully". # +# ---------------------------------------------------------------------------- # + + +def _smoke_test( + student_path: str, + base_pipeline_path: str, + output_png: str, + ema_path: str | None = None, + prompt: str = "a small red cube on a white table", + height: int = 512, + width: int = 512, + seed: int = 42, +) -> None: + """Run a one-shot inference and dump a PNG. + + Writes a small JSON sidecar next to the PNG with shape/dtype/range stats + so the success criteria can be machine-checked. + """ + import json + + logging.basicConfig(level=logging.INFO) + pipe = QwenImageDMDInferencePipeline.from_pretrained( + student_path=student_path, + base_pipeline_path=base_pipeline_path, + ema_path=ema_path, + torch_dtype=torch.bfloat16, + ) + device = "cuda" if torch.cuda.is_available() else "cpu" + pipe = pipe.to(device) + gen = torch.Generator(device=device).manual_seed(seed) + + out = pipe( + prompt=prompt, + num_inference_steps=1, + height=height, + width=width, + generator=gen, + ) + image = out.images[0] + + # PIL image; sanity-check shape + range. + import numpy as np + + arr = np.array(image) + stats = { + "prompt": prompt, + "height": height, + "width": width, + "seed": seed, + "ema_overlay": ema_path is not None, + "image_shape": list(arr.shape), + "image_dtype": str(arr.dtype), + "image_min": int(arr.min()), + "image_max": int(arr.max()), + "image_mean": float(arr.mean()), + "image_std": float(arr.std()), + "is_finite": bool(np.isfinite(arr).all()), + "is_not_constant": bool(arr.std() > 0), + } + + os.makedirs(os.path.dirname(output_png), exist_ok=True) + image.save(output_png) + sidecar = output_png.replace(".png", "_stats.json") + with open(sidecar, "w") as f: + json.dump(stats, f, indent=2) + print(json.dumps(stats, indent=2)) + print(f"\nImage saved to: {output_png}") + print(f"Stats sidecar: {sidecar}") + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument( + "--student_path", + required=True, + help="Path to the consolidated safetensors student checkpoint " + "(e.g. .../epoch_0_step_500/model/consolidated).", + ) + parser.add_argument( + "--base_pipeline_path", + default="Qwen/Qwen-Image", + help="Base Qwen-Image pipeline (HF id or local snapshot) for the VAE / text-encoder / tokenizer.", + ) + parser.add_argument( + "--output_png", + default="./outputs/dmd2_sample.png", + ) + parser.add_argument("--ema_path", default=None) + parser.add_argument("--prompt", default="a small red cube on a white table") + parser.add_argument("--height", type=int, default=512) + parser.add_argument("--width", type=int, default=512) + parser.add_argument("--seed", type=int, default=42) + args = parser.parse_args() + + _smoke_test( + student_path=args.student_path, + base_pipeline_path=args.base_pipeline_path, + output_png=args.output_png, + ema_path=args.ema_path, + prompt=args.prompt, + height=args.height, + width=args.width, + seed=args.seed, + ) diff --git a/examples/diffusers/fastgen/requirements.txt b/examples/diffusers/fastgen/requirements.txt new file mode 100644 index 00000000000..1cba3ce7868 --- /dev/null +++ b/examples/diffusers/fastgen/requirements.txt @@ -0,0 +1,10 @@ +# Runtime requirements for the DMD2 Qwen-Image AutoModel example. +# Torch + diffusers are already pulled in via Model-Optimizer's ``[all]`` extras. +# The one thing that's NOT shipped with Model-Optimizer is nemo_automodel. + +# NeMo AutoModel (parent recipe, dataloader, FSDP2 wrapping). +# The diffusion extras install diffusers + accelerate with matching pins. +nemo_automodel[diffusion] + +# Optional but recommended for the smoke logs. +wandb diff --git a/modelopt/torch/fastgen/__init__.py b/modelopt/torch/fastgen/__init__.py new file mode 100644 index 00000000000..507acfddd72 --- /dev/null +++ b/modelopt/torch/fastgen/__init__.py @@ -0,0 +1,68 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Framework-agnostic diffusion step-distillation losses (FastGen port). + +``modelopt.torch.fastgen`` is a loss-computation library. It accepts already-built +``nn.Module`` references (student / teacher / fake-score / optional discriminator) and +returns scalar loss tensors. It does **not** load models, manage optimizers, wrap +anything as a ``DynamicModule``, or register itself in any mode registry. + +Typical usage with a YAML-driven config:: + + import modelopt.torch.fastgen as mtf + + student, teacher = build_wan_student_and_teacher(...) + fake_score = mtf.create_fake_score(teacher) + + cfg = mtf.load_dmd_config("general/distillation/dmd2_qwen_image") + + # If GAN is enabled, expose intermediate teacher features to the discriminator. + if cfg.gan_loss_weight_gen > 0: + mtf.plugins.qwen_image.attach_feature_capture(teacher, feature_indices=[30]) + + pipeline = mtf.DMDPipeline(student, teacher, fake_score, cfg, discriminator=disc) + + # Inside the training loop (framework-owned): + if step % cfg.student_update_freq == 0: + losses = pipeline.compute_student_loss( + latents, noise, text_embeds, negative_encoder_hidden_states=neg_embeds + ) + losses["total"].backward() + student_opt.step() + pipeline.update_ema() + else: + f = pipeline.compute_fake_score_loss(latents, noise, text_embeds) + f["total"].backward() + fake_score_opt.step() + if disc is not None: + d = pipeline.compute_discriminator_loss(latents, noise, text_embeds) + d["total"].backward() + disc_opt.step() +""" + +from . import flow_matching, losses, utils +from .config import * +from .ema import * +from .factory import * +from .loader import * +from .methods.dmd import * +from .pipeline import * + +# isort: off +# Plugins must be imported after the core exports so the plugin hooks can reference +# DMDPipeline if needed in the future; also matches the ordering used by +# modelopt.torch.distill. +from . import plugins diff --git a/modelopt/torch/fastgen/config.py b/modelopt/torch/fastgen/config.py new file mode 100644 index 00000000000..30d78b8720f --- /dev/null +++ b/modelopt/torch/fastgen/config.py @@ -0,0 +1,315 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Pydantic configuration classes for the fastgen distillation pipelines. + +Configurations are layered so a method-specific config (e.g. :class:`DMDConfig`) inherits +shared diffusion-distillation hyperparameters from :class:`DistillationConfig`. All classes +inherit :class:`modelopt.torch.opt.config.ModeloptBaseConfig`, which provides torch-safe +serialization and dict-like iteration. + +The default values in :class:`DMDConfig` mirror the FastGen Wan 2.2 5B experiment at +``FastGen/fastgen/configs/experiments/WanT2V/config_dmd2_wan22_5b.py``. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Literal + +from pydantic import Field, model_validator + +from modelopt.torch.opt.config import ModeloptBaseConfig, ModeloptField + +if TYPE_CHECKING: + from pathlib import Path + +__all__ = [ + "DMDConfig", + "DistillationConfig", + "EMAConfig", + "SampleTimestepConfig", +] + +PredType = Literal["x0", "eps", "v", "flow"] +TimeDistType = Literal["uniform", "logitnormal", "lognormal", "shifted", "polynomial"] + + +class SampleTimestepConfig(ModeloptBaseConfig): + """Timestep sampling distribution for diffusion training.""" + + time_dist_type: TimeDistType = ModeloptField( + default="shifted", + title="Timestep distribution", + description=( + "Distribution used to sample the training timestep ``t``. Rectified-flow models" + " typically use ``shifted`` (Wan 2.2) or ``logitnormal`` (SD3, Flux)." + ), + ) + min_t: float = ModeloptField( + default=0.001, + title="Minimum t", + description="Lower bound of the sampling range (clamped before use).", + ) + max_t: float = ModeloptField( + default=0.999, + title="Maximum t", + description="Upper bound of the sampling range (clamped before use).", + ) + shift: float = ModeloptField( + default=5.0, + title="Shift factor", + description="Shift factor for ``time_dist_type='shifted'``; must be >= 1.", + ) + p_mean: float = ModeloptField( + default=0.0, + title="Distribution mean (log-space)", + description="Mean of the underlying normal for ``logitnormal`` / ``lognormal``.", + ) + p_std: float = ModeloptField( + default=1.0, + title="Distribution std (log-space)", + description="Standard deviation of the underlying normal for ``logitnormal`` / ``lognormal``.", + ) + t_list: list[float] | None = ModeloptField( + default=None, + title="Multi-step student timesteps", + description=( + "Explicit timestep schedule used when ``DMDConfig.student_sample_steps > 1``." + " The final element must be ``0.0``." + ), + ) + + @model_validator(mode="after") + def _check_bounds(self) -> SampleTimestepConfig: + assert 0.0 <= self.min_t < self.max_t, ( + f"require 0 <= min_t < max_t, got min_t={self.min_t}, max_t={self.max_t}" + ) + assert self.shift >= 1.0, f"shift must be >= 1, got {self.shift}" + if self.t_list is not None: + assert len(self.t_list) >= 2, "t_list must contain at least 2 entries (including t=0)" + assert self.t_list[-1] == 0.0, f"t_list[-1] must be 0.0, got {self.t_list[-1]}" + return self + + +class EMAConfig(ModeloptBaseConfig): + """Exponential moving average (EMA) hyperparameters for the student network.""" + + decay: float = ModeloptField( + default=0.9999, + title="EMA decay", + description="Decay coefficient for ``type='constant'``. Ignored for ``halflife``/``power``.", + ) + type: Literal["constant", "halflife", "power"] = ModeloptField( + default="constant", + title="EMA decay schedule", + description="Schedule used to compute the per-step decay coefficient.", + ) + start_iter: int = ModeloptField( + default=0, + title="EMA start iteration", + description="Iteration at which EMA tracking begins (EMA is initialized from the live weights at this step).", + ) + gamma: float = ModeloptField( + default=16.97, + title="Power schedule gamma", + description="Exponent for ``type='power'`` (``beta = (1 - 1/iter)**(gamma + 1)``).", + ) + halflife_kimg: float = ModeloptField( + default=500.0, + title="Halflife (kimg)", + description="Halflife in thousands of images for ``type='halflife'``.", + ) + rampup_ratio: float | None = ModeloptField( + default=0.05, + title="Halflife rampup ratio", + description="Rampup fraction for ``type='halflife'``; pass ``None`` to disable rampup.", + ) + batch_size: int = ModeloptField( + default=1, + title="Effective batch size", + description="Per-step global batch size used to convert iterations to nimg for the halflife schedule.", + ) + fsdp2: bool = ModeloptField( + default=True, + title="FSDP2 enabled", + description="If True, the EMA uses ``DTensor.full_tensor()`` to gather sharded parameters before updating.", + ) + mode: Literal["full_tensor", "local_shard"] = ModeloptField( + default="full_tensor", + title="FSDP2 gather mode", + description=( + "``full_tensor`` performs an all_gather per parameter (higher memory, exact global EMA)." + " ``local_shard`` updates each rank's local DTensor shard in place (low memory fallback)." + ), + ) + dtype: Literal["float32", "bfloat16", "float16"] | None = ModeloptField( + default="float32", + title="EMA shadow dtype", + description=( + "Precision of the EMA parameter shadows. Defaults to ``float32`` so EMA updates" + " remain numerically meaningful even when the live model is bf16/fp16 (cf. FastGen," + " which instantiates its EMA module in the net's construction dtype — typically" + " fp32). Pass ``None`` to keep param shadows in the live parameter's dtype." + " Buffer shadows always track the live dtype regardless of this setting." + ), + ) + + +class DistillationConfig(ModeloptBaseConfig): + """Shared hyperparameters for diffusion step-distillation methods. + + Concrete methods subclass this config to add method-specific fields + (see :class:`DMDConfig`). + """ + + pred_type: PredType = ModeloptField( + default="flow", + title="Network prediction parameterization", + description="Quantity predicted by the teacher / student network.", + ) + guidance_scale: float | None = ModeloptField( + default=None, + title="CFG scale", + description="Classifier-free guidance scale. If ``None`` CFG is disabled.", + ) + # ``ModeloptField`` hard-asserts on ``default_factory``; use Pydantic's ``Field`` + # directly for this mutable sub-config so each DMDConfig instance gets its own + # SampleTimestepConfig instead of sharing a single mutable default. + sample_t_cfg: SampleTimestepConfig = Field( + default_factory=SampleTimestepConfig, + title="Timestep sampling", + description="Timestep distribution used for both the teacher forward and the VSD / DSM losses.", + ) + student_sample_steps: int = ModeloptField( + default=1, + title="Student inference steps", + description="Number of denoising steps the distilled student performs at inference.", + ) + student_sample_type: Literal["sde", "ode"] = ModeloptField( + default="ode", + title="Student sampling mode", + description=( + "Integrator used when unrolling the student over ``student_sample_steps > 1`` steps." + " Consumed by inference samplers and by DMDPipeline when" + " ``DMDConfig.backward_simulation`` is enabled." + ), + ) + num_train_timesteps: int | None = ModeloptField( + default=None, + title="Training-time discrete timestep count", + description=( + "If set, the pipeline rescales the continuous RF timestep ``t ∈ [0, 1]`` to" + " ``num_train_timesteps * t`` before passing it to the model. Matches the" + " diffusers convention used by Wan 2.2 / SD3 / Flux (``num_train_timesteps = 1000``)." + " Leave ``None`` when the model wrapper already handles the rescaling internally." + ), + ) + + +class DMDConfig(DistillationConfig): + """Hyperparameters for DMD / DMD2 distribution-matching distillation. + + Default values are tuned for Wan 2.2 5B; callers fine-tune them per model. + See ``FastGen/fastgen/configs/experiments/WanT2V/config_dmd2_wan22_5b.py``. + """ + + student_update_freq: int = ModeloptField( + default=5, + title="Student update frequency", + description=( + "One student step for every ``student_update_freq`` fake-score / discriminator steps." + " Matches FastGen's DMD2 alternation. Not read by DMDPipeline; the training loop is" + " expected to enforce the alternation." + ), + ) + fake_score_pred_type: PredType | None = ModeloptField( + default="x0", + title="Fake-score prediction parameterization", + description=( + "Parameterization used when training the fake score. If ``None`` falls back to" + " :attr:`DistillationConfig.pred_type`." + ), + ) + backward_simulation: bool = ModeloptField( + default=False, + title="Backward simulation", + description=( + "When True for multi-step students, build the selected student input by" + " no-grad unrolling the current student from the first schedule rung through" + " earlier rungs, then re-noising the generated x0 at the selected rung." + " When False, use FastGen's Qwen-style noised-real latent path." + ), + ) + gan_loss_weight_gen: float = ModeloptField( + default=0.0, + title="Generator GAN weight", + description="Weight of the GAN generator term in the student loss. ``0`` disables the GAN branch.", + ) + gan_use_same_t_noise: bool = ModeloptField( + default=False, + title="Share t/noise across real and fake", + description="If True, reuse the same ``t`` and ``eps`` for real and fake samples in the discriminator update.", + ) + gan_r1_reg_weight: float = ModeloptField( + default=0.0, + title="R1 regularization weight", + description=( + "Weight of the approximate-R1 regularization term for the discriminator update. ``0`` disables R1." + " Recommended range when enabled: 100-1000." + ), + ) + gan_r1_reg_alpha: float = ModeloptField( + default=0.1, + title="R1 regularization noise scale", + description=( + "Standard deviation of the perturbation applied to real data when computing the" + " approximate R1 term." + ), + ) + ema: EMAConfig | None = ModeloptField( + default=None, + title="Student EMA", + description=( + "If set, an exponential moving average of the student is maintained and updated" + " via ``DMDPipeline.update_ema``." + ), + ) + + @model_validator(mode="after") + def _check_gan(self) -> DMDConfig: + if self.gan_r1_reg_weight > 0 and self.gan_loss_weight_gen <= 0: + raise ValueError( + "gan_r1_reg_weight > 0 requires gan_loss_weight_gen > 0 (the discriminator must be enabled)." + ) + if self.backward_simulation: + if self.student_sample_steps <= 1: + raise ValueError("backward_simulation=True requires student_sample_steps > 1.") + if self.sample_t_cfg.t_list is None: + raise ValueError("backward_simulation=True requires sample_t_cfg.t_list to be set.") + return self + + @classmethod + def from_yaml(cls, config_file: str | Path) -> DMDConfig: + """Construct a :class:`DMDConfig` from a YAML file. + + Thin wrapper around :func:`modelopt.torch.fastgen.loader.load_dmd_config`. + The resolver searches the built-in ``modelopt_recipes/`` package first, then + the filesystem. Suffixes (``.yml`` / ``.yaml``) may be omitted. + """ + # Imported lazily to avoid a circular import between this module and + # ``modelopt.torch.fastgen.loader`` (which imports :class:`DMDConfig`). + from .loader import load_dmd_config + + return load_dmd_config(config_file) diff --git a/modelopt/torch/fastgen/discriminators.py b/modelopt/torch/fastgen/discriminators.py new file mode 100644 index 00000000000..8e55233d052 --- /dev/null +++ b/modelopt/torch/fastgen/discriminators.py @@ -0,0 +1,136 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Discriminator modules for the DMD2 GAN branch. + +Ports FastGen's image-DiT discriminator from +``source/FastGen/fastgen/networks/discriminators.py`` so that ModelOpt's +:class:`~modelopt.torch.fastgen.methods.dmd.DMDPipeline` can run the GAN branch +without a FastGen dependency. The discriminator is **model-agnostic**: it takes +a list of spatial feature tensors ``[B, C, H, W]`` and returns concatenated +logits ``[B, num_heads]``. The model-specific work of producing those tensors +(installing forward hooks, reshaping packed-token streams into spatial maps) +lives in the per-model plugins (``plugins/qwen_image.py``). +""" + +from __future__ import annotations + +import torch +from torch import nn + +__all__ = ["Discriminator", "Discriminator_ImageDiT"] + + +def _get_optimal_groups(num_channels: int) -> int: + """Return a GroupNorm group count that divides ``num_channels`` evenly. + + Matches the heuristic in FastGen's discriminator: prefer 32 groups when + possible, fall back to the largest divisor below 32, and use + ``num_channels // 4`` for very small channel counts. + """ + if num_channels <= 32: + groups = max(1, num_channels // 4) + else: + groups = 32 + while groups > 1 and num_channels % groups != 0: + groups -= 1 + assert num_channels % groups == 0, f"{num_channels} not divisible by {groups}" + return groups + + +class Discriminator(nn.Module): + """Base class for DMD2 discriminators.""" + + def __init__(self, feature_indices: set[int] | None = None) -> None: + """Store the teacher block indices whose features feed the discriminator.""" + super().__init__() + self.feature_indices = feature_indices + + def forward(self, feats: list[torch.Tensor]) -> torch.Tensor: + """Map captured teacher features to discriminator logits (overridden by subclasses).""" + raise NotImplementedError("Subclasses must implement forward()") + + +# Class name kept verbatim from the FastGen reference implementation. +class Discriminator_ImageDiT(Discriminator): # noqa: N801 + """Image-DiT discriminator with one lightweight conv head per captured block. + + Input: list of feature tensors with shape ``[B, inner_dim, H, W]``, one per + block index in :attr:`feature_indices`. + + Output: concatenated logits ``[B, num_heads]`` (one column per head). The + DMD2 generator/discriminator losses read this as a 2D tensor. + + Per-head parameter count is ~``inner_dim * (inner_dim // 2) * 16 + ...``; + for ``inner_dim=3072`` (Flux / Qwen-Image) that's ~75 M params per head, so + keep ``len(feature_indices)`` small (≤3 heads is typical). + """ + + def __init__( + self, + feature_indices: set[int] | None = None, + num_blocks: int = 57, + inner_dim: int = 3072, + ) -> None: + """Build one lightweight conv classification head per captured block.""" + super().__init__(feature_indices=feature_indices) + + if self.feature_indices is None: + self.feature_indices = {int(num_blocks // 2)} + self.feature_indices = {i for i in self.feature_indices if i < num_blocks} + self.num_features = len(self.feature_indices) + self.inner_dim = inner_dim + + hidden_channels = inner_dim // 2 + self.cls_pred_heads = nn.ModuleList() + for _ in range(self.num_features): + head = nn.Sequential( + nn.Conv2d( + in_channels=inner_dim, + out_channels=hidden_channels, + kernel_size=4, + stride=2, + padding=1, + ), + nn.GroupNorm( + num_groups=_get_optimal_groups(hidden_channels), + num_channels=hidden_channels, + ), + nn.LeakyReLU(0.2), + nn.Conv2d( + in_channels=hidden_channels, + out_channels=1, + kernel_size=1, + stride=1, + padding=0, + ), + nn.AdaptiveAvgPool2d((1, 1)), + nn.Flatten(), + ) + self.cls_pred_heads.append(head) + + def forward(self, feats: list[torch.Tensor]) -> torch.Tensor: + """Run each per-block conv head and concatenate their logits to ``[B, num_heads]``.""" + if not isinstance(feats, list) or len(feats) != self.num_features: + raise ValueError( + f"Expected list of {self.num_features} feature tensors, " + f"got {type(feats).__name__} with length " + f"{len(feats) if isinstance(feats, list) else 'N/A'}" + ) + all_logits = [] + for head, feat in zip(self.cls_pred_heads, feats): + logits = head(feat) + all_logits.append(logits) + return torch.cat(all_logits, dim=1) diff --git a/modelopt/torch/fastgen/ema.py b/modelopt/torch/fastgen/ema.py new file mode 100644 index 00000000000..3088acfaa62 --- /dev/null +++ b/modelopt/torch/fastgen/ema.py @@ -0,0 +1,258 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Exponential moving average of a student network, FSDP2 DTensor aware. + +Ported from ``FastGen/fastgen/callbacks/ema.py`` (lines 20-169) but exposed as a plain +class rather than a framework-specific callback. The caller decides when to call +:meth:`update` (typically after ``optimizer.step()``), how to persist the shadow state +(via :meth:`state_dict`), and when to publish the EMA weights back to a target module +(via :meth:`copy_to`). +""" + +from __future__ import annotations + +import copy +from typing import TYPE_CHECKING + +import torch +from torch import nn + +if TYPE_CHECKING: + from .config import EMAConfig + +__all__ = ["ExponentialMovingAverage"] + + +_DTYPE_MAP: dict[str, torch.dtype] = { + "float32": torch.float32, + "bfloat16": torch.bfloat16, + "float16": torch.float16, +} + + +def _resolve_dtype(config_dtype: str | None, fallback: torch.dtype) -> torch.dtype: + """Map an ``EMAConfig.dtype`` string to a ``torch.dtype``. + + ``config_dtype is None`` falls through to ``fallback`` (the live parameter's dtype). + """ + if config_dtype is None: + return fallback + try: + return _DTYPE_MAP[config_dtype] + except KeyError as exc: + raise ValueError( + f"Unsupported EMA dtype {config_dtype!r}; expected one of {sorted(_DTYPE_MAP)} or None." + ) from exc + + +def _is_distributed_tensor(t: torch.Tensor) -> bool: + """Return True if ``t`` is a ``torch.distributed.DTensor`` supporting ``full_tensor()``.""" + return hasattr(t, "full_tensor") and callable(t.full_tensor) + + +def _gather_full(param: torch.Tensor, *, fsdp2: bool) -> torch.Tensor: + """Return a materialised full tensor for ``param``. + + Mirrors the FSDP2 branch in ``FastGen/fastgen/callbacks/ema.py:128-139``: if CPU + offloading is enabled the local shard must be moved to CUDA before ``full_tensor()`` + can perform the all-gather (which requires a CUDA backend). + """ + if fsdp2 and _is_distributed_tensor(param): + if param.device.type == "cpu": + return param.to("cuda").full_tensor() + return param.full_tensor() + return param + + +def _strip_checkpoint_prefix(name: str) -> str: + """Remove the ``_checkpoint_wrapped_module.`` prefix injected by FSDP2 activation checkpointing.""" + return name.replace("_checkpoint_wrapped_module.", "") + + +class ExponentialMovingAverage: + """FSDP2-aware EMA tracker for a PyTorch module. + + The tracker stores a shadow state dict: parameters are promoted per + :attr:`EMAConfig.dtype` (default fp32) while buffers are kept in the live module's + dtype. Buffers are replicated across ranks and stepped via ``copy_`` rather than + ``lerp_``, so the bf16-roundoff argument that motivates parameter promotion + doesn't apply — preserving the live dtype makes the buffer restore exact. + + By default the tracker materialises the full tensor per parameter + (``mode='full_tensor'``) so the EMA represents the globally averaged weights even + when the model is sharded across ranks. A ``mode='local_shard'`` fallback is + available for memory-constrained settings — it does not all-gather and therefore + each rank holds an EMA of its local shard only. + + Example:: + + ema = ExponentialMovingAverage(student, EMAConfig(decay=0.999)) + for step in range(max_steps): + ... # compute loss, backward, optimizer.step() + ema.update(student, iteration=step) + + ema.copy_to(student_for_eval) # publish for inference + """ + + def __init__(self, model: nn.Module, config: EMAConfig) -> None: + """Pre-allocate the shadow state from ``model``'s parameters and buffers.""" + self.config = config + self._shadow: dict[str, torch.Tensor] = {} + self._buffer_shadow: dict[str, torch.Tensor] = {} + self._initialized = False + + # Pre-allocate shadow storage as a deepcopy of the live parameters on their + # current devices. Shadow dtype is promoted to ``EMAConfig.dtype`` (default + # fp32) so EMA updates remain meaningful even when the live model is + # bf16/fp16: the per-step increment ``(live - shadow) * (1 - beta)`` rounds + # to zero in bf16 (unit roundoff ~2^-8 of |shadow|) long before the live + # weights have converged. Pass ``dtype=None`` to fall back to the live + # parameter's dtype. + with torch.no_grad(): + for name, p in model.named_parameters(): + clean = _strip_checkpoint_prefix(name) + full = _gather_full(p.detach(), fsdp2=config.fsdp2) + target_dtype = _resolve_dtype(config.dtype, full.dtype) + self._shadow[clean] = copy.deepcopy(full).to(dtype=target_dtype) + # Buffers are replicated (not averaged) across ranks and stepped via + # ``copy_``, so the bf16-roundoff argument that drives ``EMAConfig.dtype`` + # on parameters doesn't apply — keep buffers in the live dtype for exact + # restore. + for name, b in model.named_buffers(): + clean = _strip_checkpoint_prefix(name) + self._buffer_shadow[clean] = copy.deepcopy(b.detach()) + + # ------------------------------------------------------------------ # + # Decay schedules # + # ------------------------------------------------------------------ # + + def _beta(self, iteration: int) -> float: + cfg = self.config + if cfg.type == "constant": + return cfg.decay + if cfg.type == "power": + # (1 - 1/iter) ** (gamma + 1); iteration must be > 0 for this to be finite. + safe_iter = max(iteration, 1) + return (1.0 - 1.0 / safe_iter) ** (cfg.gamma + 1.0) + if cfg.type == "halflife": + ema_halflife_nimg = cfg.halflife_kimg * 1000.0 + cur_nimg = iteration * cfg.batch_size + if cfg.rampup_ratio is not None: + ema_halflife_nimg = min(ema_halflife_nimg, cur_nimg * cfg.rampup_ratio) + return 0.5 ** (cfg.batch_size / max(ema_halflife_nimg, 1e-8)) + raise ValueError(f"Unsupported EMA type: {cfg.type!r}") + + # ------------------------------------------------------------------ # + # Public API # + # ------------------------------------------------------------------ # + + @torch.no_grad() + def update(self, model: nn.Module, *, iteration: int) -> None: + """Update the shadow state from ``model`` at the given iteration. + + Skips updates before :attr:`EMAConfig.start_iter`. On the iteration that equals + ``start_iter`` the shadow is (re-)initialised from the live weights; after that + it is updated with ``shadow = beta * shadow + (1 - beta) * live``. + """ + if iteration < self.config.start_iter: + return + + # (Re-)initialise the shadow from the live weights. Both arms are intentional: + # ``iteration == start_iter`` inits exactly at start when start_iter > 0 (earlier + # iterations are skipped above), while ``not self._initialized`` covers start_iter + # == 0 — where the auto-incremented counter never passes 0 — plus the first call + # after a resume. + if iteration == self.config.start_iter or not self._initialized: + self._copy_from_model(model) + self._initialized = True + return + + beta = self._beta(iteration) + + for name, p in model.named_parameters(): + clean = _strip_checkpoint_prefix(name) + if clean not in self._shadow: + continue + shadow = self._shadow[clean] + if self.config.mode == "full_tensor": + live = _gather_full(p.detach(), fsdp2=self.config.fsdp2) + else: + live = p.detach().to_local() if _is_distributed_tensor(p) else p.detach() + shadow.lerp_(live.to(device=shadow.device, dtype=shadow.dtype), 1.0 - beta) + + # Buffers are replicated across ranks under FSDP2, so we just copy. + for name, b in model.named_buffers(): + clean = _strip_checkpoint_prefix(name) + if clean in self._buffer_shadow: + shadow = self._buffer_shadow[clean] + shadow.copy_(b.detach().to(device=shadow.device, dtype=shadow.dtype)) + + @torch.no_grad() + def copy_to(self, target: nn.Module) -> None: + """Load the shadow state into ``target`` (which should share the tracked module's structure). + + The target is expected to be an unsharded module (i.e. the caller has unwrapped + any FSDP2 wrappers before calling). For sharded targets, prefer saving the + shadow via :meth:`state_dict` and reloading it through the framework's usual + checkpoint path. + """ + for name, p in target.named_parameters(): + clean = _strip_checkpoint_prefix(name) + if clean in self._shadow: + shadow = self._shadow[clean] + p.data.copy_(shadow.to(device=p.device, dtype=p.dtype)) + for name, b in target.named_buffers(): + clean = _strip_checkpoint_prefix(name) + if clean in self._buffer_shadow: + shadow = self._buffer_shadow[clean] + b.data.copy_(shadow.to(device=b.device, dtype=b.dtype)) + + def state_dict(self) -> dict[str, torch.Tensor]: + """Return the shadow state (parameters + buffers) for checkpointing.""" + merged: dict[str, torch.Tensor] = {} + merged.update(self._shadow) + merged.update(self._buffer_shadow) + return merged + + def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: + """Restore the shadow state from a previously saved dict.""" + for k, v in state.items(): + if k in self._shadow: + shadow = self._shadow[k] + shadow.copy_(v.to(device=shadow.device, dtype=shadow.dtype)) + elif k in self._buffer_shadow: + shadow = self._buffer_shadow[k] + shadow.copy_(v.to(device=shadow.device, dtype=shadow.dtype)) + self._initialized = True + + # ------------------------------------------------------------------ # + # Internals # + # ------------------------------------------------------------------ # + + @torch.no_grad() + def _copy_from_model(self, model: nn.Module) -> None: + for name, p in model.named_parameters(): + clean = _strip_checkpoint_prefix(name) + if clean not in self._shadow: + continue + shadow = self._shadow[clean] + live = _gather_full(p.detach(), fsdp2=self.config.fsdp2) + shadow.copy_(live.to(device=shadow.device, dtype=shadow.dtype)) + for name, b in model.named_buffers(): + clean = _strip_checkpoint_prefix(name) + if clean in self._buffer_shadow: + shadow = self._buffer_shadow[clean] + shadow.copy_(b.detach().to(device=shadow.device, dtype=shadow.dtype)) diff --git a/modelopt/torch/fastgen/factory.py b/modelopt/torch/fastgen/factory.py new file mode 100644 index 00000000000..25ca7f31331 --- /dev/null +++ b/modelopt/torch/fastgen/factory.py @@ -0,0 +1,106 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Convenience factory helpers for constructing the auxiliary DMD networks. + +These helpers are intentionally tiny — the training framework is free to build the +fake score directly (e.g. under a meta-init context for FSDP2) instead of calling +:func:`create_fake_score`. See the ModelOpt ↔ FastGen design doc (FASTGEN_MODELOPT.md, +section "How the framework can build the fake_score") for both options. +""" + +from __future__ import annotations + +import copy +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from torch import nn + +__all__ = ["create_fake_score"] + + +def _looks_fsdp_wrapped(module: nn.Module) -> bool: + """Best-effort detection of an FSDP-wrapped module. + + Matches two shapes: + + - **FSDP1** (``torch.distributed.fsdp.FullyShardedDataParallel``): child modules + carry an ``_fsdp_wrapped_module`` attribute. + - **FSDP2** (``torch.distributed._composable.fsdp.fully_shard``): parameters are + ``DTensor`` instances exposing a ``full_tensor`` method. + + Probes only the first parameter to avoid iterating a large model. Intended for + the ``create_fake_score`` fast-fail check, not as a general-purpose predicate. + """ + if any(hasattr(m, "_fsdp_wrapped_module") for m in module.modules()): + return True + for p in module.parameters(): + if hasattr(p, "full_tensor"): + return True + break # only probe the first param + return False + + +def create_fake_score(teacher: nn.Module, *, deep_copy: bool = True) -> nn.Module: + """Return a trainable fake-score network initialized from the teacher. + + This is the unit-test / single-script path; frameworks that do meta-init + FSDP2 + wrapping will typically construct the fake score themselves and pass it directly + into :class:`~modelopt.torch.fastgen.methods.dmd.DMDPipeline`. + + Args: + teacher: The already-built teacher module. Must already have its weights loaded. + deep_copy: If True, :func:`copy.deepcopy` the teacher; if False, reuse the same + instance (only sensible if the caller can guarantee it is no longer held + elsewhere as the frozen teacher). + + Returns: + A copy of ``teacher`` in training mode with all parameters requiring gradients. + + FSDP2 caveat + ------------ + ``copy.deepcopy(teacher)`` is **not safe** when the teacher is already FSDP2-wrapped + (DTensor parameters + FSDP pre/post hooks + meta-init bookkeeping). For Stage-2 FSDP2 + training, skip this factory and construct the fake score under meta-init, then + rank-0-load weights and let ``sync_module_states`` broadcast:: + + with meta_init_context(): + fake_score = build_teacher_from_config(teacher_config) + if is_rank0(): + fake_score.load_state_dict(teacher.state_dict(), strict=False) + # Wrap with FSDP2(..., sync_module_states=True) to broadcast from rank 0. + + The pattern mirrors FastGen's + ``methods/distribution_matching/dmd2.py::DMD2Model.build_model``. A dedicated + ``create_fake_score_meta`` factory is planned alongside the Stage-2 training example. + + Raises: + RuntimeError: When ``deep_copy=True`` and the teacher looks FSDP-wrapped + (either FSDP1 via ``_fsdp_wrapped_module`` or FSDP2 via DTensor + parameters). The ``deep_copy=False`` branch skips the check because + reusing the teacher directly is compatible with an FSDP-wrapped input. + """ + if deep_copy and _looks_fsdp_wrapped(teacher): + raise RuntimeError( + "create_fake_score(deep_copy=True) is not safe on an FSDP-wrapped teacher " + "(DTensor parameters + FSDP hooks + meta-init bookkeeping don't survive " + "copy.deepcopy). Construct the fake score under meta-init and rank-0-load " + "weights instead — see the 'FSDP2 caveat' section of this function's docstring." + ) + fake_score = copy.deepcopy(teacher) if deep_copy else teacher + fake_score.train() + fake_score.requires_grad_(True) + return fake_score diff --git a/modelopt/torch/fastgen/flow_matching.py b/modelopt/torch/fastgen/flow_matching.py new file mode 100644 index 00000000000..66925dd00d1 --- /dev/null +++ b/modelopt/torch/fastgen/flow_matching.py @@ -0,0 +1,263 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Rectified-flow (RF) helpers: forward process, inversions, timestep sampling. + +This module intentionally does **not** define a ``NoiseScheduler`` class. It exposes the +handful of primitives that DMD2 actually needs as plain functions, so callers can plug +fastgen into any training stack without adopting a new scheduler object. + +RF convention used throughout: ``alpha_t = 1 - t`` and ``sigma_t = t``, so +``x_t = (1 - t) * x_0 + t * eps`` with ``t in [0, 1]``. Internally all arithmetic is in +``float64`` for numerical stability, and the result is cast back to the input dtype. +""" + +from __future__ import annotations + +import math +from typing import TYPE_CHECKING + +import torch +from torch.distributions import Normal + +from .utils import expand_like + +if TYPE_CHECKING: + from .config import SampleTimestepConfig + +__all__ = [ + "add_noise", + "pred_noise_to_pred_x0", + "pred_x0_from_flow", + "rf_alpha", + "rf_sigma", + "sample_from_t_list", + "sample_timesteps", + "x0_to_eps", + "x0_to_flow", +] + + +def rf_alpha(t: torch.Tensor) -> torch.Tensor: + """Rectified-flow data coefficient ``alpha_t = 1 - t``.""" + return 1.0 - t + + +def rf_sigma(t: torch.Tensor) -> torch.Tensor: + """Rectified-flow noise coefficient ``sigma_t = t``.""" + return t + + +def add_noise(x0: torch.Tensor, eps: torch.Tensor, t: torch.Tensor) -> torch.Tensor: + """Forward process under rectified flow: ``x_t = (1 - t) * x_0 + t * eps``. + + ``t`` is broadcast across the spatial axes of ``x_0`` via :func:`expand_like`. + Computation is performed in ``float64`` for numerical stability and the output is + cast back to ``x_0``'s dtype. + """ + original_dtype = x0.dtype + x0_64 = x0.to(torch.float64) + eps_64 = eps.to(torch.float64) + t_64 = t.to(torch.float64) + alpha = expand_like(rf_alpha(t_64), x0_64) + sigma = expand_like(rf_sigma(t_64), x0_64) + x_t = x0_64 * alpha + eps_64 * sigma + return x_t.to(original_dtype) + + +def pred_noise_to_pred_x0( + pred_noise: torch.Tensor, + noisy_latents: torch.Tensor, + t: torch.Tensor, +) -> torch.Tensor: + """Convert an ``eps``-parameterized prediction to an ``x_0`` prediction under RF. + + Solves ``x_t = (1 - t) * x_0 + t * eps`` for ``x_0``: + ``x_0 = (x_t - t * eps) / (1 - t)``. + """ + original_dtype = noisy_latents.dtype + x_t = noisy_latents.to(torch.float64) + pred_noise_64 = pred_noise.to(torch.float64) + t_64 = t.to(torch.float64) + alpha = expand_like(rf_alpha(t_64), x_t) + sigma = expand_like(rf_sigma(t_64), x_t) + x0 = (x_t - sigma * pred_noise_64) / alpha.clamp_min(1e-6) + return x0.to(original_dtype) + + +def pred_x0_from_flow( + pred_flow: torch.Tensor, + noisy_latents: torch.Tensor, + t: torch.Tensor, +) -> torch.Tensor: + """Convert a flow-parameterized prediction (``v = eps - x_0``) to an ``x_0`` prediction. + + Under RF ``x_t = (1 - t) * x_0 + t * eps`` and ``v = eps - x_0`` combine to + ``x_t = x_0 + t * v``, so ``x_0 = x_t - t * v``. + """ + original_dtype = noisy_latents.dtype + x_t = noisy_latents.to(torch.float64) + v = pred_flow.to(torch.float64) + t_64 = t.to(torch.float64) + sigma = expand_like(rf_sigma(t_64), x_t) + x0 = x_t - sigma * v + return x0.to(original_dtype) + + +def x0_to_eps( + x0: torch.Tensor, + x_t: torch.Tensor, + t: torch.Tensor, +) -> torch.Tensor: + """Invert the RF forward process: ``eps = (x_t - (1 - t) * x_0) / t``. + + Used when unrolling the student in ODE mode — given the current ``x_t`` and the + student's ``x_0`` prediction, we can recover the implied ``eps`` deterministically. + """ + original_dtype = x0.dtype + x0_64 = x0.to(torch.float64) + x_t_64 = x_t.to(torch.float64) + t_64 = t.to(torch.float64) + alpha = expand_like(rf_alpha(t_64), x0_64) + sigma = expand_like(rf_sigma(t_64), x0_64) + eps = (x_t_64 - alpha * x0_64) / sigma.clamp_min(1e-6) + return eps.to(original_dtype) + + +def x0_to_flow( + x0: torch.Tensor, + x_t: torch.Tensor, + t: torch.Tensor, +) -> torch.Tensor: + """Convert an ``x_0`` prediction back into a flow-parameterized prediction under RF. + + Under RF ``x_t = (1 - t) * x_0 + t * eps`` and ``v = eps - x_0``, so + ``x_t - x_0 = t * (eps - x_0) = t * v`` and therefore ``v = (x_t - x_0) / t``. + + Used when the fake score is flow-native but the DSM loss is computed in a + different target parameterization: convert raw flow → x_0 via + :func:`pred_x0_from_flow`, then back to the loss space (which may coincide + with flow, in which case the round-trip is identity up to fp64 round-off). + """ + original_dtype = x0.dtype + x0_64 = x0.to(torch.float64) + x_t_64 = x_t.to(torch.float64) + t_64 = t.to(torch.float64) + sigma = expand_like(rf_sigma(t_64), x0_64) + flow = (x_t_64 - x0_64) / sigma.clamp_min(1e-6) + return flow.to(original_dtype) + + +# ---------------------------------------------------------------------------- # +# Timestep sampling # +# ---------------------------------------------------------------------------- # + + +def _truncated_lognormal( + n: int, + mean: float, + std: float, + *, + min_t: float, + max_t: float, + device: torch.device, + dtype: torch.dtype, +) -> torch.Tensor: + """Sample ``n`` values from a lognormal truncated to ``(min_t, max_t)``. + + Implementation ported from ``FastGen/fastgen/networks/noise_schedule.py`` (EDM + ``_truncated_lognormal_sample``). Uses CDF inversion on the underlying normal for + exact truncation. + """ + min_t = max(min_t, 1e-12) + log_min_t = torch.tensor(math.log(min_t), dtype=torch.float64) + log_max_t = torch.tensor(math.log(max_t), dtype=torch.float64) + normal = Normal( + torch.tensor(mean, dtype=torch.float64), + torch.tensor(std, dtype=torch.float64), + ) + cdf_min = normal.cdf(log_min_t) + cdf_max = normal.cdf(log_max_t) + u = torch.rand(n, dtype=torch.float64) * (cdf_max - cdf_min) + cdf_min + t = normal.icdf(u).exp() + return t.to(device=device, dtype=dtype) + + +def sample_timesteps( + n: int, + cfg: SampleTimestepConfig, + *, + device: torch.device, + dtype: torch.dtype = torch.float32, +) -> torch.Tensor: + """Sample ``n`` training timesteps according to ``cfg``. + + Supports ``uniform``, ``logitnormal``, ``lognormal``, ``shifted``, and ``polynomial`` + distributions. ``polynomial`` for RF degenerates to discrete uniform sampling from a + ``linspace(min_t, max_t, 1000)`` grid (EDM's polynomial-spaced ``_sigmas`` is + EDM-specific and not applicable under RF). + """ + min_t = cfg.min_t + max_t = cfg.max_t + + if cfg.time_dist_type == "uniform": + t = torch.rand(n, device=device, dtype=dtype) * (max_t - min_t) + min_t + elif cfg.time_dist_type == "logitnormal": + t = ( + torch.sigmoid(torch.randn(n, device=device, dtype=dtype) * cfg.p_std + cfg.p_mean) + * (max_t - min_t) + + min_t + ) + elif cfg.time_dist_type == "lognormal": + t = _truncated_lognormal( + n, + cfg.p_mean, + cfg.p_std, + min_t=min_t, + max_t=max_t, + device=device, + dtype=dtype, + ) + elif cfg.time_dist_type == "shifted": + t = torch.rand(n, device=device, dtype=dtype) * (max_t - min_t) + min_t + t = t * cfg.shift / (t * (cfg.shift - 1.0) + 1.0) + elif cfg.time_dist_type == "polynomial": + grid = torch.linspace(min_t, max_t, 1000, device=device, dtype=dtype) + idx = torch.randint(0, grid.numel(), (n,), device=device) + t = grid[idx] + else: + raise ValueError(f"Unsupported time_dist_type={cfg.time_dist_type!r}") + + return t.clamp(min_t, max_t) + + +def sample_from_t_list( + n: int, + t_list: list[float], + *, + device: torch.device, + dtype: torch.dtype = torch.float32, +) -> torch.Tensor: + """Sample ``n`` starting timesteps uniformly from ``t_list[:-1]``. + + Used for multi-step student training: ``t_list`` encodes the inference trajectory + (``t_list[-1]`` must be ``0``), and a random intermediate timestep is sampled so the + student is trained at every rung of the trajectory. + """ + assert len(t_list) >= 2, "t_list must have at least 2 entries (including the final 0)" + assert t_list[-1] == 0.0, f"t_list[-1] must be 0.0, got {t_list[-1]}" + t_tensor = torch.tensor(t_list, device=device, dtype=dtype) + ids = torch.randint(0, t_tensor.numel() - 1, (n,), device=device) + return t_tensor[ids] diff --git a/modelopt/torch/fastgen/loader.py b/modelopt/torch/fastgen/loader.py new file mode 100644 index 00000000000..175ffa23661 --- /dev/null +++ b/modelopt/torch/fastgen/loader.py @@ -0,0 +1,149 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""YAML-driven configuration loading for fastgen distillation pipelines. + +YAML is the first-class entry point for DMD configurations — the fastgen library +does not expect callers to hand-build Python dicts. Typical usage:: + + from modelopt.torch.fastgen import DMDConfig, load_dmd_config + + # (a) Load a built-in recipe by relative path + cfg = load_dmd_config("general/distillation/dmd2_qwen_image") + + # (b) Load a user-provided file + cfg = load_dmd_config("/path/to/my_dmd.yaml") + + # (c) Equivalent classmethod + cfg = DMDConfig.from_yaml("/path/to/my_dmd.yaml") + +The loader resolves paths in two places, in order: + +1. ``modelopt_recipes/`` (the built-in recipes package shipped with ModelOpt) — resolved + via :func:`importlib.resources.files`. Suffixes ``.yml`` / ``.yaml`` may be omitted. +2. The filesystem (absolute or working-directory-relative). + +Suffixes ``.yml`` and ``.yaml`` are both accepted. +""" + +from __future__ import annotations + +import contextlib +from importlib.resources import files +from pathlib import Path +from typing import TYPE_CHECKING, Any + +# ``Traversable`` moved out of ``importlib.abc`` in Python 3.11. We only need it for +# type hints, but suppress ImportError so older runtimes can still import this module. +with contextlib.suppress(ImportError): + from importlib.resources.abc import Traversable + +import yaml + +from .config import DMDConfig + +if TYPE_CHECKING: + from importlib.abc import Traversable + +__all__ = ["load_config", "load_dmd_config"] + + +# Root to all built-in recipes shipped with modelopt. +_BUILTIN_RECIPES_LIB = files("modelopt_recipes") + + +_SUFFIXES = (".yml", ".yaml") + + +def _candidate_paths(config_file: str | Path) -> list[Path | Traversable]: + """Return the ordered list of locations to probe for ``config_file``.""" + candidates: list[Path | Traversable] = [] + + # Normalize to string for suffix probing; keep Path/Traversable behavior otherwise. + if isinstance(config_file, str): + base = config_file + if base.endswith(_SUFFIXES): + candidates.append(Path(base)) + candidates.append(_BUILTIN_RECIPES_LIB.joinpath(base)) + else: + candidates.extend(Path(base + suffix) for suffix in _SUFFIXES) + candidates.extend(_BUILTIN_RECIPES_LIB.joinpath(base + suffix) for suffix in _SUFFIXES) + elif isinstance(config_file, Path): + if config_file.suffix in _SUFFIXES: + candidates.append(config_file) + if not config_file.is_absolute(): + candidates.append(_BUILTIN_RECIPES_LIB.joinpath(str(config_file))) + else: + candidates.extend(Path(str(config_file) + suffix) for suffix in _SUFFIXES) + if not config_file.is_absolute(): + candidates.extend( + _BUILTIN_RECIPES_LIB.joinpath(str(config_file) + suffix) for suffix in _SUFFIXES + ) + else: + raise TypeError( + f"Expected str or Path for config_file, got {type(config_file).__name__!r}." + ) + return candidates + + +def load_config(config_file: str | Path) -> dict[str, Any]: + """Load a YAML file and return the parsed mapping. + + Mirrors :func:`modelopt.recipe._config_loader.load_config` in spirit but without + the ExMy-num-bits post-processing that is specific to quantization recipes. + + Args: + config_file: YAML path. Suffix is optional; resolution searches the built-in + ``modelopt_recipes/`` package first, then the filesystem. + + Returns: + The parsed dictionary. An empty file yields ``{}``. + """ + for candidate in _candidate_paths(config_file): + if candidate.is_file(): + data = yaml.safe_load(candidate.read_text(encoding="utf-8")) + if data is None: + return {} + if not isinstance(data, dict): + raise ValueError( + f"Config file {candidate!s} must contain a YAML mapping, got {type(data).__name__}." + ) + return data + raise FileNotFoundError( + f"Cannot locate config file {config_file!r}; searched both the built-in " + f"recipe library and the filesystem." + ) + + +def load_dmd_config(config_file: str | Path) -> DMDConfig: + """Load a YAML file and construct a :class:`DMDConfig`. + + The YAML is validated against :class:`DMDConfig`'s Pydantic schema — unknown keys + raise ``ValidationError``. + + Example YAML:: + + pred_type: flow + guidance_scale: 5.0 + student_sample_steps: 2 + gan_loss_weight_gen: 0.03 + sample_t_cfg: + time_dist_type: shifted + t_list: [0.999, 0.833, 0.0] + ema: + decay: 0.9999 + """ + data = load_config(config_file) + return DMDConfig(**data) diff --git a/modelopt/torch/fastgen/losses.py b/modelopt/torch/fastgen/losses.py new file mode 100644 index 00000000000..8b5b1c9faf2 --- /dev/null +++ b/modelopt/torch/fastgen/losses.py @@ -0,0 +1,178 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Pure loss functions used by the fastgen distillation pipelines. + +All functions in this module are stateless: they take tensors in and return a scalar +loss tensor. They do not touch any ``nn.Module``. Higher-level orchestration (teacher +forward, CFG, noise scheduling) lives in :mod:`modelopt.torch.fastgen.methods.dmd`. + +Math ported from ``FastGen/fastgen/methods/common_loss.py`` (lines 12-136) and +``FastGen/fastgen/methods/distribution_matching/dmd2.py`` lines 287-317 (R1). +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import torch +import torch.nn.functional as F + +from .utils import expand_like + +if TYPE_CHECKING: + from collections.abc import Callable + +__all__ = [ + "dsm_loss", + "gan_disc_loss", + "gan_gen_loss", + "r1_loss", + "vsd_loss", +] + + +def dsm_loss( + pred_type: str, + net_pred: torch.Tensor, + *, + x0: torch.Tensor | None = None, + eps: torch.Tensor | None = None, + t: torch.Tensor | None = None, + alpha_fn: Callable[[torch.Tensor], torch.Tensor] | None = None, + sigma_fn: Callable[[torch.Tensor], torch.Tensor] | None = None, +) -> torch.Tensor: + """Denoising score-matching loss for ``x0`` / ``eps`` / ``v`` / ``flow`` predictions. + + The forward process is ``x_t = alpha_t * x_0 + sigma_t * eps``. For + ``pred_type='v'`` we need ``alpha_t`` and ``sigma_t``, which are supplied as + callables rather than a full noise-scheduler object so this function stays + scheduler-agnostic. + + Args: + pred_type: One of ``"x0"``, ``"eps"``, ``"v"``, ``"flow"``. + net_pred: The network output; its interpretation is determined by ``pred_type``. + x0: Clean data. Required for all ``pred_type`` except ``"eps"``. + eps: Noise used in the forward process. Required for all ``pred_type`` except ``"x0"``. + t: Timesteps in ``[0, 1]`` (or scheduler convention). Required for ``pred_type='v'``. + alpha_fn: Callable mapping ``t`` -> ``alpha_t``. Required for ``pred_type='v'``. + sigma_fn: Callable mapping ``t`` -> ``sigma_t``. Required for ``pred_type='v'``. + + Returns: + Scalar MSE loss. + """ + if pred_type == "x0": + assert x0 is not None, "x0 is required for pred_type='x0'" + return F.mse_loss(x0, net_pred, reduction="mean") + if pred_type == "eps": + assert eps is not None, "eps is required for pred_type='eps'" + return F.mse_loss(eps, net_pred, reduction="mean") + if pred_type == "v": + assert x0 is not None and eps is not None and t is not None, ( + "x0, eps, and t are required for pred_type='v'" + ) + assert alpha_fn is not None and sigma_fn is not None, ( + "alpha_fn and sigma_fn are required for pred_type='v'" + ) + alpha_t = expand_like(alpha_fn(t), x0).to(device=x0.device, dtype=x0.dtype) + sigma_t = expand_like(sigma_fn(t), x0).to(device=x0.device, dtype=x0.dtype) + v = alpha_t * eps - sigma_t * x0 + return F.mse_loss(v, net_pred, reduction="mean") + if pred_type == "flow": + assert x0 is not None and eps is not None, "x0 and eps are required for pred_type='flow'" + flow_velocity = eps - x0 + return F.mse_loss(flow_velocity, net_pred, reduction="mean") + raise ValueError(f"Unknown pred_type {pred_type!r}; expected one of 'x0', 'eps', 'v', 'flow'.") + + +def vsd_loss( + gen_data: torch.Tensor, + teacher_x0: torch.Tensor, + fake_score_x0: torch.Tensor, + additional_scale: torch.Tensor | None = None, +) -> torch.Tensor: + """Variational score-distillation (VSD) loss used by the DMD student update. + + Implements the FastGen formulation: a per-sample weight + ``w = 1 / (mean_abs(gen_data - teacher_x0) + 1e-6)`` is computed in fp32 for + numerical stability, then the gradient ``(fake_score_x0 - teacher_x0) * w`` is + subtracted from the generated data to form a pseudo-target. The loss is + ``0.5 * MSE(gen_data, pseudo_target)``. + + Args: + gen_data: Student-generated clean data ``x_0``. + teacher_x0: Teacher ``x_0`` prediction (after CFG, if enabled). Detached. + fake_score_x0: Fake-score ``x_0`` prediction. Detached. + additional_scale: Optional per-sample scale applied multiplicatively to the weight. + + Returns: + Scalar VSD loss. + """ + dims = tuple(range(1, teacher_x0.ndim)) + + with torch.no_grad(): + original_dtype = gen_data.dtype + gen_data_fp32 = gen_data.float() + teacher_x0_fp32 = teacher_x0.float() + + diff_abs_mean = (gen_data_fp32 - teacher_x0_fp32).abs().mean(dim=dims, keepdim=True) + w_fp32 = 1.0 / (diff_abs_mean + 1e-6) + + if additional_scale is not None: + w_fp32 = w_fp32 * expand_like(additional_scale.float(), w_fp32) + + w = w_fp32.to(dtype=original_dtype) + vsd_grad = (fake_score_x0 - teacher_x0) * w + pseudo_target = gen_data - vsd_grad + + return 0.5 * F.mse_loss(gen_data, pseudo_target, reduction="mean") + + +def gan_gen_loss(fake_logits: torch.Tensor) -> torch.Tensor: + """Softplus GAN generator loss: ``E[softplus(-fake_logits)]``. + + Args: + fake_logits: Discriminator logits on generated samples. Must be 2D: ``(B, num_heads)``. + """ + assert fake_logits.ndim == 2, f"fake_logits must be 2D, got shape {tuple(fake_logits.shape)}" + return F.softplus(-fake_logits).mean() + + +def gan_disc_loss(real_logits: torch.Tensor, fake_logits: torch.Tensor) -> torch.Tensor: + """Softplus GAN discriminator loss: ``E[softplus(fake_logits)] + E[softplus(-real_logits)]``.""" + assert real_logits.ndim == 2, f"real_logits must be 2D, got shape {tuple(real_logits.shape)}" + assert fake_logits.ndim == 2, f"fake_logits must be 2D, got shape {tuple(fake_logits.shape)}" + return F.softplus(fake_logits).mean() + F.softplus(-real_logits).mean() + + +def r1_loss( + real_logits: torch.Tensor, + perturbed_real_logits: torch.Tensor, +) -> torch.Tensor: + """Approximate R1 regularization (APT formulation). + + Penalizes the discriminator for being sensitive to small noise perturbations of + the real data. The caller is responsible for computing ``perturbed_real_logits`` + by re-running the teacher feature extractor and discriminator on real data that + has been perturbed with ``alpha * randn_like(real)``; this function only applies + the final MSE between the two logit sets. + + See ``FastGen/fastgen/methods/distribution_matching/dmd2.py`` lines 287-317. + """ + assert real_logits.shape == perturbed_real_logits.shape, ( + f"real_logits {tuple(real_logits.shape)} and perturbed_real_logits " + f"{tuple(perturbed_real_logits.shape)} must have matching shapes" + ) + return F.mse_loss(real_logits, perturbed_real_logits, reduction="mean") diff --git a/modelopt/torch/fastgen/methods/__init__.py b/modelopt/torch/fastgen/methods/__init__.py new file mode 100644 index 00000000000..c999ca87f6f --- /dev/null +++ b/modelopt/torch/fastgen/methods/__init__.py @@ -0,0 +1,18 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Concrete distillation method implementations (DMD, future: Self-Forcing, CausVid, ...).""" + +from .dmd import * diff --git a/modelopt/torch/fastgen/methods/dmd.py b/modelopt/torch/fastgen/methods/dmd.py new file mode 100644 index 00000000000..a9d9add2439 --- /dev/null +++ b/modelopt/torch/fastgen/methods/dmd.py @@ -0,0 +1,812 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Distribution Matching Distillation (DMD2) pipeline. + +:class:`DMDPipeline` holds references to the student / teacher / fake-score / (optional) +discriminator and exposes the three loss-computation entry points that a training loop +calls from each update step: + +- :meth:`DMDPipeline.compute_student_loss` — variational score-distillation loss plus an optional + GAN generator term. +- :meth:`DMDPipeline.compute_fake_score_loss` — denoising score matching against the student's + generated samples. +- :meth:`DMDPipeline.compute_discriminator_loss` — GAN discriminator loss plus an optional R1 + regularizer. + +The pipeline does **not** own optimizers, schedulers, gradient toggles, or device placement. +Callers drive the alternation between student / fake-score / discriminator updates, toggle +``requires_grad``, and call the appropriate ``compute_*_loss`` each step. + +Math is a close port of ``FastGen/fastgen/methods/distribution_matching/dmd2.py`` (lines +45-455). See the docstrings on the individual methods for line-level references. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +import torch +import torch.distributed as dist +from torch import nn + +from ..ema import ExponentialMovingAverage +from ..flow_matching import ( + add_noise, + pred_noise_to_pred_x0, + pred_x0_from_flow, + rf_alpha, + rf_sigma, + sample_from_t_list, + x0_to_eps, + x0_to_flow, +) +from ..losses import dsm_loss, gan_disc_loss, gan_gen_loss, r1_loss, vsd_loss +from ..pipeline import DistillationPipeline +from ..utils import classifier_free_guidance + +if TYPE_CHECKING: + from ..config import DMDConfig + +__all__ = ["DMDPipeline"] + + +# ---------------------------------------------------------------------------- # +# Feature capture helper (duck-typed so tests can bypass the capture plugin) # +# ---------------------------------------------------------------------------- # + + +def _drain_if_hooked(module: nn.Module) -> list[torch.Tensor] | None: + """Drain the feature-capture buffer on ``module`` if hooks are attached. + + Returns the captured tensors in insertion order (clearing the buffer in-place), or + ``None`` when no hooks are installed. Non-raising by design so pipeline-internal + call sites can drain unconditionally after every teacher forward — this prevents + the buffer from growing across steps when hooks are attached but the GAN branch is + disabled (e.g. an ablation). Callers that need the strict "did you forget to attach + hooks?" failure mode should call :func:`_require_hooked` on the result. + """ + captured = getattr(module, "_fastgen_captured", None) + if captured is None: + return None + out = list(captured) + captured.clear() + return out + + +def _require_hooked( + features: list[torch.Tensor] | None, + *, + which: str, +) -> list[torch.Tensor]: + """Adapter that turns a ``None`` drain result into a clear ``RuntimeError``. + + Use at pipeline sites that *must* consume captured features (i.e. the GAN-enabled + paths in ``compute_student_loss`` / ``compute_discriminator_loss``). Keeps the + non-raising ``_drain_if_hooked`` primitive for the "drain-and-discard" sites. + + The message names the attribute the pipeline looks for + (``teacher._fastgen_captured``) so a debugger can grep straight to the hook + installation site. + """ + if features is None: + raise RuntimeError( + f"Feature-capture hooks are required on the teacher ({which} branch): " + "teacher._fastgen_captured is missing. Call " + "modelopt.torch.fastgen.plugins.qwen_image.attach_feature_capture(teacher, ...) " + "before running this loss." + ) + return features + + +# ---------------------------------------------------------------------------- # +# DMDPipeline # +# ---------------------------------------------------------------------------- # + + +class DMDPipeline(DistillationPipeline): + """DMD2 loss pipeline. + + Args: + student: Trainable student module. Must be callable with ``(hidden_states, timestep, + encoder_hidden_states=..., **kwargs)`` and return either a ``Tensor``, a + ``(Tensor, ...)`` tuple (as diffusers returns with ``return_dict=False``), or an + object with a ``.sample`` attribute. + teacher: Frozen reference module with the same call signature. If ``discriminator`` + is provided, feature-capture hooks must be attached to ``teacher`` before + calling ``compute_*_loss`` — see :func:`modelopt.torch.fastgen.plugins.qwen_image.attach_feature_capture`. + fake_score: Trainable auxiliary module (same signature as teacher/student). Used to + approximate the student's generated distribution for the VSD gradient. + config: :class:`~modelopt.torch.fastgen.config.DMDConfig` with the hyperparameters. + discriminator: Optional discriminator. Required when ``config.gan_loss_weight_gen > 0``. + Must accept ``list[Tensor]`` (the captured teacher features) and return a 2D logit tensor. + """ + + def __init__( + self, + student: nn.Module, + teacher: nn.Module, + fake_score: nn.Module, + config: DMDConfig, + *, + discriminator: nn.Module | None = None, + ) -> None: + """Wire up student / teacher / fake-score / discriminator and create the EMA tracker.""" + super().__init__(student, teacher, config) + self.fake_score = fake_score + self.discriminator = discriminator + self._ema: ExponentialMovingAverage | None = ( + ExponentialMovingAverage(student, config.ema) if config.ema is not None else None + ) + self._iteration = 0 + + if config.gan_loss_weight_gen > 0 and discriminator is None: + raise ValueError( + "gan_loss_weight_gen > 0 requires a discriminator to be provided to DMDPipeline." + ) + + # Re-declare config at the class level so type checkers see ``DMDConfig`` here + # even though the base class stores it as ``DistillationConfig``. At runtime the + # attribute is set by :meth:`DistillationPipeline.__init__`. + config: DMDConfig + + @property + def ema(self) -> ExponentialMovingAverage | None: + """Reference to the student EMA tracker, if configured.""" + return self._ema + + # ================================================================== # + # Model-call helpers # + # ================================================================== # + + def _call_model( + self, + model: nn.Module, + hidden_states: torch.Tensor, + timestep: torch.Tensor, + encoder_hidden_states: torch.Tensor | None = None, + **model_kwargs: Any, + ) -> torch.Tensor: + """Forward a diffusers-style transformer and return the raw prediction tensor. + + Assumes the target module accepts ``hidden_states`` / ``timestep`` / + ``encoder_hidden_states`` as kwargs and returns one of: + + * a ``torch.Tensor`` (custom modules), + * a ``tuple`` whose first element is the prediction (diffusers ``return_dict=False``), + * an object with a ``.sample`` attribute (diffusers ``return_dict=True``). + + **Timestep convention.** ``timestep`` is passed verbatim to the model by default. + If :attr:`DistillationConfig.num_train_timesteps` is set, the continuous RF time + ``t ∈ [0, 1]`` is rescaled to ``num_train_timesteps * t`` before the call — which + matches the diffusers training convention for Wan 2.2, SD3, Flux. Leave + ``num_train_timesteps=None`` when the upstream model wrapper (e.g. a VaceWan-style + module) already scales the timestep internally. + + Subclass and override this method for modules with non-diffusers signatures + (positional-only args, alternate kwarg names) or bespoke timestep transforms. + """ + call_kwargs: dict[str, Any] = dict(model_kwargs) + call_kwargs["hidden_states"] = hidden_states + if self.config.num_train_timesteps is not None: + # Cast to match the hidden-state dtype, mirroring FastGen's VaceWan + # wrapper (``noise_scheduler.rescale_t(t).to(dtype=x_t.dtype)``). + timestep = (timestep * float(self.config.num_train_timesteps)).to( + dtype=hidden_states.dtype + ) + call_kwargs["timestep"] = timestep + if encoder_hidden_states is not None: + call_kwargs["encoder_hidden_states"] = encoder_hidden_states + + out = model(**call_kwargs) + if isinstance(out, torch.Tensor): + return out + if isinstance(out, tuple): + return out[0] + if hasattr(out, "sample"): + return out.sample + raise TypeError( + f"DMDPipeline._call_model could not extract a tensor from output of type " + f"{type(out).__name__!r}. Override ``_call_model`` in a subclass to handle " + f"custom module signatures." + ) + + @staticmethod + def _raw_to_x0( + raw: torch.Tensor, + x_t: torch.Tensor, + t: torch.Tensor, + *, + native_pred_type: str, + ) -> torch.Tensor: + """Convert a raw model output in ``native_pred_type`` space to an ``x_0`` estimate. + + ``native_pred_type`` is the parameterization the module *actually* predicts + (i.e. its architecture's native output), not the space a downstream loss + wants to operate in. Under RF, ``flow`` and ``v`` are equivalent (both are + ``eps - x_0``). + """ + if native_pred_type == "x0": + return raw + if native_pred_type == "eps": + return pred_noise_to_pred_x0(raw, x_t, t) + if native_pred_type in ("flow", "v"): + return pred_x0_from_flow(raw, x_t, t) + raise ValueError(f"Unsupported native_pred_type={native_pred_type!r}") + + @staticmethod + def _x0_to_raw( + x0: torch.Tensor, + x_t: torch.Tensor, + t: torch.Tensor, + *, + target_pred_type: str, + ) -> torch.Tensor: + """Inverse of :meth:`_raw_to_x0` — project an ``x_0`` estimate into ``target_pred_type`` space.""" + if target_pred_type == "x0": + return x0 + if target_pred_type == "eps": + return x0_to_eps(x0, x_t, t) + if target_pred_type in ("flow", "v"): + return x0_to_flow(x0, x_t, t) + raise ValueError(f"Unsupported target_pred_type={target_pred_type!r}") + + def _convert_pred( + self, + raw: torch.Tensor, + x_t: torch.Tensor, + t: torch.Tensor, + *, + from_pred_type: str, + to_pred_type: str, + ) -> torch.Tensor: + """Project a prediction between parameterizations via the ``x_0`` hub. + + Used by :meth:`compute_fake_score_loss` to land the fake-score's raw + output in the DSM loss's target space. Short-circuits to identity when + both spaces agree. + """ + if from_pred_type == to_pred_type: + return raw + x0 = self._raw_to_x0(raw, x_t, t, native_pred_type=from_pred_type) + if to_pred_type == "x0": + return x0 + return self._x0_to_raw(x0, x_t, t, target_pred_type=to_pred_type) + + def _predict_x0( + self, + model: nn.Module, + hidden_states: torch.Tensor, + timestep: torch.Tensor, + encoder_hidden_states: torch.Tensor | None = None, + *, + native_pred_type: str | None = None, + **model_kwargs: Any, + ) -> torch.Tensor: + """Forward ``model`` and return its ``x_0`` estimate. + + ``native_pred_type`` declares the module's **architectural** output + parameterization — NOT any downstream loss's target space. In the DMD2 + setup the student / teacher / fake_score are arch-twins, so this defaults + to :attr:`DistillationConfig.pred_type`; callers should only override it + when wiring in a model whose architecture genuinely differs. + """ + raw = self._call_model( + model, hidden_states, timestep, encoder_hidden_states, **model_kwargs + ) + native_pred_type = native_pred_type or self.config.pred_type + return self._raw_to_x0(raw, hidden_states, timestep, native_pred_type=native_pred_type) + + # ================================================================== # + # Noise / timestep sampling # + # ================================================================== # + + def _build_backward_simulated_student_input( + self, + noise: torch.Tensor, + encoder_hidden_states: torch.Tensor | None = None, + **model_kwargs: Any, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Build a multi-step student input by no-grad unrolling the student. + + This mirrors the SDXL DMD2 ``--backward_simulation`` idea in RF space: + choose a schedule rung, generate an x0 distribution by running the current + student through all earlier rungs, then re-noise that generated x0 at the + selected rung. ``student_sample_type`` controls whether intermediate + transitions reuse the implied ODE noise or draw fresh SDE noise. + """ + cfg = self.config + t_list = cfg.sample_t_cfg.t_list + if t_list is None: + raise ValueError( + "backward_simulation=True requires DMDConfig.sample_t_cfg.t_list to be set." + ) + if len(t_list) != cfg.student_sample_steps + 1: + raise ValueError( + "backward_simulation=True expects len(sample_t_cfg.t_list) == " + "student_sample_steps + 1, got " + f"{len(t_list)} vs {cfg.student_sample_steps + 1}." + ) + + batch_size = noise.shape[0] + device = noise.device + dtype = noise.dtype + num_train_rungs = len(t_list) - 1 + selected_idx_tensor = torch.randint( + 0, num_train_rungs, (1,), device=device, dtype=torch.long + ) + if dist.is_available() and dist.is_initialized(): + dist.broadcast(selected_idx_tensor, src=0) + selected_idx = int(selected_idx_tensor.item()) + t_student = torch.full( + (batch_size,), float(t_list[selected_idx]), device=device, dtype=torch.float32 + ) + + # First rung is the initial RF noise state, matching inference's + # ``latents = noise * schedule[0]`` and SDXL's pure-noise special case. + if selected_idx == 0: + input_student = (noise.to(torch.float64) * float(t_list[0])).to(dtype) + return input_student, t_student + + current = (torch.randn_like(noise).to(torch.float64) * float(t_list[0])).to(dtype) + generated_x0: torch.Tensor | None = None + with torch.no_grad(): + for step_idx in range(selected_idx): + t_cur = torch.full( + (batch_size,), float(t_list[step_idx]), device=device, dtype=torch.float32 + ) + generated_x0 = self._predict_x0( + self.student, + current, + t_cur, + encoder_hidden_states=encoder_hidden_states, + **model_kwargs, + ) + + if step_idx == selected_idx - 1: + break + + t_next = torch.full( + (batch_size,), + float(t_list[step_idx + 1]), + device=device, + dtype=torch.float32, + ) + if cfg.student_sample_type == "ode": + step_noise = x0_to_eps(generated_x0, current, t_cur) + elif cfg.student_sample_type == "sde": + step_noise = torch.randn_like(noise) + else: + raise ValueError( + "student_sample_type must be one of {'ode', 'sde'}, got " + f"{cfg.student_sample_type!r}." + ) + current = add_noise(generated_x0, step_noise, t_next) + + if generated_x0 is None: + raise RuntimeError("backward simulation did not produce a generated x0.") + input_student = add_noise(generated_x0.detach(), noise, t_student) + return input_student, t_student + + def _build_student_input( + self, + latents: torch.Tensor, + noise: torch.Tensor, + encoder_hidden_states: torch.Tensor | None = None, + **model_kwargs: Any, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Construct ``(input_student, t_student)`` for the student forward pass. + + - ``student_sample_steps == 1``: Use the maximum training timestep and set the + student's input to ``sigma(max_t) * noise = max_t * noise`` (RF). + - ``student_sample_steps > 1`` and ``backward_simulation=False``: sample a + random intermediate timestep from ``config.sample_t_cfg.t_list`` and + noise the real latents up to that timestep. + - ``student_sample_steps > 1`` and ``backward_simulation=True``: no-grad + unroll the current student to the selected rung and noise that generated + x0, matching the SDXL DMD2 backward-simulation training regime. + """ + cfg = self.config + batch_size = latents.shape[0] + device = latents.device + + if cfg.student_sample_steps == 1: + max_t = cfg.sample_t_cfg.max_t + t_student = torch.full((batch_size,), max_t, device=device, dtype=torch.float32) + # Under RF, ``sigma(max_t) = max_t``. Do the scaling in fp64 and cast back + # to mirror FastGen's ``BaseNoiseSchedule.latents`` — matters for bf16 + # student input at ``max_t ≈ 0.999`` where naive bf16 multiply loses + # ~10 bits of mantissa relative to the fp64 path. + original_dtype = noise.dtype + input_student = (noise.to(torch.float64) * float(max_t)).to(original_dtype) + else: + if cfg.sample_t_cfg.t_list is None: + raise ValueError( + "student_sample_steps > 1 requires DMDConfig.sample_t_cfg.t_list to be set." + ) + if cfg.backward_simulation: + return self._build_backward_simulated_student_input( + noise, + encoder_hidden_states=encoder_hidden_states, + **model_kwargs, + ) + t_student = sample_from_t_list( + batch_size, + cfg.sample_t_cfg.t_list, + device=device, + dtype=torch.float32, + ) + input_student = add_noise(latents, noise, t_student) + return input_student, t_student + + # ================================================================== # + # Public API # + # ================================================================== # + + def compute_student_loss( + self, + latents: torch.Tensor, + noise: torch.Tensor, + encoder_hidden_states: torch.Tensor | None = None, + *, + negative_encoder_hidden_states: torch.Tensor | None = None, + negative_encoder_hidden_states_mask: torch.Tensor | None = None, + guidance_scale: float | None = None, + **model_kwargs: Any, + ) -> dict[str, torch.Tensor]: + """Compute the student update losses. + + The returned dict always contains ``"vsd"`` and ``"total"``. When the GAN branch + is enabled (``discriminator is not None`` and ``config.gan_loss_weight_gen > 0``), + ``"gan_gen"`` is also present. + + Gradient flow summary: + + - VSD gradient: flows through ``student`` only (``teacher_x0`` is detached, + ``fake_score_x0`` is computed under ``torch.no_grad()``). + - GAN generator gradient: flows through ``student`` via the feature-capture + hooks on the teacher. The teacher forward is therefore **not** wrapped in + ``torch.no_grad()`` when the GAN branch is active. + + Args: + latents: Real clean-data latents ``x_0``. Used only when + ``student_sample_steps > 1`` to construct ``input_student``. + noise: Pure Gaussian noise tensor matching ``latents`` in shape/dtype. + encoder_hidden_states: Positive conditioning passed unchanged to all three + models. + negative_encoder_hidden_states: Negative conditioning used by classifier-free + guidance. Required when ``guidance_scale`` (or :attr:`DMDConfig.guidance_scale`) + is not ``None``. + negative_encoder_hidden_states_mask: Optional negative-conditioning mask. Used + for models such as Qwen-Image whose positional embedding depends on the + real text sequence length. + guidance_scale: Overrides :attr:`DMDConfig.guidance_scale` for this call. + ``None`` keeps the config-level value. + **model_kwargs: Forwarded verbatim to ``student``, ``teacher``, and ``fake_score``. + + Returns: + Dictionary with keys ``"vsd"``, ``"total"``, and optionally ``"gan_gen"``. + """ + cfg = self.config + batch_size = latents.shape[0] + device = latents.device + gan_enabled = self.discriminator is not None and cfg.gan_loss_weight_gen > 0 + + # 1. Student input. + input_student, t_student = self._build_student_input( + latents, + noise, + encoder_hidden_states=encoder_hidden_states, + **model_kwargs, + ) + + # 2. Student forward -> x0. + gen_data = self._predict_x0( + self.student, + input_student, + t_student, + encoder_hidden_states=encoder_hidden_states, + **model_kwargs, + ) + + # 3. Sample perturbation timesteps and noise, perturb gen_data. + t = self.sample_timesteps(batch_size, device=device, dtype=torch.float32) + eps = torch.randn_like(latents) + perturbed = add_noise(gen_data, eps, t) + + # 4. Fake score prediction (no grad). + # + # VSD always operates in x_0 space, regardless of ``fake_score_pred_type`` + # (which controls the DSM loss space on the fake-score side — see + # :meth:`compute_fake_score_loss`). The fake_score's architecture matches the + # student's in the DMD2 setup, so its native output parameterization is + # ``cfg.pred_type``; ``_predict_x0`` converts from that to x_0 automatically. + with torch.no_grad(): + fake_score_x0 = self._predict_x0( + self.fake_score, + perturbed, + t, + encoder_hidden_states=encoder_hidden_states, + **model_kwargs, + ) + + # 5. Teacher forward. + fake_feat: list[torch.Tensor] | None = None + if gan_enabled: + # Grad must flow through the teacher for the GAN generator term, since the + # captured features depend on perturbed -> gen_data -> student weights. + teacher_x0 = self._predict_x0( + self.teacher, + perturbed, + t, + encoder_hidden_states=encoder_hidden_states, + **model_kwargs, + ) + fake_feat = _require_hooked(_drain_if_hooked(self.teacher), which="student-fake") + else: + with torch.no_grad(): + teacher_x0 = self._predict_x0( + self.teacher, + perturbed, + t, + encoder_hidden_states=encoder_hidden_states, + **model_kwargs, + ) + # Drain any hooks attached but not consumed (e.g. hooks left over from a + # previous GAN-enabled ablation). No-op when hooks aren't installed. + _ = _drain_if_hooked(self.teacher) + + # 6. Classifier-free guidance (applied to teacher_x0 prior to detach). + effective_scale = guidance_scale if guidance_scale is not None else cfg.guidance_scale + if effective_scale is not None: + if negative_encoder_hidden_states is None: + raise ValueError( + "guidance_scale is set but negative_encoder_hidden_states was not provided." + ) + with torch.no_grad(): + negative_model_kwargs = dict(model_kwargs) + if negative_encoder_hidden_states_mask is not None: + negative_model_kwargs["encoder_hidden_states_mask"] = ( + negative_encoder_hidden_states_mask + ) + else: + negative_model_kwargs.pop("encoder_hidden_states_mask", None) + negative_model_kwargs.pop("txt_seq_lens", None) + teacher_x0_neg = self._predict_x0( + self.teacher, + perturbed, + t, + encoder_hidden_states=negative_encoder_hidden_states, + **negative_model_kwargs, + ) + # Negative-branch features are never used for GAN — drain unconditionally so + # the buffer stays clean for subsequent calls. + _ = _drain_if_hooked(self.teacher) + teacher_x0 = classifier_free_guidance(teacher_x0, teacher_x0_neg, effective_scale) + + teacher_x0 = teacher_x0.detach() + + # 7. Losses. + vsd = vsd_loss(gen_data, teacher_x0, fake_score_x0) + + if gan_enabled: + # ``fake_feat`` is guaranteed non-None by ``_require_hooked`` above; + # ``gan_enabled`` implies a discriminator was provided. + assert self.discriminator is not None + gan_gen = gan_gen_loss(self.discriminator(fake_feat)) + total = vsd + cfg.gan_loss_weight_gen * gan_gen + return {"vsd": vsd, "gan_gen": gan_gen, "total": total} + + return {"vsd": vsd, "total": vsd} + + def compute_fake_score_loss( + self, + latents: torch.Tensor, + noise: torch.Tensor, + encoder_hidden_states: torch.Tensor | None = None, + **model_kwargs: Any, + ) -> dict[str, torch.Tensor]: + """Compute the fake-score (auxiliary) update loss. + + The fake score is trained with denoising score matching against the student's + generated samples. The student forward is wrapped in ``torch.no_grad()`` — the + gradient here is w.r.t. ``fake_score`` only. + + Returns a dict with ``"fake_score"`` and ``"total"`` (both equal). + """ + cfg = self.config + batch_size = latents.shape[0] + device = latents.device + + # 1. Build student input. + input_student, t_student = self._build_student_input( + latents, + noise, + encoder_hidden_states=encoder_hidden_states, + **model_kwargs, + ) + + # 2. Generate data from student (no grad). + with torch.no_grad(): + gen_data = self._predict_x0( + self.student, + input_student, + t_student, + encoder_hidden_states=encoder_hidden_states, + **model_kwargs, + ) + + # 3. Perturb gen_data. + t = self.sample_timesteps(batch_size, device=device, dtype=torch.float32) + eps = torch.randn_like(latents) + perturbed = add_noise(gen_data, eps, t) + + # 4. Fake-score forward (grad flows here). + # + # The fake_score's architectural output parameterization is ``cfg.pred_type`` + # (same arch as teacher/student in DMD2). ``fake_score_pred_type`` controls + # which space the DSM loss is computed in — it is a loss-side knob, not a + # model-side one. When the two differ (e.g. the Wan 2.2 recipe with + # flow-native models and ``fake_score_pred_type='x0'``), we project the raw + # output through the ``x_0`` hub into the loss space before calling + # ``dsm_loss``. When they agree, ``_convert_pred`` short-circuits to identity. + fake_pred_type = cfg.fake_score_pred_type or cfg.pred_type + raw = self._call_model( + self.fake_score, + perturbed, + t, + encoder_hidden_states=encoder_hidden_states, + **model_kwargs, + ) + pred_in_loss_space = self._convert_pred( + raw, + perturbed, + t, + from_pred_type=cfg.pred_type, + to_pred_type=fake_pred_type, + ) + + # 5. DSM loss in the chosen parameterization. + loss = dsm_loss( + fake_pred_type, + pred_in_loss_space, + x0=gen_data, + eps=eps, + t=t, + alpha_fn=rf_alpha, + sigma_fn=rf_sigma, + ) + return {"fake_score": loss, "total": loss} + + def compute_discriminator_loss( + self, + latents: torch.Tensor, + noise: torch.Tensor, + encoder_hidden_states: torch.Tensor | None = None, + **model_kwargs: Any, + ) -> dict[str, torch.Tensor]: + """Compute the discriminator update loss (GAN + optional R1). + + Teacher and student forwards are wrapped in ``torch.no_grad()``; gradient flows + only through the discriminator. + + Returns a dict with ``"gan_disc"`` and ``"total"``. When + ``config.gan_r1_reg_weight > 0`` the dict also contains ``"r1"``. + """ + cfg = self.config + if self.discriminator is None: + raise RuntimeError( + "compute_discriminator_loss requires a discriminator to be set on DMDPipeline." + ) + batch_size = latents.shape[0] + device = latents.device + + # 1. Build student input and generate gen_data (no grad). + input_student, t_student = self._build_student_input( + latents, + noise, + encoder_hidden_states=encoder_hidden_states, + **model_kwargs, + ) + with torch.no_grad(): + gen_data = self._predict_x0( + self.student, + input_student, + t_student, + encoder_hidden_states=encoder_hidden_states, + **model_kwargs, + ) + + # 2. Sample fake-branch timesteps and noise. + t = self.sample_timesteps(batch_size, device=device, dtype=torch.float32) + eps = torch.randn_like(latents) + perturbed_fake = add_noise(gen_data, eps, t) + + # 3. Teacher forward on fake data to capture features. + _ = self._predict_x0( + self.teacher, + perturbed_fake, + t, + encoder_hidden_states=encoder_hidden_states, + **model_kwargs, + ) + fake_feat = _require_hooked(_drain_if_hooked(self.teacher), which="disc-fake") + + # 4. Real branch: same t/eps or re-sampled. + if cfg.gan_use_same_t_noise: + t_real = t + eps_real = eps + else: + t_real = self.sample_timesteps(batch_size, device=device, dtype=torch.float32) + eps_real = torch.randn_like(latents) + perturbed_real = add_noise(latents, eps_real, t_real) + + _ = self._predict_x0( + self.teacher, + perturbed_real, + t_real, + encoder_hidden_states=encoder_hidden_states, + **model_kwargs, + ) + real_feat = _require_hooked(_drain_if_hooked(self.teacher), which="disc-real") + + # 5. Discriminator on real / fake (grad required). + real_logits = self.discriminator(real_feat) + fake_logits = self.discriminator(fake_feat) + disc = gan_disc_loss(real_logits, fake_logits) + + result: dict[str, torch.Tensor] = {"gan_disc": disc} + + # 6. Optional R1 regularization. + if cfg.gan_r1_reg_weight > 0: + with torch.no_grad(): + perturbed_real_alpha = latents + cfg.gan_r1_reg_alpha * torch.randn_like(latents) + _ = self._predict_x0( + self.teacher, + perturbed_real_alpha, + t_real, + encoder_hidden_states=encoder_hidden_states, + **model_kwargs, + ) + real_feat_alpha = _require_hooked(_drain_if_hooked(self.teacher), which="disc-r1") + real_logits_alpha = self.discriminator(real_feat_alpha) + r1 = r1_loss(real_logits, real_logits_alpha) + total = disc + cfg.gan_r1_reg_weight * r1 + result["r1"] = r1 + result["total"] = total + else: + result["total"] = disc + return result + + # ================================================================== # + # EMA # + # ================================================================== # + + def update_ema(self, *, iteration: int | None = None) -> None: + """Update the student EMA tracker (no-op if ``config.ema`` is ``None``). + + Typically called after the student optimizer step. If ``iteration`` is not + provided, an internal counter is auto-incremented. + """ + if self._ema is None: + return + if iteration is not None: + self._iteration = iteration + else: + # Counter starts at 0 and pre-increments, so the first auto call passes 1. + # With start_iter=0 the shadow is therefore first initialised via EMA.update's + # ``not self._initialized`` arm, not the ``iteration == start_iter`` one. + self._iteration += 1 + self._ema.update(self.student, iteration=self._iteration) diff --git a/modelopt/torch/fastgen/pipeline.py b/modelopt/torch/fastgen/pipeline.py new file mode 100644 index 00000000000..5c39f70d6ae --- /dev/null +++ b/modelopt/torch/fastgen/pipeline.py @@ -0,0 +1,99 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Base class for diffusion step-distillation pipelines. + +:class:`DistillationPipeline` is deliberately minimal: it is **not** an ``nn.Module``, +does not wrap the student or teacher, does not manage optimizers or lifecycle state, +and does not register itself in any mode registry. It exists only to hold references +to the student / teacher and to freeze the teacher in a single place. + +Concrete methods — for now :class:`~modelopt.torch.fastgen.methods.dmd.DMDPipeline` — +subclass this and add ``compute_*_loss`` methods. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import torch +from torch import nn + +from .flow_matching import sample_timesteps + +if TYPE_CHECKING: + from .config import DistillationConfig + +__all__ = ["DistillationPipeline"] + + +class DistillationPipeline: + """Hold student/teacher references and expose shared utilities. + + Args: + student: Trainable student module. The pipeline does not wrap it — its lifecycle + (``train()`` / ``eval()``, ``requires_grad_``, sharding, optimizer) remains + owned by the caller. + teacher: Reference module. Frozen here via ``eval()`` + ``requires_grad_(False)``. + config: A :class:`DistillationConfig` (or subclass). + """ + + def __init__( + self, + student: nn.Module, + teacher: nn.Module, + config: DistillationConfig, + ) -> None: + """Store student / teacher references and freeze the teacher.""" + self.student = student + self.teacher = teacher.eval().requires_grad_(False) + self.config = config + + # ------------------------------------------------------------------ # + # Device / dtype inferred from the student # + # ------------------------------------------------------------------ # + + @property + def device(self) -> torch.device: + """Device of the first student parameter (best-effort; falls back to CPU).""" + for p in self.student.parameters(): + return p.device + return torch.device("cpu") + + @property + def dtype(self) -> torch.dtype: + """Dtype of the first student parameter (best-effort; falls back to float32).""" + for p in self.student.parameters(): + return p.dtype + return torch.float32 + + # ------------------------------------------------------------------ # + # Shared helpers # + # ------------------------------------------------------------------ # + + def sample_timesteps( + self, + n: int, + *, + device: torch.device | None = None, + dtype: torch.dtype = torch.float32, + ) -> torch.Tensor: + """Sample ``n`` training timesteps according to :attr:`config`.``sample_t_cfg``.""" + return sample_timesteps( + n, + self.config.sample_t_cfg, + device=device or self.device, + dtype=dtype, + ) diff --git a/modelopt/torch/fastgen/plugins/__init__.py b/modelopt/torch/fastgen/plugins/__init__.py new file mode 100644 index 00000000000..8810470b26f --- /dev/null +++ b/modelopt/torch/fastgen/plugins/__init__.py @@ -0,0 +1,27 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Optional plugins for the fastgen subpackage (gated via ``import_plugin``). + +``qwen_image`` holds the Qwen-Image pipeline plus the forward-hook helpers that expose +intermediate teacher activations to the DMD2 GAN discriminator. The import is gated so +environments that choose not to install the optional fastgen dependencies still see a +clean package import. +""" + +from modelopt.torch.utils import import_plugin + +with import_plugin("qwen_image"): + from .qwen_image import * diff --git a/modelopt/torch/fastgen/plugins/qwen_image.py b/modelopt/torch/fastgen/plugins/qwen_image.py new file mode 100644 index 00000000000..08a32b09301 --- /dev/null +++ b/modelopt/torch/fastgen/plugins/qwen_image.py @@ -0,0 +1,381 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Qwen-Image plumbing for the DMD2 pipeline. + +Qwen-Image's ``QwenImageTransformer2DModel`` does not accept the diffusers-standard +``(hidden_states[B,C,H,W], timestep, encoder_hidden_states)`` triple. Instead it expects +*packed* latents ``[B, (H//2)*(W//2), C*4]`` plus three extra kwargs +(``encoder_hidden_states_mask``, ``img_shapes``, ``guidance``) and returns its prediction +in the same packed layout. The packing / unpacking step mirrors +``QwenImagePipeline._pack_latents`` in diffusers. + +DMD2's internal math (noise injection, VSD / DSM losses, EMA updates, fake-score updates) +all operates on the *unpacked* latent ``[B, C, H, W]``, so we keep that as the +:class:`DMDPipeline` external contract and push the pack / call / unpack triple into a +single override of :meth:`DMDPipeline._call_model` on :class:`QwenImageDMDPipeline`. + +Usage from a training recipe:: + + from modelopt.torch.fastgen.plugins.qwen_image import QwenImageDMDPipeline + + pipeline = QwenImageDMDPipeline( + student=student_transformer, + teacher=teacher_transformer, + fake_score=fake_score_transformer, + config=dmd_config, + discriminator=None, + ) + +The companion ``modelopt_recipes/general/distillation/dmd2_qwen_image.yaml`` must keep +``num_train_timesteps: null`` so the continuous RF time ``t ∈ [0, 1]`` is forwarded +verbatim to the transformer (Qwen-Image normalises timesteps to ``[0, 1]`` internally; +the diffusers ``[0, 1000]`` scale used for Wan / SD3 / Flux does NOT apply here). +""" + +from __future__ import annotations + +import contextlib +from typing import TYPE_CHECKING, Any + +import torch +from torch import nn + +from ..methods.dmd import DMDPipeline + +if TYPE_CHECKING: + from ..config import DMDConfig + +__all__ = [ + "QwenImageDMDPipeline", + "attach_feature_capture", + "build_img_shapes", + "pack_latents", + "remove_feature_capture", + "unpack_latents", +] + + +# ---------------------------------------------------------------------------- # +# Latent pack / unpack helpers (2x2 patch grouping) # +# ---------------------------------------------------------------------------- # + + +def pack_latents(latents: torch.Tensor) -> torch.Tensor: + """Pack ``[B, C, H, W]`` latents into ``[B, (H//2)*(W//2), C*4]`` for Qwen-Image. + + Mirrors ``QwenImagePipeline._pack_latents`` (diffusers): groups every ``2x2`` spatial + block on each channel and lays the four values out along the channel axis so the + transformer's ``in_channels = 4 * out_channels`` patch embedding sees them as a + single token. ``H`` and ``W`` must both be even. + """ + if latents.ndim != 4: + raise ValueError( + f"pack_latents expects [B, C, H, W] (got {latents.ndim}D tensor of shape " + f"{tuple(latents.shape)})." + ) + b, c, h, w = latents.shape + if h % 2 or w % 2: + raise ValueError( + f"pack_latents requires even spatial dims, got H={h}, W={w}. " + "Increase the latent resolution or pad before packing." + ) + x = latents.view(b, c, h // 2, 2, w // 2, 2) + x = x.permute(0, 2, 4, 1, 3, 5) + x = x.reshape(b, (h // 2) * (w // 2), c * 4) + return x + + +def unpack_latents(packed: torch.Tensor, height: int, width: int) -> torch.Tensor: + """Inverse of :func:`pack_latents`. ``height`` / ``width`` are the unpacked latent dims.""" + if packed.ndim != 3: + raise ValueError( + f"unpack_latents expects [B, num_patches, C*4] (got {packed.ndim}D tensor of " + f"shape {tuple(packed.shape)})." + ) + b, num_patches, c4 = packed.shape + if c4 % 4: + raise ValueError(f"unpack_latents expects last dim divisible by 4, got {c4}.") + c = c4 // 4 + if height % 2 or width % 2: + raise ValueError( + f"unpack_latents requires even target spatial dims, got H={height}, W={width}." + ) + h2, w2 = height // 2, width // 2 + if num_patches != h2 * w2: + raise ValueError( + f"num_patches ({num_patches}) does not match H//2 * W//2 ({h2 * w2}) for " + f"target shape H={height}, W={width}." + ) + x = packed.view(b, h2, w2, c, 2, 2) + x = x.permute(0, 3, 1, 4, 2, 5) + x = x.reshape(b, c, height, width) + return x + + +def build_img_shapes(batch_size: int, h_lat: int, w_lat: int) -> list[list[tuple[int, int, int]]]: + """Build the ``img_shapes`` kwarg expected by ``QwenImageTransformer2DModel``. + + Each entry is ``[(1, h_lat // 2, w_lat // 2)]`` — one tuple per sample in the batch. + The leading ``1`` is the temporal dim (single frame for T2I). + """ + if h_lat % 2 or w_lat % 2: + raise ValueError( + f"build_img_shapes requires even latent dims, got h_lat={h_lat}, w_lat={w_lat}." + ) + return [[(1, h_lat // 2, w_lat // 2)]] * batch_size + + +# ---------------------------------------------------------------------------- # +# Pipeline subclass # +# ---------------------------------------------------------------------------- # + + +class QwenImageDMDPipeline(DMDPipeline): + """DMD2 pipeline that targets Qwen-Image's packed transformer signature. + + Drops in for :class:`DMDPipeline` and overrides :meth:`_call_model` only. All other + behaviour (noise schedules, VSD / DSM losses, EMA, GAN paths) is inherited unchanged. + + The student / teacher / fake_score modules must be the raw diffusers + ``QwenImageTransformer2DModel`` (or FSDP-sharded copy thereof). The pipeline handles + pack / call / unpack on every internal forward. + + Args: + guidance: Optional scalar guidance value forwarded to the transformer's + ``guidance`` kwarg every call. Only used when the transformer was built with + ``guidance_embeds=true`` (off by default for ``Qwen/Qwen-Image``). Leave + ``None`` to skip the guidance embedding entirely — this is independent of + :attr:`DMDConfig.guidance_scale`, which controls classifier-free guidance on + the teacher. + """ + + def __init__( + self, + student: nn.Module, + teacher: nn.Module, + fake_score: nn.Module, + config: DMDConfig, + *, + discriminator: nn.Module | None = None, + guidance: float | None = None, + ) -> None: + """Wrap the base DMD pipeline with Qwen-Image patch packing / guidance handling.""" + super().__init__( + student=student, + teacher=teacher, + fake_score=fake_score, + config=config, + discriminator=discriminator, + ) + if config.num_train_timesteps is not None: + raise ValueError( + "QwenImageDMDPipeline requires DMDConfig.num_train_timesteps=None — " + f"Qwen-Image normalises timesteps to [0, 1] internally (got " + f"num_train_timesteps={config.num_train_timesteps})." + ) + self._guidance_value = guidance + + def _call_model( + self, + model: nn.Module, + hidden_states: torch.Tensor, + timestep: torch.Tensor, + encoder_hidden_states: torch.Tensor | None = None, + **model_kwargs: Any, + ) -> torch.Tensor: + """Pack [B, C, H, W] -> packed -> call transformer -> unpack -> [B, C, H, W].""" + if hidden_states.ndim != 4: + raise ValueError( + f"QwenImageDMDPipeline._call_model expects 4D hidden_states " + f"[B, C, H, W] (got {hidden_states.ndim}D)." + ) + b, _c, h, w = hidden_states.shape + + packed = pack_latents(hidden_states) + img_shapes = build_img_shapes(b, h, w) + + call_kwargs: dict[str, Any] = dict(model_kwargs) + call_kwargs.pop("hidden_states", None) + encoder_hidden_states_mask = call_kwargs.pop("encoder_hidden_states_mask", None) + call_kwargs.pop("img_shapes", None) + call_kwargs.pop("guidance", None) + call_kwargs.pop("return_dict", None) + txt_seq_lens = call_kwargs.pop("txt_seq_lens", None) + if txt_seq_lens is None and encoder_hidden_states_mask is not None: + txt_seq_lens = encoder_hidden_states_mask.sum(dim=1).int().tolist() + + guidance = None + if self._guidance_value is not None: + guidance = torch.full( + (b,), + float(self._guidance_value), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + + out = model( + hidden_states=packed, + timestep=timestep, + encoder_hidden_states=encoder_hidden_states, + encoder_hidden_states_mask=encoder_hidden_states_mask, + img_shapes=img_shapes, + txt_seq_lens=txt_seq_lens, + guidance=guidance, + return_dict=False, + **call_kwargs, + ) + + if isinstance(out, tuple): + raw_packed = out[0] + elif isinstance(out, torch.Tensor): + raw_packed = out + elif hasattr(out, "sample"): + raw_packed = out.sample + else: + raise TypeError( + "QwenImageDMDPipeline._call_model could not extract a tensor from output " + f"of type {type(out).__name__!r}." + ) + + return unpack_latents(raw_packed, h, w) + + +# ---------------------------------------------------------------------------- # +# GAN feature capture # +# ---------------------------------------------------------------------------- # + +# These attribute names are what the shared +# :func:`~modelopt.torch.fastgen.methods.dmd._drain_if_hooked` / +# :func:`~modelopt.torch.fastgen.methods.dmd._require_hooked` helpers look for. +_CAPTURED_ATTR = "_fastgen_captured" +_HANDLES_ATTR = "_fastgen_capture_handles" +_INDICES_ATTR = "_fastgen_capture_indices" +_SHAPE_ATTR = "_fastgen_capture_shape" + + +def attach_feature_capture( + teacher: nn.Module, + feature_indices: list[int], + h_lat: int, + w_lat: int, + *, + blocks_attr: str = "transformer_blocks", +) -> None: + """Install forward hooks on ``teacher.transformer_blocks[i]`` for each ``i`` in ``feature_indices``. + + Qwen-Image's ``QwenImageTransformerBlock.forward`` returns + ``(encoder_hidden_states, hidden_states)`` where ``hidden_states`` has shape + ``[B, num_image_patches, 3072]`` with ``num_image_patches == (h_lat // 2) * (w_lat // 2)`` + (joint dual-stream attention; the text branch is the first tuple element and + is discarded). The hook unpacks the image branch and reshapes it to + ``[B, 3072, h_lat // 2, w_lat // 2]`` so the discriminator can consume it + as standard NCHW spatial features. + + Captured tensors land in ``teacher._fastgen_captured`` (a list) for the + DMD2 ``_drain_if_hooked`` / ``_require_hooked`` helpers to pop after each + teacher forward. + + Args: + teacher: The teacher transformer module. + feature_indices: Block indices to capture (e.g. ``[30]`` or + ``[15, 30, 45]`` for the 60-block Qwen-Image teacher). + h_lat: Latent height passed to the transformer this step. Must be even. + w_lat: Latent width passed to the transformer this step. Must be even. + blocks_attr: Attribute under which the teacher exposes its block stack. + Default ``"transformer_blocks"`` matches diffusers' + ``QwenImageTransformer2DModel``. + + Raises: + AttributeError: ``teacher`` does not expose ``blocks_attr``. + IndexError: An entry of ``feature_indices`` is out of range. + ValueError: ``h_lat`` or ``w_lat`` is odd. + """ + if h_lat % 2 != 0 or w_lat % 2 != 0: + raise ValueError( + f"attach_feature_capture requires even latent dims, got h_lat={h_lat}, w_lat={w_lat}." + ) + + remove_feature_capture(teacher) + + blocks = getattr(teacher, blocks_attr, None) + if blocks is None: + raise AttributeError( + f"Teacher {type(teacher).__name__!r} does not expose a ``{blocks_attr}`` attribute; " + f"pass blocks_attr=... if the block stack is named differently." + ) + try: + num_blocks = len(blocks) + except TypeError as exc: + raise TypeError( + f"Teacher ``{blocks_attr}`` is not a sequence (got {type(blocks).__name__!r})." + ) from exc + + sorted_indices = sorted(set(feature_indices)) + for idx in sorted_indices: + if not (0 <= idx < num_blocks): + raise IndexError( + f"feature_indices entry {idx} is out of range for teacher with {num_blocks} blocks." + ) + + captured: list[torch.Tensor] = [] + setattr(teacher, _CAPTURED_ATTR, captured) + setattr(teacher, _INDICES_ATTR, list(sorted_indices)) + setattr(teacher, _SHAPE_ATTR, (h_lat // 2, w_lat // 2)) + + handles: list[Any] = [] + h_half = h_lat // 2 + w_half = w_lat // 2 + for idx in sorted_indices: + block = blocks[idx] + + def _hook(_module: nn.Module, _inputs: Any, output: Any) -> None: + # Qwen-Image block.forward returns (encoder_hidden_states, hidden_states). + if isinstance(output, tuple) and len(output) == 2: + hidden = output[1] + elif isinstance(output, torch.Tensor): + hidden = output + else: + raise TypeError( + f"Unexpected QwenImage block output type {type(output).__name__!r}; " + "expected (encoder_hidden_states, hidden_states) tuple or Tensor." + ) + # hidden: [B, num_image_patches, C] -> [B, C, H_half, W_half]. + b, s, c = hidden.shape + expected_s = h_half * w_half + if s != expected_s: + raise RuntimeError( + f"QwenImage feature-capture got hidden_states seq_len={s} but expected " + f"{expected_s} = (h_lat // 2) * (w_lat // 2). Did the input resolution " + f"drift from the attach_feature_capture-time setting?" + ) + feat = hidden.permute(0, 2, 1).reshape(b, c, h_half, w_half) + captured.append(feat) + + handles.append(block.register_forward_hook(_hook)) + + setattr(teacher, _HANDLES_ATTR, handles) + + +def remove_feature_capture(teacher: nn.Module) -> None: + """Remove previously installed feature-capture hooks (no-op if none are installed).""" + handles = getattr(teacher, _HANDLES_ATTR, None) + if handles: + for h in handles: + h.remove() + for attr in (_HANDLES_ATTR, _CAPTURED_ATTR, _INDICES_ATTR, _SHAPE_ATTR): + if hasattr(teacher, attr): + with contextlib.suppress(AttributeError): + delattr(teacher, attr) diff --git a/modelopt/torch/fastgen/utils.py b/modelopt/torch/fastgen/utils.py new file mode 100644 index 00000000000..2e9ce6a7001 --- /dev/null +++ b/modelopt/torch/fastgen/utils.py @@ -0,0 +1,54 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Small tensor helpers shared across the fastgen subpackage.""" + +from __future__ import annotations + +import torch + +__all__ = ["classifier_free_guidance", "expand_like"] + + +def expand_like(x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """Pad ``x`` with trailing singleton dims until it has the same ndim as ``target``. + + Used to broadcast per-sample scalars like ``alpha_t`` / ``sigma_t`` across the + spatial / temporal axes of a video latent. + + Example:: + + x = torch.ones(5) # shape (5,) + target = torch.ones(5, 4, 16, 16) + expand_like(x, target).shape # (5, 1, 1, 1) + """ + x = torch.atleast_1d(x) + while x.ndim < target.ndim: + x = x[..., None] + return x + + +def classifier_free_guidance( + cond_pred: torch.Tensor, + uncond_pred: torch.Tensor, + guidance_scale: float, +) -> torch.Tensor: + """Combine conditional and unconditional predictions via classifier-free guidance. + + Uses the DMD2 convention ``cond + (scale - 1) * (cond - uncond)``, which is + mathematically equivalent to the standard CFG formula + ``uncond + scale * (cond - uncond)``. + """ + return cond_pred + (guidance_scale - 1.0) * (cond_pred - uncond_pred) diff --git a/modelopt_recipes/general/distillation/dmd2_qwen_image.yaml b/modelopt_recipes/general/distillation/dmd2_qwen_image.yaml new file mode 100644 index 00000000000..79f8124e6fe --- /dev/null +++ b/modelopt_recipes/general/distillation/dmd2_qwen_image.yaml @@ -0,0 +1,80 @@ +# DMD2 distillation recipe for Qwen-Image (text-to-image). +# +# Maps to :class:`modelopt.torch.fastgen.DMDConfig`. Load with:: +# +# from modelopt.torch.fastgen import load_dmd_config +# cfg = load_dmd_config("general/distillation/dmd2_qwen_image") +# +# Reference: NeMo AutoModel's Qwen-Image flow-matching config +# (Automodel/examples/diffusion/finetune/qwen_image_t2i_flow.yaml). + +# Qwen-Image is rectified-flow. Under RF, "flow" and "v" are equivalent. +pred_type: flow + +# Qwen-Image normalises timesteps to [0, 1] internally (see +# QwenImageAdapter.prepare_inputs: ``timesteps = context.timesteps / 1000``). +# QwenImageDMDPipeline asserts num_train_timesteps is null so the continuous RF +# time ``t ∈ [0, 1]`` is forwarded verbatim. +num_train_timesteps: + +# Classifier-free guidance strength applied to the teacher during the student update. +# null disables CFG (skips the negative-conditioning teacher forward). +guidance_scale: + +# Phase 2: 4-step student to match FastGen's Qwen-Image DMD2 default +# (config_dmd2.py:52). ``sample_t_cfg.t_list`` below pins the per-step schedule. +student_sample_steps: 4 +student_sample_type: ode +# Default keeps FastGen Qwen parity: train each rung from noised real latents. +# Override with ``--dmd2.backward_simulation=true`` to no-grad unroll the current +# student from the first rung before training the selected rung. +backward_simulation: false + +# One student step per N fake-score / discriminator steps. +student_update_freq: 5 + +# Fake score trains in x0 space while the student/teacher operate in flow space. +fake_score_pred_type: x0 + +# GAN generator weight. 0 disables the discriminator branch (Phase 1). +gan_loss_weight_gen: 0.0 +gan_use_same_t_noise: false +gan_r1_reg_weight: 0.0 +gan_r1_reg_alpha: 0.1 + +sample_t_cfg: + # ``time_dist_type`` governs the *perturbation* timestep ``t`` sampled on + # every loss path — VSD perturbation in compute_student_loss (dmd.py:417), + # fake-score DSM perturbation in compute_fake_score_loss (dmd.py:529), and + # GAN/discriminator perturbation in compute_discriminator_loss + # (dmd.py:605). All three call ``self.sample_timesteps`` → ``sample_t`` → + # reads ``time_dist_type``. + # + # It does NOT govern the student's *starting* timestep ``t_student``: under + # student_sample_steps > 1, ``_build_student_input`` calls + # ``sample_from_t_list`` (dmd.py:346) which samples uniformly from + # ``t_list[:-1]`` regardless of ``time_dist_type``. So ``time_dist_type`` + # is a knob that's active at every loss path; the 4-step setup just makes + # it irrelevant for ``t_student`` specifically. + time_dist_type: logitnormal + min_t: 0.001 + max_t: 0.999 + p_mean: 0.0 + p_std: 1.0 + # Exact ``torch.linspace(max_t=0.999, 0.0, 5).tolist()`` — the inference + # pipeline's default schedule when no t_list is passed (see + # ``inference_dmd2_qwen_image.py:259``). Stride = 0.999 / 4 = 0.249975, so + # the four training timesteps drawn from ``t_list[:-1]`` exactly match the + # four inference sample points. The previous values + # ``[0.999, 0.75, 0.5, 0.25, 0.0]`` looked like a uniform 0.25 stride but + # were really ``linspace(1.0, 0, 5)`` with t=1 shaved to 0.999, leaving a + # silent ~0.3% train↔inference skew on each non-endpoint timestep. + t_list: [0.999, 0.74925, 0.4995, 0.24975, 0.0] + +# Student EMA. Omit this block to disable EMA tracking entirely. +ema: + decay: 0.9999 + type: constant + start_iter: 0 + fsdp2: true + mode: full_tensor diff --git a/tests/_test_utils/torch/diffusers_models.py b/tests/_test_utils/torch/diffusers_models.py index 352cc60d793..a42ebdd8ffb 100644 --- a/tests/_test_utils/torch/diffusers_models.py +++ b/tests/_test_utils/torch/diffusers_models.py @@ -42,6 +42,16 @@ except Exception: # pragma: no cover - optional diffusers models AutoencoderKLWan = None +try: + from diffusers.models.transformers import QwenImageTransformer2DModel +except Exception: # pragma: no cover - optional diffusers models + QwenImageTransformer2DModel = None + +try: + from diffusers.models.autoencoders import AutoencoderKLQwenImage +except Exception: # pragma: no cover - optional diffusers models + AutoencoderKLQwenImage = None + import modelopt.torch.opt as mto @@ -250,3 +260,120 @@ def create_tiny_wan22_pipeline_dir(tmp_path: Path) -> Path: save_dir = tmp_path / "tiny_wan22" pipe.save_pretrained(save_dir) return save_dir + + +def get_tiny_qwen_image_transformer(**config_kwargs): + """Create a tiny QwenImageTransformer2DModel for testing. + + Scaled down from the real Qwen-Image config (60 layers, 24 heads, head_dim 128, + joint_attention_dim 3584). Two constraints to keep in mind: + - ``axes_dims_rope`` must sum to ``attention_head_dim``. + - ``joint_attention_dim`` must match the text-embedding dim the model is fed. In + the DMD2 mock-data training path that is the *dataloader's* ``text_embed_dim`` + (the bundled text encoder is bypassed), so pair this with + ``--data.dataloader.text_embed_dim=``. + """ + if QwenImageTransformer2DModel is None: + pytest.skip("QwenImageTransformer2DModel is not available in this diffusers version.") + + kwargs = { + "patch_size": 2, + "in_channels": 64, # vae z_dim (16) * 2x2 patch + "out_channels": 16, # = vae z_dim + "num_layers": 2, + "attention_head_dim": 16, + "num_attention_heads": 2, # hidden_dim = 32 + "joint_attention_dim": 32, + "pooled_projection_dim": 32, + "guidance_embeds": False, + "axes_dims_rope": (8, 4, 4), # sums to attention_head_dim (16) + } + kwargs.update(**config_kwargs) + return QwenImageTransformer2DModel(**kwargs) + + +def get_tiny_qwen_image_vae(**config_kwargs): + """Create a tiny AutoencoderKLQwenImage for testing (z_dim=16 to match the transformer).""" + if AutoencoderKLQwenImage is None: + pytest.skip("AutoencoderKLQwenImage is not available in this diffusers version.") + + kwargs = { + "base_dim": 8, + "z_dim": 16, # = transformer out_channels + "dim_mult": [1, 2], + "num_res_blocks": 1, + "temperal_downsample": [True], # len == len(dim_mult) - 1 + "attn_scales": [], + "latents_mean": [0.0] * 16, # length must == z_dim + "latents_std": [1.0] * 16, + } + kwargs.update(**config_kwargs) + return AutoencoderKLQwenImage(**kwargs) + + +def create_tiny_qwen_image_pipeline_dir(tmp_path: Path) -> Path: + """Create and save a tiny Qwen-Image pipeline to a directory (SKETCH). + + Mirrors ``create_tiny_wan22_pipeline_dir``. Needs in-container validation; the + fragile piece is the Qwen2.5-VL text encoder. This prefers a tiny-random HF model + (as Wan uses ``hf-internal-testing/tiny-random-t5``); if that id drifts or the + config schema differs across transformers versions, copy the text-encoder + construction from diffusers' own QwenImage fast test + (``tests/pipelines/qwenimage/test_qwenimage.py``). + + For the DMD2 mock-data training path the transformer consumes the dataloader's + embeddings rather than the text encoder, so the bundled tiny text encoder only + needs to load; its hidden size is intentionally decoupled from the transformer's + ``joint_attention_dim`` (set the dataloader's ``text_embed_dim`` to match instead). + The saved dir loads with ``QwenImagePipeline.from_pretrained(path)``. + """ + if QwenImageTransformer2DModel is None or AutoencoderKLQwenImage is None: + pytest.skip("QwenImage diffusers classes not available in this diffusers version.") + from diffusers import FlowMatchEulerDiscreteScheduler, QwenImagePipeline + + transformers = pytest.importorskip("transformers") + + # Tiny Qwen2.5-VL text encoder + matching Qwen2 tokenizer (loaded, but bypassed + # during DMD2 mock-data training). + # NOTE (validated 2026-06-06): the hf-internal-testing id below does NOT exist on the + # Hub, so this fixture currently skips. To make the recipe e2e runnable in CI, + # construct the encoder inline from a tiny ``Qwen2_5_VLConfig`` (nested text + vision + # config) — mirror diffusers' ``QwenImagePipelineFastTests.get_dummy_components`` in + # ``tests/pipelines/qwenimage/test_qwenimage.py``. + tiny_id = "hf-internal-testing/tiny-random-Qwen2_5_VLForConditionalGeneration" + try: + text_encoder = transformers.Qwen2_5_VLForConditionalGeneration.from_pretrained(tiny_id) + tokenizer = transformers.Qwen2Tokenizer.from_pretrained(tiny_id) + except Exception as exc: # pragma: no cover - depends on hub availability / version + pytest.skip( + f"tiny Qwen2.5-VL text encoder unavailable ({exc}); " + "copy the fixture from diffusers' QwenImage fast test" + ) + + torch.manual_seed(0) + transformer = get_tiny_qwen_image_transformer() + torch.manual_seed(0) + vae = get_tiny_qwen_image_vae() + + scheduler = FlowMatchEulerDiscreteScheduler( + base_image_seq_len=256, + base_shift=0.5, + max_image_seq_len=8192, + max_shift=0.9, + num_train_timesteps=1000, + shift=1.0, + shift_terminal=0.02, + use_dynamic_shifting=True, + time_shift_type="exponential", + ) + + pipe = QwenImagePipeline( + transformer=transformer, + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + scheduler=scheduler, + ) + save_dir = tmp_path / "tiny_qwen_image" + pipe.save_pretrained(save_dir) + return save_dir diff --git a/tests/examples/diffusers/conftest.py b/tests/examples/diffusers/conftest.py index 8893d188d9e..e704f6d5879 100644 --- a/tests/examples/diffusers/conftest.py +++ b/tests/examples/diffusers/conftest.py @@ -29,3 +29,20 @@ def tiny_wan22_path(tmp_path_factory): tmp_path = tmp_path_factory.mktemp("wan22") return str(create_tiny_wan22_pipeline_dir(tmp_path)) + + +@pytest.fixture(scope="session") +def tiny_qwen_image_path(tmp_path_factory): + """Create a tiny Qwen-Image pipeline and return its path (built once per session). + + SKETCH fixture for the recipe-level DMD2 e2e (``test_fastgen_recipe_e2e.py``). + See ``create_tiny_qwen_image_pipeline_dir`` for caveats — notably the tiny + Qwen2.5-VL text encoder, which needs in-container validation. + """ + try: + from _test_utils.torch.diffusers_models import create_tiny_qwen_image_pipeline_dir + except ImportError: + pytest.skip("Qwen-Image diffusers models not available") + + tmp_path = tmp_path_factory.mktemp("qwen_image") + return str(create_tiny_qwen_image_pipeline_dir(tmp_path)) diff --git a/tests/unit/torch/fastgen/conftest.py b/tests/unit/torch/fastgen/conftest.py new file mode 100644 index 00000000000..26110005bc5 --- /dev/null +++ b/tests/unit/torch/fastgen/conftest.py @@ -0,0 +1,114 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Shared fixtures for ``tests/unit/torch/fastgen/``. + +Keeps the duplicated ``_ToyTransformer`` / ``_ToyDiscriminator`` / pipeline-builder +helpers in one place so individual test files can focus on assertions rather than +wiring. +""" + +from __future__ import annotations + +import copy + +import pytest +import torch +from torch import nn + +from modelopt.torch.fastgen import DMDConfig, DMDPipeline + + +class ToyTransformer(nn.Module): + """Minimal diffusers-shaped transformer: output = Linear(hidden_states). + + Accepts ``hidden_states`` / ``timestep`` / ``encoder_hidden_states`` / **kwargs + but ignores all of them except the first. + """ + + def __init__(self, d: int) -> None: + super().__init__() + self.linear = nn.Linear(d, d, bias=False) + + def forward(self, hidden_states, timestep=None, encoder_hidden_states=None, **kwargs): + return self.linear(hidden_states) + + +class ToyDiscriminator(nn.Module): + """Consumes ``list[Tensor]`` and returns 2D logits ``(B, 1)`` by averaging features.""" + + def forward(self, features): + x = features[0] + return x.flatten(start_dim=1).mean(dim=-1, keepdim=True) + + +@pytest.fixture +def toy_transformer_factory(): + """Return a callable ``d -> ToyTransformer(d)`` factory.""" + return ToyTransformer + + +@pytest.fixture +def toy_discriminator_factory(): + """Return a callable ``() -> ToyDiscriminator()`` factory.""" + return ToyDiscriminator + + +@pytest.fixture +def build_pipeline(): + """Factory that constructs a :class:`DMDPipeline` with toy student/teacher/fake_score. + + Usage:: + + pipeline = build_pipeline( + d=4, pred_type="flow", gan_loss_weight_gen=0.03, discriminator=ToyDiscriminator() + ) + """ + + def _build( + d: int, + *, + pred_type: str = "flow", + fake_score_pred_type: str | None = None, + num_train_timesteps: int | None = None, + gan_loss_weight_gen: float = 0.0, + gan_use_same_t_noise: bool = False, + gan_r1_reg_weight: float = 0.0, + ema=None, + discriminator: nn.Module | None = None, + seed: int = 0, + ) -> DMDPipeline: + torch.manual_seed(seed) + student = ToyTransformer(d) + teacher = ToyTransformer(d) + fake_score = copy.deepcopy(teacher) + cfg = DMDConfig( + pred_type=pred_type, + fake_score_pred_type=fake_score_pred_type, + num_train_timesteps=num_train_timesteps, + gan_loss_weight_gen=gan_loss_weight_gen, + gan_use_same_t_noise=gan_use_same_t_noise, + gan_r1_reg_weight=gan_r1_reg_weight, + ema=ema, + ) + return DMDPipeline( + student, + teacher, + fake_score, + cfg, + discriminator=discriminator, + ) + + return _build diff --git a/tests/unit/torch/fastgen/test_dmd_gradient_routing.py b/tests/unit/torch/fastgen/test_dmd_gradient_routing.py new file mode 100644 index 00000000000..5c842bd8201 --- /dev/null +++ b/tests/unit/torch/fastgen/test_dmd_gradient_routing.py @@ -0,0 +1,162 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Gradient-routing tests on ``DMDPipeline`` with tiny modules. + +Ports the in-process gradient-isolation bullets from checklist §3 (3.1, 3.2): +the student loss must only touch student params, and the fake-score loss must +only touch fake-score params. The source-grep bullets (3.4 / 3.6 / 3.7 / 3.8) +intentionally stay in ``experiments/qwen.3/run_section_3.py`` — they are +recipe-source linting, not unit-testable logic. + +Uses a ``_TinyTransformer`` with a timestep bias (the checklist's §3 module) so +the gradient signal definitely flows through the transformer when the +``compute_*_loss`` paths are exercised. +""" + +from __future__ import annotations + +import torch +from torch import nn + +from modelopt.torch.fastgen.config import DMDConfig, SampleTimestepConfig +from modelopt.torch.fastgen.methods.dmd import DMDPipeline + + +class _TinyTransformer(nn.Module): + """``(hidden_states, timestep, encoder_hidden_states, **kw) -> Tensor`` module. + + Linear projection over the flattened spatial axes plus a timestep-derived + bias. Cheap enough to run on CPU and returns a tensor in flow-space so it + can play either student / teacher / fake_score role. + """ + + def __init__(self, channels: int = 16, dim: int = 8) -> None: + super().__init__() + self.channels = channels + self.dim = dim + flat = channels * dim * dim + self.proj = nn.Linear(flat, flat, bias=True) + self.t_proj = nn.Linear(1, flat, bias=False) + + def forward(self, hidden_states, timestep, encoder_hidden_states=None, **_kw): + b = hidden_states.shape[0] + x = hidden_states.reshape(b, -1) + t = timestep.reshape(b, 1).to(x.dtype) + return (self.proj(x) + self.t_proj(t)).reshape_as(hidden_states) + + +def _make_pipeline() -> tuple[DMDPipeline, _TinyTransformer, _TinyTransformer, _TinyTransformer]: + torch.manual_seed(0) + student = _TinyTransformer() + teacher = _TinyTransformer() + fake_score = _TinyTransformer() + cfg = DMDConfig( + pred_type="flow", + num_train_timesteps=None, + student_sample_steps=1, + student_update_freq=5, + fake_score_pred_type="x0", + gan_loss_weight_gen=0.0, + guidance_scale=None, + sample_t_cfg=SampleTimestepConfig( + time_dist_type="shifted", min_t=0.001, max_t=0.999, shift=1.0 + ), + ema=None, + ) + return ( + DMDPipeline(student=student, teacher=teacher, fake_score=fake_score, config=cfg), + student, + teacher, + fake_score, + ) + + +def _has_grad(module: nn.Module) -> bool: + return any(p.grad is not None and p.grad.abs().sum().item() > 0 for p in module.parameters()) + + +# ---------------------------------------------------------------------------- # +# §3.1 — compute_student_loss: only the student gets grads # +# ---------------------------------------------------------------------------- # + + +def test_compute_student_loss_routes_gradients_to_student_only(): + pipe, student, teacher, fake_score = _make_pipeline() + + # Mirror the recipe's _set_grad_requirements for the student phase. + student.train() + for p in student.parameters(): + p.requires_grad_(True) + fake_score.eval() + for p in fake_score.parameters(): + p.requires_grad_(False) + teacher.eval() + for p in teacher.parameters(): + p.requires_grad_(False) + + torch.manual_seed(1) + latents = torch.randn(2, 16, 8, 8) + noise = torch.randn_like(latents) + text = torch.randn(2, 8, 4) + + losses = pipe.compute_student_loss( + latents, + noise, + encoder_hidden_states=text, + negative_encoder_hidden_states=None, + guidance_scale=None, + ) + losses["total"].backward() + + assert "vsd" in losses + assert "total" in losses + assert _has_grad(student) + assert not _has_grad(teacher) + assert not _has_grad(fake_score) + + +# ---------------------------------------------------------------------------- # +# §3.2 — compute_fake_score_loss: only the fake_score gets grads # +# ---------------------------------------------------------------------------- # + + +def test_compute_fake_score_loss_routes_gradients_to_fake_score_only(): + pipe, student, teacher, fake_score = _make_pipeline() + + # Mirror the fake-score phase grad config. + student.eval() + for p in student.parameters(): + p.requires_grad_(False) + fake_score.train() + for p in fake_score.parameters(): + p.requires_grad_(True) + teacher.eval() + for p in teacher.parameters(): + p.requires_grad_(False) + + torch.manual_seed(2) + latents = torch.randn(2, 16, 8, 8) + noise = torch.randn_like(latents) + text = torch.randn(2, 8, 4) + + losses = pipe.compute_fake_score_loss(latents, noise, encoder_hidden_states=text) + losses["total"].backward() + + assert "fake_score" in losses + assert "total" in losses + assert _has_grad(fake_score) + assert not _has_grad(student) + assert not _has_grad(teacher) diff --git a/tests/unit/torch/fastgen/test_dmd_math.py b/tests/unit/torch/fastgen/test_dmd_math.py new file mode 100644 index 00000000000..91db5a61392 --- /dev/null +++ b/tests/unit/torch/fastgen/test_dmd_math.py @@ -0,0 +1,419 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""DMD math parity tests against the FastGen reference implementation. + +Ports checklist §2 bullets — flow-matching identities, dsm/vsd/gan losses, CFG +formula, and the fake-score flow→x0→DSM conversion chain. The FastGen reference +math is inlined verbatim from +``source/FastGen/fastgen/methods/common_loss.py`` and +``source/FastGen/fastgen/methods/distribution_matching/dmd2.py`` so the test is +hermetic — no FastGen import required. + +Numerical tolerance is ``1e-6`` for floating-point losses; pack/permute paths +(``add_noise``, ``pred_x0_from_flow``, ``x0_to_flow``, CFG) use ``torch.equal`` +because both implementations route through fp64 intermediates with the same +operation order. +""" + +from __future__ import annotations + +import pytest +import torch +import torch.nn.functional as F +from torch import nn + +from modelopt.torch.fastgen.config import DMDConfig, SampleTimestepConfig +from modelopt.torch.fastgen.flow_matching import ( + add_noise, + pred_x0_from_flow, + rf_alpha, + rf_sigma, + x0_to_flow, +) +from modelopt.torch.fastgen.losses import dsm_loss, gan_disc_loss, gan_gen_loss, vsd_loss +from modelopt.torch.fastgen.methods import dmd as dmd_module +from modelopt.torch.fastgen.methods.dmd import DMDPipeline +from modelopt.torch.fastgen.utils import classifier_free_guidance + +# ---------------------------------------------------------------------------- # +# FastGen reference impls (math only) — inlined verbatim # +# ---------------------------------------------------------------------------- # + + +def _expand_like(t: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + while t.ndim < target.ndim: + t = t.unsqueeze(-1) + return t + + +def _fastgen_forward_process(x: torch.Tensor, eps: torch.Tensor, t: torch.Tensor) -> torch.Tensor: + """``BaseNoiseSchedule.forward_process`` with RF alpha/sigma inlined.""" + original_dtype = x.dtype + t64 = t.to(torch.float64) + x64 = x.to(torch.float64) + eps64 = eps.to(torch.float64) + alpha_t = _expand_like(1.0 - t64, x64) + sigma_t = _expand_like(t64, eps64) + return (x64 * alpha_t + eps64 * sigma_t).to(original_dtype) + + +def _fastgen_dsm(pred_type, net_pred, *, x0=None, eps=None, t=None): + """common_loss.py:12-60. RF scheduler so alpha=1-t, sigma=t for 'v'.""" + if pred_type == "x0": + return F.mse_loss(x0, net_pred, reduction="mean") + if pred_type == "eps": + return F.mse_loss(eps, net_pred, reduction="mean") + if pred_type == "v": + alpha_t = _expand_like((1.0 - t).to(dtype=x0.dtype), x0).to(device=x0.device) + sigma_t = _expand_like(t.to(dtype=x0.dtype), x0).to(device=x0.device) + v = alpha_t * eps - sigma_t * x0 + return F.mse_loss(v, net_pred, reduction="mean") + if pred_type == "flow": + return F.mse_loss(eps - x0, net_pred, reduction="mean") + raise NotImplementedError(pred_type) + + +def _fastgen_vsd(gen_data, teacher_x0, fake_score_x0): + """common_loss.py:63-103.""" + dims = tuple(range(1, teacher_x0.ndim)) + with torch.no_grad(): + original_dtype = gen_data.dtype + gen_fp32 = gen_data.float() + teacher_fp32 = teacher_x0.float() + diff_abs_mean = (gen_fp32 - teacher_fp32).abs().mean(dim=dims, keepdim=True) + w = (1 / (diff_abs_mean + 1e-6)).to(dtype=original_dtype) + pseudo_target = gen_data - (fake_score_x0 - teacher_x0) * w + return 0.5 * F.mse_loss(gen_data, pseudo_target, reduction="mean") + + +def _fastgen_cfg(cond, uncond, scale): + """dmd2.py:184 — ``teacher_x0 + (scale - 1) * (teacher_x0 - teacher_x0_neg)``.""" + return cond + (scale - 1) * (cond - uncond) + + +def _fastgen_gan_gen(fake_logits): + return F.softplus(-fake_logits).mean() + + +def _fastgen_gan_disc(real_logits, fake_logits): + return F.softplus(fake_logits).mean() + F.softplus(-real_logits).mean() + + +# ---------------------------------------------------------------------------- # +# §2.1 — RF forward process # +# ---------------------------------------------------------------------------- # + + +def test_rf_forward_matches_fastgen(): + torch.manual_seed(0) + x0 = torch.randn(2, 16, 8, 8, dtype=torch.float32) + eps = torch.randn_like(x0) + t = torch.tensor([0.1, 0.7], dtype=torch.float64) + assert torch.equal(add_noise(x0, eps, t), _fastgen_forward_process(x0, eps, t)) + + +# ---------------------------------------------------------------------------- # +# §2.2 / §2.3 — student input for single-step and multi-step # +# ---------------------------------------------------------------------------- # + + +def _student_input_pipeline(*, sample_steps: int, t_list=None) -> DMDPipeline: + cfg = DMDConfig( + pred_type="flow", + num_train_timesteps=None, + student_sample_steps=sample_steps, + student_update_freq=5, + sample_t_cfg=SampleTimestepConfig( + time_dist_type="shifted", + min_t=0.001, + max_t=0.999, + shift=5.0, + t_list=t_list, + ), + ) + return DMDPipeline( + student=nn.Identity(), teacher=nn.Identity(), fake_score=nn.Identity(), config=cfg + ) + + +def test_build_student_input_single_step_matches_max_t_noise(): + pipe = _student_input_pipeline(sample_steps=1) + latents = torch.randn(2, 16, 8, 8, dtype=torch.float32) + noise = torch.randn_like(latents) + input_student, t_student = pipe._build_student_input(latents, noise) + max_t = float(pipe.config.sample_t_cfg.max_t) + expected_input = (noise.to(torch.float64) * max_t).to(noise.dtype) + expected_t = torch.full((2,), max_t, dtype=torch.float32) + assert torch.equal(input_student, expected_input) + assert torch.equal(t_student, expected_t) + + +def test_build_student_input_multi_step_uses_t_list_prefix_and_add_noise(): + pipe = _student_input_pipeline(sample_steps=2, t_list=[0.999, 0.5, 0.0]) + torch.manual_seed(0) + latents = torch.randn(8, 16, 4, 4, dtype=torch.float32) + noise = torch.randn_like(latents) + input_student, t_student = pipe._build_student_input(latents, noise) + allowed = list(pipe.config.sample_t_cfg.t_list[:-1]) + actual = t_student.detach().cpu().tolist() + # fp32 round-trip — compare with tolerance, not set membership. + assert all(any(abs(v - a) < 1e-5 for a in allowed) for v in actual) + assert torch.equal(input_student, add_noise(latents, noise, t_student)) + + +class _ZeroFlow(nn.Module): + def forward(self, hidden_states, timestep, encoder_hidden_states=None, **_kwargs): + return torch.zeros_like(hidden_states) + + +def _backward_simulation_pipeline(*, sample_type: str = "ode") -> DMDPipeline: + cfg = DMDConfig( + pred_type="flow", + num_train_timesteps=None, + student_sample_steps=2, + student_sample_type=sample_type, + backward_simulation=True, + sample_t_cfg=SampleTimestepConfig( + time_dist_type="uniform", + min_t=0.001, + max_t=0.9, + t_list=[0.9, 0.5, 0.0], + ), + ) + model = _ZeroFlow() + return DMDPipeline(student=model, teacher=model, fake_score=model, config=cfg) + + +def test_build_student_input_backward_simulation_uses_generated_distribution(monkeypatch): + def _fixed_randint(low, high, size, *, device=None, dtype=None, **_kwargs): + assert low == 0 + assert high == 2 + return torch.ones(size, device=device, dtype=dtype or torch.long) + + def _fixed_randn_like(x, *args, **kwargs): + return torch.full_like(x, 2.0) + + monkeypatch.setattr(torch, "randint", _fixed_randint) + monkeypatch.setattr(torch, "randn_like", _fixed_randn_like) + + pipe = _backward_simulation_pipeline() + latents = torch.zeros(2, 16, 4, 4, dtype=torch.float32) + final_noise = torch.full_like(latents, 3.0) + input_student, t_student = pipe._build_student_input(latents, final_noise) + + expected_t = torch.full((2,), 0.5, dtype=torch.float32) + generated_x0 = torch.full_like(latents, 2.0 * 0.9) + expected_input = add_noise(generated_x0, final_noise, expected_t) + assert torch.equal(t_student, expected_t) + assert torch.equal(input_student, expected_input) + assert not torch.equal(input_student, add_noise(latents, final_noise, expected_t)) + + +def test_backward_simulation_selected_rung_is_broadcast(monkeypatch): + calls = [] + + def _fixed_randint(low, high, size, *, device=None, dtype=None, **_kwargs): + assert low == 0 + assert high == 2 + return torch.ones(size, device=device, dtype=dtype or torch.long) + + def _broadcast_to_first_rung(tensor, src): + assert src == 0 + calls.append(tensor.clone()) + tensor.zero_() + + monkeypatch.setattr(torch, "randint", _fixed_randint) + monkeypatch.setattr(dmd_module.dist, "is_available", lambda: True) + monkeypatch.setattr(dmd_module.dist, "is_initialized", lambda: True) + monkeypatch.setattr(dmd_module.dist, "broadcast", _broadcast_to_first_rung) + + pipe = _backward_simulation_pipeline() + latents = torch.zeros(2, 16, 4, 4, dtype=torch.float32) + final_noise = torch.full_like(latents, 3.0) + input_student, t_student = pipe._build_student_input(latents, final_noise) + + expected_t = torch.full((2,), 0.9, dtype=torch.float32) + expected_input = (final_noise.to(torch.float64) * 0.9).to(final_noise.dtype) + assert len(calls) == 1 + assert torch.equal(calls[0], torch.ones(1, dtype=torch.long)) + assert torch.equal(t_student, expected_t) + assert torch.equal(input_student, expected_input) + + +# ---------------------------------------------------------------------------- # +# §2.4 / §2.5 — flow ↔ x0 identities # +# ---------------------------------------------------------------------------- # + + +def test_pred_x0_from_flow_matches_identity(): + torch.manual_seed(1) + x_t = torch.randn(2, 16, 8, 8, dtype=torch.float32) + flow = torch.randn_like(x_t) + t = torch.tensor([0.3, 0.7], dtype=torch.float32) + mo = pred_x0_from_flow(flow, x_t, t) + t64 = _expand_like(t.to(torch.float64), x_t.to(torch.float64)) + ref = (x_t.to(torch.float64) - t64 * flow.to(torch.float64)).to(x_t.dtype) + assert torch.equal(mo, ref) + + +def test_x0_to_flow_matches_identity(): + torch.manual_seed(2) + x0 = torch.randn(2, 16, 8, 8, dtype=torch.float32) + x_t = torch.randn_like(x0) + t = torch.tensor([0.3, 0.7], dtype=torch.float32) + mo = x0_to_flow(x0, x_t, t) + t64 = _expand_like(t.to(torch.float64), x0.to(torch.float64)) + ref = ((x_t.to(torch.float64) - x0.to(torch.float64)) / t64.clamp_min(1e-6)).to(x0.dtype) + assert torch.equal(mo, ref) + + +# ---------------------------------------------------------------------------- # +# §2.6 — dsm_loss for x0 / eps / flow / v # +# ---------------------------------------------------------------------------- # + + +@pytest.mark.parametrize("pred_type", ["x0", "eps", "flow", "v"]) +def test_dsm_loss_matches_fastgen(pred_type): + torch.manual_seed(3) + x0 = torch.randn(2, 16, 8, 8, dtype=torch.float32) + eps = torch.randn_like(x0) + t = torch.tensor([0.4, 0.6], dtype=torch.float32) + net_pred = torch.randn_like(x0) + kwargs = {"x0": x0, "eps": eps, "t": t} + if pred_type == "v": + kwargs["alpha_fn"] = rf_alpha + kwargs["sigma_fn"] = rf_sigma + mo = dsm_loss(pred_type, net_pred, **kwargs).item() + fg = _fastgen_dsm(pred_type, net_pred, x0=x0, eps=eps, t=t).item() + assert abs(mo - fg) < 1e-6 + + +# ---------------------------------------------------------------------------- # +# §2.7 — vsd_loss # +# ---------------------------------------------------------------------------- # + + +def test_vsd_loss_matches_fastgen(): + torch.manual_seed(4) + gen_data = torch.randn(2, 16, 8, 8, dtype=torch.float32, requires_grad=True) + teacher_x0 = torch.randn_like(gen_data).detach() + fake_score_x0 = torch.randn_like(gen_data).detach() + mo = vsd_loss(gen_data, teacher_x0, fake_score_x0).item() + fg = _fastgen_vsd(gen_data.detach(), teacher_x0, fake_score_x0).item() + assert abs(mo - fg) < 1e-6 + + +# ---------------------------------------------------------------------------- # +# §2.8 — fake-score DSM target: ModelOpt flow→x0→DSM matches FastGen direct DSM('x0') # +# ---------------------------------------------------------------------------- # + + +def test_fake_score_flow_to_x0_dsm_matches_fastgen(): + torch.manual_seed(5) + x0_real = torch.randn(2, 16, 8, 8, dtype=torch.float32) + eps = torch.randn_like(x0_real) + t = torch.tensor([0.3, 0.7], dtype=torch.float32) + x_t = add_noise(x0_real, eps, t) + raw_flow = torch.randn_like(x0_real) + x0_pred_modelopt = DMDPipeline._raw_to_x0(raw_flow, x_t, t, native_pred_type="flow") + loss_modelopt = dsm_loss("x0", x0_pred_modelopt, x0=x0_real).item() + + # FastGen "direct" reference: x0 = x_t - t * flow. + t64 = _expand_like(t.to(torch.float64), x_t.to(torch.float64)) + x0_pred_ref = (x_t.to(torch.float64) - t64 * raw_flow.to(torch.float64)).to(x0_real.dtype) + loss_fastgen = _fastgen_dsm("x0", x0_pred_ref, x0=x0_real).item() + assert abs(loss_modelopt - loss_fastgen) < 1e-6 + + +# ---------------------------------------------------------------------------- # +# §2.9 — classifier-free guidance # +# ---------------------------------------------------------------------------- # + + +def test_classifier_free_guidance_matches_fastgen(): + torch.manual_seed(6) + cond = torch.randn(2, 16, 8, 8, dtype=torch.float32) + uncond = torch.randn_like(cond) + assert torch.equal(classifier_free_guidance(cond, uncond, 4.0), _fastgen_cfg(cond, uncond, 4.0)) + + +class _RecordingFlow(nn.Module): + def __init__(self): + super().__init__() + self.masks: list[torch.Tensor | None] = [] + + def forward(self, hidden_states, timestep, encoder_hidden_states=None, **kwargs): + mask = kwargs.get("encoder_hidden_states_mask") + self.masks.append(mask.detach().clone() if torch.is_tensor(mask) else None) + return torch.zeros_like(hidden_states) + + +def test_compute_student_loss_uses_separate_negative_cfg_mask(): + cfg = DMDConfig( + pred_type="flow", + num_train_timesteps=None, + student_sample_steps=1, + guidance_scale=4.0, + sample_t_cfg=SampleTimestepConfig(time_dist_type="uniform", min_t=0.001, max_t=0.999), + ) + student = _RecordingFlow() + teacher = _RecordingFlow() + fake_score = _RecordingFlow() + pipe = DMDPipeline(student=student, teacher=teacher, fake_score=fake_score, config=cfg) + + torch.manual_seed(7) + latents = torch.randn(2, 16, 4, 4) + noise = torch.randn_like(latents) + text = torch.randn(2, 8, 4) + neg_text = torch.randn(2, 3, 4) + text_mask = torch.ones(2, 8, dtype=torch.long) + neg_mask = torch.ones(2, 3, dtype=torch.long) + + pipe.compute_student_loss( + latents, + noise, + encoder_hidden_states=text, + encoder_hidden_states_mask=text_mask, + negative_encoder_hidden_states=neg_text, + negative_encoder_hidden_states_mask=neg_mask, + ) + + assert torch.equal(teacher.masks[0], text_mask) + assert torch.equal(teacher.masks[1], neg_mask) + + +# ---------------------------------------------------------------------------- # +# §2.10 — GAN gen/disc/R1 losses # +# ---------------------------------------------------------------------------- # + + +def test_gan_gen_loss_matches_fastgen(): + torch.manual_seed(7) + fake_logits = torch.randn(8, 1) + assert abs(gan_gen_loss(fake_logits).item() - _fastgen_gan_gen(fake_logits).item()) < 1e-6 + + +def test_gan_disc_loss_matches_fastgen(): + torch.manual_seed(7) + fake_logits = torch.randn(8, 1) + real_logits = torch.randn(8, 1) + assert ( + abs( + gan_disc_loss(real_logits, fake_logits).item() + - _fastgen_gan_disc(real_logits, fake_logits).item() + ) + < 1e-6 + ) diff --git a/tests/unit/torch/fastgen/test_dmd_pipeline_step.py b/tests/unit/torch/fastgen/test_dmd_pipeline_step.py new file mode 100644 index 00000000000..71d7a93e640 --- /dev/null +++ b/tests/unit/torch/fastgen/test_dmd_pipeline_step.py @@ -0,0 +1,164 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Hermetic DMD2 training-step test (pipeline-level). + +Exercises one full DMD2 optimizer step through the real ``QwenImageDMDPipeline`` +loss path — the student VSD phase and the fake-score DSM phase — on tiny stub +transformers, with no real Qwen weights, NCCL, or FSDP2. A single test +transitively covers the plugin's pack/unpack ``_call_model`` path, the VSD/DSM +loss math, gradient isolation between phases, an optimizer update, and a +checkpoint round-trip. +""" + +from __future__ import annotations + +import copy + +import torch +from torch import nn + +from modelopt.torch.fastgen.config import DMDConfig, SampleTimestepConfig +from modelopt.torch.fastgen.plugins.qwen_image import QwenImageDMDPipeline + + +class _TinyQwenTransformer(nn.Module): + """Grad-capable stub over the packed Qwen hidden states ``[B, num_patches, C*4]``. + + ``QwenImageDMDPipeline._call_model`` packs ``[B, C, H, W] -> [B, P, C*4]`` before + the forward and unpacks after, so a ``Linear`` over the last (``C*4``) dim gives a + real gradient path while matching the expected return shape. + """ + + def __init__(self, packed_dim: int = 64) -> None: + super().__init__() + self.proj = nn.Linear(packed_dim, packed_dim) + + def forward(self, hidden_states, **_kwargs): + return self.proj(hidden_states) + + +def _build_pipeline(): + torch.manual_seed(0) + student = _TinyQwenTransformer() + teacher = _TinyQwenTransformer() + fake_score = _TinyQwenTransformer() + cfg = DMDConfig( + pred_type="flow", + num_train_timesteps=None, # required by QwenImageDMDPipeline + student_sample_steps=1, + student_update_freq=5, + fake_score_pred_type="x0", + gan_loss_weight_gen=0.0, # no GAN branch -> no discriminator / feature hooks needed + guidance_scale=None, + sample_t_cfg=SampleTimestepConfig(time_dist_type="uniform", min_t=0.001, max_t=0.999), + ema=None, + ) + pipe = QwenImageDMDPipeline( + student=student, + teacher=teacher, + fake_score=fake_score, + config=cfg, + discriminator=None, + ) + return pipe, student, teacher, fake_score + + +def _mock_batch(batch_size: int = 2): + latents = torch.randn(batch_size, 16, 8, 8) # even dims -> packs to [B, 16, 64] + noise = torch.randn_like(latents) + text = torch.randn(batch_size, 8, 64) # shape is arbitrary; the stub ignores it + return latents, noise, text + + +def _snapshot(module: nn.Module) -> dict[str, torch.Tensor]: + return {k: v.detach().clone() for k, v in module.state_dict().items()} + + +def _params_changed(before: dict[str, torch.Tensor], module: nn.Module) -> bool: + return any(not torch.equal(before[k], v) for k, v in module.state_dict().items()) + + +def _train_only(active: nn.Module, *frozen: nn.Module) -> None: + active.train() + for p in active.parameters(): + p.requires_grad_(True) + for m in frozen: + m.eval() + for p in m.parameters(): + p.requires_grad_(False) + + +def test_dmd2_student_then_fake_score_step_updates_only_active_module(): + """One student VSD step then one fake-score DSM step: each phase yields a + finite loss, steps only its own module, and leaves the others untouched.""" + pipe, student, teacher, fake_score = _build_pipeline() + latents, noise, text = _mock_batch() + + # ---- student (VSD) phase ---- + _train_only(student, teacher, fake_score) + student_before, teacher_before = _snapshot(student), _snapshot(teacher) + opt_s = torch.optim.Adam(student.parameters(), lr=1e-2) + opt_s.zero_grad() + losses = pipe.compute_student_loss( + latents, + noise, + encoder_hidden_states=text, + negative_encoder_hidden_states=None, + guidance_scale=None, + ) + assert "vsd" in losses and "total" in losses + assert torch.isfinite(losses["total"]) + losses["total"].backward() + opt_s.step() + + assert _params_changed(student_before, student) # student learned + assert not _params_changed(teacher_before, teacher) # teacher stayed frozen + + # ---- fake-score (DSM) phase ---- + _train_only(fake_score, student, teacher) + student_after_student_phase = _snapshot(student) + fake_before = _snapshot(fake_score) + opt_f = torch.optim.Adam(fake_score.parameters(), lr=1e-2) + opt_f.zero_grad() + fs_losses = pipe.compute_fake_score_loss(latents, noise, encoder_hidden_states=text) + assert "fake_score" in fs_losses and "total" in fs_losses + assert torch.isfinite(fs_losses["total"]) + fs_losses["total"].backward() + opt_f.step() + + assert _params_changed(fake_before, fake_score) # fake_score learned + assert not _params_changed(student_after_student_phase, student) # student untouched + + +def test_dmd2_student_state_dict_round_trips(): + """After a training step, the student's state_dict reloads bit-exactly into a + fresh module — the save/restore contract the recipe's checkpointing relies on.""" + pipe, student, _teacher, _fake = _build_pipeline() + latents, noise, text = _mock_batch() + + _train_only(student, _teacher, _fake) + opt = torch.optim.Adam(student.parameters(), lr=1e-2) + opt.zero_grad() + pipe.compute_student_loss(latents, noise, encoder_hidden_states=text, guidance_scale=None)[ + "total" + ].backward() + opt.step() + + saved = copy.deepcopy(student.state_dict()) + reloaded = _TinyQwenTransformer() + reloaded.load_state_dict(saved) + for k, v in student.state_dict().items(): + assert torch.equal(v, reloaded.state_dict()[k]), k diff --git a/tests/unit/torch/fastgen/test_hook_requirements.py b/tests/unit/torch/fastgen/test_hook_requirements.py new file mode 100644 index 00000000000..00914b5d53d --- /dev/null +++ b/tests/unit/torch/fastgen/test_hook_requirements.py @@ -0,0 +1,111 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the "did you attach hooks?" / "is this FSDP-wrapped?" runtime guards. + +Covers R2.1 (GAN branches must raise a clear ``RuntimeError`` when +``teacher._fastgen_captured`` is missing) and R2.5 (``create_fake_score`` must +reject FSDP-wrapped teachers when ``deep_copy=True``). +""" + +from __future__ import annotations + +import pytest +import torch +from torch import nn + +from modelopt.torch.fastgen import DMDConfig, DMDPipeline, create_fake_score + + +class _ToyTransformer(nn.Module): + """Linear-on-hidden-states transformer that matches the pipeline's call convention.""" + + def __init__(self, d: int) -> None: + super().__init__() + self.linear = nn.Linear(d, d, bias=False) + + def forward(self, hidden_states, timestep=None, encoder_hidden_states=None, **kwargs): + return self.linear(hidden_states) + + +class _ToyDiscriminator(nn.Module): + """Consumes ``list[Tensor]`` and returns 2D logits ``(B, 1)``.""" + + def forward(self, features): + x = features[0] + return x.flatten(start_dim=1).mean(dim=-1, keepdim=True) + + +def _build_gan_pipeline(d: int) -> DMDPipeline: + torch.manual_seed(0) + cfg = DMDConfig(pred_type="flow", gan_loss_weight_gen=0.03) + return DMDPipeline( + _ToyTransformer(d), + _ToyTransformer(d), + _ToyTransformer(d), + cfg, + discriminator=_ToyDiscriminator(), + ) + + +def test_compute_student_loss_raises_when_hooks_missing(): + """GAN-enabled ``compute_student_loss`` without ``attach_feature_capture`` + must raise a ``RuntimeError`` naming the attach helper, not strip under + ``-O`` like the previous ``assert``.""" + d, b = 4, 2 + pipeline = _build_gan_pipeline(d) + latents = torch.randn(b, d) + noise = torch.randn(b, d) + + with pytest.raises(RuntimeError, match="attach_feature_capture"): + pipeline.compute_student_loss(latents, noise) + + +def test_compute_discriminator_loss_raises_when_hooks_missing(): + """``compute_discriminator_loss`` without ``attach_feature_capture`` must + raise ``RuntimeError`` at the first drain, before the discriminator is called.""" + d, b = 4, 2 + pipeline = _build_gan_pipeline(d) + latents = torch.randn(b, d) + noise = torch.randn(b, d) + + with pytest.raises(RuntimeError, match="attach_feature_capture"): + pipeline.compute_discriminator_loss(latents, noise) + + +def test_create_fake_score_raises_on_fsdp2_wrapped(): + """``create_fake_score(deep_copy=True)`` on a module whose first parameter + looks like a DTensor (``full_tensor`` attribute) must raise with a message + pointing at the meta-init recipe in the docstring.""" + m = nn.Linear(4, 4) + # Monkey-patch the first parameter to look like a DTensor. + first = next(m.parameters()) + first.full_tensor = lambda: first # type: ignore[attr-defined] + + with pytest.raises(RuntimeError, match="meta-init"): + create_fake_score(m, deep_copy=True) + + +def test_create_fake_score_no_copy_skips_fsdp_check(): + """``deep_copy=False`` reuses the teacher directly — the FSDP-wrap check + is skipped because there is no ``copy.deepcopy`` to protect against.""" + m = nn.Linear(4, 4) + first = next(m.parameters()) + first.full_tensor = lambda: first # type: ignore[attr-defined] + + fake_score = create_fake_score(m, deep_copy=False) + assert fake_score is m + assert fake_score.training # create_fake_score must set .train() regardless + assert all(p.requires_grad for p in fake_score.parameters()) diff --git a/tests/unit/torch/fastgen/test_pred_type_conversion.py b/tests/unit/torch/fastgen/test_pred_type_conversion.py new file mode 100644 index 00000000000..69346b81838 --- /dev/null +++ b/tests/unit/torch/fastgen/test_pred_type_conversion.py @@ -0,0 +1,218 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Regression tests for fastgen pred_type conversion, timestep rescaling, and EMA dtype promotion. + +Covers three round-1 fix guards: + +1. ``fake_score_pred_type`` regression — under ``pred_type='flow'`` with + ``fake_score_pred_type='x0'``, the fake-score DSM loss must operate on the + ``x_0`` projection of the raw flow output, not on the raw flow tensor. +2. ``num_train_timesteps`` rescale — the pipeline must scale the RF + ``t ∈ [0, 1]`` to ``num_train_timesteps * t`` before passing it to the + transformer, when the knob is set. +3. EMA shadow dtype promotion — by default the shadow lives in ``float32`` + even when the live model is ``bfloat16``. +""" + +from __future__ import annotations + +import copy +from unittest import mock + +import torch +import torch.nn.functional as F +from torch import nn + +from modelopt.torch.fastgen import DMDConfig, DMDPipeline, EMAConfig, ExponentialMovingAverage +from modelopt.torch.fastgen.flow_matching import add_noise, pred_x0_from_flow +from modelopt.torch.fastgen.methods import dmd as dmd_module + + +class _ToyTransformer(nn.Module): + """Minimal diffusers-shaped transformer. Output is linear in ``hidden_states`` and + ignores ``timestep`` / ``encoder_hidden_states`` — keeps the expected-value + reconstruction analytic.""" + + def __init__(self, d: int) -> None: + super().__init__() + self.linear = nn.Linear(d, d, bias=False) + + def forward(self, hidden_states, timestep=None, encoder_hidden_states=None, **kwargs): + return self.linear(hidden_states) + + +class _TimestepEchoModel(nn.Module): + """Returns ``hidden_states + timestep`` broadcast — used to observe the rescale knob.""" + + def forward(self, hidden_states, timestep, encoder_hidden_states=None, **kwargs): + return hidden_states + timestep.view(-1, 1).to(hidden_states.dtype) + + +def _build_pipeline( + d: int, + *, + pred_type: str = "flow", + fake_score_pred_type: str | None = "x0", + num_train_timesteps: int | None = None, +): + torch.manual_seed(0) + student = _ToyTransformer(d) + teacher = _ToyTransformer(d) + fake_score = copy.deepcopy(teacher) + cfg = DMDConfig( + pred_type=pred_type, + fake_score_pred_type=fake_score_pred_type, + num_train_timesteps=num_train_timesteps, + ) + pipeline = DMDPipeline(student, teacher, fake_score, cfg) + return pipeline, student, teacher, fake_score, cfg + + +def test_fake_score_dsm_matches_manual_flow_to_x0(): + """compute_fake_score_loss under ``(flow, x0)`` must equal the manual + ``F.mse_loss(gen_data, pred_x0_from_flow(raw, x_t, t))`` reconstruction.""" + d, b = 4, 2 + pipeline, student, _teacher, fake_score, cfg = _build_pipeline( + d, pred_type="flow", fake_score_pred_type="x0" + ) + latents = torch.randn(b, d) + noise = torch.randn(b, d) + + fixed_t = torch.full((b,), 0.5) + fixed_eps = torch.randn(b, d) + + with ( + mock.patch.object(pipeline, "sample_timesteps", return_value=fixed_t), + mock.patch.object(dmd_module.torch, "randn_like", return_value=fixed_eps), + ): + actual = pipeline.compute_fake_score_loss(latents, noise)["fake_score"] + + with torch.no_grad(): + max_t = cfg.sample_t_cfg.max_t + t_student = torch.full((b,), max_t, dtype=torch.float32) + input_student = noise * max_t + gen_data = pred_x0_from_flow( + student(hidden_states=input_student, timestep=t_student), + input_student, + t_student, + ) + perturbed = add_noise(gen_data, fixed_eps, fixed_t) + fake_raw = fake_score(hidden_states=perturbed, timestep=fixed_t) + pred_x0 = pred_x0_from_flow(fake_raw, perturbed, fixed_t) + expected = F.mse_loss(gen_data, pred_x0) + + assert torch.allclose(actual, expected, atol=1e-6), ( + f"compute_fake_score_loss={actual.item():.3e}, manual={expected.item():.3e}" + ) + + +def test_student_vsd_sees_x0_not_raw_flow(): + """compute_student_loss must feed ``vsd_loss`` the x0-converted flow output of the + fake score (the prior bug was to forward the raw flow tensor instead).""" + d, b = 4, 2 + pipeline, student, _teacher, fake_score, cfg = _build_pipeline( + d, pred_type="flow", fake_score_pred_type="x0" + ) + latents = torch.randn(b, d) + noise = torch.randn(b, d) + fixed_t = torch.full((b,), 0.5) + fixed_eps = torch.randn(b, d) + + captured: dict[str, torch.Tensor] = {} + orig_vsd_loss = dmd_module.vsd_loss + + def spy(gen_data, teacher_x0, fake_score_x0, additional_scale=None): + captured["fake_score_x0"] = fake_score_x0.detach().clone() + return orig_vsd_loss(gen_data, teacher_x0, fake_score_x0, additional_scale) + + with ( + mock.patch.object(pipeline, "sample_timesteps", return_value=fixed_t), + mock.patch.object(dmd_module.torch, "randn_like", return_value=fixed_eps), + mock.patch.object(dmd_module, "vsd_loss", side_effect=spy), + ): + pipeline.compute_student_loss(latents, noise) + + with torch.no_grad(): + max_t = cfg.sample_t_cfg.max_t + t_student = torch.full((b,), max_t, dtype=torch.float32) + input_student = noise * max_t + gen_data_expected = pred_x0_from_flow( + student(hidden_states=input_student, timestep=t_student), + input_student, + t_student, + ) + perturbed = add_noise(gen_data_expected, fixed_eps, fixed_t) + fake_raw = fake_score(hidden_states=perturbed, timestep=fixed_t) + fake_x0_expected = pred_x0_from_flow(fake_raw, perturbed, fixed_t) + + assert torch.allclose(captured["fake_score_x0"], fake_x0_expected, atol=1e-6) + + +def test_call_model_rescales_timestep_when_num_train_timesteps_set(): + """``num_train_timesteps=1000`` rescales ``t`` by 1000 inside ``_call_model`` + and casts it to ``hidden_states.dtype`` (matching FastGen's VaceWan wrapper); + ``None`` leaves ``t`` untouched.""" + d, b = 3, 2 + model = _TimestepEchoModel() + x = torch.zeros(b, d) + t = torch.tensor([0.1, 0.7]) + + pipe_scaled = DMDPipeline( + _ToyTransformer(d), + _ToyTransformer(d), + _ToyTransformer(d), + DMDConfig(pred_type="x0", num_train_timesteps=1000), + ) + out_scaled = pipe_scaled._call_model(model, x, t) + assert torch.allclose(out_scaled, x + (t * 1000.0).view(-1, 1)) + + pipe_none = DMDPipeline( + _ToyTransformer(d), + _ToyTransformer(d), + _ToyTransformer(d), + DMDConfig(pred_type="x0", num_train_timesteps=None), + ) + out_none = pipe_none._call_model(model, x, t) + assert torch.allclose(out_none, x + t.view(-1, 1)) + + # bf16 hidden_states: rescaled timestep must be cast to bf16 so the + # addition inside the model returns a bf16 tensor (parity with FastGen's + # ``.to(dtype=x_t.dtype)``). fp32 timestep + bf16 hidden_states without the + # cast would either upcast the result or push dtype juggling into the model. + x_bf16 = torch.zeros(b, d, dtype=torch.bfloat16) + out_bf16 = pipe_scaled._call_model(model, x_bf16, t) + assert out_bf16.dtype == torch.bfloat16 + assert torch.allclose( + out_bf16, + x_bf16 + (t * 1000.0).to(torch.bfloat16).view(-1, 1), + ) + + +def test_ema_shadow_dtype_promotion(): + """``EMAConfig.dtype='float32'`` on a bf16 student gives an fp32 shadow; + ``dtype=None`` falls back to the live parameter's dtype.""" + torch.manual_seed(0) + student_bf16 = nn.Linear(4, 4).to(torch.bfloat16) + + cfg_fp32 = EMAConfig(fsdp2=False, dtype="float32") + ema_fp32 = ExponentialMovingAverage(student_bf16, cfg_fp32) + for name, shadow in ema_fp32.state_dict().items(): + assert shadow.dtype == torch.float32, f"{name}: expected float32, got {shadow.dtype}" + + cfg_none = EMAConfig(fsdp2=False, dtype=None) + ema_none = ExponentialMovingAverage(student_bf16, cfg_none) + for name, shadow in ema_none.state_dict().items(): + assert shadow.dtype == torch.bfloat16, f"{name}: expected bfloat16, got {shadow.dtype}" diff --git a/tests/unit/torch/fastgen/test_qwen_image_plugin.py b/tests/unit/torch/fastgen/test_qwen_image_plugin.py new file mode 100644 index 00000000000..498b6ce5f9f --- /dev/null +++ b/tests/unit/torch/fastgen/test_qwen_image_plugin.py @@ -0,0 +1,234 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for ``modelopt.torch.fastgen.plugins.qwen_image``. + +Ports the checklist §1 bullets (pack/unpack + FastGen parity + _call_model wiring) +into pytest form so they live in-repo and run under ``pytest tests/examples/diffusers``. +Adds the §6-specific bullet ``num_train_timesteps != None`` constructor error. + +The parity comparison against the FastGen reference extract is bit-exact +(``torch.equal``) — both are pure permute+reshape operations with no +floating-point arithmetic. +""" + +from __future__ import annotations + +from types import SimpleNamespace + +import pytest +import torch +from torch import nn + +from modelopt.torch.fastgen import DMDConfig +from modelopt.torch.fastgen.plugins.qwen_image import ( + QwenImageDMDPipeline, + build_img_shapes, + pack_latents, + unpack_latents, +) + +# ---------------------------------------------------------------------------- # +# §1.1 — pack / unpack inverse for representative latent sizes # +# ---------------------------------------------------------------------------- # + + +@pytest.mark.parametrize( + ("shape", "dtype"), + [ + ((1, 16, 128, 128), torch.float32), # production size + ((2, 16, 32, 32), torch.bfloat16), # batch>1 + bf16 dtype preserved + ], +) +def test_pack_unpack_inverse(shape, dtype): + """Round-trip is bit-exact; dtype/device/contiguity preserved (incl. bf16).""" + x = torch.randn(*shape, dtype=dtype) + p = pack_latents(x) + y = unpack_latents(p, shape[2], shape[3]) + assert torch.equal(x, y) + assert x.dtype == p.dtype == y.dtype + assert x.device == p.device == y.device + assert p.is_contiguous() + assert y.is_contiguous() + + +# ---------------------------------------------------------------------------- # +# §1.2 — odd spatial dims raise a clear ValueError # +# ---------------------------------------------------------------------------- # + + +def test_pack_rejects_odd_spatial(): + with pytest.raises(ValueError, match="even"): + pack_latents(torch.randn(1, 16, 31, 32)) + + +def test_unpack_rejects_odd_target(): + with pytest.raises(ValueError, match="even"): + unpack_latents(torch.randn(1, 256, 64), 31, 32) + + +# ---------------------------------------------------------------------------- # +# §1.5 — parity vs the FastGen reference # +# # +# FastGen's QwenImage class pulls heavy deps; we inline the two methods # +# verbatim from ``source/FastGen/fastgen/networks/QwenImage/network.py`` so # +# the parity check is hermetic. # +# ---------------------------------------------------------------------------- # + + +def _fastgen_pack(latents: torch.Tensor) -> torch.Tensor: + batch_size, channels, height, width = latents.shape + latents = latents.view(batch_size, channels, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + return latents.reshape(batch_size, (height // 2) * (width // 2), channels * 4) + + +def _fastgen_unpack(latents: torch.Tensor, height: int, width: int) -> torch.Tensor: + batch_size = latents.shape[0] + channels = latents.shape[2] // 4 + latents = latents.reshape(batch_size, height // 2, width // 2, channels, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + return latents.reshape(batch_size, channels, height, width) + + +def test_pack_unpack_parity_vs_fastgen(): + x = torch.randn(2, 16, 32, 32) + fg_p = _fastgen_pack(x) + mo_p = pack_latents(x) + assert torch.equal(fg_p, mo_p) + + fg_u = _fastgen_unpack(fg_p, 32, 32) + mo_u = unpack_latents(mo_p, 32, 32) + assert torch.equal(fg_u, mo_u) + assert torch.equal(fg_u, x) # FastGen unpack round-trips the input + + +# ---------------------------------------------------------------------------- # +# §1.6 — build_img_shapes structural equality # +# ---------------------------------------------------------------------------- # + + +def test_build_img_shapes_structure(): + out = build_img_shapes(batch_size=2, h_lat=32, w_lat=32) + assert out == [[(1, 16, 16)], [(1, 16, 16)]] + + +# ---------------------------------------------------------------------------- # +# §1.7 / §1.8 — _call_model kwarg forwarding + unpack return styles # +# ---------------------------------------------------------------------------- # + + +class _CapturingModel(nn.Module): + """Stub transformer that records the kwargs it was called with and emits + a zero tensor of the requested shape in one of three return styles.""" + + def __init__(self, out_shape: tuple[int, ...], style: str = "tensor") -> None: + super().__init__() + self.out_shape = out_shape + self.style = style + self.last_kwargs: dict[str, object] = {} + + def forward(self, **kwargs): + self.last_kwargs = dict(kwargs) + out = torch.zeros(*self.out_shape, dtype=kwargs["hidden_states"].dtype) + if self.style == "tensor": + return out + if self.style == "tuple": + return (out, "extra") + if self.style == "sample": + return SimpleNamespace(sample=out) + raise ValueError(self.style) + + +def _make_pipeline(student: nn.Module) -> QwenImageDMDPipeline: + return QwenImageDMDPipeline( + student=student, + teacher=nn.Identity(), + fake_score=nn.Identity(), + config=DMDConfig(num_train_timesteps=None), + discriminator=None, + ) + + +def test_call_model_forwards_qwen_kwargs(): + """``_call_model`` must forward the exact Qwen signature (hidden_states packed + to ``[B, num_patches, 64]``, encoder_hidden_states verbatim, + encoder_hidden_states_mask verbatim, txt_seq_lens derived from the mask, + img_shapes as ``[[(1, h//2, w//2)]] * B``, guidance=None, return_dict=False, + timestep verbatim with no /1000 rescale).""" + b, c, h, w = 2, 16, 32, 32 + student = _CapturingModel(out_shape=(b, (h // 2) * (w // 2), c * 4), style="tensor") + pipe = _make_pipeline(student) + + hidden = torch.randn(b, c, h, w) + t = torch.full((b,), 0.5, dtype=torch.float32) + enc = torch.randn(b, 512, 3584) + mask = torch.zeros(b, 512, dtype=torch.long) + mask[0, :37] = 1 + mask[1, :42] = 1 + + out = pipe._call_model( + student, + hidden, + t, + encoder_hidden_states=enc, + encoder_hidden_states_mask=mask, + ) + + kw = student.last_kwargs + assert tuple(kw["hidden_states"].shape) == (b, (h // 2) * (w // 2), c * 4) + assert tuple(kw["encoder_hidden_states"].shape) == (b, 512, 3584) + assert torch.equal(kw["encoder_hidden_states_mask"], mask) + assert kw["txt_seq_lens"] == [37, 42] + assert kw["img_shapes"] == [[(1, h // 2, w // 2)]] * b + assert kw["guidance"] is None + assert kw["return_dict"] is False + assert torch.equal(kw["timestep"], t) # no /1000 rescale when num_train_timesteps=None + assert tuple(out.shape) == (b, c, h, w) + + +@pytest.mark.parametrize("style", ["tensor", "tuple", "sample"]) +def test_call_model_unpacks_return_styles(style): + """``_call_model`` must unpack ``tensor`` / ``tuple`` / ``.sample`` return + styles into ``[B, C, H, W]`` of the input's dtype.""" + b, c, h, w = 1, 16, 32, 32 + model = _CapturingModel(out_shape=(b, (h // 2) * (w // 2), c * 4), style=style) + pipe = _make_pipeline(model) + hidden = torch.randn(b, c, h, w) + t = torch.full((b,), 0.5, dtype=torch.float32) + enc = torch.randn(b, 512, 3584) + out = pipe._call_model(model, hidden, t, encoder_hidden_states=enc) + assert tuple(out.shape) == (b, c, h, w) + assert out.dtype == hidden.dtype + + +# ---------------------------------------------------------------------------- # +# §6.X — QwenImageDMDPipeline constructor rejects num_train_timesteps != None # +# ---------------------------------------------------------------------------- # + + +def test_constructor_rejects_non_null_num_train_timesteps(): + """The pipeline normalizes ``t ∈ [0, 1]`` internally and forwards continuous + ``t`` to the Qwen transformer. ``num_train_timesteps`` is a discretization + knob that doesn't apply — the constructor must refuse it loudly.""" + cfg = DMDConfig(num_train_timesteps=1000) + with pytest.raises(ValueError, match="num_train_timesteps"): + QwenImageDMDPipeline( + student=nn.Identity(), + teacher=nn.Identity(), + fake_score=nn.Identity(), + config=cfg, + discriminator=None, + ) From cc4066cf7010ceeee0002ffc7144aaadcd5aaa86 Mon Sep 17 00:00:00 2001 From: Chenjie Luo <108829653+cjluo-nv@users.noreply.github.com> Date: Mon, 8 Jun 2026 16:13:25 -0700 Subject: [PATCH 06/24] feat(recipes): add kv_fp8_cast variants for partial-NVFP4 and weight-only PTQ recipes (#1652) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What does this PR do? Type of change: new feature (recipes) Several `general/ptq` recipe families shipped a data-driven FP8 KV-cache (`-kv_fp8`) variant but lacked the constant-amax `kv_fp8_cast` companion that `fp8_default` and `nvfp4_default` already have. This PR adds the missing cast variants so every KV-quantizing (and the weight-only) family offers the calibration-free FP8 KV-cache option: - `general/ptq/nvfp4_experts_only-kv_fp8_cast` - `general/ptq/nvfp4_mlp_only-kv_fp8_cast` - `general/ptq/nvfp4_omlp_only-kv_fp8_cast` - `general/ptq/nvfp4_weight_only-kv_fp8_cast` Each new recipe composes the exact same model-quant config as its existing sibling and swaps the `kv_fp8` unit for the shared `kv_fp8_cast` unit (constant-amax FP8 KV cache; no KV calibration forward pass). The docs guide table/tree and the changelog are updated to match. ### Usage ```bash python examples/llm_ptq/hf_ptq.py \ --pyt_ckpt_path \ --recipe general/ptq/nvfp4_mlp_only-kv_fp8_cast ``` ### Testing Extended the built-in PTQ smoke test `tests/unit/recipe/test_loader.py::test_load_recipe_all_builtins` with the four new recipe paths; all four load into a valid `ModelOptPTQRecipe` with a populated `quantize` section. ``` $ python -m pytest tests/unit/recipe/test_loader.py tests/unit/recipe/test_presets.py -q 180 passed ``` `pre-commit` (including the `validate modelopt recipes` hook) passes on all changed files. ### Before your PR is "*Ready for review*" - Is this change backward compatible?: ✅ (additive — only new recipe files) - If you copied code from any other sources or added a new PIP dependency, did you follow guidance in `CONTRIBUTING.md`: N/A - Did you write any new necessary tests?: ✅ (extended the builtin recipe smoke test) - Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?: ✅ - Did you get Claude approval on this PR?: ❌ (not yet) ### Additional Information The two weight-only families were discussed for scope; `nvfp4_weight_only` is included (it already names a KV mode, `kv_fp16`), while `int4_blockwise_weight_only` is intentionally left untouched since it carries no `-kv_` composition. 🤖 Generated with [Claude Code](https://claude.com/claude-code) ## Summary by CodeRabbit * **New Features** * Added four new NVFP4 PTQ (Post-Training Quantization) recipe variants: experts-only, MLP-only, OMLP-only, and weight-only configurations. * All new recipes include FP8 KV-cache cast mode support for improved inference performance. * **Documentation** * Updated built-in recipes guide with new NVFP4 recipe options and repository layout. * **Tests** * Expanded recipe loader test coverage for new recipe configurations. Signed-off-by: Chenjie Luo Co-authored-by: Claude Opus 4.8 Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> --- CHANGELOG.rst | 1 + docs/source/guides/10_recipes.rst | 12 +++++ .../ptq/nvfp4_experts_only-kv_fp8_cast.yaml | 51 ++++++++++++++++++ .../ptq/nvfp4_mlp_only-kv_fp8_cast.yaml | 52 +++++++++++++++++++ .../ptq/nvfp4_omlp_only-kv_fp8_cast.yaml | 52 +++++++++++++++++++ .../ptq/nvfp4_weight_only-kv_fp8_cast.yaml | 35 +++++++++++++ tests/unit/recipe/test_loader.py | 4 ++ 7 files changed, 207 insertions(+) create mode 100644 modelopt_recipes/general/ptq/nvfp4_experts_only-kv_fp8_cast.yaml create mode 100644 modelopt_recipes/general/ptq/nvfp4_mlp_only-kv_fp8_cast.yaml create mode 100644 modelopt_recipes/general/ptq/nvfp4_omlp_only-kv_fp8_cast.yaml create mode 100644 modelopt_recipes/general/ptq/nvfp4_weight_only-kv_fp8_cast.yaml diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 91fca51a59c..12da06a2066 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -35,6 +35,7 @@ Changelog - Add ``--cast_mxfp4_to_nvfp4`` flag to ``examples/llm_ptq/hf_ptq.py`` for closed-form, bit-exact MXFP4 → NVFP4 weight conversion. Supports the GPT-OSS family (``openai/gpt-oss-20b``, ``openai/gpt-oss-120b``). See `examples/llm_ptq/README.md `__ for usage. - DeepSeek PTQ (``examples/deepseek/ptq.py``) now defaults to native top-k calibration with post-hoc per-layer peer-max sync of expert ``input_quantizer.amax``; the all-experts path is preserved behind ``--calib_all_experts``. - Add NVFP4 W4A16 weight-only quantization (``w4a16_nvfp4``): FP4 weights with group_size=16, BF16 activations, no calibration forward pass required. Use ``mtq.W4A16_NVFP4_CFG`` or ``--qformat w4a16_nvfp4`` in ``hf_ptq.py``. vLLM deployment support is in progress. +- Add FP8 KV-cache cast variants for the partial-NVFP4 and weight-only general PTQ recipes: ``general/ptq/nvfp4_mlp_only-kv_fp8_cast``, ``general/ptq/nvfp4_experts_only-kv_fp8_cast``, ``general/ptq/nvfp4_omlp_only-kv_fp8_cast``, and ``general/ptq/nvfp4_weight_only-kv_fp8_cast``. These compose the same model-quant configs as their ``-kv_fp8`` siblings with the ``kv_fp8_cast`` unit (constant-amax FP8 KV cache, no KV calibration forward pass). - Add Megatron Core export/import mapping for Qwen3-VL (``Qwen3VLForConditionalGeneration``) vision-language models. The mapping handles the ``model.language_model.`` weight prefix used by Qwen3-VL. - Add active-MoE cost accounting for ``mtq.auto_quantize`` effective-bits search. Set ``constraints={"effective_bits": ..., "cost_model": "active_moe", "cost": {"active_moe_expert_ratio": ...}}`` to weight routed MoE expert costs by active experts per token while keeping shared experts fully counted. The ``hf_ptq.py`` AutoQuant path exposes this via ``--auto_quantize_cost_model active_moe`` and ``--auto_quantize_active_moe_expert_ratio``. - Add ``DATASET_COMBOS`` to ``modelopt.torch.utils.dataset_utils`` — single ``--dataset`` tokens that fan out to multiple registered datasets; per-entry ``num_samples`` is split evenly across the members. Initial combos: ``cnn_nemotron_v2_mix`` (``cnn_dailymail`` + ``nemotron-post-training-dataset-v2``, used by ``hf_ptq.py`` when no ``--dataset`` is provided) and ``nemotron-post-training-v3`` (the seven ``nvidia/Nemotron-*`` SFT datasets added in #1498, mirroring the `nemotron-post-training-v3 collection `_). Combo names are listed by ``get_supported_datasets()`` and surfaced in ``--dataset`` help. ``get_dataset_dataloader`` rejects inputs that mix a combo with one of its member datasets (e.g. ``cnn_dailymail,cnn_nemotron_v2_mix``) to avoid double-sampling, and ``get_dataset_samples`` rejects combo names so callers route through the dataloader. ``hf_ptq.py`` default ``--calib_size`` is bumped from ``512`` to ``1024`` so the total calibration sample count under the new default combo matches the previous two-dataset fallback. diff --git a/docs/source/guides/10_recipes.rst b/docs/source/guides/10_recipes.rst index ed2473c0233..7b9180c52d6 100644 --- a/docs/source/guides/10_recipes.rst +++ b/docs/source/guides/10_recipes.rst @@ -499,14 +499,22 @@ General PTQ recipes are model-agnostic and apply to any supported architecture: - NVFP4 W4A4, FP8 KV cache with data-driven calibration * - ``general/ptq/nvfp4_default-kv_nvfp4_cast`` - NVFP4 W4A4, NVFP4 KV cache with constant amax, max calibration + * - ``general/ptq/nvfp4_mlp_only-kv_fp8_cast`` + - NVFP4 for MLP layers only, FP8 KV cache with constant amax * - ``general/ptq/nvfp4_mlp_only-kv_fp8`` - NVFP4 for MLP layers only, FP8 KV cache + * - ``general/ptq/nvfp4_experts_only-kv_fp8_cast`` + - NVFP4 for MoE expert layers only, FP8 KV cache with constant amax * - ``general/ptq/nvfp4_experts_only-kv_fp8`` - NVFP4 for MoE expert layers only, FP8 KV cache * - ``general/ptq/nvfp4_experts_only-kv_fp8_layerwise`` - NVFP4 for MoE expert layers only, FP8 KV cache, layerwise calibration + * - ``general/ptq/nvfp4_omlp_only-kv_fp8_cast`` + - NVFP4 for output projection + MLP layers, FP8 KV cache with constant amax * - ``general/ptq/nvfp4_omlp_only-kv_fp8`` - NVFP4 for output projection + MLP layers, FP8 KV cache + * - ``general/ptq/nvfp4_weight_only-kv_fp8_cast`` + - NVFP4 W4A16 weight-only, FP8 KV cache with constant amax Model-specific recipes ---------------------- @@ -668,10 +676,14 @@ The ``modelopt_recipes/`` package is organized as follows: | +-- nvfp4_default-kv_fp8_cast.yaml | +-- nvfp4_default-kv_fp8.yaml | +-- nvfp4_default-kv_nvfp4_cast.yaml + | +-- nvfp4_mlp_only-kv_fp8_cast.yaml | +-- nvfp4_mlp_only-kv_fp8.yaml + | +-- nvfp4_experts_only-kv_fp8_cast.yaml | +-- nvfp4_experts_only-kv_fp8.yaml | +-- nvfp4_experts_only-kv_fp8_layerwise.yaml + | +-- nvfp4_omlp_only-kv_fp8_cast.yaml | +-- nvfp4_omlp_only-kv_fp8.yaml + | +-- nvfp4_weight_only-kv_fp8_cast.yaml +-- huggingface/ # Model-specific recipes | +-- / # see modelopt_recipes/huggingface/README.md | +-- / diff --git a/modelopt_recipes/general/ptq/nvfp4_experts_only-kv_fp8_cast.yaml b/modelopt_recipes/general/ptq/nvfp4_experts_only-kv_fp8_cast.yaml new file mode 100644 index 00000000000..8b4d2c20f30 --- /dev/null +++ b/modelopt_recipes/general/ptq/nvfp4_experts_only-kv_fp8_cast.yaml @@ -0,0 +1,51 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Composed PTQ recipe for expert-only dynamic NVFP4 quantization with FP8 KV-cache cast mode. + +imports: + base_disable_all: configs/ptq/units/base_disable_all + default_disabled_quantizers: configs/ptq/units/default_disabled_quantizers + nvfp4: configs/numerics/nvfp4 + kv_fp8_cast: configs/ptq/units/kv_fp8_cast + +metadata: + recipe_type: ptq + description: >- + Applies dynamic NVFP4 only to expert-layer weight and input quantizers, plus FP8 KV-cache cast + mode using constant amax; uses max calibration. +quantize: + algorithm: + method: max + # Max calibration is fast and does not typically need checkpointing. + # layerwise=false required for VLMs where the decoder layers are nested under + # `model.language_model.layers` (layerwise_calibrate can't find them otherwise). + layerwise: false + quant_cfg: + - $import: base_disable_all + - quantizer_name: '*mlp.experts*weight_quantizer' + cfg: + $import: nvfp4 + - quantizer_name: '*mlp.experts*input_quantizer' + cfg: + $import: nvfp4 + - quantizer_name: '*block_sparse_moe*weight_quantizer' + cfg: + $import: nvfp4 + - quantizer_name: '*block_sparse_moe*input_quantizer' + cfg: + $import: nvfp4 + - $import: kv_fp8_cast + - $import: default_disabled_quantizers diff --git a/modelopt_recipes/general/ptq/nvfp4_mlp_only-kv_fp8_cast.yaml b/modelopt_recipes/general/ptq/nvfp4_mlp_only-kv_fp8_cast.yaml new file mode 100644 index 00000000000..225ecf7f086 --- /dev/null +++ b/modelopt_recipes/general/ptq/nvfp4_mlp_only-kv_fp8_cast.yaml @@ -0,0 +1,52 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Composed PTQ recipe for MLP/MoE-only dynamic NVFP4 quantization with FP8 KV-cache cast mode. + +imports: + base_disable_all: configs/ptq/units/base_disable_all + default_disabled_quantizers: configs/ptq/units/default_disabled_quantizers + nvfp4: configs/numerics/nvfp4 + kv_fp8_cast: configs/ptq/units/kv_fp8_cast + +metadata: + recipe_type: ptq + description: >- + Applies dynamic NVFP4 only to MLP/MoE weight and input quantizers, plus FP8 KV-cache cast mode + using constant amax; uses max calibration. +quantize: + algorithm: max + quant_cfg: + - $import: base_disable_all + - quantizer_name: '*mlp*weight_quantizer' + cfg: + $import: nvfp4 + - quantizer_name: '*mlp*input_quantizer' + cfg: + $import: nvfp4 + - quantizer_name: '*block_sparse_moe*weight_quantizer' + cfg: + $import: nvfp4 + - quantizer_name: '*block_sparse_moe*input_quantizer' + cfg: + $import: nvfp4 + - quantizer_name: '*.experts.*weight_quantizer' + cfg: + $import: nvfp4 + - quantizer_name: '*.experts.*input_quantizer' + cfg: + $import: nvfp4 + - $import: kv_fp8_cast + - $import: default_disabled_quantizers diff --git a/modelopt_recipes/general/ptq/nvfp4_omlp_only-kv_fp8_cast.yaml b/modelopt_recipes/general/ptq/nvfp4_omlp_only-kv_fp8_cast.yaml new file mode 100644 index 00000000000..ba9e1e1c27a --- /dev/null +++ b/modelopt_recipes/general/ptq/nvfp4_omlp_only-kv_fp8_cast.yaml @@ -0,0 +1,52 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Composed PTQ recipe for output-projection and MLP/MoE dynamic NVFP4 quantization with FP8 KV-cache cast mode. + +imports: + base_disable_all: configs/ptq/units/base_disable_all + default_disabled_quantizers: configs/ptq/units/default_disabled_quantizers + nvfp4: configs/numerics/nvfp4 + kv_fp8_cast: configs/ptq/units/kv_fp8_cast + +metadata: + recipe_type: ptq + description: >- + Applies dynamic NVFP4 to output-projection and MLP/MoE weight and input quantizers, plus + FP8 KV-cache cast mode using constant amax; uses max calibration. +quantize: + algorithm: max + quant_cfg: + - $import: base_disable_all + - quantizer_name: '*mlp*weight_quantizer' + cfg: + $import: nvfp4 + - quantizer_name: '*mlp*input_quantizer' + cfg: + $import: nvfp4 + - quantizer_name: '*block_sparse_moe*weight_quantizer' + cfg: + $import: nvfp4 + - quantizer_name: '*block_sparse_moe*input_quantizer' + cfg: + $import: nvfp4 + - quantizer_name: '*o_proj*weight_quantizer' + cfg: + $import: nvfp4 + - quantizer_name: '*o_proj*input_quantizer' + cfg: + $import: nvfp4 + - $import: kv_fp8_cast + - $import: default_disabled_quantizers diff --git a/modelopt_recipes/general/ptq/nvfp4_weight_only-kv_fp8_cast.yaml b/modelopt_recipes/general/ptq/nvfp4_weight_only-kv_fp8_cast.yaml new file mode 100644 index 00000000000..ff335405f30 --- /dev/null +++ b/modelopt_recipes/general/ptq/nvfp4_weight_only-kv_fp8_cast.yaml @@ -0,0 +1,35 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Composed PTQ recipe for NVFP4 W4A16 weight-only quantization with FP8 KV-cache cast mode. + +imports: + base_disable_all: configs/ptq/units/base_disable_all + default_disabled_quantizers: configs/ptq/units/default_disabled_quantizers + w4a16_nvfp4: configs/ptq/units/w4_nvfp4 + kv_fp8_cast: configs/ptq/units/kv_fp8_cast + +metadata: + recipe_type: ptq + description: >- + NVFP4 W4A16 weight-only, BF16 activations, plus FP8 KV-cache cast mode using constant amax; uses + max calibration. No calibration forward pass required. +quantize: + algorithm: max + quant_cfg: + - $import: base_disable_all + - $import: w4a16_nvfp4 + - $import: kv_fp8_cast + - $import: default_disabled_quantizers diff --git a/tests/unit/recipe/test_loader.py b/tests/unit/recipe/test_loader.py index b5c433888d7..c0f4dfef9af 100644 --- a/tests/unit/recipe/test_loader.py +++ b/tests/unit/recipe/test_loader.py @@ -161,9 +161,13 @@ def test_load_recipe_builtin_description(): "general/ptq/nvfp4_default-kv_nvfp4_cast", "general/ptq/nvfp4_default-kv_none-gptq", "general/ptq/nvfp4_experts_only-kv_fp8", + "general/ptq/nvfp4_experts_only-kv_fp8_cast", "general/ptq/nvfp4_experts_only-kv_fp8_layerwise", "general/ptq/nvfp4_mlp_only-kv_fp8", + "general/ptq/nvfp4_mlp_only-kv_fp8_cast", "general/ptq/nvfp4_omlp_only-kv_fp8", + "general/ptq/nvfp4_omlp_only-kv_fp8_cast", + "general/ptq/nvfp4_weight_only-kv_fp8_cast", ] From bb988d23d4c7c6f46e65911abf359235b00432d9 Mon Sep 17 00:00:00 2001 From: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> Date: Tue, 9 Jun 2026 22:16:02 +0530 Subject: [PATCH 07/24] ci: cache JIT-compiled CUDA torch extensions in GPU/example tests (#1651) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What does this PR do? Type of change: CI / infrastructure (build-time speedup) ModelOpt's CUDA quantization extensions (`modelopt_cuda_ext`, `_fp8`, `_mx`) JIT-compile via `torch.utils.cpp_extension.load()` on first use — ~110–140s **each** in a fresh container, which is the dominant cost of the `gpu_trtllm` job and the TRT-LLM example jobs. This caches them across runs. The logic lives in a reusable composite action, **`.github/actions/cache-extensions`**, used by both `gpu_tests.yml` and `_example_tests_runner.yml`: - Sets a **literal in-container `TORCH_EXTENSIONS_DIR`** (`/root/.cache/torch_extensions`). `${{ github.workspace }}` can't be used — for `container:` jobs it resolves to the *host* path, which is mounted elsewhere (`/__w`) inside the container, so torch and the cache step would disagree on the location. - Caches that dir with `actions/cache`, keyed on a caller-supplied **env discriminator** (`rtxpro6000` + container image) plus a `hashFiles` of the kernel/loader sources — so the cache busts on any kernel change and is scoped per arch+image. - On an **exact hit**, **backdates the kernel sources** below the cached objects so ninja reuses them. (Touching the *objects* instead desyncs ninja's `.ninja_deps`, which records each output's build-time mtime → `stored deps info out of date` → rebuild.) Also fixes the unused `runner` default in `_example_tests_runner.yml` (`h100` → `rtxpro6000`) so it can't seed a wrong-arch cache. ### Usage N/A — CI only. To reuse from another job: ```yaml - uses: ./.github/actions/cache-extensions with: cache-key: rtxpro6000-${{ matrix.container_image }} # GPU arch + image ``` ### Testing Validated on `gpu_trtllm`: cache hit → `ninja: no work to do` → `test_cuda_ext*` dropped from **113s / 108s / 139s → 2.8s / 0.03s / 0.03s** (~360s saved per run). Jobs that build no extension (e.g. `gpu_vllm`) simply skip the save. ### Before your PR is "*Ready for review*" - Is this change backward compatible?: ✅ (CI-only; key busts on source/image change) - If you copied code from any other sources or added a new PIP dependency, did you follow guidance in `CONTRIBUTING.md`: N/A - Did you write any new necessary tests?: N/A - Did you update Changelog?: N/A (CI infrastructure) - Did you get Claude approval on this PR?: ❌ (pending) ### Additional Information - Single-arch assumption: callers pass `rtxpro6000` in `cache-key`; if the runner fleet ever mixes GPU archs, update that prefix (the cache path is not arch-specific). - No explicit TTL: the key is content-addressed, and GitHub auto-evicts caches unused for 7 days (+ 10 GB/repo LRU). 🤖 Generated with [Claude Code](https://claude.com/claude-code) Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> Co-authored-by: Claude Opus 4.8 Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> --- .github/actions/cache-extensions/action.yml | 24 +++++++++++++++++++++ .github/workflows/_example_tests_runner.yml | 5 ++++- .github/workflows/gpu_tests.yml | 5 ++++- 3 files changed, 32 insertions(+), 2 deletions(-) create mode 100644 .github/actions/cache-extensions/action.yml diff --git a/.github/actions/cache-extensions/action.yml b/.github/actions/cache-extensions/action.yml new file mode 100644 index 00000000000..3df12f8f6fc --- /dev/null +++ b/.github/actions/cache-extensions/action.yml @@ -0,0 +1,24 @@ +name: Cache extensions +description: Cache and reuse JIT-compiled extensions (e.g. CUDA extensions) across runs. + +inputs: + cache-key: + description: Environment discriminator for the cache key (e.g. GPU arch + container image). + required: true + +runs: + using: composite + steps: + - shell: bash + run: echo "TORCH_EXTENSIONS_DIR=/root/.cache/torch_extensions" >> "$GITHUB_ENV" + - id: cache + uses: actions/cache@v4 + with: + path: /root/.cache/torch_extensions + key: torch-ext-${{ inputs.cache-key }}-${{ hashFiles('modelopt/torch/kernels/quantization/**', 'modelopt/torch/quantization/extensions.py', 'modelopt/torch/utils/cpp_extension.py') + }} + # On a cache hit, backdate sources below the cached objects so ninja reuses them (touching + # the objects instead would desync ninja's .ninja_deps and force a rebuild). + - if: steps.cache.outputs.cache-hit == 'true' + shell: bash + run: find modelopt/torch/kernels/quantization -exec touch -d '2000-01-01' {} + diff --git a/.github/workflows/_example_tests_runner.yml b/.github/workflows/_example_tests_runner.yml index c7a04e17ef9..2bca58b8a81 100644 --- a/.github/workflows/_example_tests_runner.yml +++ b/.github/workflows/_example_tests_runner.yml @@ -26,7 +26,7 @@ on: description: "GitHub runner to use" required: false type: string - default: "linux-amd64-gpu-h100-latest-1" + default: "linux-amd64-gpu-rtxpro6000-latest-1" jobs: run-test: @@ -41,6 +41,9 @@ jobs: steps: - uses: actions/checkout@v6 - uses: nv-gha-runners/setup-proxy-cache@main + - uses: ./.github/actions/cache-extensions + with: + cache-key: rtxpro6000-${{ inputs.docker_image }} - name: Setup environment variables run: | echo "LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/include:/usr/lib/x86_64-linux-gnu:/usr/local/tensorrt/targets/x86_64-linux-gnu/lib" >> $GITHUB_ENV diff --git a/.github/workflows/gpu_tests.yml b/.github/workflows/gpu_tests.yml index ab23036e0f4..009089c1233 100644 --- a/.github/workflows/gpu_tests.yml +++ b/.github/workflows/gpu_tests.yml @@ -45,7 +45,7 @@ jobs: timeout: 60 container_image: nvcr.io/nvidia/nemo:26.04 - example: gpu_trtllm - timeout: 30 + timeout: 15 container_image: nvcr.io/nvidia/tensorrt-llm/release:1.3.0rc17 - example: gpu_vllm timeout: 15 @@ -66,6 +66,9 @@ jobs: run: apt-get update && apt-get install -y git - uses: actions/checkout@v6 - uses: nv-gha-runners/setup-proxy-cache@main + - uses: ./.github/actions/cache-extensions + with: + cache-key: rtxpro6000-${{ matrix.container_image }} - name: Setup environment variables run: | echo "LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/include:/usr/lib/x86_64-linux-gnu" >> $GITHUB_ENV From 5f4cc79a7ac3a339f0428cece2d176ed14129c81 Mon Sep 17 00:00:00 2001 From: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> Date: Wed, 10 Jun 2026 00:25:42 +0530 Subject: [PATCH 08/24] Migrate Nemotron-3-Nano tutorial PTQ to MBridge scripts and move under examples/megatron_bridge (#1601) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What does this PR do? Type of change: documentation (+ minor test fixes) Migrates the Nemotron-3-Nano-30B-A3B-BF16 tutorial quantization step from `examples/llm_ptq/hf_ptq.py` to the Megatron-Bridge quantize + export, and relocates the tutorial next to the scripts it now uses. Now that the whole tutorial is Megatron-Bridge based, it lives under `examples/megatron_bridge/`. - **Quantization migration:** replace the single `hf_ptq.py` call with `examples/megatron_bridge/quantize.py` (calibrate + save a Megatron checkpoint) → `examples/megatron_bridge/export.py` (deployable unified HF checkpoint). The FP8 results table is refreshed with the `quantize.py` numbers (same defaults, slightly better on average). - **Relocation:** moved `examples/pruning/minitron/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16/` → `examples/megatron_bridge/tutorials/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16/`. A **redirect-stub `README.md`** remains at the old path (a directory symlink isn't traversable in the GitHub web UI), and all in-repo references (root README, CHANGELOG, pruning READMEs, megatron_bridge README) plus the tutorial's own relative links are updated. - **Evaluation:** per-format vLLM benchmark commands (BF16 / FP8), FP8 deployment notes documented in `nemo_evaluator.yaml`, reduced LiveCodeBench/AIME `num_repeats` (were too slow), and bumped the `nemo-evaluator-launcher` pin. - **Misc:** drop the `examples/megatron_bridge/requirements.txt` `transformers<5` pin in favor of an inline "downgrade `transformers<5` to save pruned Nemotron checkpoints" note; guard the hybrid Mamba-MoE sharded-state-dict test behind `HAS_MAMBA` (requires `mamba_ssm`); shrink the tiny Gemma3 test fixture's attention heads. > **Note:** the **NVFP4 + QAD** experiments (formerly the focus of this PR) are split out — their accuracy/throughput results are still in progress — and will follow in a separate PR on top of this one. ### Testing Docs-only + test-guard changes. Pre-commit hooks (markdownlint, RST checks, ruff, mypy) pass. The tutorial's relative links and the old-path redirect stub were verified to resolve to real files. ### Before your PR is "*Ready for review*" - Is this change backward compatible?: ✅ (old tutorial path still resolves via a redirect-stub README; `quantize.py`/`export.py` already exist in `examples/megatron_bridge`) - If you copied code from any other sources or added a new PIP dependency, did you follow guidance in `CONTRIBUTING.md`: N/A - Did you write any new necessary tests?: N/A (adjusts/guards existing tests only) - Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?: ✅ (existing tutorial entry updated to the new path) - Did you get Claude approval on this PR?: ✅ ### Additional Information Supersedes the previous "Part 3 of 4 (NVFP4 + QAD docs)" scope of this PR; the NVFP4 + QAD tutorial additions will land in a follow-up. ## Summary by CodeRabbit * **Documentation** * Moved the Nemotron-3-Nano-30B-A3B tutorial into the Megatron-Bridge tutorials and replaced the old file with a pointer to the new location. * Updated vLLM throughput numbers to 2.6× and expanded results/throughput tables. * Reworked the FP8 quantization/export workflow and added a note to use transformers<5 when saving pruned models. * Added a tutorials index and adjusted evaluator launcher pin and repeat counts. * **Tests** * Tests now detect optional Mamba support and skip related tests when unavailable. --------- Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> Co-authored-by: Claude Opus 4.8 (1M context) Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> --- CHANGELOG.rst | 2 +- README.md | 2 +- examples/megatron_bridge/README.md | 11 +- examples/megatron_bridge/requirements.txt | 2 - .../ABLATIONS.md | 0 .../README.md | 545 ++++++++++++++++++ .../figures/learning_curves.png | Bin .../nemo_evaluator.yaml | 43 +- examples/megatron_bridge/tutorials/README.md | 10 + examples/pruning/README.md | 2 +- .../README.md | 492 +--------------- .../NVIDIA-Nemotron-Nano-9B-v2/README.md | 23 +- .../nemo_evaluator.yaml | 25 +- .../pruning/minitron_vs_puzzletron/README.md | 6 +- tests/_test_utils/torch/megatron/models.py | 1 + .../_test_utils/torch/transformers_models.py | 2 +- .../quantization/plugins/test_megatron.py | 3 + 17 files changed, 617 insertions(+), 552 deletions(-) delete mode 100644 examples/megatron_bridge/requirements.txt rename examples/{pruning/minitron => megatron_bridge/tutorials}/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16/ABLATIONS.md (100%) create mode 100644 examples/megatron_bridge/tutorials/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16/README.md rename examples/{pruning/minitron => megatron_bridge/tutorials}/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16/figures/learning_curves.png (100%) rename examples/{pruning/minitron => megatron_bridge/tutorials}/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16/nemo_evaluator.yaml (76%) create mode 100644 examples/megatron_bridge/tutorials/README.md diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 12da06a2066..d39a76dd8e3 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -31,7 +31,7 @@ Changelog - Add support for ``active_params`` (for MoE models) and ``memory_mb`` constraints in Minitron pruning on top of existing ``params`` constraint. You can also provide multiple constraints. See `examples/pruning/README.md `_ for more details. The underlying utility functions ``mcore_param_count``, ``mcore_memory_footprint_mb``, and ``print_mcore_model_stats`` in ``modelopt.torch.nas.plugins.megatron_model_stats`` are also available for standalone use to compute parameter counts and memory footprints (weights + KV-cache + Mamba state) for any Megatron-Core model. - Add Minitron pruning support for Megatron-Bridge Gemma3 models. - Add quantization examples for the Megatron-Bridge framework: post-training quantization (`quantize.py `_), export to a deployable HuggingFace checkpoint (`export.py `_), and Quantization Aware Distillation (extend existing `distill.py `_). -- Add end-to-end tutorial for Minitron pruning + two-phase distillation (80B @ 8K + 20B @ 32K long-context = 100B tokens) + FP8 PTQ + vLLM deployment for Nemotron-3-Nano-30B-A3B-BF16 (MoE + Mamba-Transformer hybrid) → Pruned 22B/A3.0B active params, along with data blend preparation steps (with tool-calling data) and detailed pruning / data-blend / long-context ablations. See `examples/pruning/minitron/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16/README.md `_ for details. +- Add end-to-end optimization tutorial for Minitron pruning + two-phase distillation (80B @ 8K + 20B @ 32K long-context = 100B tokens) + FP8 PTQ + vLLM deployment for Nemotron-3-Nano-30B-A3B-BF16 (MoE + Mamba-Transformer hybrid) → Pruned 22B/A3.0B active params, along with data blend preparation steps (with tool-calling data) and detailed pruning / data-blend / long-context ablations. See `examples/megatron_bridge/tutorials/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16/README.md `_ for details. - Add ``--cast_mxfp4_to_nvfp4`` flag to ``examples/llm_ptq/hf_ptq.py`` for closed-form, bit-exact MXFP4 → NVFP4 weight conversion. Supports the GPT-OSS family (``openai/gpt-oss-20b``, ``openai/gpt-oss-120b``). See `examples/llm_ptq/README.md `__ for usage. - DeepSeek PTQ (``examples/deepseek/ptq.py``) now defaults to native top-k calibration with post-hoc per-layer peer-max sync of expert ``input_quantizer.amax``; the all-experts path is preserved behind ``--calib_all_experts``. - Add NVFP4 W4A16 weight-only quantization (``w4a16_nvfp4``): FP4 weights with group_size=16, BF16 activations, no calibration forward pass required. Use ``mtq.W4A16_NVFP4_CFG`` or ``--qformat w4a16_nvfp4`` in ``hf_ptq.py``. vLLM deployment support is in progress. diff --git a/README.md b/README.md index b3de87a9e4f..452a0767016 100644 --- a/README.md +++ b/README.md @@ -26,7 +26,7 @@ Model Optimizer is also integrated with [NVIDIA Megatron-Bridge](https://github. ## Latest News -- [2026/05/27] [**End-to-end Minitron workflow for Nemotron-3-Nano-30B-A3B**](./examples/pruning/minitron/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16): Pruning + two-phase distillation + FP8 quantization achieving 1.64× vLLM throughput and 2.6× memory reduction. +- [2026/05/27] [**End-to-end Minitron workflow for Nemotron-3-Nano-30B-A3B**](./examples/megatron_bridge/tutorials/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16): Pruning + two-phase distillation + FP8 quantization achieving 2.6× vLLM throughput and 2.6× memory reduction. - [2026/05/13] [**Puzzletron**](./examples/puzzletron): A new algorithm for heterogeneous pruning & NAS of LLM and VLM models. - [2026/04/15] Customer story: [Domyn compresses Colosseum-355B → 260B using ModelOpt's Minitron pruning + distillation](https://www.domyn.com/blog/domyn-large-the-journey-of-a-european-sovereign-ai-model-for-regulated-industries) - [2026/03/17] Customer story: [Bielik.AI builds Bielik Minitron 7B (33% smaller, 50% faster, 90% quality retained) using ModelOpt's Minitron pruning + distillation](https://bielik.ai/en/nvidia-gtc-bielik-minitron-premiere/) diff --git a/examples/megatron_bridge/README.md b/examples/megatron_bridge/README.md index 715a9aefc3d..6bb1f19be5c 100644 --- a/examples/megatron_bridge/README.md +++ b/examples/megatron_bridge/README.md @@ -16,7 +16,7 @@ This directory contains examples of using Model Optimizer with [NeMo Megatron-Br > [!TIP] -> Checkout the [Nemotron-3-Nano-30B-A3B pruning + distillation (with data blend prep) + quantization tutorial](../pruning/minitron/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16/README.md) for a complete end-to-end workflow using Megatron-Bridge! +> Checkout the [Nemotron-3-Nano-30B-A3B pruning + distillation (with data blend prep) + quantization tutorial](tutorials/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16/README.md) for a complete end-to-end workflow using Megatron-Bridge! ## Pre-Requisites @@ -47,12 +47,6 @@ docker run \ > [!WARNING] > Use `python -m pip` instead of `pip` to avoid conflicts with the system-wide installed packages in the NeMo containers. You may also refer to this [doc](https://github.com/NVIDIA-NeMo/Megatron-Bridge/blob/main/docker/common/README.md#installing-packages-inside-the-container) on how to correctly install packages in the NeMo containers without breaking existing torch installation. -Also install additional dependencies from the [requirements.txt](./requirements.txt) file. - -```bash -python -m pip install -r requirements.txt -``` - You also need to login with your HuggingFace token to download gated datasets / models. Note that the default dataset for pruning and quantization is [`nemotron-post-training-dataset-v2`](https://huggingface.co/datasets/nvidia/Nemotron-Post-Training-Dataset-v2), which is gated. @@ -307,6 +301,9 @@ torchrun --nproc_per_node 1 prune_minitron.py --help > uneven PP by setting `--num_layers_in_first_pipeline_stage` and `--num_layers_in_last_pipeline_stage`. > E.g. for Qwen3-8B with 36 layers and 8 GPUs, you can set both to 3 to get 3-5-5-5-5-5-5-3 layers per GPU. +> [!NOTE] +> If pruning a Nemotron model and you want to save the pruned model back in HF format, please downgrade to `transformers<5` via `python -m pip install "transformers<5"` before pruning. + ## Resources - 📅 [Roadmap](https://github.com/NVIDIA/Model-Optimizer/issues/146) diff --git a/examples/megatron_bridge/requirements.txt b/examples/megatron_bridge/requirements.txt deleted file mode 100644 index ec38c2f7ee7..00000000000 --- a/examples/megatron_bridge/requirements.txt +++ /dev/null @@ -1,2 +0,0 @@ -# Saving some pruned models (e.g. Nemotron-3-Nano-30B-A3B-BF16) have issues with transformers>=5.0 -transformers<5.0 diff --git a/examples/pruning/minitron/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16/ABLATIONS.md b/examples/megatron_bridge/tutorials/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16/ABLATIONS.md similarity index 100% rename from examples/pruning/minitron/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16/ABLATIONS.md rename to examples/megatron_bridge/tutorials/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16/ABLATIONS.md diff --git a/examples/megatron_bridge/tutorials/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16/README.md b/examples/megatron_bridge/tutorials/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16/README.md new file mode 100644 index 00000000000..1c7729da51a --- /dev/null +++ b/examples/megatron_bridge/tutorials/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16/README.md @@ -0,0 +1,545 @@ +# Nemotron-3-Nano-30B-A3B: Prune + Distill + Quantize + vLLM Deployment + +End-to-end optimization of [NVIDIA-Nemotron-3-Nano-30B-A3B-BF16](https://huggingface.co/nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16) demonstrating how ModelOpt techniques stack: Minitron structured pruning → Megatron-Bridge knowledge distillation to recover accuracy → evaluation benchmarking → FP8 quantization → vLLM deployment and throughput benchmarking. This document covers: + +1. **[Data Preparation](#1-data-preparation)** — tokenizing the training blend for distillation +2. **[Pruning](#2-pruning)** — Minitron structured pruning +3. **[Distillation](#3-distillation)** — recovering accuracy via Megatron-Bridge knowledge distillation +4. **[Evaluation](#4-evaluation)** — benchmarking with NeMo Evaluator across MMLU Pro, GPQA Diamond, AIME, and more +5. **[Quantization](#5-quantization)** — FP8 PTQ on the distilled checkpoint using ModelOpt's `examples/megatron_bridge/quantize.py` script +6. **[vLLM Inference Benchmarking](#6-vllm-inference-benchmarking)** — throughput comparison of BF16 vs FP8 on a single H100 + +## Results + +![Benchmark Recovery During Knowledge Distillation](figures/learning_curves.png) + +| Model | MMLU Pro | GPQA Diamond | LiveCodeBench v6 | AIME 2025 | IFBench | SciCode (Subtask) | Average | +| --- | --- | --- | --- | --- | --- | --- | --- | +| Pruned 22B/A3.0B (no distillation) | 47.1 | 33.5 | 27.4 | 15.5 | 36.9 | 12.1 | 28.8 | +| Distill @ 2.5B tokens (100 iters at 8K SeqLen) | 73.3 | 63.7 | 55.3 | 77.6 | 59.1 | 25.1 | 59.0 | +| Distill @ 20B tokens (800 iters at 8K SeqLen) | 74.8 | 66.0 | 62.3 | 79.6 | 65.4 | 26.1 | 62.4 | +| Distill @ 40B tokens (1600 iters at 8K SeqLen) | 76.4 | 67.2 | 62.3 | 79.8 | 66.0 | 26.6 | 63.1 | +| Distill @ 60B tokens (2400 iters at 8K SeqLen) | 76.1 | 68.1 | 63.6 | 78.8 | 67.3 | 27.0 | 63.5 | +| Distill @ 80B tokens (3200 iters at 8K SeqLen) | 76.5 | 69.1 | 63.9 | 80.7 | 66.5 | 29.0 | 64.3 | +| Distill @ 82.5B tokens (+100 iters at 32K SeqLen) | 76.2 | 69.8 | 64.8 | 87.0 | 68.2 | 27.0 | 65.5 | +| Distill @ 100B tokens (+800 iters at 32K SeqLen) - **BF16** | 76.6 | 69.6 | 66.1 | 87.3 | 68.9 | 28.4 | 66.2 | +| Distill @ 100B tokens + **FP8 Quantize** | 76.7 | 70.7 | 65.5 | 87.3 | 69.0 | 28.5 | 66.3 | +| Nemotron-3-Nano-30B-A3B-BF16 (official, 31.6B/A3.6B) | 78.0 | 70.3 | 67.9 | 87.1 | 69.1 | 31.8 | 67.4 | + +### vLLM Throughput (single H100, ISL=32768, OSL=1024) + +| Checkpoint | Model loading memory | Output tokens/s | Speedup vs Nemotron-3-Nano-30B-A3B-BF16 | +| --- | --- | --- | --- | +| Nemotron-3-Nano-30B-A3B-BF16 (official, 31.6B/A3.6B) | 58.9 GiB | 598 | 1.0× | +| Nemotron-3-Nano-30B-A3B-FP8 (official) | 31.4 GiB | 1,323 | 2.2× | +| Nemotron-3-Nano-Pruned-22B-A3.0B-BF16 | 41.5 GiB | 1,190 | 2.0× | +| Nemotron-3-Nano-Pruned-22B-A3.0B-FP8 | 22.8 GiB | 1,576 | 2.6× | + +Pruning alone (BF16 → Pruned-A3.0B BF16) gives a **2.0×** throughput speedup with a 30% memory reduction (58.9 → 41.5 GiB), and FP8 quantization alone (BF16 → FP8) gives a **2.2×** speedup with a 47% memory reduction. Stacking both — pruning + FP8 — compounds to a **2.6×** throughput speedup and a **2.6× memory reduction** (58.9 → 22.8 GiB) relative to the original 30B BF16 model, while preserving most of the benchmark accuracy. See [Section 6](#6-vllm-inference-benchmarking) for the benchmark command. + +Distillation uses the **30% Pretraining (Code 5, General 20, MATH 5) + 70% Post-training v1/v3 (Math 27, Coding 20, Science 13, IF 5, Tool calling 5)** blend (see [Data Blend](#data-blend) below) with an **80B @ 8K + 20B @ 32K = 100B token** schedule. Blend ablations and long-context phase ablations are in [ABLATIONS.md](ABLATIONS.md). + +> [!TIP] +> From the benchmark numbers above, the model is still learning at 100B tokens and that further training (or a higher-quality data blend) would continue to close the gap to the original 31.6B/A3.6B model. + +> [!NOTE] +> Exact numbers may vary depending on deployment and evaluation setup. All models above (including the official model) were evaluated with the same [evaluation setup](#4-evaluation) for fair comparison. These numbers may differ from those reported on the official [Nemotron-3-Nano-30B-A3B-BF16](https://huggingface.co/nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16) HuggingFace model card. + +--- + +## Steps to Reproduce + +**Environment:** Container `nvcr.io/nvidia/nemo:26.04`, ModelOpt 0.45.0. See the [Megatron-Bridge README](../../README.md) for environment setup (including ModelOpt mount path) and container usage. Pruning Nemotron models requires `transformers<5` via `python -m pip install "transformers<5"` else saving pruned model as HF checkpoint may fail. + +### 1. Data Preparation + +See [examples/dataset/MEGATRON_DATA_PREP.md](../../../dataset/MEGATRON_DATA_PREP.md) for tokenization commands for all datasets used in this blend. + +For this experiment: `TOKENIZER=nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16`, `OUTPUT_DIR=tokenized_nemotron_3`. + +#### Data Blend + +**30% Pretraining (Code 5, General 20, MATH 5) + 70% Post-training v1/v3 (Math 27, Coding 20, Science 13, IF 5, Tool calling 5)** + +| Dataset | Tokens | Weight | Notes | +| ---------------------------------------------------------- | ------ | ------ | ---------------------------------------------- | +| Nemotron-Pretraining-SFT-v1 / Code (10M samples) | 7B | 5 | Pretraining code | +| Nemotron-Pretraining-SFT-v1 / General (10M samples) | 16B | 20 | Upweighted to close MMLU gap | +| Nemotron-Pretraining-SFT-v1 / MATH (10M samples) | 13B | 5 | Pretraining math | +| Nemotron-Math-v2 / high_part00 | 13B | 10 | Hard math reasoning | +| Nemotron-SFT-Math-v3 / train | 52B | 17 | Hard math reasoning with full reasoning traces | +| Nemotron-SFT-Competitive-Programming-v2 / python_00 | 7B | 15 | Python reasoning traces | +| Nemotron-SFT-Competitive-Programming-v2 / cpp_00 | 7B | 5 | C++ reasoning traces | +| Nemotron-Post-Training-Dataset-v1 / stem (5M samples) | 22B | 8 | Broad STEM | +| Nemotron-Science-v1 / MCQ | 0.5B | 3 | GPQA MCQ format alignment | +| Nemotron-Science-v1 / RQA | 0.3B | 2 | GPQA format diversity | +| Nemotron-SFT-Instruction-Following-Chat-v2 / reasoning_on | 2B | 3 | Instruction following (thinking on) | +| Nemotron-SFT-Instruction-Following-Chat-v2 / reasoning_off | 1B | 2 | Instruction following (thinking off) | +| Nemotron-Agentic-v1 / tool_calling | 1B | 5 | Tool-use scaffolding; helps SciCode / GPQA | + +
+Data blend for distillation (click to expand) + +```bash +DATA_BLEND=" \ +5 tokenized_nemotron_3/nvidia--Nemotron-Pretraining-SFT-v1_Nemotron-SFT-Code_train_text_max10000000 \ +20 tokenized_nemotron_3/nvidia--Nemotron-Pretraining-SFT-v1_Nemotron-SFT-General_train_text_max10000000 \ +5 tokenized_nemotron_3/nvidia--Nemotron-Pretraining-SFT-v1_Nemotron-SFT-MATH_train_text_max10000000 \ +10 tokenized_nemotron_3/nvidia--Nemotron-Math-v2_default_high_part00_messages \ +17 tokenized_nemotron_3/nvidia--Nemotron-SFT-Math-v3_default_train_messages \ +15 tokenized_nemotron_3/competitive_programming_python_00_messages \ +5 tokenized_nemotron_3/competitive_programming_cpp_00_messages \ +8 tokenized_nemotron_3/nvidia--Nemotron-Post-Training-Dataset-v1_default_stem_messages_max5000000 \ +3 tokenized_nemotron_3/MCQ_messages \ +2 tokenized_nemotron_3/RQA_messages \ +3 tokenized_nemotron_3/reasoning_on_messages \ +2 tokenized_nemotron_3/reasoning_off_messages \ +5 tokenized_nemotron_3/nvidia--Nemotron-Agentic-v1_tool_calling_messages \ +" +``` + +
+ +#### General Guidelines + +The optimal blend is 30% pretraining and 70% post-training data. Exact proportions may vary depending on the benchmarks you care about. The blend above was designed to maximize recovery on popular General Knowledge, Reasoning, Instruction Following, and Tool Calling benchmarks. The key design decisions were: + +- **30% pretraining data** closes the MMLU gap that arises from training exclusively on reasoning-heavy post-training data. The General split (20%) is upweighted specifically to recover general knowledge recall. +- **Math (27%)** is the largest post-training category because AIME and MMLU Pro respond strongly to more math reasoning tokens. We use a mix of `Nemotron-Math-v2` and `Nemotron-SFT-Math-v3` for higher quality math reasoning signal with full reasoning traces. +- **Science (13%)** uses `Nemotron-Post-Training-Dataset-v1 / stem` as the primary source for volume and GPQA stability, with small allocations to `Nemotron-Science-v1` MCQ/RQA subsets for format alignment with GPQA's multiple-choice structure. +- **Instruction following (5%)** saturates quickly so a small allocation is sufficient. +- **Tool calling (5%)** uses `Nemotron-Agentic-v1 / tool_calling`. Our evals run with `--enable-auto-tool-choice`, so the student needs explicit exposure to function-call schemas; this helps SciCode (heavy Python tool use) and GPQA Diamond (which can benefit from calculator tools). + +This blend intentionally omits capabilities not targeted in this experiment (e.g. multilingual, SWE). Depending on what benchmarks matter for your use case, you can substitute or add datasets from the [Nemotron Post-Training v3 collection](https://huggingface.co/collections/nvidia/nemotron-post-training-v3), for example: + +| Capability | Relevant datasets | +| --- | --- | +| Multilingual | `Nemotron-SFT-Multilingual-v1` | +| Software engineering (SWE) | `Nemotron-SFT-SWE-v2` | +| Safety / alignment | `Nemotron-SFT-Safety-v1` | + +When adding new datasets, reduce weights of lower-priority categories proportionally to keep the total at 100%. + +--- + +### 2. Pruning + +Here we prune the [NVIDIA-Nemotron-3-Nano-30B-A3B-BF16](https://huggingface.co/nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16) HuggingFace checkpoint from 31.6B/A3.6B to 3.0B active parameters. The output is a pruned HuggingFace checkpoint that feeds into the distillation step. + +Run on **1 node with 8x H100** (~1 hour) + +
+Pruning command (click to expand) + +```bash +torchrun --nproc_per_node 8 /opt/Model-Optimizer/examples/megatron_bridge/prune_minitron.py \ + --pp_size 8 \ + --hf_model_name_or_path nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16 \ + --trust_remote_code \ + --prune_target_params 28e9 \ + --prune_target_active_params 3e9 \ + --hparams_to_skip num_attention_heads \ + --seq_length 8192 \ + --output_hf_path /path/to/Nemotron-3-Nano-30B-A3B-Pruned-A3.0B \ + --top_k 20 \ + --max_depth_pruning 0.15 \ + --max_width_pruning 0.30 \ + --prune_score_func mmlu_10pct_bs32 \ + --num_layers_in_first_pipeline_stage 5 \ + --num_layers_in_last_pipeline_stage 5 +``` + +Non-default arguments: + +- `--hparams_to_skip num_attention_heads` (default: none) — attention heads pruning is harder to recover, hence skipped +- `--seq_length 8192` (default: 4096) — dataset has longer sequences +- `--prune_target_active_params 3e9` — MoE-specific; the **primary** pruning constraint — targets active params rather than total params, which is what matters for MoE inference cost +- `--prune_target_params 28e9` — upper bound on total params only; the actual pruned model total can range anywhere from ~20B to 28B depending on which architecture wins — see pruning logs below for the top 20 candidates. You may also skip this argument all together for simplicity. +- `--top_k 20` (default: 10) — larger candidate pool for better architecture search +- `--max_depth_pruning 0.15` (default: 0.20) — tighter constraint since candidates with 42–46 layers universally fail for this model +- `--max_width_pruning 0.30` (default: 0.40) — tighter constraint to prevent head_dim≤48 and hidden=2048 dead zones +- `--prune_score_func mmlu_10pct_bs32` (default: `mmlu_10pct_bs1`) — batch_size=32 for ~3–4× faster candidate scoring +- `--num_layers_in_first_pipeline_stage 5 --num_layers_in_last_pipeline_stage 5` — Uneven pipeline parallelism since 52 layers is not divisible by 8 GPUs + +**NOTE**: The tighter search space constraints here (`--max_depth_pruning`, `--max_width_pruning`) are specific to Nemotron hybrid models (Mamba + Attention + MoE). Unlike standard transformers which expose only layers/hidden/attention/FFN dimensions, these models add Mamba-specific dimensions (`mamba_num_heads`, `mamba_head_dim`) and MoE dimensions (`num_moe_experts`, `moe_ffn_hidden_size`, `moe_shared_expert_intermediate_size`), making the combined search space much larger. The default 40%/20% bounds cast too wide a net and waste compute on dead-zone architectures. + +See [ABLATIONS.md](ABLATIONS.md#pruning) for the full architecture search analysis across various candidates. +
+ +
+Pruning logs (top 20 candidates, best subnet, layer patterns) (click to expand) + +```text +╭──────────────────────────────────────────────────── Original Model Stats ─────────────────────────────────────────────────────╮ +│ Total Parameters 31.58B │ +│ Active Parameters 3.58B │ +│ Memory (BF16, seq_length=8192, batch_size=1) weights: 60230.1 MB, kv_cache: 48.0 MB, mamba_state: 23.8 MB, Total: 60301.9 MB │ +╰───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯ + + Search Space + (≤30% width / ≤15% depth pruning) +┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ +┃ Hyperparameter ┃ Choices ┃ +┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ +│ num_layers │ [46, 48, 50, 52] │ +│ hidden_size │ [2048, 2304, 2560, 2688] │ +│ mamba_num_heads │ [48, 56, 64] │ +│ mamba_head_dim │ [48, 56, 64] │ +│ num_moe_experts │ [96, 104, 112, 120, 128] │ +│ moe_ffn_hidden_size │ [1536, 1792, 1856] │ +│ moe_shared_expert_intermediate_size │ [2816, 3072, 3328, 3584, 3712] │ +├─────────────────────────────────────┼────────────────────────────────┤ +│ Search space size │ 10800 │ +└─────────────────────────────────────┴────────────────────────────────┘ + +Top 20 Candidates with Scores +┏━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━┳━━━━━━━━┓ +┃ # ┃ export_config ┃ active_params ┃ params ┃ score ┃ +┡━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━╇━━━━━━━━┩ +│ 1 │ {'num_layers': 46, 'hidden_size': 2560, 'mamba_num_heads': 56, 'mamba_head_dim': 64, 'num_moe_experts': 120, │ 3.00B │ 27.06B │ 0.3399 │ +│ │ 'moe_ffn_hidden_size': 1792, 'moe_shared_expert_intermediate_size': 3072} │ │ │ │ +│ 2 │ {'num_layers': 48, 'hidden_size': 2560, 'mamba_num_heads': 56, 'mamba_head_dim': 56, 'num_moe_experts': 112, │ 3.00B │ 25.37B │ 0.4650 │ +│ │ 'moe_ffn_hidden_size': 1792, 'moe_shared_expert_intermediate_size': 3072} │ │ │ │ +│ 3 │ {'num_layers': 46, 'hidden_size': 2560, 'mamba_num_heads': 64, 'mamba_head_dim': 56, 'num_moe_experts': 112, │ 3.00B │ 25.37B │ 0.2343 │ +│ │ 'moe_ffn_hidden_size': 1792, 'moe_shared_expert_intermediate_size': 3072} │ │ │ │ +│ 4 │ {'num_layers': 52, 'hidden_size': 2688, 'mamba_num_heads': 56, 'mamba_head_dim': 48, 'num_moe_experts': 96, │ 3.00B │ 20.09B │ 0.2552 │ +│ │ 'moe_ffn_hidden_size': 1536, 'moe_shared_expert_intermediate_size': 3072} │ │ │ │ +│ 5 │ {'num_layers': 52, 'hidden_size': 2688, 'mamba_num_heads': 48, 'mamba_head_dim': 56, 'num_moe_experts': 104, │ 3.00B │ 21.61B │ 0.2601 │ +│ │ 'moe_ffn_hidden_size': 1536, 'moe_shared_expert_intermediate_size': 3072} │ │ │ │ +│ 6 │ {'num_layers': 52, 'hidden_size': 2560, 'mamba_num_heads': 48, 'mamba_head_dim': 64, 'num_moe_experts': 96, │ 3.00B │ 19.28B │ 0.3762 │ +│ │ 'moe_ffn_hidden_size': 1536, 'moe_shared_expert_intermediate_size': 3712} │ │ │ │ +│ 7 │ {'num_layers': 52, 'hidden_size': 2304, 'mamba_num_heads': 64, 'mamba_head_dim': 64, 'num_moe_experts': 104, │ 3.00B │ 22.28B │ 0.4783 │ +│ │ 'moe_ffn_hidden_size': 1856, 'moe_shared_expert_intermediate_size': 3072} │ │ │ │ +│ 8 │ {'num_layers': 52, 'hidden_size': 2560, 'mamba_num_heads': 48, 'mamba_head_dim': 48, 'num_moe_experts': 96, │ 3.00B │ 21.99B │ 0.2420 │ +│ │ 'moe_ffn_hidden_size': 1792, 'moe_shared_expert_intermediate_size': 3328} │ │ │ │ +│ 9 │ {'num_layers': 50, 'hidden_size': 2560, 'mamba_num_heads': 48, 'mamba_head_dim': 48, 'num_moe_experts': 112, │ 3.00B │ 25.37B │ 0.2399 │ +│ │ 'moe_ffn_hidden_size': 1792, 'moe_shared_expert_intermediate_size': 3712} │ │ │ │ +│ 10 │ {'num_layers': 50, 'hidden_size': 2560, 'mamba_num_heads': 48, 'mamba_head_dim': 48, 'num_moe_experts': 112, │ 3.00B │ 26.17B │ 0.2601 │ +│ │ 'moe_ffn_hidden_size': 1856, 'moe_shared_expert_intermediate_size': 3328} │ │ │ │ +│ 11 │ {'num_layers': 46, 'hidden_size': 2560, 'mamba_num_heads': 56, 'mamba_head_dim': 64, 'num_moe_experts': 112, │ 3.00B │ 25.37B │ 0.2503 │ +│ │ 'moe_ffn_hidden_size': 1792, 'moe_shared_expert_intermediate_size': 3072} │ │ │ │ +│ 12 │ {'num_layers': 48, 'hidden_size': 2560, 'mamba_num_heads': 56, 'mamba_head_dim': 56, 'num_moe_experts': 104, │ 3.00B │ 23.68B │ 0.4329 │ +│ │ 'moe_ffn_hidden_size': 1792, 'moe_shared_expert_intermediate_size': 3072} │ │ │ │ +│ 13 │ {'num_layers': 46, 'hidden_size': 2688, 'mamba_num_heads': 64, 'mamba_head_dim': 64, 'num_moe_experts': 128, │ 3.00B │ 26.17B │ 0.2587 │ +│ │ 'moe_ffn_hidden_size': 1536, 'moe_shared_expert_intermediate_size': 2816} │ │ │ │ +│ 14 │ {'num_layers': 46, 'hidden_size': 2560, 'mamba_num_heads': 64, 'mamba_head_dim': 56, 'num_moe_experts': 104, │ 3.00B │ 23.68B │ 0.2336 │ +│ │ 'moe_ffn_hidden_size': 1792, 'moe_shared_expert_intermediate_size': 3072} │ │ │ │ +│ 15 │ {'num_layers': 52, 'hidden_size': 2688, 'mamba_num_heads': 48, 'mamba_head_dim': 56, 'num_moe_experts': 96, │ 3.00B │ 20.09B │ 0.2559 │ +│ │ 'moe_ffn_hidden_size': 1536, 'moe_shared_expert_intermediate_size': 3072} │ │ │ │ +│ 16 │ {'num_layers': 52, 'hidden_size': 2304, 'mamba_num_heads': 64, 'mamba_head_dim': 64, 'num_moe_experts': 96, │ 3.00B │ 20.70B │ 0.4608 │ +│ │ 'moe_ffn_hidden_size': 1856, 'moe_shared_expert_intermediate_size': 3072} │ │ │ │ +│ 17 │ {'num_layers': 50, 'hidden_size': 2560, 'mamba_num_heads': 48, 'mamba_head_dim': 48, 'num_moe_experts': 104, │ 3.00B │ 23.68B │ 0.2455 │ +│ │ 'moe_ffn_hidden_size': 1792, 'moe_shared_expert_intermediate_size': 3712} │ │ │ │ +│ 18 │ {'num_layers': 50, 'hidden_size': 2560, 'mamba_num_heads': 48, 'mamba_head_dim': 48, 'num_moe_experts': 104, │ 3.00B │ 24.42B │ 0.2503 │ +│ │ 'moe_ffn_hidden_size': 1856, 'moe_shared_expert_intermediate_size': 3328} │ │ │ │ +│ 19 │ {'num_layers': 48, 'hidden_size': 2560, 'mamba_num_heads': 48, 'mamba_head_dim': 48, 'num_moe_experts': 120, │ 3.00B │ 27.92B │ 0.2587 │ +│ │ 'moe_ffn_hidden_size': 1856, 'moe_shared_expert_intermediate_size': 3712} │ │ │ │ +│ 20 │ {'num_layers': 46, 'hidden_size': 2560, 'mamba_num_heads': 56, 'mamba_head_dim': 64, 'num_moe_experts': 104, │ 3.00B │ 23.68B │ 0.2469 │ +│ │ 'moe_ffn_hidden_size': 1792, 'moe_shared_expert_intermediate_size': 3072} │ │ │ │ +└────┴───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┴───────────────┴────────┴────────┘ + +╭──────────────────────────────────────────────────────────────────────── Best Subnet ─────────────────────────────────────────────────────────────────────────╮ +│ export_config {'num_layers': 52, 'hidden_size': 2304, 'mamba_num_heads': 64, 'mamba_head_dim': 64, 'num_moe_experts': 104, 'moe_ffn_hidden_size': 1856, │ +│ 'moe_shared_expert_intermediate_size': 3072} │ +│ active_params 3.00B │ +│ params 22.28B │ +│ score 0.4783 │ +╰──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯ + +Original hybrid_layer_pattern: MEMEM*EMEMEM*EMEMEM*EMEMEM*EMEMEM*EMEMEMEM*EMEMEMEME +Pruned hybrid_layer_pattern: MEMEM*EMEMEM*EMEMEM*EMEMEM*EMEMEM*EMEMEMEM*EMEMEMEME + +╭───────────────────────────────────────────────────── Pruned Model Stats ──────────────────────────────────────────────────────╮ +│ Total Parameters 22.28B │ +│ Active Parameters 3.00B │ +│ Memory (BF16, seq_length=8192, batch_size=1) weights: 42489.7 MB, kv_cache: 48.0 MB, mamba_state: 23.8 MB, Total: 42561.6 MB │ +╰───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯ +``` + +
+ +> [!TIP] +> Candidate selection above relies on the pruning score alone — it does not run a short KD trial per candidate to pick the winner. The main post-pruning distillation in [Section 3](#3-distillation) is still performed on the selected candidate. If you want a stronger pick, take a few top candidates' `export_config` from the logs above (where the score is similar to the best subnet), export them separately, run KD for ~2B tokens on each, and pick the best on your target metrics. See [ABLATIONS.md — 1st vs 2nd best candidate](ABLATIONS.md#distillation-results-1st-best-vs-2nd-best-pruning-candidate) for a concrete comparison. + +--- + +### 3. Distillation + +Distillation is run in two phases: an 80B-token phase at 8K sequence length, followed by a 20B-token long-context phase at 32K sequence length. The two phases are launched as separate runs with an intermediate Megatron→HF checkpoint conversion, because the long-context phase changes `seq_length`, `gbs`, and `cp_size` — Megatron's checkpoint resume bookkeeping (sample counter is in absolute samples, iteration counter is in iter-units tied to `gbs`) does not handle a mid-run `gbs` change cleanly. + +Minimum hardware: **4 nodes × 8x H100 (32 GPUs)** for the 8K phase — required by `TP=4 × EP=8`. The 32K phase additionally requires context parallel to fit the longer sequence, doubling the minimum to **8 nodes × 8x H100 (64 GPUs)**. On **96 nodes × 8x H100 (768 GPUs total)**, it takes ~900 H100 GPU-hours per 10B tokens (400 iters), i.e. ~70 min wall-clock per 10B tokens on 96 nodes. Full schedule (80B @ 8K + 20B @ 32K = 100B tokens, 4k total steps) takes ~9k H100 GPU-hours (~12 hours wall-clock). + +#### 3a. Phase 1 — 80B tokens @ 8K seq length + +
+Phase 1 distillation command (click to expand) + +> NOTE: We use `python -u` for slurm multi-node run here. + +```bash +python -u /opt/Model-Optimizer/examples/megatron_bridge/distill.py \ + --teacher_hf_path nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16 \ + --student_hf_path /path/to/Nemotron-3-Nano-30B-A3B-Pruned-A3.0B \ + --trust_remote_code \ + --tp_size 4 \ + --ep_size 8 \ + --data_paths "${DATA_BLEND}" \ + --data_path_to_cache /path/to/cache \ + --seq_length 8192 \ + --mbs 1 \ + --gbs 3072 \ + --train_iters 3200 \ + --lr 1e-4 \ + --min_lr 1e-5 \ + --lr_warmup_iters 25 \ + --eval_interval 200 \ + --eval_iters 8 \ + --log_interval 10 \ + --output_dir /path/to/distill_output_phase1_8k + +# Optional: Weights & Biases logging +# --wandb_project \ +# --wandb_entity \ +# --wandb_exp_name +``` + +Non-default arguments: + +- `--seq_length 8192` (default: 4096) +- `--gbs 3072` (default: 768) — matches the original Nemotron-3-Nano-30B training GBS from the paper, kept to preserve the training distribution +- `--train_iters 3200` — 80B tokens at GBS 3072 × seq_length 8192 +- `--lr 1e-4 --min_lr 1e-5 --lr_warmup_iters 25` — cosine fully decays over 3200 iters; the model is approaching saturation at 8K by this point (see [ABLATIONS.md — 8K trajectory](ABLATIONS.md#effect-of-data-blend-tool_calling)). +- `--eval_interval 200` (default: 100) — less frequent eval to save compute +- `--eval_iters 8` (default: 32) — since GBS is 4× larger than default + +All other arguments use defaults. +
+ +#### 3b. Convert Phase 1 final checkpoint to HuggingFace format + +Phase 2 starts as a separate run from a fresh HuggingFace student checkpoint, so the final Phase 1 Megatron checkpoint must be exported to HF first using the Megatron-Bridge conversion script (see [Megatron-Bridge README](../../README.md) for full details). You can also use this same script to convert any intermediate Phase 1 checkpoint to HF format for evaluation along the way. + +
+Checkpoint conversion command (click to expand) + +> NOTE: Below command only works for non-quantized checkpoints. For quantized checkpoints, we use the `export.py` script in Section 5 to directly export the quantized checkpoint to Unified HF format for deployment. + +```bash +python /opt/Megatron-Bridge/examples/conversion/convert_checkpoints.py export \ + --hf-model /path/to/Nemotron-3-Nano-30B-A3B-Pruned-A3.0B \ + --megatron-path /path/to/distill_output_phase1_8k/checkpoints/iter_0003200 \ + --hf-path /path/to/distill_output_phase1_8k/checkpoints/hf_iter_0003200 +``` + +
+ +#### 3c. Phase 2 — 20B tokens @ 32K seq length + +Phase 2 is a **fresh run** with the Phase 1 final checkpoint as the new student. It uses a different `--seed` so the data blend reshuffles (otherwise the model would see overlapping prefix of the same samples it already saw at 8K). The LR is bumped back up modestly to capture the rapid long-context adaptation observed in [ABLATIONS.md — Effect of long context training](ABLATIONS.md#effect-of-long-context-training). + +
+Phase 2 distillation command (click to expand) + +```bash +python -u /opt/Model-Optimizer/examples/megatron_bridge/distill.py \ + --teacher_hf_path nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16 \ + --student_hf_path /path/to/distill_output_phase1_8k/checkpoints/hf_iter_0003200 \ + --trust_remote_code \ + --tp_size 4 \ + --cp_size 2 \ + --ep_size 8 \ + --seed 5678 \ + --data_paths "${DATA_BLEND}" \ + --data_path_to_cache /path/to/cache \ + --seq_length 32768 \ + --mbs 1 \ + --gbs 768 \ + --train_iters 800 \ + --lr 2e-5 \ + --min_lr 1e-5 \ + --lr_warmup_iters 10 \ + --recompute_granularity selective \ + --recompute_modules moe \ + --eval_interval 200 \ + --eval_iters 8 \ + --log_interval 10 \ + --output_dir /path/to/distill_output_phase2_32k +``` + +Changed arguments from Phase 1: + +- `--student_hf_path` — points at the HF export of the Phase 1 final checkpoint +- `--seq_length 32768` — long-context phase +- `--gbs 768` — `seq_length × gbs` product unchanged, so each iter still processes the same number of tokens +- `--cp_size 2` — context parallel is needed to fit the longer sequence; doubles the minimum-hardware footprint to 8 nodes +- `--train_iters 800` — 20B tokens at GBS 768 × seq_length 32768 +- `--lr 2e-5 --min_lr 1e-5 --lr_warmup_iters 10` — modest LR bump for the long-context adaptation (Phase 1 ended at fully-decayed LR 1e-5); the 10-iter warmup re-populates Adam moment estimates which restart from zero in a fresh run +- `--recompute_granularity selective --recompute_modules moe` — selective MoE recompute further reduces activation memory at 32K. You may skip this if you have more memory. +- `--seed 5678` — different from the Phase 1 seed (default 1234) so the data blend reshuffles +- `--output_dir /path/to/distill_output_phase2_32k` — must be a **fresh directory** different from Phase 1's, so distill.py's resume mechanism (which auto-loads from `/checkpoints` if it exists) does not pull in stale state + +
+ +For multi-node Slurm runs, see the [Megatron-Bridge README](../../README.md#slurm-usage) for details. + +> [!NOTE] +> This is pure SFT-style distillation — no RL or online reward signal is used. Adding an RL-based post-training step after distillation is a natural next step that could further improve some of these benchmarks. + +--- + +### 4. Evaluation + +The eval config in [nemo_evaluator.yaml](nemo_evaluator.yaml) is for Slurm-based evaluation — it submits a vLLM serving job (with tool calling enabled via `--enable-auto-tool-choice --tool-call-parser qwen3_coder`) and runs evals against it. For local model execution and evaluation, refer to the [NeMo Evaluator documentation](https://docs.nvidia.com/nemo/evaluator/latest/) or this [blog](https://huggingface.co/blog/nvidia/nemotron-3-nano-evaluation-recipe). + +
+Evaluation launch steps (click to expand) + +Before running, update the following fields in the `nemo_evaluator.yaml` file or overwrite them in the command line with `-o
+ +**Tasks and exact metric names reported in the results table:** + +| Benchmark | Library | num_repeats | Metric name | +| --- | --- | --- | --- | +| MMLU Pro | NeMo Evaluator | 1 | `mmlu-pro_pass_at_1_symbolic_correct` | +| GPQA Diamond | NeMo Evaluator | 8 | `gpqa_pass_at_1_avg-of-8_symbolic_correct` | +| LiveCodeBench v6 | NeMo Evaluator | 4 | `livecodebench_pass_at_1_avg-of-4_accuracy` | +| AIME 2025 | NeMo Evaluator | 32 | `aime25_pass_at_1_avg-of-32_symbolic_correct` | +| IFBench | NeMo Evaluator | 8 | `ifbench_pass_at_1_avg-of-8_average_score` | +| SciCode (Subtask) | NeMo Evaluator | 8 | `scicode_pass_at_1_avg-of-8_subtask_accuracy` | + +For more details on NeMo Evaluator, see the [GitHub repo](https://github.com/NVIDIA-NeMo/evaluator) and [documentation](https://docs.nvidia.com/nemo/evaluator/latest/). + +--- + +### 5. Quantization + +ModelOpt allows stacking multiple optimization techniques. Here we stack FP8 quantization on top of the pruned and distilled model to get an even more optimized model. See [examples/megatron_bridge/README.md](../../README.md) for the full Megatron-Bridge PTQ documentation. + +Similar to the official [Nemotron-3-Nano-30B-A3B-FP8](https://huggingface.co/nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-FP8) model, if you want to quantize the pruned 22B/A3.0B model to FP8, the Mamba, MoE, and MLP layers are quantized to FP8, while the attention layers and the Conv1d components within the Mamba layers are kept in BF16 to avoid accuracy degradation. + +This is done with the `MAMBA_MOE_FP8_CONSERVATIVE_CFG` config defined in [`modelopt/torch/quantization/config.py`](../../../../modelopt/torch/quantization/config.py), which you select by passing `--quant_cfg MAMBA_MOE_FP8_CONSERVATIVE_CFG` below. For a faster model at the cost of a larger accuracy drop, you can use `MAMBA_MOE_FP8_AGGRESSIVE_CFG` instead. + +> [!NOTE] +> You can also quantize to NVFP4 using `--quant_cfg MAMBA_MOE_NVFP4_CONSERVATIVE_CFG` or `MAMBA_MOE_NVFP4_AGGRESSIVE_CFG` (faster, more accuracy drop), which may require further distillation (QAD) to recover accuracy and a Blackwell GPU for deployment. + +Quantization is a two-step flow: `quantize.py` calibrates and saves a Megatron checkpoint, then `export.py` converts it to a deployable HuggingFace checkpoint (the unified HF exporter loads at TP=1, so pipeline parallelism is used to shard across GPUs). Both steps take a few minutes on 8x H100. + +**Step 1 — calibrate and save the quantized Megatron checkpoint:** + +
+FP8 PTQ command (click to expand) + +```bash +torchrun --nproc_per_node 8 /opt/Model-Optimizer/examples/megatron_bridge/quantize.py \ + --hf_model_name_or_path /path/to/distill_output_phase2_32k/checkpoints/hf_iter_0000800 \ + --trust_remote_code \ + --tp_size 8 \ + --quant_cfg MAMBA_MOE_FP8_CONSERVATIVE_CFG \ + --calib_batch_size 32 \ + --seq_length 512 \ + --export_megatron_path /path/to/distill_output_phase2_32k/checkpoints/iter_0000800_fp8_megatron \ + --skip_generate +``` + +
+ +**Step 2 — export the Megatron checkpoint to a deployable HuggingFace checkpoint:** + +
+Export command (click to expand) + +```bash +torchrun --nproc_per_node 1 /opt/Model-Optimizer/examples/megatron_bridge/export.py \ + --hf_model_name_or_path /path/to/distill_output_phase2_32k/checkpoints/hf_iter_0000800 \ + --megatron_path /path/to/distill_output_phase2_32k/checkpoints/iter_0000800_fp8_megatron \ + --trust_remote_code \ + --pp_size 1 \ + --export_unified_hf_path /path/to/distill_output_phase2_32k/checkpoints/hf_iter_0000800_fp8 +``` + +
+ +The exported HuggingFace checkpoint is directly deployable with [vLLM](https://github.com/vllm-project/vllm), [TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM) and [SGLang](https://github.com/sgl-project/sglang). + +> [!TIP] +> Run text generation on sample prompts to sanity-check the quantized checkpoint generates reasonable output: +> +> ```bash +> python /opt/Model-Optimizer/examples/megatron_bridge/generate_vllm.py \ +> --model /path/to/distill_output_phase2_32k/checkpoints/hf_iter_0000800_fp8 \ +> --trust_remote_code +> ``` + +> [!TIP] +> You can run the evaluation using the same `nemo_evaluator.yaml` file for the quantized checkpoint also — just apply the FP8 deployment tweaks documented at the top of the yaml. + +See FP8 vs BF16 results in the [Results](#results) section above. + +--- + +### 6. vLLM Inference Benchmarking + +Benchmark throughput using [vLLM](https://github.com/vllm-project/vllm) on a single H100 GPU. + +
+vLLM benchmark commands (ISL=32768, OSL=1024) (click to expand) + +```bash +# BF16 (original or pruned) +vllm bench throughput \ + --model \ + --random-input-len 32768 \ + --random-output-len 1024 \ + --trust-remote-code \ + --mamba_ssm_cache_dtype float32 \ + --load-format safetensors + +# FP8 (Hopper GPU) +VLLM_USE_FLASHINFER_MOE_FP8=1 VLLM_FLASHINFER_MOE_BACKEND=throughput \ +vllm bench throughput \ + --model \ + --random-input-len 32768 \ + --random-output-len 1024 \ + --trust-remote-code \ + --mamba_ssm_cache_dtype float32 \ + --kv-cache-dtype fp8 \ + --load-format safetensors +``` + +
+ +See the [vLLM Throughput table in Results](#vllm-throughput-single-h100-isl32768-osl1024) for measured numbers. + +> [!TIP] +> To deploy the model with vLLM, you can refer to the [vLLM Quickstart documentation](https://docs.vllm.ai/en/stable/getting_started/quickstart/). diff --git a/examples/pruning/minitron/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16/figures/learning_curves.png b/examples/megatron_bridge/tutorials/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16/figures/learning_curves.png similarity index 100% rename from examples/pruning/minitron/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16/figures/learning_curves.png rename to examples/megatron_bridge/tutorials/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16/figures/learning_curves.png diff --git a/examples/pruning/minitron/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16/nemo_evaluator.yaml b/examples/megatron_bridge/tutorials/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16/nemo_evaluator.yaml similarity index 76% rename from examples/pruning/minitron/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16/nemo_evaluator.yaml rename to examples/megatron_bridge/tutorials/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16/nemo_evaluator.yaml index cacbc078807..b2f63f17e86 100644 --- a/examples/pruning/minitron/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16/nemo_evaluator.yaml +++ b/examples/megatron_bridge/tutorials/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16/nemo_evaluator.yaml @@ -3,11 +3,15 @@ # Before running, update the following fields in the yaml: # - `execution.hostname` — your Slurm login node hostname # - `execution.account` — your Slurm account -# - `deployment.checkpoint_path` — Hugging Face checkpoint path (original, pruned or quantized) -# - `evaluation.nemo_evaluator_config.config.params.extra.tokenizer` — same path as `checkpoint_path` +# - `deployment.checkpoint_path` — Hugging Face checkpoint path (original, pruned, or quantized) +# +# This config is set up for a BF16 checkpoint. For an FP8 checkpoint, also apply the deployment changes below: +# - FP8: add `--kv-cache-dtype fp8` to `deployment.extra_args`, and set in `deployment.env_vars`: +# VLLM_USE_FLASHINFER_MOE_FP8: "1" +# VLLM_FLASHINFER_MOE_BACKEND: throughput # # Usage: -# pip install "nemo-evaluator-launcher[all]==0.1.90" +# pip install "nemo-evaluator-launcher[all]==0.1.82" # # # Set required environment variables: # export HF_TOKEN= @@ -31,9 +35,9 @@ defaults: execution: type: slurm - hostname: + hostname: ??? username: ${oc.env:USER} - account: + account: ??? partition: batch num_nodes: 1 ntasks_per_node: 1 @@ -54,17 +58,22 @@ execution: # Note: Only tp=1 works for Nano (Mamba-based hybrid architecture) deployment: # Update this to your Hugging Face checkpoint path (original, pruned or quantized) - checkpoint_path: + checkpoint_path: ??? served_model_name: Nemotron-3-Nano-30B-A3B port: 8000 tensor_parallel_size: 1 pipeline_parallel_size: 1 data_parallel_size: 8 gpu_memory_utilization: 0.90 - extra_args: "--max-model-len 262144 --enable-log-requests --no-enable-prefix-caching --trust-remote-code --mamba_ssm_cache_dtype float32 --enable-auto-tool-choice\ - \ --tool-call-parser qwen3_coder --reasoning-parser-plugin /checkpoint/nano_v3_reasoning_parser.py --reasoning-parser nano_v3" - env_vars: - VLLM_FLASHINFER_MOE_BACKEND: throughput + # extra_args is for the BF16 checkpoint. For an FP8 checkpoint, append `--kv-cache-dtype fp8`. + extra_args: "--max-model-len 262144 --max-num-seqs 8 --enable-log-requests --no-enable-prefix-caching --trust-remote-code --mamba_ssm_cache_dtype float32\ + \ --enable-auto-tool-choice --tool-call-parser qwen3_coder --reasoning-parser-plugin /checkpoint/nano_v3_reasoning_parser.py --reasoning-parser nano_v3" + # env_vars is for the BF16 checkpoint (no MoE backend flags needed). For an FP8 + # checkpoint, replace `{}` with the block below (see notes at top of file): + # FP8: + # VLLM_USE_FLASHINFER_MOE_FP8: "1" + # VLLM_FLASHINFER_MOE_BACKEND: throughput + env_vars: {} endpoints: chat: /v1/chat/completions completions: /v1/completions @@ -99,8 +108,8 @@ evaluation: max_retries: 10 extra: tokenizer_backend: huggingface - # Update tokenizer path to match checkpoint_path above - tokenizer: + # Auto-derived from deployment.checkpoint_path below + tokenizer: ${deployment.checkpoint_path} env_vars: HF_TOKEN: HF_TOKEN HF_HOME: HF_HOME @@ -118,7 +127,6 @@ evaluation: nemo_evaluator_config: config: params: - # limit_samples: 8 extra: num_repeats: 1 args: "++prompt_config=eval/aai/mcq-10choices-boxed" @@ -130,7 +138,6 @@ evaluation: nemo_evaluator_config: config: params: - # limit_samples: 8 extra: num_repeats: 8 args: "++prompt_config=eval/aai/mcq-4choices" @@ -142,9 +149,8 @@ evaluation: nemo_evaluator_config: config: params: - # limit_samples: 8 extra: - num_repeats: 8 + num_repeats: 4 dataset_split: test_v6_2408_2505 # 4. AIME 2025 @@ -154,9 +160,8 @@ evaluation: nemo_evaluator_config: config: params: - # limit_samples: 8 extra: - num_repeats: 64 + num_repeats: 32 # 5. IFBench - name: ns_ifbench @@ -165,7 +170,6 @@ evaluation: nemo_evaluator_config: config: params: - # limit_samples: 8 extra: num_repeats: 8 @@ -176,6 +180,5 @@ evaluation: nemo_evaluator_config: config: params: - # limit_samples: 8 extra: num_repeats: 8 diff --git a/examples/megatron_bridge/tutorials/README.md b/examples/megatron_bridge/tutorials/README.md new file mode 100644 index 00000000000..9a27c6914d9 --- /dev/null +++ b/examples/megatron_bridge/tutorials/README.md @@ -0,0 +1,10 @@ +# Megatron-Bridge Tutorials + +End-to-end tutorials that combine ModelOpt optimization techniques on [NVIDIA Megatron-Bridge](https://github.com/NVIDIA-NeMo/Megatron-Bridge) models. +Each one walks through a complete workflow using the scripts in [examples/megatron_bridge](../README.md) (`prune_minitron.py`, `distill.py`, `quantize.py`, `export.py`). + +## Available tutorials + +| Tutorial | What it covers | +| --- | --- | +| [NVIDIA-Nemotron-3-Nano-30B-A3B-BF16](NVIDIA-Nemotron-3-Nano-30B-A3B-BF16/README.md) | End-to-end optimization of the Nemotron-3-Nano-30B-A3B-BF16 (MoE + Mamba-Transformer hybrid) model: Minitron structured pruning (31.6B/A3.6B → 22B/A3.0B) → two-phase knowledge distillation (100B tokens, 8K then 32K seq length) → quantization → vLLM deployment. Includes data-blend preparation, evaluation setup, and detailed pruning / data-blend / long-context ablations. | diff --git a/examples/pruning/README.md b/examples/pruning/README.md index 0bf8d904ecf..025c61a6712 100644 --- a/examples/pruning/README.md +++ b/examples/pruning/README.md @@ -294,7 +294,7 @@ After pruning, distillation is required to recover model accuracy. Below are rec End-to-end distillation results with Megatron-Bridge after Minitron and Puzzletron pruning: -- **[Minitron — Nemotron-3-Nano-30B-A3B-BF16](minitron/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16/README.md)** ⭐ *recommended — newer and most comprehensive*: End-to-end tutorial of structured pruning for Nemotron-3-Nano-30B-A3B-BF16 (31.6B/A3.6B) to 22B/A3.0B active parameters followed by two-phase knowledge distillation (80B tokens @ 8K seq length + 20B tokens @ 32K seq length = 100B tokens total), quantization, and vLLM deployment. Covers MoE + Mamba-Transformer hybrid, tool-calling data, and a long-context fine-tuning phase. Achieves near-parity with the official 30B model across popular pretraining and reasoning benchmarks while delivering up to 1.64× throughput speedup and 2.6× memory reduction when combined with FP8 quantization. +- **[Minitron — Nemotron-3-Nano-30B-A3B-BF16](../megatron_bridge/tutorials/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16/README.md)** ⭐ *recommended — newer and most comprehensive*: End-to-end tutorial of structured pruning for Nemotron-3-Nano-30B-A3B-BF16 (31.6B/A3.6B) to 22B/A3.0B active parameters followed by two-phase knowledge distillation (80B tokens @ 8K seq length + 20B tokens @ 32K seq length = 100B tokens total), quantization, and vLLM deployment. Covers MoE + Mamba-Transformer hybrid, tool-calling data, and a long-context fine-tuning phase. Achieves near-parity with the official 30B model across popular pretraining and reasoning benchmarks while delivering up to 2.6× throughput speedup and 2.6× memory reduction when combined with FP8 quantization. - **[Minitron — Nemotron-Nano-9B-v2](minitron/NVIDIA-Nemotron-Nano-9B-v2/README.md)**: Earlier end-to-end tutorial covering structured pruning of the dense Mamba-Transformer Nemotron-Nano-9B-v2 to 7B followed by knowledge distillation up to 80B tokens, quantization, and vLLM deployment. Simpler architecture, single-phase 8K seq length distillation, no tool-calling or long-context phase. - **[Puzzletron — Qwen3-8B and Llama-3.1-8B-Instruct](puzzletron/Llama-3.1-8B-Instruct.md)**: MIP-based compression followed by short distillation runs on WikiText-103. Shows MMLU recovery and illustrates the importance of using larger datasets to avoid overfitting. diff --git a/examples/pruning/minitron/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16/README.md b/examples/pruning/minitron/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16/README.md index 59737f26206..338c7973f0d 100644 --- a/examples/pruning/minitron/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16/README.md +++ b/examples/pruning/minitron/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16/README.md @@ -1,490 +1,4 @@ -# Nemotron-3-Nano-30B-A3B: Prune + Distill + Quantize + vLLM Deployment +# This tutorial has moved -End-to-end optimization of [NVIDIA-Nemotron-3-Nano-30B-A3B-BF16](https://huggingface.co/nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16) demonstrating how ModelOpt techniques stack: Minitron structured pruning → Megatron-Bridge knowledge distillation to recover accuracy → evaluation benchmarking → FP8 quantization → vLLM deployment and throughput benchmarking. This document covers: - -1. **[Data Preparation](#1-data-preparation)** — tokenizing the training blend for distillation -2. **[Pruning](#2-pruning)** — Minitron structured pruning -3. **[Distillation](#3-distillation)** — recovering accuracy via Megatron-Bridge knowledge distillation -4. **[Evaluation](#4-evaluation)** — benchmarking with NeMo Evaluator across MMLU Pro, GPQA Diamond, AIME, and more -5. **[Quantization](#5-quantization)** — FP8 PTQ on the distilled checkpoint using ModelOpt's `examples/llm_ptq/hf_ptq.py` script -6. **[vLLM Inference Benchmarking](#6-vllm-inference-benchmarking)** — throughput comparison of BF16 vs FP8 on a single H100 - -## Results - -![Benchmark Recovery During Knowledge Distillation](figures/learning_curves.png) - -| Model | MMLU Pro | GPQA Diamond | LiveCodeBench v6 | AIME 2025 | IFBench | SciCode (Subtask) | Average | -| --- | --- | --- | --- | --- | --- | --- | --- | -| Pruned 22B/A3.0B (no distillation) | 47.1 | 33.5 | 27.4 | 15.5 | 36.9 | 12.1 | 28.8 | -| Distill @ 2.5B tokens (100 iters at 8K Seq Length) | 73.3 | 63.7 | 55.3 | 77.6 | 59.1 | 25.1 | 59.0 | -| Distill @ 20B tokens (800 iters at 8K Seq Length) | 74.8 | 66.0 | 62.3 | 79.6 | 65.4 | 26.1 | 62.4 | -| Distill @ 40B tokens (1600 iters at 8K Seq Length) | 76.4 | 67.2 | 62.3 | 79.8 | 66.0 | 26.6 | 63.1 | -| Distill @ 60B tokens (2400 iters at 8K Seq Length) | 76.1 | 68.1 | 63.6 | 78.8 | 67.3 | 27.0 | 63.5 | -| Distill @ 80B tokens (3200 iters at 8K Seq Length) | 76.5 | 69.1 | 63.9 | 80.7 | 66.5 | 29.0 | 64.3 | -| Distill @ 82.5B tokens (+100 iters at 32K Seq Length) | 76.2 | 69.8 | 64.8 | 87.0 | 68.2 | 27.0 | 65.5 | -| Distill @ 100B tokens (+800 iters at 32K Seq Length) - **BF16** | 76.6 | 69.6 | 66.1 | 87.3 | 68.9 | 28.4 | 66.2 | -| Distill @ 100B tokens + **FP8 Quantize** | 76.3 | 69.8 | 65.5 | 86.0 | 69.7 | 27.9 | 65.9 | -| NVIDIA-Nemotron-3-Nano-30B-A3B-BF16 (official, 31.6B/A3.6B) | 78.0 | 70.3 | 67.9 | 87.1 | 69.1 | 31.8 | 67.4 | - -### vLLM Throughput (single H100, ISL=32768, OSL=1024) - -| Checkpoint | Model loading memory | Output tokens/s | Speedup vs Nemotron-3-Nano-30B-A3B BF16 | -| --- | --- | --- | --- | -| Nemotron-3-Nano-30B-A3B-BF16 (official, 31.6B/A3.6B) | 58.9 GiB | 1,006 | 1.00× | -| Nemotron-3-Nano-30B-A3B-FP8 (official) | 31.4 GiB | 1,404 | 1.40× | -| Nemotron-3-Nano-30B-A3B-Pruned-A3.0B (22B/A3.0B) | 41.5 GiB | 1,301 | 1.29× | -| Nemotron-3-Nano-30B-A3B-Pruned-A3.0B-FP8 | 22.8 GiB | 1,653 | 1.64× | - -Pruning alone (BF16 → Pruned-A3.0B BF16) gives a **1.29×** throughput speedup with a 30% memory reduction (58.9 → 41.5 GiB), and FP8 quantization alone (BF16 → FP8) gives a **1.40×** speedup with a 47% memory reduction. Stacking both — pruning + FP8 — compounds to a **1.64×** throughput speedup and a **2.6× memory reduction** (58.9 → 22.8 GiB) relative to the original 30B BF16 model, while preserving most of the benchmark accuracy. The NemotronH hybrid architecture (Mamba + Attention + MoE) moderates the FP8 gain relative to pure-transformer models, since Attention and Conv1d layers are not quantized. See [Section 6](#6-vllm-inference-benchmarking) for the benchmark command. - -Distillation uses the **30% Pretraining (Code 5, General 20, MATH 5) + 70% Post-training v1/v3 (Math 27, Coding 20, Science 13, IF 5, Tool calling 5)** blend (see [Data Blend](#data-blend) below) with an **80B @ 8K + 20B @ 32K = 100B token** schedule. Blend ablations and long-context phase ablations are in [ABLATIONS.md](ABLATIONS.md). - -> [!TIP] -> From the benchmark numbers above, the model is still learning at 100B tokens and that further training (or a higher-quality data blend) would continue to close the gap to the original 31.6B/A3.6B model. - -> [!NOTE] -> Exact numbers may vary depending on deployment and evaluation setup. All models above (including the official model) were evaluated once with the same [evaluation setup](#4-evaluation) for fair comparison. These numbers may differ from those reported on the official [Nemotron-3-Nano-30B-A3B-BF16](https://huggingface.co/nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16) HuggingFace model card. - ---- - -## Steps to Reproduce - -**Environment:** Container `nvcr.io/nvidia/nemo:26.04`, ModelOpt 0.45.0. See the [Megatron-Bridge README](../../../megatron_bridge/README.md) for environment setup (including ModelOpt mount path) and container usage. - -### 1. Data Preparation - -See [examples/dataset/MEGATRON_DATA_PREP.md](../../../dataset/MEGATRON_DATA_PREP.md) for tokenization commands for all datasets used in this blend. - -For this experiment: `TOKENIZER=nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16`, `OUTPUT_DIR=tokenized_nemotron_3`. - -#### Data Blend - -**30% Pretraining (Code 5, General 20, MATH 5) + 70% Post-training v1/v3 (Math 27, Coding 20, Science 13, IF 5, Tool calling 5)** - -| Dataset | Tokens | Weight | Notes | -| ---------------------------------------------------------- | ------ | ------ | ---------------------------------------------- | -| Nemotron-Pretraining-SFT-v1 / Code (10M samples) | 7B | 5 | Pretraining code | -| Nemotron-Pretraining-SFT-v1 / General (10M samples) | 16B | 20 | Upweighted to close MMLU gap | -| Nemotron-Pretraining-SFT-v1 / MATH (10M samples) | 13B | 5 | Pretraining math | -| Nemotron-Math-v2 / high_part00 | 13B | 10 | Hard math reasoning | -| Nemotron-SFT-Math-v3 / train | 52B | 17 | Hard math reasoning with full reasoning traces | -| Nemotron-SFT-Competitive-Programming-v2 / python_00 | 7B | 15 | Python reasoning traces | -| Nemotron-SFT-Competitive-Programming-v2 / cpp_00 | 7B | 5 | C++ reasoning traces | -| Nemotron-Post-Training-Dataset-v1 / stem (5M samples) | 22B | 8 | Broad STEM | -| Nemotron-Science-v1 / MCQ | 0.5B | 3 | GPQA MCQ format alignment | -| Nemotron-Science-v1 / RQA | 0.3B | 2 | GPQA format diversity | -| Nemotron-SFT-Instruction-Following-Chat-v2 / reasoning_on | 2B | 3 | Instruction following (thinking on) | -| Nemotron-SFT-Instruction-Following-Chat-v2 / reasoning_off | 1B | 2 | Instruction following (thinking off) | -| Nemotron-Agentic-v1 / tool_calling | 1B | 5 | Tool-use scaffolding; helps SciCode / GPQA | - -
-Data blend for distillation (click to expand) - -```bash -DATA_BLEND=" \ -5 tokenized_nemotron_3/nvidia--Nemotron-Pretraining-SFT-v1_Nemotron-SFT-Code_train_text_max10000000 \ -20 tokenized_nemotron_3/nvidia--Nemotron-Pretraining-SFT-v1_Nemotron-SFT-General_train_text_max10000000 \ -5 tokenized_nemotron_3/nvidia--Nemotron-Pretraining-SFT-v1_Nemotron-SFT-MATH_train_text_max10000000 \ -10 tokenized_nemotron_3/nvidia--Nemotron-Math-v2_default_high_part00_messages \ -17 tokenized_nemotron_3/nvidia--Nemotron-SFT-Math-v3_default_train_messages \ -15 tokenized_nemotron_3/competitive_programming_python_00_messages \ -5 tokenized_nemotron_3/competitive_programming_cpp_00_messages \ -8 tokenized_nemotron_3/nvidia--Nemotron-Post-Training-Dataset-v1_default_stem_messages_max5000000 \ -3 tokenized_nemotron_3/MCQ_messages \ -2 tokenized_nemotron_3/RQA_messages \ -3 tokenized_nemotron_3/reasoning_on_messages \ -2 tokenized_nemotron_3/reasoning_off_messages \ -5 tokenized_nemotron_3/nvidia--Nemotron-Agentic-v1_tool_calling_messages \ -" -``` - -
- -#### General Guidelines - -The optimal blend is 30% pretraining and 70% post-training data. Exact proportions may vary depending on the benchmarks you care about. The blend above was designed to maximize recovery on popular General Knowledge, Reasoning, Instruction Following, and Tool Calling benchmarks. The key design decisions were: - -- **30% pretraining data** closes the MMLU gap that arises from training exclusively on reasoning-heavy post-training data. The General split (20%) is upweighted specifically to recover general knowledge recall. -- **Math (27%)** is the largest post-training category because AIME and MMLU Pro respond strongly to more math reasoning tokens. We use a mix of `Nemotron-Math-v2` and `Nemotron-SFT-Math-v3` for higher quality math reasoning signal with full reasoning traces. -- **Science (13%)** uses `Nemotron-Post-Training-Dataset-v1 / stem` as the primary source for volume and GPQA stability, with small allocations to `Nemotron-Science-v1` MCQ/RQA subsets for format alignment with GPQA's multiple-choice structure. -- **Instruction following (5%)** saturates quickly so a small allocation is sufficient. -- **Tool calling (5%)** uses `Nemotron-Agentic-v1 / tool_calling`. Our evals run with `--enable-auto-tool-choice`, so the student needs explicit exposure to function-call schemas; this helps SciCode (heavy Python tool use) and GPQA Diamond (which can benefit from calculator tools). - -This blend intentionally omits capabilities not targeted in this experiment (e.g. multilingual, SWE). Depending on what benchmarks matter for your use case, you can substitute or add datasets from the [Nemotron Post-Training v3 collection](https://huggingface.co/collections/nvidia/nemotron-post-training-v3), for example: - -| Capability | Relevant datasets | -| --- | --- | -| Multilingual | `Nemotron-SFT-Multilingual-v1` | -| Software engineering (SWE) | `Nemotron-SFT-SWE-v2` | -| Safety / alignment | `Nemotron-SFT-Safety-v1` | - -When adding new datasets, reduce weights of lower-priority categories proportionally to keep the total at 100%. - ---- - -### 2. Pruning - -Here we prune the [NVIDIA-Nemotron-3-Nano-30B-A3B-BF16](https://huggingface.co/nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16) HuggingFace checkpoint from 31.6B/A3.6B to 3.0B active parameters. The output is a pruned HuggingFace checkpoint that feeds into the distillation step. - -Run on **1 node with 8x H100** (~1 hour) - -
-Pruning command (click to expand) - -```bash -torchrun --nproc_per_node 8 /opt/Model-Optimizer/examples/megatron_bridge/prune_minitron.py \ - --pp_size 8 \ - --hf_model_name_or_path nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16 \ - --trust_remote_code \ - --prune_target_params 28e9 \ - --prune_target_active_params 3e9 \ - --hparams_to_skip num_attention_heads \ - --seq_length 8192 \ - --output_hf_path /path/to/Nemotron-3-Nano-30B-A3B-Pruned-A3.0B \ - --top_k 20 \ - --max_depth_pruning 0.15 \ - --max_width_pruning 0.30 \ - --prune_score_func mmlu_10pct_bs32 \ - --num_layers_in_first_pipeline_stage 5 \ - --num_layers_in_last_pipeline_stage 5 -``` - -Non-default arguments: - -- `--hparams_to_skip num_attention_heads` (default: none) — attention heads pruning is harder to recover, hence skipped -- `--seq_length 8192` (default: 4096) — dataset has longer sequences -- `--prune_target_active_params 3e9` — MoE-specific; the **primary** pruning constraint — targets active params rather than total params, which is what matters for MoE inference cost -- `--prune_target_params 28e9` — upper bound on total params only; the actual pruned model total can range anywhere from ~20B to 28B depending on which architecture wins — see pruning logs below for the top 20 candidates. You may also skip this argument all together for simplicity. -- `--top_k 20` (default: 10) — larger candidate pool for better architecture search -- `--max_depth_pruning 0.15` (default: 0.20) — tighter constraint since candidates with 42–46 layers universally fail for this model -- `--max_width_pruning 0.30` (default: 0.40) — tighter constraint to prevent head_dim≤48 and hidden=2048 dead zones -- `--prune_score_func mmlu_10pct_bs32` (default: `mmlu_10pct_bs1`) — batch_size=32 for ~3–4× faster candidate scoring -- `--num_layers_in_first_pipeline_stage 5 --num_layers_in_last_pipeline_stage 5` — Uneven pipeline parallelism since 52 layers is not divisible by 8 GPUs - -**NOTE**: The tighter search space constraints here (`--max_depth_pruning`, `--max_width_pruning`) are specific to Nemotron hybrid models (Mamba + Attention + MoE). Unlike standard transformers which expose only layers/hidden/attention/FFN dimensions, these models add Mamba-specific dimensions (`mamba_num_heads`, `mamba_head_dim`) and MoE dimensions (`num_moe_experts`, `moe_ffn_hidden_size`, `moe_shared_expert_intermediate_size`), making the combined search space much larger. The default 40%/20% bounds cast too wide a net and waste compute on dead-zone architectures. - -See [ABLATIONS.md](ABLATIONS.md#pruning) for the full architecture search analysis across various candidates. -
- -
-Pruning logs (top 20 candidates, best subnet, layer patterns) (click to expand) - -```text -╭──────────────────────────────────────────────────── Original Model Stats ─────────────────────────────────────────────────────╮ -│ Total Parameters 31.58B │ -│ Active Parameters 3.58B │ -│ Memory (BF16, seq_length=8192, batch_size=1) weights: 60230.1 MB, kv_cache: 48.0 MB, mamba_state: 23.8 MB, Total: 60301.9 MB │ -╰───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯ - - Search Space - (≤30% width / ≤15% depth pruning) -┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ -┃ Hyperparameter ┃ Choices ┃ -┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ -│ num_layers │ [46, 48, 50, 52] │ -│ hidden_size │ [2048, 2304, 2560, 2688] │ -│ mamba_num_heads │ [48, 56, 64] │ -│ mamba_head_dim │ [48, 56, 64] │ -│ num_moe_experts │ [96, 104, 112, 120, 128] │ -│ moe_ffn_hidden_size │ [1536, 1792, 1856] │ -│ moe_shared_expert_intermediate_size │ [2816, 3072, 3328, 3584, 3712] │ -├─────────────────────────────────────┼────────────────────────────────┤ -│ Search space size │ 10800 │ -└─────────────────────────────────────┴────────────────────────────────┘ - -Top 20 Candidates with Scores -┏━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━┳━━━━━━━━┓ -┃ # ┃ export_config ┃ active_params ┃ params ┃ score ┃ -┡━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━╇━━━━━━━━┩ -│ 1 │ {'num_layers': 46, 'hidden_size': 2560, 'mamba_num_heads': 56, 'mamba_head_dim': 64, 'num_moe_experts': 120, │ 3.00B │ 27.06B │ 0.3399 │ -│ │ 'moe_ffn_hidden_size': 1792, 'moe_shared_expert_intermediate_size': 3072} │ │ │ │ -│ 2 │ {'num_layers': 48, 'hidden_size': 2560, 'mamba_num_heads': 56, 'mamba_head_dim': 56, 'num_moe_experts': 112, │ 3.00B │ 25.37B │ 0.4650 │ -│ │ 'moe_ffn_hidden_size': 1792, 'moe_shared_expert_intermediate_size': 3072} │ │ │ │ -│ 3 │ {'num_layers': 46, 'hidden_size': 2560, 'mamba_num_heads': 64, 'mamba_head_dim': 56, 'num_moe_experts': 112, │ 3.00B │ 25.37B │ 0.2343 │ -│ │ 'moe_ffn_hidden_size': 1792, 'moe_shared_expert_intermediate_size': 3072} │ │ │ │ -│ 4 │ {'num_layers': 52, 'hidden_size': 2688, 'mamba_num_heads': 56, 'mamba_head_dim': 48, 'num_moe_experts': 96, │ 3.00B │ 20.09B │ 0.2552 │ -│ │ 'moe_ffn_hidden_size': 1536, 'moe_shared_expert_intermediate_size': 3072} │ │ │ │ -│ 5 │ {'num_layers': 52, 'hidden_size': 2688, 'mamba_num_heads': 48, 'mamba_head_dim': 56, 'num_moe_experts': 104, │ 3.00B │ 21.61B │ 0.2601 │ -│ │ 'moe_ffn_hidden_size': 1536, 'moe_shared_expert_intermediate_size': 3072} │ │ │ │ -│ 6 │ {'num_layers': 52, 'hidden_size': 2560, 'mamba_num_heads': 48, 'mamba_head_dim': 64, 'num_moe_experts': 96, │ 3.00B │ 19.28B │ 0.3762 │ -│ │ 'moe_ffn_hidden_size': 1536, 'moe_shared_expert_intermediate_size': 3712} │ │ │ │ -│ 7 │ {'num_layers': 52, 'hidden_size': 2304, 'mamba_num_heads': 64, 'mamba_head_dim': 64, 'num_moe_experts': 104, │ 3.00B │ 22.28B │ 0.4783 │ -│ │ 'moe_ffn_hidden_size': 1856, 'moe_shared_expert_intermediate_size': 3072} │ │ │ │ -│ 8 │ {'num_layers': 52, 'hidden_size': 2560, 'mamba_num_heads': 48, 'mamba_head_dim': 48, 'num_moe_experts': 96, │ 3.00B │ 21.99B │ 0.2420 │ -│ │ 'moe_ffn_hidden_size': 1792, 'moe_shared_expert_intermediate_size': 3328} │ │ │ │ -│ 9 │ {'num_layers': 50, 'hidden_size': 2560, 'mamba_num_heads': 48, 'mamba_head_dim': 48, 'num_moe_experts': 112, │ 3.00B │ 25.37B │ 0.2399 │ -│ │ 'moe_ffn_hidden_size': 1792, 'moe_shared_expert_intermediate_size': 3712} │ │ │ │ -│ 10 │ {'num_layers': 50, 'hidden_size': 2560, 'mamba_num_heads': 48, 'mamba_head_dim': 48, 'num_moe_experts': 112, │ 3.00B │ 26.17B │ 0.2601 │ -│ │ 'moe_ffn_hidden_size': 1856, 'moe_shared_expert_intermediate_size': 3328} │ │ │ │ -│ 11 │ {'num_layers': 46, 'hidden_size': 2560, 'mamba_num_heads': 56, 'mamba_head_dim': 64, 'num_moe_experts': 112, │ 3.00B │ 25.37B │ 0.2503 │ -│ │ 'moe_ffn_hidden_size': 1792, 'moe_shared_expert_intermediate_size': 3072} │ │ │ │ -│ 12 │ {'num_layers': 48, 'hidden_size': 2560, 'mamba_num_heads': 56, 'mamba_head_dim': 56, 'num_moe_experts': 104, │ 3.00B │ 23.68B │ 0.4329 │ -│ │ 'moe_ffn_hidden_size': 1792, 'moe_shared_expert_intermediate_size': 3072} │ │ │ │ -│ 13 │ {'num_layers': 46, 'hidden_size': 2688, 'mamba_num_heads': 64, 'mamba_head_dim': 64, 'num_moe_experts': 128, │ 3.00B │ 26.17B │ 0.2587 │ -│ │ 'moe_ffn_hidden_size': 1536, 'moe_shared_expert_intermediate_size': 2816} │ │ │ │ -│ 14 │ {'num_layers': 46, 'hidden_size': 2560, 'mamba_num_heads': 64, 'mamba_head_dim': 56, 'num_moe_experts': 104, │ 3.00B │ 23.68B │ 0.2336 │ -│ │ 'moe_ffn_hidden_size': 1792, 'moe_shared_expert_intermediate_size': 3072} │ │ │ │ -│ 15 │ {'num_layers': 52, 'hidden_size': 2688, 'mamba_num_heads': 48, 'mamba_head_dim': 56, 'num_moe_experts': 96, │ 3.00B │ 20.09B │ 0.2559 │ -│ │ 'moe_ffn_hidden_size': 1536, 'moe_shared_expert_intermediate_size': 3072} │ │ │ │ -│ 16 │ {'num_layers': 52, 'hidden_size': 2304, 'mamba_num_heads': 64, 'mamba_head_dim': 64, 'num_moe_experts': 96, │ 3.00B │ 20.70B │ 0.4608 │ -│ │ 'moe_ffn_hidden_size': 1856, 'moe_shared_expert_intermediate_size': 3072} │ │ │ │ -│ 17 │ {'num_layers': 50, 'hidden_size': 2560, 'mamba_num_heads': 48, 'mamba_head_dim': 48, 'num_moe_experts': 104, │ 3.00B │ 23.68B │ 0.2455 │ -│ │ 'moe_ffn_hidden_size': 1792, 'moe_shared_expert_intermediate_size': 3712} │ │ │ │ -│ 18 │ {'num_layers': 50, 'hidden_size': 2560, 'mamba_num_heads': 48, 'mamba_head_dim': 48, 'num_moe_experts': 104, │ 3.00B │ 24.42B │ 0.2503 │ -│ │ 'moe_ffn_hidden_size': 1856, 'moe_shared_expert_intermediate_size': 3328} │ │ │ │ -│ 19 │ {'num_layers': 48, 'hidden_size': 2560, 'mamba_num_heads': 48, 'mamba_head_dim': 48, 'num_moe_experts': 120, │ 3.00B │ 27.92B │ 0.2587 │ -│ │ 'moe_ffn_hidden_size': 1856, 'moe_shared_expert_intermediate_size': 3712} │ │ │ │ -│ 20 │ {'num_layers': 46, 'hidden_size': 2560, 'mamba_num_heads': 56, 'mamba_head_dim': 64, 'num_moe_experts': 104, │ 3.00B │ 23.68B │ 0.2469 │ -│ │ 'moe_ffn_hidden_size': 1792, 'moe_shared_expert_intermediate_size': 3072} │ │ │ │ -└────┴───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┴───────────────┴────────┴────────┘ - -╭──────────────────────────────────────────────────────────────────────── Best Subnet ─────────────────────────────────────────────────────────────────────────╮ -│ export_config {'num_layers': 52, 'hidden_size': 2304, 'mamba_num_heads': 64, 'mamba_head_dim': 64, 'num_moe_experts': 104, 'moe_ffn_hidden_size': 1856, │ -│ 'moe_shared_expert_intermediate_size': 3072} │ -│ active_params 3.00B │ -│ params 22.28B │ -│ score 0.4783 │ -╰──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯ - -Original hybrid_layer_pattern: MEMEM*EMEMEM*EMEMEM*EMEMEM*EMEMEM*EMEMEMEM*EMEMEMEME -Pruned hybrid_layer_pattern: MEMEM*EMEMEM*EMEMEM*EMEMEM*EMEMEM*EMEMEMEM*EMEMEMEME - -╭───────────────────────────────────────────────────── Pruned Model Stats ──────────────────────────────────────────────────────╮ -│ Total Parameters 22.28B │ -│ Active Parameters 3.00B │ -│ Memory (BF16, seq_length=8192, batch_size=1) weights: 42489.7 MB, kv_cache: 48.0 MB, mamba_state: 23.8 MB, Total: 42561.6 MB │ -╰───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯ -``` - -
- -> [!TIP] -> Candidate selection above relies on the pruning score alone — it does not run a short KD trial per candidate to pick the winner. The main post-pruning distillation in [Section 3](#3-distillation) is still performed on the selected candidate. If you want a stronger pick, take a few top candidates' `export_config` from the logs above (where the score is similar to the best subnet), export them separately, run KD for ~2B tokens on each, and pick the best on your target metrics. See [ABLATIONS.md — 1st vs 2nd best candidate](ABLATIONS.md#distillation-results-1st-best-vs-2nd-best-pruning-candidate) for a concrete comparison. - ---- - -### 3. Distillation - -Distillation is run in two phases: an 80B-token phase at 8K sequence length, followed by a 20B-token long-context phase at 32K sequence length. The two phases are launched as separate runs with an intermediate Megatron→HF checkpoint conversion, because the long-context phase changes `seq_length`, `gbs`, and `cp_size` — Megatron's checkpoint resume bookkeeping (sample counter is in absolute samples, iteration counter is in iter-units tied to `gbs`) does not handle a mid-run `gbs` change cleanly. - -Minimum hardware: **4 nodes × 8x H100 (32 GPUs)** for the 8K phase — required by `TP=4 × EP=8`. The 32K phase additionally requires context parallel to fit the longer sequence, doubling the minimum to **8 nodes × 8x H100 (64 GPUs)**. On **96 nodes × 8x H100 (768 GPUs total)**, it takes ~900 H100 GPU-hours per 10B tokens (400 iters), i.e. ~70 min wall-clock per 10B tokens on 96 nodes. Full schedule (80B @ 8K + 20B @ 32K = 100B tokens, 4k total steps) takes ~9k H100 GPU-hours (~12 hours wall-clock). - -#### 3a. Phase 1 — 80B tokens @ 8K seq length - -
-Phase 1 distillation command (click to expand) - -```bash -python -u /opt/Model-Optimizer/examples/megatron_bridge/distill.py \ - --teacher_hf_path nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16 \ - --student_hf_path /path/to/Nemotron-3-Nano-30B-A3B-Pruned-A3.0B \ - --trust_remote_code \ - --tp_size 4 \ - --ep_size 8 \ - --data_paths "${DATA_BLEND}" \ - --data_path_to_cache /path/to/cache \ - --seq_length 8192 \ - --mbs 1 \ - --gbs 3072 \ - --train_iters 3200 \ - --lr 1e-4 \ - --min_lr 1e-5 \ - --lr_warmup_iters 25 \ - --eval_interval 200 \ - --eval_iters 8 \ - --log_interval 10 \ - --output_dir /path/to/distill_output_phase1_8k - -# Optional: Weights & Biases logging -# --wandb_project \ -# --wandb_entity \ -# --wandb_exp_name -``` - -Non-default arguments: - -- `--seq_length 8192` (default: 4096) -- `--gbs 3072` (default: 768) — matches the original Nemotron-3-Nano-30B training GBS from the paper, kept to preserve the training distribution -- `--train_iters 3200` — 80B tokens at GBS 3072 × seq_length 8192 -- `--lr 1e-4 --min_lr 1e-5 --lr_warmup_iters 25` — cosine fully decays over 3200 iters; the model is approaching saturation at 8K by this point (see [ABLATIONS.md — 8K trajectory](ABLATIONS.md#effect-of-data-blend-tool_calling)). -- `--eval_interval 200` (default: 100) — less frequent eval to save compute -- `--eval_iters 8` (default: 32) — since GBS is 4× larger than default - -All other arguments use defaults. -
- -#### 3b. Convert Phase 1 final checkpoint to HuggingFace format - -Phase 2 starts as a separate run from a fresh HuggingFace student checkpoint, so the final Phase 1 Megatron checkpoint must be exported to HF first using the Megatron-Bridge conversion script (see [Megatron-Bridge README](../../../megatron_bridge/README.md) for full details). You can also use this same script to convert any intermediate Phase 1 checkpoint to HF format for evaluation along the way. - -
-Checkpoint conversion command (click to expand) - -```bash -python /opt/Megatron-Bridge/examples/conversion/convert_checkpoints.py export \ - --hf-model /path/to/Nemotron-3-Nano-30B-A3B-Pruned-A3.0B \ - --megatron-path /path/to/distill_output_phase1_8k/checkpoints/iter_0003200 \ - --hf-path /path/to/distill_output_phase1_8k/checkpoints/hf_iter_0003200 -``` - -
- -#### 3c. Phase 2 — 20B tokens @ 32K seq length - -Phase 2 is a **fresh run** with the Phase 1 final checkpoint as the new student. It uses a different `--seed` so the data blend reshuffles (otherwise the model would see overlapping prefix of the same samples it already saw at 8K). The LR is bumped back up modestly to capture the rapid long-context adaptation observed in [ABLATIONS.md — Effect of long context training](ABLATIONS.md#effect-of-long-context-training). - -
-Phase 2 distillation command (click to expand) - -```bash -python -u /opt/Model-Optimizer/examples/megatron_bridge/distill.py \ - --teacher_hf_path nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16 \ - --student_hf_path /path/to/distill_output_phase1_8k/checkpoints/hf_iter_0003200 \ - --trust_remote_code \ - --tp_size 4 \ - --cp_size 2 \ - --ep_size 8 \ - --seed 5678 \ - --data_paths "${DATA_BLEND}" \ - --data_path_to_cache /path/to/cache \ - --seq_length 32768 \ - --mbs 1 \ - --gbs 768 \ - --train_iters 800 \ - --lr 2e-5 \ - --min_lr 1e-5 \ - --lr_warmup_iters 10 \ - --recompute_granularity selective \ - --recompute_modules moe \ - --eval_interval 200 \ - --eval_iters 8 \ - --log_interval 10 \ - --output_dir /path/to/distill_output_phase2_32k -``` - -Changed arguments from Phase 1: - -- `--student_hf_path` — points at the HF export of the Phase 1 final checkpoint -- `--seq_length 32768` — long-context phase -- `--gbs 768` — `seq_length × gbs` product unchanged, so each iter still processes the same number of tokens -- `--cp_size 2` — context parallel is needed to fit the longer sequence; doubles the minimum-hardware footprint to 8 nodes -- `--train_iters 800` — 20B tokens at GBS 768 × seq_length 32768 -- `--lr 2e-5 --min_lr 1e-5 --lr_warmup_iters 10` — modest LR bump for the long-context adaptation (Phase 1 ended at fully-decayed LR 1e-5); the 10-iter warmup re-populates Adam moment estimates which restart from zero in a fresh run -- `--recompute_granularity selective --recompute_modules moe` — selective MoE recompute further reduces activation memory at 32K. You may skip this if you have more memory. -- `--seed 5678` — different from the Phase 1 seed (default 1234) so the data blend reshuffles -- `--output_dir /path/to/distill_output_phase2_32k` — must be a **fresh directory** different from Phase 1's, so distill.py's resume mechanism (which auto-loads from `/checkpoints` if it exists) does not pull in stale state - -
- -For multi-node Slurm runs, see the [Megatron-Bridge README](../../../megatron_bridge/README.md#slurm-usage) for details. - -> [!NOTE] -> This is pure SFT-style distillation — no RL or online reward signal is used. Adding an RL-based post-training step after distillation is a natural next step that could further improve some of these benchmarks. - ---- - -### 4. Evaluation - -The eval config in [nemo_evaluator.yaml](nemo_evaluator.yaml) is for Slurm-based evaluation — it submits a vLLM serving job (with tool calling enabled via `--enable-auto-tool-choice --tool-call-parser qwen3_coder`) and runs evals against it. For local model execution and evaluation, refer to the [NeMo Evaluator documentation](https://docs.nvidia.com/nemo/evaluator/latest/) or this [blog](https://huggingface.co/blog/nvidia/nemotron-3-nano-evaluation-recipe). - -Before running, update the following fields in the yaml or overwrite them in the command line with `-o