Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 48 additions & 1 deletion .ci/scripts/export_model_artifact.sh
Original file line number Diff line number Diff line change
Expand Up @@ -195,9 +195,17 @@ case "$HF_MODEL" in
PREPROCESSOR_FEATURE_SIZE=""
PREPROCESSOR_OUTPUT=""
;;
SocialLocalMobile/gemma-4-31B-it-HQQ-INT4)
MODEL_NAME="gemma4_31b"
TASK=""
MAX_SEQ_LEN=""
EXTRA_PIP=""
PREPROCESSOR_FEATURE_SIZE=""
PREPROCESSOR_OUTPUT=""
;;
*)
echo "Error: Unsupported model '$HF_MODEL'"
echo "Supported models: mistralai/Voxtral-Mini-3B-2507, mistralai/Voxtral-Mini-4B-Realtime-2602, openai/whisper-{small, medium, large, large-v2, large-v3, large-v3-turbo}, google/gemma-3-4b-it, Qwen/Qwen3-0.6B, nvidia/diar_streaming_sortformer_4spk-v2, nvidia/parakeet-tdt, facebook/dinov2-small-imagenet1k-1-layer, SocialLocalMobile/Qwen3.5-35B-A3B-HQQ-INT4"
echo "Supported models: mistralai/Voxtral-Mini-3B-2507, mistralai/Voxtral-Mini-4B-Realtime-2602, openai/whisper-{small, medium, large, large-v2, large-v3, large-v3-turbo}, google/gemma-3-4b-it, Qwen/Qwen3-0.6B, nvidia/diar_streaming_sortformer_4spk-v2, nvidia/parakeet-tdt, facebook/dinov2-small-imagenet1k-1-layer, SocialLocalMobile/Qwen3.5-35B-A3B-HQQ-INT4, SocialLocalMobile/gemma-4-31B-it-HQQ-INT4"
exit 1
;;
esac
Expand Down Expand Up @@ -459,6 +467,45 @@ if [ "$MODEL_NAME" = "qwen3_5_moe" ]; then
exit 0
fi

# Gemma 4 31B uses a prequantized checkpoint and custom export script
if [ "$MODEL_NAME" = "gemma4_31b" ]; then
pip install safetensors huggingface_hub gguf

# Download prequantized model outside OUTPUT_DIR to avoid uploading on failure
LOCAL_MODEL_DIR=$(mktemp -d)
INDUCTOR_CACHE=$(mktemp -d)
trap 'rm -rf "$LOCAL_MODEL_DIR" "$INDUCTOR_CACHE"' EXIT

python -c "from huggingface_hub import snapshot_download; snapshot_download('${HF_MODEL}', local_dir='${LOCAL_MODEL_DIR}')"

# Sanity check: run inference on the prequantized model
echo "::group::Inference sanity check"
python -m executorch.examples.models.gemma4_31b.inference \
--prequantized "$LOCAL_MODEL_DIR" \
--prompt "What is the capital of France?" \
--max-new-tokens 32 \
--temperature 0 \
--no-compile
echo "::endgroup::"

# Copy tokenizer for the runner
cp "$LOCAL_MODEL_DIR/tokenizer.json" "${OUTPUT_DIR}/tokenizer.json"

# Export to .pte/.ptd (short cache dir avoids objcopy symbol length issues)
echo "::group::Export"
TORCHINDUCTOR_CACHE_DIR="$INDUCTOR_CACHE" \
python -m executorch.examples.models.gemma4_31b.export \
--prequantized "$LOCAL_MODEL_DIR" \
--output-dir "${OUTPUT_DIR}"
echo "::endgroup::"

test -f "${OUTPUT_DIR}/model.pte"
test -f "${OUTPUT_DIR}/aoti_cuda_blob.ptd"
ls -al "${OUTPUT_DIR}"

exit 0
fi

