diff --git a/claudecode/claude_api_client.py b/claudecode/claude_api_client.py index 19a6d40..c5dab80 100644 --- a/claudecode/claude_api_client.py +++ b/claudecode/claude_api_client.py @@ -59,7 +59,7 @@ def validate_api_access(self) -> Tuple[bool, str]: try: # Simple test call to verify API access self.client.messages.create( - model="claude-3-5-haiku-20241022", + model=self.model, max_tokens=10, messages=[{"role": "user", "content": "Hello"}], timeout=10 diff --git a/claudecode/github_action_audit.py b/claudecode/github_action_audit.py index 7e9f608..62729aa 100644 --- a/claudecode/github_action_audit.py +++ b/claudecode/github_action_audit.py @@ -417,6 +417,7 @@ def initialize_findings_filter(custom_filtering_instructions: Optional[str] = No use_hard_exclusions=True, use_claude_filtering=True, api_key=api_key, + model=DEFAULT_CLAUDE_MODEL, custom_filtering_instructions=custom_filtering_instructions ) else: diff --git a/claudecode/test_helper_functions.py b/claudecode/test_helper_functions.py index 374a85c..f2a78cb 100644 --- a/claudecode/test_helper_functions.py +++ b/claudecode/test_helper_functions.py @@ -5,6 +5,7 @@ from unittest.mock import patch, MagicMock +from claudecode.constants import DEFAULT_CLAUDE_MODEL from claudecode.github_action_audit import ( get_environment_config, initialize_clients, @@ -115,6 +116,7 @@ def test_initialize_findings_filter_with_claude(self, mock_filter): use_hard_exclusions=True, use_claude_filtering=True, api_key='test-key-123', + model=DEFAULT_CLAUDE_MODEL, custom_filtering_instructions=None ) diff --git a/claudecode/test_model_configuration.py b/claudecode/test_model_configuration.py new file mode 100644 index 0000000..4545f43 --- /dev/null +++ b/claudecode/test_model_configuration.py @@ -0,0 +1,151 @@ +#!/usr/bin/env python3 +""" +Tests for model configuration propagation. + +Verifies that: +1. validate_api_access() uses self.model (not a hardcoded model string) +2. model parameter flows from FindingsFilter through to ClaudeAPIClient +3. GitHub Actions ::warning:: annotation is emitted when API validation fails +""" + +import io +import os +import sys +from unittest.mock import Mock, patch, MagicMock + +import pytest + +from claudecode.constants import DEFAULT_CLAUDE_MODEL + + +class TestValidateApiAccessUsesConfiguredModel: + """Test that validate_api_access() uses self.model instead of a hardcoded value.""" + + @patch("claudecode.claude_api_client.Anthropic") + def test_validate_uses_self_model(self, mock_anthropic_cls): + """validate_api_access should pass self.model to messages.create, not a hardcoded string.""" + mock_client = MagicMock() + mock_anthropic_cls.return_value = mock_client + + from claudecode.claude_api_client import ClaudeAPIClient + + custom_model = "claude-sonnet-4-20250514" + with patch.dict(os.environ, {"ANTHROPIC_API_KEY": "test-key"}): + client = ClaudeAPIClient(model=custom_model, api_key="test-key") + client.validate_api_access() + + call_kwargs = mock_client.messages.create.call_args[1] + assert call_kwargs["model"] == custom_model + + @patch("claudecode.claude_api_client.Anthropic") + def test_validate_never_uses_deprecated_haiku(self, mock_anthropic_cls): + """Regression: validate_api_access must never send the hardcoded deprecated model.""" + mock_client = MagicMock() + mock_anthropic_cls.return_value = mock_client + + from claudecode.claude_api_client import ClaudeAPIClient + + deprecated = "claude-3-5-haiku-20241022" + with patch.dict(os.environ, {"ANTHROPIC_API_KEY": "test-key"}): + client = ClaudeAPIClient(model=DEFAULT_CLAUDE_MODEL, api_key="test-key") + client.validate_api_access() + + call_kwargs = mock_client.messages.create.call_args[1] + assert call_kwargs["model"] != deprecated + + @patch("claudecode.claude_api_client.Anthropic") + def test_validate_uses_default_model_when_none_specified(self, mock_anthropic_cls): + """When model is not specified, validate_api_access should use DEFAULT_CLAUDE_MODEL.""" + mock_client = MagicMock() + mock_anthropic_cls.return_value = mock_client + + from claudecode.claude_api_client import ClaudeAPIClient + + with patch.dict(os.environ, {"ANTHROPIC_API_KEY": "test-key"}): + client = ClaudeAPIClient(api_key="test-key") + client.validate_api_access() + + call_kwargs = mock_client.messages.create.call_args[1] + assert call_kwargs["model"] == DEFAULT_CLAUDE_MODEL + + +class TestModelFlowsThroughFindingsFilter: + """Test that model propagates from initialize_findings_filter to ClaudeAPIClient.""" + + @patch("claudecode.findings_filter.ClaudeAPIClient") + def test_model_reaches_claude_api_client(self, mock_client_cls): + """Model passed to FindingsFilter should reach ClaudeAPIClient.__init__.""" + mock_client_cls.return_value.validate_api_access.return_value = (True, "") + + from claudecode.findings_filter import FindingsFilter + + test_model = "claude-opus-4-5" + FindingsFilter( + use_hard_exclusions=True, + use_claude_filtering=True, + api_key="test-key", + model=test_model, + ) + + call_kwargs = mock_client_cls.call_args[1] + assert call_kwargs["model"] == test_model + + @patch("claudecode.github_action_audit.FindingsFilter") + def test_initialize_findings_filter_passes_model(self, mock_filter_cls): + """initialize_findings_filter must pass model= to FindingsFilter.""" + mock_filter_cls.return_value = MagicMock() + + from claudecode.github_action_audit import initialize_findings_filter + + with patch.dict(os.environ, { + "ANTHROPIC_API_KEY": "test-key", + "ENABLE_CLAUDE_FILTERING": "true", + }): + initialize_findings_filter() + + call_kwargs = mock_filter_cls.call_args[1] + assert "model" in call_kwargs + assert call_kwargs["model"] == DEFAULT_CLAUDE_MODEL + + +class TestWarningAnnotationOnValidationFailure: + """Test that a GitHub Actions ::warning:: annotation is emitted when API validation fails.""" + + @patch("claudecode.findings_filter.ClaudeAPIClient") + def test_warning_annotation_emitted(self, mock_client_cls): + """When validate_api_access fails, a ::warning:: annotation should be printed.""" + mock_client_cls.return_value.validate_api_access.return_value = ( + False, "model_not_found: the model is deprecated" + ) + + from claudecode.findings_filter import FindingsFilter + + captured = io.StringIO() + with patch("sys.stdout", captured): + FindingsFilter( + use_hard_exclusions=True, + use_claude_filtering=True, + api_key="test-key", + model=DEFAULT_CLAUDE_MODEL, + ) + + output = captured.getvalue() + assert "::warning::" in output + + @patch("claudecode.findings_filter.ClaudeAPIClient") + def test_filtering_disabled_on_validation_failure(self, mock_client_cls): + """When validation fails, use_claude_filtering must be False.""" + mock_client_cls.return_value.validate_api_access.return_value = (False, "bad key") + + from claudecode.findings_filter import FindingsFilter + + with patch("sys.stdout", io.StringIO()): + f = FindingsFilter( + use_hard_exclusions=True, + use_claude_filtering=True, + api_key="test-key", + model=DEFAULT_CLAUDE_MODEL, + ) + + assert not f.use_claude_filtering + assert f.claude_client is None