diff --git a/demos/Direct_Path_Patching_Demo.ipynb b/demos/Direct_Path_Patching_Demo.ipynb new file mode 100644 index 000000000..bad609134 --- /dev/null +++ b/demos/Direct_Path_Patching_Demo.ipynb @@ -0,0 +1,330 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Direct Path Patching in TransformerLens\n", + "\n", + "This notebook demonstrates **direct path patching** — a causal intervention technique for mechanistic interpretability.\n", + "\n", + "## What is path patching?\n", + "\n", + "Standard **activation patching** replaces the full residual stream at a layer with activations from a different (\"clean\") run. This tells you: *does this layer's residual stream matter for the output?*\n", + "\n", + "**Direct path patching** is more targeted: it patches only the contribution of a specific component (e.g., a single attention head) to a specific downstream position, without affecting other paths. This lets you isolate *which component, sending to which position, causes the output difference*.\n", + "\n", + "### The difference\n", + "\n", + "| Technique | What you patch | What you learn |\n", + "|---|---|---|\n", + "| Activation patching | Full residual stream at layer L, position p | Does the state at (L, p) matter? |\n", + "| Direct path patching | Output of component C → position p only | Does C's direct contribution to p matter? |\n", + "\n", + "## Setup" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Install if needed\n", + "# !pip install transformer_lens" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "from transformer_lens import HookedTransformer\n", + "import einops\n", + "\n", + "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + "print(f\"Using device: {device}\")\n", + "\n", + "model = HookedTransformer.from_pretrained(\"gpt2\", device=device)\n", + "model.eval()\n", + "print(f\"Loaded GPT-2: {model.cfg.n_layers} layers, d_model={model.cfg.d_model}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Part 1: Basic Activation Patching\n", + "\n", + "Before doing direct path patching, let's establish the baseline: full residual stream patching.\n", + "\n", + "We use a **clean prompt** and a **corrupted prompt** that differ in one key token." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Classic IOI-style setup: two prompts differing in one name\n", + "clean_prompt = \"When Mary and John went to the store, John gave a bag to\"\n", + "corrupt_prompt = \"When Mary and John went to the store, Mary gave a bag to\"\n", + "\n", + "clean_tokens = model.to_tokens(clean_prompt)\n", + "corrupt_tokens = model.to_tokens(corrupt_prompt)\n", + "\n", + "# The clean answer is \" Mary\", the corrupted answer is \" John\"\n", + "answer_token = model.to_single_token(\" Mary\")\n", + "wrong_token = model.to_single_token(\" John\")\n", + "\n", + "print(f\"Clean tokens shape: {clean_tokens.shape}\")\n", + "print(f\"Correct answer token: {answer_token} = '{model.to_string(answer_token)}'\")\n", + "print(f\"Wrong answer token: {wrong_token} = '{model.to_string(wrong_token)}'\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def logit_diff(logits, answer_tok=answer_token, wrong_tok=wrong_token):\n", + " \"\"\"Difference in logits between the correct and incorrect answer tokens.\"\"\"\n", + " return logits[0, -1, answer_tok] - logits[0, -1, wrong_tok]\n", + "\n", + "# Baseline measurements\n", + "with torch.no_grad():\n", + " clean_logits, clean_cache = model.run_with_cache(clean_tokens)\n", + " corrupt_logits, corrupt_cache = model.run_with_cache(corrupt_tokens)\n", + "\n", + "clean_diff = logit_diff(clean_logits).item()\n", + "corrupt_diff = logit_diff(corrupt_logits).item()\n", + "\n", + "print(f\"Clean logit diff (higher = model says Mary): {clean_diff:+.3f}\")\n", + "print(f\"Corrupted logit diff: {corrupt_diff:+.3f}\")\n", + "print(f\"\\nTotal effect to explain: {clean_diff - corrupt_diff:.3f}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def patch_residual_stream(layer, position):\n", + " \"\"\"\n", + " Patch the residual stream at (layer, position) from the clean cache\n", + " into a corrupted forward pass. Returns the normalized logit diff.\n", + " \"\"\"\n", + " hook_name = f\"blocks.{layer}.hook_resid_post\"\n", + " clean_act = clean_cache[hook_name][:, position, :].clone()\n", + "\n", + " def patch_hook(resid, hook):\n", + " resid[:, position, :] = clean_act\n", + " return resid\n", + "\n", + " with torch.no_grad():\n", + " patched_logits = model.run_with_hooks(\n", + " corrupt_tokens,\n", + " fwd_hooks=[(hook_name, patch_hook)]\n", + " )\n", + "\n", + " patched_diff = logit_diff(patched_logits).item()\n", + " # Normalized: 0 = no recovery, 1 = full recovery\n", + " return (patched_diff - corrupt_diff) / (clean_diff - corrupt_diff)\n", + "\n", + "\n", + "# Scan all layers at the final token position\n", + "final_pos = clean_tokens.shape[1] - 1\n", + "print(f\"Patching at final token position ({final_pos}):\\n\")\n", + "print(f\"{'Layer':<8} {'Normalized recovery':>20}\")\n", + "print(\"-\" * 30)\n", + "\n", + "for layer in range(model.cfg.n_layers):\n", + " recovery = patch_residual_stream(layer, final_pos)\n", + " bar = \"█\" * int(abs(recovery) * 20)\n", + " print(f\" L{layer:<5} {recovery:>+.3f} {bar}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Part 2: Direct Path Patching\n", + "\n", + "Full residual stream patching tells us *which layer* matters. But it includes contributions from all upstream components. Direct path patching isolates the contribution of a **single attention head** to the final position.\n", + "\n", + "### How it works\n", + "\n", + "An attention head's output is added to the residual stream at every position. To patch the **direct path** from head `(layer, head)` to the final position:\n", + "\n", + "1. Run the **clean** forward pass with cache\n", + "2. Start a **corrupted** forward pass\n", + "3. At the attention output hook for `(layer, head)`, replace only that head's output at the target position with the clean value\n", + "4. Let the rest of the forward pass continue normally\n", + "\n", + "This isolates the head's direct write to the residual stream — without changing what the head attends to in earlier layers." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def patch_head_output(layer, head, position):\n", + " \"\"\"\n", + " Patch the direct output of attention head (layer, head) at `position`\n", + " from the clean cache into a corrupted forward pass.\n", + " Returns normalized logit diff recovery.\n", + " \"\"\"\n", + " hook_name = f\"blocks.{layer}.attn.hook_result\"\n", + " # hook_result shape: [batch, seq, n_heads, d_model]\n", + " clean_head_out = clean_cache[hook_name][:, position, head, :].clone()\n", + "\n", + " def patch_hook(result, hook):\n", + " result[:, position, head, :] = clean_head_out\n", + " return result\n", + "\n", + " with torch.no_grad():\n", + " patched_logits = model.run_with_hooks(\n", + " corrupt_tokens,\n", + " fwd_hooks=[(hook_name, patch_hook)]\n", + " )\n", + "\n", + " patched_diff = logit_diff(patched_logits).item()\n", + " return (patched_diff - corrupt_diff) / (clean_diff - corrupt_diff)\n", + "\n", + "\n", + "# Scan all heads at the final token position\n", + "print(f\"Direct path patching: head outputs → final position ({final_pos})\")\n", + "print(f\"(Normalized recovery: 0=no effect, 1=full recovery)\\n\")\n", + "\n", + "results = {}\n", + "for layer in range(model.cfg.n_layers):\n", + " for head in range(model.cfg.n_heads):\n", + " recovery = patch_head_output(layer, head, final_pos)\n", + " results[(layer, head)] = recovery\n", + "\n", + "# Show top 10 most important heads\n", + "sorted_heads = sorted(results.items(), key=lambda x: abs(x[1]), reverse=True)\n", + "print(f\"{'Head':<12} {'Recovery':>12}\")\n", + "print(\"-\" * 26)\n", + "for (layer, head), recovery in sorted_heads[:10]:\n", + " bar = \"█\" * int(abs(recovery) * 30)\n", + " print(f\" L{layer}H{head:<6} {recovery:>+.3f} {bar}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Part 3: Visualize the Path Patching Results\n", + "\n", + "Plot a heatmap of direct path recovery across all (layer, head) pairs." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "try:\n", + " import matplotlib.pyplot as plt\n", + " import numpy as np\n", + "\n", + " grid = np.zeros((model.cfg.n_layers, model.cfg.n_heads))\n", + " for (layer, head), recovery in results.items():\n", + " grid[layer, head] = recovery\n", + "\n", + " fig, ax = plt.subplots(figsize=(12, 6))\n", + " im = ax.imshow(grid, cmap=\"RdBu\", vmin=-0.5, vmax=0.5, aspect=\"auto\")\n", + " ax.set_xlabel(\"Head\", fontsize=12)\n", + " ax.set_ylabel(\"Layer\", fontsize=12)\n", + " ax.set_title(\n", + " \"Direct Path Patching Recovery\\n\"\n", + " \"(Clean → Corrupted, patching head output at final token position)\",\n", + " fontsize=13\n", + " )\n", + " ax.set_xticks(range(model.cfg.n_heads))\n", + " ax.set_yticks(range(model.cfg.n_layers))\n", + " plt.colorbar(im, ax=ax, label=\"Normalized logit diff recovery\")\n", + " plt.tight_layout()\n", + " plt.show()\n", + " print(\"Heads with high recovery (red) have strong direct causal influence on the output.\")\n", + "except ImportError:\n", + " print(\"matplotlib not installed — skipping plot. Top heads shown in table above.\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Part 4: Why Direct Path Patching Matters\n", + "\n", + "### Difference from full activation patching\n", + "\n", + "Full residual stream patching at layer L includes contributions from *all* heads in layers 0..L plus all MLPs. It's a coarse measure.\n", + "\n", + "Direct path patching isolates a single head's write, making it much more precise for identifying specific circuit components.\n", + "\n", + "### Caution: the \"no mediation\" assumption\n", + "\n", + "Direct path patching measures the **direct** causal path: `head output → final position → logits`. It does NOT capture:\n", + "- Indirect paths where head A's output is read by head B, which then writes to the final position\n", + "- Effects mediated through MLPs\n", + "\n", + "For full circuit analysis, combine direct path patching with attention pattern analysis (`hook_attn`) to trace indirect paths.\n", + "\n", + "### Connection to the circuit discovery literature\n", + "\n", + "Direct path patching is the core operation in:\n", + "- Wang et al. (2022): *Interpretability in the Wild* (IOI circuit)\n", + "- Conmy et al. (2023): *ACDC* (automated circuit discovery)\n", + "- Marks et al. (2024): *Sparse Feature Circuits*\n", + "\n", + "### Limitations and robustness\n", + "\n", + "**Important:** High path patching recovery at a fixed position does not guarantee that the head is causally necessary for the behavior in all input distributions. Heads that appear causal on one prompt distribution may have their contribution \"read\" by the model differently on adversarially constructed inputs.\n", + "\n", + "Always validate circuit findings with:\n", + "1. Causal ablations (zero-ablate or mean-ablate the head)\n", + "2. Multiple prompt distributions\n", + "3. Cross-model replication" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Summary\n", + "\n", + "| Operation | TransformerLens hook | What it patches |\n", + "|---|---|---|\n", + "| Residual stream patching | `blocks.{L}.hook_resid_post` | Full residual stream at layer L |\n", + "| Direct head output patching | `blocks.{L}.attn.hook_result` | Single head's write to residual stream |\n", + "| Attention pattern patching | `blocks.{L}.attn.hook_attn` | Which tokens a head attends to |\n", + "| MLP output patching | `blocks.{L}.hook_mlp_out` | MLP contribution to residual stream |\n", + "\n", + "Each hook corresponds to a different granularity of causal intervention. Start coarse (residual stream), then refine (head output) to identify the minimal circuit responsible for a behavior." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.10.0" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/tests/unit/model_bridge/supported_architectures/test_neo_adapter.py b/tests/unit/model_bridge/supported_architectures/test_neo_adapter.py new file mode 100644 index 000000000..b8be13f1d --- /dev/null +++ b/tests/unit/model_bridge/supported_architectures/test_neo_adapter.py @@ -0,0 +1,287 @@ +"""Unit tests for NeoArchitectureAdapter. + +Tests cover: +- Config attribute validation (all required attributes are set correctly) +- Component mapping structure (correct bridge types and HF module names) +- Weight conversion keys and count +- NeoLinearTransposeConversion numerical correctness +""" + +import pytest +import torch + +from transformer_lens.config import TransformerBridgeConfig +from transformer_lens.model_bridge.generalized_components import ( + AttentionBridge, + BlockBridge, + EmbeddingBridge, + LinearBridge, + MLPBridge, + NormalizationBridge, + PosEmbedBridge, + UnembeddingBridge, +) +from transformer_lens.model_bridge.supported_architectures.neo import ( + NeoArchitectureAdapter, + NeoLinearTransposeConversion, +) + +# --------------------------------------------------------------------------- +# Helpers / fixtures +# --------------------------------------------------------------------------- + + +def _make_cfg( + n_heads: int = 4, + d_model: int = 64, + n_layers: int = 2, + d_mlp: int = 256, + d_vocab: int = 1000, + n_ctx: int = 512, +) -> TransformerBridgeConfig: + """Return a minimal TransformerBridgeConfig for Neo adapter tests.""" + return TransformerBridgeConfig( + d_model=d_model, + d_head=d_model // n_heads, + n_layers=n_layers, + n_ctx=n_ctx, + n_heads=n_heads, + d_vocab=d_vocab, + d_mlp=d_mlp, + default_prepend_bos=True, + architecture="GPTNeoForCausalLM", + ) + + +@pytest.fixture +def cfg() -> TransformerBridgeConfig: + return _make_cfg() + + +@pytest.fixture +def adapter(cfg: TransformerBridgeConfig) -> NeoArchitectureAdapter: + return NeoArchitectureAdapter(cfg) + + +# --------------------------------------------------------------------------- +# Config attribute tests +# --------------------------------------------------------------------------- + + +class TestNeoAdapterConfig: + """Tests that the adapter sets required config attributes correctly.""" + + def test_normalization_type_is_ln(self, adapter: NeoArchitectureAdapter) -> None: + assert adapter.cfg.normalization_type == "LN" + + def test_positional_embedding_type_is_standard(self, adapter: NeoArchitectureAdapter) -> None: + assert adapter.cfg.positional_embedding_type == "standard" + + def test_final_rms_is_false(self, adapter: NeoArchitectureAdapter) -> None: + assert adapter.cfg.final_rms is False + + def test_gated_mlp_is_false(self, adapter: NeoArchitectureAdapter) -> None: + assert adapter.cfg.gated_mlp is False + + def test_attn_only_is_false(self, adapter: NeoArchitectureAdapter) -> None: + assert adapter.cfg.attn_only is False + + +# --------------------------------------------------------------------------- +# Component mapping structure tests +# --------------------------------------------------------------------------- + + +class TestNeoAdapterComponentMapping: + """Tests that component_mapping has the correct bridge types and HF module names.""" + + # -- Top-level keys -- + + def test_embed_is_embedding_bridge(self, adapter: NeoArchitectureAdapter) -> None: + assert isinstance(adapter.component_mapping["embed"], EmbeddingBridge) + + def test_embed_name(self, adapter: NeoArchitectureAdapter) -> None: + assert adapter.component_mapping["embed"].name == "transformer.wte" + + def test_pos_embed_is_pos_embed_bridge(self, adapter: NeoArchitectureAdapter) -> None: + assert isinstance(adapter.component_mapping["pos_embed"], PosEmbedBridge) + + def test_pos_embed_name(self, adapter: NeoArchitectureAdapter) -> None: + assert adapter.component_mapping["pos_embed"].name == "transformer.wpe" + + def test_blocks_is_block_bridge(self, adapter: NeoArchitectureAdapter) -> None: + assert isinstance(adapter.component_mapping["blocks"], BlockBridge) + + def test_blocks_name(self, adapter: NeoArchitectureAdapter) -> None: + assert adapter.component_mapping["blocks"].name == "transformer.h" + + def test_ln_final_is_normalization_bridge(self, adapter: NeoArchitectureAdapter) -> None: + assert isinstance(adapter.component_mapping["ln_final"], NormalizationBridge) + + def test_ln_final_name(self, adapter: NeoArchitectureAdapter) -> None: + assert adapter.component_mapping["ln_final"].name == "transformer.ln_f" + + def test_unembed_is_unembedding_bridge(self, adapter: NeoArchitectureAdapter) -> None: + assert isinstance(adapter.component_mapping["unembed"], UnembeddingBridge) + + def test_unembed_name(self, adapter: NeoArchitectureAdapter) -> None: + assert adapter.component_mapping["unembed"].name == "lm_head" + + # -- Block submodules -- + + def test_blocks_ln1_is_normalization_bridge(self, adapter: NeoArchitectureAdapter) -> None: + assert isinstance( + adapter.component_mapping["blocks"].submodules["ln1"], NormalizationBridge + ) + + def test_blocks_ln1_name(self, adapter: NeoArchitectureAdapter) -> None: + assert adapter.component_mapping["blocks"].submodules["ln1"].name == "ln_1" + + def test_blocks_ln2_is_normalization_bridge(self, adapter: NeoArchitectureAdapter) -> None: + assert isinstance( + adapter.component_mapping["blocks"].submodules["ln2"], NormalizationBridge + ) + + def test_blocks_ln2_name(self, adapter: NeoArchitectureAdapter) -> None: + assert adapter.component_mapping["blocks"].submodules["ln2"].name == "ln_2" + + def test_attn_is_attention_bridge(self, adapter: NeoArchitectureAdapter) -> None: + """Neo uses separate Q/K/V projections (AttentionBridge), unlike GPT-2's combined QKV.""" + blocks = adapter.component_mapping["blocks"] + assert isinstance(blocks.submodules["attn"], AttentionBridge) + + def test_attn_name(self, adapter: NeoArchitectureAdapter) -> None: + """Neo's attention submodule is nested as attn.attention in HuggingFace.""" + blocks = adapter.component_mapping["blocks"] + assert blocks.submodules["attn"].name == "attn.attention" + + def test_attn_q_is_linear_bridge(self, adapter: NeoArchitectureAdapter) -> None: + attn = adapter.component_mapping["blocks"].submodules["attn"] + assert isinstance(attn.submodules["q"], LinearBridge) + + def test_attn_q_name(self, adapter: NeoArchitectureAdapter) -> None: + attn = adapter.component_mapping["blocks"].submodules["attn"] + assert attn.submodules["q"].name == "q_proj" + + def test_attn_k_is_linear_bridge(self, adapter: NeoArchitectureAdapter) -> None: + attn = adapter.component_mapping["blocks"].submodules["attn"] + assert isinstance(attn.submodules["k"], LinearBridge) + + def test_attn_k_name(self, adapter: NeoArchitectureAdapter) -> None: + attn = adapter.component_mapping["blocks"].submodules["attn"] + assert attn.submodules["k"].name == "k_proj" + + def test_attn_v_is_linear_bridge(self, adapter: NeoArchitectureAdapter) -> None: + attn = adapter.component_mapping["blocks"].submodules["attn"] + assert isinstance(attn.submodules["v"], LinearBridge) + + def test_attn_v_name(self, adapter: NeoArchitectureAdapter) -> None: + attn = adapter.component_mapping["blocks"].submodules["attn"] + assert attn.submodules["v"].name == "v_proj" + + def test_attn_o_is_linear_bridge(self, adapter: NeoArchitectureAdapter) -> None: + attn = adapter.component_mapping["blocks"].submodules["attn"] + assert isinstance(attn.submodules["o"], LinearBridge) + + def test_attn_o_name(self, adapter: NeoArchitectureAdapter) -> None: + attn = adapter.component_mapping["blocks"].submodules["attn"] + assert attn.submodules["o"].name == "out_proj" + + def test_mlp_is_mlp_bridge(self, adapter: NeoArchitectureAdapter) -> None: + blocks = adapter.component_mapping["blocks"] + assert isinstance(blocks.submodules["mlp"], MLPBridge) + + def test_mlp_name(self, adapter: NeoArchitectureAdapter) -> None: + assert adapter.component_mapping["blocks"].submodules["mlp"].name == "mlp" + + def test_mlp_in_name(self, adapter: NeoArchitectureAdapter) -> None: + mlp = adapter.component_mapping["blocks"].submodules["mlp"] + assert mlp.submodules["in"].name == "c_fc" + + def test_mlp_out_name(self, adapter: NeoArchitectureAdapter) -> None: + mlp = adapter.component_mapping["blocks"].submodules["mlp"] + assert mlp.submodules["out"].name == "c_proj" + + +# --------------------------------------------------------------------------- +# Weight processing conversion tests +# --------------------------------------------------------------------------- + + +class TestNeoAdapterWeightConversions: + """Tests that weight_processing_conversions has exactly the expected keys.""" + + @pytest.mark.parametrize( + "key", + [ + "blocks.{i}.attn.q.weight", + "blocks.{i}.attn.k.weight", + "blocks.{i}.attn.v.weight", + "blocks.{i}.attn.o.weight", + "blocks.{i}.mlp.in.weight", + "blocks.{i}.mlp.out.weight", + "blocks.{i}.attn.q.bias", + "blocks.{i}.attn.k.bias", + "blocks.{i}.attn.v.bias", + ], + ) + def test_conversion_key_present(self, adapter: NeoArchitectureAdapter, key: str) -> None: + assert key in adapter.weight_processing_conversions + + def test_exactly_nine_conversion_keys(self, adapter: NeoArchitectureAdapter) -> None: + assert len(adapter.weight_processing_conversions) == 9 + + +# --------------------------------------------------------------------------- +# NeoLinearTransposeConversion — numerical correctness tests +# --------------------------------------------------------------------------- + + +class TestNeoLinearTransposeConversion: + """Numerical correctness of Neo's Linear weight transposition.""" + + D_MODEL, N_HEADS, D_HEAD = 64, 4, 16 # D_MODEL = N_HEADS * D_HEAD + + def test_transpose_only_roundtrips(self) -> None: + """A weight transposed and reverted should recover the original.""" + torch.manual_seed(0) + conv = NeoLinearTransposeConversion() + original = torch.randn(self.D_MODEL, self.D_MODEL) + reverted = conv.revert(conv.handle_conversion(original)) + assert reverted.shape == original.shape + assert torch.allclose(original, reverted) + + def test_transpose_changes_shape(self) -> None: + """handle_conversion transposes [out, in] -> [in, out].""" + w = torch.zeros(128, 64) # [out_features, in_features] + out = NeoLinearTransposeConversion().handle_conversion(w) + assert out.shape == (64, 128) + + def test_transpose_with_rearrange_q_weight(self) -> None: + """Q/K/V weight: [d_model, n*d_head] -> transpose -> rearrange to [n, d_model, d_head].""" + conv = NeoLinearTransposeConversion("d_model (n h) -> n d_model h", n=self.N_HEADS) + w = torch.randn(self.D_MODEL, self.N_HEADS * self.D_HEAD) + out = conv.handle_conversion(w) + assert out.shape == (self.N_HEADS, self.D_MODEL, self.D_HEAD) + + def test_transpose_with_rearrange_o_weight(self) -> None: + """O weight: [n*d_head, d_model] -> transpose -> rearrange to [n, d_head, d_model].""" + conv = NeoLinearTransposeConversion("(n h) d_model -> n h d_model", n=self.N_HEADS) + w = torch.randn(self.N_HEADS * self.D_HEAD, self.D_MODEL) + out = conv.handle_conversion(w) + assert out.shape == (self.N_HEADS, self.D_HEAD, self.D_MODEL) + + def test_rearrange_roundtrip(self) -> None: + """handle_conversion -> revert recovers the original weight for Q projection.""" + torch.manual_seed(1) + conv = NeoLinearTransposeConversion("d_model (n h) -> n d_model h", n=self.N_HEADS) + original = torch.randn(self.D_MODEL, self.N_HEADS * self.D_HEAD) + recovered = conv.revert(conv.handle_conversion(original)) + assert recovered.shape == original.shape + assert torch.allclose(original, recovered, atol=1e-6) + + def test_values_preserved_after_transpose(self) -> None: + """Values should be identical after transpose (not just shape).""" + w = torch.arange(12, dtype=torch.float).reshape(3, 4) + out = NeoLinearTransposeConversion().handle_conversion(w) + assert torch.allclose(out, w.T) diff --git a/tests/unit/model_bridge/supported_architectures/test_neox_adapter.py b/tests/unit/model_bridge/supported_architectures/test_neox_adapter.py new file mode 100644 index 000000000..269b34d57 --- /dev/null +++ b/tests/unit/model_bridge/supported_architectures/test_neox_adapter.py @@ -0,0 +1,250 @@ +"""Unit tests for NeoxArchitectureAdapter. + +Tests cover: +- Config attribute validation (all required attributes are set correctly) +- Component mapping structure (correct bridge types and HF module names) +- Weight conversion keys and count +""" + +import pytest + +from transformer_lens.config import TransformerBridgeConfig +from transformer_lens.model_bridge.generalized_components import ( + EmbeddingBridge, + JointQKVPositionEmbeddingsAttentionBridge, + LinearBridge, + MLPBridge, + NormalizationBridge, + ParallelBlockBridge, + RotaryEmbeddingBridge, + UnembeddingBridge, +) +from transformer_lens.model_bridge.supported_architectures.neox import ( + NeoxArchitectureAdapter, +) + +# --------------------------------------------------------------------------- +# Helpers / fixtures +# --------------------------------------------------------------------------- + + +def _make_cfg( + n_heads: int = 4, + d_model: int = 64, + n_layers: int = 2, + d_mlp: int = 256, + d_vocab: int = 1000, + n_ctx: int = 512, +) -> TransformerBridgeConfig: + """Return a minimal TransformerBridgeConfig for NeoX adapter tests.""" + return TransformerBridgeConfig( + d_model=d_model, + d_head=d_model // n_heads, + n_layers=n_layers, + n_ctx=n_ctx, + n_heads=n_heads, + d_vocab=d_vocab, + d_mlp=d_mlp, + default_prepend_bos=False, + architecture="GPTNeoXForCausalLM", + ) + + +@pytest.fixture +def cfg() -> TransformerBridgeConfig: + return _make_cfg() + + +@pytest.fixture +def adapter(cfg: TransformerBridgeConfig) -> NeoxArchitectureAdapter: + return NeoxArchitectureAdapter(cfg) + + +# --------------------------------------------------------------------------- +# Config attribute tests +# --------------------------------------------------------------------------- + + +class TestNeoxAdapterConfig: + """Tests that the adapter sets required config attributes correctly.""" + + def test_normalization_type_is_ln(self, adapter: NeoxArchitectureAdapter) -> None: + assert adapter.cfg.normalization_type == "LN" + + def test_positional_embedding_type_is_rotary(self, adapter: NeoxArchitectureAdapter) -> None: + assert adapter.cfg.positional_embedding_type == "rotary" + + def test_final_rms_is_false(self, adapter: NeoxArchitectureAdapter) -> None: + assert adapter.cfg.final_rms is False + + def test_gated_mlp_is_false(self, adapter: NeoxArchitectureAdapter) -> None: + assert adapter.cfg.gated_mlp is False + + def test_attn_only_is_false(self, adapter: NeoxArchitectureAdapter) -> None: + assert adapter.cfg.attn_only is False + + def test_parallel_attn_mlp_is_true(self, adapter: NeoxArchitectureAdapter) -> None: + assert adapter.cfg.parallel_attn_mlp is True + + def test_default_prepend_bos_is_false(self, adapter: NeoxArchitectureAdapter) -> None: + """GPT-NeoX/Pythia models were not trained with BOS tokens.""" + assert adapter.cfg.default_prepend_bos is False + + +# --------------------------------------------------------------------------- +# Component mapping structure tests +# --------------------------------------------------------------------------- + + +class TestNeoxAdapterComponentMapping: + """Tests that component_mapping has the correct bridge types and HF module names.""" + + # -- Top-level keys -- + + def test_embed_is_embedding_bridge(self, adapter: NeoxArchitectureAdapter) -> None: + assert isinstance(adapter.component_mapping["embed"], EmbeddingBridge) + + def test_embed_name(self, adapter: NeoxArchitectureAdapter) -> None: + assert adapter.component_mapping["embed"].name == "gpt_neox.embed_in" + + def test_rotary_emb_is_rotary_embedding_bridge(self, adapter: NeoxArchitectureAdapter) -> None: + """NeoX uses rotary embeddings instead of learned positional embeddings.""" + assert isinstance(adapter.component_mapping["rotary_emb"], RotaryEmbeddingBridge) + + def test_rotary_emb_name(self, adapter: NeoxArchitectureAdapter) -> None: + assert adapter.component_mapping["rotary_emb"].name == "gpt_neox.rotary_emb" + + def test_no_pos_embed_key(self, adapter: NeoxArchitectureAdapter) -> None: + """NeoX has no learned positional embedding — uses rotary instead.""" + assert "pos_embed" not in adapter.component_mapping + + def test_blocks_is_parallel_block_bridge(self, adapter: NeoxArchitectureAdapter) -> None: + """NeoX runs attention and MLP in parallel (ParallelBlockBridge).""" + assert isinstance(adapter.component_mapping["blocks"], ParallelBlockBridge) + + def test_blocks_name(self, adapter: NeoxArchitectureAdapter) -> None: + assert adapter.component_mapping["blocks"].name == "gpt_neox.layers" + + def test_ln_final_is_normalization_bridge(self, adapter: NeoxArchitectureAdapter) -> None: + assert isinstance(adapter.component_mapping["ln_final"], NormalizationBridge) + + def test_ln_final_name(self, adapter: NeoxArchitectureAdapter) -> None: + assert adapter.component_mapping["ln_final"].name == "gpt_neox.final_layer_norm" + + def test_unembed_is_unembedding_bridge(self, adapter: NeoxArchitectureAdapter) -> None: + assert isinstance(adapter.component_mapping["unembed"], UnembeddingBridge) + + def test_unembed_name(self, adapter: NeoxArchitectureAdapter) -> None: + assert adapter.component_mapping["unembed"].name == "embed_out" + + # -- Block submodules -- + + def test_blocks_ln1_is_normalization_bridge(self, adapter: NeoxArchitectureAdapter) -> None: + assert isinstance( + adapter.component_mapping["blocks"].submodules["ln1"], NormalizationBridge + ) + + def test_blocks_ln1_name(self, adapter: NeoxArchitectureAdapter) -> None: + assert adapter.component_mapping["blocks"].submodules["ln1"].name == "input_layernorm" + + def test_blocks_ln2_is_normalization_bridge(self, adapter: NeoxArchitectureAdapter) -> None: + assert isinstance( + adapter.component_mapping["blocks"].submodules["ln2"], NormalizationBridge + ) + + def test_blocks_ln2_name(self, adapter: NeoxArchitectureAdapter) -> None: + assert ( + adapter.component_mapping["blocks"].submodules["ln2"].name == "post_attention_layernorm" + ) + + def test_attn_is_joint_qkv_position_embeddings_bridge( + self, adapter: NeoxArchitectureAdapter + ) -> None: + """NeoX uses a combined QKV matrix with rotary embeddings.""" + blocks = adapter.component_mapping["blocks"] + assert isinstance(blocks.submodules["attn"], JointQKVPositionEmbeddingsAttentionBridge) + + def test_attn_name(self, adapter: NeoxArchitectureAdapter) -> None: + blocks = adapter.component_mapping["blocks"] + assert blocks.submodules["attn"].name == "attention" + + def test_attn_requires_attention_mask(self, adapter: NeoxArchitectureAdapter) -> None: + """GPTNeoX/StableLM requires an explicit attention mask.""" + attn = adapter.component_mapping["blocks"].submodules["attn"] + assert attn.requires_attention_mask is True + + def test_attn_qkv_is_linear_bridge(self, adapter: NeoxArchitectureAdapter) -> None: + attn = adapter.component_mapping["blocks"].submodules["attn"] + assert isinstance(attn.submodules["qkv"], LinearBridge) + + def test_attn_qkv_name(self, adapter: NeoxArchitectureAdapter) -> None: + attn = adapter.component_mapping["blocks"].submodules["attn"] + assert attn.submodules["qkv"].name == "query_key_value" + + def test_attn_o_is_linear_bridge(self, adapter: NeoxArchitectureAdapter) -> None: + attn = adapter.component_mapping["blocks"].submodules["attn"] + assert isinstance(attn.submodules["o"], LinearBridge) + + def test_attn_o_name(self, adapter: NeoxArchitectureAdapter) -> None: + attn = adapter.component_mapping["blocks"].submodules["attn"] + assert attn.submodules["o"].name == "dense" + + def test_mlp_is_mlp_bridge(self, adapter: NeoxArchitectureAdapter) -> None: + blocks = adapter.component_mapping["blocks"] + assert isinstance(blocks.submodules["mlp"], MLPBridge) + + def test_mlp_name(self, adapter: NeoxArchitectureAdapter) -> None: + assert adapter.component_mapping["blocks"].submodules["mlp"].name == "mlp" + + def test_mlp_in_name(self, adapter: NeoxArchitectureAdapter) -> None: + mlp = adapter.component_mapping["blocks"].submodules["mlp"] + assert mlp.submodules["in"].name == "dense_h_to_4h" + + def test_mlp_out_name(self, adapter: NeoxArchitectureAdapter) -> None: + mlp = adapter.component_mapping["blocks"].submodules["mlp"] + assert mlp.submodules["out"].name == "dense_4h_to_h" + + +# --------------------------------------------------------------------------- +# Weight processing conversion tests +# --------------------------------------------------------------------------- + + +class TestNeoxAdapterWeightConversions: + """Tests that weight_processing_conversions has exactly the expected keys.""" + + @pytest.mark.parametrize( + "key", + [ + "blocks.{i}.attn.q", + "blocks.{i}.attn.k", + "blocks.{i}.attn.v", + "blocks.{i}.attn.b_Q", + "blocks.{i}.attn.b_K", + "blocks.{i}.attn.b_V", + "blocks.{i}.attn.o", + ], + ) + def test_conversion_key_present(self, adapter: NeoxArchitectureAdapter, key: str) -> None: + assert key in adapter.weight_processing_conversions + + def test_exactly_seven_conversion_keys(self, adapter: NeoxArchitectureAdapter) -> None: + assert len(adapter.weight_processing_conversions) == 7 + + def test_qkv_conversions_share_source_key(self, adapter: NeoxArchitectureAdapter) -> None: + """Q, K, V weights all come from the same combined QKV matrix in HuggingFace.""" + expected_source = "gpt_neox.layers.{i}.attention.query_key_value.weight" + for key in ("blocks.{i}.attn.q", "blocks.{i}.attn.k", "blocks.{i}.attn.v"): + assert adapter.weight_processing_conversions[key].source_key == expected_source + + def test_bias_conversions_share_source_key(self, adapter: NeoxArchitectureAdapter) -> None: + """Q, K, V biases all come from the same combined QKV bias vector.""" + expected_source = "gpt_neox.layers.{i}.attention.query_key_value.bias" + for key in ("blocks.{i}.attn.b_Q", "blocks.{i}.attn.b_K", "blocks.{i}.attn.b_V"): + assert adapter.weight_processing_conversions[key].source_key == expected_source + + def test_o_conversion_source_key(self, adapter: NeoxArchitectureAdapter) -> None: + expected_source = "gpt_neox.layers.{i}.attention.dense.weight" + assert ( + adapter.weight_processing_conversions["blocks.{i}.attn.o"].source_key == expected_source + ) diff --git a/tests/unit/model_bridge/supported_architectures/test_openelm_adapter.py b/tests/unit/model_bridge/supported_architectures/test_openelm_adapter.py new file mode 100644 index 000000000..d8166bf0d --- /dev/null +++ b/tests/unit/model_bridge/supported_architectures/test_openelm_adapter.py @@ -0,0 +1,219 @@ +"""Unit tests for OpenElmArchitectureAdapter. + +Tests cover: +- Config attribute validation (all required attributes are set correctly) +- Component mapping structure (correct bridge types and HF module names) +- Weight conversion keys (empty for OpenELM — native attention handles all variants) +""" + +import pytest + +from transformer_lens.config import TransformerBridgeConfig +from transformer_lens.model_bridge.generalized_components import ( + AttentionBridge, + BlockBridge, + EmbeddingBridge, + LinearBridge, + MLPBridge, + RMSNormalizationBridge, + UnembeddingBridge, +) +from transformer_lens.model_bridge.supported_architectures.openelm import ( + OpenElmArchitectureAdapter, +) + +# --------------------------------------------------------------------------- +# Helpers / fixtures +# --------------------------------------------------------------------------- + + +def _make_cfg( + n_heads: int = 4, + d_model: int = 64, + n_layers: int = 2, + d_mlp: int = 256, + d_vocab: int = 1000, + n_ctx: int = 512, +) -> TransformerBridgeConfig: + """Return a minimal TransformerBridgeConfig for OpenELM adapter tests.""" + return TransformerBridgeConfig( + d_model=d_model, + d_head=d_model // n_heads, + n_layers=n_layers, + n_ctx=n_ctx, + n_heads=n_heads, + d_vocab=d_vocab, + d_mlp=d_mlp, + default_prepend_bos=True, + architecture="OpenELMForCausalLM", + ) + + +@pytest.fixture +def cfg() -> TransformerBridgeConfig: + return _make_cfg() + + +@pytest.fixture +def adapter(cfg: TransformerBridgeConfig) -> OpenElmArchitectureAdapter: + return OpenElmArchitectureAdapter(cfg) + + +# --------------------------------------------------------------------------- +# Config attribute tests +# --------------------------------------------------------------------------- + + +class TestOpenElmAdapterConfig: + """Tests that the adapter sets required config attributes correctly.""" + + def test_normalization_type_is_rms(self, adapter: OpenElmArchitectureAdapter) -> None: + assert adapter.cfg.normalization_type == "RMS" + + def test_positional_embedding_type_is_rotary(self, adapter: OpenElmArchitectureAdapter) -> None: + assert adapter.cfg.positional_embedding_type == "rotary" + + def test_final_rms_is_true(self, adapter: OpenElmArchitectureAdapter) -> None: + assert adapter.cfg.final_rms is True + + def test_gated_mlp_is_true(self, adapter: OpenElmArchitectureAdapter) -> None: + assert adapter.cfg.gated_mlp is True + + def test_attn_only_is_false(self, adapter: OpenElmArchitectureAdapter) -> None: + assert adapter.cfg.attn_only is False + + def test_uses_rms_norm_is_true(self, adapter: OpenElmArchitectureAdapter) -> None: + assert adapter.cfg.uses_rms_norm is True + + def test_tokenizer_name(self, adapter: OpenElmArchitectureAdapter) -> None: + """OpenELM has no bundled tokenizer — uses LLaMA-2 tokenizer as proxy.""" + assert adapter.cfg.tokenizer_name == "NousResearch/Llama-2-7b-hf" + + +# --------------------------------------------------------------------------- +# Component mapping structure tests +# --------------------------------------------------------------------------- + + +class TestOpenElmAdapterComponentMapping: + """Tests that component_mapping has the correct bridge types and HF module names.""" + + # -- Top-level keys -- + + def test_embed_is_embedding_bridge(self, adapter: OpenElmArchitectureAdapter) -> None: + assert isinstance(adapter.component_mapping["embed"], EmbeddingBridge) + + def test_embed_name(self, adapter: OpenElmArchitectureAdapter) -> None: + assert adapter.component_mapping["embed"].name == "transformer.token_embeddings" + + def test_no_pos_embed_key(self, adapter: OpenElmArchitectureAdapter) -> None: + """OpenELM uses per-layer rotary embeddings — no shared positional embedding.""" + assert "pos_embed" not in adapter.component_mapping + + def test_no_rotary_emb_key(self, adapter: OpenElmArchitectureAdapter) -> None: + """OpenELM RoPE is embedded per-layer in attention, not a top-level component.""" + assert "rotary_emb" not in adapter.component_mapping + + def test_blocks_is_block_bridge(self, adapter: OpenElmArchitectureAdapter) -> None: + assert isinstance(adapter.component_mapping["blocks"], BlockBridge) + + def test_blocks_name(self, adapter: OpenElmArchitectureAdapter) -> None: + assert adapter.component_mapping["blocks"].name == "transformer.layers" + + def test_ln_final_is_rms_normalization_bridge( + self, adapter: OpenElmArchitectureAdapter + ) -> None: + assert isinstance(adapter.component_mapping["ln_final"], RMSNormalizationBridge) + + def test_ln_final_name(self, adapter: OpenElmArchitectureAdapter) -> None: + assert adapter.component_mapping["ln_final"].name == "transformer.norm" + + def test_unembed_is_unembedding_bridge(self, adapter: OpenElmArchitectureAdapter) -> None: + assert isinstance(adapter.component_mapping["unembed"], UnembeddingBridge) + + def test_unembed_name(self, adapter: OpenElmArchitectureAdapter) -> None: + assert adapter.component_mapping["unembed"].name == "lm_head" + + # -- Block submodules -- + + def test_blocks_ln1_is_rms_normalization_bridge( + self, adapter: OpenElmArchitectureAdapter + ) -> None: + assert isinstance( + adapter.component_mapping["blocks"].submodules["ln1"], RMSNormalizationBridge + ) + + def test_blocks_ln1_name(self, adapter: OpenElmArchitectureAdapter) -> None: + assert adapter.component_mapping["blocks"].submodules["ln1"].name == "attn_norm" + + def test_blocks_ln2_is_rms_normalization_bridge( + self, adapter: OpenElmArchitectureAdapter + ) -> None: + assert isinstance( + adapter.component_mapping["blocks"].submodules["ln2"], RMSNormalizationBridge + ) + + def test_blocks_ln2_name(self, adapter: OpenElmArchitectureAdapter) -> None: + assert adapter.component_mapping["blocks"].submodules["ln2"].name == "ffn_norm" + + def test_attn_is_attention_bridge(self, adapter: OpenElmArchitectureAdapter) -> None: + blocks = adapter.component_mapping["blocks"] + assert isinstance(blocks.submodules["attn"], AttentionBridge) + + def test_attn_name(self, adapter: OpenElmArchitectureAdapter) -> None: + blocks = adapter.component_mapping["blocks"] + assert blocks.submodules["attn"].name == "attn" + + def test_attn_requires_attention_mask(self, adapter: OpenElmArchitectureAdapter) -> None: + attn = adapter.component_mapping["blocks"].submodules["attn"] + assert attn.requires_attention_mask is True + + def test_attn_qkv_is_linear_bridge(self, adapter: OpenElmArchitectureAdapter) -> None: + """OpenELM uses a combined QKV projection (not separate q/k/v).""" + attn = adapter.component_mapping["blocks"].submodules["attn"] + assert isinstance(attn.submodules["qkv"], LinearBridge) + + def test_attn_qkv_name(self, adapter: OpenElmArchitectureAdapter) -> None: + attn = adapter.component_mapping["blocks"].submodules["attn"] + assert attn.submodules["qkv"].name == "qkv_proj" + + def test_attn_o_is_linear_bridge(self, adapter: OpenElmArchitectureAdapter) -> None: + attn = adapter.component_mapping["blocks"].submodules["attn"] + assert isinstance(attn.submodules["o"], LinearBridge) + + def test_attn_o_name(self, adapter: OpenElmArchitectureAdapter) -> None: + attn = adapter.component_mapping["blocks"].submodules["attn"] + assert attn.submodules["o"].name == "out_proj" + + def test_mlp_is_mlp_bridge(self, adapter: OpenElmArchitectureAdapter) -> None: + blocks = adapter.component_mapping["blocks"] + assert isinstance(blocks.submodules["mlp"], MLPBridge) + + def test_mlp_name(self, adapter: OpenElmArchitectureAdapter) -> None: + """OpenELM names its MLP submodule 'ffn' (feedforward network).""" + assert adapter.component_mapping["blocks"].submodules["mlp"].name == "ffn" + + def test_mlp_in_name(self, adapter: OpenElmArchitectureAdapter) -> None: + mlp = adapter.component_mapping["blocks"].submodules["mlp"] + assert mlp.submodules["in"].name == "proj_1" + + def test_mlp_out_name(self, adapter: OpenElmArchitectureAdapter) -> None: + mlp = adapter.component_mapping["blocks"].submodules["mlp"] + assert mlp.submodules["out"].name == "proj_2" + + +# --------------------------------------------------------------------------- +# Weight processing conversion tests +# --------------------------------------------------------------------------- + + +class TestOpenElmAdapterWeightConversions: + """Tests that weight_processing_conversions is empty for OpenELM. + + OpenELM uses per-layer varying head counts and FFN dimensions handled + entirely by native HuggingFace attention — no static weight rearrangements + are needed at the bridge level. + """ + + def test_no_weight_processing_conversions(self, adapter: OpenElmArchitectureAdapter) -> None: + assert len(adapter.weight_processing_conversions) == 0