Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 2 additions & 17 deletions tests/unit/model_bridge/compatibility/test_svd_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from transformer_lens import SVDInterpreter
from transformer_lens.model_bridge import TransformerBridge

MODEL = "gpt2" # Use a model that works with TransformerBridge
MODEL = "Intel/tiny-random-gpt2" # Use a model that works with TransformerBridge
VECTOR_TYPES = ["OV", "w_in", "w_out"]
ATOL = 2e-4 # Absolute tolerance - how far does a float have to be before we consider it no longer equal?

Expand All @@ -17,14 +17,7 @@ def model():

@pytest.fixture(scope="module")
def second_model():
# Use a different model architecture if available, otherwise same model
# Note: If gpt2-medium fails to load, tests that need different models will be skipped
try:
return TransformerBridge.boot_transformers("gpt2-medium", device="cpu")
except Exception:
# Fallback to same model if gpt2-medium is not available
# The test will skip if both models end up being the same
return TransformerBridge.boot_transformers(MODEL, device="cpu")
return TransformerBridge.boot_transformers("hyper-accel/tiny-random-gpt2", device="cpu")


def test_svd_interpreter_returns_meaningful_values(model):
Expand Down Expand Up @@ -56,10 +49,6 @@ def test_svd_interpreter_returns_meaningful_values(model):


def test_svd_interpreter_returns_different_answers_for_different_layers(model):
# Only test if model has multiple layers
if model.cfg.n_layers < 2:
pytest.skip("Model only has one layer")

svd_interpreter = SVDInterpreter(model)

# Layer 0 results
Expand Down Expand Up @@ -91,10 +80,6 @@ def test_svd_interpreter_returns_different_answers_for_different_layers(model):


def test_svd_interpreter_returns_different_answers_for_different_models(model, second_model):
# Skip if both models are the same (check model name/config, not just object ID)
if id(model) == id(second_model) or model.cfg.model_name == second_model.cfg.model_name:
pytest.skip("Same model used for both fixtures")

# Get results from first model
svd_interpreter_1 = SVDInterpreter(model)
ov_1 = svd_interpreter_1.get_singular_vectors(
Expand Down
Loading