Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 80 additions & 0 deletions tests/integration/model_bridge/test_gemma4_bridge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
"""Integration tests for the Gemma 4 text-only TransformerBridge.

Uses tiny random-init `Gemma4ForConditionalGeneration` fixtures (4 layers, d_model 8) so CI
stays light while still exercising the real per-layer heterogeneity across the family:

- ``tiny-random/gemma-4-e`` — Per-Layer Embeddings + KV-cache sharing (E2B/E4B shape)
- ``tiny-random/gemma-4-dense`` — K==V attention on global layers, no v_proj (31B shape)
- ``tiny-random/gemma-4-moe`` — router + batched experts beside the dense MLP (26B-A4B shape)

Confirms logit parity vs HF (the block bridge defers all math to HF) and that hooks fire on
the conventional single-stream residual.
"""

import pytest
import torch

from transformer_lens.model_bridge import TransformerBridge

MODEL_NAMES = {
"ple_kv_shared": "tiny-random/gemma-4-e",
"dense_k_eq_v": "tiny-random/gemma-4-dense",
"moe": "tiny-random/gemma-4-moe",
}
IDS = torch.tensor([[1, 2, 3, 4, 5]])


@pytest.fixture(scope="module", params=list(MODEL_NAMES), ids=list(MODEL_NAMES))
def bridge(request):
return TransformerBridge.boot_transformers(
MODEL_NAMES[request.param], device="cpu", dtype=torch.float32
)


def test_text_only_logit_parity_vs_hf(bridge):
from transformers import AutoModelForCausalLM

hf = AutoModelForCausalLM.from_pretrained(
bridge.cfg.model_name, torch_dtype=torch.float32, attn_implementation="eager"
).eval()
with torch.no_grad():
ref = hf(IDS).logits
out = bridge.forward(IDS, return_type="logits")
assert out.shape == ref.shape
# PLE / KV-sharing / K==V / MoE all run inside HF — the bridge is a pass-through.
assert torch.max(torch.abs(out - ref)).item() < 1e-3


def test_config_from_text_config(bridge):
# Text dims resolve from the nested text_config of the multimodal model.
assert bridge.cfg.n_layers == 4
assert getattr(bridge.cfg, "is_multimodal", False) is False


def test_resid_hooks_fire_with_conventional_shape(bridge):
"""The residual stream is a single conventional (batch, seq, d_model) tensor."""
captured = {}

def cap(tensor, hook):
captured[hook.name] = tensor.detach()
return tensor

names = [
n
for n in bridge.hook_dict
if n.endswith("blocks.0.hook_resid_pre") or n.endswith("blocks.0.hook_resid_post")
]
assert names, "no residual hooks registered"
with torch.no_grad():
bridge.run_with_hooks(IDS, fwd_hooks=[(n, cap) for n in names])

assert captured, "residual hooks did not fire"
for tensor in captured.values():
assert tensor.shape == (IDS.shape[0], IDS.shape[1], bridge.cfg.d_model)


def test_run_with_cache_text_only(bridge):
with torch.no_grad():
logits, cache = bridge.run_with_cache(IDS)
assert torch.isfinite(logits).all()
assert len(cache) > 0
124 changes: 124 additions & 0 deletions tests/unit/model_bridge/supported_architectures/test_gemma4_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
"""Unit tests for the Gemma 4 text-only architecture adapter."""

from transformer_lens.config.transformer_bridge_config import TransformerBridgeConfig
from transformer_lens.factories.architecture_adapter_factory import (
ArchitectureAdapterFactory,
)
from transformer_lens.model_bridge.generalized_components import (
DelegatedAttentionBlockBridge,
EmbeddingBridge,
LinearBridge,
RotaryEmbeddingBridge,
UnembeddingBridge,
)

ARCH = "Gemma4ForConditionalGeneration"


def _adapter():
# Dimensions follow google/gemma-4-E2B's text_config.
cfg = TransformerBridgeConfig(
d_model=1536,
d_head=256,
n_heads=8,
n_layers=35,
n_ctx=131072,
d_vocab=262144,
n_key_value_heads=1,
architecture=ARCH,
)
return ArchitectureAdapterFactory.select_architecture_adapter(cfg)


