Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions docs/source/_static/adapter-template.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,6 @@ def __init__(self, cfg: Any) -> None:
self.cfg.attn_only = False # True only for attention-only models (rare)
self.cfg.uses_rms_norm = True # Should match normalization_type

# TODO: Set the epsilon attribute name used by this model's normalization
# Check the HF model's norm layer to find the correct attribute name
self.cfg.eps_attr = "variance_epsilon" # or "layer_norm_eps", "rms_norm_eps", etc.

# TODO: Handle GQA if applicable
# If the model uses Grouped Query Attention (n_key_value_heads < n_heads):
if hasattr(cfg, "n_key_value_heads") and cfg.n_key_value_heads is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,6 @@ Set these on `self.cfg` in `__init__` *before* building the component mapping (t
| `gated_mlp` | `bool` | MLP has gate projection (SwiGLU) |
| `attn_only` | `bool` | Model has no MLP layers (rare) |
| `uses_rms_norm` | `bool` | Should match `normalization_type == "RMS"` |
| `eps_attr` | `str` | HF attribute name for norm epsilon |

For GQA models, also forward `n_key_value_heads`:

Expand Down Expand Up @@ -284,7 +283,6 @@ Both must be clean. Don't paper over mypy errors with `# type: ignore` — fix t

## Common pitfalls

- **Wrong `eps_attr` name.** Models that look identical use different attribute names (`variance_epsilon`, `rms_norm_eps`, `eps`). Read the norm class.
- **Forgetting `n_key_value_heads`.** Without it, GQA models silently reshape weights as if they were MHA — verification fails with cryptic shape errors.
- **Missing registration.** Adapter exists but the factory can't find it. Update both `__init__.py` and `architecture_adapter_factory.py`.
- **Skipping `setup_component_testing` for RoPE.** Rotary embeddings need to be wired through to each attention bridge or component testing produces nonsense.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ Set these on `self.cfg` in `__init__` before building the component mapping:
| `gated_mlp` | `bool` | Whether MLP uses gate projection | Llama=True, GPT2=False |
| `attn_only` | `bool` | Whether model has no MLP layers | Usually False |
| `uses_rms_norm` | `bool` | Redundant with normalization_type but needed | Match normalization_type |
| `eps_attr` | `str` | Attribute name for norm epsilon | `"variance_epsilon"`, `"layer_norm_eps"` |

### GQA (Grouped Query Attention)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ Organize around the three things an adapter decides (config, component mapping,
| Area | Worth asserting | Skip |
| --- | --- | --- |
| **Component mapping** | The HF module paths and bridge **types** for this arch — especially non-standard ones (`transformer.wte`, `model.tok_embeddings`, `out_proj`, `fc_in`, `EncDecAttention`); the distinctive bridge (`JointQKVAttentionBridge`, `ParallelBlockBridge`, `SymbolicBridge`, `MoEBridge`, `SigLIP`); the exact submodule **set** (e.g. attention has `q_norm`/`k_norm`, or block has no `ln2`). | — |
| **Config quirks** | Propagation that drives *behavior*: `n_key_value_heads` (GQA) through the adapter's own branch, custom `eps_attr` value, softcap / `logit_scale` coercion + `None`-fallback, `rmsnorm_uses_offset`, `parallel_attn_mlp`, `uses_combined_qkv`, `supports_fold_ln=False` when a fused projection forces it, multimodal/`gated_q_proj` flags. | A flag whose only effect is the literal you set (see "config-literal" below). |
| **Config quirks** | Propagation that drives *behavior*: `n_key_value_heads` (GQA) through the adapter's own branch, softcap / `logit_scale` coercion + `None`-fallback, `rmsnorm_uses_offset`, `parallel_attn_mlp`, `uses_combined_qkv`, `supports_fold_ln=False` when a fused projection forces it, multimodal/`gated_q_proj` flags. | A flag whose only effect is the literal you set (see "config-literal" below). |
| **Weight conversions** | Logic the **adapter** implements: a fused-QKV split's numerical partition (which rows are Q vs K vs V — e.g. GPT-2 thirds, CodeGen's `[Q,V,K]` `mp_num` ordering, Baichuan/InternLM2 interleaved layouts), a manual LayerNorm fold (values folded, weight reset to ones, dtype preserved), the exact conversion **key set** (no stray norm/bias entries). | The einops rearrange itself (see "dependency test"). |
| **Overrides** | Each branch of `setup_component_testing` / `preprocess_weights` / `prepare_model` / `prepare_loading` you wrote — the happy path *and* the defensive `hasattr`/`None` guards, the no-op-when-absent path, the rejection guard. | Overrides you didn't write. |
| **Behavioral hook shapes** | Where the adapter's config drives reshaping: GQA `hook_k`/`hook_v` at `n_key_value_heads`, MQA single KV head, hybrid layers where attn hooks are **absent** on linear-attention layers. | Generic `(batch, seq, d_model)` output shape (it's the shared bridge's contract, not yours). |
Expand Down
2 changes: 1 addition & 1 deletion docs/source/content/contributing.md
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ Two test layers:

### Common adapter gotchas

- **HF raw config attributes are invisible to TL-side consumers unless explicitly propagated to `self.cfg`.** Walk the HF `config.json` and mirror any non-standard knobs (`final_logit_softcapping`, `attn_logit_softcapping`, `query_pre_attn_scalar`, `sliding_window`, `layer_types`, custom `eps_attr` names) onto `self.cfg` so weight processing and forward passes can see them.
- **HF raw config attributes are invisible to TL-side consumers unless explicitly propagated to `self.cfg`.** Walk the HF `config.json` and mirror any non-standard knobs (`final_logit_softcapping`, `attn_logit_softcapping`, `query_pre_attn_scalar`, `sliding_window`, `layer_types`) onto `self.cfg` so weight processing and forward passes can see them.
- **Some config attrs need both surface-on-cfg AND fold-into-weight** via a `preprocess_weights()` override. The trigger: a numerical operation HF's forward applies natively must also be baked into the raw weights, or `bridge.enable_compatibility_mode()` (which calls `process_weights` on raw weights) produces wrong results. Concrete examples in-tree: Cohere `logit_scale` → `unembed.weight`; Gemma embedding scale (`√d_model`) → `embed.weight`. Skip the fold and Phase 3 / Phase 4 of `verify_models` will silently degrade.
- **Tokenizer policy is per-model, not per-architecture.** Sibling models in the same family routinely differ — the chat-instruct variant may prepend BOS where the base does not, padding side can flip, EOS handling can differ. It's worth re-checking `default_prepend_bos`, padding side, and EOS handling against the specific target rather than copying them from a starter adapter. `tokenizer_config.json` is not always reliable on its own — some architectures (Cohere is a notable example) declare `add_bos_token=False` but HF's `__call__` prepends BOS anyway. The most reliable check is to invoke the tokenizer directly:

Expand Down
1 change: 0 additions & 1 deletion docs/source/content/debugging_numerical_divergence.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ The first hop where they disagree localizes the bug.
| Off by a constant scale in residual | Final-RMS-norm offset missing | `cfg.rmsnorm_uses_offset = True` + `ArithmeticTensorConversion(ADDITION, 1.0)` |
| Logits flat / saturated at extremes | Missing logit softcap | `cfg.output_logits_soft_cap` from HF's `final_logit_softcapping` |
| Attention pattern collapses to argmax | Missing attention-score softcap | `cfg.attn_scores_soft_cap` from HF's `attn_logit_softcapping` |
| Off by `eps` magnitudes in norm | Wrong RMSNorm eps attribute name | `cfg.eps_attr` (Llama uses `"variance_epsilon"`, most others use `"eps"`) |
| First MLP off; gate matches | Forgot gated-MLP wiring | `GatedMLPBridge` with `{gate, in, out}` submodules — not `MLPBridge` |
| Bias-related drift | Adapter assumes biases that don't exist (Llama / RMSNorm) | `ProcessWeights._safe_get_tensor` handles `None`; check the weight-processing conversions are bias-aware |
| Drift only in compatibility mode | Hook semantic carve-out missing for post-norm or MLA | See [compatibility_mode.md](compatibility_mode.md) §"Hook semantic parity" |
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,6 @@ def _make_w_pack_component(d_model: int) -> Any:


class TestBaichuanAdapterConfig:
def test_eps_attr(self, adapter: BaichuanArchitectureAdapter) -> None:
assert adapter.cfg.eps_attr == "variance_epsilon"

def test_supports_fold_ln_false(self, adapter: BaichuanArchitectureAdapter) -> None:
assert adapter.supports_fold_ln is False

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,6 @@ def test_uses_rms_norm_is_false(self, adapter: CohereArchitectureAdapter) -> Non
# CohereLayerNorm subtracts the mean — NOT RMSNorm.
assert adapter.cfg.uses_rms_norm is False

def test_eps_attr_is_variance_epsilon(self, adapter: CohereArchitectureAdapter) -> None:
# CohereLayerNorm stores epsilon as self.variance_epsilon.
assert adapter.cfg.eps_attr == "variance_epsilon"

