From 5d5c26e221192512af7f5925429e3f8f2f7464b5 Mon Sep 17 00:00:00 2001 From: mnachin Date: Fri, 15 May 2026 06:48:26 -0700 Subject: [PATCH 1/3] Apply Gemma 4 IT chat template in inference.py and C++ runner Gemma 4 31B-IT is instruction-tuned and produces degenerate output without the chat template wrapping. Auto-wrap --prompt with the IT template (<|turn>user\n{prompt}\n<|turn>model\n <|channel>thought\n) by default; --raw-prompt / --raw_prompt skips wrapping for pre-formatted input. --- examples/models/gemma4_31b/README.md | 6 ++++++ examples/models/gemma4_31b/inference.py | 25 ++++++++++++++++++++++++- examples/models/gemma4_31b/main.cpp | 12 ++++++++++++ 3 files changed, 42 insertions(+), 1 deletion(-) diff --git a/examples/models/gemma4_31b/README.md b/examples/models/gemma4_31b/README.md index 6f567d739b7..94783c8f823 100644 --- a/examples/models/gemma4_31b/README.md +++ b/examples/models/gemma4_31b/README.md @@ -79,6 +79,9 @@ Writes `model.pte` and `model.ptd` into `--output-dir`. ## Eager inference +The prompt is automatically wrapped with the Gemma 4 IT chat template. +Pass `--raw-prompt` to skip template wrapping for pre-formatted input. + ```bash python examples/models/gemma4_31b/inference.py \ --prequantized ./gemma4_31b_int4 \ @@ -109,6 +112,9 @@ The binary lands at `cmake-out/examples/models/gemma4_31b/gemma4_31b_runner`. ## Run the .pte +The prompt is automatically wrapped with the Gemma 4 IT chat template. +Pass `--raw_prompt` to skip template wrapping for pre-formatted input. + ```bash ./gemma4_31b_runner \ --model_path ./gemma4_31b_exports/model.pte \ diff --git a/examples/models/gemma4_31b/inference.py b/examples/models/gemma4_31b/inference.py index 12785450d8c..62dfe5956a7 100644 --- a/examples/models/gemma4_31b/inference.py +++ b/examples/models/gemma4_31b/inference.py @@ -13,6 +13,11 @@ Packs for the target backend (--backend cuda), materializes runtime buffers, optionally compiles with ``torch.compile``, and generates text autoregressively. +Gemma 4 31B-IT is instruction-tuned and requires chat-template formatting. +The ``--prompt`` is automatically wrapped with the Gemma 4 chat template +(``<|turn>user\\n{prompt}\\n<|turn>model\\n<|channel>thought\\n``; BOS is prepended separately). +Pass ``--raw-prompt`` to skip template wrapping (e.g., for pre-formatted input). + Usage: python inference.py \\ --prequantized ./gemma4_31b_int4 \\ @@ -63,6 +68,17 @@ def _move_to_cuda(model, config) -> None: materialize_runtime_buffers(model, dtype=torch.bfloat16, device="cuda") +def apply_chat_template(prompt: str) -> str: + """Wrap a user prompt in the Gemma 4 IT chat template. + + Does not include BOS — ``generate()`` prepends it at the token-ID level. + """ + return ( + "<|turn>user\n" + prompt + + "\n<|turn>model\n<|channel>thought\n" + ) + + def generate( model, tokenizer, @@ -155,6 +171,11 @@ def main() -> None: default=4096, help="KV cache length to allocate for this run.", ) + parser.add_argument( + "--raw-prompt", + action="store_true", + help="Skip chat-template wrapping (use if the prompt is already formatted).", + ) parser.add_argument( "--no-compile", action="store_true", @@ -204,6 +225,8 @@ def main() -> None: # Gemma 4 EOS tokens (from generation_config.json: ids 1, 50, 106). eos_token_ids = {1, 50, 106} + prompt = args.prompt if args.raw_prompt else apply_chat_template(args.prompt) + print(f"\nPrompt: {args.prompt}") print("-" * 40) @@ -211,7 +234,7 @@ def main() -> None: output = generate( model, tokenizer, - args.prompt, + prompt, max_new_tokens=args.max_new_tokens, temperature=args.temperature, eos_token_ids=eos_token_ids, diff --git a/examples/models/gemma4_31b/main.cpp b/examples/models/gemma4_31b/main.cpp index 0be2fef517c..3ddf64e410f 100644 --- a/examples/models/gemma4_31b/main.cpp +++ b/examples/models/gemma4_31b/main.cpp @@ -65,6 +65,10 @@ DEFINE_double(temperature, 0.8, "Sampling temperature (0 = near-greedy)."); DEFINE_int32(max_new_tokens, 128, "Maximum tokens to generate."); DEFINE_int32(bos_id, 2, "BOS token id to prepend (Gemma convention: 2)."); DEFINE_int32(eos_id, 1, "EOS token id (Gemma convention: 1)."); +DEFINE_bool( + raw_prompt, + false, + "Skip chat-template wrapping (use if the prompt is already formatted)."); DEFINE_bool( cuda_graph, false, @@ -232,6 +236,14 @@ int main(int argc, char** argv) { (std::istreambuf_iterator(f)), std::istreambuf_iterator()); } + // Wrap with Gemma 4 IT chat template unless --raw_prompt is set. + // BOS is prepended separately below; this adds the turn structure and the + // empty thought block required by the instruction-tuned model. + if (!FLAGS_raw_prompt) { + prompt_text = "<|turn>user\n" + prompt_text + + "\n<|turn>model\n<|channel>thought\n"; + } + // Encode prompt auto encode_result = tokenizer->encode(prompt_text); if (!encode_result.ok()) { From 6c1603114a3213b787f8509f2c7f81532b362008 Mon Sep 17 00:00:00 2001 From: Gasoonjia Date: Mon, 18 May 2026 07:13:19 -0700 Subject: [PATCH 2/3] Gemma4 31b rope fix and ci (#19627) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Summary Currently `materialize_runtime_buffers` in model.py was zeroing out ALL meta buffers, including each layer's inv_freq (RoPE frequencies). The follow-up `attn.inv_freq.to(device)` was a no-op on already-zero tensors. So RoPE produced cos=1, sin=0 for every position → model had NO positional information → introduce the period-N echo cycle pattern. This PR fix the issue by recomputing inv_freq per-layer with real values (using the layer's head_dim, partial_rotary, rope_theta, is_sliding flag) in materialize_runtime_buffers. ### Test plan Add e2e ci for gemma4-31b model and check its output. --- .ci/scripts/export_model_artifact.sh | 49 ++++++++++++++++++++- .ci/scripts/test_model_e2e.sh | 19 +++++++- .github/workflows/cuda.yml | 30 ++++++++++--- examples/models/gemma4_31b/inference.py | 58 +++++++++++++++---------- examples/models/gemma4_31b/model.py | 24 +++++++++- 5 files changed, 148 insertions(+), 32 deletions(-) diff --git a/.ci/scripts/export_model_artifact.sh b/.ci/scripts/export_model_artifact.sh index 1f75d850e84..4bc8485dde8 100755 --- a/.ci/scripts/export_model_artifact.sh +++ b/.ci/scripts/export_model_artifact.sh @@ -195,9 +195,17 @@ case "$HF_MODEL" in PREPROCESSOR_FEATURE_SIZE="" PREPROCESSOR_OUTPUT="" ;; + SocialLocalMobile/gemma-4-31B-it-HQQ-INT4) + MODEL_NAME="gemma4_31b" + TASK="" + MAX_SEQ_LEN="" + EXTRA_PIP="" + PREPROCESSOR_FEATURE_SIZE="" + PREPROCESSOR_OUTPUT="" + ;; *) echo "Error: Unsupported model '$HF_MODEL'" - echo "Supported models: mistralai/Voxtral-Mini-3B-2507, mistralai/Voxtral-Mini-4B-Realtime-2602, openai/whisper-{small, medium, large, large-v2, large-v3, large-v3-turbo}, google/gemma-3-4b-it, Qwen/Qwen3-0.6B, nvidia/diar_streaming_sortformer_4spk-v2, nvidia/parakeet-tdt, facebook/dinov2-small-imagenet1k-1-layer, SocialLocalMobile/Qwen3.5-35B-A3B-HQQ-INT4" + echo "Supported models: mistralai/Voxtral-Mini-3B-2507, mistralai/Voxtral-Mini-4B-Realtime-2602, openai/whisper-{small, medium, large, large-v2, large-v3, large-v3-turbo}, google/gemma-3-4b-it, Qwen/Qwen3-0.6B, nvidia/diar_streaming_sortformer_4spk-v2, nvidia/parakeet-tdt, facebook/dinov2-small-imagenet1k-1-layer, SocialLocalMobile/Qwen3.5-35B-A3B-HQQ-INT4, SocialLocalMobile/gemma-4-31B-it-HQQ-INT4" exit 1 ;; esac @@ -459,6 +467,45 @@ if [ "$MODEL_NAME" = "qwen3_5_moe" ]; then exit 0 fi +# Gemma 4 31B uses a prequantized checkpoint and custom export script +if [ "$MODEL_NAME" = "gemma4_31b" ]; then + pip install safetensors huggingface_hub gguf + + # Download prequantized model outside OUTPUT_DIR to avoid uploading on failure + LOCAL_MODEL_DIR=$(mktemp -d) + INDUCTOR_CACHE=$(mktemp -d) + trap 'rm -rf "$LOCAL_MODEL_DIR" "$INDUCTOR_CACHE"' EXIT + + python -c "from huggingface_hub import snapshot_download; snapshot_download('${HF_MODEL}', local_dir='${LOCAL_MODEL_DIR}')" + + # Sanity check: run inference on the prequantized model + echo "::group::Inference sanity check" + python -m executorch.examples.models.gemma4_31b.inference \ + --prequantized "$LOCAL_MODEL_DIR" \ + --prompt "What is the capital of France?" \ + --max-new-tokens 32 \ + --temperature 0 \ + --no-compile + echo "::endgroup::" + + # Copy tokenizer for the runner + cp "$LOCAL_MODEL_DIR/tokenizer.json" "${OUTPUT_DIR}/tokenizer.json" + + # Export to .pte/.ptd (short cache dir avoids objcopy symbol length issues) + echo "::group::Export" + TORCHINDUCTOR_CACHE_DIR="$INDUCTOR_CACHE" \ + python -m executorch.examples.models.gemma4_31b.export \ + --prequantized "$LOCAL_MODEL_DIR" \ + --output-dir "${OUTPUT_DIR}" + echo "::endgroup::" + + test -f "${OUTPUT_DIR}/model.pte" + test -f "${OUTPUT_DIR}/aoti_cuda_blob.ptd" + ls -al "${OUTPUT_DIR}" + + exit 0 +fi + MAX_SEQ_LEN_ARG="" if [ -n "$MAX_SEQ_LEN" ]; then MAX_SEQ_LEN_ARG="--max_seq_len $MAX_SEQ_LEN" diff --git a/.ci/scripts/test_model_e2e.sh b/.ci/scripts/test_model_e2e.sh index 1678b0a4fbb..27b0dd9d597 100755 --- a/.ci/scripts/test_model_e2e.sh +++ b/.ci/scripts/test_model_e2e.sh @@ -228,9 +228,21 @@ case "$HF_MODEL" in AUDIO_FILE="" IMAGE_PATH="" ;; + SocialLocalMobile/gemma-4-31B-it-HQQ-INT4) + MODEL_NAME="gemma4_31b" + RUNNER_TARGET="gemma4_31b_runner" + RUNNER_PATH="gemma4_31b" + EXPECTED_OUTPUT="Paris" + PREPROCESSOR="" + TOKENIZER_URL="" + TOKENIZER_FILE="tokenizer.json" + AUDIO_URL="" + AUDIO_FILE="" + IMAGE_PATH="" + ;; *) echo "Error: Unsupported model '$HF_MODEL'" - echo "Supported models: mistralai/Voxtral-Mini-3B-2507, mistralai/Voxtral-Mini-4B-Realtime-2602, nvidia/diar_streaming_sortformer_4spk-v2, openai/whisper series (whisper-{small, medium, large, large-v2, large-v3, large-v3-turbo}), google/gemma-3-4b-it, Qwen/Qwen3-0.6B, nvidia/parakeet-tdt, facebook/dinov2-small-imagenet1k-1-layer, SocialLocalMobile/Qwen3.5-35B-A3B-HQQ-INT4" + echo "Supported models: mistralai/Voxtral-Mini-3B-2507, mistralai/Voxtral-Mini-4B-Realtime-2602, nvidia/diar_streaming_sortformer_4spk-v2, openai/whisper series (whisper-{small, medium, large, large-v2, large-v3, large-v3-turbo}), google/gemma-3-4b-it, Qwen/Qwen3-0.6B, nvidia/parakeet-tdt, facebook/dinov2-small-imagenet1k-1-layer, SocialLocalMobile/Qwen3.5-35B-A3B-HQQ-INT4, SocialLocalMobile/gemma-4-31B-it-HQQ-INT4" exit 1 ;; esac @@ -244,7 +256,7 @@ echo "::group::Prepare $MODEL_NAME Artifacts" # Download tokenizer files (skip for models that bundle tokenizer in export or do not use one) -if [ "$MODEL_NAME" != "parakeet" ] && [ "$MODEL_NAME" != "voxtral_realtime" ] && [ "$MODEL_NAME" != "sortformer" ] && [ "$MODEL_NAME" != "dinov2" ] && [ "$MODEL_NAME" != "qwen3_5_moe" ]; then +if [ "$MODEL_NAME" != "parakeet" ] && [ "$MODEL_NAME" != "voxtral_realtime" ] && [ "$MODEL_NAME" != "sortformer" ] && [ "$MODEL_NAME" != "dinov2" ] && [ "$MODEL_NAME" != "qwen3_5_moe" ] && [ "$MODEL_NAME" != "gemma4_31b" ]; then if [ "$TOKENIZER_FILE" != "" ]; then curl -L $TOKENIZER_URL/$TOKENIZER_FILE -o $MODEL_DIR/$TOKENIZER_FILE else @@ -368,6 +380,9 @@ EOF qwen3_5_moe) RUNNER_ARGS="$RUNNER_ARGS --tokenizer_path ${MODEL_DIR}/$TOKENIZER_FILE --prompt 'What is the capital of France?' --max_new_tokens 128 --temperature 0 --cuda_graph" ;; + gemma4_31b) + RUNNER_ARGS="$RUNNER_ARGS --tokenizer_path ${MODEL_DIR}/$TOKENIZER_FILE --prompt 'What is the capital of France?' --max_new_tokens 128 --temperature 0 --cuda_graph" + ;; voxtral_realtime) RUNNER_ARGS="--model_path ${MODEL_DIR}/model.pte --tokenizer_path ${MODEL_DIR}/$TOKENIZER_FILE --preprocessor_path ${MODEL_DIR}/$PREPROCESSOR --audio_path ${MODEL_DIR}/$AUDIO_FILE --temperature 0" # Add CUDA data path if present diff --git a/.github/workflows/cuda.yml b/.github/workflows/cuda.yml index 087917c1116..52bd1943cd6 100644 --- a/.github/workflows/cuda.yml +++ b/.github/workflows/cuda.yml @@ -148,10 +148,6 @@ jobs: # Run Qwen 3.5 MoE tests (quantize roundtrip + TurboQuant KV cache + sampler) python -m pytest examples/models/qwen3_5_moe/test_quantize_roundtrip.py examples/models/qwen3_5_moe/test_turboquant.py examples/models/qwen3_5_moe/test_sampler.py -v -o "addopts=" - # Run Gemma 4 31B tests (quant unit tests + pipeline integration tests) - pip install gguf - python -m pytest examples/models/gemma4_31b/quant/tests/ examples/models/gemma4_31b/tests/ -v -o "addopts=" - export-model-cuda-artifact: name: export-model-cuda-artifact # Skip this job if the pull request is from a fork (HuggingFace secrets are not available) @@ -185,6 +181,8 @@ jobs: name: "dinov2-small-imagenet1k-1-layer" - repo: "SocialLocalMobile" name: "Qwen3.5-35B-A3B-HQQ-INT4" + - repo: "SocialLocalMobile" + name: "gemma-4-31B-it-HQQ-INT4" quant: - "non-quantized" - "quantized-int4-tile-packed" @@ -204,6 +202,15 @@ jobs: repo: "SocialLocalMobile" name: "Qwen3.5-35B-A3B-HQQ-INT4" quant: "quantized-int4-weight-only" + # Gemma 4 31B uses a prequantized checkpoint, only tile-packed + - model: + repo: "SocialLocalMobile" + name: "gemma-4-31B-it-HQQ-INT4" + quant: "non-quantized" + - model: + repo: "SocialLocalMobile" + name: "gemma-4-31B-it-HQQ-INT4" + quant: "quantized-int4-weight-only" # Voxtral Realtime only supports int4-tile-packed on CUDA - model: repo: "mistralai" @@ -258,7 +265,7 @@ jobs: with: timeout: 90 secrets-env: EXECUTORCH_HF_TOKEN - runner: ${{ matrix.model.name == 'Qwen3.5-35B-A3B-HQQ-INT4' && 'linux.aws.a100' || 'linux.g5.4xlarge.nvidia.gpu' }} + runner: ${{ (matrix.model.name == 'Qwen3.5-35B-A3B-HQQ-INT4' || matrix.model.name == 'gemma-4-31B-it-HQQ-INT4') && 'linux.aws.a100' || 'linux.g5.4xlarge.nvidia.gpu' }} gpu-arch-type: cuda gpu-arch-version: 12.6 use-custom-docker-registry: false @@ -315,6 +322,8 @@ jobs: name: "dinov2-small-imagenet1k-1-layer" - repo: "SocialLocalMobile" name: "Qwen3.5-35B-A3B-HQQ-INT4" + - repo: "SocialLocalMobile" + name: "gemma-4-31B-it-HQQ-INT4" quant: - "non-quantized" - "quantized-int4-tile-packed" @@ -334,6 +343,15 @@ jobs: repo: "SocialLocalMobile" name: "Qwen3.5-35B-A3B-HQQ-INT4" quant: "quantized-int4-weight-only" + # Gemma 4 31B uses a prequantized checkpoint, only tile-packed + - model: + repo: "SocialLocalMobile" + name: "gemma-4-31B-it-HQQ-INT4" + quant: "non-quantized" + - model: + repo: "SocialLocalMobile" + name: "gemma-4-31B-it-HQQ-INT4" + quant: "quantized-int4-weight-only" # Voxtral Realtime only supports int4-tile-packed on CUDA - model: repo: "mistralai" @@ -382,7 +400,7 @@ jobs: quant: "non-quantized" with: timeout: 90 - runner: ${{ matrix.model.name == 'Qwen3.5-35B-A3B-HQQ-INT4' && 'linux.aws.a100' || 'linux.g5.4xlarge.nvidia.gpu' }} + runner: ${{ (matrix.model.name == 'Qwen3.5-35B-A3B-HQQ-INT4' || matrix.model.name == 'gemma-4-31B-it-HQQ-INT4') && 'linux.aws.a100' || 'linux.g5.4xlarge.nvidia.gpu' }} gpu-arch-type: cuda gpu-arch-version: 12.6 use-custom-docker-registry: false diff --git a/examples/models/gemma4_31b/inference.py b/examples/models/gemma4_31b/inference.py index 62dfe5956a7..e1563c04ff6 100644 --- a/examples/models/gemma4_31b/inference.py +++ b/examples/models/gemma4_31b/inference.py @@ -6,12 +6,10 @@ """Eager inference on Gemma 4 31B-IT (CUDA + torch.compile). -Two input paths: +Three input paths: --prequantized Load a quantized checkpoint (from quantize_and_save.py). --gguf Load a GGUF file (e.g., Q4_K_M from the community). - -Packs for the target backend (--backend cuda), materializes runtime buffers, -optionally compiles with ``torch.compile``, and generates text autoregressively. + --bf16 Load the bf16 HF safetensors checkpoint via from_hf_checkpoint. Gemma 4 31B-IT is instruction-tuned and requires chat-template formatting. The ``--prompt`` is automatically wrapped with the Gemma 4 chat template @@ -38,7 +36,10 @@ import torch from executorch.examples.models.gemma4_31b.export import load_prequantized_model -from executorch.examples.models.gemma4_31b.model import materialize_runtime_buffers +from executorch.examples.models.gemma4_31b.model import ( + Gemma4_31B, + materialize_runtime_buffers, +) def _move_to_cuda(model, config) -> None: @@ -74,7 +75,8 @@ def apply_chat_template(prompt: str) -> str: Does not include BOS — ``generate()`` prepends it at the token-ID level. """ return ( - "<|turn>user\n" + prompt + "<|turn>user\n" + + prompt + "\n<|turn>model\n<|channel>thought\n" ) @@ -147,6 +149,11 @@ def main() -> None: default=None, help="Path to a GGUF file (e.g., gemma-4-31B-it-Q4_K_M.gguf).", ) + src.add_argument( + "--bf16", + default=None, + help="Path to a bf16 hf directory (e.g., gemma-4-31B).", + ) parser.add_argument( "--tokenizer-path", default=None, @@ -192,12 +199,34 @@ def main() -> None: if args.backend == "cuda" and not torch.cuda.is_available(): parser.error("CUDA is required for the cuda backend.") + # ---- Tokenizer ---- + if args.tokenizer_path: + tokenizer_path = args.tokenizer_path + elif args.prequantized: + tokenizer_path = os.path.join(args.prequantized, "tokenizer.json") + elif args.bf16: + tokenizer_path = os.path.join(args.bf16, "tokenizer.json") + else: + parser.error("--tokenizer-path is required with --gguf.") + from tokenizers import Tokenizer + + tokenizer = Tokenizer.from_file(tokenizer_path) + + prompt_str = args.prompt if args.raw_prompt else apply_chat_template(args.prompt) + + # Gemma 4 EOS tokens (from generation_config.json: ids 1, 50, 106). + eos_token_ids = {1, 50, 106} + if args.gguf: from executorch.examples.models.gemma4_31b.gguf_loader import load_gguf_model model, config = load_gguf_model( args.gguf, args.max_seq_len, backend=args.backend ) + elif args.bf16: + model, config = Gemma4_31B.from_hf_checkpoint( + args.bf16, max_seq_len=args.max_seq_len + ) else: print(f"Loading prequantized model from {args.prequantized}...") model, config = load_prequantized_model( @@ -212,21 +241,6 @@ def main() -> None: print("Compiling model with torch.compile...") model = torch.compile(model, mode="default") - if args.tokenizer_path: - tokenizer_path = args.tokenizer_path - elif args.prequantized: - tokenizer_path = os.path.join(args.prequantized, "tokenizer.json") - else: - parser.error("--tokenizer-path is required with --gguf.") - from tokenizers import Tokenizer - - tokenizer = Tokenizer.from_file(tokenizer_path) - - # Gemma 4 EOS tokens (from generation_config.json: ids 1, 50, 106). - eos_token_ids = {1, 50, 106} - - prompt = args.prompt if args.raw_prompt else apply_chat_template(args.prompt) - print(f"\nPrompt: {args.prompt}") print("-" * 40) @@ -234,7 +248,7 @@ def main() -> None: output = generate( model, tokenizer, - prompt, + prompt_str, max_new_tokens=args.max_new_tokens, temperature=args.temperature, eos_token_ids=eos_token_ids, diff --git a/examples/models/gemma4_31b/model.py b/examples/models/gemma4_31b/model.py index b0eb4004c52..b457c8807ca 100644 --- a/examples/models/gemma4_31b/model.py +++ b/examples/models/gemma4_31b/model.py @@ -675,7 +675,29 @@ def materialize_runtime_buffers( for layer in model.layers: attn = layer.self_attn - attn.inv_freq = attn.inv_freq.to(device) + if attn.is_sliding: + rotary_dim = attn.head_dim + else: + rotary_dim = int(attn.head_dim * attn.partial_rotary) + rope_angles = rotary_dim // 2 + inv_freq_rotated = 1.0 / ( + attn.rope_theta + ** ( + torch.arange(0, rotary_dim, 2, device=device, dtype=torch.float32) + / attn.head_dim + ) + ) + nope_angles = attn.head_dim // 2 - rope_angles + if nope_angles > 0: + inv_freq = torch.cat( + [ + inv_freq_rotated, + torch.zeros(nope_angles, device=device, dtype=torch.float32), + ] + ) + else: + inv_freq = inv_freq_rotated + attn.register_buffer("inv_freq", inv_freq, persistent=False) model.register_buffer( "embed_normalizer", From dbfd0ff8ddb72d0ccc5aaa09e30d0d413ba71973 Mon Sep 17 00:00:00 2001 From: Mergen Nachin Date: Mon, 18 May 2026 07:27:47 -0700 Subject: [PATCH 3/3] Deduplicate inv_freq computation, harden CI sanity check, restore unit tests - Extract _compute_inv_freq() on Gemma4Attention so __init__ and materialize_runtime_buffers share a single implementation. - Check for "Paris" in the export CI inference sanity check instead of just checking the script doesn't crash. - Restore gemma4_31b quant/pipeline unit tests in the CUDA build job. - Update model.md to reflect that inv_freq is recomputed, not moved. --- .ci/scripts/export_model_artifact.sh | 9 +++- .github/workflows/cuda.yml | 4 ++ examples/models/gemma4_31b/model.md | 2 +- examples/models/gemma4_31b/model.py | 68 +++++++++++----------------- 4 files changed, 39 insertions(+), 44 deletions(-) diff --git a/.ci/scripts/export_model_artifact.sh b/.ci/scripts/export_model_artifact.sh index 4bc8485dde8..9adea394993 100755 --- a/.ci/scripts/export_model_artifact.sh +++ b/.ci/scripts/export_model_artifact.sh @@ -480,12 +480,17 @@ if [ "$MODEL_NAME" = "gemma4_31b" ]; then # Sanity check: run inference on the prequantized model echo "::group::Inference sanity check" - python -m executorch.examples.models.gemma4_31b.inference \ + INFERENCE_OUTPUT=$(python -m executorch.examples.models.gemma4_31b.inference \ --prequantized "$LOCAL_MODEL_DIR" \ --prompt "What is the capital of France?" \ --max-new-tokens 32 \ --temperature 0 \ - --no-compile + --no-compile 2>&1) + echo "$INFERENCE_OUTPUT" + if ! echo "$INFERENCE_OUTPUT" | grep -q "Paris"; then + echo "ERROR: Inference sanity check failed — expected 'Paris' in output" + exit 1 + fi echo "::endgroup::" # Copy tokenizer for the runner diff --git a/.github/workflows/cuda.yml b/.github/workflows/cuda.yml index 52bd1943cd6..eb7fc5a8939 100644 --- a/.github/workflows/cuda.yml +++ b/.github/workflows/cuda.yml @@ -148,6 +148,10 @@ jobs: # Run Qwen 3.5 MoE tests (quantize roundtrip + TurboQuant KV cache + sampler) python -m pytest examples/models/qwen3_5_moe/test_quantize_roundtrip.py examples/models/qwen3_5_moe/test_turboquant.py examples/models/qwen3_5_moe/test_sampler.py -v -o "addopts=" + # Run Gemma 4 31B tests (quant unit tests + pipeline integration tests) + pip install gguf + python -m pytest examples/models/gemma4_31b/quant/tests/ examples/models/gemma4_31b/tests/ -v -o "addopts=" + export-model-cuda-artifact: name: export-model-cuda-artifact # Skip this job if the pull request is from a fork (HuggingFace secrets are not available) diff --git a/examples/models/gemma4_31b/model.md b/examples/models/gemma4_31b/model.md index 8233b6d430e..51e420528f1 100644 --- a/examples/models/gemma4_31b/model.md +++ b/examples/models/gemma4_31b/model.md @@ -165,7 +165,7 @@ RoPE inv_freq buffers, and scalar constants are still on the meta device. them with real tensors: - KV caches → zeros in `dtype` (bf16 for inference, bf16 for export) -- `inv_freq` → moved to target device (cos/sin computed on the fly per forward) +- `inv_freq` → recomputed on target device (cos/sin computed on the fly per forward) - `embed_normalizer`, `logit_softcap`, `cache_positions` → scalar constants Called by `export.py` (device="cpu" for tracing) and `inference.py` diff --git a/examples/models/gemma4_31b/model.py b/examples/models/gemma4_31b/model.py index b457c8807ca..f0aa2fac982 100644 --- a/examples/models/gemma4_31b/model.py +++ b/examples/models/gemma4_31b/model.py @@ -251,25 +251,7 @@ def __init__(self, config: Gemma4_31BConfig, layer_idx: int): self.k_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps) self.v_norm = RMSNormNoWeight(self.head_dim, eps=config.rms_norm_eps) - # Precomputed RoPE table for this layer (per-layer because head_dim - # and theta differ between sliding and full attention). For full - # attention layers we pass freq_base_dim=head_dim so the zero-padded - # On-the-fly RoPE: store only inv_freq, compute cos/sin per forward. - # Saves memory vs precomputed [max_seq_len, head_dim] tables. - if self.is_sliding: - rotary_dim = self.head_dim - else: - rotary_dim = int(self.head_dim * self.partial_rotary) - rope_angles = rotary_dim // 2 - inv_freq_rotated = 1.0 / ( - self.rope_theta ** (torch.arange(0, rotary_dim, 2).float() / self.head_dim) - ) - nope_angles = self.head_dim // 2 - rope_angles - if nope_angles > 0: - inv_freq = torch.cat([inv_freq_rotated, torch.zeros(nope_angles)]) - else: - inv_freq = inv_freq_rotated - self.register_buffer("inv_freq", inv_freq, persistent=False) + self.register_buffer("inv_freq", self._compute_inv_freq(), persistent=False) # KV cache. Sliding layers use a ring buffer (2x window) to save # memory; full layers use a flat buffer (max_seq_len). @@ -289,6 +271,30 @@ def __init__(self, config: Gemma4_31BConfig, layer_idx: int): use_index_copy=True, ) + def _compute_inv_freq(self, device: Optional[torch.device] = None) -> torch.Tensor: + """Compute RoPE inverse-frequency table for this layer.""" + if self.is_sliding: + rotary_dim = self.head_dim + else: + rotary_dim = int(self.head_dim * self.partial_rotary) + rope_angles = rotary_dim // 2 + inv_freq_rotated = 1.0 / ( + self.rope_theta + ** ( + torch.arange(0, rotary_dim, 2, device=device, dtype=torch.float32) + / self.head_dim + ) + ) + nope_angles = self.head_dim // 2 - rope_angles + if nope_angles > 0: + return torch.cat( + [ + inv_freq_rotated, + torch.zeros(nope_angles, device=device, dtype=torch.float32), + ] + ) + return inv_freq_rotated + def forward( self, x: torch.Tensor, @@ -675,29 +681,9 @@ def materialize_runtime_buffers( for layer in model.layers: attn = layer.self_attn - if attn.is_sliding: - rotary_dim = attn.head_dim - else: - rotary_dim = int(attn.head_dim * attn.partial_rotary) - rope_angles = rotary_dim // 2 - inv_freq_rotated = 1.0 / ( - attn.rope_theta - ** ( - torch.arange(0, rotary_dim, 2, device=device, dtype=torch.float32) - / attn.head_dim - ) + attn.register_buffer( + "inv_freq", attn._compute_inv_freq(device=device), persistent=False ) - nope_angles = attn.head_dim // 2 - rope_angles - if nope_angles > 0: - inv_freq = torch.cat( - [ - inv_freq_rotated, - torch.zeros(nope_angles, device=device, dtype=torch.float32), - ] - ) - else: - inv_freq = inv_freq_rotated - attn.register_buffer("inv_freq", inv_freq, persistent=False) model.register_buffer( "embed_normalizer",