def test_config_flags():
a = _adapter()
# Text-only; PLE / layer_scalar / MoE residual topology is not fold-safe.
assert a.cfg.is_multimodal is False
assert a.supports_fold_ln is False
assert a.weight_processing_conversions == {}
assert a.cfg.normalization_type == "RMS"
# Gemma4RMSNorm scales by weight directly — no (1 + weight) offset, unlike Gemma 1-3.
assert a.cfg.rmsnorm_uses_offset is False
assert a.cfg.positional_embedding_type == "rotary"
assert a.applicable_phases == [1, 2, 4]


def test_text_path_nested_under_language_model():
m = _adapter().component_mapping
assert m["embed"].name == "model.language_model.embed_tokens"
assert m["rotary_emb"].name == "model.language_model.rotary_emb"
assert m["blocks"].name == "model.language_model.layers"
assert m["ln_final"].name == "model.language_model.norm"
assert m["unembed"].name == "lm_head"
assert isinstance(m["embed"], EmbeddingBridge)
assert isinstance(m["rotary_emb"], RotaryEmbeddingBridge)
assert isinstance(m["blocks"], DelegatedAttentionBlockBridge)
assert isinstance(m["unembed"], UnembeddingBridge)
# Vision/audio towers are referenced-but-unbridged.
assert "vision_encoder" not in m and "audio_encoder" not in m


def test_block_decomposition():
blocks = _adapter().component_mapping["blocks"]
for name in ("attn", "mlp"):
assert name in blocks.submodules
# Sandwich norms (same shape as Gemma 2/3) under canonical keys.
for norm in ("ln1", "ln1_post", "ln2", "ln2_post"):
assert norm in blocks.submodules
assert blocks.submodules[norm].optional is False


def test_split_qkv_fork_aliases_absent():
"""Attention is delegated wholesale to HF; per-layer structure is heterogeneous
(KV-shared layers have no k/v projections), so the split-qkv fork aliases
do not apply."""
blocks = _adapter().component_mapping["blocks"]
for alias in ("hook_q_input", "hook_k_input", "hook_v_input", "hook_attn_in"):
assert alias not in blocks.hook_aliases
# The single-stream residual aliases remain, redirected through the sandwich norms.
assert blocks.hook_aliases["hook_resid_mid"] == "ln2.hook_in"
assert blocks.hook_aliases["hook_attn_out"] == "ln1_post.hook_out"
assert blocks.hook_aliases["hook_mlp_out"] == "ln2_post.hook_out"


def test_kv_shared_and_k_eq_v_submodules_are_optional():
"""KV-shared layers (E2B/E4B) drop k/v proj + norms; K==V global-attention
layers (31B / 26B-A4B) drop v_proj."""
attn = _adapter().component_mapping["blocks"].submodules["attn"]
assert attn.submodules["q"].optional is False
assert attn.submodules["o"].optional is False
assert attn.submodules["q_norm"].optional is False
for shared in ("k", "v", "k_norm", "v_norm"):
assert attn.submodules[shared].optional is True
assert isinstance(attn.submodules["q"], LinearBridge)


def test_per_layer_embedding_submodules_are_optional():
"""PLE modules exist only when hidden_size_per_layer_input > 0 (E2B/E4B)."""
blocks = _adapter().component_mapping["blocks"]
for name in (
"per_layer_input_gate",
"per_layer_projection",
"post_per_layer_input_norm",
):
assert blocks.submodules[name].optional is True


def test_moe_submodules_are_optional():
"""MoE branch exists only when enable_moe_block (26B-A4B)."""
blocks = _adapter().component_mapping["blocks"]
for name in (
"router",
"experts",
"pre_feedforward_layernorm_2",
"post_feedforward_layernorm_1",
"post_feedforward_layernorm_2",
):
assert blocks.submodules[name].optional is True