MAX_SEQ_LEN_ARG=""
if [ -n "$MAX_SEQ_LEN" ]; then
MAX_SEQ_LEN_ARG="--max_seq_len $MAX_SEQ_LEN"
Expand Down
19 changes: 17 additions & 2 deletions .ci/scripts/test_model_e2e.sh
Original file line number Diff line number Diff line change
Expand Up @@ -228,9 +228,21 @@ case "$HF_MODEL" in
AUDIO_FILE=""
IMAGE_PATH=""
;;
SocialLocalMobile/gemma-4-31B-it-HQQ-INT4)
MODEL_NAME="gemma4_31b"
RUNNER_TARGET="gemma4_31b_runner"
RUNNER_PATH="gemma4_31b"
EXPECTED_OUTPUT="Paris"
PREPROCESSOR=""
TOKENIZER_URL=""
TOKENIZER_FILE="tokenizer.json"
AUDIO_URL=""
AUDIO_FILE=""
IMAGE_PATH=""
;;
*)
echo "Error: Unsupported model '$HF_MODEL'"
echo "Supported models: mistralai/Voxtral-Mini-3B-2507, mistralai/Voxtral-Mini-4B-Realtime-2602, nvidia/diar_streaming_sortformer_4spk-v2, openai/whisper series (whisper-{small, medium, large, large-v2, large-v3, large-v3-turbo}), google/gemma-3-4b-it, Qwen/Qwen3-0.6B, nvidia/parakeet-tdt, facebook/dinov2-small-imagenet1k-1-layer, SocialLocalMobile/Qwen3.5-35B-A3B-HQQ-INT4"
echo "Supported models: mistralai/Voxtral-Mini-3B-2507, mistralai/Voxtral-Mini-4B-Realtime-2602, nvidia/diar_streaming_sortformer_4spk-v2, openai/whisper series (whisper-{small, medium, large, large-v2, large-v3, large-v3-turbo}), google/gemma-3-4b-it, Qwen/Qwen3-0.6B, nvidia/parakeet-tdt, facebook/dinov2-small-imagenet1k-1-layer, SocialLocalMobile/Qwen3.5-35B-A3B-HQQ-INT4, SocialLocalMobile/gemma-4-31B-it-HQQ-INT4"
exit 1
;;
esac
Expand All @@ -244,7 +256,7 @@ echo "::group::Prepare $MODEL_NAME Artifacts"


# Download tokenizer files (skip for models that bundle tokenizer in export or do not use one)
if [ "$MODEL_NAME" != "parakeet" ] && [ "$MODEL_NAME" != "voxtral_realtime" ] && [ "$MODEL_NAME" != "sortformer" ] && [ "$MODEL_NAME" != "dinov2" ] && [ "$MODEL_NAME" != "qwen3_5_moe" ]; then
if [ "$MODEL_NAME" != "parakeet" ] && [ "$MODEL_NAME" != "voxtral_realtime" ] && [ "$MODEL_NAME" != "sortformer" ] && [ "$MODEL_NAME" != "dinov2" ] && [ "$MODEL_NAME" != "qwen3_5_moe" ] && [ "$MODEL_NAME" != "gemma4_31b" ]; then
if [ "$TOKENIZER_FILE" != "" ]; then
curl -L $TOKENIZER_URL/$TOKENIZER_FILE -o $MODEL_DIR/$TOKENIZER_FILE
else
Expand Down Expand Up @@ -368,6 +380,9 @@ EOF
qwen3_5_moe)
RUNNER_ARGS="$RUNNER_ARGS --tokenizer_path ${MODEL_DIR}/$TOKENIZER_FILE --prompt 'What is the capital of France?' --max_new_tokens 128 --temperature 0 --cuda_graph"
;;
gemma4_31b)
RUNNER_ARGS="$RUNNER_ARGS --tokenizer_path ${MODEL_DIR}/$TOKENIZER_FILE --prompt 'What is the capital of France?' --max_new_tokens 128 --temperature 0 --cuda_graph"
;;
voxtral_realtime)
RUNNER_ARGS="--model_path ${MODEL_DIR}/model.pte --tokenizer_path ${MODEL_DIR}/$TOKENIZER_FILE --preprocessor_path ${MODEL_DIR}/$PREPROCESSOR --audio_path ${MODEL_DIR}/$AUDIO_FILE --temperature 0"
# Add CUDA data path if present
Expand Down
30 changes: 24 additions & 6 deletions .github/workflows/cuda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -148,10 +148,6 @@ jobs:
# Run Qwen 3.5 MoE tests (quantize roundtrip + TurboQuant KV cache + sampler)
python -m pytest examples/models/qwen3_5_moe/test_quantize_roundtrip.py examples/models/qwen3_5_moe/test_turboquant.py examples/models/qwen3_5_moe/test_sampler.py -v -o "addopts="

# Run Gemma 4 31B tests (quant unit tests + pipeline integration tests)
pip install gguf
python -m pytest examples/models/gemma4_31b/quant/tests/ examples/models/gemma4_31b/tests/ -v -o "addopts="

