diff --git a/.ci/scripts/export_model_artifact.sh b/.ci/scripts/export_model_artifact.sh index 1f75d850e84..9adea394993 100755 --- a/.ci/scripts/export_model_artifact.sh +++ b/.ci/scripts/export_model_artifact.sh @@ -195,9 +195,17 @@ case "$HF_MODEL" in PREPROCESSOR_FEATURE_SIZE="" PREPROCESSOR_OUTPUT="" ;; + SocialLocalMobile/gemma-4-31B-it-HQQ-INT4) + MODEL_NAME="gemma4_31b" + TASK="" + MAX_SEQ_LEN="" + EXTRA_PIP="" + PREPROCESSOR_FEATURE_SIZE="" + PREPROCESSOR_OUTPUT="" + ;; *) echo "Error: Unsupported model '$HF_MODEL'" - echo "Supported models: mistralai/Voxtral-Mini-3B-2507, mistralai/Voxtral-Mini-4B-Realtime-2602, openai/whisper-{small, medium, large, large-v2, large-v3, large-v3-turbo}, google/gemma-3-4b-it, Qwen/Qwen3-0.6B, nvidia/diar_streaming_sortformer_4spk-v2, nvidia/parakeet-tdt, facebook/dinov2-small-imagenet1k-1-layer, SocialLocalMobile/Qwen3.5-35B-A3B-HQQ-INT4" + echo "Supported models: mistralai/Voxtral-Mini-3B-2507, mistralai/Voxtral-Mini-4B-Realtime-2602, openai/whisper-{small, medium, large, large-v2, large-v3, large-v3-turbo}, google/gemma-3-4b-it, Qwen/Qwen3-0.6B, nvidia/diar_streaming_sortformer_4spk-v2, nvidia/parakeet-tdt, facebook/dinov2-small-imagenet1k-1-layer, SocialLocalMobile/Qwen3.5-35B-A3B-HQQ-INT4, SocialLocalMobile/gemma-4-31B-it-HQQ-INT4" exit 1 ;; esac @@ -459,6 +467,50 @@ if [ "$MODEL_NAME" = "qwen3_5_moe" ]; then exit 0 fi +# Gemma 4 31B uses a prequantized checkpoint and custom export script +if [ "$MODEL_NAME" = "gemma4_31b" ]; then + pip install safetensors huggingface_hub gguf + + # Download prequantized model outside OUTPUT_DIR to avoid uploading on failure + LOCAL_MODEL_DIR=$(mktemp -d) + INDUCTOR_CACHE=$(mktemp -d) + trap 'rm -rf "$LOCAL_MODEL_DIR" "$INDUCTOR_CACHE"' EXIT + + python -c "from huggingface_hub import snapshot_download; snapshot_download('${HF_MODEL}', local_dir='${LOCAL_MODEL_DIR}')" + + # Sanity check: run inference on the prequantized model + echo "::group::Inference sanity check" + INFERENCE_OUTPUT=$(python -m executorch.examples.models.gemma4_31b.inference \ + --prequantized "$LOCAL_MODEL_DIR" \ + --prompt "What is the capital of France?" \ + --max-new-tokens 32 \ + --temperature 0 \ + --no-compile 2>&1) + echo "$INFERENCE_OUTPUT" + if ! echo "$INFERENCE_OUTPUT" | grep -q "Paris"; then + echo "ERROR: Inference sanity check failed — expected 'Paris' in output" + exit 1 + fi + echo "::endgroup::" + + # Copy tokenizer for the runner + cp "$LOCAL_MODEL_DIR/tokenizer.json" "${OUTPUT_DIR}/tokenizer.json" + + # Export to .pte/.ptd (short cache dir avoids objcopy symbol length issues) + echo "::group::Export" + TORCHINDUCTOR_CACHE_DIR="$INDUCTOR_CACHE" \ + python -m executorch.examples.models.gemma4_31b.export \ + --prequantized "$LOCAL_MODEL_DIR" \ + --output-dir "${OUTPUT_DIR}" + echo "::endgroup::" + + test -f "${OUTPUT_DIR}/model.pte" + test -f "${OUTPUT_DIR}/aoti_cuda_blob.ptd" + ls -al "${OUTPUT_DIR}" + + exit 0 +fi + MAX_SEQ_LEN_ARG="" if [ -n "$MAX_SEQ_LEN" ]; then MAX_SEQ_LEN_ARG="--max_seq_len $MAX_SEQ_LEN" diff --git a/.ci/scripts/test_model_e2e.sh b/.ci/scripts/test_model_e2e.sh index 1678b0a4fbb..27b0dd9d597 100755 --- a/.ci/scripts/test_model_e2e.sh +++ b/.ci/scripts/test_model_e2e.sh @@ -228,9 +228,21 @@ case "$HF_MODEL" in AUDIO_FILE="" IMAGE_PATH="" ;; + SocialLocalMobile/gemma-4-31B-it-HQQ-INT4) + MODEL_NAME="gemma4_31b" + RUNNER_TARGET="gemma4_31b_runner" + RUNNER_PATH="gemma4_31b" + EXPECTED_OUTPUT="Paris" + PREPROCESSOR="" + TOKENIZER_URL="" + TOKENIZER_FILE="tokenizer.json" + AUDIO_URL="" + AUDIO_FILE="" + IMAGE_PATH="" + ;; *) echo "Error: Unsupported model '$HF_MODEL'" - echo "Supported models: mistralai/Voxtral-Mini-3B-2507, mistralai/Voxtral-Mini-4B-Realtime-2602, nvidia/diar_streaming_sortformer_4spk-v2, openai/whisper series (whisper-{small, medium, large, large-v2, large-v3, large-v3-turbo}), google/gemma-3-4b-it, Qwen/Qwen3-0.6B, nvidia/parakeet-tdt, facebook/dinov2-small-imagenet1k-1-layer, SocialLocalMobile/Qwen3.5-35B-A3B-HQQ-INT4" + echo "Supported models: mistralai/Voxtral-Mini-3B-2507, mistralai/Voxtral-Mini-4B-Realtime-2602, nvidia/diar_streaming_sortformer_4spk-v2, openai/whisper series (whisper-{small, medium, large, large-v2, large-v3, large-v3-turbo}), google/gemma-3-4b-it, Qwen/Qwen3-0.6B, nvidia/parakeet-tdt, facebook/dinov2-small-imagenet1k-1-layer, SocialLocalMobile/Qwen3.5-35B-A3B-HQQ-INT4, SocialLocalMobile/gemma-4-31B-it-HQQ-INT4" exit 1 ;; esac @@ -244,7 +256,7 @@ echo "::group::Prepare $MODEL_NAME Artifacts" # Download tokenizer files (skip for models that bundle tokenizer in export or do not use one) -if [ "$MODEL_NAME" != "parakeet" ] && [ "$MODEL_NAME" != "voxtral_realtime" ] && [ "$MODEL_NAME" != "sortformer" ] && [ "$MODEL_NAME" != "dinov2" ] && [ "$MODEL_NAME" != "qwen3_5_moe" ]; then +if [ "$MODEL_NAME" != "parakeet" ] && [ "$MODEL_NAME" != "voxtral_realtime" ] && [ "$MODEL_NAME" != "sortformer" ] && [ "$MODEL_NAME" != "dinov2" ] && [ "$MODEL_NAME" != "qwen3_5_moe" ] && [ "$MODEL_NAME" != "gemma4_31b" ]; then if [ "$TOKENIZER_FILE" != "" ]; then curl -L $TOKENIZER_URL/$TOKENIZER_FILE -o $MODEL_DIR/$TOKENIZER_FILE else @@ -368,6 +380,9 @@ EOF qwen3_5_moe) RUNNER_ARGS="$RUNNER_ARGS --tokenizer_path ${MODEL_DIR}/$TOKENIZER_FILE --prompt 'What is the capital of France?' --max_new_tokens 128 --temperature 0 --cuda_graph" ;; + gemma4_31b) + RUNNER_ARGS="$RUNNER_ARGS --tokenizer_path ${MODEL_DIR}/$TOKENIZER_FILE --prompt 'What is the capital of France?' --max_new_tokens 128 --temperature 0 --cuda_graph" + ;; voxtral_realtime) RUNNER_ARGS="--model_path ${MODEL_DIR}/model.pte --tokenizer_path ${MODEL_DIR}/$TOKENIZER_FILE --preprocessor_path ${MODEL_DIR}/$PREPROCESSOR --audio_path ${MODEL_DIR}/$AUDIO_FILE --temperature 0" # Add CUDA data path if present diff --git a/.github/workflows/cuda.yml b/.github/workflows/cuda.yml index 087917c1116..eb7fc5a8939 100644 --- a/.github/workflows/cuda.yml +++ b/.github/workflows/cuda.yml @@ -185,6 +185,8 @@ jobs: name: "dinov2-small-imagenet1k-1-layer" - repo: "SocialLocalMobile" name: "Qwen3.5-35B-A3B-HQQ-INT4" + - repo: "SocialLocalMobile" + name: "gemma-4-31B-it-HQQ-INT4" quant: - "non-quantized" - "quantized-int4-tile-packed" @@ -204,6 +206,15 @@ jobs: repo: "SocialLocalMobile" name: "Qwen3.5-35B-A3B-HQQ-INT4" quant: "quantized-int4-weight-only" + # Gemma 4 31B uses a prequantized checkpoint, only tile-packed + - model: + repo: "SocialLocalMobile" + name: "gemma-4-31B-it-HQQ-INT4" + quant: "non-quantized" + - model: + repo: "SocialLocalMobile" + name: "gemma-4-31B-it-HQQ-INT4" + quant: "quantized-int4-weight-only" # Voxtral Realtime only supports int4-tile-packed on CUDA - model: repo: "mistralai" @@ -258,7 +269,7 @@ jobs: with: timeout: 90 secrets-env: EXECUTORCH_HF_TOKEN - runner: ${{ matrix.model.name == 'Qwen3.5-35B-A3B-HQQ-INT4' && 'linux.aws.a100' || 'linux.g5.4xlarge.nvidia.gpu' }} + runner: ${{ (matrix.model.name == 'Qwen3.5-35B-A3B-HQQ-INT4' || matrix.model.name == 'gemma-4-31B-it-HQQ-INT4') && 'linux.aws.a100' || 'linux.g5.4xlarge.nvidia.gpu' }} gpu-arch-type: cuda gpu-arch-version: 12.6 use-custom-docker-registry: false @@ -315,6 +326,8 @@ jobs: name: "dinov2-small-imagenet1k-1-layer" - repo: "SocialLocalMobile" name: "Qwen3.5-35B-A3B-HQQ-INT4" + - repo: "SocialLocalMobile" + name: "gemma-4-31B-it-HQQ-INT4" quant: - "non-quantized" - "quantized-int4-tile-packed" @@ -334,6 +347,15 @@ jobs: repo: "SocialLocalMobile" name: "Qwen3.5-35B-A3B-HQQ-INT4" quant: "quantized-int4-weight-only" + # Gemma 4 31B uses a prequantized checkpoint, only tile-packed + - model: + repo: "SocialLocalMobile" + name: "gemma-4-31B-it-HQQ-INT4" + quant: "non-quantized" + - model: + repo: "SocialLocalMobile" + name: "gemma-4-31B-it-HQQ-INT4" + quant: "quantized-int4-weight-only" # Voxtral Realtime only supports int4-tile-packed on CUDA - model: repo: "mistralai" @@ -382,7 +404,7 @@ jobs: quant: "non-quantized" with: timeout: 90 - runner: ${{ matrix.model.name == 'Qwen3.5-35B-A3B-HQQ-INT4' && 'linux.aws.a100' || 'linux.g5.4xlarge.nvidia.gpu' }} + runner: ${{ (matrix.model.name == 'Qwen3.5-35B-A3B-HQQ-INT4' || matrix.model.name == 'gemma-4-31B-it-HQQ-INT4') && 'linux.aws.a100' || 'linux.g5.4xlarge.nvidia.gpu' }} gpu-arch-type: cuda gpu-arch-version: 12.6 use-custom-docker-registry: false diff --git a/examples/models/gemma4_31b/README.md b/examples/models/gemma4_31b/README.md index 6f567d739b7..94783c8f823 100644 --- a/examples/models/gemma4_31b/README.md +++ b/examples/models/gemma4_31b/README.md @@ -79,6 +79,9 @@ Writes `model.pte` and `model.ptd` into `--output-dir`. ## Eager inference +The prompt is automatically wrapped with the Gemma 4 IT chat template. +Pass `--raw-prompt` to skip template wrapping for pre-formatted input. + ```bash python examples/models/gemma4_31b/inference.py \ --prequantized ./gemma4_31b_int4 \ @@ -109,6 +112,9 @@ The binary lands at `cmake-out/examples/models/gemma4_31b/gemma4_31b_runner`. ## Run the .pte +The prompt is automatically wrapped with the Gemma 4 IT chat template. +Pass `--raw_prompt` to skip template wrapping for pre-formatted input. + ```bash ./gemma4_31b_runner \ --model_path ./gemma4_31b_exports/model.pte \ diff --git a/examples/models/gemma4_31b/inference.py b/examples/models/gemma4_31b/inference.py index 12785450d8c..e1563c04ff6 100644 --- a/examples/models/gemma4_31b/inference.py +++ b/examples/models/gemma4_31b/inference.py @@ -6,12 +6,15 @@ """Eager inference on Gemma 4 31B-IT (CUDA + torch.compile). -Two input paths: +Three input paths: --prequantized Load a quantized checkpoint (from quantize_and_save.py). --gguf Load a GGUF file (e.g., Q4_K_M from the community). + --bf16 Load the bf16 HF safetensors checkpoint via from_hf_checkpoint. -Packs for the target backend (--backend cuda), materializes runtime buffers, -optionally compiles with ``torch.compile``, and generates text autoregressively. +Gemma 4 31B-IT is instruction-tuned and requires chat-template formatting. +The ``--prompt`` is automatically wrapped with the Gemma 4 chat template +(``<|turn>user\\n{prompt}\\n<|turn>model\\n<|channel>thought\\n``; BOS is prepended separately). +Pass ``--raw-prompt`` to skip template wrapping (e.g., for pre-formatted input). Usage: python inference.py \\ @@ -33,7 +36,10 @@ import torch from executorch.examples.models.gemma4_31b.export import load_prequantized_model -from executorch.examples.models.gemma4_31b.model import materialize_runtime_buffers +from executorch.examples.models.gemma4_31b.model import ( + Gemma4_31B, + materialize_runtime_buffers, +) def _move_to_cuda(model, config) -> None: @@ -63,6 +69,18 @@ def _move_to_cuda(model, config) -> None: materialize_runtime_buffers(model, dtype=torch.bfloat16, device="cuda") +def apply_chat_template(prompt: str) -> str: + """Wrap a user prompt in the Gemma 4 IT chat template. + + Does not include BOS — ``generate()`` prepends it at the token-ID level. + """ + return ( + "<|turn>user\n" + + prompt + + "\n<|turn>model\n<|channel>thought\n" + ) + + def generate( model, tokenizer, @@ -131,6 +149,11 @@ def main() -> None: default=None, help="Path to a GGUF file (e.g., gemma-4-31B-it-Q4_K_M.gguf).", ) + src.add_argument( + "--bf16", + default=None, + help="Path to a bf16 hf directory (e.g., gemma-4-31B).", + ) parser.add_argument( "--tokenizer-path", default=None, @@ -155,6 +178,11 @@ def main() -> None: default=4096, help="KV cache length to allocate for this run.", ) + parser.add_argument( + "--raw-prompt", + action="store_true", + help="Skip chat-template wrapping (use if the prompt is already formatted).", + ) parser.add_argument( "--no-compile", action="store_true", @@ -171,12 +199,34 @@ def main() -> None: if args.backend == "cuda" and not torch.cuda.is_available(): parser.error("CUDA is required for the cuda backend.") + # ---- Tokenizer ---- + if args.tokenizer_path: + tokenizer_path = args.tokenizer_path + elif args.prequantized: + tokenizer_path = os.path.join(args.prequantized, "tokenizer.json") + elif args.bf16: + tokenizer_path = os.path.join(args.bf16, "tokenizer.json") + else: + parser.error("--tokenizer-path is required with --gguf.") + from tokenizers import Tokenizer + + tokenizer = Tokenizer.from_file(tokenizer_path) + + prompt_str = args.prompt if args.raw_prompt else apply_chat_template(args.prompt) + + # Gemma 4 EOS tokens (from generation_config.json: ids 1, 50, 106). + eos_token_ids = {1, 50, 106} + if args.gguf: from executorch.examples.models.gemma4_31b.gguf_loader import load_gguf_model model, config = load_gguf_model( args.gguf, args.max_seq_len, backend=args.backend ) + elif args.bf16: + model, config = Gemma4_31B.from_hf_checkpoint( + args.bf16, max_seq_len=args.max_seq_len + ) else: print(f"Loading prequantized model from {args.prequantized}...") model, config = load_prequantized_model( @@ -191,19 +241,6 @@ def main() -> None: print("Compiling model with torch.compile...") model = torch.compile(model, mode="default") - if args.tokenizer_path: - tokenizer_path = args.tokenizer_path - elif args.prequantized: - tokenizer_path = os.path.join(args.prequantized, "tokenizer.json") - else: - parser.error("--tokenizer-path is required with --gguf.") - from tokenizers import Tokenizer - - tokenizer = Tokenizer.from_file(tokenizer_path) - - # Gemma 4 EOS tokens (from generation_config.json: ids 1, 50, 106). - eos_token_ids = {1, 50, 106} - print(f"\nPrompt: {args.prompt}") print("-" * 40) @@ -211,7 +248,7 @@ def main() -> None: output = generate( model, tokenizer, - args.prompt, + prompt_str, max_new_tokens=args.max_new_tokens, temperature=args.temperature, eos_token_ids=eos_token_ids, diff --git a/examples/models/gemma4_31b/main.cpp b/examples/models/gemma4_31b/main.cpp index 0be2fef517c..3ddf64e410f 100644 --- a/examples/models/gemma4_31b/main.cpp +++ b/examples/models/gemma4_31b/main.cpp @@ -65,6 +65,10 @@ DEFINE_double(temperature, 0.8, "Sampling temperature (0 = near-greedy)."); DEFINE_int32(max_new_tokens, 128, "Maximum tokens to generate."); DEFINE_int32(bos_id, 2, "BOS token id to prepend (Gemma convention: 2)."); DEFINE_int32(eos_id, 1, "EOS token id (Gemma convention: 1)."); +DEFINE_bool( + raw_prompt, + false, + "Skip chat-template wrapping (use if the prompt is already formatted)."); DEFINE_bool( cuda_graph, false, @@ -232,6 +236,14 @@ int main(int argc, char** argv) { (std::istreambuf_iterator(f)), std::istreambuf_iterator()); } + // Wrap with Gemma 4 IT chat template unless --raw_prompt is set. + // BOS is prepended separately below; this adds the turn structure and the + // empty thought block required by the instruction-tuned model. + if (!FLAGS_raw_prompt) { + prompt_text = "<|turn>user\n" + prompt_text + + "\n<|turn>model\n<|channel>thought\n"; + } + // Encode prompt auto encode_result = tokenizer->encode(prompt_text); if (!encode_result.ok()) { diff --git a/examples/models/gemma4_31b/model.md b/examples/models/gemma4_31b/model.md index 8233b6d430e..51e420528f1 100644 --- a/examples/models/gemma4_31b/model.md +++ b/examples/models/gemma4_31b/model.md @@ -165,7 +165,7 @@ RoPE inv_freq buffers, and scalar constants are still on the meta device. them with real tensors: - KV caches → zeros in `dtype` (bf16 for inference, bf16 for export) -- `inv_freq` → moved to target device (cos/sin computed on the fly per forward) +- `inv_freq` → recomputed on target device (cos/sin computed on the fly per forward) - `embed_normalizer`, `logit_softcap`, `cache_positions` → scalar constants Called by `export.py` (device="cpu" for tracing) and `inference.py` diff --git a/examples/models/gemma4_31b/model.py b/examples/models/gemma4_31b/model.py index b0eb4004c52..f0aa2fac982 100644 --- a/examples/models/gemma4_31b/model.py +++ b/examples/models/gemma4_31b/model.py @@ -251,25 +251,7 @@ def __init__(self, config: Gemma4_31BConfig, layer_idx: int): self.k_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps) self.v_norm = RMSNormNoWeight(self.head_dim, eps=config.rms_norm_eps) - # Precomputed RoPE table for this layer (per-layer because head_dim - # and theta differ between sliding and full attention). For full - # attention layers we pass freq_base_dim=head_dim so the zero-padded - # On-the-fly RoPE: store only inv_freq, compute cos/sin per forward. - # Saves memory vs precomputed [max_seq_len, head_dim] tables. - if self.is_sliding: - rotary_dim = self.head_dim - else: - rotary_dim = int(self.head_dim * self.partial_rotary) - rope_angles = rotary_dim // 2 - inv_freq_rotated = 1.0 / ( - self.rope_theta ** (torch.arange(0, rotary_dim, 2).float() / self.head_dim) - ) - nope_angles = self.head_dim // 2 - rope_angles - if nope_angles > 0: - inv_freq = torch.cat([inv_freq_rotated, torch.zeros(nope_angles)]) - else: - inv_freq = inv_freq_rotated - self.register_buffer("inv_freq", inv_freq, persistent=False) + self.register_buffer("inv_freq", self._compute_inv_freq(), persistent=False) # KV cache. Sliding layers use a ring buffer (2x window) to save # memory; full layers use a flat buffer (max_seq_len). @@ -289,6 +271,30 @@ def __init__(self, config: Gemma4_31BConfig, layer_idx: int): use_index_copy=True, ) + def _compute_inv_freq(self, device: Optional[torch.device] = None) -> torch.Tensor: + """Compute RoPE inverse-frequency table for this layer.""" + if self.is_sliding: + rotary_dim = self.head_dim + else: + rotary_dim = int(self.head_dim * self.partial_rotary) + rope_angles = rotary_dim // 2 + inv_freq_rotated = 1.0 / ( + self.rope_theta + ** ( + torch.arange(0, rotary_dim, 2, device=device, dtype=torch.float32) + / self.head_dim + ) + ) + nope_angles = self.head_dim // 2 - rope_angles + if nope_angles > 0: + return torch.cat( + [ + inv_freq_rotated, + torch.zeros(nope_angles, device=device, dtype=torch.float32), + ] + ) + return inv_freq_rotated + def forward( self, x: torch.Tensor, @@ -675,7 +681,9 @@ def materialize_runtime_buffers( for layer in model.layers: attn = layer.self_attn - attn.inv_freq = attn.inv_freq.to(device) + attn.register_buffer( + "inv_freq", attn._compute_inv_freq(device=device), persistent=False + ) model.register_buffer( "embed_normalizer",