Gemma 4 31B: chat template, inv_freq dedup, CI hardening#19614
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/19614
Note: Links to docs will display an error until the docs builds have been completed. ❗ 2 Active SEVsThere are 2 currently active SEVs. If your PR is affected, please view them below: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
There was a problem hiding this comment.
Pull request overview
Applies the Gemma 4 IT chat template to user prompts by default in both the Python inference.py and C++ runner, with an opt-out flag for pre-formatted input, to avoid degenerate output from the instruction-tuned model.
Changes:
- Add
apply_chat_templatehelper and--raw-promptflag ininference.py. - Add
--raw_promptflag and template-wrapping logic inmain.cpp(BOS prepended separately). - Document the auto-wrapping and opt-out flags in the README.
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated no comments.
| File | Description |
|---|---|
| examples/models/gemma4_31b/inference.py | Defines chat template, wraps prompt by default, adds --raw-prompt flag |
| examples/models/gemma4_31b/main.cpp | Adds --raw_prompt flag and wraps prompt text with the IT chat template |
| examples/models/gemma4_31b/README.md | Documents auto-wrapping behavior and skip flags for both runners |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
@claude Review this PR |
|
Claude finished @mergennachin's task in 1m 31s —— View job Review of PR #19614
Overall this is a clean, well-scoped change. Two issues worth addressing before merge: 1.
|
Gemma 4 31B-IT is instruction-tuned and produces degenerate output
without the chat template wrapping. Auto-wrap --prompt with the IT
template (<bos><|turn>user\n{prompt}<turn|>\n<|turn>model\n
<|channel>thought\n<channel|>) by default; --raw-prompt / --raw_prompt
skips wrapping for pre-formatted input.
78ee61f to
5d5c26e
Compare
### Summary Currently `materialize_runtime_buffers` in model.py was zeroing out ALL meta buffers, including each layer's inv_freq (RoPE frequencies). The follow-up `attn.inv_freq.to(device)` was a no-op on already-zero tensors. So RoPE produced cos=1, sin=0 for every position → model had NO positional information → introduce the period-N echo cycle pattern. This PR fix the issue by recomputing inv_freq per-layer with real values (using the layer's head_dim, partial_rotary, rope_theta, is_sliding flag) in materialize_runtime_buffers. ### Test plan Add e2e ci for gemma4-31b model and check its output.
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 7 out of 7 changed files in this pull request and generated 2 comments.
Comments suppressed due to low confidence (1)
examples/models/gemma4_31b/model.py:700
- The PR title and description state the change is about applying the Gemma 4 IT chat template in
inference.pyand the C++ runner. However, the diff also includes several unrelated changes that are not mentioned in the description:
model.py: replacesattn.inv_freq = attn.inv_freq.to(device)with a full re-computation ofinv_freq(including partial-rotary / NoPE handling) inmaterialize_runtime_buffers.inference.py: adds a new--bf16input path that callsGemma4_31B.from_hf_checkpoint..github/workflows/cuda.yml: removes thepip install gguf+pytest examples/models/gemma4_31b/quant/tests/ examples/models/gemma4_31b/tests/step, and adds Gemma 4 31B-IT to the export/e2e matrices..ci/scripts/export_model_artifact.shand.ci/scripts/test_model_e2e.sh: add a new Gemma 4 31B pipeline.
Please either update the PR description to cover these changes, or split them into separate PRs so each change can be reviewed against a description that matches its scope.
if attn.is_sliding:
rotary_dim = attn.head_dim
else:
rotary_dim = int(attn.head_dim * attn.partial_rotary)
rope_angles = rotary_dim // 2
inv_freq_rotated = 1.0 / (
attn.rope_theta
** (
torch.arange(0, rotary_dim, 2, device=device, dtype=torch.float32)
/ attn.head_dim
)
)
nope_angles = attn.head_dim // 2 - rope_angles
if nope_angles > 0:
inv_freq = torch.cat(
[
inv_freq_rotated,
torch.zeros(nope_angles, device=device, dtype=torch.float32),
]
)
else:
inv_freq = inv_freq_rotated
attn.register_buffer("inv_freq", inv_freq, persistent=False)
…t tests - Extract _compute_inv_freq() on Gemma4Attention so __init__ and materialize_runtime_buffers share a single implementation. - Check for "Paris" in the export CI inference sanity check instead of just checking the script doesn't crash. - Restore gemma4_31b quant/pipeline unit tests in the CUDA build job. - Update model.md to reflect that inv_freq is recomputed, not moved.
pre-formatted input.