export-model-cuda-artifact:
name: export-model-cuda-artifact
# Skip this job if the pull request is from a fork (HuggingFace secrets are not available)
Expand Down Expand Up @@ -185,6 +181,8 @@ jobs:
name: "dinov2-small-imagenet1k-1-layer"
- repo: "SocialLocalMobile"
name: "Qwen3.5-35B-A3B-HQQ-INT4"
- repo: "SocialLocalMobile"
name: "gemma-4-31B-it-HQQ-INT4"
quant:
- "non-quantized"
- "quantized-int4-tile-packed"
Expand All @@ -204,6 +202,15 @@ jobs:
repo: "SocialLocalMobile"
name: "Qwen3.5-35B-A3B-HQQ-INT4"
quant: "quantized-int4-weight-only"
# Gemma 4 31B uses a prequantized checkpoint, only tile-packed
- model:
repo: "SocialLocalMobile"
name: "gemma-4-31B-it-HQQ-INT4"
quant: "non-quantized"
- model:
repo: "SocialLocalMobile"
name: "gemma-4-31B-it-HQQ-INT4"
quant: "quantized-int4-weight-only"
# Voxtral Realtime only supports int4-tile-packed on CUDA
- model:
repo: "mistralai"
Expand Down Expand Up @@ -258,7 +265,7 @@ jobs:
with:
timeout: 90
secrets-env: EXECUTORCH_HF_TOKEN
runner: ${{ matrix.model.name == 'Qwen3.5-35B-A3B-HQQ-INT4' && 'linux.aws.a100' || 'linux.g5.4xlarge.nvidia.gpu' }}
runner: ${{ (matrix.model.name == 'Qwen3.5-35B-A3B-HQQ-INT4' || matrix.model.name == 'gemma-4-31B-it-HQQ-INT4') && 'linux.aws.a100' || 'linux.g5.4xlarge.nvidia.gpu' }}
gpu-arch-type: cuda
gpu-arch-version: 12.6
use-custom-docker-registry: false
Expand Down Expand Up @@ -315,6 +322,8 @@ jobs:
name: "dinov2-small-imagenet1k-1-layer"
- repo: "SocialLocalMobile"
name: "Qwen3.5-35B-A3B-HQQ-INT4"
- repo: "SocialLocalMobile"
name: "gemma-4-31B-it-HQQ-INT4"
quant:
- "non-quantized"
- "quantized-int4-tile-packed"
Expand All @@ -334,6 +343,15 @@ jobs:
repo: "SocialLocalMobile"
name: "Qwen3.5-35B-A3B-HQQ-INT4"
quant: "quantized-int4-weight-only"
# Gemma 4 31B uses a prequantized checkpoint, only tile-packed
- model:
repo: "SocialLocalMobile"
name: "gemma-4-31B-it-HQQ-INT4"
quant: "non-quantized"
- model:
repo: "SocialLocalMobile"
name: "gemma-4-31B-it-HQQ-INT4"
quant: "quantized-int4-weight-only"
# Voxtral Realtime only supports int4-tile-packed on CUDA
- model:
repo: "mistralai"
Expand Down Expand Up @@ -382,7 +400,7 @@ jobs:
quant: "non-quantized"
with:
timeout: 90
runner: ${{ matrix.model.name == 'Qwen3.5-35B-A3B-HQQ-INT4' && 'linux.aws.a100' || 'linux.g5.4xlarge.nvidia.gpu' }}
runner: ${{ (matrix.model.name == 'Qwen3.5-35B-A3B-HQQ-INT4' || matrix.model.name == 'gemma-4-31B-it-HQQ-INT4') && 'linux.aws.a100' || 'linux.g5.4xlarge.nvidia.gpu' }}
gpu-arch-type: cuda
gpu-arch-version: 12.6
use-custom-docker-registry: false
Expand Down
58 changes: 36 additions & 22 deletions examples/models/gemma4_31b/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,10 @@

"""Eager inference on Gemma 4 31B-IT (CUDA + torch.compile).

Two input paths:
Three input paths:
--prequantized <dir> Load a quantized checkpoint (from quantize_and_save.py).
--gguf <file> Load a GGUF file (e.g., Q4_K_M from the community).

Packs for the target backend (--backend cuda), materializes runtime buffers,
optionally compiles with ``torch.compile``, and generates text autoregressively.
--bf16 <dir> Load the bf16 HF safetensors checkpoint via from_hf_checkpoint.

Gemma 4 31B-IT is instruction-tuned and requires chat-template formatting.
The ``--prompt`` is automatically wrapped with the Gemma 4 chat template
Expand All @@ -38,7 +36,10 @@
import torch

from executorch.examples.models.gemma4_31b.export import load_prequantized_model
from executorch.examples.models.gemma4_31b.model import materialize_runtime_buffers
from executorch.examples.models.gemma4_31b.model import (
Gemma4_31B,
materialize_runtime_buffers,
)


def _move_to_cuda(model, config) -> None:
Expand Down Expand Up @@ -74,7 +75,8 @@ def apply_chat_template(prompt: str) -> str:
Does not include BOS — ``generate()`` prepends it at the token-ID level.
"""
return (
"<|turn>user\n" + prompt
"<|turn>user\n"
+ prompt
+ "<turn|>\n<|turn>model\n<|channel>thought\n<channel|>"
)

