From 3be749cfd964b2dad44cbefba25f39a4337babf6 Mon Sep 17 00:00:00 2001 From: punishell Date: Wed, 10 Jun 2026 14:31:32 +0200 Subject: [PATCH] Add Gemma 4 architecture support to TransformerBridge Adds a text-only adapter covering both Gemma4ForConditionalGeneration (E2B/E4B/31B/26B-A4B) and Gemma4UnifiedForConditionalGeneration (12B), addressing #1297. Gemma 4 layers are heterogeneous: KV-shared layers drop k/v projections, K==V layers drop v_proj, and per-layer-embedding / MoE submodules appear only on some variants -- all mapped optional and delegated to HF. Unlike Gemma 1-3, Gemma4RMSNorm has no (1+weight) offset. Adds DelegatedAttentionBlockBridge (drops the split-QKV fork aliases, as MLABlockBridge does) so hook-alias resolution stays clean when attention is delegated wholesale to HF. google/gemma-4-E2B-it passes verification (P1 100%, P2 100%, P4 94.7%). - New adapter + four-place registration + gemma4/gemma4_unified model_type mappings - 10 checkpoints added to the model registry - Unit + integration tests (logit parity vs HF on all three structural variants) --- .../model_bridge/test_gemma4_bridge.py | 80 ++++++++++ .../test_gemma4_adapter.py | 124 +++++++++++++++ .../factories/architecture_adapter_factory.py | 6 + .../generalized_components/__init__.py | 2 + .../generalized_components/block.py | 31 ++++ .../model_bridge/sources/transformers.py | 4 + .../supported_architectures/__init__.py | 4 + .../supported_architectures/gemma4.py | 139 +++++++++++++++++ .../tools/model_registry/__init__.py | 4 + .../model_registry/data/supported_models.json | 146 +++++++++++++++++- .../data/verification_history.json | 12 +- .../tools/model_registry/generate_report.py | 2 + 12 files changed, 550 insertions(+), 4 deletions(-) create mode 100644 tests/integration/model_bridge/test_gemma4_bridge.py create mode 100644 tests/unit/model_bridge/supported_architectures/test_gemma4_adapter.py create mode 100644 transformer_lens/model_bridge/supported_architectures/gemma4.py diff --git a/tests/integration/model_bridge/test_gemma4_bridge.py b/tests/integration/model_bridge/test_gemma4_bridge.py new file mode 100644 index 000000000..a8e67dc28 --- /dev/null +++ b/tests/integration/model_bridge/test_gemma4_bridge.py @@ -0,0 +1,80 @@ +"""Integration tests for the Gemma 4 text-only TransformerBridge. + +Uses tiny random-init `Gemma4ForConditionalGeneration` fixtures (4 layers, d_model 8) so CI +stays light while still exercising the real per-layer heterogeneity across the family: + +- ``tiny-random/gemma-4-e`` — Per-Layer Embeddings + KV-cache sharing (E2B/E4B shape) +- ``tiny-random/gemma-4-dense`` — K==V attention on global layers, no v_proj (31B shape) +- ``tiny-random/gemma-4-moe`` — router + batched experts beside the dense MLP (26B-A4B shape) + +Confirms logit parity vs HF (the block bridge defers all math to HF) and that hooks fire on +the conventional single-stream residual. +""" + +import pytest +import torch + +from transformer_lens.model_bridge import TransformerBridge + +MODEL_NAMES = { + "ple_kv_shared": "tiny-random/gemma-4-e", + "dense_k_eq_v": "tiny-random/gemma-4-dense", + "moe": "tiny-random/gemma-4-moe", +} +IDS = torch.tensor([[1, 2, 3, 4, 5]]) + + +@pytest.fixture(scope="module", params=list(MODEL_NAMES), ids=list(MODEL_NAMES)) +def bridge(request): + return TransformerBridge.boot_transformers( + MODEL_NAMES[request.param], device="cpu", dtype=torch.float32 + ) + + +def test_text_only_logit_parity_vs_hf(bridge): + from transformers import AutoModelForCausalLM + + hf = AutoModelForCausalLM.from_pretrained( + bridge.cfg.model_name, torch_dtype=torch.float32, attn_implementation="eager" + ).eval() + with torch.no_grad(): + ref = hf(IDS).logits + out = bridge.forward(IDS, return_type="logits") + assert out.shape == ref.shape + # PLE / KV-sharing / K==V / MoE all run inside HF — the bridge is a pass-through. + assert torch.max(torch.abs(out - ref)).item() < 1e-3 + + +def test_config_from_text_config(bridge): + # Text dims resolve from the nested text_config of the multimodal model. + assert bridge.cfg.n_layers == 4 + assert getattr(bridge.cfg, "is_multimodal", False) is False + + +def test_resid_hooks_fire_with_conventional_shape(bridge): + """The residual stream is a single conventional (batch, seq, d_model) tensor.""" + captured = {} + + def cap(tensor, hook): + captured[hook.name] = tensor.detach() + return tensor + + names = [ + n + for n in bridge.hook_dict + if n.endswith("blocks.0.hook_resid_pre") or n.endswith("blocks.0.hook_resid_post") + ] + assert names, "no residual hooks registered" + with torch.no_grad(): + bridge.run_with_hooks(IDS, fwd_hooks=[(n, cap) for n in names]) + + assert captured, "residual hooks did not fire" + for tensor in captured.values(): + assert tensor.shape == (IDS.shape[0], IDS.shape[1], bridge.cfg.d_model) + + +def test_run_with_cache_text_only(bridge): + with torch.no_grad(): + logits, cache = bridge.run_with_cache(IDS) + assert torch.isfinite(logits).all() + assert len(cache) > 0 diff --git a/tests/unit/model_bridge/supported_architectures/test_gemma4_adapter.py b/tests/unit/model_bridge/supported_architectures/test_gemma4_adapter.py new file mode 100644 index 000000000..b0387f12e --- /dev/null +++ b/tests/unit/model_bridge/supported_architectures/test_gemma4_adapter.py @@ -0,0 +1,124 @@ +"""Unit tests for the Gemma 4 text-only architecture adapter.""" + +from transformer_lens.config.transformer_bridge_config import TransformerBridgeConfig +from transformer_lens.factories.architecture_adapter_factory import ( + ArchitectureAdapterFactory, +) +from transformer_lens.model_bridge.generalized_components import ( + DelegatedAttentionBlockBridge, + EmbeddingBridge, + LinearBridge, + RotaryEmbeddingBridge, + UnembeddingBridge, +) + +ARCH = "Gemma4ForConditionalGeneration" + + +def _adapter(): + # Dimensions follow google/gemma-4-E2B's text_config. + cfg = TransformerBridgeConfig( + d_model=1536, + d_head=256, + n_heads=8, + n_layers=35, + n_ctx=131072, + d_vocab=262144, + n_key_value_heads=1, + architecture=ARCH, + ) + return ArchitectureAdapterFactory.select_architecture_adapter(cfg) + + +def test_config_flags(): + a = _adapter() + # Text-only; PLE / layer_scalar / MoE residual topology is not fold-safe. + assert a.cfg.is_multimodal is False + assert a.supports_fold_ln is False + assert a.weight_processing_conversions == {} + assert a.cfg.normalization_type == "RMS" + # Gemma4RMSNorm scales by weight directly — no (1 + weight) offset, unlike Gemma 1-3. + assert a.cfg.rmsnorm_uses_offset is False + assert a.cfg.positional_embedding_type == "rotary" + assert a.applicable_phases == [1, 2, 4] + + +def test_text_path_nested_under_language_model(): + m = _adapter().component_mapping + assert m["embed"].name == "model.language_model.embed_tokens" + assert m["rotary_emb"].name == "model.language_model.rotary_emb" + assert m["blocks"].name == "model.language_model.layers" + assert m["ln_final"].name == "model.language_model.norm" + assert m["unembed"].name == "lm_head" + assert isinstance(m["embed"], EmbeddingBridge) + assert isinstance(m["rotary_emb"], RotaryEmbeddingBridge) + assert isinstance(m["blocks"], DelegatedAttentionBlockBridge) + assert isinstance(m["unembed"], UnembeddingBridge) + # Vision/audio towers are referenced-but-unbridged. + assert "vision_encoder" not in m and "audio_encoder" not in m + + +def test_block_decomposition(): + blocks = _adapter().component_mapping["blocks"] + for name in ("attn", "mlp"): + assert name in blocks.submodules + # Sandwich norms (same shape as Gemma 2/3) under canonical keys. + for norm in ("ln1", "ln1_post", "ln2", "ln2_post"): + assert norm in blocks.submodules + assert blocks.submodules[norm].optional is False + + +def test_split_qkv_fork_aliases_absent(): + """Attention is delegated wholesale to HF; per-layer structure is heterogeneous + (KV-shared layers have no k/v projections), so the split-qkv fork aliases + do not apply.""" + blocks = _adapter().component_mapping["blocks"] + for alias in ("hook_q_input", "hook_k_input", "hook_v_input", "hook_attn_in"): + assert alias not in blocks.hook_aliases + # The single-stream residual aliases remain, redirected through the sandwich norms. + assert blocks.hook_aliases["hook_resid_mid"] == "ln2.hook_in" + assert blocks.hook_aliases["hook_attn_out"] == "ln1_post.hook_out" + assert blocks.hook_aliases["hook_mlp_out"] == "ln2_post.hook_out" + + +def test_kv_shared_and_k_eq_v_submodules_are_optional(): + """KV-shared layers (E2B/E4B) drop k/v proj + norms; K==V global-attention + layers (31B / 26B-A4B) drop v_proj.""" + attn = _adapter().component_mapping["blocks"].submodules["attn"] + assert attn.submodules["q"].optional is False + assert attn.submodules["o"].optional is False + assert attn.submodules["q_norm"].optional is False + for shared in ("k", "v", "k_norm", "v_norm"): + assert attn.submodules[shared].optional is True + assert isinstance(attn.submodules["q"], LinearBridge) + + +def test_per_layer_embedding_submodules_are_optional(): + """PLE modules exist only when hidden_size_per_layer_input > 0 (E2B/E4B).""" + blocks = _adapter().component_mapping["blocks"] + for name in ( + "per_layer_input_gate", + "per_layer_projection", + "post_per_layer_input_norm", + ): + assert blocks.submodules[name].optional is True + + +def test_moe_submodules_are_optional(): + """MoE branch exists only when enable_moe_block (26B-A4B).""" + blocks = _adapter().component_mapping["blocks"] + for name in ( + "router", + "experts", + "pre_feedforward_layernorm_2", + "post_feedforward_layernorm_1", + "post_feedforward_layernorm_2", + ): + assert blocks.submodules[name].optional is True + + +def test_gated_mlp_decomposition(): + mlp = _adapter().component_mapping["blocks"].submodules["mlp"] + assert mlp.submodules["gate"].name == "gate_proj" + assert mlp.submodules["in"].name == "up_proj" + assert mlp.submodules["out"].name == "down_proj" diff --git a/transformer_lens/factories/architecture_adapter_factory.py b/transformer_lens/factories/architecture_adapter_factory.py index 49dd134f7..50b402107 100644 --- a/transformer_lens/factories/architecture_adapter_factory.py +++ b/transformer_lens/factories/architecture_adapter_factory.py @@ -23,6 +23,7 @@ Gemma3ArchitectureAdapter, Gemma3MultimodalArchitectureAdapter, Gemma3nArchitectureAdapter, + Gemma4ArchitectureAdapter, GPT2ArchitectureAdapter, Gpt2LmHeadCustomArchitectureAdapter, GPTBigCodeArchitectureAdapter, @@ -86,6 +87,11 @@ "Gemma3ForCausalLM": Gemma3ArchitectureAdapter, "Gemma3ForConditionalGeneration": Gemma3MultimodalArchitectureAdapter, "Gemma3nForConditionalGeneration": Gemma3nArchitectureAdapter, + "Gemma4ForConditionalGeneration": Gemma4ArchitectureAdapter, + # The unified (encoder-free) 12B variant's text decoder is a strict structural + # subset of gemma4 (no PLE, no MoE — both optional in the adapter), with the + # same module paths. Requires transformers >= 5.10 to load. + "Gemma4UnifiedForConditionalGeneration": Gemma4ArchitectureAdapter, "GraniteForCausalLM": GraniteArchitectureAdapter, "GraniteMoeForCausalLM": GraniteMoeArchitectureAdapter, "GraniteMoeHybridForCausalLM": GraniteMoeHybridArchitectureAdapter, diff --git a/transformer_lens/model_bridge/generalized_components/__init__.py b/transformer_lens/model_bridge/generalized_components/__init__.py index 50d139f16..31d27dd5b 100644 --- a/transformer_lens/model_bridge/generalized_components/__init__.py +++ b/transformer_lens/model_bridge/generalized_components/__init__.py @@ -7,6 +7,7 @@ ) from transformer_lens.model_bridge.generalized_components.block import ( BlockBridge, + DelegatedAttentionBlockBridge, MLABlockBridge, ParallelBlockBridge, ) @@ -108,6 +109,7 @@ "AttentionBridge", "AudioFeatureExtractorBridge", "BlockBridge", + "DelegatedAttentionBlockBridge", "MLABlockBridge", "ParallelBlockBridge", "BloomBlockBridge", diff --git a/transformer_lens/model_bridge/generalized_components/block.py b/transformer_lens/model_bridge/generalized_components/block.py index 506107781..37851b16a 100644 --- a/transformer_lens/model_bridge/generalized_components/block.py +++ b/transformer_lens/model_bridge/generalized_components/block.py @@ -405,3 +405,34 @@ def __init__( if self.hook_aliases is BlockBridge.hook_aliases: self.hook_aliases = dict(self.hook_aliases) self.hook_aliases.pop("hook_resid_mid", None) + + +class DelegatedAttentionBlockBridge(BlockBridge): + """Block whose attention is delegated wholesale to HF (no split-qkv fork). + + For architectures with heterogeneous per-layer attention structure — e.g. + Gemma 4, where KV-shared layers have no ``k_proj``/``v_proj`` at all and + K==V layers have no ``v_proj`` — there is no uniform HookPoint that + represents "input that becomes Q/K/V", so the block-level ``hook_q_input``/ + ``hook_k_input``/``hook_v_input``/``hook_attn_in`` aliases do not apply. + Type-level distinction means a reader of the adapter sees + ``DelegatedAttentionBlockBridge`` and knows those hooks are absent. + """ + + def __init__( + self, + name: str, + config: Optional[Any] = None, + submodules: Optional[Dict[str, GeneralizedComponent]] = None, + hook_alias_overrides: Optional[Dict[str, str]] = None, + ): + super().__init__( + name, + config=config, + submodules=submodules, + hook_alias_overrides=hook_alias_overrides, + ) + if self.hook_aliases is BlockBridge.hook_aliases: + self.hook_aliases = dict(self.hook_aliases) + for alias in ("hook_q_input", "hook_k_input", "hook_v_input", "hook_attn_in"): + self.hook_aliases.pop(alias, None) diff --git a/transformer_lens/model_bridge/sources/transformers.py b/transformer_lens/model_bridge/sources/transformers.py index a26e00004..3eafd3dc4 100644 --- a/transformer_lens/model_bridge/sources/transformers.py +++ b/transformer_lens/model_bridge/sources/transformers.py @@ -221,6 +221,10 @@ def determine_architecture_from_hf_config(hf_config): # gemma3n is tri-modal; the text path loads as the full ForConditionalGeneration # (vision/audio referenced but unbridged in the text-only adapter). "gemma3n": "Gemma3nForConditionalGeneration", + # gemma4 is multimodal-only; all released checkpoints load as the full + # ForConditionalGeneration (vision/audio referenced but unbridged). + "gemma4": "Gemma4ForConditionalGeneration", + "gemma4_unified": "Gemma4UnifiedForConditionalGeneration", "bert": "BertForMaskedLM", "bloom": "BloomForCausalLM", "codegen": "CodeGenForCausalLM", diff --git a/transformer_lens/model_bridge/supported_architectures/__init__.py b/transformer_lens/model_bridge/supported_architectures/__init__.py index 84e4584af..acc8be08b 100644 --- a/transformer_lens/model_bridge/supported_architectures/__init__.py +++ b/transformer_lens/model_bridge/supported_architectures/__init__.py @@ -42,6 +42,9 @@ from transformer_lens.model_bridge.supported_architectures.gemma3n import ( Gemma3nArchitectureAdapter, ) +from transformer_lens.model_bridge.supported_architectures.gemma4 import ( + Gemma4ArchitectureAdapter, +) from transformer_lens.model_bridge.supported_architectures.gpt2 import ( GPT2ArchitectureAdapter, ) @@ -189,6 +192,7 @@ "Gemma3ArchitectureAdapter", "Gemma3nArchitectureAdapter", "Gemma3MultimodalArchitectureAdapter", + "Gemma4ArchitectureAdapter", "GraniteArchitectureAdapter", "GraniteMoeArchitectureAdapter", "GraniteMoeHybridArchitectureAdapter", diff --git a/transformer_lens/model_bridge/supported_architectures/gemma4.py b/transformer_lens/model_bridge/supported_architectures/gemma4.py new file mode 100644 index 000000000..931bd2b61 --- /dev/null +++ b/transformer_lens/model_bridge/supported_architectures/gemma4.py @@ -0,0 +1,139 @@ +"""Gemma 4 text-only architecture adapter. + +Bridges the text path of the multimodal ``Gemma4ForConditionalGeneration`` +(``model.language_model`` + ``lm_head``); the vision/audio towers stay referenced but +unbridged. All released Gemma 4 checkpoints (E2B / E4B / 31B / 26B-A4B) ship as +``Gemma4ForConditionalGeneration``, so there is no separate text-only entry point. + +The same adapter also covers ``Gemma4UnifiedForConditionalGeneration`` (the +encoder-free 12B variant, transformers >= 5.10): its text decoder is a strict +structural subset — same module paths, no PLE and no MoE, both optional here. + +Per-layer structure is heterogeneous across the family, so all math is deferred to HF +and submodules are decomposed only for hooks (parity-safe delegation): + +- **KV sharing** (E2B/E4B): the last ``num_kv_shared_layers`` layers reuse earlier KV + states and drop their own ``k_proj`` / ``v_proj`` / ``k_norm`` / ``v_norm``. +- **K==V attention** (31B / 26B-A4B): global-attention layers share key and value + weights (``attention_k_eq_v``) and have no ``v_proj``. +- **Per-Layer Embeddings** (E2B/E4B): each layer mixes in a per-layer input via + ``per_layer_input_gate`` / ``per_layer_projection`` / ``post_per_layer_input_norm``. +- **MoE** (26B-A4B): layers add a ``router`` + batched ``experts`` block in parallel + with the dense MLP, sandwiched by three extra norms. + +Unlike Gemma 1-3, ``Gemma4RMSNorm`` multiplies by ``weight`` directly — there is no +``(1.0 + weight)`` offset. +""" + +from typing import Any + +from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter +from transformer_lens.model_bridge.generalized_components import ( + DelegatedAttentionBlockBridge, + EmbeddingBridge, + LinearBridge, + RotaryEmbeddingBridge, + UnembeddingBridge, +) +from transformer_lens.model_bridge.generalized_components.base import ( + GeneralizedComponent, +) + + +class Gemma4ArchitectureAdapter(ArchitectureAdapter): + """Text-only adapter for Gemma 4 (`Gemma4ForConditionalGeneration`).""" + + # Phase 3 (processed/compatibility mode) folds LN into a single residual stream, + # which the PLE residual mix, per-layer `layer_scalar` buffers, and the MoE branch + # can't represent. Phases 1 (HF parity), 2 (hooks), and 4 (text quality) apply. + applicable_phases: list[int] = [1, 2, 4] + + def __init__(self, cfg: Any) -> None: + super().__init__(cfg) + + self.cfg.is_multimodal = False + self.cfg.gated_mlp = True + self.cfg.uses_rms_norm = True + self.cfg.normalization_type = "RMS" + # Gemma4RMSNorm scales by weight directly — no (1 + weight) offset, unlike Gemma 1-3. + self.cfg.rmsnorm_uses_offset = False + self.cfg.positional_embedding_type = "rotary" + self.cfg.attn_implementation = "eager" + # PLE / layer_scalar / MoE residual topology isn't fold-safe. + self.supports_fold_ln = False + self.weight_processing_conversions: dict = {} + + self.component_mapping = { + "embed": EmbeddingBridge(name="model.language_model.embed_tokens"), + # Single rotary module serving both layer types (full / sliding) via a + # per-layer-type forward kwarg, with separate rope parameters per type. + "rotary_emb": RotaryEmbeddingBridge(name="model.language_model.rotary_emb"), + "blocks": DelegatedAttentionBlockBridge( + name="model.language_model.layers", + submodules={ + # Sandwich norms: ln1/ln1_post around attention, ln2/ln2_post + # around the MLP (same shape as Gemma 2/3). + "ln1": GeneralizedComponent(name="input_layernorm"), + "ln1_post": GeneralizedComponent(name="post_attention_layernorm"), + "ln2": GeneralizedComponent(name="pre_feedforward_layernorm"), + "ln2_post": GeneralizedComponent(name="post_feedforward_layernorm"), + # PLE residual mix — present only when hidden_size_per_layer_input > 0 + # (E2B/E4B; absent on 31B and 26B-A4B). + "per_layer_input_gate": GeneralizedComponent( + name="per_layer_input_gate", optional=True + ), + "per_layer_projection": GeneralizedComponent( + name="per_layer_projection", optional=True + ), + "post_per_layer_input_norm": GeneralizedComponent( + name="post_per_layer_input_norm", optional=True + ), + # MoE branch — present only when enable_moe_block (26B-A4B). + "router": GeneralizedComponent(name="router", optional=True), + "experts": GeneralizedComponent(name="experts", optional=True), + "pre_feedforward_layernorm_2": GeneralizedComponent( + name="pre_feedforward_layernorm_2", optional=True + ), + "post_feedforward_layernorm_1": GeneralizedComponent( + name="post_feedforward_layernorm_1", optional=True + ), + "post_feedforward_layernorm_2": GeneralizedComponent( + name="post_feedforward_layernorm_2", optional=True + ), + "attn": GeneralizedComponent( + name="self_attn", + submodules={ + "q": LinearBridge(name="q_proj"), + # KV-shared layers (E2B/E4B) drop k/v projections and norms; + # K==V layers (31B / 26B-A4B global attention) drop v_proj. + "k": LinearBridge(name="k_proj", optional=True), + "v": LinearBridge(name="v_proj", optional=True), + "o": LinearBridge(name="o_proj"), + "q_norm": GeneralizedComponent(name="q_norm"), + "k_norm": GeneralizedComponent(name="k_norm", optional=True), + "v_norm": GeneralizedComponent(name="v_norm", optional=True), + }, + ), + "mlp": GeneralizedComponent( + name="mlp", + submodules={ + "gate": LinearBridge(name="gate_proj"), + "in": LinearBridge(name="up_proj"), + "out": LinearBridge(name="down_proj"), + }, + ), + }, + ), + "ln_final": GeneralizedComponent(name="model.language_model.norm"), + "unembed": UnembeddingBridge(name="lm_head"), + } + + def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None: + """Force eager attention so bridge and HF match (sliding/full layer mix).""" + if hasattr(hf_model, "config") and hasattr(hf_model.config, "_attn_implementation"): + hf_model.config._attn_implementation = "eager" + language_model = getattr(getattr(hf_model, "model", None), "language_model", None) + if language_model is not None and hasattr(language_model, "layers"): + for layer in language_model.layers: + if hasattr(layer, "self_attn") and hasattr(layer.self_attn, "config"): + layer.self_attn.config._attn_implementation = "eager" diff --git a/transformer_lens/tools/model_registry/__init__.py b/transformer_lens/tools/model_registry/__init__.py index 769b9b0d1..65d24332e 100644 --- a/transformer_lens/tools/model_registry/__init__.py +++ b/transformer_lens/tools/model_registry/__init__.py @@ -59,6 +59,8 @@ "Gemma3ForCausalLM", "Gemma3ForConditionalGeneration", "Gemma3nForConditionalGeneration", + "Gemma4ForConditionalGeneration", + "Gemma4UnifiedForConditionalGeneration", "GraniteForCausalLM", "GraniteMoeForCausalLM", "GraniteMoeHybridForCausalLM", @@ -118,6 +120,8 @@ "Gemma3ForCausalLM": ["google"], "Gemma3ForConditionalGeneration": ["google"], "Gemma3nForConditionalGeneration": ["google"], + "Gemma4ForConditionalGeneration": ["google"], + "Gemma4UnifiedForConditionalGeneration": ["google"], "GemmaForCausalLM": ["google"], "GPT2LMHeadModel": ["openai-community", "stanford-crfm", "Writer"], "GPTBigCodeForCausalLM": ["bigcode"], diff --git a/transformer_lens/tools/model_registry/data/supported_models.json b/transformer_lens/tools/model_registry/data/supported_models.json index d668e14e8..15f9fb860 100644 --- a/transformer_lens/tools/model_registry/data/supported_models.json +++ b/transformer_lens/tools/model_registry/data/supported_models.json @@ -6,9 +6,9 @@ "min_downloads": 500, "scan_duration_seconds": 8.0 }, - "total_architectures": 55, - "total_models": 12112, - "total_verified": 743, + "total_architectures": 57, + "total_models": 12122, + "total_verified": 744, "models": [ { "architecture_id": "MistralForCausalLM", @@ -168137,6 +168137,146 @@ "phase4_score": null, "phase7_score": null, "phase8_score": null + }, + { + "architecture_id": "Gemma4ForConditionalGeneration", + "model_id": "google/gemma-4-E2B", + "status": 0, + "verified_date": null, + "metadata": null, + "note": null, + "phase1_score": null, + "phase2_score": null, + "phase3_score": null, + "phase4_score": null, + "phase7_score": null, + "phase8_score": null + }, + { + "architecture_id": "Gemma4ForConditionalGeneration", + "model_id": "google/gemma-4-E2B-it", + "status": 1, + "verified_date": "2026-06-10", + "metadata": null, + "note": "Full verification completed", + "phase1_score": 100.0, + "phase2_score": 100.0, + "phase3_score": null, + "phase4_score": 94.7, + "phase7_score": null, + "phase8_score": null + }, + { + "architecture_id": "Gemma4ForConditionalGeneration", + "model_id": "google/gemma-4-E4B", + "status": 0, + "verified_date": null, + "metadata": null, + "note": null, + "phase1_score": null, + "phase2_score": null, + "phase3_score": null, + "phase4_score": null, + "phase7_score": null, + "phase8_score": null + }, + { + "architecture_id": "Gemma4ForConditionalGeneration", + "model_id": "google/gemma-4-E4B-it", + "status": 0, + "verified_date": null, + "metadata": null, + "note": null, + "phase1_score": null, + "phase2_score": null, + "phase3_score": null, + "phase4_score": null, + "phase7_score": null, + "phase8_score": null + }, + { + "architecture_id": "Gemma4ForConditionalGeneration", + "model_id": "google/gemma-4-31B", + "status": 0, + "verified_date": null, + "metadata": null, + "note": null, + "phase1_score": null, + "phase2_score": null, + "phase3_score": null, + "phase4_score": null, + "phase7_score": null, + "phase8_score": null + }, + { + "architecture_id": "Gemma4ForConditionalGeneration", + "model_id": "google/gemma-4-31B-it", + "status": 0, + "verified_date": null, + "metadata": null, + "note": null, + "phase1_score": null, + "phase2_score": null, + "phase3_score": null, + "phase4_score": null, + "phase7_score": null, + "phase8_score": null + }, + { + "architecture_id": "Gemma4ForConditionalGeneration", + "model_id": "google/gemma-4-26B-A4B", + "status": 0, + "verified_date": null, + "metadata": null, + "note": null, + "phase1_score": null, + "phase2_score": null, + "phase3_score": null, + "phase4_score": null, + "phase7_score": null, + "phase8_score": null + }, + { + "architecture_id": "Gemma4ForConditionalGeneration", + "model_id": "google/gemma-4-26B-A4B-it", + "status": 0, + "verified_date": null, + "metadata": null, + "note": null, + "phase1_score": null, + "phase2_score": null, + "phase3_score": null, + "phase4_score": null, + "phase7_score": null, + "phase8_score": null + }, + { + "architecture_id": "Gemma4UnifiedForConditionalGeneration", + "model_id": "google/gemma-4-12B", + "status": 0, + "verified_date": null, + "metadata": null, + "note": null, + "phase1_score": null, + "phase2_score": null, + "phase3_score": null, + "phase4_score": null, + "phase7_score": null, + "phase8_score": null + }, + { + "architecture_id": "Gemma4UnifiedForConditionalGeneration", + "model_id": "google/gemma-4-12B-it", + "status": 0, + "verified_date": null, + "metadata": null, + "note": null, + "phase1_score": null, + "phase2_score": null, + "phase3_score": null, + "phase4_score": null, + "phase7_score": null, + "phase8_score": null } ] } diff --git a/transformer_lens/tools/model_registry/data/verification_history.json b/transformer_lens/tools/model_registry/data/verification_history.json index f96594e31..5756197ab 100644 --- a/transformer_lens/tools/model_registry/data/verification_history.json +++ b/transformer_lens/tools/model_registry/data/verification_history.json @@ -1,5 +1,5 @@ { - "last_updated": "2026-06-05T13:10:29.591019", + "last_updated": "2026-06-10T14:06:20.074159", "records": [ { "model_id": "Macropodus/macbert4mdcspell_v1", @@ -12360,6 +12360,16 @@ "notes": "Full verification completed", "invalidated": false, "invalidation_reason": null + }, + { + "model_id": "google/gemma-4-E2B-it", + "architecture_id": "Gemma4ForConditionalGeneration", + "verified_date": "2026-06-10", + "verified_by": "verify_models", + "transformerlens_version": null, + "notes": "Full verification completed", + "invalidated": false, + "invalidation_reason": null } ] } diff --git a/transformer_lens/tools/model_registry/generate_report.py b/transformer_lens/tools/model_registry/generate_report.py index 9349844a2..e3ead2444 100644 --- a/transformer_lens/tools/model_registry/generate_report.py +++ b/transformer_lens/tools/model_registry/generate_report.py @@ -37,6 +37,8 @@ "Gemma2ForCausalLM": "Google's Gemma 2 with improved architecture", "Gemma3ForCausalLM": "Google's Gemma 3 latest generation", "Gemma3nForConditionalGeneration": "Google's Gemma 3n efficient tri-modal model (text-only support)", + "Gemma4ForConditionalGeneration": "Google's Gemma 4 multimodal model family (text-only support)", + "Gemma4UnifiedForConditionalGeneration": "Google's Gemma 4 unified encoder-free multimodal model (text-only support)", "Qwen2ForCausalLM": "Alibaba's Qwen2 multilingual model", "Qwen3ForCausalLM": "Alibaba's Qwen3 latest generation", "Qwen3_5ForConditionalGeneration": "Alibaba's Qwen3.5 vision-language model",