Feat/gemma4 adapters#1385
Draft
huseyincavusbi wants to merge 34 commits into
Draft
Conversation
Fix broken link in README
…ensOrg#1316) * Add Direct Logit Attribution tool for TransformerBridge * Resolve review feedback and add Direct Logit Attribution tests Resolved review feedback from @jlarson4, added tests covering reconstruction invariants on a distilgpt2 bridge in compatibility mode, arguments, asserting sum(scores) == logit_diff - (b_U[correct] - b_U[wrong]) against the model's real logits, plus labels/shape and batch-averaging checks. Added additional hardening: - Fix a latent direction-shape bug: replace the fragile answer_tokens.numel()==1 branch with a robust reshape so single-prompt, single-token inputs are handled correctly - Detect hybrid blocks via bridge.layer_types() instead of substring matching named_modules(), the codebase's own semantic mechanism - Import get_act_name from transformer_lens.utilities to avoid the transformer_lens.utils DeprecationWarning; drop the invalid return_type kwarg to run_with_cache - Register the analysis subpackage in tools/__init__.py Closes TransformerLensOrg#1263.
…merLensOrg#1369) * Add Direct Logit Attribution tool (TransformerLensOrg#1263) Add transformer_lens/tools/analysis/direct_logit_attribution.py, a single-call DLA analysis that decomposes a logit (or logit difference) into per-component, per-layer (logit-lens), or per-head contributions. Wraps the existing ActivationCache primitives (decompose_resid / accumulated_resid / stack_head_results / logit_attrs) and works with both HookedTransformer and TransformerBridge, since they share the cache API. Returns a DirectLogitAttribution dataclass (attribution tensor + aligned labels, plus a top(k) helper). Adds integration tests asserting the exact DLA correctness invariant on both systems: the complete decomposition reconstructs the model's real logit up to the unembedding bias b_U. Closes TransformerLensOrg#1263 * Resolving conflicts between 1316 and 1369 * format fixes --------- Co-authored-by: Azra Bano <azrabano23@gmail.com> Co-authored-by: Jonah Larson <jonahalarson@comcast.net>
* Add Phi adapter tests * Add comment about setup component test * Delete redundant config literal tests
* Fixed SVD interpreter test * Format SVD interpreter fixture test
The Restricted Loss section called loss_fn(all_logits, labels), but all_logits had been rearranged earlier into a (p, p, d_vocab) grid for the logit periodicity analysis. loss_fn's 3-D branch assumes (batch, pos, d_vocab) and takes logits[:, -1], producing a (p, p) tensor that crashes the gather against the p*p labels (TransformerLensOrg#543). Use original_logits instead, which is recomputed just above and is the same full-dataset loss the cell intends to print. Also clear the stored RuntimeError output from the cell.
Breaking: removes the public eps_attr constructor argument and the config.eps_attr attribute. The field was never read (its consumer was deleted when NormalizationBridge moved to direct HF delegation), so no model behavior changes, but it is an API removal.
…utes - Unwrap text_config for Gemma4ForConditionalGeneration models - Read PLE, KV sharing, layer_types, softcapping from text_cfg - Add NotImplementedError guard for MoE variants (26B-A4B) - Update tests to exercise text_config path
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
This PR adds
TransformerBridgesupport for the Gemma 4 model family (E2B,E4B,26B-A4B, and31B) through a single unifiedGemma4ArchitectureAdapter.Key Implementation Details
gemma4.py): Dynamically handles all 4 variants by evaluating initialization configuration flags:enable_moe_block=True(specifically for the26Bvariant).num_kv_shared_layers > 0(forE2B/E4B).hidden_size_per_layer_input > 0.position_embeddings_attention.py: Applies V norm post-reshape (Gemma 4 is the first architecture featuring per-head value normalization). Handles KV-sharing delegation to Hugging Face's original attention implementation when K/V submodules are omitted. Caches computed KV states inshared_kv_statespost-RoPE for structural layer reuse.bridge.py: Introduces ause_native_generateopt-in flag. This bypasses a current Hugging Facetransformersdev-version issue where eager attention causes a KV-cache dimension mismatch during generation. Setting this flag (scoped strictly to this adapter) delegates processing to HF's nativegenerate()utilizing SDPA.main_benchmark.py: Fixes pad_token_id assignment when eos_token_id is a list (Gemma4 uses [1, 106]), taking the first element.Verification & Performance
All models have been validated.
Fixes #1297
Type of change
Please delete options that are not relevant.
Screenshots
Please attach before and after screenshots of the change if applicable.
Checklist: