-
Notifications
You must be signed in to change notification settings - Fork 596
Feat/gemma4 adapters #1385
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
huseyincavusbi
wants to merge
24
commits into
TransformerLensOrg:dev
Choose a base branch
from
huseyincavusbi:feat/gemma4-adapters
base: dev
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Feat/gemma4 adapters #1385
Changes from all commits
Commits
Show all changes
24 commits
Select commit
Hold shift + click to select a range
c19a062
feat: Initial Gemma4 architecture adapter (V norm, softcap, PLE/KV co…
huseyincavusbi 5d5564d
feat: Register Gemma4ArchitectureAdapter in factory and __init__
huseyincavusbi b1c2a3d
feat: Add final_rms and eps_attr to Gemma4 adapter config
huseyincavusbi 39565bc
fix: Use setattr for custom config fields to pass mypy
huseyincavusbi eaf190c
fix: Register Gemma4ForConditionalGeneration alias
huseyincavusbi cadfe52
fix: Dynamic text prefix for text-only vs multimodal Gemma4 variants
huseyincavusbi 79d8de4
fix: Add Gemma4 to model_registry and add unit tests
huseyincavusbi 35fa11c
fix: Read text_config for nested Gemma4 conditional generation attrib…
huseyincavusbi eb9f214
Remove dead v_norm weight conversion (with_scale=False has no learnab…
huseyincavusbi 7fb469e
Add full Gemma4 MoE support with optional submodules for 26B-A4B
huseyincavusbi decefd8
Make k_proj, v_proj, k_norm, v_norm optional for KV-sharing layers
huseyincavusbi 2700965
fix: AutoModel returns Gemma4Model directly, correct text_prefix
huseyincavusbi 74b6168
fix: revert text_prefix — AutoModelForCausalLM needs model. prefix
huseyincavusbi 7033b91
fix: check cfg.architecture instead of cfg.architectures for prefix d…
huseyincavusbi b54e2cd
fix: delegate to original attention on KV-sharing layers
huseyincavusbi d5ce541
fix: store computed KV in shared_kv_states for Gemma4 KV-sharing
huseyincavusbi ee60b1c
fix: add Gemma4ForConditionalGeneration to MULTIMODAL_ARCHITECTURES
huseyincavusbi 6587691
fix: add use_native_generate opt-in flag for hf_generate delegation
huseyincavusbi 3237784
feat: use_native_generate and prepare_loading for Gemma4 adapter
huseyincavusbi 6a21267
fix: handle list eos_token_id when setting pad_token_id
huseyincavusbi 0732f27
fix: apply V norm in post-reshape attention phase for Gemma4
huseyincavusbi a8d2c4e
fix: restore Gemma3nForConditionalGeneration in MULTIMODAL_ARCHITECTURES
huseyincavusbi a30f390
fix: remove dead eps_attr, resolve conflict marker, fix mypy
huseyincavusbi 7eed605
feat: add multimodal vision support to Gemma4 adapter
huseyincavusbi File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,362 @@ | ||
| """Unit tests for Gemma4 architecture adapter registration and configuration.""" | ||
|
|
||
| from types import SimpleNamespace | ||
|
|
||
| import pytest | ||
|
|
||
| from transformer_lens.config.TransformerBridgeConfig import TransformerBridgeConfig | ||
| from transformer_lens.factories.architecture_adapter_factory import ( | ||
| SUPPORTED_ARCHITECTURES, | ||
| ArchitectureAdapterFactory, | ||
| ) | ||
| from transformer_lens.model_bridge.supported_architectures.gemma4 import ( | ||
| Gemma4ArchitectureAdapter, | ||
| ) | ||
|
|
||
|
|
||
| def _make_text_cfg(**kwargs): | ||
| """Create a text_config SimpleNamespace matching Gemma4TextConfig.""" | ||
| defaults = dict( | ||
| hidden_size=1536, | ||
| num_attention_heads=8, | ||
| num_hidden_layers=35, | ||
| num_key_value_heads=1, | ||
| intermediate_size=6144, | ||
| vocab_size=262144, | ||
| head_dim=256, | ||
| max_position_embeddings=131072, | ||
| rms_norm_eps=1e-6, | ||
| sliding_window=512, | ||
| ) | ||
| defaults.update(kwargs) | ||
| return SimpleNamespace(**defaults) | ||
|
|
||
|
|
||
| def _make_gemma4_cfg(architectures=None, text_config=None, **overrides): | ||
| """Create a TransformerBridgeConfig for Gemma4 E2B.""" | ||
| arch = architectures or ["Gemma4ForCausalLM"] | ||
| defaults = dict( | ||
| d_model=1536, | ||
| d_head=256, | ||
| n_heads=8, | ||
| n_layers=35, | ||
| n_ctx=8192, | ||
| d_vocab=262144, | ||
| n_key_value_heads=1, | ||
| d_mlp=6144, | ||
| architecture=arch[0], | ||
| ) | ||
| defaults.update(overrides) | ||
| cfg = TransformerBridgeConfig(**defaults) | ||
| setattr(cfg, "architectures", arch) | ||
| if text_config is not None: | ||
| setattr(cfg, "text_config", text_config) | ||
| return cfg | ||
|
|
||
|
|
||
| class TestGemma4Registration: | ||
| """Test that Gemma4ArchitectureAdapter is properly registered.""" | ||
|
|
||
| def test_architecture_in_supported_architectures(self): | ||
| assert "Gemma4ForCausalLM" in SUPPORTED_ARCHITECTURES | ||
|
|
||
| def test_conditional_generation_registered(self): | ||
| assert "Gemma4ForConditionalGeneration" in SUPPORTED_ARCHITECTURES | ||
|
|
||
| def test_architecture_maps_to_correct_adapter(self): | ||
| assert SUPPORTED_ARCHITECTURES["Gemma4ForCausalLM"] is Gemma4ArchitectureAdapter | ||
|
|
||
| def test_factory_selects_correct_adapter(self): | ||
| cfg = _make_gemma4_cfg() | ||
| adapter = ArchitectureAdapterFactory.select_architecture_adapter(cfg) | ||
| assert isinstance(adapter, Gemma4ArchitectureAdapter) | ||
|
|
||
| def test_factory_selects_conditional_generation(self): | ||
| cfg = _make_gemma4_cfg(architecture="Gemma4ForConditionalGeneration") | ||
| adapter = ArchitectureAdapterFactory.select_architecture_adapter(cfg) | ||
| assert isinstance(adapter, Gemma4ArchitectureAdapter) | ||
|
|
||
|
|
||
| class TestGemma4ConfigAttributes: | ||
| """Test Gemma4ArchitectureAdapter configuration attributes.""" | ||
|
|
||
| @pytest.fixture | ||
| def adapter(self): | ||
| cfg = _make_gemma4_cfg(architectures=["Gemma4ForCausalLM"]) | ||
| return Gemma4ArchitectureAdapter(cfg) | ||
|
|
||
| def test_gated_mlp(self, adapter): | ||
| assert adapter.cfg.gated_mlp is True | ||
|
|
||
| def test_uses_rms_norm(self, adapter): | ||
| assert adapter.cfg.uses_rms_norm is True | ||
|
|
||
| def test_normalization_type(self, adapter): | ||
| assert adapter.cfg.normalization_type == "RMS" | ||
|
|
||
| def test_final_rms(self, adapter): | ||
| assert adapter.cfg.final_rms is True | ||
|
|
||
| def test_rmsnorm_uses_offset(self, adapter): | ||
| assert adapter.cfg.rmsnorm_uses_offset is True | ||
|
|
||
| def test_positional_embedding_type(self, adapter): | ||
| assert adapter.cfg.positional_embedding_type == "rotary" | ||
|
|
||
| def test_attn_implementation(self, adapter): | ||
| assert adapter.cfg.attn_implementation == "eager" | ||
|
|
||
|
|
||
| class TestGemma4Softcapping: | ||
| """Test logit and attention softcapping attribute mapping.""" | ||
|
|
||
| def test_output_logits_soft_cap(self): | ||
| tc = _make_text_cfg(final_logit_softcapping=30.0) | ||
| cfg = _make_gemma4_cfg(text_config=tc) | ||
| adapter = Gemma4ArchitectureAdapter(cfg) | ||
| assert adapter.cfg.output_logits_soft_cap == 30.0 | ||
|
|
||
| def test_attn_scores_soft_cap(self): | ||
| tc = _make_text_cfg(attn_logit_softcapping=50.0) | ||
| cfg = _make_gemma4_cfg(text_config=tc) | ||
| adapter = Gemma4ArchitectureAdapter(cfg) | ||
| assert adapter.cfg.attn_scores_soft_cap == 50.0 | ||
|
|
||
| def test_no_softcapping_when_absent(self): | ||
| cfg = _make_gemma4_cfg() | ||
| adapter = Gemma4ArchitectureAdapter(cfg) | ||
| # defaults are -1.0 (unchanged by adapter when softcapping not in text config) | ||
| assert adapter.cfg.attn_scores_soft_cap == -1.0 | ||
| assert adapter.cfg.output_logits_soft_cap == -1.0 | ||
|
|
||
|
|
||
| class TestGemma4E2BConfig: | ||
| """Test Gemma4 E-series specific config: PLE, KV sharing, layer_types.""" | ||
|
|
||
| def test_hidden_size_per_layer_input(self): | ||
| tc = _make_text_cfg(hidden_size_per_layer_input=256) | ||
| cfg = _make_gemma4_cfg(text_config=tc) | ||
| adapter = Gemma4ArchitectureAdapter(cfg) | ||
| assert adapter.cfg.hidden_size_per_layer_input == 256 | ||
|
|
||
| def test_num_kv_shared_layers(self): | ||
| tc = _make_text_cfg(num_kv_shared_layers=20) | ||
| cfg = _make_gemma4_cfg(text_config=tc) | ||
| adapter = Gemma4ArchitectureAdapter(cfg) | ||
| assert adapter.cfg.num_kv_shared_layers == 20 | ||
|
|
||
| def test_layer_types(self): | ||
| layer_types = ["sliding_attention"] * 35 | ||
| tc = _make_text_cfg(layer_types=layer_types) | ||
| cfg = _make_gemma4_cfg(text_config=tc) | ||
| adapter = Gemma4ArchitectureAdapter(cfg) | ||
| assert adapter.cfg.layer_types == layer_types | ||
|
|
||
| def test_ple_not_set_when_absent(self): | ||
| cfg = _make_gemma4_cfg() | ||
| adapter = Gemma4ArchitectureAdapter(cfg) | ||
| assert not hasattr(adapter.cfg, "hidden_size_per_layer_input") | ||
|
|
||
| def test_kv_sharing_not_set_when_absent(self): | ||
| cfg = _make_gemma4_cfg() | ||
| adapter = Gemma4ArchitectureAdapter(cfg) | ||
| assert not hasattr(adapter.cfg, "num_kv_shared_layers") | ||
|
|
||
|
|
||
| class TestGemma4MoESupport: | ||
| """Test that MoE submodules are added when enable_moe_block=True.""" | ||
|
|
||
| def test_moe_block_submodules_exist(self): | ||
| tc = _make_text_cfg( | ||
| enable_moe_block=True, num_experts=128, top_k_experts=8, moe_intermediate_size=704 | ||
| ) | ||
| cfg = _make_gemma4_cfg(text_config=tc) | ||
| adapter = Gemma4ArchitectureAdapter(cfg) | ||
| assert adapter.enable_moe_block is True | ||
| block = adapter.component_mapping["blocks"] | ||
| assert "router" in block.submodules | ||
| assert "experts" in block.submodules | ||
| assert "ln2_post_moe_1" in block.submodules | ||
| assert "ln2_pre_moe_2" in block.submodules | ||
| assert "ln2_post_moe_2" in block.submodules | ||
|
|
||
| def test_moe_block_names(self): | ||
| tc = _make_text_cfg( | ||
| enable_moe_block=True, num_experts=128, top_k_experts=8, moe_intermediate_size=704 | ||
| ) | ||
| cfg = _make_gemma4_cfg(text_config=tc) | ||
| adapter = Gemma4ArchitectureAdapter(cfg) | ||
| block = adapter.component_mapping["blocks"] | ||
| assert block.submodules["router"].name == "router" | ||
| assert block.submodules["experts"].name == "experts" | ||
| assert block.submodules["ln2_post_moe_1"].name == "post_feedforward_layernorm_1" | ||
| assert block.submodules["ln2_pre_moe_2"].name == "pre_feedforward_layernorm_2" | ||
| assert block.submodules["ln2_post_moe_2"].name == "post_feedforward_layernorm_2" | ||
|
|
||
| def test_dense_has_no_moe_submodules(self): | ||
| cfg = _make_gemma4_cfg() | ||
| adapter = Gemma4ArchitectureAdapter(cfg) | ||
| block = adapter.component_mapping["blocks"] | ||
| assert "router" not in block.submodules | ||
| assert "experts" not in block.submodules | ||
|
|
||
|
|
||
| class TestGemma4TextPrefix: | ||
| """Test text prefix detection for text-only vs multimodal.""" | ||
|
|
||
| def test_text_prefix_causal(self): | ||
| cfg = _make_gemma4_cfg(architectures=["Gemma4ForCausalLM"]) | ||
| adapter = Gemma4ArchitectureAdapter(cfg) | ||
| assert adapter.text_prefix == "model" | ||
|
|
||
| def test_text_prefix_conditional_generation(self): | ||
| cfg = _make_gemma4_cfg(architectures=["Gemma4ForConditionalGeneration"]) | ||
| adapter = Gemma4ArchitectureAdapter(cfg) | ||
| assert adapter.text_prefix == "model.language_model" | ||
|
|
||
|
|
||
| class TestGemma4ComponentMapping: | ||
| """Test Gemma4ArchitectureAdapter component mapping.""" | ||
|
|
||
| @pytest.fixture | ||
| def adapter(self): | ||
| cfg = _make_gemma4_cfg(architectures=["Gemma4ForCausalLM"]) | ||
| return Gemma4ArchitectureAdapter(cfg) | ||
|
|
||
| def test_has_embed(self, adapter): | ||
| assert "embed" in adapter.component_mapping | ||
|
|
||
| def test_has_rotary_emb(self, adapter): | ||
| assert "rotary_emb" in adapter.component_mapping | ||
|
|
||
| def test_has_blocks(self, adapter): | ||
| assert "blocks" in adapter.component_mapping | ||
|
|
||
| def test_has_ln_final(self, adapter): | ||
| assert "ln_final" in adapter.component_mapping | ||
|
|
||
| def test_has_unembed(self, adapter): | ||
| assert "unembed" in adapter.component_mapping | ||
|
|
||
| def test_embed_path_causal(self, adapter): | ||
| assert adapter.component_mapping["embed"].name == "model.embed_tokens" | ||
| assert adapter.component_mapping["rotary_emb"].name == "model.rotary_emb" | ||
| assert adapter.component_mapping["blocks"].name == "model.layers" | ||
| assert adapter.component_mapping["ln_final"].name == "model.norm" | ||
|
|
||
| def test_unembed_path(self, adapter): | ||
| assert adapter.component_mapping["unembed"].name == "lm_head" | ||
|
|
||
| def test_block_submodules(self, adapter): | ||
| block = adapter.component_mapping["blocks"] | ||
| assert "ln1" in block.submodules | ||
| assert "ln1_post" in block.submodules | ||
| assert "ln2" in block.submodules | ||
| assert "ln2_post" in block.submodules | ||
| assert "attn" in block.submodules | ||
| assert "mlp" in block.submodules | ||
|
|
||
| def test_ln_names(self, adapter): | ||
| block = adapter.component_mapping["blocks"] | ||
| assert block.submodules["ln1"].name == "input_layernorm" | ||
| assert block.submodules["ln1_post"].name == "post_attention_layernorm" | ||
| assert block.submodules["ln2"].name == "pre_feedforward_layernorm" | ||
| assert block.submodules["ln2_post"].name == "post_feedforward_layernorm" | ||
|
|
||
| def test_attn_submodules(self, adapter): | ||
| attn = adapter.component_mapping["blocks"].submodules["attn"] | ||
| assert attn.name == "self_attn" | ||
| assert "q" in attn.submodules | ||
| assert "k" in attn.submodules | ||
| assert "v" in attn.submodules | ||
| assert "o" in attn.submodules | ||
| assert "q_norm" in attn.submodules | ||
| assert "k_norm" in attn.submodules | ||
| assert "v_norm" in attn.submodules | ||
|
|
||
| def test_attn_linear_names(self, adapter): | ||
| attn = adapter.component_mapping["blocks"].submodules["attn"] | ||
| assert attn.submodules["q"].name == "q_proj" | ||
| assert attn.submodules["k"].name == "k_proj" | ||
| assert attn.submodules["v"].name == "v_proj" | ||
| assert attn.submodules["o"].name == "o_proj" | ||
|
|
||
| def test_attn_norm_names(self, adapter): | ||
| attn = adapter.component_mapping["blocks"].submodules["attn"] | ||
| assert attn.submodules["q_norm"].name == "q_norm" | ||
| assert attn.submodules["k_norm"].name == "k_norm" | ||
| assert attn.submodules["v_norm"].name == "v_norm" | ||
|
|
||
| def test_mlp_submodules(self, adapter): | ||
| mlp = adapter.component_mapping["blocks"].submodules["mlp"] | ||
| assert mlp.name == "mlp" | ||
| assert "gate" in mlp.submodules | ||
| assert "in" in mlp.submodules | ||
| assert "out" in mlp.submodules | ||
|
|
||
| def test_mlp_linear_names(self, adapter): | ||
| mlp = adapter.component_mapping["blocks"].submodules["mlp"] | ||
| assert mlp.submodules["gate"].name == "gate_proj" | ||
| assert mlp.submodules["in"].name == "up_proj" | ||
| assert mlp.submodules["out"].name == "down_proj" | ||
|
|
||
|
|
||
| class TestGemma4ComponentMappingMultimodal: | ||
| """Test component mapping paths for multimodal (ConditionalGeneration) variant.""" | ||
|
|
||
| @pytest.fixture | ||
| def adapter(self): | ||
| cfg = _make_gemma4_cfg( | ||
| architecture="Gemma4ForConditionalGeneration", | ||
| architectures=["Gemma4ForConditionalGeneration"], | ||
| ) | ||
| return Gemma4ArchitectureAdapter(cfg) | ||
|
|
||
| def test_embed_path_conditional(self, adapter): | ||
| assert adapter.component_mapping["embed"].name == "model.language_model.embed_tokens" | ||
| assert adapter.component_mapping["rotary_emb"].name == "model.language_model.rotary_emb" | ||
| assert adapter.component_mapping["blocks"].name == "model.language_model.layers" | ||
| assert adapter.component_mapping["ln_final"].name == "model.language_model.norm" | ||
|
|
||
| def test_is_multimodal_true_for_conditional(self, adapter): | ||
| assert adapter.cfg.is_multimodal is True | ||
|
|
||
| def test_has_vision_components(self, adapter): | ||
| assert "vision_encoder" in adapter.component_mapping | ||
| assert "vision_projector" in adapter.component_mapping | ||
| assert adapter.component_mapping["vision_encoder"].name == "model.vision_tower" | ||
| assert adapter.component_mapping["vision_projector"].name == "model.embed_vision" | ||
|
|
||
|
|
||
| class TestGemma4WeightConversions: | ||
| """Test Gemma4ArchitectureAdapter weight processing conversions exist.""" | ||
|
|
||
| @pytest.fixture | ||
| def adapter(self): | ||
| cfg = _make_gemma4_cfg(architectures=["Gemma4ForCausalLM"]) | ||
| return Gemma4ArchitectureAdapter(cfg) | ||
|
|
||
| def test_qkv_weight_conversions(self, adapter): | ||
| assert "blocks.{i}.attn.q.weight" in adapter.weight_processing_conversions | ||
| assert "blocks.{i}.attn.k.weight" in adapter.weight_processing_conversions | ||
| assert "blocks.{i}.attn.v.weight" in adapter.weight_processing_conversions | ||
| assert "blocks.{i}.attn.o.weight" in adapter.weight_processing_conversions | ||
|
|
||
| def test_norm_weight_conversions(self, adapter): | ||
| assert "blocks.{i}.ln1.weight" in adapter.weight_processing_conversions | ||
| assert "blocks.{i}.ln1_post.weight" in adapter.weight_processing_conversions | ||
| assert "blocks.{i}.ln2.weight" in adapter.weight_processing_conversions | ||
| assert "blocks.{i}.ln2_post.weight" in adapter.weight_processing_conversions | ||
| assert "ln_final.weight" in adapter.weight_processing_conversions | ||
|
|
||
| def test_attn_norm_weight_conversions(self, adapter): | ||
| assert "blocks.{i}.attn.q_norm.weight" in adapter.weight_processing_conversions | ||
| assert "blocks.{i}.attn.k_norm.weight" in adapter.weight_processing_conversions | ||
|
|
||
| def test_mlp_weight_conversions(self, adapter): | ||
| assert "blocks.{i}.mlp.gate.weight" in adapter.weight_processing_conversions | ||
| assert "blocks.{i}.mlp.in.weight" in adapter.weight_processing_conversions | ||
| assert "blocks.{i}.mlp.out.weight" in adapter.weight_processing_conversions | ||
|
|
||
| def test_unembed_weight_conversion(self, adapter): | ||
| assert "unembed.weight" in adapter.weight_processing_conversions | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since you began this PR, the structure of the
TrasnformerBridgeConfigimport path was adjusted due to a name conflict introduced in an related change refactor. Please update this to