Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
330 changes: 330 additions & 0 deletions demos/Direct_Path_Patching_Demo.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,330 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Direct Path Patching in TransformerLens\n",
"\n",
"This notebook demonstrates **direct path patching** — a causal intervention technique for mechanistic interpretability.\n",
"\n",
"## What is path patching?\n",
"\n",
"Standard **activation patching** replaces the full residual stream at a layer with activations from a different (\"clean\") run. This tells you: *does this layer's residual stream matter for the output?*\n",
"\n",
"**Direct path patching** is more targeted: it patches only the contribution of a specific component (e.g., a single attention head) to a specific downstream position, without affecting other paths. This lets you isolate *which component, sending to which position, causes the output difference*.\n",
"\n",
"### The difference\n",
"\n",
"| Technique | What you patch | What you learn |\n",
"|---|---|---|\n",
"| Activation patching | Full residual stream at layer L, position p | Does the state at (L, p) matter? |\n",
"| Direct path patching | Output of component C → position p only | Does C's direct contribution to p matter? |\n",
"\n",
"## Setup"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Install if needed\n",
"# !pip install transformer_lens"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from transformer_lens import HookedTransformer\n",
"import einops\n",
"\n",
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
"print(f\"Using device: {device}\")\n",
"\n",
"model = HookedTransformer.from_pretrained(\"gpt2\", device=device)\n",
"model.eval()\n",
"print(f\"Loaded GPT-2: {model.cfg.n_layers} layers, d_model={model.cfg.d_model}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Part 1: Basic Activation Patching\n",
"\n",
"Before doing direct path patching, let's establish the baseline: full residual stream patching.\n",
"\n",
"We use a **clean prompt** and a **corrupted prompt** that differ in one key token."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Classic IOI-style setup: two prompts differing in one name\n",
"clean_prompt = \"When Mary and John went to the store, John gave a bag to\"\n",
"corrupt_prompt = \"When Mary and John went to the store, Mary gave a bag to\"\n",
"\n",
"clean_tokens = model.to_tokens(clean_prompt)\n",
"corrupt_tokens = model.to_tokens(corrupt_prompt)\n",
"\n",
"# The clean answer is \" Mary\", the corrupted answer is \" John\"\n",
"answer_token = model.to_single_token(\" Mary\")\n",
"wrong_token = model.to_single_token(\" John\")\n",
"\n",
"print(f\"Clean tokens shape: {clean_tokens.shape}\")\n",
"print(f\"Correct answer token: {answer_token} = '{model.to_string(answer_token)}'\")\n",
"print(f\"Wrong answer token: {wrong_token} = '{model.to_string(wrong_token)}'\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def logit_diff(logits, answer_tok=answer_token, wrong_tok=wrong_token):\n",
" \"\"\"Difference in logits between the correct and incorrect answer tokens.\"\"\"\n",
" return logits[0, -1, answer_tok] - logits[0, -1, wrong_tok]\n",
"\n",
"# Baseline measurements\n",
"with torch.no_grad():\n",
" clean_logits, clean_cache = model.run_with_cache(clean_tokens)\n",
" corrupt_logits, corrupt_cache = model.run_with_cache(corrupt_tokens)\n",
"\n",
"clean_diff = logit_diff(clean_logits).item()\n",
"corrupt_diff = logit_diff(corrupt_logits).item()\n",
"\n",
"print(f\"Clean logit diff (higher = model says Mary): {clean_diff:+.3f}\")\n",
"print(f\"Corrupted logit diff: {corrupt_diff:+.3f}\")\n",
"print(f\"\\nTotal effect to explain: {clean_diff - corrupt_diff:.3f}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def patch_residual_stream(layer, position):\n",
" \"\"\"\n",
" Patch the residual stream at (layer, position) from the clean cache\n",
" into a corrupted forward pass. Returns the normalized logit diff.\n",
" \"\"\"\n",
" hook_name = f\"blocks.{layer}.hook_resid_post\"\n",
" clean_act = clean_cache[hook_name][:, position, :].clone()\n",
"\n",
" def patch_hook(resid, hook):\n",
" resid[:, position, :] = clean_act\n",
" return resid\n",
"\n",
" with torch.no_grad():\n",
" patched_logits = model.run_with_hooks(\n",
" corrupt_tokens,\n",
" fwd_hooks=[(hook_name, patch_hook)]\n",
" )\n",
"\n",
" patched_diff = logit_diff(patched_logits).item()\n",
" # Normalized: 0 = no recovery, 1 = full recovery\n",
" return (patched_diff - corrupt_diff) / (clean_diff - corrupt_diff)\n",
"\n",
"\n",
"# Scan all layers at the final token position\n",
"final_pos = clean_tokens.shape[1] - 1\n",
"print(f\"Patching at final token position ({final_pos}):\\n\")\n",
"print(f\"{'Layer':<8} {'Normalized recovery':>20}\")\n",
"print(\"-\" * 30)\n",
"\n",
"for layer in range(model.cfg.n_layers):\n",
" recovery = patch_residual_stream(layer, final_pos)\n",
" bar = \"█\" * int(abs(recovery) * 20)\n",
" print(f\" L{layer:<5} {recovery:>+.3f} {bar}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Part 2: Direct Path Patching\n",
"\n",
"Full residual stream patching tells us *which layer* matters. But it includes contributions from all upstream components. Direct path patching isolates the contribution of a **single attention head** to the final position.\n",
"\n",
"### How it works\n",
"\n",
"An attention head's output is added to the residual stream at every position. To patch the **direct path** from head `(layer, head)` to the final position:\n",
"\n",
"1. Run the **clean** forward pass with cache\n",
"2. Start a **corrupted** forward pass\n",
"3. At the attention output hook for `(layer, head)`, replace only that head's output at the target position with the clean value\n",
"4. Let the rest of the forward pass continue normally\n",
"\n",
"This isolates the head's direct write to the residual stream — without changing what the head attends to in earlier layers."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def patch_head_output(layer, head, position):\n",
" \"\"\"\n",
" Patch the direct output of attention head (layer, head) at `position`\n",
" from the clean cache into a corrupted forward pass.\n",
" Returns normalized logit diff recovery.\n",
" \"\"\"\n",
" hook_name = f\"blocks.{layer}.attn.hook_result\"\n",
" # hook_result shape: [batch, seq, n_heads, d_model]\n",
" clean_head_out = clean_cache[hook_name][:, position, head, :].clone()\n",
"\n",
" def patch_hook(result, hook):\n",
" result[:, position, head, :] = clean_head_out\n",
" return result\n",
"\n",
" with torch.no_grad():\n",
" patched_logits = model.run_with_hooks(\n",
" corrupt_tokens,\n",
" fwd_hooks=[(hook_name, patch_hook)]\n",
" )\n",
"\n",
" patched_diff = logit_diff(patched_logits).item()\n",
" return (patched_diff - corrupt_diff) / (clean_diff - corrupt_diff)\n",
"\n",
"\n",
"# Scan all heads at the final token position\n",
"print(f\"Direct path patching: head outputs → final position ({final_pos})\")\n",
"print(f\"(Normalized recovery: 0=no effect, 1=full recovery)\\n\")\n",
"\n",
"results = {}\n",
"for layer in range(model.cfg.n_layers):\n",
" for head in range(model.cfg.n_heads):\n",
" recovery = patch_head_output(layer, head, final_pos)\n",
" results[(layer, head)] = recovery\n",
"\n",
"# Show top 10 most important heads\n",
"sorted_heads = sorted(results.items(), key=lambda x: abs(x[1]), reverse=True)\n",
"print(f\"{'Head':<12} {'Recovery':>12}\")\n",
"print(\"-\" * 26)\n",
"for (layer, head), recovery in sorted_heads[:10]:\n",
" bar = \"█\" * int(abs(recovery) * 30)\n",
" print(f\" L{layer}H{head:<6} {recovery:>+.3f} {bar}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Part 3: Visualize the Path Patching Results\n",
"\n",
"Plot a heatmap of direct path recovery across all (layer, head) pairs."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"try:\n",
" import matplotlib.pyplot as plt\n",
" import numpy as np\n",
"\n",
" grid = np.zeros((model.cfg.n_layers, model.cfg.n_heads))\n",
" for (layer, head), recovery in results.items():\n",
" grid[layer, head] = recovery\n",
"\n",
" fig, ax = plt.subplots(figsize=(12, 6))\n",
" im = ax.imshow(grid, cmap=\"RdBu\", vmin=-0.5, vmax=0.5, aspect=\"auto\")\n",
" ax.set_xlabel(\"Head\", fontsize=12)\n",
" ax.set_ylabel(\"Layer\", fontsize=12)\n",
" ax.set_title(\n",
" \"Direct Path Patching Recovery\\n\"\n",
" \"(Clean → Corrupted, patching head output at final token position)\",\n",
" fontsize=13\n",
" )\n",
" ax.set_xticks(range(model.cfg.n_heads))\n",
" ax.set_yticks(range(model.cfg.n_layers))\n",
" plt.colorbar(im, ax=ax, label=\"Normalized logit diff recovery\")\n",
" plt.tight_layout()\n",
" plt.show()\n",
" print(\"Heads with high recovery (red) have strong direct causal influence on the output.\")\n",
"except ImportError:\n",
" print(\"matplotlib not installed — skipping plot. Top heads shown in table above.\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Part 4: Why Direct Path Patching Matters\n",
"\n",
"### Difference from full activation patching\n",
"\n",
"Full residual stream patching at layer L includes contributions from *all* heads in layers 0..L plus all MLPs. It's a coarse measure.\n",
"\n",
"Direct path patching isolates a single head's write, making it much more precise for identifying specific circuit components.\n",
"\n",
"### Caution: the \"no mediation\" assumption\n",
"\n",
"Direct path patching measures the **direct** causal path: `head output → final position → logits`. It does NOT capture:\n",
"- Indirect paths where head A's output is read by head B, which then writes to the final position\n",
"- Effects mediated through MLPs\n",
"\n",
"For full circuit analysis, combine direct path patching with attention pattern analysis (`hook_attn`) to trace indirect paths.\n",
"\n",
"### Connection to the circuit discovery literature\n",
"\n",
"Direct path patching is the core operation in:\n",
"- Wang et al. (2022): *Interpretability in the Wild* (IOI circuit)\n",
"- Conmy et al. (2023): *ACDC* (automated circuit discovery)\n",
"- Marks et al. (2024): *Sparse Feature Circuits*\n",
"\n",
"### Limitations and robustness\n",
"\n",
"**Important:** High path patching recovery at a fixed position does not guarantee that the head is causally necessary for the behavior in all input distributions. Heads that appear causal on one prompt distribution may have their contribution \"read\" by the model differently on adversarially constructed inputs.\n",
"\n",
"Always validate circuit findings with:\n",
"1. Causal ablations (zero-ablate or mean-ablate the head)\n",
"2. Multiple prompt distributions\n",
"3. Cross-model replication"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Summary\n",
"\n",
"| Operation | TransformerLens hook | What it patches |\n",
"|---|---|---|\n",
"| Residual stream patching | `blocks.{L}.hook_resid_post` | Full residual stream at layer L |\n",
"| Direct head output patching | `blocks.{L}.attn.hook_result` | Single head's write to residual stream |\n",
"| Attention pattern patching | `blocks.{L}.attn.hook_attn` | Which tokens a head attends to |\n",
"| MLP output patching | `blocks.{L}.hook_mlp_out` | MLP contribution to residual stream |\n",
"\n",
"Each hook corresponds to a different granularity of causal intervention. Start coarse (residual stream), then refine (head output) to identify the minimal circuit responsible for a behavior."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.10.0"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Loading
Loading