diff --git a/.importlinter b/.importlinter index a56c7d3c7e..3e5669e1a3 100644 --- a/.importlinter +++ b/.importlinter @@ -10,7 +10,7 @@ type = layers layers = ggshield.__main__ ggshield.cmd.auth | ggshield.cmd.config | ggshield.cmd.hmsl | ggshield.cmd.honeytoken | ggshield.cmd.install | ggshield.cmd.plugin | ggshield.cmd.quota | ggshield.cmd.secret | ggshield.cmd.status | ggshield.cmd.utils - ggshield.verticals.auth | ggshield.verticals.hmsl | ggshield.verticals.secret + ggshield.verticals.ai | ggshield.verticals.auth | ggshield.verticals.hmsl | ggshield.verticals.secret ggshield.core click | ggshield.utils | pygitguardian ignore_imports = @@ -33,6 +33,7 @@ source_modules = ggshield.cmd.status ggshield.cmd.utils forbidden_modules = + ggshield.verticals.ai ggshield.verticals.auth ggshield.verticals.hmsl ggshield.verticals.secret @@ -46,7 +47,7 @@ ignore_imports = ggshield.cmd.hmsl.** -> ggshield.verticals.hmsl.** ggshield.cmd.honeytoken.** -> ggshield.verticals.honeytoken ggshield.cmd.honeytoken.** -> ggshield.verticals.honeytoken.** - ggshield.cmd.install -> ggshield.verticals.secret.ai_hook + ggshield.cmd.install -> ggshield.verticals.ai.installation ggshield.cmd.install.** -> ggshield.verticals.install ggshield.cmd.install.** -> ggshield.verticals.install.** ggshield.cmd.plugin.** -> ggshield.core.plugin @@ -55,6 +56,7 @@ ignore_imports = ggshield.cmd.quota.** -> ggshield.verticals.quota.** ggshield.cmd.secret.** -> ggshield.verticals.secret ggshield.cmd.secret.** -> ggshield.verticals.secret.** + ggshield.cmd.secret.scan.ai_hook -> ggshield.verticals.ai.hooks ggshield.cmd.status.** -> ggshield.verticals.status ggshield.cmd.status.** -> ggshield.verticals.status.** ggshield.cmd.utils.** -> ggshield.verticals.utils diff --git a/ggshield/cmd/install.py b/ggshield/cmd/install.py index f835926d7f..9afee1469c 100644 --- a/ggshield/cmd/install.py +++ b/ggshield/cmd/install.py @@ -10,7 +10,7 @@ from ggshield.core.dirs import get_data_dir from ggshield.core.errors import UnexpectedError from ggshield.utils.git_shell import check_git_dir, git -from ggshield.verticals.secret.ai_hook import AI_FLAVORS, install_hooks +from ggshield.verticals.ai.installation import AGENTS, install_hooks # This snippet is used by the global hook to call the hook defined in the @@ -39,7 +39,7 @@ @click.option( "--hook-type", "-t", - type=click.Choice(["pre-commit", "pre-push"] + list(AI_FLAVORS.keys())), + type=click.Choice(["pre-commit", "pre-push"] + list(AGENTS.keys())), help="Type of hook to install.", default="pre-commit", ) @@ -61,7 +61,7 @@ def install_cmd( It can also install ggshield as a Cursor IDE or Claude Code agent hook. """ - if hook_type in AI_FLAVORS: + if hook_type in AGENTS: return install_hooks(name=hook_type, mode=mode, force=force) return_code = ( diff --git a/ggshield/cmd/secret/scan/ai_hook.py b/ggshield/cmd/secret/scan/ai_hook.py index c53ab6a716..3689b00131 100644 --- a/ggshield/cmd/secret/scan/ai_hook.py +++ b/ggshield/cmd/secret/scan/ai_hook.py @@ -10,9 +10,11 @@ from ggshield.core import ui from ggshield.core.client import create_client_from_config from ggshield.core.scan import ScanContext, ScanMode +from ggshield.verticals.ai.hooks import AIHookScanner from ggshield.verticals.secret import SecretScanner -from ggshield.verticals.secret.ai_hook import AIHookScanner -from ggshield.verticals.secret.ai_hook.models import MAX_READ_SIZE + + +MAX_READ_SIZE = 1024 * 1024 * 10 # We restrict stdin read to 10MB @click.command() diff --git a/ggshield/core/scan/__init__.py b/ggshield/core/scan/__init__.py index b37b6197f6..e2c370a110 100644 --- a/ggshield/core/scan/__init__.py +++ b/ggshield/core/scan/__init__.py @@ -3,6 +3,7 @@ from .scan_context import ScanContext from .scan_mode import ScanMode from .scannable import DecodeError, NonSeekableFileError, Scannable, StringScannable +from .scanner import ResultsProtocol, ScannerProtocol, SecretProtocol __all__ = [ @@ -11,8 +12,11 @@ "DecodeError", "File", "NonSeekableFileError", + "ResultsProtocol", "ScanContext", "ScanMode", "Scannable", + "ScannerProtocol", + "SecretProtocol", "StringScannable", ] diff --git a/ggshield/core/scan/scanner.py b/ggshield/core/scan/scanner.py new file mode 100644 index 0000000000..9f963669f6 --- /dev/null +++ b/ggshield/core/scan/scanner.py @@ -0,0 +1,50 @@ +""" +Protocols for SecretScanner and its results, +so that other verticals can use the scanner if they are provided one. +""" + +from collections.abc import Sequence +from typing import Iterable, Optional, Protocol + +from pygitguardian.models import Match + +from ggshield.core.scanner_ui import ScannerUI + +from . import Scannable + + +class SecretProtocol(Protocol): + """Abstract base class for secrets. + + We use getters instead of properties to have a . + """ + + @property + def detector_display_name(self) -> str: ... + + @property + def validity(self) -> str: ... + + @property + def matches(self) -> Sequence[Match]: ... + + +class ResultProtocol(Protocol): + @property + def secrets(self) -> Sequence[SecretProtocol]: ... + + +class ResultsProtocol(Protocol): + @property + def results(self) -> Sequence[ResultProtocol]: ... + + +class ScannerProtocol(Protocol): + """Protocol for scanners.""" + + def scan( + self, + files: Iterable[Scannable], + scanner_ui: ScannerUI, + scan_threads: Optional[int] = None, + ) -> ResultsProtocol: ... diff --git a/ggshield/verticals/ai/__init__.py b/ggshield/verticals/ai/__init__.py new file mode 100644 index 0000000000..cbd7e4db0c --- /dev/null +++ b/ggshield/verticals/ai/__init__.py @@ -0,0 +1,10 @@ +from .agents import AGENTS +from .hooks import AIHookScanner +from .installation import install_hooks + + +__all__ = [ + "AGENTS", + "AIHookScanner", + "install_hooks", +] diff --git a/ggshield/verticals/ai/agents/__init__.py b/ggshield/verticals/ai/agents/__init__.py new file mode 100644 index 0000000000..7289463cc7 --- /dev/null +++ b/ggshield/verticals/ai/agents/__init__.py @@ -0,0 +1,9 @@ +from .claude_code import Claude +from .copilot import Copilot +from .cursor import Cursor + + +AGENTS = {agent.name: agent for agent in [Cursor(), Claude(), Copilot()]} + + +__all__ = ["AGENTS", "Claude", "Copilot", "Cursor"] diff --git a/ggshield/verticals/secret/ai_hook/claude_code.py b/ggshield/verticals/ai/agents/claude_code.py similarity index 92% rename from ggshield/verticals/secret/ai_hook/claude_code.py rename to ggshield/verticals/ai/agents/claude_code.py index 378490f904..b9ea7670f1 100644 --- a/ggshield/verticals/secret/ai_hook/claude_code.py +++ b/ggshield/verticals/ai/agents/claude_code.py @@ -4,15 +4,21 @@ import click -from .models import EventType, Flavor, Result +from ..models import Agent, EventType, HookResult -class Claude(Flavor): +class Claude(Agent): """Behavior specific to Claude Code.""" - name = "Claude Code" + @property + def name(self) -> str: + return "claude-code" + + @property + def display_name(self) -> str: + return "Claude Code" - def output_result(self, result: Result) -> int: + def output_result(self, result: HookResult) -> int: response = {} if result.block: if result.payload.event_type in [ diff --git a/ggshield/verticals/secret/ai_hook/copilot.py b/ggshield/verticals/ai/agents/copilot.py similarity index 81% rename from ggshield/verticals/secret/ai_hook/copilot.py rename to ggshield/verticals/ai/agents/copilot.py index b523daa70b..300a8ce132 100644 --- a/ggshield/verticals/secret/ai_hook/copilot.py +++ b/ggshield/verticals/ai/agents/copilot.py @@ -3,8 +3,8 @@ import click +from ..models import EventType, HookResult from .claude_code import Claude -from .models import EventType, Result class Copilot(Claude): @@ -13,9 +13,15 @@ class Copilot(Claude): Inherits most of its behavior from Claude Code. """ - name = "Copilot" + @property + def name(self) -> str: + return "copilot" + + @property + def display_name(self) -> str: + return "Copilot Chat" - def output_result(self, result: Result) -> int: + def output_result(self, result: HookResult) -> int: response = {} if result.block: if result.payload.event_type == EventType.PRE_TOOL_USE: diff --git a/ggshield/verticals/secret/ai_hook/cursor.py b/ggshield/verticals/ai/agents/cursor.py similarity index 87% rename from ggshield/verticals/secret/ai_hook/cursor.py rename to ggshield/verticals/ai/agents/cursor.py index 3b20c46fec..c75e807380 100644 --- a/ggshield/verticals/secret/ai_hook/cursor.py +++ b/ggshield/verticals/ai/agents/cursor.py @@ -4,15 +4,21 @@ import click -from .models import EventType, Flavor, Result +from ..models import Agent, EventType, HookResult -class Cursor(Flavor): +class Cursor(Agent): """Behavior specific to Cursor.""" - name = "Cursor" + @property + def name(self) -> str: + return "cursor" + + @property + def display_name(self) -> str: + return "Cursor" - def output_result(self, result: Result) -> int: + def output_result(self, result: HookResult) -> int: response = {} if result.payload.event_type == EventType.USER_PROMPT: response["continue"] = not result.block diff --git a/ggshield/verticals/secret/ai_hook/scanner.py b/ggshield/verticals/ai/hooks.py similarity index 57% rename from ggshield/verticals/secret/ai_hook/scanner.py rename to ggshield/verticals/ai/hooks.py index 7eacfe5abe..4307398c79 100644 --- a/ggshield/verticals/secret/ai_hook/scanner.py +++ b/ggshield/verticals/ai/hooks.py @@ -6,15 +6,37 @@ from notifypy import Notify from ggshield.core.filter import censor_match +from ggshield.core.scan import ScannerProtocol +from ggshield.core.scan import SecretProtocol as Secret from ggshield.core.scanner_ui import create_message_only_scanner_ui from ggshield.core.text_utils import pluralize, translate_validity -from ggshield.verticals.secret import SecretScanner -from ggshield.verticals.secret.ai_hook.copilot import Copilot -from ggshield.verticals.secret.secret_scan_collection import Secret -from .claude_code import Claude -from .cursor import Cursor -from .models import EventType, Flavor, Payload, Result, Tool +from .agents import Claude, Copilot, Cursor +from .models import Agent, EventType, HookPayload, HookResult, Tool + + +HOOK_NAME_TO_EVENT_TYPE = { + "userpromptsubmit": EventType.USER_PROMPT, + "beforesubmitprompt": EventType.USER_PROMPT, + "pretooluse": EventType.PRE_TOOL_USE, + "posttooluse": EventType.POST_TOOL_USE, +} + +TOOL_NAME_TO_TOOL = { + "shell": Tool.BASH, # Cursor + "bash": Tool.BASH, # Claude Code + "run_in_terminal": Tool.BASH, # Copilot + "read": Tool.READ, # Claude/Cursor + "read_file": Tool.READ, # Copilot +} + + +def lookup(data: Dict[str, Any], keys: Sequence[str], default: Any = None) -> Any: + """Returns the value of the first key found in a dictionary.""" + for key in keys: + if key in data: + return data[key] + return default # Regex (and method) to look for any @file_path in the prompt. @@ -41,12 +63,121 @@ def find_filepaths(prompt: str) -> Set[str]: return paths +def parse_hook_input(raw_content: str) -> list[HookPayload]: + """Parse the input content. Raises a ValueError if the input is not valid. + + Returns: + A list of payloads. Most of the time the list will contain only one payload, + but in some cases files mentioned in the prompt will be read but the + PreToolUse event will not be called. So we need to handle this case ourselves. + """ + # Parse the content as JSON + if not raw_content.strip(): + raise ValueError("Error: No input received on stdin") + try: + data = json.loads(raw_content) + except json.JSONDecodeError as e: + raise ValueError(f"Error: Failed to parse JSON from stdin: {e}") from e + + payloads = [] + + # Try to guess which AI coding assistant is calling us + agent = _detect_agent(data) + + # Infer the event type + event_name = lookup(data, ["hook_event_name", "hookEventName"], None) + if event_name is None: + raise ValueError("Error: couldn't find event type") + event_type = HOOK_NAME_TO_EVENT_TYPE.get(event_name.lower(), EventType.OTHER) + + identifier = "" + content = "" + tool = None + + # Extract the identifier and content based on the event type + if event_type == EventType.USER_PROMPT: + content = data.get("prompt", "") + # Look for files mentioned in the prompt that could be read + # without triggering a PRE_TOOL_USE event. + payloads.extend(_parse_user_prompt(content, event_type, agent)) + + elif event_type == EventType.PRE_TOOL_USE: + tool_name = data.get("tool_name", "").lower() + tool = TOOL_NAME_TO_TOOL.get(tool_name, Tool.OTHER) + tool_input = data.get("tool_input", {}) + # Select the content based on the tool + if tool == Tool.BASH: + content = tool_input.get("command", "") + identifier = content + elif tool == Tool.READ: + # We only need to deal with the identifier, the content will be read by the Scannable + identifier = lookup(tool_input, ["file_path", "filePath"], "") + + elif event_type == EventType.POST_TOOL_USE: + tool_name = data.get("tool_name", "").lower() + tool = TOOL_NAME_TO_TOOL.get(tool_name, Tool.OTHER) + content = data.get("tool_output", "") or data.get("tool_response", {}) + # Claude Code returns a dict for the tool output + if isinstance(content, (dict, list)): + content = json.dumps(content) + + # If identifier was not set, hash the content + if not identifier: + identifier = hashlib.sha256((content or "").encode()).hexdigest() + + payloads.append( + HookPayload( + event_type=event_type, + tool=tool, + content=content, + identifier=identifier, + agent=agent, + ) + ) + return payloads + + +def _detect_agent(data: Dict[str, Any]) -> Agent: + """Detect the AI code assistant.""" + if "cursor_version" in data: + return Cursor() + elif "github.copilot-chat" in data.get("transcript_path", "").lower(): + return Copilot() + # no .lower() here to reduce the risk of false positives (this is also why this check is last) + elif "session_id" in data and "claude" in data.get("transcript_path", ""): + return Claude() + # No other agent is supported yet + raise ValueError("Unsupported agent") + + +def _parse_user_prompt( + content: str, event_type: EventType, agent: Agent +) -> List[HookPayload]: + """Parse the user prompt for additional payloads that we may miss.""" + payloads = [] + # Scenario 1 (the only one we know about so far): + # Code assistants don't always trigger a PRE_TOOL_USE event when + # a file is mentioned in the prompt, especially with an "@" prefix. + matches = find_filepaths(content) + for match in matches: + payloads.append( + HookPayload( + event_type=event_type, + tool=Tool.READ, + content="", + identifier=match, + agent=agent, + ) + ) + return payloads + + class AIHookScanner: """AI hook scanner. It is called with the payload of a hook event. Note that instead of having a base class with common method and a subclass per supported AI tool, - we instead have a single class which detects which protocol to use (called "flavor"). + we instead have a single class which detects which protocol to use. This is because some tools sloppily support hooks from others. For instance, Cursor will call hooks defined in the Claude Code format, but send payload in its own format. So we can't assume which tool will call us based on the command line/hook configuration only. @@ -55,98 +186,27 @@ class AIHookScanner: ValueError: If the input is not valid. """ - def __init__(self, scanner: SecretScanner): + def __init__(self, scanner: ScannerProtocol): self.scanner = scanner def scan(self, content: str) -> int: """Scan the content, print the result and return the exit code.""" - payloads = self._parse_input(content) + payloads = parse_hook_input(content) result = self._scan_payloads(payloads) payload = result.payload # Special case: in post-tool use, the action is already done: at least notify the user if result.block and payload.event_type == EventType.POST_TOOL_USE: self._send_secret_notification( - result.nbr_secrets, payload.tool or Tool.OTHER, payload.flavor.name + result.nbr_secrets, + payload.tool or Tool.OTHER, + payload.agent.display_name, ) - return payload.flavor.output_result(result) - - def _parse_input(self, raw_content: str) -> list[Payload]: - """Parse the input content. Raises a ValueError if the input is not valid. - - Returns: - A list of payloads. Most of the time the list will contain only one payload, - but in some cases files mentioned in the prompt will be read but the - PreToolUse event will not be called. So we need to handle this case ourselves. - """ - # Parse the content as JSON - if not raw_content.strip(): - raise ValueError("Error: No input received on stdin") - try: - data = json.loads(raw_content) - except json.JSONDecodeError as e: - raise ValueError(f"Error: Failed to parse JSON from stdin: {e}") from e - - payloads = [] - - # Try to guess which AI coding assistant is calling us - flavor = self._detect_flavor(data) - - # Infer the event type - event_name = lookup(data, ["hook_event_name", "hookEventName"], None) - if event_name is None: - raise ValueError("Error: couldn't find event type") - event_type = HOOK_NAME_TO_EVENT_TYPE.get(event_name.lower(), EventType.OTHER) - - identifier = "" - content = "" - tool = None - - # Extract the identifier and content based on the event type - if event_type == EventType.USER_PROMPT: - content = data.get("prompt", "") - # Look for files mentioned in the prompt that could be read - # without triggering a PRE_TOOL_USE event. - payloads.extend(self._parse_user_prompt(content, event_type, flavor)) - - elif event_type == EventType.PRE_TOOL_USE: - tool_name = data.get("tool_name", "").lower() - tool = TOOL_NAME_TO_TOOL.get(tool_name, Tool.OTHER) - tool_input = data.get("tool_input", {}) - # Select the content based on the tool - if tool == Tool.BASH: - content = tool_input.get("command", "") - identifier = content - elif tool == Tool.READ: - # We only need to deal with the identifier, the content will be read by the Scannable - identifier = lookup(tool_input, ["file_path", "filePath"], "") - - elif event_type == EventType.POST_TOOL_USE: - tool_name = data.get("tool_name", "").lower() - tool = TOOL_NAME_TO_TOOL.get(tool_name, Tool.OTHER) - content = data.get("tool_output", "") or data.get("tool_response", {}) - # Claude Code returns a dict for the tool output - if isinstance(content, (dict, list)): - content = json.dumps(content) - - # If identifier was not set, hash the content - if not identifier: - identifier = hashlib.sha256((content or "").encode()).hexdigest() - - payloads.append( - Payload( - event_type=event_type, - tool=tool, - content=content, - identifier=identifier, - flavor=flavor, - ) - ) - return payloads + return payload.agent.output_result(result) - def _scan_payloads(self, payloads: List[Payload]) -> Result: + def _scan_payloads(self, payloads: List[HookPayload]) -> HookResult: """Scan payloads for secrets using the SecretScanner. Returns: @@ -159,16 +219,16 @@ def _scan_payloads(self, payloads: List[Payload]) -> Result: result = self._scan_content(payload) if result.block: return result - return Result.allow(payloads[0]) + return HookResult.allow(payloads[0]) def _scan_content( self, - payload: Payload, - ) -> Result: + payload: HookPayload, + ) -> HookResult: """Scan content for secrets using the SecretScanner.""" # Short path: if there is no content, no need to do an API call if payload.empty: - return Result.allow(payload) + return HookResult.allow(payload) with create_message_only_scanner_ui() as scanner_ui: results = self.scanner.scan([payload.scannable], scanner_ui=scanner_ui) @@ -178,58 +238,25 @@ def _scan_content( secrets.extend(result.secrets) if not secrets: - return Result.allow(payload) + return HookResult.allow(payload) message = self._message_from_secrets( secrets, payload, escape_markdown=True, ) - return Result( + return HookResult( block=True, message=message, nbr_secrets=len(secrets), payload=payload, ) - @staticmethod - def _detect_flavor(data: Dict[str, Any]) -> Flavor: - """Detect the AI code assistant.""" - if "cursor_version" in data: - return Cursor() - elif "github.copilot-chat" in data.get("transcript_path", "").lower(): - return Copilot() - # no .lower() here to reduce the risk of false positives (this is also why this check is last) - elif "session_id" in data and "claude" in data.get("transcript_path", ""): - return Claude() - else: - # Fallback that respect base conventions - return Flavor() - - def _parse_user_prompt( - self, content: str, event_type: EventType, flavor: Flavor - ) -> List[Payload]: - """Parse the user prompt for additional payloads that we may miss.""" - payloads = [] - # Scenario 1 (the only one we know about so far): - # Code assistants don't always trigger a PRE_TOOL_USE event when - # a file is mentioned in the prompt, especially with an "@" prefix. - matches = find_filepaths(content) - for match in matches: - payloads.append( - Payload( - event_type=event_type, - tool=Tool.READ, - content="", - identifier=match, - flavor=flavor, - ) - ) - return payloads - @staticmethod def _message_from_secrets( - secrets: List[Secret], payload: Payload, escape_markdown: bool = False + secrets: List[Secret], + payload: HookPayload, + escape_markdown: bool = False, ) -> str: """ Format detected secrets into a user-friendly message. @@ -308,27 +335,3 @@ def _send_secret_notification( # This is best effort, we don't want to propagate an error # if the notification fails. pass - - -HOOK_NAME_TO_EVENT_TYPE = { - "userpromptsubmit": EventType.USER_PROMPT, - "beforesubmitprompt": EventType.USER_PROMPT, - "pretooluse": EventType.PRE_TOOL_USE, - "posttooluse": EventType.POST_TOOL_USE, -} - -TOOL_NAME_TO_TOOL = { - "shell": Tool.BASH, # Cursor - "bash": Tool.BASH, # Claude Code - "run_in_terminal": Tool.BASH, # Copilot - "read": Tool.READ, # Claude/Cursor - "read_file": Tool.READ, # Copilot -} - - -def lookup(data: Dict[str, Any], keys: Sequence[str], default: Any = None) -> Any: - """Returns the value of the first key found in a dictionary.""" - for key in keys: - if key in data: - return data[key] - return default diff --git a/ggshield/verticals/secret/ai_hook/installation.py b/ggshield/verticals/ai/installation.py similarity index 87% rename from ggshield/verticals/secret/ai_hook/installation.py rename to ggshield/verticals/ai/installation.py index 3139c7b3d5..ca328b374f 100644 --- a/ggshield/verticals/secret/ai_hook/installation.py +++ b/ggshield/verticals/ai/installation.py @@ -9,16 +9,7 @@ from ggshield.core.dirs import get_user_home_dir from ggshield.core.errors import UnexpectedError -from .claude_code import Claude -from .copilot import Copilot -from .cursor import Cursor - - -AI_FLAVORS = { - "cursor": Cursor, - "claude-code": Claude, - "copilot": Copilot, -} +from .agents import AGENTS @dataclass @@ -41,12 +32,12 @@ def install_hooks( """ try: - flavor = AI_FLAVORS[name]() + agent = AGENTS[name] except KeyError: - raise ValueError(f"Unsupported tool name: {name}") + raise ValueError(f"Unsupported agent: {name}") base_dir = get_user_home_dir() if mode == "global" else Path(".") - settings_path = base_dir / flavor.settings_path + settings_path = base_dir / agent.settings_path command = "ggshield secret scan ai-hook" @@ -71,11 +62,11 @@ def install_hooks( stats = _fill_dict( config=existing_config, - template=flavor.settings_template, + template=agent.settings_template, command=command, overwrite=force, stats=stats, - locator=flavor.settings_locate, + locator=agent.settings_locate, ) # Ensure parent directory exists @@ -89,11 +80,11 @@ def install_hooks( # Report what happened styled_path = click.style(settings_path, fg="yellow", bold=True) if stats.added == 0 and stats.already_present > 0: - click.echo(f"{flavor.name} hooks already installed in {styled_path}") + click.echo(f"{agent.display_name} hooks already installed in {styled_path}") elif stats.added > 0 and stats.already_present > 0: - click.echo(f"{flavor.name} hooks updated in {styled_path}") + click.echo(f"{agent.display_name} hooks updated in {styled_path}") else: - click.echo(f"{flavor.name} hooks successfully added in {styled_path}") + click.echo(f"{agent.display_name} hooks successfully added in {styled_path}") return 0 diff --git a/ggshield/verticals/secret/ai_hook/models.py b/ggshield/verticals/ai/models.py similarity index 80% rename from ggshield/verticals/secret/ai_hook/models.py rename to ggshield/verticals/ai/models.py index a031e2eea3..fecf9ddd27 100644 --- a/ggshield/verticals/secret/ai_hook/models.py +++ b/ggshield/verticals/ai/models.py @@ -1,17 +1,13 @@ +from abc import ABC, abstractmethod from dataclasses import dataclass from enum import Enum, auto from pathlib import Path from typing import Any, Dict, List, Optional -import click - from ggshield.core.scan import File, Scannable, StringScannable from ggshield.utils.files import is_path_binary -MAX_READ_SIZE = 1024 * 1024 * 50 # We restrict payloads read to 50MB - - class EventType(Enum): """Event type constants for hook events.""" @@ -33,31 +29,64 @@ class Tool(Enum): @dataclass -class Result: +class HookResult: """Result of a scan: allow or not.""" block: bool message: str nbr_secrets: int - payload: "Payload" + payload: "HookPayload" @classmethod - def allow(cls, payload: "Payload") -> "Result": + def allow(cls, payload: "HookPayload") -> "HookResult": return cls(block=False, message="", nbr_secrets=0, payload=payload) -class Flavor: +@dataclass +class HookPayload: + event_type: EventType + tool: Optional[Tool] + content: str + identifier: str + agent: "Agent" + + @property + def scannable(self) -> Scannable: + """Return the appropriate Scannable for the payload.""" + if self.tool == Tool.READ: + path = Path(self.identifier) + if path.is_file() and not is_path_binary(path): + return File(path=self.identifier) + return StringScannable(url=self.identifier, content=self.content) + + @property + def empty(self) -> bool: + """Return True if the payload is empty.""" + return not self.scannable.is_longer_than(0) + + +class Agent(ABC): """ Class that can be derived to implement behavior specific to some AI code assistants. """ - name = "Your AI coding tool" + # Metadata - def output_result(self, result: Result) -> int: - """How to output the result of a scan. + @property + @abstractmethod + def display_name(self) -> str: + """A user-friendly name for the agent.""" - This base implementation has sensible defaults (like returning 2 in case of a block, - and printing the output in stderr or stdout). + @property + @abstractmethod + def name(self) -> str: + """The name of the agent.""" + + # Hooks + + @abstractmethod + def output_result(self, result: HookResult) -> int: + """How to output the result of a scan. This method is expected to have side effects, like printing to stdout or stderr. @@ -66,26 +95,23 @@ def output_result(self, result: Result) -> int: Returns: the exit code. """ - if result.block: - click.echo(result.message, err=True) - return 2 - else: - click.echo("No secrets found. Good to go.") - return 0 + + # Settings @property + @abstractmethod def settings_path(self) -> Path: """Path to the settings file for this AI coding tool.""" - return Path(".agents") / "hooks.json" @property + @abstractmethod def settings_template(self) -> Dict[str, Any]: """ Template for the settings file for this AI coding tool. Use the sentinel "" for the places where the command should be inserted. """ - return {} + @abstractmethod def settings_locate( self, candidates: List[Dict[str, Any]], template: Dict[str, Any] ) -> Optional[Dict[str, Any]]: @@ -102,26 +128,3 @@ def settings_locate( Returns: the object to update, or None if no object was found. """ return None - - -@dataclass -class Payload: - event_type: EventType - tool: Optional[Tool] - content: str - identifier: str - flavor: Flavor - - @property - def scannable(self) -> Scannable: - """Return the appropriate Scannable for the payload.""" - if self.tool == Tool.READ: - path = Path(self.identifier) - if path.is_file() and not is_path_binary(path): - return File(path=self.identifier) - return StringScannable(url=self.identifier, content=self.content) - - @property - def empty(self) -> bool: - """Return True if the payload is empty.""" - return not self.scannable.is_longer_than(0) diff --git a/ggshield/verticals/secret/ai_hook/__init__.py b/ggshield/verticals/secret/ai_hook/__init__.py deleted file mode 100644 index 68d4170d4a..0000000000 --- a/ggshield/verticals/secret/ai_hook/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from .installation import AI_FLAVORS, install_hooks -from .scanner import AIHookScanner - - -__all__ = ["AIHookScanner", "install_hooks", "AI_FLAVORS"] diff --git a/scripts/generate-import-linter-config.py b/scripts/generate-import-linter-config.py index bdb0823275..3bf75f2a76 100755 --- a/scripts/generate-import-linter-config.py +++ b/scripts/generate-import-linter-config.py @@ -64,8 +64,10 @@ class Contract(TypedDict): "ggshield.cmd.{}.** -> ggshield.verticals.{}.**", # FIXME: #521 - enforce boundaries between cmd.auth and verticals.hmsl "ggshield.cmd.auth.** -> ggshield.verticals.hmsl.**", - # Logic to install hooks for AI assistants - "ggshield.cmd.install -> ggshield.verticals.secret.ai_hook", + # Install command import logic to install AI hooks + "ggshield.cmd.install -> ggshield.verticals.ai.installation", + # AI hook command import logic to scan AI hook payloads + "ggshield.cmd.secret.scan.ai_hook -> ggshield.verticals.ai.hooks", ], "unmatched_ignore_imports_alerting": "none", }, diff --git a/tests/unit/verticals/secret/ai_hook/test_hooks.py b/tests/unit/verticals/ai/test_hooks.py similarity index 83% rename from tests/unit/verticals/secret/ai_hook/test_hooks.py rename to tests/unit/verticals/ai/test_hooks.py index f763cf76b4..adf355774f 100644 --- a/tests/unit/verticals/secret/ai_hook/test_hooks.py +++ b/tests/unit/verticals/ai/test_hooks.py @@ -7,18 +7,32 @@ import pytest from ggshield.utils.git_shell import Filemode +from ggshield.verticals.ai.agents import Claude, Copilot, Cursor +from ggshield.verticals.ai.hooks import AIHookScanner, find_filepaths, parse_hook_input +from ggshield.verticals.ai.models import EventType, HookPayload, HookResult, Tool from ggshield.verticals.secret import SecretScanner -from ggshield.verticals.secret.ai_hook.claude_code import Claude -from ggshield.verticals.secret.ai_hook.copilot import Copilot -from ggshield.verticals.secret.ai_hook.cursor import Cursor -from ggshield.verticals.secret.ai_hook.models import EventType, Flavor, Payload -from ggshield.verticals.secret.ai_hook.models import Result as HookResult -from ggshield.verticals.secret.ai_hook.models import Tool -from ggshield.verticals.secret.ai_hook.scanner import AIHookScanner, find_filepaths from ggshield.verticals.secret.secret_scan_collection import Result as ScanResult from ggshield.verticals.secret.secret_scan_collection import Results, Secret +def _dummy_payload(event_type: EventType = EventType.OTHER) -> HookPayload: + return HookPayload( + event_type=event_type, + tool=None, + content="", + identifier="", + agent=Cursor(), + ) + + +@pytest.fixture +def tmp_file(tmp_path: Path) -> Path: + """Create a temporary file with content.""" + file = tmp_path / "test.txt" + file.write_text("this is the content") + return file + + def _mock_scanner(matches: List[str]) -> MagicMock: """Create a mock SecretScanner that returns the given Results from scan().""" mock = MagicMock(spec=SecretScanner) @@ -62,26 +76,46 @@ def _make_secret(match_str: str = "***"): ) -def _dummy_payload(event_type: EventType = EventType.OTHER) -> Payload: - return Payload( - event_type=event_type, - tool=None, - content="", - identifier="", - flavor=Flavor(), - ) +class TestAIHookScannerScanContent: + """Unit tests for AIHookScanner._scan_content.""" + def test_no_secrets_returns_allow(self): + """When scanner returns no secrets, result has block=False and nbr_secrets=0.""" + hook_scanner = AIHookScanner(_mock_scanner([])) + payload = HookPayload( + event_type=EventType.USER_PROMPT, + tool=None, + content="safe content", + identifier="id", + agent=Cursor(), + ) + result = hook_scanner._scan_content(payload) + assert isinstance(result, HookResult) + assert result.block is False + assert result.nbr_secrets == 0 + assert result.message == "" -@pytest.fixture -def tmp_file(tmp_path: Path) -> Path: - """Create a temporary file with content.""" - file = tmp_path / "test.txt" - file.write_text("this is the content") - return file + def test_with_secrets_returns_block_and_message(self): + """When scanner returns secrets, result has block=True, nbr_secrets and message set.""" + hook_scanner = AIHookScanner(_mock_scanner(["sk-xxx"])) + payload = HookPayload( + event_type=EventType.USER_PROMPT, + tool=None, + content="content with sk-xxx", + identifier="id", + agent=Cursor(), + ) + result = hook_scanner._scan_content(payload) + assert isinstance(result, HookResult) + assert result.block is True + assert result.nbr_secrets == 1 + assert "dummy-detector" in result.message + assert "secret" in result.message.lower() + assert "remove the secrets from your prompt" in result.message -class TestAIHookScannerParseInput: - """Unit tests for AIHookScanner._parse_input.""" +class TestAIHookScannerScan: + """Unit tests for the AIHookScanner.scan() method.""" def test_empty_input_raises(self): """Empty or whitespace-only input raises ValueError.""" @@ -91,23 +125,178 @@ def test_empty_input_raises(self): with pytest.raises(ValueError, match="No input received on stdin"): scanner.scan(" \n ") + def test_scan_no_secrets_returns_zero(self): + """scan() with no secrets returns 0.""" + scanner = AIHookScanner(_mock_scanner([])) + data = { + "hook_event_name": "UserPromptSubmit", + "prompt": "hello world", + "transcript_path": "/home/user/.claude/projects/foo/session.jsonl", + "cursor_version": "1.2.3", + } + code = scanner.scan(json.dumps(data)) + assert code == 0 + + @patch("ggshield.verticals.ai.hooks.AIHookScanner._send_secret_notification") + def test_scan_post_tool_use_with_secrets_sends_notification( + self, mock_notify: MagicMock + ): + """scan() on POST_TOOL_USE with secrets sends a notification and returns 0 (no block).""" + scanner = AIHookScanner(_mock_scanner(["sk-xxx"])) + data = { + "hook_event_name": "PostToolUse", + "tool_name": "Bash", + "tool_input": {"command": "echo sk-xxx"}, + "tool_response": {"stdout": "sk-xxx\n"}, + "transcript_path": "/home/user/.claude/projects/foo/session.jsonl", + "session_id": "427ae0c5-0862-4e14-aa2c-12fad909c323", + } + code = scanner.scan(json.dumps(data)) + assert code == 0 + mock_notify.assert_called_once() + args = mock_notify.call_args[0] + assert args[0] == 1 # nbr_secrets + assert args[1] == Tool.BASH # tool + + def test_scan_pre_tool_use_with_secrets_blocks(self): + """scan() on PRE_TOOL_USE with secrets returns block result.""" + scanner = AIHookScanner(_mock_scanner(["sk-xxx"])) + data = { + "hook_event_name": "PreToolUse", + "tool_name": "Bash", + "tool_input": {"command": "echo sk-xxx"}, + "session_id": "427ae0c5-0862-4e14-aa2c-12fad909c323", + "transcript_path": "/home/user/.claude/projects/foo/session.jsonl", + } + code = scanner.scan(json.dumps(data)) + # Claude output_result always returns 0 + assert code == 0 + + def test_scan_no_content_returns_allow(self): + """scan() with no content returns 0 (and doesn't call the API).""" + mock_scanner = _mock_scanner([]) + scanner = AIHookScanner(mock_scanner) + data = { + "hook_event_name": "PreToolUse", + "tool_name": "Read", + "tool_input": {"file_path": "doesn-t-exist"}, + "cursor_version": "1.2.3", + } + code = scanner.scan(json.dumps(data)) + assert code == 0 + mock_scanner.scan.assert_not_called() + + def test_scan_payloads_refuse_empty_list(self): + """scan() with empty list of payloads raises ValueError.""" + scanner = AIHookScanner(_mock_scanner([])) + with pytest.raises(ValueError): + scanner._scan_payloads([]) + + +class TestMessageFromSecrets: + """Unit tests for AIHookScanner._message_from_secrets with different payload types.""" + + def test_message_for_bash_tool(self): + """Message for BASH tool mentions environment variables.""" + payload = HookPayload( + event_type=EventType.PRE_TOOL_USE, + tool=Tool.BASH, + content="echo sk-xxx", + identifier="echo sk-xxx", + agent=Cursor(), + ) + message = AIHookScanner._message_from_secrets([_make_secret("sk-xxx")], payload) + assert "remove the secrets from the command" in message + assert "environment variables" in message + + def test_message_for_read_tool(self): + """Message for READ tool mentions file content.""" + payload = HookPayload( + event_type=EventType.PRE_TOOL_USE, + tool=Tool.READ, + content="file content with secret", + identifier="/path/to/file", + agent=Cursor(), + ) + message = AIHookScanner._message_from_secrets([_make_secret("sk-xxx")], payload) + assert "remove the secrets from" in message + + def test_message_for_other_tool(self): + """Message for OTHER tool uses generic message.""" + payload = HookPayload( + event_type=EventType.PRE_TOOL_USE, + tool=Tool.OTHER, + content="some content", + identifier="id", + agent=Cursor(), + ) + message = AIHookScanner._message_from_secrets([_make_secret("sk-xxx")], payload) + assert "remove the secrets from the tool input" in message + + def test_message_escapes_markdown(self): + """When escape_markdown=True, asterisks in matches are replaced with dots.""" + payload = HookPayload( + event_type=EventType.USER_PROMPT, + tool=None, + content="content", + identifier="id", + agent=Cursor(), + ) + message = AIHookScanner._message_from_secrets( + [_make_secret("sk-xxx")], payload, escape_markdown=True + ) + # The message itself should not contain raw asterisks from matches + # (the header uses ** for bold which is intentional) + assert "Detected" in message + + +class TestSendSecretNotification: + """Unit tests for AIHookScanner._send_secret_notification.""" + + @patch("ggshield.verticals.ai.hooks.Notify") + def test_notification_for_bash_tool(self, mock_notify_cls: MagicMock): + """Notification for BASH tool says 'running a command'.""" + AIHookScanner._send_secret_notification(1, Tool.BASH, "Claude Code") + instance = mock_notify_cls.return_value + assert "running a command" in instance.message + assert "Claude Code" in instance.message + instance.send.assert_called_once() + + @patch("ggshield.verticals.ai.hooks.Notify") + def test_notification_for_read_tool(self, mock_notify_cls: MagicMock): + """Notification for READ tool says 'reading a file'.""" + AIHookScanner._send_secret_notification(2, Tool.READ, "Cursor") + instance = mock_notify_cls.return_value + assert "reading a file" in instance.message + assert "2" in instance.message + instance.send.assert_called_once() + + @patch("ggshield.verticals.ai.hooks.Notify") + def test_notification_for_other_tool(self, mock_notify_cls: MagicMock): + """Notification for OTHER tool says 'using a tool'.""" + AIHookScanner._send_secret_notification(1, Tool.OTHER, "Copilot") + instance = mock_notify_cls.return_value + assert "using a tool" in instance.message + instance.send.assert_called_once() + + +class TestAIHookScannerParseInput: + """Unit tests for AIHookparse_hook_input.""" + def test_invalid_json_raises(self): """Invalid JSON raises ValueError with parse error.""" - scanner = AIHookScanner(_mock_scanner([])) with pytest.raises(ValueError, match="Failed to parse JSON"): - scanner._parse_input("not json {") + parse_hook_input("not json {") with pytest.raises(ValueError, match="Failed to parse JSON"): - scanner._parse_input("{ missing brace ") + parse_hook_input("{ missing brace ") def test_missing_event_type_raises(self): """JSON without event type raises ValueError.""" - scanner = AIHookScanner(_mock_scanner([])) - with pytest.raises(ValueError, match="couldn't find event type"): - scanner._parse_input('{"prompt": "hello"}') + with pytest.raises(ValueError): + parse_hook_input('{"prompt": "hello"}') def test_cursor_user_prompt(self): """Test Cursor beforeSubmitPrompt (user prompt) parsing.""" - scanner = AIHookScanner(_mock_scanner([])) data = { "conversation_id": "75fed8a8-2078-4e49-80d2-776b20d441c3", "generation_id": "1501ede6-b8ac-43f4-9943-0e218610c5c6", @@ -120,16 +309,15 @@ def test_cursor_user_prompt(self): "user_email": "user@example.com", "transcript_path": "/home/user1/.cursor/projects/foo/agent-transcripts/75fed8a8/75fed8a8.jsonl", } - payload = scanner._parse_input(json.dumps(data))[0] + payload = parse_hook_input(json.dumps(data))[0] assert payload.event_type == EventType.USER_PROMPT assert payload.content == "hello world" assert payload.tool is None assert payload.identifier != "" - assert isinstance(payload.flavor, Cursor) + assert isinstance(payload.agent, Cursor) def test_cursor_pre_tool_use_shell(self): """Test Cursor preToolUse with Shell (bash) parsing.""" - scanner = AIHookScanner(_mock_scanner([])) data = { "conversation_id": "37a17cfc-322c-47ab-88c5-e810f23f4739", "generation_id": "049f5b26-326a-4081-82c1-e5c42a63d19e", @@ -146,16 +334,15 @@ def test_cursor_pre_tool_use_shell(self): "workspace_roots": ["/home/user1/foo"], "transcript_path": "/home/user1/.cursor/projects/foo/agent-transcripts/37a17cfc/37a17cfc.jsonl", } - payload = scanner._parse_input(json.dumps(data))[0] + payload = parse_hook_input(json.dumps(data))[0] assert payload.event_type == EventType.PRE_TOOL_USE assert payload.tool == Tool.BASH assert payload.content == "whoami" assert payload.identifier == "whoami" - assert isinstance(payload.flavor, Cursor) + assert isinstance(payload.agent, Cursor) def test_cursor_pre_tool_use_read(self, tmp_file: Path): """Test Cursor preToolUse with Read (file) parsing.""" - scanner = AIHookScanner(_mock_scanner([])) data = { "conversation_id": "75fed8a8-2078-4e49-80d2-776b20d441c3", "generation_id": "1501ede6-b8ac-43f4-9943-0e218610c5c6", @@ -168,17 +355,16 @@ def test_cursor_pre_tool_use_read(self, tmp_file: Path): "workspace_roots": ["/home/user1/foo"], "transcript_path": "/home/user1/.cursor/projects/foo/agent-transcripts/75fed8a8/75fed8a8.jsonl", } - payload = scanner._parse_input(json.dumps(data))[0] + payload = parse_hook_input(json.dumps(data))[0] assert payload.event_type == EventType.PRE_TOOL_USE assert payload.tool == Tool.READ assert payload.identifier == tmp_file.as_posix() assert payload.content == "" assert payload.scannable.content == "this is the content" - assert isinstance(payload.flavor, Cursor) + assert isinstance(payload.agent, Cursor) def test_cursor_post_tool_use_shell(self): """Test Cursor postToolUse with Shell (simulated cat command result).""" - scanner = AIHookScanner(_mock_scanner([])) data = { "conversation_id": "37a17cfc-322c-47ab-88c5-e810f23f4739", "generation_id": "049f5b26-326a-4081-82c1-e5c42a63d19e", @@ -193,15 +379,14 @@ def test_cursor_post_tool_use_shell(self): "workspace_roots": ["/home/user1/foo"], "transcript_path": "/home/user/.cursor/projects/foo/agent-transcripts/37a17cfc/37a17cfc.jsonl", } - payload = scanner._parse_input(json.dumps(data))[0] + payload = parse_hook_input(json.dumps(data))[0] assert payload.event_type == EventType.POST_TOOL_USE assert payload.tool == Tool.BASH assert "user1" in payload.content - assert isinstance(payload.flavor, Cursor) + assert isinstance(payload.agent, Cursor) def test_claude_user_prompt(self): """Test Claude Code UserPromptSubmit parsing.""" - scanner = AIHookScanner(_mock_scanner([])) data = { "session_id": "273ad859-3608-4799-9971-fa15ecb1a65c", "transcript_path": "/home/user1/.claude/projects/foo/273ad859-3608-4799-9971-fa15ecb1a65c.jsonl", @@ -210,15 +395,14 @@ def test_claude_user_prompt(self): "hook_event_name": "UserPromptSubmit", "prompt": "hello world", } - payload = scanner._parse_input(json.dumps(data))[0] + payload = parse_hook_input(json.dumps(data))[0] assert payload.event_type == EventType.USER_PROMPT assert payload.content == "hello world" assert payload.tool is None - assert isinstance(payload.flavor, Claude) + assert isinstance(payload.agent, Claude) def test_claude_pre_tool_use_bash(self): """Test Claude Code PreToolUse with Bash parsing.""" - scanner = AIHookScanner(_mock_scanner([])) data = { "session_id": "3b7ae0c5-0862-4e14-aa2c-12fad909c323", "transcript_path": "/home/user1/.claude/projects/foo/3b7ae0c5.jsonl", @@ -232,15 +416,14 @@ def test_claude_pre_tool_use_bash(self): }, "tool_use_id": "toolu_01BPMKeZAMCqBtn1xJRNfDJw", } - payload = scanner._parse_input(json.dumps(data))[0] + payload = parse_hook_input(json.dumps(data))[0] assert payload.event_type == EventType.PRE_TOOL_USE assert payload.tool == Tool.BASH assert "whoami" in payload.content - assert isinstance(payload.flavor, Claude) + assert isinstance(payload.agent, Claude) def test_claude_pre_tool_use_read(self, tmp_file: Path): """Test Claude Code PreToolUse with Read parsing.""" - scanner = AIHookScanner(_mock_scanner([])) # From raw_hooks_logs: Claude PreToolUse Read data = { "session_id": "3b7ae0c5-0862-4e14-aa2c-12fad909c323", @@ -252,17 +435,16 @@ def test_claude_pre_tool_use_read(self, tmp_file: Path): "tool_input": {"file_path": tmp_file.as_posix()}, "tool_use_id": "toolu_01WabtWJpzf1ZJ8GJ3JfQEmq", } - payload = scanner._parse_input(json.dumps(data))[0] + payload = parse_hook_input(json.dumps(data))[0] assert payload.event_type == EventType.PRE_TOOL_USE assert payload.tool == Tool.READ assert payload.identifier == tmp_file.as_posix() assert payload.content == "" assert payload.scannable.content == "this is the content" - assert isinstance(payload.flavor, Claude) + assert isinstance(payload.agent, Claude) def test_claude_post_tool_use_bash(self): """Test Claude Code PostToolUse with Bash (simulated cat command result).""" - scanner = AIHookScanner(_mock_scanner([])) # From raw_hooks_logs: Claude PostToolUse Bash - tool_response has stdout data = { "session_id": "3b7ae0c5-0862-4e14-aa2c-12fad909c323", @@ -284,16 +466,15 @@ def test_claude_post_tool_use_bash(self): }, "tool_use_id": "toolu_01BPMKeZAMCqBtn1xJRNfDJw", } - payload = scanner._parse_input(json.dumps(data))[0] + payload = parse_hook_input(json.dumps(data))[0] assert payload.event_type == EventType.POST_TOOL_USE assert payload.tool == Tool.BASH # Content is json.dumps(tool_response), so the stdout is inside the string assert "user1" in payload.content - assert isinstance(payload.flavor, Claude) + assert isinstance(payload.agent, Claude) def test_claude_parse_read_files_in_prompt(self): """Test parsing "@file_path" mentions from Claude Code prompt.""" - scanner = AIHookScanner(_mock_scanner([])) data = { "session_id": "273ad859-3608-4799-9971-fa15ecb1a65c", "transcript_path": "/home/user1/.claude/projects/foo/273ad859-3608-4799-9971-fa15ecb1a65c.jsonl", @@ -302,24 +483,23 @@ def test_claude_parse_read_files_in_prompt(self): "hook_event_name": "UserPromptSubmit", "prompt": "read @folder/file.txt and summarize the content.", } - payloads = scanner._parse_input(json.dumps(data)) + payloads = parse_hook_input(json.dumps(data)) assert len(payloads) == 2 payload = payloads[0] assert payload.event_type == EventType.USER_PROMPT assert payload.tool == Tool.READ assert payload.identifier == "folder/file.txt" assert payload.content == "" # empty because inexistent file - assert isinstance(payload.flavor, Claude) + assert isinstance(payload.agent, Claude) payload = payloads[1] assert payload.event_type == EventType.USER_PROMPT assert payload.content == "read @folder/file.txt and summarize the content." assert payload.tool is None - assert isinstance(payload.flavor, Claude) + assert isinstance(payload.agent, Claude) def test_copilot_user_prompt(self): """Test Copilot UserPromptSubmit parsing.""" - scanner = AIHookScanner(_mock_scanner([])) data = { "timestamp": "2026-02-26T11:28:53.112Z", "hookEventName": "UserPromptSubmit", @@ -331,15 +511,14 @@ def test_copilot_user_prompt(self): "prompt": "hello world", "cwd": "/home/user1/foo", } - payload = scanner._parse_input(json.dumps(data))[0] + payload = parse_hook_input(json.dumps(data))[0] assert payload.event_type == EventType.USER_PROMPT assert "hello world" in payload.content assert payload.tool is None - assert isinstance(payload.flavor, Copilot) + assert isinstance(payload.agent, Copilot) def test_copilot_pre_tool_use_run_in_terminal(self): """Test Copilot PreToolUse with run_in_terminal (shell) parsing.""" - scanner = AIHookScanner(_mock_scanner([])) # From raw_hooks_logs: Copilot PreToolUse run_in_terminal data = { "timestamp": "2026-02-26T11:29:05.821Z", @@ -360,15 +539,14 @@ def test_copilot_pre_tool_use_run_in_terminal(self): "tool_use_id": "call_ADJcoVxpnzPtpU6uf0h9wzLR__vscode-1772105116075", "cwd": "/home/user1/foo", } - payload = scanner._parse_input(json.dumps(data))[0] + payload = parse_hook_input(json.dumps(data))[0] assert payload.event_type == EventType.PRE_TOOL_USE assert payload.tool == Tool.BASH assert "whoami" in payload.content - assert isinstance(payload.flavor, Copilot) + assert isinstance(payload.agent, Copilot) def test_copilot_pre_tool_use_read_file(self, tmp_file: Path): """Test Copilot PreToolUse with read_file parsing.""" - scanner = AIHookScanner(_mock_scanner([])) # From raw_hooks_logs: Copilot PreToolUse read_file (nonexistent path for deterministic test) data = { "timestamp": "2026-02-26T11:53:49.593Z", @@ -387,17 +565,16 @@ def test_copilot_pre_tool_use_read_file(self, tmp_file: Path): "tool_use_id": "call_iMFuTGETQ2z23a3xYTqcHBXp__vscode-1772105116078", "cwd": "/home/user1/foo", } - payload = scanner._parse_input(json.dumps(data))[0] + payload = parse_hook_input(json.dumps(data))[0] assert payload.event_type == EventType.PRE_TOOL_USE assert payload.tool == Tool.READ assert payload.identifier == tmp_file.as_posix() assert payload.content == "" assert payload.scannable.content == "this is the content" - assert isinstance(payload.flavor, Copilot) + assert isinstance(payload.agent, Copilot) def test_copilot_post_tool_use_run_in_terminal(self): """Test Copilot PostToolUse with run_in_terminal (simulated cat result).""" - scanner = AIHookScanner(_mock_scanner([])) # From raw_hooks_logs: Copilot PostToolUse run_in_terminal - tool_response is string data = { "timestamp": "2026-02-26T11:53:47.392Z", @@ -419,23 +596,23 @@ def test_copilot_post_tool_use_run_in_terminal(self): "tool_use_id": "call_f96KUoNCGS8jENVKnlWnSz5Q__vscode-1772105116077", "cwd": "/home/user1/foo", } - payload = scanner._parse_input(json.dumps(data))[0] + payload = parse_hook_input(json.dumps(data))[0] assert payload.event_type == EventType.POST_TOOL_USE assert payload.tool == Tool.BASH assert "user1" in payload.content - assert isinstance(payload.flavor, Copilot) + assert isinstance(payload.agent, Copilot) def test_pre_tool_use_read_with_missing_file(self): """PRE_TOOL_USE with tool_name 'read' and non-existing file yields empty content.""" - scanner = AIHookScanner(_mock_scanner([])) content = json.dumps( { "hook_event_name": "pretooluse", "tool_name": "read", "tool_input": {"file_path": "/nonexistent/path"}, + "cursor_version": "1.2.3", } ) - payload = scanner._parse_input(content)[0] + payload = parse_hook_input(content)[0] assert payload.event_type == EventType.PRE_TOOL_USE assert payload.tool == Tool.READ assert payload.identifier == "/nonexistent/path" @@ -443,75 +620,37 @@ def test_pre_tool_use_read_with_missing_file(self): def test_pre_tool_use_other_tool(self): """PRE_TOOL_USE with unknown tool yields Tool.OTHER and empty content.""" - scanner = AIHookScanner(_mock_scanner([])) data = { "hook_event_name": "PreToolUse", "tool_name": "SomeUnknownTool", "tool_input": {"arg": "value"}, + "cursor_version": "1.2.3", } - payload = scanner._parse_input(json.dumps(data))[0] + payload = parse_hook_input(json.dumps(data))[0] assert payload.event_type == EventType.PRE_TOOL_USE assert payload.tool == Tool.OTHER assert payload.content == "" def test_other_event_type(self): """Unknown event type yields EventType.OTHER with empty content.""" - scanner = AIHookScanner(_mock_scanner([])) data = { "hook_event_name": "SomeOtherEvent", "prompt": "hello", + "cursor_version": "1.2.3", } - payload = scanner._parse_input(json.dumps(data))[0] + payload = parse_hook_input(json.dumps(data))[0] assert payload.event_type == EventType.OTHER assert payload.content == "" assert payload.tool is None -class TestAIHookScannerScanContent: - """Unit tests for AIHookScanner._scan_content.""" - - def test_no_secrets_returns_allow(self): - """When scanner returns no secrets, result has block=False and nbr_secrets=0.""" - hook_scanner = AIHookScanner(_mock_scanner([])) - payload = Payload( - event_type=EventType.USER_PROMPT, - tool=None, - content="safe content", - identifier="id", - flavor=Flavor(), - ) - result = hook_scanner._scan_content(payload) - assert isinstance(result, HookResult) - assert result.block is False - assert result.nbr_secrets == 0 - assert result.message == "" - - def test_with_secrets_returns_block_and_message(self): - """When scanner returns secrets, result has block=True, nbr_secrets and message set.""" - hook_scanner = AIHookScanner(_mock_scanner(["sk-xxx"])) - payload = Payload( - event_type=EventType.USER_PROMPT, - tool=None, - content="content with sk-xxx", - identifier="id", - flavor=Flavor(), - ) - result = hook_scanner._scan_content(payload) - assert isinstance(result, HookResult) - assert result.block is True - assert result.nbr_secrets == 1 - assert "dummy-detector" in result.message - assert "secret" in result.message.lower() - assert "remove the secrets from your prompt" in result.message - - class TestFlavorOutputResult: """Unit tests for Cursor, Claude, Copilot output_result with Result objects. Mocks click.echo to capture stdout/stderr and asserts both output and return code. """ - @patch("ggshield.verticals.secret.ai_hook.cursor.click.echo") + @patch("ggshield.verticals.ai.agents.cursor.click.echo") def test_cursor_output_result_user_prompt_allow(self, mock_echo: MagicMock): """Cursor USER_PROMPT with block=False: JSON to stdout, return 0.""" result = HookResult.allow(_dummy_payload(EventType.USER_PROMPT)) @@ -524,7 +663,7 @@ def test_cursor_output_result_user_prompt_allow(self, mock_echo: MagicMock): assert out["continue"] is True assert out["user_message"] == "" - @patch("ggshield.verticals.secret.ai_hook.cursor.click.echo") + @patch("ggshield.verticals.ai.agents.cursor.click.echo") def test_cursor_output_result_user_prompt_block(self, mock_echo: MagicMock): """Cursor USER_PROMPT with block=True: JSON to stdout, return 0.""" result = HookResult( @@ -542,7 +681,7 @@ def test_cursor_output_result_user_prompt_block(self, mock_echo: MagicMock): assert out["continue"] is False assert out["user_message"] == "Remove secrets from prompt" - @patch("ggshield.verticals.secret.ai_hook.cursor.click.echo") + @patch("ggshield.verticals.ai.agents.cursor.click.echo") def test_cursor_output_result_pre_tool_use_allow(self, mock_echo: MagicMock): """Cursor PRE_TOOL_USE with block=False: permission allow, return 0.""" result = HookResult.allow(_dummy_payload(EventType.PRE_TOOL_USE)) @@ -554,7 +693,7 @@ def test_cursor_output_result_pre_tool_use_allow(self, mock_echo: MagicMock): out = json.loads(args[0]) assert out["permission"] == "allow" - @patch("ggshield.verticals.secret.ai_hook.cursor.click.echo") + @patch("ggshield.verticals.ai.agents.cursor.click.echo") def test_cursor_output_result_pre_tool_use_block(self, mock_echo: MagicMock): """Cursor PRE_TOOL_USE with block=True: permission deny, return 0.""" result = HookResult( @@ -572,7 +711,7 @@ def test_cursor_output_result_pre_tool_use_block(self, mock_echo: MagicMock): assert out["permission"] == "deny" assert out["user_message"] == "Secrets detected in command" - @patch("ggshield.verticals.secret.ai_hook.cursor.click.echo") + @patch("ggshield.verticals.ai.agents.cursor.click.echo") def test_cursor_output_result_post_tool_use(self, mock_echo: MagicMock): """Cursor POST_TOOL_USE: empty JSON to stdout, return 0.""" result = HookResult( @@ -588,7 +727,7 @@ def test_cursor_output_result_post_tool_use(self, mock_echo: MagicMock): assert kwargs.get("err", False) is False # stdout (default) assert json.loads(args[0]) == {} - @patch("ggshield.verticals.secret.ai_hook.cursor.click.echo") + @patch("ggshield.verticals.ai.agents.cursor.click.echo") def test_cursor_output_result_other_block(self, mock_echo: MagicMock): """Cursor OTHER event with block: empty JSON, return 2.""" result = HookResult( @@ -601,7 +740,7 @@ def test_cursor_output_result_other_block(self, mock_echo: MagicMock): assert code == 2 mock_echo.assert_called_once_with("{}") - @patch("ggshield.verticals.secret.ai_hook.cursor.click.echo") + @patch("ggshield.verticals.ai.agents.cursor.click.echo") def test_cursor_output_result_other_allow(self, mock_echo: MagicMock): """Cursor OTHER event without block: empty JSON, return 0.""" result = HookResult.allow(_dummy_payload(EventType.OTHER)) @@ -609,7 +748,7 @@ def test_cursor_output_result_other_allow(self, mock_echo: MagicMock): assert code == 0 mock_echo.assert_called_once_with("{}") - @patch("ggshield.verticals.secret.ai_hook.claude_code.click.echo") + @patch("ggshield.verticals.ai.agents.claude_code.click.echo") def test_claude_output_result_allow(self, mock_echo: MagicMock): """Claude with block=False: JSON continue true to stdout, return 0.""" result = HookResult.allow(_dummy_payload(EventType.USER_PROMPT)) @@ -621,7 +760,7 @@ def test_claude_output_result_allow(self, mock_echo: MagicMock): out = json.loads(args[0]) assert out["continue"] is True - @patch("ggshield.verticals.secret.ai_hook.claude_code.click.echo") + @patch("ggshield.verticals.ai.agents.claude_code.click.echo") def test_claude_output_result_block(self, mock_echo: MagicMock): """Claude with block=True: JSON continue false and stopReason to stdout, return 0.""" result = HookResult( @@ -641,7 +780,7 @@ def test_claude_output_result_block(self, mock_echo: MagicMock): out["hookSpecificOutput"]["permissionDecisionReason"] == "Secrets in file" ) - @patch("ggshield.verticals.secret.ai_hook.claude_code.click.echo") + @patch("ggshield.verticals.ai.agents.claude_code.click.echo") def test_copilot_output_result_allow(self, mock_echo: MagicMock): """Copilot with block=False: same as Claude, JSON to stdout, return 0.""" result = HookResult.allow(_dummy_payload(EventType.USER_PROMPT)) @@ -654,7 +793,7 @@ def test_copilot_output_result_allow(self, mock_echo: MagicMock): assert out["continue"] is True assert "stopReason" not in out - @patch("ggshield.verticals.secret.ai_hook.claude_code.click.echo") + @patch("ggshield.verticals.ai.agents.claude_code.click.echo") def test_copilot_output_result_block(self, mock_echo: MagicMock): """Copilot with block=True: same as Claude, JSON to stdout, return 0.""" result = HookResult( @@ -672,7 +811,7 @@ def test_copilot_output_result_block(self, mock_echo: MagicMock): assert out["decision"] == "block" assert out["reason"] == "Secret in tool output" - @patch("ggshield.verticals.secret.ai_hook.claude_code.click.echo") + @patch("ggshield.verticals.ai.agents.claude_code.click.echo") def test_copilot_other_result_block(self, mock_echo: MagicMock): """Copilot with block=True, other type of event""" result = HookResult( @@ -688,201 +827,6 @@ def test_copilot_other_result_block(self, mock_echo: MagicMock): assert not out["continue"] -class TestBaseFlavor: - """Unit tests for the base Flavor class.""" - - @patch("ggshield.verticals.secret.ai_hook.models.click.echo") - def test_base_flavor_output_result_allow(self, mock_echo: MagicMock): - """Base Flavor with block=False: prints allow message, returns 0.""" - result = HookResult.allow(_dummy_payload(EventType.USER_PROMPT)) - code = Flavor().output_result(result) - assert code == 0 - mock_echo.assert_called_once_with("No secrets found. Good to go.") - - @patch("ggshield.verticals.secret.ai_hook.models.click.echo") - def test_base_flavor_output_result_block(self, mock_echo: MagicMock): - """Base Flavor with block=True: prints message to stderr, returns 2.""" - result = HookResult( - block=True, - message="Secrets found", - nbr_secrets=1, - payload=_dummy_payload(EventType.PRE_TOOL_USE), - ) - code = Flavor().output_result(result) - assert code == 2 - mock_echo.assert_called_once_with("Secrets found", err=True) - - def test_base_flavor_settings_path(self): - """Base Flavor settings_path returns default path.""" - assert Flavor().settings_path == Path(".agents") / "hooks.json" - - def test_base_flavor_settings_template(self): - """Base Flavor settings_template returns empty dict.""" - assert Flavor().settings_template == {} - - def test_base_flavor_settings_locate(self): - """Base Flavor settings_locate always returns None.""" - assert Flavor().settings_locate([{"a": 1}], {"a": 1}) is None - - -class TestAIHookScannerScan: - """Unit tests for the AIHookScanner.scan() method.""" - - def test_scan_no_secrets_returns_zero(self): - """scan() with no secrets returns 0.""" - scanner = AIHookScanner(_mock_scanner([])) - data = { - "hook_event_name": "UserPromptSubmit", - "prompt": "hello world", - "transcript_path": "/home/user/.claude/projects/foo/session.jsonl", - } - code = scanner.scan(json.dumps(data)) - assert code == 0 - - @patch( - "ggshield.verticals.secret.ai_hook.scanner.AIHookScanner._send_secret_notification" - ) - def test_scan_post_tool_use_with_secrets_sends_notification( - self, mock_notify: MagicMock - ): - """scan() on POST_TOOL_USE with secrets sends a notification and returns 0 (no block).""" - scanner = AIHookScanner(_mock_scanner(["sk-xxx"])) - data = { - "hook_event_name": "PostToolUse", - "tool_name": "Bash", - "tool_input": {"command": "echo sk-xxx"}, - "tool_response": {"stdout": "sk-xxx\n"}, - "transcript_path": "/home/user/.claude/projects/foo/session.jsonl", - "session_id": "427ae0c5-0862-4e14-aa2c-12fad909c323", - } - code = scanner.scan(json.dumps(data)) - assert code == 0 - mock_notify.assert_called_once() - args = mock_notify.call_args[0] - assert args[0] == 1 # nbr_secrets - assert args[1] == Tool.BASH # tool - - def test_scan_pre_tool_use_with_secrets_blocks(self): - """scan() on PRE_TOOL_USE with secrets returns block result.""" - scanner = AIHookScanner(_mock_scanner(["sk-xxx"])) - data = { - "hook_event_name": "PreToolUse", - "tool_name": "Bash", - "tool_input": {"command": "echo sk-xxx"}, - "session_id": "427ae0c5-0862-4e14-aa2c-12fad909c323", - "transcript_path": "/home/user/.claude/projects/foo/session.jsonl", - } - code = scanner.scan(json.dumps(data)) - # Claude output_result always returns 0 - assert code == 0 - - def test_scan_no_content_returns_allow(self): - """scan() with no content returns 0 (and doesn't call the API).""" - mock_scanner = _mock_scanner([]) - scanner = AIHookScanner(mock_scanner) - data = { - "hook_event_name": "PreToolUse", - "tool_name": "Read", - "tool_input": {"file_path": "doesn-t-exist"}, - } - code = scanner.scan(json.dumps(data)) - assert code == 0 - mock_scanner.scan.assert_not_called() - - def test_scan_payloads_refuse_empty_list(self): - """scan() with empty list of payloads raises ValueError.""" - scanner = AIHookScanner(_mock_scanner([])) - with pytest.raises(ValueError): - scanner._scan_payloads([]) - - -class TestMessageFromSecrets: - """Unit tests for AIHookScanner._message_from_secrets with different payload types.""" - - def test_message_for_bash_tool(self): - """Message for BASH tool mentions environment variables.""" - payload = Payload( - event_type=EventType.PRE_TOOL_USE, - tool=Tool.BASH, - content="echo sk-xxx", - identifier="echo sk-xxx", - flavor=Flavor(), - ) - message = AIHookScanner._message_from_secrets([_make_secret("sk-xxx")], payload) - assert "remove the secrets from the command" in message - assert "environment variables" in message - - def test_message_for_read_tool(self): - """Message for READ tool mentions file content.""" - payload = Payload( - event_type=EventType.PRE_TOOL_USE, - tool=Tool.READ, - content="file content with secret", - identifier="/path/to/file", - flavor=Flavor(), - ) - message = AIHookScanner._message_from_secrets([_make_secret("sk-xxx")], payload) - assert "remove the secrets from" in message - - def test_message_for_other_tool(self): - """Message for OTHER tool uses generic message.""" - payload = Payload( - event_type=EventType.PRE_TOOL_USE, - tool=Tool.OTHER, - content="some content", - identifier="id", - flavor=Flavor(), - ) - message = AIHookScanner._message_from_secrets([_make_secret("sk-xxx")], payload) - assert "remove the secrets from the tool input" in message - - def test_message_escapes_markdown(self): - """When escape_markdown=True, asterisks in matches are replaced with dots.""" - payload = Payload( - event_type=EventType.USER_PROMPT, - tool=None, - content="content", - identifier="id", - flavor=Flavor(), - ) - message = AIHookScanner._message_from_secrets( - [_make_secret("sk-xxx")], payload, escape_markdown=True - ) - # The message itself should not contain raw asterisks from matches - # (the header uses ** for bold which is intentional) - assert "Detected" in message - - -class TestSendSecretNotification: - """Unit tests for AIHookScanner._send_secret_notification.""" - - @patch("ggshield.verticals.secret.ai_hook.scanner.Notify") - def test_notification_for_bash_tool(self, mock_notify_cls: MagicMock): - """Notification for BASH tool says 'running a command'.""" - AIHookScanner._send_secret_notification(1, Tool.BASH, "Claude Code") - instance = mock_notify_cls.return_value - assert "running a command" in instance.message - assert "Claude Code" in instance.message - instance.send.assert_called_once() - - @patch("ggshield.verticals.secret.ai_hook.scanner.Notify") - def test_notification_for_read_tool(self, mock_notify_cls: MagicMock): - """Notification for READ tool says 'reading a file'.""" - AIHookScanner._send_secret_notification(2, Tool.READ, "Cursor") - instance = mock_notify_cls.return_value - assert "reading a file" in instance.message - assert "2" in instance.message - instance.send.assert_called_once() - - @patch("ggshield.verticals.secret.ai_hook.scanner.Notify") - def test_notification_for_other_tool(self, mock_notify_cls: MagicMock): - """Notification for OTHER tool says 'using a tool'.""" - AIHookScanner._send_secret_notification(1, Tool.OTHER, "Copilot") - instance = mock_notify_cls.return_value - assert "using a tool" in instance.message - instance.send.assert_called_once() - - @pytest.mark.parametrize( "prompt, filepaths", [ diff --git a/tests/unit/verticals/secret/ai_hook/test_installation.py b/tests/unit/verticals/ai/test_installation.py similarity index 90% rename from tests/unit/verticals/secret/ai_hook/test_installation.py rename to tests/unit/verticals/ai/test_installation.py index 5a98ef2f7b..20e2dea79d 100644 --- a/tests/unit/verticals/secret/ai_hook/test_installation.py +++ b/tests/unit/verticals/ai/test_installation.py @@ -6,10 +6,8 @@ import pytest from ggshield.core.errors import UnexpectedError -from ggshield.verticals.secret.ai_hook.claude_code import Claude -from ggshield.verticals.secret.ai_hook.copilot import Copilot -from ggshield.verticals.secret.ai_hook.cursor import Cursor -from ggshield.verticals.secret.ai_hook.installation import ( +from ggshield.verticals.ai.agents import Claude, Copilot, Cursor +from ggshield.verticals.ai.installation import ( InstallationStats, _fill_dict, install_hooks, @@ -287,16 +285,14 @@ def test_copilot_settings_path(self): class TestInstallHooks: """Unit tests for the install_hooks function.""" - @patch("ggshield.verticals.secret.ai_hook.installation.get_user_home_dir") + @patch("ggshield.verticals.ai.installation.get_user_home_dir") def test_install_cursor_local_fresh(self, mock_home: Any, tmp_path: Path): """Install Cursor hooks locally into a fresh directory (no existing config).""" mock_home.return_value = tmp_path settings_path = tmp_path / ".cursor" / "hooks.json" assert not settings_path.exists() - with patch( - "ggshield.verticals.secret.ai_hook.installation.Path" - ) as mock_path_cls: + with patch("ggshield.verticals.ai.installation.Path") as mock_path_cls: # Make Path(".") return tmp_path so local mode writes there mock_path_cls.side_effect = lambda *a: Path(*a) if a != (".",) else tmp_path code = install_hooks("cursor", mode="local") @@ -308,7 +304,7 @@ def test_install_cursor_local_fresh(self, mock_home: Any, tmp_path: Path): for key in ("beforeSubmitPrompt", "preToolUse", "postToolUse"): assert any("ggshield" in h["command"] for h in config["hooks"][key]) - @patch("ggshield.verticals.secret.ai_hook.installation.get_user_home_dir") + @patch("ggshield.verticals.ai.installation.get_user_home_dir") def test_install_claude_global(self, mock_home: Any, tmp_path: Path): """Install Claude Code hooks globally.""" mock_home.return_value = tmp_path @@ -322,7 +318,7 @@ def test_install_claude_global(self, mock_home: Any, tmp_path: Path): for key in ("PreToolUse", "PostToolUse", "UserPromptSubmit"): assert key in config["hooks"] - @patch("ggshield.verticals.secret.ai_hook.installation.get_user_home_dir") + @patch("ggshield.verticals.ai.installation.get_user_home_dir") def test_install_copilot_global(self, mock_home: Any, tmp_path: Path): """Install Copilot hooks globally.""" mock_home.return_value = tmp_path @@ -332,12 +328,12 @@ def test_install_copilot_global(self, mock_home: Any, tmp_path: Path): settings_path = tmp_path / ".github" / "hooks" / "hooks.json" assert settings_path.exists() - def test_install_unsupported_tool_raises(self): - """install_hooks raises ValueError for unsupported tool name.""" - with pytest.raises(ValueError, match="Unsupported tool name"): - install_hooks("unknown-tool", mode="local") + def test_install_unsupported_agent_raises(self): + """install_hooks raises ValueError for unsupported agent.""" + with pytest.raises(ValueError, match="Unsupported agent"): + install_hooks("unknown-agent", mode="local") - @patch("ggshield.verticals.secret.ai_hook.installation.get_user_home_dir") + @patch("ggshield.verticals.ai.installation.get_user_home_dir") def test_install_with_existing_config(self, mock_home: Any, tmp_path: Path): """Install hooks when a config file already exists (merges).""" mock_home.return_value = tmp_path @@ -345,9 +341,7 @@ def test_install_with_existing_config(self, mock_home: Any, tmp_path: Path): settings_path.parent.mkdir(parents=True) settings_path.write_text(json.dumps({"version": 1, "other_key": "keep_me"})) - with patch( - "ggshield.verticals.secret.ai_hook.installation.Path" - ) as mock_path_cls: + with patch("ggshield.verticals.ai.installation.Path") as mock_path_cls: mock_path_cls.side_effect = lambda *a: Path(*a) if a != (".",) else tmp_path code = install_hooks("cursor", mode="local") @@ -356,7 +350,7 @@ def test_install_with_existing_config(self, mock_home: Any, tmp_path: Path): assert config["other_key"] == "keep_me" assert "hooks" in config - @patch("ggshield.verticals.secret.ai_hook.installation.get_user_home_dir") + @patch("ggshield.verticals.ai.installation.get_user_home_dir") def test_install_with_corrupt_json_raises(self, mock_home: Any, tmp_path: Path): """install_hooks raises UnexpectedError when existing config is invalid JSON.""" mock_home.return_value = tmp_path @@ -364,21 +358,17 @@ def test_install_with_corrupt_json_raises(self, mock_home: Any, tmp_path: Path): settings_path.parent.mkdir(parents=True) settings_path.write_text("{ invalid json") - with patch( - "ggshield.verticals.secret.ai_hook.installation.Path" - ) as mock_path_cls: + with patch("ggshield.verticals.ai.installation.Path") as mock_path_cls: mock_path_cls.side_effect = lambda *a: Path(*a) if a != (".",) else tmp_path with pytest.raises(UnexpectedError, match="Failed to parse"): install_hooks("cursor", mode="local") - @patch("ggshield.verticals.secret.ai_hook.installation.get_user_home_dir") + @patch("ggshield.verticals.ai.installation.get_user_home_dir") def test_install_already_present(self, mock_home: Any, tmp_path: Path): """install_hooks when hooks are already installed reports 'already installed'.""" mock_home.return_value = tmp_path - with patch( - "ggshield.verticals.secret.ai_hook.installation.Path" - ) as mock_path_cls: + with patch("ggshield.verticals.ai.installation.Path") as mock_path_cls: mock_path_cls.side_effect = lambda *a: Path(*a) if a != (".",) else tmp_path # Install once install_hooks("cursor", mode="local") @@ -387,14 +377,12 @@ def test_install_already_present(self, mock_home: Any, tmp_path: Path): assert code == 0 - @patch("ggshield.verticals.secret.ai_hook.installation.get_user_home_dir") + @patch("ggshield.verticals.ai.installation.get_user_home_dir") def test_install_force_updates(self, mock_home: Any, tmp_path: Path): """install_hooks with force=True updates existing hooks.""" mock_home.return_value = tmp_path - with patch( - "ggshield.verticals.secret.ai_hook.installation.Path" - ) as mock_path_cls: + with patch("ggshield.verticals.ai.installation.Path") as mock_path_cls: mock_path_cls.side_effect = lambda *a: Path(*a) if a != (".",) else tmp_path install_hooks("cursor", mode="local") code = install_hooks("cursor", mode="local", force=True)