diff --git a/tests/unit/model_bridge/compatibility/test_svd_interpreter.py b/tests/unit/model_bridge/compatibility/test_svd_interpreter.py index b2b0de900..e73a087d6 100644 --- a/tests/unit/model_bridge/compatibility/test_svd_interpreter.py +++ b/tests/unit/model_bridge/compatibility/test_svd_interpreter.py @@ -5,7 +5,7 @@ from transformer_lens import SVDInterpreter from transformer_lens.model_bridge import TransformerBridge -MODEL = "gpt2" # Use a model that works with TransformerBridge +MODEL = "Intel/tiny-random-gpt2" # Use a model that works with TransformerBridge VECTOR_TYPES = ["OV", "w_in", "w_out"] ATOL = 2e-4 # Absolute tolerance - how far does a float have to be before we consider it no longer equal? @@ -17,14 +17,7 @@ def model(): @pytest.fixture(scope="module") def second_model(): - # Use a different model architecture if available, otherwise same model - # Note: If gpt2-medium fails to load, tests that need different models will be skipped - try: - return TransformerBridge.boot_transformers("gpt2-medium", device="cpu") - except Exception: - # Fallback to same model if gpt2-medium is not available - # The test will skip if both models end up being the same - return TransformerBridge.boot_transformers(MODEL, device="cpu") + return TransformerBridge.boot_transformers("hyper-accel/tiny-random-gpt2", device="cpu") def test_svd_interpreter_returns_meaningful_values(model): @@ -56,10 +49,6 @@ def test_svd_interpreter_returns_meaningful_values(model): def test_svd_interpreter_returns_different_answers_for_different_layers(model): - # Only test if model has multiple layers - if model.cfg.n_layers < 2: - pytest.skip("Model only has one layer") - svd_interpreter = SVDInterpreter(model) # Layer 0 results @@ -91,10 +80,6 @@ def test_svd_interpreter_returns_different_answers_for_different_layers(model): def test_svd_interpreter_returns_different_answers_for_different_models(model, second_model): - # Skip if both models are the same (check model name/config, not just object ID) - if id(model) == id(second_model) or model.cfg.model_name == second_model.cfg.model_name: - pytest.skip("Same model used for both fixtures") - # Get results from first model svd_interpreter_1 = SVDInterpreter(model) ov_1 = svd_interpreter_1.get_singular_vectors(