def test_parallel_attn_mlp_is_true(self, adapter: CohereArchitectureAdapter) -> None:
# Single input_layernorm; attn and MLP run in parallel on same normed input.
assert adapter.cfg.parallel_attn_mlp is True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -265,10 +265,6 @@ def test_only_qkvo_conversion_keys(self, adapter: GPTBigCodeArchitectureAdapter)
def test_uses_rms_norm_false(self, adapter: GPTBigCodeArchitectureAdapter) -> None:
assert adapter.cfg.uses_rms_norm is False

def test_eps_attr(self, adapter: GPTBigCodeArchitectureAdapter) -> None:
# GPT-2 family eps (not RMS variance_epsilon).
assert adapter.cfg.eps_attr == "layer_norm_epsilon"


# ---------------------------------------------------------------------------
# MQAQKVConversionRule tests
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,15 +158,6 @@ def __init__(self, cfg: TransformerBridgeConfig) -> None:
self.o_proj = nn.Linear(cfg.n_heads * cfg.d_head, cfg.d_model, bias=False)


class TestGPTOSSAdapterConfig:
"""Adapter-owned config defaults that downstream bridge code relies on."""

def test_eps_attr_is_variance_epsilon(self, adapter: GPTOSSArchitectureAdapter) -> None:
"""GPT-OSS uses HF's `variance_epsilon` attribute name on RMSNorm modules,
not the default `eps`. Downstream norm-folding reads this attribute."""
assert adapter.cfg.eps_attr == "variance_epsilon"


