Skip to content

Feat/gemma4 adapters#1385

Draft
huseyincavusbi wants to merge 34 commits into
TransformerLensOrg:mainfrom
huseyincavusbi:feat/gemma4-adapters
Draft

Feat/gemma4 adapters#1385
huseyincavusbi wants to merge 34 commits into
TransformerLensOrg:mainfrom
huseyincavusbi:feat/gemma4-adapters

Conversation

@huseyincavusbi

Copy link
Copy Markdown
Contributor

Description

This PR adds TransformerBridge support for the Gemma 4 model family (E2B, E4B, 26B-A4B, and 31B) through a single unified Gemma4ArchitectureAdapter.

Key Implementation Details

  • Unified Adapter (gemma4.py): Dynamically handles all 4 variants by evaluating initialization configuration flags:
    • MoE Blocks: Submodules conditionally spin up only when enable_moe_block=True (specifically for the 26B variant).
    • KV-Sharing: Dropped gracefully when num_kv_shared_layers > 0 (for E2B/E4B).
    • PLE Embeddings: Surfaced dynamically when hidden_size_per_layer_input > 0.
    • Weight Processing: Maps and converts Gemma 4's joint QKV layout, dual RoPE configurations, alternating sliding/full attention mechanisms, logit softcapping, and RMSNorm.
    • Includes 45 dedicated unit tests verifying config attributes, MoE behavior, and weight conversions.
  • Shared-Library Updates (3 files, fully opt-in, zero regressions on existing adapter tests):
    1. position_embeddings_attention.py: Applies V norm post-reshape (Gemma 4 is the first architecture featuring per-head value normalization). Handles KV-sharing delegation to Hugging Face's original attention implementation when K/V submodules are omitted. Caches computed KV states in shared_kv_states post-RoPE for structural layer reuse.
    2. bridge.py: Introduces a use_native_generate opt-in flag. This bypasses a current Hugging Face transformers dev-version issue where eager attention causes a KV-cache dimension mismatch during generation. Setting this flag (scoped strictly to this adapter) delegates processing to HF's native generate() utilizing SDPA.
    3. main_benchmark.py: Fixes pad_token_id assignment when eos_token_id is a list (Gemma4 uses [1, 106]), taking the first element.

Verification & Performance

All models have been validated.

Fixes #1297

Type of change

Please delete options that are not relevant.

  • New feature (non-breaking change which adds functionality)

Screenshots

Please attach before and after screenshots of the change if applicable.

Checklist:

  • I have commented my code, particularly in hard-to-understand areas
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes
  • I have not rewritten tests relating to key interfaces which would affect backward compatibility

danra and others added 30 commits June 8, 2026 09:14
…ensOrg#1316)

* Add Direct Logit Attribution tool for TransformerBridge

* Resolve review feedback and add Direct Logit Attribution tests

Resolved review feedback from @jlarson4, added tests covering
reconstruction invariants on a distilgpt2 bridge in compatibility mode,
arguments, asserting sum(scores) == logit_diff - (b_U[correct] -
b_U[wrong]) against the model's real logits, plus labels/shape and
batch-averaging checks.

Added additional hardening:
- Fix a latent direction-shape bug: replace the fragile
  answer_tokens.numel()==1 branch with a robust reshape so single-prompt,
  single-token inputs are handled correctly
- Detect hybrid blocks via bridge.layer_types() instead of substring
  matching named_modules(), the codebase's own semantic mechanism
- Import get_act_name from transformer_lens.utilities to avoid the
  transformer_lens.utils DeprecationWarning; drop the invalid
  return_type kwarg to run_with_cache
- Register the analysis subpackage in tools/__init__.py

Closes TransformerLensOrg#1263.
…merLensOrg#1369)

* Add Direct Logit Attribution tool (TransformerLensOrg#1263)

Add transformer_lens/tools/analysis/direct_logit_attribution.py, a single-call
DLA analysis that decomposes a logit (or logit difference) into per-component,
per-layer (logit-lens), or per-head contributions. Wraps the existing
ActivationCache primitives (decompose_resid / accumulated_resid /
stack_head_results / logit_attrs) and works with both HookedTransformer and
TransformerBridge, since they share the cache API.

Returns a DirectLogitAttribution dataclass (attribution tensor + aligned
labels, plus a top(k) helper). Adds integration tests asserting the exact DLA
correctness invariant on both systems: the complete decomposition reconstructs
the model's real logit up to the unembedding bias b_U.

Closes TransformerLensOrg#1263

* Resolving conflicts between 1316 and 1369

* format fixes

---------

Co-authored-by: Azra Bano <azrabano23@gmail.com>
Co-authored-by: Jonah Larson <jonahalarson@comcast.net>
* Add Phi adapter tests

* Add comment about setup component test

* Delete redundant config literal tests
* Fixed SVD interpreter test

* Format SVD interpreter fixture test
The Restricted Loss section called loss_fn(all_logits, labels), but
all_logits had been rearranged earlier into a (p, p, d_vocab) grid for
the logit periodicity analysis. loss_fn's 3-D branch assumes
(batch, pos, d_vocab) and takes logits[:, -1], producing a (p, p)
tensor that crashes the gather against the p*p labels (TransformerLensOrg#543).

Use original_logits instead, which is recomputed just above and is the
same full-dataset loss the cell intends to print. Also clear the stored
RuntimeError output from the cell.
Breaking: removes the public eps_attr constructor argument and the config.eps_attr attribute. The field was never read (its consumer was deleted when NormalizationBridge moved to direct HF delegation), so no model behavior changes, but it is an API removal.
…utes

- Unwrap text_config for Gemma4ForConditionalGeneration models
- Read PLE, KV sharing, layer_types, softcapping from text_cfg
- Add NotImplementedError guard for MoE variants (26B-A4B)
- Update tests to exercise text_config path
@huseyincavusbi huseyincavusbi marked this pull request as draft June 14, 2026 10:49
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Proposal] Gemma4 Architecture Adapter

9 participants