Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
c19a062
feat: Initial Gemma4 architecture adapter (V norm, softcap, PLE/KV co…
huseyincavusbi May 11, 2026
5d5564d
feat: Register Gemma4ArchitectureAdapter in factory and __init__
huseyincavusbi May 11, 2026
b1c2a3d
feat: Add final_rms and eps_attr to Gemma4 adapter config
huseyincavusbi May 11, 2026
39565bc
fix: Use setattr for custom config fields to pass mypy
huseyincavusbi May 11, 2026
eaf190c
fix: Register Gemma4ForConditionalGeneration alias
huseyincavusbi May 11, 2026
cadfe52
fix: Dynamic text prefix for text-only vs multimodal Gemma4 variants
huseyincavusbi May 11, 2026
79d8de4
fix: Add Gemma4 to model_registry and add unit tests
huseyincavusbi May 16, 2026
35fa11c
fix: Read text_config for nested Gemma4 conditional generation attrib…
huseyincavusbi May 16, 2026
eb9f214
Remove dead v_norm weight conversion (with_scale=False has no learnab…
huseyincavusbi May 16, 2026
7fb469e
Add full Gemma4 MoE support with optional submodules for 26B-A4B
huseyincavusbi May 16, 2026
decefd8
Make k_proj, v_proj, k_norm, v_norm optional for KV-sharing layers
huseyincavusbi May 16, 2026
2700965
fix: AutoModel returns Gemma4Model directly, correct text_prefix
huseyincavusbi May 17, 2026
74b6168
fix: revert text_prefix — AutoModelForCausalLM needs model. prefix
huseyincavusbi May 17, 2026
7033b91
fix: check cfg.architecture instead of cfg.architectures for prefix d…
huseyincavusbi May 17, 2026
b54e2cd
fix: delegate to original attention on KV-sharing layers
huseyincavusbi May 17, 2026
d5ce541
fix: store computed KV in shared_kv_states for Gemma4 KV-sharing
huseyincavusbi May 17, 2026
ee60b1c
fix: add Gemma4ForConditionalGeneration to MULTIMODAL_ARCHITECTURES
huseyincavusbi May 20, 2026
6587691
fix: add use_native_generate opt-in flag for hf_generate delegation
huseyincavusbi May 29, 2026
3237784
feat: use_native_generate and prepare_loading for Gemma4 adapter
huseyincavusbi May 29, 2026
6a21267
fix: handle list eos_token_id when setting pad_token_id
huseyincavusbi May 29, 2026
0732f27
fix: apply V norm in post-reshape attention phase for Gemma4
huseyincavusbi May 30, 2026
a8d2c4e
fix: restore Gemma3nForConditionalGeneration in MULTIMODAL_ARCHITECTURES
huseyincavusbi Jun 8, 2026
a30f390
fix: remove dead eps_attr, resolve conflict marker, fix mypy
huseyincavusbi Jun 13, 2026
7eed605
feat: add multimodal vision support to Gemma4 adapter
huseyincavusbi Jun 13, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
362 changes: 362 additions & 0 deletions tests/unit/model_bridge/test_gemma4.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,362 @@
"""Unit tests for Gemma4 architecture adapter registration and configuration."""

from types import SimpleNamespace

import pytest

from transformer_lens.config.TransformerBridgeConfig import TransformerBridgeConfig

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since you began this PR, the structure of the TrasnformerBridgeConfig import path was adjusted due to a name conflict introduced in an related change refactor. Please update this to

from transformer_lens.config import TransformerBridgeConfig

from transformer_lens.factories.architecture_adapter_factory import (
SUPPORTED_ARCHITECTURES,
ArchitectureAdapterFactory,
)
from transformer_lens.model_bridge.supported_architectures.gemma4 import (
Gemma4ArchitectureAdapter,
)


def _make_text_cfg(**kwargs):
"""Create a text_config SimpleNamespace matching Gemma4TextConfig."""
defaults = dict(
hidden_size=1536,
num_attention_heads=8,
num_hidden_layers=35,
num_key_value_heads=1,
intermediate_size=6144,
vocab_size=262144,
head_dim=256,
max_position_embeddings=131072,
rms_norm_eps=1e-6,
sliding_window=512,
)
defaults.update(kwargs)
return SimpleNamespace(**defaults)


def _make_gemma4_cfg(architectures=None, text_config=None, **overrides):
"""Create a TransformerBridgeConfig for Gemma4 E2B."""
arch = architectures or ["Gemma4ForCausalLM"]
defaults = dict(
d_model=1536,
d_head=256,
n_heads=8,
n_layers=35,
n_ctx=8192,
d_vocab=262144,
n_key_value_heads=1,
d_mlp=6144,
architecture=arch[0],
)
defaults.update(overrides)
cfg = TransformerBridgeConfig(**defaults)
setattr(cfg, "architectures", arch)
if text_config is not None:
setattr(cfg, "text_config", text_config)
return cfg


class TestGemma4Registration:
"""Test that Gemma4ArchitectureAdapter is properly registered."""

def test_architecture_in_supported_architectures(self):
assert "Gemma4ForCausalLM" in SUPPORTED_ARCHITECTURES

def test_conditional_generation_registered(self):
assert "Gemma4ForConditionalGeneration" in SUPPORTED_ARCHITECTURES

def test_architecture_maps_to_correct_adapter(self):
assert SUPPORTED_ARCHITECTURES["Gemma4ForCausalLM"] is Gemma4ArchitectureAdapter

def test_factory_selects_correct_adapter(self):
cfg = _make_gemma4_cfg()
adapter = ArchitectureAdapterFactory.select_architecture_adapter(cfg)
assert isinstance(adapter, Gemma4ArchitectureAdapter)

def test_factory_selects_conditional_generation(self):
cfg = _make_gemma4_cfg(architecture="Gemma4ForConditionalGeneration")
adapter = ArchitectureAdapterFactory.select_architecture_adapter(cfg)
assert isinstance(adapter, Gemma4ArchitectureAdapter)


class TestGemma4ConfigAttributes:
"""Test Gemma4ArchitectureAdapter configuration attributes."""

@pytest.fixture
def adapter(self):
cfg = _make_gemma4_cfg(architectures=["Gemma4ForCausalLM"])
return Gemma4ArchitectureAdapter(cfg)

def test_gated_mlp(self, adapter):
assert adapter.cfg.gated_mlp is True

def test_uses_rms_norm(self, adapter):
assert adapter.cfg.uses_rms_norm is True

def test_normalization_type(self, adapter):
assert adapter.cfg.normalization_type == "RMS"

def test_final_rms(self, adapter):
assert adapter.cfg.final_rms is True

def test_rmsnorm_uses_offset(self, adapter):
assert adapter.cfg.rmsnorm_uses_offset is True

def test_positional_embedding_type(self, adapter):
assert adapter.cfg.positional_embedding_type == "rotary"

def test_attn_implementation(self, adapter):
assert adapter.cfg.attn_implementation == "eager"


class TestGemma4Softcapping:
"""Test logit and attention softcapping attribute mapping."""

def test_output_logits_soft_cap(self):
tc = _make_text_cfg(final_logit_softcapping=30.0)
cfg = _make_gemma4_cfg(text_config=tc)
adapter = Gemma4ArchitectureAdapter(cfg)
assert adapter.cfg.output_logits_soft_cap == 30.0

def test_attn_scores_soft_cap(self):
tc = _make_text_cfg(attn_logit_softcapping=50.0)
cfg = _make_gemma4_cfg(text_config=tc)
adapter = Gemma4ArchitectureAdapter(cfg)
assert adapter.cfg.attn_scores_soft_cap == 50.0

def test_no_softcapping_when_absent(self):
cfg = _make_gemma4_cfg()
adapter = Gemma4ArchitectureAdapter(cfg)
# defaults are -1.0 (unchanged by adapter when softcapping not in text config)
assert adapter.cfg.attn_scores_soft_cap == -1.0
assert adapter.cfg.output_logits_soft_cap == -1.0


class TestGemma4E2BConfig:
"""Test Gemma4 E-series specific config: PLE, KV sharing, layer_types."""

def test_hidden_size_per_layer_input(self):
tc = _make_text_cfg(hidden_size_per_layer_input=256)
cfg = _make_gemma4_cfg(text_config=tc)
adapter = Gemma4ArchitectureAdapter(cfg)
assert adapter.cfg.hidden_size_per_layer_input == 256

def test_num_kv_shared_layers(self):
tc = _make_text_cfg(num_kv_shared_layers=20)
cfg = _make_gemma4_cfg(text_config=tc)
adapter = Gemma4ArchitectureAdapter(cfg)
assert adapter.cfg.num_kv_shared_layers == 20

def test_layer_types(self):
layer_types = ["sliding_attention"] * 35
tc = _make_text_cfg(layer_types=layer_types)
cfg = _make_gemma4_cfg(text_config=tc)
adapter = Gemma4ArchitectureAdapter(cfg)
assert adapter.cfg.layer_types == layer_types

def test_ple_not_set_when_absent(self):
cfg = _make_gemma4_cfg()
adapter = Gemma4ArchitectureAdapter(cfg)
assert not hasattr(adapter.cfg, "hidden_size_per_layer_input")

def test_kv_sharing_not_set_when_absent(self):
cfg = _make_gemma4_cfg()
adapter = Gemma4ArchitectureAdapter(cfg)
assert not hasattr(adapter.cfg, "num_kv_shared_layers")


class TestGemma4MoESupport:
"""Test that MoE submodules are added when enable_moe_block=True."""

def test_moe_block_submodules_exist(self):
tc = _make_text_cfg(
enable_moe_block=True, num_experts=128, top_k_experts=8, moe_intermediate_size=704
)
cfg = _make_gemma4_cfg(text_config=tc)
adapter = Gemma4ArchitectureAdapter(cfg)
assert adapter.enable_moe_block is True
block = adapter.component_mapping["blocks"]
assert "router" in block.submodules
assert "experts" in block.submodules
assert "ln2_post_moe_1" in block.submodules
assert "ln2_pre_moe_2" in block.submodules
assert "ln2_post_moe_2" in block.submodules

def test_moe_block_names(self):
tc = _make_text_cfg(
enable_moe_block=True, num_experts=128, top_k_experts=8, moe_intermediate_size=704
)
cfg = _make_gemma4_cfg(text_config=tc)
adapter = Gemma4ArchitectureAdapter(cfg)
block = adapter.component_mapping["blocks"]
assert block.submodules["router"].name == "router"
assert block.submodules["experts"].name == "experts"
assert block.submodules["ln2_post_moe_1"].name == "post_feedforward_layernorm_1"
assert block.submodules["ln2_pre_moe_2"].name == "pre_feedforward_layernorm_2"
assert block.submodules["ln2_post_moe_2"].name == "post_feedforward_layernorm_2"

def test_dense_has_no_moe_submodules(self):
cfg = _make_gemma4_cfg()
adapter = Gemma4ArchitectureAdapter(cfg)
block = adapter.component_mapping["blocks"]
assert "router" not in block.submodules
assert "experts" not in block.submodules


class TestGemma4TextPrefix:
"""Test text prefix detection for text-only vs multimodal."""

def test_text_prefix_causal(self):
cfg = _make_gemma4_cfg(architectures=["Gemma4ForCausalLM"])
adapter = Gemma4ArchitectureAdapter(cfg)
assert adapter.text_prefix == "model"

def test_text_prefix_conditional_generation(self):
cfg = _make_gemma4_cfg(architectures=["Gemma4ForConditionalGeneration"])
adapter = Gemma4ArchitectureAdapter(cfg)
assert adapter.text_prefix == "model.language_model"


class TestGemma4ComponentMapping:
"""Test Gemma4ArchitectureAdapter component mapping."""

@pytest.fixture
def adapter(self):
cfg = _make_gemma4_cfg(architectures=["Gemma4ForCausalLM"])
return Gemma4ArchitectureAdapter(cfg)

def test_has_embed(self, adapter):
assert "embed" in adapter.component_mapping

def test_has_rotary_emb(self, adapter):
assert "rotary_emb" in adapter.component_mapping

def test_has_blocks(self, adapter):
assert "blocks" in adapter.component_mapping

def test_has_ln_final(self, adapter):
assert "ln_final" in adapter.component_mapping

def test_has_unembed(self, adapter):
assert "unembed" in adapter.component_mapping

def test_embed_path_causal(self, adapter):
assert adapter.component_mapping["embed"].name == "model.embed_tokens"
assert adapter.component_mapping["rotary_emb"].name == "model.rotary_emb"
assert adapter.component_mapping["blocks"].name == "model.layers"
assert adapter.component_mapping["ln_final"].name == "model.norm"

def test_unembed_path(self, adapter):
assert adapter.component_mapping["unembed"].name == "lm_head"

def test_block_submodules(self, adapter):
block = adapter.component_mapping["blocks"]
assert "ln1" in block.submodules
assert "ln1_post" in block.submodules
assert "ln2" in block.submodules
assert "ln2_post" in block.submodules
assert "attn" in block.submodules
assert "mlp" in block.submodules

def test_ln_names(self, adapter):
block = adapter.component_mapping["blocks"]
assert block.submodules["ln1"].name == "input_layernorm"
assert block.submodules["ln1_post"].name == "post_attention_layernorm"
assert block.submodules["ln2"].name == "pre_feedforward_layernorm"
assert block.submodules["ln2_post"].name == "post_feedforward_layernorm"

def test_attn_submodules(self, adapter):
attn = adapter.component_mapping["blocks"].submodules["attn"]
assert attn.name == "self_attn"
assert "q" in attn.submodules
assert "k" in attn.submodules
assert "v" in attn.submodules
assert "o" in attn.submodules
assert "q_norm" in attn.submodules
assert "k_norm" in attn.submodules
assert "v_norm" in attn.submodules

def test_attn_linear_names(self, adapter):
attn = adapter.component_mapping["blocks"].submodules["attn"]
assert attn.submodules["q"].name == "q_proj"
assert attn.submodules["k"].name == "k_proj"
assert attn.submodules["v"].name == "v_proj"
assert attn.submodules["o"].name == "o_proj"

def test_attn_norm_names(self, adapter):
attn = adapter.component_mapping["blocks"].submodules["attn"]
assert attn.submodules["q_norm"].name == "q_norm"
assert attn.submodules["k_norm"].name == "k_norm"
assert attn.submodules["v_norm"].name == "v_norm"

def test_mlp_submodules(self, adapter):
mlp = adapter.component_mapping["blocks"].submodules["mlp"]
assert mlp.name == "mlp"
assert "gate" in mlp.submodules
assert "in" in mlp.submodules
assert "out" in mlp.submodules

def test_mlp_linear_names(self, adapter):
mlp = adapter.component_mapping["blocks"].submodules["mlp"]
assert mlp.submodules["gate"].name == "gate_proj"
assert mlp.submodules["in"].name == "up_proj"
assert mlp.submodules["out"].name == "down_proj"


class TestGemma4ComponentMappingMultimodal:
"""Test component mapping paths for multimodal (ConditionalGeneration) variant."""

@pytest.fixture
def adapter(self):
cfg = _make_gemma4_cfg(
architecture="Gemma4ForConditionalGeneration",
architectures=["Gemma4ForConditionalGeneration"],
)
return Gemma4ArchitectureAdapter(cfg)

def test_embed_path_conditional(self, adapter):
assert adapter.component_mapping["embed"].name == "model.language_model.embed_tokens"
assert adapter.component_mapping["rotary_emb"].name == "model.language_model.rotary_emb"
assert adapter.component_mapping["blocks"].name == "model.language_model.layers"
assert adapter.component_mapping["ln_final"].name == "model.language_model.norm"

def test_is_multimodal_true_for_conditional(self, adapter):
assert adapter.cfg.is_multimodal is True

def test_has_vision_components(self, adapter):
assert "vision_encoder" in adapter.component_mapping
assert "vision_projector" in adapter.component_mapping
assert adapter.component_mapping["vision_encoder"].name == "model.vision_tower"
assert adapter.component_mapping["vision_projector"].name == "model.embed_vision"


class TestGemma4WeightConversions:
"""Test Gemma4ArchitectureAdapter weight processing conversions exist."""

@pytest.fixture
def adapter(self):
cfg = _make_gemma4_cfg(architectures=["Gemma4ForCausalLM"])
return Gemma4ArchitectureAdapter(cfg)

def test_qkv_weight_conversions(self, adapter):
assert "blocks.{i}.attn.q.weight" in adapter.weight_processing_conversions
assert "blocks.{i}.attn.k.weight" in adapter.weight_processing_conversions
assert "blocks.{i}.attn.v.weight" in adapter.weight_processing_conversions
assert "blocks.{i}.attn.o.weight" in adapter.weight_processing_conversions

def test_norm_weight_conversions(self, adapter):
assert "blocks.{i}.ln1.weight" in adapter.weight_processing_conversions
assert "blocks.{i}.ln1_post.weight" in adapter.weight_processing_conversions
assert "blocks.{i}.ln2.weight" in adapter.weight_processing_conversions
assert "blocks.{i}.ln2_post.weight" in adapter.weight_processing_conversions
assert "ln_final.weight" in adapter.weight_processing_conversions

def test_attn_norm_weight_conversions(self, adapter):
assert "blocks.{i}.attn.q_norm.weight" in adapter.weight_processing_conversions
assert "blocks.{i}.attn.k_norm.weight" in adapter.weight_processing_conversions

def test_mlp_weight_conversions(self, adapter):
assert "blocks.{i}.mlp.gate.weight" in adapter.weight_processing_conversions
assert "blocks.{i}.mlp.in.weight" in adapter.weight_processing_conversions
assert "blocks.{i}.mlp.out.weight" in adapter.weight_processing_conversions

def test_unembed_weight_conversion(self, adapter):
assert "unembed.weight" in adapter.weight_processing_conversions
3 changes: 2 additions & 1 deletion transformer_lens/benchmarks/main_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -941,7 +941,8 @@ def cleanup_model(model, model_name_str: str):
model_name, trust_remote_code=trust_remote_code, token=_hf_token()
)
if not hasattr(hf_config, "pad_token_id") or "pad_token_id" not in hf_config.__dict__:
hf_config.pad_token_id = getattr(hf_config, "eos_token_id", None)
eos = getattr(hf_config, "eos_token_id", None)
hf_config.pad_token_id = eos[0] if isinstance(eos, (list, tuple)) else eos
hf_kwargs["config"] = hf_config
if trust_remote_code:
hf_kwargs["trust_remote_code"] = True
Expand Down
3 changes: 3 additions & 0 deletions transformer_lens/factories/architecture_adapter_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
Gemma3ArchitectureAdapter,
Gemma3MultimodalArchitectureAdapter,
Gemma3nArchitectureAdapter,
Gemma4ArchitectureAdapter,
GPT2ArchitectureAdapter,
Gpt2LmHeadCustomArchitectureAdapter,
GPTBigCodeArchitectureAdapter,
Expand Down Expand Up @@ -86,6 +87,8 @@
"Gemma3ForCausalLM": Gemma3ArchitectureAdapter,
"Gemma3ForConditionalGeneration": Gemma3MultimodalArchitectureAdapter,
"Gemma3nForConditionalGeneration": Gemma3nArchitectureAdapter,
"Gemma4ForCausalLM": Gemma4ArchitectureAdapter,
"Gemma4ForConditionalGeneration": Gemma4ArchitectureAdapter,
"GraniteForCausalLM": GraniteArchitectureAdapter,
"GraniteMoeForCausalLM": GraniteMoeArchitectureAdapter,
"GraniteMoeHybridForCausalLM": GraniteMoeHybridArchitectureAdapter,
Expand Down
Loading
Loading