class TestGPTOSSWeightConversions:
"""GPT-OSS uses the standard QKVO weight conversions (no biases), with GQA head counts."""

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,6 @@ def _fill_interleaved(
class TestInternLM2AdapterConfig:
"""Adapter sets all required config attributes."""

def test_eps_attr(self, adapter: InternLM2ArchitectureAdapter) -> None:
assert adapter.cfg.eps_attr == "variance_epsilon"

def test_supports_fold_ln_false(self, adapter: InternLM2ArchitectureAdapter) -> None:
# fold_ln silently skips attn when wqkv is fused in bridge state dict.
assert adapter.supports_fold_ln is False
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,6 @@ def adapter(self):
def test_is_multimodal(self, adapter):
assert adapter.cfg.is_multimodal is True

def test_eps_attr(self, adapter):
assert adapter.cfg.eps_attr == "variance_epsilon"

def test_vision_config_extracted(self, adapter):
assert adapter.cfg.vision_hidden_size == 1024
assert adapter.cfg.vision_num_layers == 24
Expand Down
2 changes: 0 additions & 2 deletions transformer_lens/config/transformer_bridge_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@ def __init__(
NTK_by_parts_low_freq_factor: float = 1.0,
NTK_by_parts_high_freq_factor: float = 4.0,
NTK_by_parts_factor: float = 8.0,
eps_attr: str = "eps",
rmsnorm_uses_offset: bool = False,
attn_implementation: Optional[str] = None,
# Audio model configuration
Expand Down Expand Up @@ -176,7 +175,6 @@ def __init__(
self.NTK_by_parts_low_freq_factor = NTK_by_parts_low_freq_factor
self.NTK_by_parts_high_freq_factor = NTK_by_parts_high_freq_factor
self.NTK_by_parts_factor = NTK_by_parts_factor
self.eps_attr = eps_attr
self.rmsnorm_uses_offset = rmsnorm_uses_offset
self.attn_implementation = attn_implementation
# Audio model configuration
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,6 @@ HF raw config attributes are invisible to TL-side consumers unless propagated to
| `query_pre_attn_scalar` | `self.cfg.query_pre_attn_scalar` | Gemma2/3 — query scaling override |
| `sliding_window` | `self.cfg.sliding_window` | Mistral, Qwen2, Gemma2 — local-attention layers |
| `layer_types` | `self.cfg.layer_types` | Hybrid models with per-layer attention type lists |
| Non-standard RMSNorm eps key | `self.cfg.eps_attr = "<attribute_name>"` | Llama uses `"variance_epsilon"` instead of `"eps"` |

**Weight-fold attributes** (need BOTH surface-on-cfg AND fold-into-weight via `preprocess_weights` — see [the next section](#when-to-override-preprocess_weights)):

Expand Down Expand Up @@ -238,7 +237,6 @@ Failure message names the missing set. (`INTENTIONAL_EXCLUDES` in the test handl
| RoPE (rotary positional embeddings) | `llama.py`, `mistral.py`, `qwen2.py`+ | `RotaryEmbeddingBridge(name="model.rotary_emb")` + `cfg.positional_embedding_type = "rotary"` |
| GQA / MQA (`n_key_value_heads < n_heads`) | `llama.py`, `mistral.py`, `falcon.py`, `cohere.py` | Set `cfg.n_key_value_heads`; pass `n_kv_heads=` to `_qkvo_weight_conversions()` |
| RMSNorm with offset | `gemma1.py`, `gemma2.py`, `gemma3.py` | `cfg.rmsnorm_uses_offset = True` + `ArithmeticTensorConversion(ADDITION, 1.0)` |
| Custom RMSNorm eps attribute | `llama.py` | `cfg.eps_attr = "variance_epsilon"` (Llama uses this instead of `eps`) |
| Standard LayerNorm | `gpt2.py`, `bloom.py` | `cfg.normalization_type = "LN"` |
| Gated MLP (`gate_proj`, `up_proj`, `down_proj`) | `llama.py`, `mistral.py`, `gemma1.py`, `qwen2.py`+ | `GatedMLPBridge` with submodules `{gate, in, out}` |
| Combined QKV (`c_attn`) | `gpt2.py`, `bloom.py` | `QKVSplitRearrangeConversion` to split + rearrange |
Expand Down Expand Up @@ -324,7 +322,7 @@ class TestMyArchHookCompatibility:

No weight load, no HF Hub access — synthetic cfg + structural assertions only. Runs in default `make unit-test`.

Add one test per architecture quirk (softcaps, RMSNorm offsets, sliding window, custom `eps_attr`, MoE routing). Gemma1's "must NOT override `setup_hook_compatibility`" is a good one-quirk-one-test example.
Add one test per architecture quirk (softcaps, RMSNorm offsets, sliding window, MoE routing). Gemma1's "must NOT override `setup_hook_compatibility`" is a good one-quirk-one-test example.

### 2. Integration parity test — `tests/integration/model_bridge/test_<arch>_adapter.py`

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,6 @@ def __init__(self, cfg: Any) -> None:
self.cfg.gated_mlp = True
self.cfg.attn_only = False
self.cfg.uses_rms_norm = True
self.cfg.eps_attr = "variance_epsilon"

# Fused W_pack prevents standard fold_ln from reaching Q/K/V separately.
# preprocess_weights() handles it instead.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,8 @@ def __init__(self, cfg: Any) -> None:
# --- Normalization ---
# CohereLayerNorm is true LayerNorm (subtracts mean), NOT RMSNorm.
# uses_rms_norm=False tells NormalizationBridge to subtract the mean.
# eps_attr="variance_epsilon": CohereLayerNorm stores eps as self.variance_epsilon.
self.cfg.normalization_type = "LN"
self.cfg.uses_rms_norm = False
self.cfg.eps_attr = "variance_epsilon"
self.cfg.final_rms = False

# --- Position embeddings and MLP ---
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,6 @@ def __init__(self, cfg: Any) -> None:
self.cfg.gated_mlp = False
self.cfg.attn_only = False
self.cfg.uses_rms_norm = False
self.cfg.eps_attr = "layer_norm_epsilon"
self.cfg.n_key_value_heads = 1 # MQA: always 1 KV head

# Mirror GPT-2 combined-QKV flags
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,6 @@ def __init__(self, cfg: Any) -> None:

self.cfg.normalization_type = "RMS"
self.cfg.uses_rms_norm = True
# GPT-OSS uses 'variance_epsilon' instead of 'eps' for RMSNorm
self.cfg.eps_attr = "variance_epsilon"
# GPT-OSS uses rotary position embeddings, not learned embeddings
self.cfg.positional_embedding_type = "rotary"
# GPT-OSS attention returns (output, attn_weights), not a 3-tuple
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ def _setup_common_config(self, cfg: Any) -> None:
self.cfg.attn_only = False
self.cfg.uses_rms_norm = True
self.cfg.default_prepend_bos = False
self.cfg.eps_attr = "variance_epsilon"

self.default_config = {
"d_model": cfg.d_model,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,6 @@ def __init__(self, cfg: Any) -> None:
self.cfg.gated_mlp = True
self.cfg.attn_only = False
self.cfg.uses_rms_norm = True
self.cfg.eps_attr = "variance_epsilon"

# Standard fold_ln silently skips attention when wqkv is fused (see class docstring).
# preprocess_weights() handles it instead — same approach as phi3.py.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,6 @@ def __init__(self, cfg: Any) -> None:
self.cfg.n_key_value_heads = cfg.n_key_value_heads

self.cfg.uses_rms_norm = True
# Llama uses 'variance_epsilon' instead of 'eps' for RMSNorm
self.cfg.eps_attr = "variance_epsilon"

self.weight_processing_conversions = {
**self._qkvo_weight_conversions(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ def __init__(self, cfg: Any) -> None:
self.cfg.attn_implementation = "eager"
self.cfg.final_rms = True
self.cfg.attn_only = False
self.cfg.eps_attr = "variance_epsilon"

# GQA support
if hasattr(cfg, "n_key_value_heads") and cfg.n_key_value_heads is not None:
Expand Down
Loading