Expand Down Expand Up @@ -147,6 +149,11 @@ def main() -> None:
default=None,
help="Path to a GGUF file (e.g., gemma-4-31B-it-Q4_K_M.gguf).",
)
src.add_argument(
"--bf16",
default=None,
help="Path to a bf16 hf directory (e.g., gemma-4-31B).",
)
parser.add_argument(
"--tokenizer-path",
default=None,
Expand Down Expand Up @@ -192,12 +199,34 @@ def main() -> None:
if args.backend == "cuda" and not torch.cuda.is_available():
parser.error("CUDA is required for the cuda backend.")

# ---- Tokenizer ----
if args.tokenizer_path:
tokenizer_path = args.tokenizer_path
elif args.prequantized:
tokenizer_path = os.path.join(args.prequantized, "tokenizer.json")
elif args.bf16:
tokenizer_path = os.path.join(args.bf16, "tokenizer.json")
else:
parser.error("--tokenizer-path is required with --gguf.")
from tokenizers import Tokenizer

tokenizer = Tokenizer.from_file(tokenizer_path)

prompt_str = args.prompt if args.raw_prompt else apply_chat_template(args.prompt)

# Gemma 4 EOS tokens (from generation_config.json: ids 1, 50, 106).
eos_token_ids = {1, 50, 106}

if args.gguf:
from executorch.examples.models.gemma4_31b.gguf_loader import load_gguf_model

model, config = load_gguf_model(
args.gguf, args.max_seq_len, backend=args.backend
)
elif args.bf16:
model, config = Gemma4_31B.from_hf_checkpoint(
args.bf16, max_seq_len=args.max_seq_len
)
else:
print(f"Loading prequantized model from {args.prequantized}...")
model, config = load_prequantized_model(
Expand All @@ -212,29 +241,14 @@ def main() -> None:
print("Compiling model with torch.compile...")
model = torch.compile(model, mode="default")

if args.tokenizer_path:
tokenizer_path = args.tokenizer_path
elif args.prequantized:
tokenizer_path = os.path.join(args.prequantized, "tokenizer.json")
else:
parser.error("--tokenizer-path is required with --gguf.")
from tokenizers import Tokenizer

tokenizer = Tokenizer.from_file(tokenizer_path)

# Gemma 4 EOS tokens (from generation_config.json: ids 1, 50, 106).
eos_token_ids = {1, 50, 106}

prompt = args.prompt if args.raw_prompt else apply_chat_template(args.prompt)

print(f"\nPrompt: {args.prompt}")
print("-" * 40)

t0 = time.perf_counter()
output = generate(
model,
tokenizer,
prompt,
prompt_str,
max_new_tokens=args.max_new_tokens,
temperature=args.temperature,
eos_token_ids=eos_token_ids,
Expand Down
24 changes: 23 additions & 1 deletion examples/models/gemma4_31b/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -675,7 +675,29 @@ def materialize_runtime_buffers(

for layer in model.layers:
attn = layer.self_attn
attn.inv_freq = attn.inv_freq.to(device)
if attn.is_sliding:
rotary_dim = attn.head_dim
else:
rotary_dim = int(attn.head_dim * attn.partial_rotary)
rope_angles = rotary_dim // 2
inv_freq_rotated = 1.0 / (
attn.rope_theta
** (
torch.arange(0, rotary_dim, 2, device=device, dtype=torch.float32)
/ attn.head_dim
)
)
nope_angles = attn.head_dim // 2 - rope_angles
if nope_angles > 0:
inv_freq = torch.cat(
[
inv_freq_rotated,
torch.zeros(nope_angles, device=device, dtype=torch.float32),
]
)
else:
inv_freq = inv_freq_rotated
attn.register_buffer("inv_freq", inv_freq, persistent=False)

model.register_buffer(
"embed_normalizer",
Expand Down
Loading