def test_gated_mlp_decomposition():
mlp = _adapter().component_mapping["blocks"].submodules["mlp"]
assert mlp.submodules["gate"].name == "gate_proj"
assert mlp.submodules["in"].name == "up_proj"
assert mlp.submodules["out"].name == "down_proj"
6 changes: 6 additions & 0 deletions transformer_lens/factories/architecture_adapter_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
Gemma3ArchitectureAdapter,
Gemma3MultimodalArchitectureAdapter,
Gemma3nArchitectureAdapter,
Gemma4ArchitectureAdapter,
GPT2ArchitectureAdapter,
Gpt2LmHeadCustomArchitectureAdapter,
GPTBigCodeArchitectureAdapter,
Expand Down Expand Up @@ -86,6 +87,11 @@
"Gemma3ForCausalLM": Gemma3ArchitectureAdapter,
"Gemma3ForConditionalGeneration": Gemma3MultimodalArchitectureAdapter,
"Gemma3nForConditionalGeneration": Gemma3nArchitectureAdapter,
"Gemma4ForConditionalGeneration": Gemma4ArchitectureAdapter,
# The unified (encoder-free) 12B variant's text decoder is a strict structural
# subset of gemma4 (no PLE, no MoE — both optional in the adapter), with the
# same module paths. Requires transformers >= 5.10 to load.
"Gemma4UnifiedForConditionalGeneration": Gemma4ArchitectureAdapter,
"GraniteForCausalLM": GraniteArchitectureAdapter,
"GraniteMoeForCausalLM": GraniteMoeArchitectureAdapter,
"GraniteMoeHybridForCausalLM": GraniteMoeHybridArchitectureAdapter,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
)
from transformer_lens.model_bridge.generalized_components.block import (
BlockBridge,
DelegatedAttentionBlockBridge,
MLABlockBridge,
ParallelBlockBridge,
)
Expand Down Expand Up @@ -108,6 +109,7 @@
"AttentionBridge",
"AudioFeatureExtractorBridge",
"BlockBridge",
"DelegatedAttentionBlockBridge",
"MLABlockBridge",
"ParallelBlockBridge",
"BloomBlockBridge",
Expand Down
31 changes: 31 additions & 0 deletions transformer_lens/model_bridge/generalized_components/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,3 +405,34 @@ def __init__(
if self.hook_aliases is BlockBridge.hook_aliases:
self.hook_aliases = dict(self.hook_aliases)
self.hook_aliases.pop("hook_resid_mid", None)


class DelegatedAttentionBlockBridge(BlockBridge):
"""Block whose attention is delegated wholesale to HF (no split-qkv fork).
For architectures with heterogeneous per-layer attention structure — e.g.
Gemma 4, where KV-shared layers have no ``k_proj``/``v_proj`` at all and
K==V layers have no ``v_proj`` — there is no uniform HookPoint that
represents "input that becomes Q/K/V", so the block-level ``hook_q_input``/
``hook_k_input``/``hook_v_input``/``hook_attn_in`` aliases do not apply.
Type-level distinction means a reader of the adapter sees
``DelegatedAttentionBlockBridge`` and knows those hooks are absent.
"""

def __init__(
self,
name: str,
config: Optional[Any] = None,
submodules: Optional[Dict[str, GeneralizedComponent]] = None,
hook_alias_overrides: Optional[Dict[str, str]] = None,
):
super().__init__(
name,
config=config,
submodules=submodules,
hook_alias_overrides=hook_alias_overrides,
)
if self.hook_aliases is BlockBridge.hook_aliases:
self.hook_aliases = dict(self.hook_aliases)
for alias in ("hook_q_input", "hook_k_input", "hook_v_input", "hook_attn_in"):
self.hook_aliases.pop(alias, None)
4 changes: 4 additions & 0 deletions transformer_lens/model_bridge/sources/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,10 @@ def determine_architecture_from_hf_config(hf_config):
# gemma3n is tri-modal; the text path loads as the full ForConditionalGeneration
# (vision/audio referenced but unbridged in the text-only adapter).
"gemma3n": "Gemma3nForConditionalGeneration",
# gemma4 is multimodal-only; all released checkpoints load as the full
# ForConditionalGeneration (vision/audio referenced but unbridged).
"gemma4": "Gemma4ForConditionalGeneration",
"gemma4_unified": "Gemma4UnifiedForConditionalGeneration",
"bert": "BertForMaskedLM",
"bloom": "BloomForCausalLM",
"codegen": "CodeGenForCausalLM",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@
from transformer_lens.model_bridge.supported_architectures.gemma3n import (
Gemma3nArchitectureAdapter,
)
from transformer_lens.model_bridge.supported_architectures.gemma4 import (
Gemma4ArchitectureAdapter,
)
from transformer_lens.model_bridge.supported_architectures.gpt2 import (
GPT2ArchitectureAdapter,
)
Expand Down Expand Up @@ -189,6 +192,7 @@
"Gemma3ArchitectureAdapter",
"Gemma3nArchitectureAdapter",
"Gemma3MultimodalArchitectureAdapter",
"Gemma4ArchitectureAdapter",
"GraniteArchitectureAdapter",
"GraniteMoeArchitectureAdapter",
"GraniteMoeHybridArchitectureAdapter",
Expand Down
Loading
Loading