From 7474d2b708e4fa09407af27ccf4ec41e8aeba928 Mon Sep 17 00:00:00 2001 From: Paul PETIT Date: Thu, 26 Mar 2026 10:58:47 +0100 Subject: [PATCH 1/2] feat(ai_hook): refactor into new "ai" vertical --- .importlinter | 1 - ggshield/cmd/install.py | 6 +- ggshield/cmd/secret/scan/ai_hook.py | 4 +- ggshield/verticals/ai/agents/__init__.py | 9 + .../ai_hook => ai/agents}/claude_code.py | 14 +- .../{secret/ai_hook => ai/agents}/copilot.py | 12 +- .../{secret/ai_hook => ai/agents}/cursor.py | 14 +- ggshield/verticals/ai/hooks.py | 165 +++++++ .../{secret/ai_hook => ai}/installation.py | 27 +- .../{secret/ai_hook => ai}/models.py | 93 ++-- .../secret/{ai_hook/scanner.py => ai_hook.py} | 177 +------- ggshield/verticals/secret/ai_hook/__init__.py | 5 - scripts/generate-import-linter-config.py | 2 - .../{secret/ai_hook => ai}/test_hooks.py | 422 +++--------------- .../ai_hook => ai}/test_installation.py | 48 +- tests/unit/verticals/secret/test_ai_hooks.py | 263 +++++++++++ 16 files changed, 613 insertions(+), 649 deletions(-) create mode 100644 ggshield/verticals/ai/agents/__init__.py rename ggshield/verticals/{secret/ai_hook => ai/agents}/claude_code.py (92%) rename ggshield/verticals/{secret/ai_hook => ai/agents}/copilot.py (81%) rename ggshield/verticals/{secret/ai_hook => ai/agents}/cursor.py (87%) create mode 100644 ggshield/verticals/ai/hooks.py rename ggshield/verticals/{secret/ai_hook => ai}/installation.py (87%) rename ggshield/verticals/{secret/ai_hook => ai}/models.py (80%) rename ggshield/verticals/secret/{ai_hook/scanner.py => ai_hook.py} (50%) delete mode 100644 ggshield/verticals/secret/ai_hook/__init__.py rename tests/unit/verticals/{secret/ai_hook => ai}/test_hooks.py (60%) rename tests/unit/verticals/{secret/ai_hook => ai}/test_installation.py (90%) create mode 100644 tests/unit/verticals/secret/test_ai_hooks.py diff --git a/.importlinter b/.importlinter index a56c7d3c7e..866d4d06db 100644 --- a/.importlinter +++ b/.importlinter @@ -46,7 +46,6 @@ ignore_imports = ggshield.cmd.hmsl.** -> ggshield.verticals.hmsl.** ggshield.cmd.honeytoken.** -> ggshield.verticals.honeytoken ggshield.cmd.honeytoken.** -> ggshield.verticals.honeytoken.** - ggshield.cmd.install -> ggshield.verticals.secret.ai_hook ggshield.cmd.install.** -> ggshield.verticals.install ggshield.cmd.install.** -> ggshield.verticals.install.** ggshield.cmd.plugin.** -> ggshield.core.plugin diff --git a/ggshield/cmd/install.py b/ggshield/cmd/install.py index f835926d7f..9afee1469c 100644 --- a/ggshield/cmd/install.py +++ b/ggshield/cmd/install.py @@ -10,7 +10,7 @@ from ggshield.core.dirs import get_data_dir from ggshield.core.errors import UnexpectedError from ggshield.utils.git_shell import check_git_dir, git -from ggshield.verticals.secret.ai_hook import AI_FLAVORS, install_hooks +from ggshield.verticals.ai.installation import AGENTS, install_hooks # This snippet is used by the global hook to call the hook defined in the @@ -39,7 +39,7 @@ @click.option( "--hook-type", "-t", - type=click.Choice(["pre-commit", "pre-push"] + list(AI_FLAVORS.keys())), + type=click.Choice(["pre-commit", "pre-push"] + list(AGENTS.keys())), help="Type of hook to install.", default="pre-commit", ) @@ -61,7 +61,7 @@ def install_cmd( It can also install ggshield as a Cursor IDE or Claude Code agent hook. """ - if hook_type in AI_FLAVORS: + if hook_type in AGENTS: return install_hooks(name=hook_type, mode=mode, force=force) return_code = ( diff --git a/ggshield/cmd/secret/scan/ai_hook.py b/ggshield/cmd/secret/scan/ai_hook.py index c53ab6a716..64e835c757 100644 --- a/ggshield/cmd/secret/scan/ai_hook.py +++ b/ggshield/cmd/secret/scan/ai_hook.py @@ -12,7 +12,9 @@ from ggshield.core.scan import ScanContext, ScanMode from ggshield.verticals.secret import SecretScanner from ggshield.verticals.secret.ai_hook import AIHookScanner -from ggshield.verticals.secret.ai_hook.models import MAX_READ_SIZE + + +MAX_READ_SIZE = 1024 * 1024 * 10 # We restrict stdin read to 10MB @click.command() diff --git a/ggshield/verticals/ai/agents/__init__.py b/ggshield/verticals/ai/agents/__init__.py new file mode 100644 index 0000000000..7289463cc7 --- /dev/null +++ b/ggshield/verticals/ai/agents/__init__.py @@ -0,0 +1,9 @@ +from .claude_code import Claude +from .copilot import Copilot +from .cursor import Cursor + + +AGENTS = {agent.name: agent for agent in [Cursor(), Claude(), Copilot()]} + + +__all__ = ["AGENTS", "Claude", "Copilot", "Cursor"] diff --git a/ggshield/verticals/secret/ai_hook/claude_code.py b/ggshield/verticals/ai/agents/claude_code.py similarity index 92% rename from ggshield/verticals/secret/ai_hook/claude_code.py rename to ggshield/verticals/ai/agents/claude_code.py index 378490f904..b9ea7670f1 100644 --- a/ggshield/verticals/secret/ai_hook/claude_code.py +++ b/ggshield/verticals/ai/agents/claude_code.py @@ -4,15 +4,21 @@ import click -from .models import EventType, Flavor, Result +from ..models import Agent, EventType, HookResult -class Claude(Flavor): +class Claude(Agent): """Behavior specific to Claude Code.""" - name = "Claude Code" + @property + def name(self) -> str: + return "claude-code" + + @property + def display_name(self) -> str: + return "Claude Code" - def output_result(self, result: Result) -> int: + def output_result(self, result: HookResult) -> int: response = {} if result.block: if result.payload.event_type in [ diff --git a/ggshield/verticals/secret/ai_hook/copilot.py b/ggshield/verticals/ai/agents/copilot.py similarity index 81% rename from ggshield/verticals/secret/ai_hook/copilot.py rename to ggshield/verticals/ai/agents/copilot.py index b523daa70b..300a8ce132 100644 --- a/ggshield/verticals/secret/ai_hook/copilot.py +++ b/ggshield/verticals/ai/agents/copilot.py @@ -3,8 +3,8 @@ import click +from ..models import EventType, HookResult from .claude_code import Claude -from .models import EventType, Result class Copilot(Claude): @@ -13,9 +13,15 @@ class Copilot(Claude): Inherits most of its behavior from Claude Code. """ - name = "Copilot" + @property + def name(self) -> str: + return "copilot" + + @property + def display_name(self) -> str: + return "Copilot Chat" - def output_result(self, result: Result) -> int: + def output_result(self, result: HookResult) -> int: response = {} if result.block: if result.payload.event_type == EventType.PRE_TOOL_USE: diff --git a/ggshield/verticals/secret/ai_hook/cursor.py b/ggshield/verticals/ai/agents/cursor.py similarity index 87% rename from ggshield/verticals/secret/ai_hook/cursor.py rename to ggshield/verticals/ai/agents/cursor.py index 3b20c46fec..c75e807380 100644 --- a/ggshield/verticals/secret/ai_hook/cursor.py +++ b/ggshield/verticals/ai/agents/cursor.py @@ -4,15 +4,21 @@ import click -from .models import EventType, Flavor, Result +from ..models import Agent, EventType, HookResult -class Cursor(Flavor): +class Cursor(Agent): """Behavior specific to Cursor.""" - name = "Cursor" + @property + def name(self) -> str: + return "cursor" + + @property + def display_name(self) -> str: + return "Cursor" - def output_result(self, result: Result) -> int: + def output_result(self, result: HookResult) -> int: response = {} if result.payload.event_type == EventType.USER_PROMPT: response["continue"] = not result.block diff --git a/ggshield/verticals/ai/hooks.py b/ggshield/verticals/ai/hooks.py new file mode 100644 index 0000000000..75adabcd36 --- /dev/null +++ b/ggshield/verticals/ai/hooks.py @@ -0,0 +1,165 @@ +import hashlib +import json +import re +from typing import Any, Dict, List, Sequence, Set + +from ggshield.verticals.ai.agents import Claude, Copilot, Cursor + +from .models import Agent, EventType, HookPayload, Tool + + +HOOK_NAME_TO_EVENT_TYPE = { + "userpromptsubmit": EventType.USER_PROMPT, + "beforesubmitprompt": EventType.USER_PROMPT, + "pretooluse": EventType.PRE_TOOL_USE, + "posttooluse": EventType.POST_TOOL_USE, +} + +TOOL_NAME_TO_TOOL = { + "shell": Tool.BASH, # Cursor + "bash": Tool.BASH, # Claude Code + "run_in_terminal": Tool.BASH, # Copilot + "read": Tool.READ, # Claude/Cursor + "read_file": Tool.READ, # Copilot +} + + +def lookup(data: Dict[str, Any], keys: Sequence[str], default: Any = None) -> Any: + """Returns the value of the first key found in a dictionary.""" + for key in keys: + if key in data: + return data[key] + return default + + +# Regex (and method) to look for any @file_path in the prompt. +# A list of test cases can be found in test_hooks.py. +_FILE_PATH_REGEX = re.compile( + r'@"((?:[^"\\]|\\.)*)"' # quoted: @"..." + r"|" + r"(?:\W|^)@([\w/\\.-]+)", # unquoted: @path + re.MULTILINE, +) + + +def find_filepaths(prompt: str) -> Set[str]: + """Find all file paths in the prompt.""" + paths = set() + for m in _FILE_PATH_REGEX.finditer(prompt): + path = m.group(1) or m.group(2) or "" + path = path.strip() + # Don't include trailing dots in the path + if path.endswith("."): + path = path[:-1] + if path: + paths.add(path) + return paths + + +def parse_hook_input(raw_content: str) -> list[HookPayload]: + """Parse the input content. Raises a ValueError if the input is not valid. + + Returns: + A list of payloads. Most of the time the list will contain only one payload, + but in some cases files mentioned in the prompt will be read but the + PreToolUse event will not be called. So we need to handle this case ourselves. + """ + # Parse the content as JSON + if not raw_content.strip(): + raise ValueError("Error: No input received on stdin") + try: + data = json.loads(raw_content) + except json.JSONDecodeError as e: + raise ValueError(f"Error: Failed to parse JSON from stdin: {e}") from e + + payloads = [] + + # Try to guess which AI coding assistant is calling us + agent = _detect_agent(data) + + # Infer the event type + event_name = lookup(data, ["hook_event_name", "hookEventName"], None) + if event_name is None: + raise ValueError("Error: couldn't find event type") + event_type = HOOK_NAME_TO_EVENT_TYPE.get(event_name.lower(), EventType.OTHER) + + identifier = "" + content = "" + tool = None + + # Extract the identifier and content based on the event type + if event_type == EventType.USER_PROMPT: + content = data.get("prompt", "") + # Look for files mentioned in the prompt that could be read + # without triggering a PRE_TOOL_USE event. + payloads.extend(_parse_user_prompt(content, event_type, agent)) + + elif event_type == EventType.PRE_TOOL_USE: + tool_name = data.get("tool_name", "").lower() + tool = TOOL_NAME_TO_TOOL.get(tool_name, Tool.OTHER) + tool_input = data.get("tool_input", {}) + # Select the content based on the tool + if tool == Tool.BASH: + content = tool_input.get("command", "") + identifier = content + elif tool == Tool.READ: + # We only need to deal with the identifier, the content will be read by the Scannable + identifier = lookup(tool_input, ["file_path", "filePath"], "") + + elif event_type == EventType.POST_TOOL_USE: + tool_name = data.get("tool_name", "").lower() + tool = TOOL_NAME_TO_TOOL.get(tool_name, Tool.OTHER) + content = data.get("tool_output", "") or data.get("tool_response", {}) + # Claude Code returns a dict for the tool output + if isinstance(content, (dict, list)): + content = json.dumps(content) + + # If identifier was not set, hash the content + if not identifier: + identifier = hashlib.sha256((content or "").encode()).hexdigest() + + payloads.append( + HookPayload( + event_type=event_type, + tool=tool, + content=content, + identifier=identifier, + agent=agent, + ) + ) + return payloads + + +def _detect_agent(data: Dict[str, Any]) -> Agent: + """Detect the AI code assistant.""" + if "cursor_version" in data: + return Cursor() + elif "github.copilot-chat" in data.get("transcript_path", "").lower(): + return Copilot() + # no .lower() here to reduce the risk of false positives (this is also why this check is last) + elif "session_id" in data and "claude" in data.get("transcript_path", ""): + return Claude() + # No other agent is supported yet + raise ValueError("Unsupported agent") + + +def _parse_user_prompt( + content: str, event_type: EventType, agent: Agent +) -> List[HookPayload]: + """Parse the user prompt for additional payloads that we may miss.""" + payloads = [] + # Scenario 1 (the only one we know about so far): + # Code assistants don't always trigger a PRE_TOOL_USE event when + # a file is mentioned in the prompt, especially with an "@" prefix. + matches = find_filepaths(content) + for match in matches: + payloads.append( + HookPayload( + event_type=event_type, + tool=Tool.READ, + content="", + identifier=match, + agent=agent, + ) + ) + return payloads diff --git a/ggshield/verticals/secret/ai_hook/installation.py b/ggshield/verticals/ai/installation.py similarity index 87% rename from ggshield/verticals/secret/ai_hook/installation.py rename to ggshield/verticals/ai/installation.py index 3139c7b3d5..f624c1fc15 100644 --- a/ggshield/verticals/secret/ai_hook/installation.py +++ b/ggshield/verticals/ai/installation.py @@ -9,16 +9,7 @@ from ggshield.core.dirs import get_user_home_dir from ggshield.core.errors import UnexpectedError -from .claude_code import Claude -from .copilot import Copilot -from .cursor import Cursor - - -AI_FLAVORS = { - "cursor": Cursor, - "claude-code": Claude, - "copilot": Copilot, -} +from .agents import AGENTS @dataclass @@ -41,12 +32,12 @@ def install_hooks( """ try: - flavor = AI_FLAVORS[name]() + agent = AGENTS[name] except KeyError: - raise ValueError(f"Unsupported tool name: {name}") + raise ValueError(f"Unsupported agent: {name}") base_dir = get_user_home_dir() if mode == "global" else Path(".") - settings_path = base_dir / flavor.settings_path + settings_path = base_dir / agent.settings_path command = "ggshield secret scan ai-hook" @@ -71,11 +62,11 @@ def install_hooks( stats = _fill_dict( config=existing_config, - template=flavor.settings_template, + template=agent.settings_template, command=command, overwrite=force, stats=stats, - locator=flavor.settings_locate, + locator=agent.settings_locate, ) # Ensure parent directory exists @@ -89,11 +80,11 @@ def install_hooks( # Report what happened styled_path = click.style(settings_path, fg="yellow", bold=True) if stats.added == 0 and stats.already_present > 0: - click.echo(f"{flavor.name} hooks already installed in {styled_path}") + click.echo(f"{agent.name} hooks already installed in {styled_path}") elif stats.added > 0 and stats.already_present > 0: - click.echo(f"{flavor.name} hooks updated in {styled_path}") + click.echo(f"{agent.name} hooks updated in {styled_path}") else: - click.echo(f"{flavor.name} hooks successfully added in {styled_path}") + click.echo(f"{agent.name} hooks successfully added in {styled_path}") return 0 diff --git a/ggshield/verticals/secret/ai_hook/models.py b/ggshield/verticals/ai/models.py similarity index 80% rename from ggshield/verticals/secret/ai_hook/models.py rename to ggshield/verticals/ai/models.py index a031e2eea3..fecf9ddd27 100644 --- a/ggshield/verticals/secret/ai_hook/models.py +++ b/ggshield/verticals/ai/models.py @@ -1,17 +1,13 @@ +from abc import ABC, abstractmethod from dataclasses import dataclass from enum import Enum, auto from pathlib import Path from typing import Any, Dict, List, Optional -import click - from ggshield.core.scan import File, Scannable, StringScannable from ggshield.utils.files import is_path_binary -MAX_READ_SIZE = 1024 * 1024 * 50 # We restrict payloads read to 50MB - - class EventType(Enum): """Event type constants for hook events.""" @@ -33,31 +29,64 @@ class Tool(Enum): @dataclass -class Result: +class HookResult: """Result of a scan: allow or not.""" block: bool message: str nbr_secrets: int - payload: "Payload" + payload: "HookPayload" @classmethod - def allow(cls, payload: "Payload") -> "Result": + def allow(cls, payload: "HookPayload") -> "HookResult": return cls(block=False, message="", nbr_secrets=0, payload=payload) -class Flavor: +@dataclass +class HookPayload: + event_type: EventType + tool: Optional[Tool] + content: str + identifier: str + agent: "Agent" + + @property + def scannable(self) -> Scannable: + """Return the appropriate Scannable for the payload.""" + if self.tool == Tool.READ: + path = Path(self.identifier) + if path.is_file() and not is_path_binary(path): + return File(path=self.identifier) + return StringScannable(url=self.identifier, content=self.content) + + @property + def empty(self) -> bool: + """Return True if the payload is empty.""" + return not self.scannable.is_longer_than(0) + + +class Agent(ABC): """ Class that can be derived to implement behavior specific to some AI code assistants. """ - name = "Your AI coding tool" + # Metadata - def output_result(self, result: Result) -> int: - """How to output the result of a scan. + @property + @abstractmethod + def display_name(self) -> str: + """A user-friendly name for the agent.""" - This base implementation has sensible defaults (like returning 2 in case of a block, - and printing the output in stderr or stdout). + @property + @abstractmethod + def name(self) -> str: + """The name of the agent.""" + + # Hooks + + @abstractmethod + def output_result(self, result: HookResult) -> int: + """How to output the result of a scan. This method is expected to have side effects, like printing to stdout or stderr. @@ -66,26 +95,23 @@ def output_result(self, result: Result) -> int: Returns: the exit code. """ - if result.block: - click.echo(result.message, err=True) - return 2 - else: - click.echo("No secrets found. Good to go.") - return 0 + + # Settings @property + @abstractmethod def settings_path(self) -> Path: """Path to the settings file for this AI coding tool.""" - return Path(".agents") / "hooks.json" @property + @abstractmethod def settings_template(self) -> Dict[str, Any]: """ Template for the settings file for this AI coding tool. Use the sentinel "" for the places where the command should be inserted. """ - return {} + @abstractmethod def settings_locate( self, candidates: List[Dict[str, Any]], template: Dict[str, Any] ) -> Optional[Dict[str, Any]]: @@ -102,26 +128,3 @@ def settings_locate( Returns: the object to update, or None if no object was found. """ return None - - -@dataclass -class Payload: - event_type: EventType - tool: Optional[Tool] - content: str - identifier: str - flavor: Flavor - - @property - def scannable(self) -> Scannable: - """Return the appropriate Scannable for the payload.""" - if self.tool == Tool.READ: - path = Path(self.identifier) - if path.is_file() and not is_path_binary(path): - return File(path=self.identifier) - return StringScannable(url=self.identifier, content=self.content) - - @property - def empty(self) -> bool: - """Return True if the payload is empty.""" - return not self.scannable.is_longer_than(0) diff --git a/ggshield/verticals/secret/ai_hook/scanner.py b/ggshield/verticals/secret/ai_hook.py similarity index 50% rename from ggshield/verticals/secret/ai_hook/scanner.py rename to ggshield/verticals/secret/ai_hook.py index 7eacfe5abe..332262a3bf 100644 --- a/ggshield/verticals/secret/ai_hook/scanner.py +++ b/ggshield/verticals/secret/ai_hook.py @@ -1,45 +1,18 @@ -import hashlib -import json -import re -from typing import Any, Dict, List, Sequence, Set +from typing import List from notifypy import Notify from ggshield.core.filter import censor_match from ggshield.core.scanner_ui import create_message_only_scanner_ui from ggshield.core.text_utils import pluralize, translate_validity +from ggshield.verticals.ai.hooks import parse_hook_input +from ggshield.verticals.ai.models import EventType +from ggshield.verticals.ai.models import HookPayload as Payload +from ggshield.verticals.ai.models import HookResult as Result +from ggshield.verticals.ai.models import Tool from ggshield.verticals.secret import SecretScanner -from ggshield.verticals.secret.ai_hook.copilot import Copilot from ggshield.verticals.secret.secret_scan_collection import Secret -from .claude_code import Claude -from .cursor import Cursor -from .models import EventType, Flavor, Payload, Result, Tool - - -# Regex (and method) to look for any @file_path in the prompt. -# A list of test cases can be found in test_hooks.py. -_FILE_PATH_REGEX = re.compile( - r'@"((?:[^"\\]|\\.)*)"' # quoted: @"..." - r"|" - r"(?:\W|^)@([\w/\\.-]+)", # unquoted: @path - re.MULTILINE, -) - - -def find_filepaths(prompt: str) -> Set[str]: - """Find all file paths in the prompt.""" - paths = set() - for m in _FILE_PATH_REGEX.finditer(prompt): - path = m.group(1) or m.group(2) or "" - path = path.strip() - # Don't include trailing dots in the path - if path.endswith("."): - path = path[:-1] - if path: - paths.add(path) - return paths - class AIHookScanner: """AI hook scanner. @@ -61,90 +34,17 @@ def __init__(self, scanner: SecretScanner): def scan(self, content: str) -> int: """Scan the content, print the result and return the exit code.""" - payloads = self._parse_input(content) + payloads = parse_hook_input(content) result = self._scan_payloads(payloads) payload = result.payload # Special case: in post-tool use, the action is already done: at least notify the user if result.block and payload.event_type == EventType.POST_TOOL_USE: self._send_secret_notification( - result.nbr_secrets, payload.tool or Tool.OTHER, payload.flavor.name + result.nbr_secrets, payload.tool or Tool.OTHER, payload.agent.name ) - return payload.flavor.output_result(result) - - def _parse_input(self, raw_content: str) -> list[Payload]: - """Parse the input content. Raises a ValueError if the input is not valid. - - Returns: - A list of payloads. Most of the time the list will contain only one payload, - but in some cases files mentioned in the prompt will be read but the - PreToolUse event will not be called. So we need to handle this case ourselves. - """ - # Parse the content as JSON - if not raw_content.strip(): - raise ValueError("Error: No input received on stdin") - try: - data = json.loads(raw_content) - except json.JSONDecodeError as e: - raise ValueError(f"Error: Failed to parse JSON from stdin: {e}") from e - - payloads = [] - - # Try to guess which AI coding assistant is calling us - flavor = self._detect_flavor(data) - - # Infer the event type - event_name = lookup(data, ["hook_event_name", "hookEventName"], None) - if event_name is None: - raise ValueError("Error: couldn't find event type") - event_type = HOOK_NAME_TO_EVENT_TYPE.get(event_name.lower(), EventType.OTHER) - - identifier = "" - content = "" - tool = None - - # Extract the identifier and content based on the event type - if event_type == EventType.USER_PROMPT: - content = data.get("prompt", "") - # Look for files mentioned in the prompt that could be read - # without triggering a PRE_TOOL_USE event. - payloads.extend(self._parse_user_prompt(content, event_type, flavor)) - - elif event_type == EventType.PRE_TOOL_USE: - tool_name = data.get("tool_name", "").lower() - tool = TOOL_NAME_TO_TOOL.get(tool_name, Tool.OTHER) - tool_input = data.get("tool_input", {}) - # Select the content based on the tool - if tool == Tool.BASH: - content = tool_input.get("command", "") - identifier = content - elif tool == Tool.READ: - # We only need to deal with the identifier, the content will be read by the Scannable - identifier = lookup(tool_input, ["file_path", "filePath"], "") - - elif event_type == EventType.POST_TOOL_USE: - tool_name = data.get("tool_name", "").lower() - tool = TOOL_NAME_TO_TOOL.get(tool_name, Tool.OTHER) - content = data.get("tool_output", "") or data.get("tool_response", {}) - # Claude Code returns a dict for the tool output - if isinstance(content, (dict, list)): - content = json.dumps(content) - - # If identifier was not set, hash the content - if not identifier: - identifier = hashlib.sha256((content or "").encode()).hexdigest() - - payloads.append( - Payload( - event_type=event_type, - tool=tool, - content=content, - identifier=identifier, - flavor=flavor, - ) - ) - return payloads + return payload.agent.output_result(result) def _scan_payloads(self, payloads: List[Payload]) -> Result: """Scan payloads for secrets using the SecretScanner. @@ -192,41 +92,6 @@ def _scan_content( payload=payload, ) - @staticmethod - def _detect_flavor(data: Dict[str, Any]) -> Flavor: - """Detect the AI code assistant.""" - if "cursor_version" in data: - return Cursor() - elif "github.copilot-chat" in data.get("transcript_path", "").lower(): - return Copilot() - # no .lower() here to reduce the risk of false positives (this is also why this check is last) - elif "session_id" in data and "claude" in data.get("transcript_path", ""): - return Claude() - else: - # Fallback that respect base conventions - return Flavor() - - def _parse_user_prompt( - self, content: str, event_type: EventType, flavor: Flavor - ) -> List[Payload]: - """Parse the user prompt for additional payloads that we may miss.""" - payloads = [] - # Scenario 1 (the only one we know about so far): - # Code assistants don't always trigger a PRE_TOOL_USE event when - # a file is mentioned in the prompt, especially with an "@" prefix. - matches = find_filepaths(content) - for match in matches: - payloads.append( - Payload( - event_type=event_type, - tool=Tool.READ, - content="", - identifier=match, - flavor=flavor, - ) - ) - return payloads - @staticmethod def _message_from_secrets( secrets: List[Secret], payload: Payload, escape_markdown: bool = False @@ -308,27 +173,3 @@ def _send_secret_notification( # This is best effort, we don't want to propagate an error # if the notification fails. pass - - -HOOK_NAME_TO_EVENT_TYPE = { - "userpromptsubmit": EventType.USER_PROMPT, - "beforesubmitprompt": EventType.USER_PROMPT, - "pretooluse": EventType.PRE_TOOL_USE, - "posttooluse": EventType.POST_TOOL_USE, -} - -TOOL_NAME_TO_TOOL = { - "shell": Tool.BASH, # Cursor - "bash": Tool.BASH, # Claude Code - "run_in_terminal": Tool.BASH, # Copilot - "read": Tool.READ, # Claude/Cursor - "read_file": Tool.READ, # Copilot -} - - -def lookup(data: Dict[str, Any], keys: Sequence[str], default: Any = None) -> Any: - """Returns the value of the first key found in a dictionary.""" - for key in keys: - if key in data: - return data[key] - return default diff --git a/ggshield/verticals/secret/ai_hook/__init__.py b/ggshield/verticals/secret/ai_hook/__init__.py deleted file mode 100644 index 68d4170d4a..0000000000 --- a/ggshield/verticals/secret/ai_hook/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from .installation import AI_FLAVORS, install_hooks -from .scanner import AIHookScanner - - -__all__ = ["AIHookScanner", "install_hooks", "AI_FLAVORS"] diff --git a/scripts/generate-import-linter-config.py b/scripts/generate-import-linter-config.py index bdb0823275..8bc41c5678 100755 --- a/scripts/generate-import-linter-config.py +++ b/scripts/generate-import-linter-config.py @@ -64,8 +64,6 @@ class Contract(TypedDict): "ggshield.cmd.{}.** -> ggshield.verticals.{}.**", # FIXME: #521 - enforce boundaries between cmd.auth and verticals.hmsl "ggshield.cmd.auth.** -> ggshield.verticals.hmsl.**", - # Logic to install hooks for AI assistants - "ggshield.cmd.install -> ggshield.verticals.secret.ai_hook", ], "unmatched_ignore_imports_alerting": "none", }, diff --git a/tests/unit/verticals/secret/ai_hook/test_hooks.py b/tests/unit/verticals/ai/test_hooks.py similarity index 60% rename from tests/unit/verticals/secret/ai_hook/test_hooks.py rename to tests/unit/verticals/ai/test_hooks.py index f763cf76b4..e63d6f0361 100644 --- a/tests/unit/verticals/secret/ai_hook/test_hooks.py +++ b/tests/unit/verticals/ai/test_hooks.py @@ -1,74 +1,22 @@ import json -from collections import Counter from pathlib import Path -from typing import List, Set +from typing import Set from unittest.mock import MagicMock, patch import pytest -from ggshield.utils.git_shell import Filemode -from ggshield.verticals.secret import SecretScanner -from ggshield.verticals.secret.ai_hook.claude_code import Claude -from ggshield.verticals.secret.ai_hook.copilot import Copilot -from ggshield.verticals.secret.ai_hook.cursor import Cursor -from ggshield.verticals.secret.ai_hook.models import EventType, Flavor, Payload -from ggshield.verticals.secret.ai_hook.models import Result as HookResult -from ggshield.verticals.secret.ai_hook.models import Tool -from ggshield.verticals.secret.ai_hook.scanner import AIHookScanner, find_filepaths -from ggshield.verticals.secret.secret_scan_collection import Result as ScanResult -from ggshield.verticals.secret.secret_scan_collection import Results, Secret - - -def _mock_scanner(matches: List[str]) -> MagicMock: - """Create a mock SecretScanner that returns the given Results from scan().""" - mock = MagicMock(spec=SecretScanner) - scan_result = Results( - results=[ - ScanResult( - filename="url", - filemode=Filemode.FILE, - path=Path("."), - url="url", - secrets=[_make_secret(match) for match in matches], - ignored_secrets_count_by_kind=Counter(), - ) - ], - errors=[], - ) - mock.scan.return_value = scan_result - return mock - - -def _make_secret(match_str: str = "***"): - """Minimal Secret for tests; _message_from_secrets only uses detector_display_name, validity, matches[].match.""" - mock_match = MagicMock() - mock_match.match = match_str - return Secret( - detector_display_name="dummy-detector", - detector_name="dummy-detector", - detector_group_name=None, - documentation_url=None, - validity="valid", - known_secret=False, - incident_url=None, - matches=[mock_match], - ignore_reason=None, - diff_kind=None, - is_vaulted=False, - vault_type=None, - vault_name=None, - vault_path=None, - vault_path_count=None, - ) +from ggshield.verticals.ai.agents import Claude, Copilot, Cursor +from ggshield.verticals.ai.hooks import find_filepaths, parse_hook_input +from ggshield.verticals.ai.models import EventType, HookPayload, HookResult, Tool -def _dummy_payload(event_type: EventType = EventType.OTHER) -> Payload: - return Payload( +def _dummy_payload(event_type: EventType = EventType.OTHER) -> HookPayload: + return HookPayload( event_type=event_type, tool=None, content="", identifier="", - flavor=Flavor(), + agent=Cursor(), ) @@ -81,33 +29,22 @@ def tmp_file(tmp_path: Path) -> Path: class TestAIHookScannerParseInput: - """Unit tests for AIHookScanner._parse_input.""" - - def test_empty_input_raises(self): - """Empty or whitespace-only input raises ValueError.""" - scanner = AIHookScanner(_mock_scanner([])) - with pytest.raises(ValueError, match="No input received on stdin"): - scanner.scan("") - with pytest.raises(ValueError, match="No input received on stdin"): - scanner.scan(" \n ") + """Unit tests for AIHookparse_hook_input.""" def test_invalid_json_raises(self): """Invalid JSON raises ValueError with parse error.""" - scanner = AIHookScanner(_mock_scanner([])) with pytest.raises(ValueError, match="Failed to parse JSON"): - scanner._parse_input("not json {") + parse_hook_input("not json {") with pytest.raises(ValueError, match="Failed to parse JSON"): - scanner._parse_input("{ missing brace ") + parse_hook_input("{ missing brace ") def test_missing_event_type_raises(self): """JSON without event type raises ValueError.""" - scanner = AIHookScanner(_mock_scanner([])) - with pytest.raises(ValueError, match="couldn't find event type"): - scanner._parse_input('{"prompt": "hello"}') + with pytest.raises(ValueError): + parse_hook_input('{"prompt": "hello"}') def test_cursor_user_prompt(self): """Test Cursor beforeSubmitPrompt (user prompt) parsing.""" - scanner = AIHookScanner(_mock_scanner([])) data = { "conversation_id": "75fed8a8-2078-4e49-80d2-776b20d441c3", "generation_id": "1501ede6-b8ac-43f4-9943-0e218610c5c6", @@ -120,16 +57,15 @@ def test_cursor_user_prompt(self): "user_email": "user@example.com", "transcript_path": "/home/user1/.cursor/projects/foo/agent-transcripts/75fed8a8/75fed8a8.jsonl", } - payload = scanner._parse_input(json.dumps(data))[0] + payload = parse_hook_input(json.dumps(data))[0] assert payload.event_type == EventType.USER_PROMPT assert payload.content == "hello world" assert payload.tool is None assert payload.identifier != "" - assert isinstance(payload.flavor, Cursor) + assert isinstance(payload.agent, Cursor) def test_cursor_pre_tool_use_shell(self): """Test Cursor preToolUse with Shell (bash) parsing.""" - scanner = AIHookScanner(_mock_scanner([])) data = { "conversation_id": "37a17cfc-322c-47ab-88c5-e810f23f4739", "generation_id": "049f5b26-326a-4081-82c1-e5c42a63d19e", @@ -146,16 +82,15 @@ def test_cursor_pre_tool_use_shell(self): "workspace_roots": ["/home/user1/foo"], "transcript_path": "/home/user1/.cursor/projects/foo/agent-transcripts/37a17cfc/37a17cfc.jsonl", } - payload = scanner._parse_input(json.dumps(data))[0] + payload = parse_hook_input(json.dumps(data))[0] assert payload.event_type == EventType.PRE_TOOL_USE assert payload.tool == Tool.BASH assert payload.content == "whoami" assert payload.identifier == "whoami" - assert isinstance(payload.flavor, Cursor) + assert isinstance(payload.agent, Cursor) def test_cursor_pre_tool_use_read(self, tmp_file: Path): """Test Cursor preToolUse with Read (file) parsing.""" - scanner = AIHookScanner(_mock_scanner([])) data = { "conversation_id": "75fed8a8-2078-4e49-80d2-776b20d441c3", "generation_id": "1501ede6-b8ac-43f4-9943-0e218610c5c6", @@ -168,17 +103,16 @@ def test_cursor_pre_tool_use_read(self, tmp_file: Path): "workspace_roots": ["/home/user1/foo"], "transcript_path": "/home/user1/.cursor/projects/foo/agent-transcripts/75fed8a8/75fed8a8.jsonl", } - payload = scanner._parse_input(json.dumps(data))[0] + payload = parse_hook_input(json.dumps(data))[0] assert payload.event_type == EventType.PRE_TOOL_USE assert payload.tool == Tool.READ assert payload.identifier == tmp_file.as_posix() assert payload.content == "" assert payload.scannable.content == "this is the content" - assert isinstance(payload.flavor, Cursor) + assert isinstance(payload.agent, Cursor) def test_cursor_post_tool_use_shell(self): """Test Cursor postToolUse with Shell (simulated cat command result).""" - scanner = AIHookScanner(_mock_scanner([])) data = { "conversation_id": "37a17cfc-322c-47ab-88c5-e810f23f4739", "generation_id": "049f5b26-326a-4081-82c1-e5c42a63d19e", @@ -193,15 +127,14 @@ def test_cursor_post_tool_use_shell(self): "workspace_roots": ["/home/user1/foo"], "transcript_path": "/home/user/.cursor/projects/foo/agent-transcripts/37a17cfc/37a17cfc.jsonl", } - payload = scanner._parse_input(json.dumps(data))[0] + payload = parse_hook_input(json.dumps(data))[0] assert payload.event_type == EventType.POST_TOOL_USE assert payload.tool == Tool.BASH assert "user1" in payload.content - assert isinstance(payload.flavor, Cursor) + assert isinstance(payload.agent, Cursor) def test_claude_user_prompt(self): """Test Claude Code UserPromptSubmit parsing.""" - scanner = AIHookScanner(_mock_scanner([])) data = { "session_id": "273ad859-3608-4799-9971-fa15ecb1a65c", "transcript_path": "/home/user1/.claude/projects/foo/273ad859-3608-4799-9971-fa15ecb1a65c.jsonl", @@ -210,15 +143,14 @@ def test_claude_user_prompt(self): "hook_event_name": "UserPromptSubmit", "prompt": "hello world", } - payload = scanner._parse_input(json.dumps(data))[0] + payload = parse_hook_input(json.dumps(data))[0] assert payload.event_type == EventType.USER_PROMPT assert payload.content == "hello world" assert payload.tool is None - assert isinstance(payload.flavor, Claude) + assert isinstance(payload.agent, Claude) def test_claude_pre_tool_use_bash(self): """Test Claude Code PreToolUse with Bash parsing.""" - scanner = AIHookScanner(_mock_scanner([])) data = { "session_id": "3b7ae0c5-0862-4e14-aa2c-12fad909c323", "transcript_path": "/home/user1/.claude/projects/foo/3b7ae0c5.jsonl", @@ -232,15 +164,14 @@ def test_claude_pre_tool_use_bash(self): }, "tool_use_id": "toolu_01BPMKeZAMCqBtn1xJRNfDJw", } - payload = scanner._parse_input(json.dumps(data))[0] + payload = parse_hook_input(json.dumps(data))[0] assert payload.event_type == EventType.PRE_TOOL_USE assert payload.tool == Tool.BASH assert "whoami" in payload.content - assert isinstance(payload.flavor, Claude) + assert isinstance(payload.agent, Claude) def test_claude_pre_tool_use_read(self, tmp_file: Path): """Test Claude Code PreToolUse with Read parsing.""" - scanner = AIHookScanner(_mock_scanner([])) # From raw_hooks_logs: Claude PreToolUse Read data = { "session_id": "3b7ae0c5-0862-4e14-aa2c-12fad909c323", @@ -252,17 +183,16 @@ def test_claude_pre_tool_use_read(self, tmp_file: Path): "tool_input": {"file_path": tmp_file.as_posix()}, "tool_use_id": "toolu_01WabtWJpzf1ZJ8GJ3JfQEmq", } - payload = scanner._parse_input(json.dumps(data))[0] + payload = parse_hook_input(json.dumps(data))[0] assert payload.event_type == EventType.PRE_TOOL_USE assert payload.tool == Tool.READ assert payload.identifier == tmp_file.as_posix() assert payload.content == "" assert payload.scannable.content == "this is the content" - assert isinstance(payload.flavor, Claude) + assert isinstance(payload.agent, Claude) def test_claude_post_tool_use_bash(self): """Test Claude Code PostToolUse with Bash (simulated cat command result).""" - scanner = AIHookScanner(_mock_scanner([])) # From raw_hooks_logs: Claude PostToolUse Bash - tool_response has stdout data = { "session_id": "3b7ae0c5-0862-4e14-aa2c-12fad909c323", @@ -284,16 +214,15 @@ def test_claude_post_tool_use_bash(self): }, "tool_use_id": "toolu_01BPMKeZAMCqBtn1xJRNfDJw", } - payload = scanner._parse_input(json.dumps(data))[0] + payload = parse_hook_input(json.dumps(data))[0] assert payload.event_type == EventType.POST_TOOL_USE assert payload.tool == Tool.BASH # Content is json.dumps(tool_response), so the stdout is inside the string assert "user1" in payload.content - assert isinstance(payload.flavor, Claude) + assert isinstance(payload.agent, Claude) def test_claude_parse_read_files_in_prompt(self): """Test parsing "@file_path" mentions from Claude Code prompt.""" - scanner = AIHookScanner(_mock_scanner([])) data = { "session_id": "273ad859-3608-4799-9971-fa15ecb1a65c", "transcript_path": "/home/user1/.claude/projects/foo/273ad859-3608-4799-9971-fa15ecb1a65c.jsonl", @@ -302,24 +231,23 @@ def test_claude_parse_read_files_in_prompt(self): "hook_event_name": "UserPromptSubmit", "prompt": "read @folder/file.txt and summarize the content.", } - payloads = scanner._parse_input(json.dumps(data)) + payloads = parse_hook_input(json.dumps(data)) assert len(payloads) == 2 payload = payloads[0] assert payload.event_type == EventType.USER_PROMPT assert payload.tool == Tool.READ assert payload.identifier == "folder/file.txt" assert payload.content == "" # empty because inexistent file - assert isinstance(payload.flavor, Claude) + assert isinstance(payload.agent, Claude) payload = payloads[1] assert payload.event_type == EventType.USER_PROMPT assert payload.content == "read @folder/file.txt and summarize the content." assert payload.tool is None - assert isinstance(payload.flavor, Claude) + assert isinstance(payload.agent, Claude) def test_copilot_user_prompt(self): """Test Copilot UserPromptSubmit parsing.""" - scanner = AIHookScanner(_mock_scanner([])) data = { "timestamp": "2026-02-26T11:28:53.112Z", "hookEventName": "UserPromptSubmit", @@ -331,15 +259,14 @@ def test_copilot_user_prompt(self): "prompt": "hello world", "cwd": "/home/user1/foo", } - payload = scanner._parse_input(json.dumps(data))[0] + payload = parse_hook_input(json.dumps(data))[0] assert payload.event_type == EventType.USER_PROMPT assert "hello world" in payload.content assert payload.tool is None - assert isinstance(payload.flavor, Copilot) + assert isinstance(payload.agent, Copilot) def test_copilot_pre_tool_use_run_in_terminal(self): """Test Copilot PreToolUse with run_in_terminal (shell) parsing.""" - scanner = AIHookScanner(_mock_scanner([])) # From raw_hooks_logs: Copilot PreToolUse run_in_terminal data = { "timestamp": "2026-02-26T11:29:05.821Z", @@ -360,15 +287,14 @@ def test_copilot_pre_tool_use_run_in_terminal(self): "tool_use_id": "call_ADJcoVxpnzPtpU6uf0h9wzLR__vscode-1772105116075", "cwd": "/home/user1/foo", } - payload = scanner._parse_input(json.dumps(data))[0] + payload = parse_hook_input(json.dumps(data))[0] assert payload.event_type == EventType.PRE_TOOL_USE assert payload.tool == Tool.BASH assert "whoami" in payload.content - assert isinstance(payload.flavor, Copilot) + assert isinstance(payload.agent, Copilot) def test_copilot_pre_tool_use_read_file(self, tmp_file: Path): """Test Copilot PreToolUse with read_file parsing.""" - scanner = AIHookScanner(_mock_scanner([])) # From raw_hooks_logs: Copilot PreToolUse read_file (nonexistent path for deterministic test) data = { "timestamp": "2026-02-26T11:53:49.593Z", @@ -387,17 +313,16 @@ def test_copilot_pre_tool_use_read_file(self, tmp_file: Path): "tool_use_id": "call_iMFuTGETQ2z23a3xYTqcHBXp__vscode-1772105116078", "cwd": "/home/user1/foo", } - payload = scanner._parse_input(json.dumps(data))[0] + payload = parse_hook_input(json.dumps(data))[0] assert payload.event_type == EventType.PRE_TOOL_USE assert payload.tool == Tool.READ assert payload.identifier == tmp_file.as_posix() assert payload.content == "" assert payload.scannable.content == "this is the content" - assert isinstance(payload.flavor, Copilot) + assert isinstance(payload.agent, Copilot) def test_copilot_post_tool_use_run_in_terminal(self): """Test Copilot PostToolUse with run_in_terminal (simulated cat result).""" - scanner = AIHookScanner(_mock_scanner([])) # From raw_hooks_logs: Copilot PostToolUse run_in_terminal - tool_response is string data = { "timestamp": "2026-02-26T11:53:47.392Z", @@ -419,23 +344,23 @@ def test_copilot_post_tool_use_run_in_terminal(self): "tool_use_id": "call_f96KUoNCGS8jENVKnlWnSz5Q__vscode-1772105116077", "cwd": "/home/user1/foo", } - payload = scanner._parse_input(json.dumps(data))[0] + payload = parse_hook_input(json.dumps(data))[0] assert payload.event_type == EventType.POST_TOOL_USE assert payload.tool == Tool.BASH assert "user1" in payload.content - assert isinstance(payload.flavor, Copilot) + assert isinstance(payload.agent, Copilot) def test_pre_tool_use_read_with_missing_file(self): """PRE_TOOL_USE with tool_name 'read' and non-existing file yields empty content.""" - scanner = AIHookScanner(_mock_scanner([])) content = json.dumps( { "hook_event_name": "pretooluse", "tool_name": "read", "tool_input": {"file_path": "/nonexistent/path"}, + "cursor_version": "1.2.3", } ) - payload = scanner._parse_input(content)[0] + payload = parse_hook_input(content)[0] assert payload.event_type == EventType.PRE_TOOL_USE assert payload.tool == Tool.READ assert payload.identifier == "/nonexistent/path" @@ -443,75 +368,37 @@ def test_pre_tool_use_read_with_missing_file(self): def test_pre_tool_use_other_tool(self): """PRE_TOOL_USE with unknown tool yields Tool.OTHER and empty content.""" - scanner = AIHookScanner(_mock_scanner([])) data = { "hook_event_name": "PreToolUse", "tool_name": "SomeUnknownTool", "tool_input": {"arg": "value"}, + "cursor_version": "1.2.3", } - payload = scanner._parse_input(json.dumps(data))[0] + payload = parse_hook_input(json.dumps(data))[0] assert payload.event_type == EventType.PRE_TOOL_USE assert payload.tool == Tool.OTHER assert payload.content == "" def test_other_event_type(self): """Unknown event type yields EventType.OTHER with empty content.""" - scanner = AIHookScanner(_mock_scanner([])) data = { "hook_event_name": "SomeOtherEvent", "prompt": "hello", + "cursor_version": "1.2.3", } - payload = scanner._parse_input(json.dumps(data))[0] + payload = parse_hook_input(json.dumps(data))[0] assert payload.event_type == EventType.OTHER assert payload.content == "" assert payload.tool is None -class TestAIHookScannerScanContent: - """Unit tests for AIHookScanner._scan_content.""" - - def test_no_secrets_returns_allow(self): - """When scanner returns no secrets, result has block=False and nbr_secrets=0.""" - hook_scanner = AIHookScanner(_mock_scanner([])) - payload = Payload( - event_type=EventType.USER_PROMPT, - tool=None, - content="safe content", - identifier="id", - flavor=Flavor(), - ) - result = hook_scanner._scan_content(payload) - assert isinstance(result, HookResult) - assert result.block is False - assert result.nbr_secrets == 0 - assert result.message == "" - - def test_with_secrets_returns_block_and_message(self): - """When scanner returns secrets, result has block=True, nbr_secrets and message set.""" - hook_scanner = AIHookScanner(_mock_scanner(["sk-xxx"])) - payload = Payload( - event_type=EventType.USER_PROMPT, - tool=None, - content="content with sk-xxx", - identifier="id", - flavor=Flavor(), - ) - result = hook_scanner._scan_content(payload) - assert isinstance(result, HookResult) - assert result.block is True - assert result.nbr_secrets == 1 - assert "dummy-detector" in result.message - assert "secret" in result.message.lower() - assert "remove the secrets from your prompt" in result.message - - class TestFlavorOutputResult: """Unit tests for Cursor, Claude, Copilot output_result with Result objects. Mocks click.echo to capture stdout/stderr and asserts both output and return code. """ - @patch("ggshield.verticals.secret.ai_hook.cursor.click.echo") + @patch("ggshield.verticals.ai.agents.cursor.click.echo") def test_cursor_output_result_user_prompt_allow(self, mock_echo: MagicMock): """Cursor USER_PROMPT with block=False: JSON to stdout, return 0.""" result = HookResult.allow(_dummy_payload(EventType.USER_PROMPT)) @@ -524,7 +411,7 @@ def test_cursor_output_result_user_prompt_allow(self, mock_echo: MagicMock): assert out["continue"] is True assert out["user_message"] == "" - @patch("ggshield.verticals.secret.ai_hook.cursor.click.echo") + @patch("ggshield.verticals.ai.agents.cursor.click.echo") def test_cursor_output_result_user_prompt_block(self, mock_echo: MagicMock): """Cursor USER_PROMPT with block=True: JSON to stdout, return 0.""" result = HookResult( @@ -542,7 +429,7 @@ def test_cursor_output_result_user_prompt_block(self, mock_echo: MagicMock): assert out["continue"] is False assert out["user_message"] == "Remove secrets from prompt" - @patch("ggshield.verticals.secret.ai_hook.cursor.click.echo") + @patch("ggshield.verticals.ai.agents.cursor.click.echo") def test_cursor_output_result_pre_tool_use_allow(self, mock_echo: MagicMock): """Cursor PRE_TOOL_USE with block=False: permission allow, return 0.""" result = HookResult.allow(_dummy_payload(EventType.PRE_TOOL_USE)) @@ -554,7 +441,7 @@ def test_cursor_output_result_pre_tool_use_allow(self, mock_echo: MagicMock): out = json.loads(args[0]) assert out["permission"] == "allow" - @patch("ggshield.verticals.secret.ai_hook.cursor.click.echo") + @patch("ggshield.verticals.ai.agents.cursor.click.echo") def test_cursor_output_result_pre_tool_use_block(self, mock_echo: MagicMock): """Cursor PRE_TOOL_USE with block=True: permission deny, return 0.""" result = HookResult( @@ -572,7 +459,7 @@ def test_cursor_output_result_pre_tool_use_block(self, mock_echo: MagicMock): assert out["permission"] == "deny" assert out["user_message"] == "Secrets detected in command" - @patch("ggshield.verticals.secret.ai_hook.cursor.click.echo") + @patch("ggshield.verticals.ai.agents.cursor.click.echo") def test_cursor_output_result_post_tool_use(self, mock_echo: MagicMock): """Cursor POST_TOOL_USE: empty JSON to stdout, return 0.""" result = HookResult( @@ -588,7 +475,7 @@ def test_cursor_output_result_post_tool_use(self, mock_echo: MagicMock): assert kwargs.get("err", False) is False # stdout (default) assert json.loads(args[0]) == {} - @patch("ggshield.verticals.secret.ai_hook.cursor.click.echo") + @patch("ggshield.verticals.ai.agents.cursor.click.echo") def test_cursor_output_result_other_block(self, mock_echo: MagicMock): """Cursor OTHER event with block: empty JSON, return 2.""" result = HookResult( @@ -601,7 +488,7 @@ def test_cursor_output_result_other_block(self, mock_echo: MagicMock): assert code == 2 mock_echo.assert_called_once_with("{}") - @patch("ggshield.verticals.secret.ai_hook.cursor.click.echo") + @patch("ggshield.verticals.ai.agents.cursor.click.echo") def test_cursor_output_result_other_allow(self, mock_echo: MagicMock): """Cursor OTHER event without block: empty JSON, return 0.""" result = HookResult.allow(_dummy_payload(EventType.OTHER)) @@ -609,7 +496,7 @@ def test_cursor_output_result_other_allow(self, mock_echo: MagicMock): assert code == 0 mock_echo.assert_called_once_with("{}") - @patch("ggshield.verticals.secret.ai_hook.claude_code.click.echo") + @patch("ggshield.verticals.ai.agents.claude_code.click.echo") def test_claude_output_result_allow(self, mock_echo: MagicMock): """Claude with block=False: JSON continue true to stdout, return 0.""" result = HookResult.allow(_dummy_payload(EventType.USER_PROMPT)) @@ -621,7 +508,7 @@ def test_claude_output_result_allow(self, mock_echo: MagicMock): out = json.loads(args[0]) assert out["continue"] is True - @patch("ggshield.verticals.secret.ai_hook.claude_code.click.echo") + @patch("ggshield.verticals.ai.agents.claude_code.click.echo") def test_claude_output_result_block(self, mock_echo: MagicMock): """Claude with block=True: JSON continue false and stopReason to stdout, return 0.""" result = HookResult( @@ -641,7 +528,7 @@ def test_claude_output_result_block(self, mock_echo: MagicMock): out["hookSpecificOutput"]["permissionDecisionReason"] == "Secrets in file" ) - @patch("ggshield.verticals.secret.ai_hook.claude_code.click.echo") + @patch("ggshield.verticals.ai.agents.claude_code.click.echo") def test_copilot_output_result_allow(self, mock_echo: MagicMock): """Copilot with block=False: same as Claude, JSON to stdout, return 0.""" result = HookResult.allow(_dummy_payload(EventType.USER_PROMPT)) @@ -654,7 +541,7 @@ def test_copilot_output_result_allow(self, mock_echo: MagicMock): assert out["continue"] is True assert "stopReason" not in out - @patch("ggshield.verticals.secret.ai_hook.claude_code.click.echo") + @patch("ggshield.verticals.ai.agents.claude_code.click.echo") def test_copilot_output_result_block(self, mock_echo: MagicMock): """Copilot with block=True: same as Claude, JSON to stdout, return 0.""" result = HookResult( @@ -672,7 +559,7 @@ def test_copilot_output_result_block(self, mock_echo: MagicMock): assert out["decision"] == "block" assert out["reason"] == "Secret in tool output" - @patch("ggshield.verticals.secret.ai_hook.claude_code.click.echo") + @patch("ggshield.verticals.ai.agents.claude_code.click.echo") def test_copilot_other_result_block(self, mock_echo: MagicMock): """Copilot with block=True, other type of event""" result = HookResult( @@ -688,201 +575,6 @@ def test_copilot_other_result_block(self, mock_echo: MagicMock): assert not out["continue"] -class TestBaseFlavor: - """Unit tests for the base Flavor class.""" - - @patch("ggshield.verticals.secret.ai_hook.models.click.echo") - def test_base_flavor_output_result_allow(self, mock_echo: MagicMock): - """Base Flavor with block=False: prints allow message, returns 0.""" - result = HookResult.allow(_dummy_payload(EventType.USER_PROMPT)) - code = Flavor().output_result(result) - assert code == 0 - mock_echo.assert_called_once_with("No secrets found. Good to go.") - - @patch("ggshield.verticals.secret.ai_hook.models.click.echo") - def test_base_flavor_output_result_block(self, mock_echo: MagicMock): - """Base Flavor with block=True: prints message to stderr, returns 2.""" - result = HookResult( - block=True, - message="Secrets found", - nbr_secrets=1, - payload=_dummy_payload(EventType.PRE_TOOL_USE), - ) - code = Flavor().output_result(result) - assert code == 2 - mock_echo.assert_called_once_with("Secrets found", err=True) - - def test_base_flavor_settings_path(self): - """Base Flavor settings_path returns default path.""" - assert Flavor().settings_path == Path(".agents") / "hooks.json" - - def test_base_flavor_settings_template(self): - """Base Flavor settings_template returns empty dict.""" - assert Flavor().settings_template == {} - - def test_base_flavor_settings_locate(self): - """Base Flavor settings_locate always returns None.""" - assert Flavor().settings_locate([{"a": 1}], {"a": 1}) is None - - -class TestAIHookScannerScan: - """Unit tests for the AIHookScanner.scan() method.""" - - def test_scan_no_secrets_returns_zero(self): - """scan() with no secrets returns 0.""" - scanner = AIHookScanner(_mock_scanner([])) - data = { - "hook_event_name": "UserPromptSubmit", - "prompt": "hello world", - "transcript_path": "/home/user/.claude/projects/foo/session.jsonl", - } - code = scanner.scan(json.dumps(data)) - assert code == 0 - - @patch( - "ggshield.verticals.secret.ai_hook.scanner.AIHookScanner._send_secret_notification" - ) - def test_scan_post_tool_use_with_secrets_sends_notification( - self, mock_notify: MagicMock - ): - """scan() on POST_TOOL_USE with secrets sends a notification and returns 0 (no block).""" - scanner = AIHookScanner(_mock_scanner(["sk-xxx"])) - data = { - "hook_event_name": "PostToolUse", - "tool_name": "Bash", - "tool_input": {"command": "echo sk-xxx"}, - "tool_response": {"stdout": "sk-xxx\n"}, - "transcript_path": "/home/user/.claude/projects/foo/session.jsonl", - "session_id": "427ae0c5-0862-4e14-aa2c-12fad909c323", - } - code = scanner.scan(json.dumps(data)) - assert code == 0 - mock_notify.assert_called_once() - args = mock_notify.call_args[0] - assert args[0] == 1 # nbr_secrets - assert args[1] == Tool.BASH # tool - - def test_scan_pre_tool_use_with_secrets_blocks(self): - """scan() on PRE_TOOL_USE with secrets returns block result.""" - scanner = AIHookScanner(_mock_scanner(["sk-xxx"])) - data = { - "hook_event_name": "PreToolUse", - "tool_name": "Bash", - "tool_input": {"command": "echo sk-xxx"}, - "session_id": "427ae0c5-0862-4e14-aa2c-12fad909c323", - "transcript_path": "/home/user/.claude/projects/foo/session.jsonl", - } - code = scanner.scan(json.dumps(data)) - # Claude output_result always returns 0 - assert code == 0 - - def test_scan_no_content_returns_allow(self): - """scan() with no content returns 0 (and doesn't call the API).""" - mock_scanner = _mock_scanner([]) - scanner = AIHookScanner(mock_scanner) - data = { - "hook_event_name": "PreToolUse", - "tool_name": "Read", - "tool_input": {"file_path": "doesn-t-exist"}, - } - code = scanner.scan(json.dumps(data)) - assert code == 0 - mock_scanner.scan.assert_not_called() - - def test_scan_payloads_refuse_empty_list(self): - """scan() with empty list of payloads raises ValueError.""" - scanner = AIHookScanner(_mock_scanner([])) - with pytest.raises(ValueError): - scanner._scan_payloads([]) - - -class TestMessageFromSecrets: - """Unit tests for AIHookScanner._message_from_secrets with different payload types.""" - - def test_message_for_bash_tool(self): - """Message for BASH tool mentions environment variables.""" - payload = Payload( - event_type=EventType.PRE_TOOL_USE, - tool=Tool.BASH, - content="echo sk-xxx", - identifier="echo sk-xxx", - flavor=Flavor(), - ) - message = AIHookScanner._message_from_secrets([_make_secret("sk-xxx")], payload) - assert "remove the secrets from the command" in message - assert "environment variables" in message - - def test_message_for_read_tool(self): - """Message for READ tool mentions file content.""" - payload = Payload( - event_type=EventType.PRE_TOOL_USE, - tool=Tool.READ, - content="file content with secret", - identifier="/path/to/file", - flavor=Flavor(), - ) - message = AIHookScanner._message_from_secrets([_make_secret("sk-xxx")], payload) - assert "remove the secrets from" in message - - def test_message_for_other_tool(self): - """Message for OTHER tool uses generic message.""" - payload = Payload( - event_type=EventType.PRE_TOOL_USE, - tool=Tool.OTHER, - content="some content", - identifier="id", - flavor=Flavor(), - ) - message = AIHookScanner._message_from_secrets([_make_secret("sk-xxx")], payload) - assert "remove the secrets from the tool input" in message - - def test_message_escapes_markdown(self): - """When escape_markdown=True, asterisks in matches are replaced with dots.""" - payload = Payload( - event_type=EventType.USER_PROMPT, - tool=None, - content="content", - identifier="id", - flavor=Flavor(), - ) - message = AIHookScanner._message_from_secrets( - [_make_secret("sk-xxx")], payload, escape_markdown=True - ) - # The message itself should not contain raw asterisks from matches - # (the header uses ** for bold which is intentional) - assert "Detected" in message - - -class TestSendSecretNotification: - """Unit tests for AIHookScanner._send_secret_notification.""" - - @patch("ggshield.verticals.secret.ai_hook.scanner.Notify") - def test_notification_for_bash_tool(self, mock_notify_cls: MagicMock): - """Notification for BASH tool says 'running a command'.""" - AIHookScanner._send_secret_notification(1, Tool.BASH, "Claude Code") - instance = mock_notify_cls.return_value - assert "running a command" in instance.message - assert "Claude Code" in instance.message - instance.send.assert_called_once() - - @patch("ggshield.verticals.secret.ai_hook.scanner.Notify") - def test_notification_for_read_tool(self, mock_notify_cls: MagicMock): - """Notification for READ tool says 'reading a file'.""" - AIHookScanner._send_secret_notification(2, Tool.READ, "Cursor") - instance = mock_notify_cls.return_value - assert "reading a file" in instance.message - assert "2" in instance.message - instance.send.assert_called_once() - - @patch("ggshield.verticals.secret.ai_hook.scanner.Notify") - def test_notification_for_other_tool(self, mock_notify_cls: MagicMock): - """Notification for OTHER tool says 'using a tool'.""" - AIHookScanner._send_secret_notification(1, Tool.OTHER, "Copilot") - instance = mock_notify_cls.return_value - assert "using a tool" in instance.message - instance.send.assert_called_once() - - @pytest.mark.parametrize( "prompt, filepaths", [ diff --git a/tests/unit/verticals/secret/ai_hook/test_installation.py b/tests/unit/verticals/ai/test_installation.py similarity index 90% rename from tests/unit/verticals/secret/ai_hook/test_installation.py rename to tests/unit/verticals/ai/test_installation.py index 5a98ef2f7b..20e2dea79d 100644 --- a/tests/unit/verticals/secret/ai_hook/test_installation.py +++ b/tests/unit/verticals/ai/test_installation.py @@ -6,10 +6,8 @@ import pytest from ggshield.core.errors import UnexpectedError -from ggshield.verticals.secret.ai_hook.claude_code import Claude -from ggshield.verticals.secret.ai_hook.copilot import Copilot -from ggshield.verticals.secret.ai_hook.cursor import Cursor -from ggshield.verticals.secret.ai_hook.installation import ( +from ggshield.verticals.ai.agents import Claude, Copilot, Cursor +from ggshield.verticals.ai.installation import ( InstallationStats, _fill_dict, install_hooks, @@ -287,16 +285,14 @@ def test_copilot_settings_path(self): class TestInstallHooks: """Unit tests for the install_hooks function.""" - @patch("ggshield.verticals.secret.ai_hook.installation.get_user_home_dir") + @patch("ggshield.verticals.ai.installation.get_user_home_dir") def test_install_cursor_local_fresh(self, mock_home: Any, tmp_path: Path): """Install Cursor hooks locally into a fresh directory (no existing config).""" mock_home.return_value = tmp_path settings_path = tmp_path / ".cursor" / "hooks.json" assert not settings_path.exists() - with patch( - "ggshield.verticals.secret.ai_hook.installation.Path" - ) as mock_path_cls: + with patch("ggshield.verticals.ai.installation.Path") as mock_path_cls: # Make Path(".") return tmp_path so local mode writes there mock_path_cls.side_effect = lambda *a: Path(*a) if a != (".",) else tmp_path code = install_hooks("cursor", mode="local") @@ -308,7 +304,7 @@ def test_install_cursor_local_fresh(self, mock_home: Any, tmp_path: Path): for key in ("beforeSubmitPrompt", "preToolUse", "postToolUse"): assert any("ggshield" in h["command"] for h in config["hooks"][key]) - @patch("ggshield.verticals.secret.ai_hook.installation.get_user_home_dir") + @patch("ggshield.verticals.ai.installation.get_user_home_dir") def test_install_claude_global(self, mock_home: Any, tmp_path: Path): """Install Claude Code hooks globally.""" mock_home.return_value = tmp_path @@ -322,7 +318,7 @@ def test_install_claude_global(self, mock_home: Any, tmp_path: Path): for key in ("PreToolUse", "PostToolUse", "UserPromptSubmit"): assert key in config["hooks"] - @patch("ggshield.verticals.secret.ai_hook.installation.get_user_home_dir") + @patch("ggshield.verticals.ai.installation.get_user_home_dir") def test_install_copilot_global(self, mock_home: Any, tmp_path: Path): """Install Copilot hooks globally.""" mock_home.return_value = tmp_path @@ -332,12 +328,12 @@ def test_install_copilot_global(self, mock_home: Any, tmp_path: Path): settings_path = tmp_path / ".github" / "hooks" / "hooks.json" assert settings_path.exists() - def test_install_unsupported_tool_raises(self): - """install_hooks raises ValueError for unsupported tool name.""" - with pytest.raises(ValueError, match="Unsupported tool name"): - install_hooks("unknown-tool", mode="local") + def test_install_unsupported_agent_raises(self): + """install_hooks raises ValueError for unsupported agent.""" + with pytest.raises(ValueError, match="Unsupported agent"): + install_hooks("unknown-agent", mode="local") - @patch("ggshield.verticals.secret.ai_hook.installation.get_user_home_dir") + @patch("ggshield.verticals.ai.installation.get_user_home_dir") def test_install_with_existing_config(self, mock_home: Any, tmp_path: Path): """Install hooks when a config file already exists (merges).""" mock_home.return_value = tmp_path @@ -345,9 +341,7 @@ def test_install_with_existing_config(self, mock_home: Any, tmp_path: Path): settings_path.parent.mkdir(parents=True) settings_path.write_text(json.dumps({"version": 1, "other_key": "keep_me"})) - with patch( - "ggshield.verticals.secret.ai_hook.installation.Path" - ) as mock_path_cls: + with patch("ggshield.verticals.ai.installation.Path") as mock_path_cls: mock_path_cls.side_effect = lambda *a: Path(*a) if a != (".",) else tmp_path code = install_hooks("cursor", mode="local") @@ -356,7 +350,7 @@ def test_install_with_existing_config(self, mock_home: Any, tmp_path: Path): assert config["other_key"] == "keep_me" assert "hooks" in config - @patch("ggshield.verticals.secret.ai_hook.installation.get_user_home_dir") + @patch("ggshield.verticals.ai.installation.get_user_home_dir") def test_install_with_corrupt_json_raises(self, mock_home: Any, tmp_path: Path): """install_hooks raises UnexpectedError when existing config is invalid JSON.""" mock_home.return_value = tmp_path @@ -364,21 +358,17 @@ def test_install_with_corrupt_json_raises(self, mock_home: Any, tmp_path: Path): settings_path.parent.mkdir(parents=True) settings_path.write_text("{ invalid json") - with patch( - "ggshield.verticals.secret.ai_hook.installation.Path" - ) as mock_path_cls: + with patch("ggshield.verticals.ai.installation.Path") as mock_path_cls: mock_path_cls.side_effect = lambda *a: Path(*a) if a != (".",) else tmp_path with pytest.raises(UnexpectedError, match="Failed to parse"): install_hooks("cursor", mode="local") - @patch("ggshield.verticals.secret.ai_hook.installation.get_user_home_dir") + @patch("ggshield.verticals.ai.installation.get_user_home_dir") def test_install_already_present(self, mock_home: Any, tmp_path: Path): """install_hooks when hooks are already installed reports 'already installed'.""" mock_home.return_value = tmp_path - with patch( - "ggshield.verticals.secret.ai_hook.installation.Path" - ) as mock_path_cls: + with patch("ggshield.verticals.ai.installation.Path") as mock_path_cls: mock_path_cls.side_effect = lambda *a: Path(*a) if a != (".",) else tmp_path # Install once install_hooks("cursor", mode="local") @@ -387,14 +377,12 @@ def test_install_already_present(self, mock_home: Any, tmp_path: Path): assert code == 0 - @patch("ggshield.verticals.secret.ai_hook.installation.get_user_home_dir") + @patch("ggshield.verticals.ai.installation.get_user_home_dir") def test_install_force_updates(self, mock_home: Any, tmp_path: Path): """install_hooks with force=True updates existing hooks.""" mock_home.return_value = tmp_path - with patch( - "ggshield.verticals.secret.ai_hook.installation.Path" - ) as mock_path_cls: + with patch("ggshield.verticals.ai.installation.Path") as mock_path_cls: mock_path_cls.side_effect = lambda *a: Path(*a) if a != (".",) else tmp_path install_hooks("cursor", mode="local") code = install_hooks("cursor", mode="local", force=True) diff --git a/tests/unit/verticals/secret/test_ai_hooks.py b/tests/unit/verticals/secret/test_ai_hooks.py new file mode 100644 index 0000000000..5bfeca82cd --- /dev/null +++ b/tests/unit/verticals/secret/test_ai_hooks.py @@ -0,0 +1,263 @@ +import json +from collections import Counter +from pathlib import Path +from typing import List +from unittest.mock import MagicMock, patch + +import pytest + +from ggshield.utils.git_shell import Filemode +from ggshield.verticals.ai.agents import Cursor +from ggshield.verticals.secret import SecretScanner +from ggshield.verticals.secret.ai_hook import AIHookScanner, EventType, Payload +from ggshield.verticals.secret.ai_hook import Result as HookResult +from ggshield.verticals.secret.ai_hook import Tool +from ggshield.verticals.secret.secret_scan_collection import Result as ScanResult +from ggshield.verticals.secret.secret_scan_collection import Results, Secret + + +def _mock_scanner(matches: List[str]) -> MagicMock: + """Create a mock SecretScanner that returns the given Results from scan().""" + mock = MagicMock(spec=SecretScanner) + scan_result = Results( + results=[ + ScanResult( + filename="url", + filemode=Filemode.FILE, + path=Path("."), + url="url", + secrets=[_make_secret(match) for match in matches], + ignored_secrets_count_by_kind=Counter(), + ) + ], + errors=[], + ) + mock.scan.return_value = scan_result + return mock + + +def _make_secret(match_str: str = "***"): + """Minimal Secret for tests; _message_from_secrets only uses detector_display_name, validity, matches[].match.""" + mock_match = MagicMock() + mock_match.match = match_str + return Secret( + detector_display_name="dummy-detector", + detector_name="dummy-detector", + detector_group_name=None, + documentation_url=None, + validity="valid", + known_secret=False, + incident_url=None, + matches=[mock_match], + ignore_reason=None, + diff_kind=None, + is_vaulted=False, + vault_type=None, + vault_name=None, + vault_path=None, + vault_path_count=None, + ) + + +class TestAIHookScannerScanContent: + """Unit tests for AIHookScanner._scan_content.""" + + def test_no_secrets_returns_allow(self): + """When scanner returns no secrets, result has block=False and nbr_secrets=0.""" + hook_scanner = AIHookScanner(_mock_scanner([])) + payload = Payload( + event_type=EventType.USER_PROMPT, + tool=None, + content="safe content", + identifier="id", + agent=Cursor(), + ) + result = hook_scanner._scan_content(payload) + assert isinstance(result, HookResult) + assert result.block is False + assert result.nbr_secrets == 0 + assert result.message == "" + + def test_with_secrets_returns_block_and_message(self): + """When scanner returns secrets, result has block=True, nbr_secrets and message set.""" + hook_scanner = AIHookScanner(_mock_scanner(["sk-xxx"])) + payload = Payload( + event_type=EventType.USER_PROMPT, + tool=None, + content="content with sk-xxx", + identifier="id", + agent=Cursor(), + ) + result = hook_scanner._scan_content(payload) + assert isinstance(result, HookResult) + assert result.block is True + assert result.nbr_secrets == 1 + assert "dummy-detector" in result.message + assert "secret" in result.message.lower() + assert "remove the secrets from your prompt" in result.message + + +class TestAIHookScannerScan: + """Unit tests for the AIHookScanner.scan() method.""" + + def test_empty_input_raises(self): + """Empty or whitespace-only input raises ValueError.""" + scanner = AIHookScanner(_mock_scanner([])) + with pytest.raises(ValueError, match="No input received on stdin"): + scanner.scan("") + with pytest.raises(ValueError, match="No input received on stdin"): + scanner.scan(" \n ") + + def test_scan_no_secrets_returns_zero(self): + """scan() with no secrets returns 0.""" + scanner = AIHookScanner(_mock_scanner([])) + data = { + "hook_event_name": "UserPromptSubmit", + "prompt": "hello world", + "transcript_path": "/home/user/.claude/projects/foo/session.jsonl", + "cursor_version": "1.2.3", + } + code = scanner.scan(json.dumps(data)) + assert code == 0 + + @patch("ggshield.verticals.secret.ai_hook.AIHookScanner._send_secret_notification") + def test_scan_post_tool_use_with_secrets_sends_notification( + self, mock_notify: MagicMock + ): + """scan() on POST_TOOL_USE with secrets sends a notification and returns 0 (no block).""" + scanner = AIHookScanner(_mock_scanner(["sk-xxx"])) + data = { + "hook_event_name": "PostToolUse", + "tool_name": "Bash", + "tool_input": {"command": "echo sk-xxx"}, + "tool_response": {"stdout": "sk-xxx\n"}, + "transcript_path": "/home/user/.claude/projects/foo/session.jsonl", + "session_id": "427ae0c5-0862-4e14-aa2c-12fad909c323", + } + code = scanner.scan(json.dumps(data)) + assert code == 0 + mock_notify.assert_called_once() + args = mock_notify.call_args[0] + assert args[0] == 1 # nbr_secrets + assert args[1] == Tool.BASH # tool + + def test_scan_pre_tool_use_with_secrets_blocks(self): + """scan() on PRE_TOOL_USE with secrets returns block result.""" + scanner = AIHookScanner(_mock_scanner(["sk-xxx"])) + data = { + "hook_event_name": "PreToolUse", + "tool_name": "Bash", + "tool_input": {"command": "echo sk-xxx"}, + "session_id": "427ae0c5-0862-4e14-aa2c-12fad909c323", + "transcript_path": "/home/user/.claude/projects/foo/session.jsonl", + } + code = scanner.scan(json.dumps(data)) + # Claude output_result always returns 0 + assert code == 0 + + def test_scan_no_content_returns_allow(self): + """scan() with no content returns 0 (and doesn't call the API).""" + mock_scanner = _mock_scanner([]) + scanner = AIHookScanner(mock_scanner) + data = { + "hook_event_name": "PreToolUse", + "tool_name": "Read", + "tool_input": {"file_path": "doesn-t-exist"}, + "cursor_version": "1.2.3", + } + code = scanner.scan(json.dumps(data)) + assert code == 0 + mock_scanner.scan.assert_not_called() + + def test_scan_payloads_refuse_empty_list(self): + """scan() with empty list of payloads raises ValueError.""" + scanner = AIHookScanner(_mock_scanner([])) + with pytest.raises(ValueError): + scanner._scan_payloads([]) + + +class TestMessageFromSecrets: + """Unit tests for AIHookScanner._message_from_secrets with different payload types.""" + + def test_message_for_bash_tool(self): + """Message for BASH tool mentions environment variables.""" + payload = Payload( + event_type=EventType.PRE_TOOL_USE, + tool=Tool.BASH, + content="echo sk-xxx", + identifier="echo sk-xxx", + agent=Cursor(), + ) + message = AIHookScanner._message_from_secrets([_make_secret("sk-xxx")], payload) + assert "remove the secrets from the command" in message + assert "environment variables" in message + + def test_message_for_read_tool(self): + """Message for READ tool mentions file content.""" + payload = Payload( + event_type=EventType.PRE_TOOL_USE, + tool=Tool.READ, + content="file content with secret", + identifier="/path/to/file", + agent=Cursor(), + ) + message = AIHookScanner._message_from_secrets([_make_secret("sk-xxx")], payload) + assert "remove the secrets from" in message + + def test_message_for_other_tool(self): + """Message for OTHER tool uses generic message.""" + payload = Payload( + event_type=EventType.PRE_TOOL_USE, + tool=Tool.OTHER, + content="some content", + identifier="id", + agent=Cursor(), + ) + message = AIHookScanner._message_from_secrets([_make_secret("sk-xxx")], payload) + assert "remove the secrets from the tool input" in message + + def test_message_escapes_markdown(self): + """When escape_markdown=True, asterisks in matches are replaced with dots.""" + payload = Payload( + event_type=EventType.USER_PROMPT, + tool=None, + content="content", + identifier="id", + agent=Cursor(), + ) + message = AIHookScanner._message_from_secrets( + [_make_secret("sk-xxx")], payload, escape_markdown=True + ) + # The message itself should not contain raw asterisks from matches + # (the header uses ** for bold which is intentional) + assert "Detected" in message + + +class TestSendSecretNotification: + """Unit tests for AIHookScanner._send_secret_notification.""" + + @patch("ggshield.verticals.secret.ai_hook.Notify") + def test_notification_for_bash_tool(self, mock_notify_cls: MagicMock): + """Notification for BASH tool says 'running a command'.""" + AIHookScanner._send_secret_notification(1, Tool.BASH, "Claude Code") + instance = mock_notify_cls.return_value + assert "running a command" in instance.message + assert "Claude Code" in instance.message + instance.send.assert_called_once() + + @patch("ggshield.verticals.secret.ai_hook.Notify") + def test_notification_for_read_tool(self, mock_notify_cls: MagicMock): + """Notification for READ tool says 'reading a file'.""" + AIHookScanner._send_secret_notification(2, Tool.READ, "Cursor") + instance = mock_notify_cls.return_value + assert "reading a file" in instance.message + assert "2" in instance.message + instance.send.assert_called_once() + + @patch("ggshield.verticals.secret.ai_hook.Notify") + def test_notification_for_other_tool(self, mock_notify_cls: MagicMock): + """Notification for OTHER tool says 'using a tool'.""" + AIHookScanner._send_secret_notification(1, Tool.OTHER, "Copilot") + instance = mock_notify_cls.return_value + assert "using a tool" in instance.message + instance.send.assert_called_once() From 17f9363b18730565e53ae6570e15575616e1d7b6 Mon Sep 17 00:00:00 2001 From: Paul PETIT Date: Tue, 7 Apr 2026 09:04:26 +0200 Subject: [PATCH 2/2] feat(ai_hook): reduce coupling between verticals --- .importlinter | 5 +- ggshield/cmd/secret/scan/ai_hook.py | 2 +- ggshield/core/scan/__init__.py | 4 + ggshield/core/scan/scanner.py | 50 ++++ ggshield/verticals/ai/__init__.py | 10 + ggshield/verticals/ai/hooks.py | 176 ++++++++++++- ggshield/verticals/ai/installation.py | 6 +- ggshield/verticals/secret/ai_hook.py | 175 ------------ scripts/generate-import-linter-config.py | 4 + tests/unit/verticals/ai/test_hooks.py | 256 +++++++++++++++++- tests/unit/verticals/secret/test_ai_hooks.py | 263 ------------------- 11 files changed, 504 insertions(+), 447 deletions(-) create mode 100644 ggshield/core/scan/scanner.py create mode 100644 ggshield/verticals/ai/__init__.py delete mode 100644 ggshield/verticals/secret/ai_hook.py delete mode 100644 tests/unit/verticals/secret/test_ai_hooks.py diff --git a/.importlinter b/.importlinter index 866d4d06db..3e5669e1a3 100644 --- a/.importlinter +++ b/.importlinter @@ -10,7 +10,7 @@ type = layers layers = ggshield.__main__ ggshield.cmd.auth | ggshield.cmd.config | ggshield.cmd.hmsl | ggshield.cmd.honeytoken | ggshield.cmd.install | ggshield.cmd.plugin | ggshield.cmd.quota | ggshield.cmd.secret | ggshield.cmd.status | ggshield.cmd.utils - ggshield.verticals.auth | ggshield.verticals.hmsl | ggshield.verticals.secret + ggshield.verticals.ai | ggshield.verticals.auth | ggshield.verticals.hmsl | ggshield.verticals.secret ggshield.core click | ggshield.utils | pygitguardian ignore_imports = @@ -33,6 +33,7 @@ source_modules = ggshield.cmd.status ggshield.cmd.utils forbidden_modules = + ggshield.verticals.ai ggshield.verticals.auth ggshield.verticals.hmsl ggshield.verticals.secret @@ -46,6 +47,7 @@ ignore_imports = ggshield.cmd.hmsl.** -> ggshield.verticals.hmsl.** ggshield.cmd.honeytoken.** -> ggshield.verticals.honeytoken ggshield.cmd.honeytoken.** -> ggshield.verticals.honeytoken.** + ggshield.cmd.install -> ggshield.verticals.ai.installation ggshield.cmd.install.** -> ggshield.verticals.install ggshield.cmd.install.** -> ggshield.verticals.install.** ggshield.cmd.plugin.** -> ggshield.core.plugin @@ -54,6 +56,7 @@ ignore_imports = ggshield.cmd.quota.** -> ggshield.verticals.quota.** ggshield.cmd.secret.** -> ggshield.verticals.secret ggshield.cmd.secret.** -> ggshield.verticals.secret.** + ggshield.cmd.secret.scan.ai_hook -> ggshield.verticals.ai.hooks ggshield.cmd.status.** -> ggshield.verticals.status ggshield.cmd.status.** -> ggshield.verticals.status.** ggshield.cmd.utils.** -> ggshield.verticals.utils diff --git a/ggshield/cmd/secret/scan/ai_hook.py b/ggshield/cmd/secret/scan/ai_hook.py index 64e835c757..3689b00131 100644 --- a/ggshield/cmd/secret/scan/ai_hook.py +++ b/ggshield/cmd/secret/scan/ai_hook.py @@ -10,8 +10,8 @@ from ggshield.core import ui from ggshield.core.client import create_client_from_config from ggshield.core.scan import ScanContext, ScanMode +from ggshield.verticals.ai.hooks import AIHookScanner from ggshield.verticals.secret import SecretScanner -from ggshield.verticals.secret.ai_hook import AIHookScanner MAX_READ_SIZE = 1024 * 1024 * 10 # We restrict stdin read to 10MB diff --git a/ggshield/core/scan/__init__.py b/ggshield/core/scan/__init__.py index b37b6197f6..e2c370a110 100644 --- a/ggshield/core/scan/__init__.py +++ b/ggshield/core/scan/__init__.py @@ -3,6 +3,7 @@ from .scan_context import ScanContext from .scan_mode import ScanMode from .scannable import DecodeError, NonSeekableFileError, Scannable, StringScannable +from .scanner import ResultsProtocol, ScannerProtocol, SecretProtocol __all__ = [ @@ -11,8 +12,11 @@ "DecodeError", "File", "NonSeekableFileError", + "ResultsProtocol", "ScanContext", "ScanMode", "Scannable", + "ScannerProtocol", + "SecretProtocol", "StringScannable", ] diff --git a/ggshield/core/scan/scanner.py b/ggshield/core/scan/scanner.py new file mode 100644 index 0000000000..9f963669f6 --- /dev/null +++ b/ggshield/core/scan/scanner.py @@ -0,0 +1,50 @@ +""" +Protocols for SecretScanner and its results, +so that other verticals can use the scanner if they are provided one. +""" + +from collections.abc import Sequence +from typing import Iterable, Optional, Protocol + +from pygitguardian.models import Match + +from ggshield.core.scanner_ui import ScannerUI + +from . import Scannable + + +class SecretProtocol(Protocol): + """Abstract base class for secrets. + + We use getters instead of properties to have a . + """ + + @property + def detector_display_name(self) -> str: ... + + @property + def validity(self) -> str: ... + + @property + def matches(self) -> Sequence[Match]: ... + + +class ResultProtocol(Protocol): + @property + def secrets(self) -> Sequence[SecretProtocol]: ... + + +class ResultsProtocol(Protocol): + @property + def results(self) -> Sequence[ResultProtocol]: ... + + +class ScannerProtocol(Protocol): + """Protocol for scanners.""" + + def scan( + self, + files: Iterable[Scannable], + scanner_ui: ScannerUI, + scan_threads: Optional[int] = None, + ) -> ResultsProtocol: ... diff --git a/ggshield/verticals/ai/__init__.py b/ggshield/verticals/ai/__init__.py new file mode 100644 index 0000000000..cbd7e4db0c --- /dev/null +++ b/ggshield/verticals/ai/__init__.py @@ -0,0 +1,10 @@ +from .agents import AGENTS +from .hooks import AIHookScanner +from .installation import install_hooks + + +__all__ = [ + "AGENTS", + "AIHookScanner", + "install_hooks", +] diff --git a/ggshield/verticals/ai/hooks.py b/ggshield/verticals/ai/hooks.py index 75adabcd36..4307398c79 100644 --- a/ggshield/verticals/ai/hooks.py +++ b/ggshield/verticals/ai/hooks.py @@ -3,9 +3,16 @@ import re from typing import Any, Dict, List, Sequence, Set -from ggshield.verticals.ai.agents import Claude, Copilot, Cursor +from notifypy import Notify -from .models import Agent, EventType, HookPayload, Tool +from ggshield.core.filter import censor_match +from ggshield.core.scan import ScannerProtocol +from ggshield.core.scan import SecretProtocol as Secret +from ggshield.core.scanner_ui import create_message_only_scanner_ui +from ggshield.core.text_utils import pluralize, translate_validity + +from .agents import Claude, Copilot, Cursor +from .models import Agent, EventType, HookPayload, HookResult, Tool HOOK_NAME_TO_EVENT_TYPE = { @@ -163,3 +170,168 @@ def _parse_user_prompt( ) ) return payloads + + +class AIHookScanner: + """AI hook scanner. + + It is called with the payload of a hook event. + Note that instead of having a base class with common method and a subclass per supported AI tool, + we instead have a single class which detects which protocol to use. + This is because some tools sloppily support hooks from others. For instance, + Cursor will call hooks defined in the Claude Code format, but send payload in its own format. + So we can't assume which tool will call us based on the command line/hook configuration only. + + Raises: + ValueError: If the input is not valid. + """ + + def __init__(self, scanner: ScannerProtocol): + self.scanner = scanner + + def scan(self, content: str) -> int: + """Scan the content, print the result and return the exit code.""" + + payloads = parse_hook_input(content) + result = self._scan_payloads(payloads) + payload = result.payload + + # Special case: in post-tool use, the action is already done: at least notify the user + if result.block and payload.event_type == EventType.POST_TOOL_USE: + self._send_secret_notification( + result.nbr_secrets, + payload.tool or Tool.OTHER, + payload.agent.display_name, + ) + + return payload.agent.output_result(result) + + def _scan_payloads(self, payloads: List[HookPayload]) -> HookResult: + """Scan payloads for secrets using the SecretScanner. + + Returns: + The result of the first blocking payload, or a non-blocking result. + Raises a ValueError if the list is empty (we must have at least one to emit a result). + """ + if not payloads: + raise ValueError("Error: no payloads to scan") + for payload in payloads: + result = self._scan_content(payload) + if result.block: + return result + return HookResult.allow(payloads[0]) + + def _scan_content( + self, + payload: HookPayload, + ) -> HookResult: + """Scan content for secrets using the SecretScanner.""" + # Short path: if there is no content, no need to do an API call + if payload.empty: + return HookResult.allow(payload) + + with create_message_only_scanner_ui() as scanner_ui: + results = self.scanner.scan([payload.scannable], scanner_ui=scanner_ui) + # Collect all secrets from results + secrets: List[Secret] = [] + for result in results.results: + secrets.extend(result.secrets) + + if not secrets: + return HookResult.allow(payload) + + message = self._message_from_secrets( + secrets, + payload, + escape_markdown=True, + ) + return HookResult( + block=True, + message=message, + nbr_secrets=len(secrets), + payload=payload, + ) + + @staticmethod + def _message_from_secrets( + secrets: List[Secret], + payload: HookPayload, + escape_markdown: bool = False, + ) -> str: + """ + Format detected secrets into a user-friendly message. + + Args: + secrets: List of detected secrets + payload: Text to display after the secrets output + escape_markdown: If True, escape asterisks to prevent markdown interpretation + + Returns: + Formatted message describing the detected secrets + """ + count = len(secrets) + header = f"**🚨 Detected {count} {pluralize('secret', count)} 🚨**" + + secret_lines = [] + for secret in secrets: + validity = translate_validity(secret.validity).lower() + if validity == "valid": + validity = f"**{validity}**" + match_str = ", ".join(censor_match(m) for m in secret.matches) + if escape_markdown: + match_str = match_str.replace("*", "•") + secret_lines.append( + f" - {secret.detector_display_name} ({validity}): {match_str}" + ) + + if payload.tool == Tool.BASH: + if payload.event_type == EventType.POST_TOOL_USE: + message = "Secrets detected in the command output." + else: + message = ( + "Please remove the secrets from the command before executing it. " + "Consider using environment variables or a secrets manager instead." + ) + elif payload.tool == Tool.READ: + message = f"Please remove the secrets from {payload.identifier} before reading it." + elif payload.event_type == EventType.USER_PROMPT: + message = "Please remove the secrets from your prompt before submitting." + else: + message = ( + "Please remove the secrets from the tool input before executing. " + "Consider using environment variables or a secrets manager instead." + ) + + secrets_block = "\n".join(secret_lines) + return f"{header}\n{secrets_block}\n\n{message}" + + @staticmethod + def _send_secret_notification( + nbr_secrets: int, tool: Tool, agent_name: str + ) -> None: + """ + Send desktop notification when secrets are detected. + + Args: + nbr_secrets: Number of detected secrets + tool: Tool used to detect the secrets + agent_name: Name of the agent that detected the secrets + """ + source = "using a tool" + if tool == Tool.READ: + source = "reading a file" + elif tool == Tool.BASH: + source = "running a command" + notification = Notify() + notification.title = "ggshield - Secrets Detected" + notification.message = ( + f"{agent_name} got access to {nbr_secrets}" + f" {pluralize('secret', nbr_secrets)} by {source}" + ) + notification.application_name = "ggshield" + try: + notification.send() + except Exception: + # This is best effort, we don't want to propagate an error + # if the notification fails. + pass diff --git a/ggshield/verticals/ai/installation.py b/ggshield/verticals/ai/installation.py index f624c1fc15..ca328b374f 100644 --- a/ggshield/verticals/ai/installation.py +++ b/ggshield/verticals/ai/installation.py @@ -80,11 +80,11 @@ def install_hooks( # Report what happened styled_path = click.style(settings_path, fg="yellow", bold=True) if stats.added == 0 and stats.already_present > 0: - click.echo(f"{agent.name} hooks already installed in {styled_path}") + click.echo(f"{agent.display_name} hooks already installed in {styled_path}") elif stats.added > 0 and stats.already_present > 0: - click.echo(f"{agent.name} hooks updated in {styled_path}") + click.echo(f"{agent.display_name} hooks updated in {styled_path}") else: - click.echo(f"{agent.name} hooks successfully added in {styled_path}") + click.echo(f"{agent.display_name} hooks successfully added in {styled_path}") return 0 diff --git a/ggshield/verticals/secret/ai_hook.py b/ggshield/verticals/secret/ai_hook.py deleted file mode 100644 index 332262a3bf..0000000000 --- a/ggshield/verticals/secret/ai_hook.py +++ /dev/null @@ -1,175 +0,0 @@ -from typing import List - -from notifypy import Notify - -from ggshield.core.filter import censor_match -from ggshield.core.scanner_ui import create_message_only_scanner_ui -from ggshield.core.text_utils import pluralize, translate_validity -from ggshield.verticals.ai.hooks import parse_hook_input -from ggshield.verticals.ai.models import EventType -from ggshield.verticals.ai.models import HookPayload as Payload -from ggshield.verticals.ai.models import HookResult as Result -from ggshield.verticals.ai.models import Tool -from ggshield.verticals.secret import SecretScanner -from ggshield.verticals.secret.secret_scan_collection import Secret - - -class AIHookScanner: - """AI hook scanner. - - It is called with the payload of a hook event. - Note that instead of having a base class with common method and a subclass per supported AI tool, - we instead have a single class which detects which protocol to use (called "flavor"). - This is because some tools sloppily support hooks from others. For instance, - Cursor will call hooks defined in the Claude Code format, but send payload in its own format. - So we can't assume which tool will call us based on the command line/hook configuration only. - - Raises: - ValueError: If the input is not valid. - """ - - def __init__(self, scanner: SecretScanner): - self.scanner = scanner - - def scan(self, content: str) -> int: - """Scan the content, print the result and return the exit code.""" - - payloads = parse_hook_input(content) - result = self._scan_payloads(payloads) - payload = result.payload - - # Special case: in post-tool use, the action is already done: at least notify the user - if result.block and payload.event_type == EventType.POST_TOOL_USE: - self._send_secret_notification( - result.nbr_secrets, payload.tool or Tool.OTHER, payload.agent.name - ) - - return payload.agent.output_result(result) - - def _scan_payloads(self, payloads: List[Payload]) -> Result: - """Scan payloads for secrets using the SecretScanner. - - Returns: - The result of the first blocking payload, or a non-blocking result. - Raises a ValueError if the list is empty (we must have at least one to emit a result). - """ - if not payloads: - raise ValueError("Error: no payloads to scan") - for payload in payloads: - result = self._scan_content(payload) - if result.block: - return result - return Result.allow(payloads[0]) - - def _scan_content( - self, - payload: Payload, - ) -> Result: - """Scan content for secrets using the SecretScanner.""" - # Short path: if there is no content, no need to do an API call - if payload.empty: - return Result.allow(payload) - - with create_message_only_scanner_ui() as scanner_ui: - results = self.scanner.scan([payload.scannable], scanner_ui=scanner_ui) - # Collect all secrets from results - secrets: List[Secret] = [] - for result in results.results: - secrets.extend(result.secrets) - - if not secrets: - return Result.allow(payload) - - message = self._message_from_secrets( - secrets, - payload, - escape_markdown=True, - ) - return Result( - block=True, - message=message, - nbr_secrets=len(secrets), - payload=payload, - ) - - @staticmethod - def _message_from_secrets( - secrets: List[Secret], payload: Payload, escape_markdown: bool = False - ) -> str: - """ - Format detected secrets into a user-friendly message. - - Args: - secrets: List of detected secrets - payload: Text to display after the secrets output - escape_markdown: If True, escape asterisks to prevent markdown interpretation - - Returns: - Formatted message describing the detected secrets - """ - count = len(secrets) - header = f"**🚨 Detected {count} {pluralize('secret', count)} 🚨**" - - secret_lines = [] - for secret in secrets: - validity = translate_validity(secret.validity).lower() - if validity == "valid": - validity = f"**{validity}**" - match_str = ", ".join(censor_match(m) for m in secret.matches) - if escape_markdown: - match_str = match_str.replace("*", "•") - secret_lines.append( - f" - {secret.detector_display_name} ({validity}): {match_str}" - ) - - if payload.tool == Tool.BASH: - if payload.event_type == EventType.POST_TOOL_USE: - message = "Secrets detected in the command output." - else: - message = ( - "Please remove the secrets from the command before executing it. " - "Consider using environment variables or a secrets manager instead." - ) - elif payload.tool == Tool.READ: - message = f"Please remove the secrets from {payload.identifier} before reading it." - elif payload.event_type == EventType.USER_PROMPT: - message = "Please remove the secrets from your prompt before submitting." - else: - message = ( - "Please remove the secrets from the tool input before executing. " - "Consider using environment variables or a secrets manager instead." - ) - - secrets_block = "\n".join(secret_lines) - return f"{header}\n{secrets_block}\n\n{message}" - - @staticmethod - def _send_secret_notification( - nbr_secrets: int, tool: Tool, agent_name: str - ) -> None: - """ - Send desktop notification when secrets are detected. - - Args: - nbr_secrets: Number of detected secrets - tool: Tool used to detect the secrets - agent_name: Name of the agent that detected the secrets - """ - source = "using a tool" - if tool == Tool.READ: - source = "reading a file" - elif tool == Tool.BASH: - source = "running a command" - notification = Notify() - notification.title = "ggshield - Secrets Detected" - notification.message = ( - f"{agent_name} got access to {nbr_secrets}" - f" {pluralize('secret', nbr_secrets)} by {source}" - ) - notification.application_name = "ggshield" - try: - notification.send() - except Exception: - # This is best effort, we don't want to propagate an error - # if the notification fails. - pass diff --git a/scripts/generate-import-linter-config.py b/scripts/generate-import-linter-config.py index 8bc41c5678..3bf75f2a76 100755 --- a/scripts/generate-import-linter-config.py +++ b/scripts/generate-import-linter-config.py @@ -64,6 +64,10 @@ class Contract(TypedDict): "ggshield.cmd.{}.** -> ggshield.verticals.{}.**", # FIXME: #521 - enforce boundaries between cmd.auth and verticals.hmsl "ggshield.cmd.auth.** -> ggshield.verticals.hmsl.**", + # Install command import logic to install AI hooks + "ggshield.cmd.install -> ggshield.verticals.ai.installation", + # AI hook command import logic to scan AI hook payloads + "ggshield.cmd.secret.scan.ai_hook -> ggshield.verticals.ai.hooks", ], "unmatched_ignore_imports_alerting": "none", }, diff --git a/tests/unit/verticals/ai/test_hooks.py b/tests/unit/verticals/ai/test_hooks.py index e63d6f0361..adf355774f 100644 --- a/tests/unit/verticals/ai/test_hooks.py +++ b/tests/unit/verticals/ai/test_hooks.py @@ -1,13 +1,18 @@ import json +from collections import Counter from pathlib import Path -from typing import Set +from typing import List, Set from unittest.mock import MagicMock, patch import pytest +from ggshield.utils.git_shell import Filemode from ggshield.verticals.ai.agents import Claude, Copilot, Cursor -from ggshield.verticals.ai.hooks import find_filepaths, parse_hook_input +from ggshield.verticals.ai.hooks import AIHookScanner, find_filepaths, parse_hook_input from ggshield.verticals.ai.models import EventType, HookPayload, HookResult, Tool +from ggshield.verticals.secret import SecretScanner +from ggshield.verticals.secret.secret_scan_collection import Result as ScanResult +from ggshield.verticals.secret.secret_scan_collection import Results, Secret def _dummy_payload(event_type: EventType = EventType.OTHER) -> HookPayload: @@ -28,6 +33,253 @@ def tmp_file(tmp_path: Path) -> Path: return file +def _mock_scanner(matches: List[str]) -> MagicMock: + """Create a mock SecretScanner that returns the given Results from scan().""" + mock = MagicMock(spec=SecretScanner) + scan_result = Results( + results=[ + ScanResult( + filename="url", + filemode=Filemode.FILE, + path=Path("."), + url="url", + secrets=[_make_secret(match) for match in matches], + ignored_secrets_count_by_kind=Counter(), + ) + ], + errors=[], + ) + mock.scan.return_value = scan_result + return mock + + +def _make_secret(match_str: str = "***"): + """Minimal Secret for tests; _message_from_secrets only uses detector_display_name, validity, matches[].match.""" + mock_match = MagicMock() + mock_match.match = match_str + return Secret( + detector_display_name="dummy-detector", + detector_name="dummy-detector", + detector_group_name=None, + documentation_url=None, + validity="valid", + known_secret=False, + incident_url=None, + matches=[mock_match], + ignore_reason=None, + diff_kind=None, + is_vaulted=False, + vault_type=None, + vault_name=None, + vault_path=None, + vault_path_count=None, + ) + + +class TestAIHookScannerScanContent: + """Unit tests for AIHookScanner._scan_content.""" + + def test_no_secrets_returns_allow(self): + """When scanner returns no secrets, result has block=False and nbr_secrets=0.""" + hook_scanner = AIHookScanner(_mock_scanner([])) + payload = HookPayload( + event_type=EventType.USER_PROMPT, + tool=None, + content="safe content", + identifier="id", + agent=Cursor(), + ) + result = hook_scanner._scan_content(payload) + assert isinstance(result, HookResult) + assert result.block is False + assert result.nbr_secrets == 0 + assert result.message == "" + + def test_with_secrets_returns_block_and_message(self): + """When scanner returns secrets, result has block=True, nbr_secrets and message set.""" + hook_scanner = AIHookScanner(_mock_scanner(["sk-xxx"])) + payload = HookPayload( + event_type=EventType.USER_PROMPT, + tool=None, + content="content with sk-xxx", + identifier="id", + agent=Cursor(), + ) + result = hook_scanner._scan_content(payload) + assert isinstance(result, HookResult) + assert result.block is True + assert result.nbr_secrets == 1 + assert "dummy-detector" in result.message + assert "secret" in result.message.lower() + assert "remove the secrets from your prompt" in result.message + + +class TestAIHookScannerScan: + """Unit tests for the AIHookScanner.scan() method.""" + + def test_empty_input_raises(self): + """Empty or whitespace-only input raises ValueError.""" + scanner = AIHookScanner(_mock_scanner([])) + with pytest.raises(ValueError, match="No input received on stdin"): + scanner.scan("") + with pytest.raises(ValueError, match="No input received on stdin"): + scanner.scan(" \n ") + + def test_scan_no_secrets_returns_zero(self): + """scan() with no secrets returns 0.""" + scanner = AIHookScanner(_mock_scanner([])) + data = { + "hook_event_name": "UserPromptSubmit", + "prompt": "hello world", + "transcript_path": "/home/user/.claude/projects/foo/session.jsonl", + "cursor_version": "1.2.3", + } + code = scanner.scan(json.dumps(data)) + assert code == 0 + + @patch("ggshield.verticals.ai.hooks.AIHookScanner._send_secret_notification") + def test_scan_post_tool_use_with_secrets_sends_notification( + self, mock_notify: MagicMock + ): + """scan() on POST_TOOL_USE with secrets sends a notification and returns 0 (no block).""" + scanner = AIHookScanner(_mock_scanner(["sk-xxx"])) + data = { + "hook_event_name": "PostToolUse", + "tool_name": "Bash", + "tool_input": {"command": "echo sk-xxx"}, + "tool_response": {"stdout": "sk-xxx\n"}, + "transcript_path": "/home/user/.claude/projects/foo/session.jsonl", + "session_id": "427ae0c5-0862-4e14-aa2c-12fad909c323", + } + code = scanner.scan(json.dumps(data)) + assert code == 0 + mock_notify.assert_called_once() + args = mock_notify.call_args[0] + assert args[0] == 1 # nbr_secrets + assert args[1] == Tool.BASH # tool + + def test_scan_pre_tool_use_with_secrets_blocks(self): + """scan() on PRE_TOOL_USE with secrets returns block result.""" + scanner = AIHookScanner(_mock_scanner(["sk-xxx"])) + data = { + "hook_event_name": "PreToolUse", + "tool_name": "Bash", + "tool_input": {"command": "echo sk-xxx"}, + "session_id": "427ae0c5-0862-4e14-aa2c-12fad909c323", + "transcript_path": "/home/user/.claude/projects/foo/session.jsonl", + } + code = scanner.scan(json.dumps(data)) + # Claude output_result always returns 0 + assert code == 0 + + def test_scan_no_content_returns_allow(self): + """scan() with no content returns 0 (and doesn't call the API).""" + mock_scanner = _mock_scanner([]) + scanner = AIHookScanner(mock_scanner) + data = { + "hook_event_name": "PreToolUse", + "tool_name": "Read", + "tool_input": {"file_path": "doesn-t-exist"}, + "cursor_version": "1.2.3", + } + code = scanner.scan(json.dumps(data)) + assert code == 0 + mock_scanner.scan.assert_not_called() + + def test_scan_payloads_refuse_empty_list(self): + """scan() with empty list of payloads raises ValueError.""" + scanner = AIHookScanner(_mock_scanner([])) + with pytest.raises(ValueError): + scanner._scan_payloads([]) + + +class TestMessageFromSecrets: + """Unit tests for AIHookScanner._message_from_secrets with different payload types.""" + + def test_message_for_bash_tool(self): + """Message for BASH tool mentions environment variables.""" + payload = HookPayload( + event_type=EventType.PRE_TOOL_USE, + tool=Tool.BASH, + content="echo sk-xxx", + identifier="echo sk-xxx", + agent=Cursor(), + ) + message = AIHookScanner._message_from_secrets([_make_secret("sk-xxx")], payload) + assert "remove the secrets from the command" in message + assert "environment variables" in message + + def test_message_for_read_tool(self): + """Message for READ tool mentions file content.""" + payload = HookPayload( + event_type=EventType.PRE_TOOL_USE, + tool=Tool.READ, + content="file content with secret", + identifier="/path/to/file", + agent=Cursor(), + ) + message = AIHookScanner._message_from_secrets([_make_secret("sk-xxx")], payload) + assert "remove the secrets from" in message + + def test_message_for_other_tool(self): + """Message for OTHER tool uses generic message.""" + payload = HookPayload( + event_type=EventType.PRE_TOOL_USE, + tool=Tool.OTHER, + content="some content", + identifier="id", + agent=Cursor(), + ) + message = AIHookScanner._message_from_secrets([_make_secret("sk-xxx")], payload) + assert "remove the secrets from the tool input" in message + + def test_message_escapes_markdown(self): + """When escape_markdown=True, asterisks in matches are replaced with dots.""" + payload = HookPayload( + event_type=EventType.USER_PROMPT, + tool=None, + content="content", + identifier="id", + agent=Cursor(), + ) + message = AIHookScanner._message_from_secrets( + [_make_secret("sk-xxx")], payload, escape_markdown=True + ) + # The message itself should not contain raw asterisks from matches + # (the header uses ** for bold which is intentional) + assert "Detected" in message + + +class TestSendSecretNotification: + """Unit tests for AIHookScanner._send_secret_notification.""" + + @patch("ggshield.verticals.ai.hooks.Notify") + def test_notification_for_bash_tool(self, mock_notify_cls: MagicMock): + """Notification for BASH tool says 'running a command'.""" + AIHookScanner._send_secret_notification(1, Tool.BASH, "Claude Code") + instance = mock_notify_cls.return_value + assert "running a command" in instance.message + assert "Claude Code" in instance.message + instance.send.assert_called_once() + + @patch("ggshield.verticals.ai.hooks.Notify") + def test_notification_for_read_tool(self, mock_notify_cls: MagicMock): + """Notification for READ tool says 'reading a file'.""" + AIHookScanner._send_secret_notification(2, Tool.READ, "Cursor") + instance = mock_notify_cls.return_value + assert "reading a file" in instance.message + assert "2" in instance.message + instance.send.assert_called_once() + + @patch("ggshield.verticals.ai.hooks.Notify") + def test_notification_for_other_tool(self, mock_notify_cls: MagicMock): + """Notification for OTHER tool says 'using a tool'.""" + AIHookScanner._send_secret_notification(1, Tool.OTHER, "Copilot") + instance = mock_notify_cls.return_value + assert "using a tool" in instance.message + instance.send.assert_called_once() + + class TestAIHookScannerParseInput: """Unit tests for AIHookparse_hook_input.""" diff --git a/tests/unit/verticals/secret/test_ai_hooks.py b/tests/unit/verticals/secret/test_ai_hooks.py deleted file mode 100644 index 5bfeca82cd..0000000000 --- a/tests/unit/verticals/secret/test_ai_hooks.py +++ /dev/null @@ -1,263 +0,0 @@ -import json -from collections import Counter -from pathlib import Path -from typing import List -from unittest.mock import MagicMock, patch - -import pytest - -from ggshield.utils.git_shell import Filemode -from ggshield.verticals.ai.agents import Cursor -from ggshield.verticals.secret import SecretScanner -from ggshield.verticals.secret.ai_hook import AIHookScanner, EventType, Payload -from ggshield.verticals.secret.ai_hook import Result as HookResult -from ggshield.verticals.secret.ai_hook import Tool -from ggshield.verticals.secret.secret_scan_collection import Result as ScanResult -from ggshield.verticals.secret.secret_scan_collection import Results, Secret - - -def _mock_scanner(matches: List[str]) -> MagicMock: - """Create a mock SecretScanner that returns the given Results from scan().""" - mock = MagicMock(spec=SecretScanner) - scan_result = Results( - results=[ - ScanResult( - filename="url", - filemode=Filemode.FILE, - path=Path("."), - url="url", - secrets=[_make_secret(match) for match in matches], - ignored_secrets_count_by_kind=Counter(), - ) - ], - errors=[], - ) - mock.scan.return_value = scan_result - return mock - - -def _make_secret(match_str: str = "***"): - """Minimal Secret for tests; _message_from_secrets only uses detector_display_name, validity, matches[].match.""" - mock_match = MagicMock() - mock_match.match = match_str - return Secret( - detector_display_name="dummy-detector", - detector_name="dummy-detector", - detector_group_name=None, - documentation_url=None, - validity="valid", - known_secret=False, - incident_url=None, - matches=[mock_match], - ignore_reason=None, - diff_kind=None, - is_vaulted=False, - vault_type=None, - vault_name=None, - vault_path=None, - vault_path_count=None, - ) - - -class TestAIHookScannerScanContent: - """Unit tests for AIHookScanner._scan_content.""" - - def test_no_secrets_returns_allow(self): - """When scanner returns no secrets, result has block=False and nbr_secrets=0.""" - hook_scanner = AIHookScanner(_mock_scanner([])) - payload = Payload( - event_type=EventType.USER_PROMPT, - tool=None, - content="safe content", - identifier="id", - agent=Cursor(), - ) - result = hook_scanner._scan_content(payload) - assert isinstance(result, HookResult) - assert result.block is False - assert result.nbr_secrets == 0 - assert result.message == "" - - def test_with_secrets_returns_block_and_message(self): - """When scanner returns secrets, result has block=True, nbr_secrets and message set.""" - hook_scanner = AIHookScanner(_mock_scanner(["sk-xxx"])) - payload = Payload( - event_type=EventType.USER_PROMPT, - tool=None, - content="content with sk-xxx", - identifier="id", - agent=Cursor(), - ) - result = hook_scanner._scan_content(payload) - assert isinstance(result, HookResult) - assert result.block is True - assert result.nbr_secrets == 1 - assert "dummy-detector" in result.message - assert "secret" in result.message.lower() - assert "remove the secrets from your prompt" in result.message - - -class TestAIHookScannerScan: - """Unit tests for the AIHookScanner.scan() method.""" - - def test_empty_input_raises(self): - """Empty or whitespace-only input raises ValueError.""" - scanner = AIHookScanner(_mock_scanner([])) - with pytest.raises(ValueError, match="No input received on stdin"): - scanner.scan("") - with pytest.raises(ValueError, match="No input received on stdin"): - scanner.scan(" \n ") - - def test_scan_no_secrets_returns_zero(self): - """scan() with no secrets returns 0.""" - scanner = AIHookScanner(_mock_scanner([])) - data = { - "hook_event_name": "UserPromptSubmit", - "prompt": "hello world", - "transcript_path": "/home/user/.claude/projects/foo/session.jsonl", - "cursor_version": "1.2.3", - } - code = scanner.scan(json.dumps(data)) - assert code == 0 - - @patch("ggshield.verticals.secret.ai_hook.AIHookScanner._send_secret_notification") - def test_scan_post_tool_use_with_secrets_sends_notification( - self, mock_notify: MagicMock - ): - """scan() on POST_TOOL_USE with secrets sends a notification and returns 0 (no block).""" - scanner = AIHookScanner(_mock_scanner(["sk-xxx"])) - data = { - "hook_event_name": "PostToolUse", - "tool_name": "Bash", - "tool_input": {"command": "echo sk-xxx"}, - "tool_response": {"stdout": "sk-xxx\n"}, - "transcript_path": "/home/user/.claude/projects/foo/session.jsonl", - "session_id": "427ae0c5-0862-4e14-aa2c-12fad909c323", - } - code = scanner.scan(json.dumps(data)) - assert code == 0 - mock_notify.assert_called_once() - args = mock_notify.call_args[0] - assert args[0] == 1 # nbr_secrets - assert args[1] == Tool.BASH # tool - - def test_scan_pre_tool_use_with_secrets_blocks(self): - """scan() on PRE_TOOL_USE with secrets returns block result.""" - scanner = AIHookScanner(_mock_scanner(["sk-xxx"])) - data = { - "hook_event_name": "PreToolUse", - "tool_name": "Bash", - "tool_input": {"command": "echo sk-xxx"}, - "session_id": "427ae0c5-0862-4e14-aa2c-12fad909c323", - "transcript_path": "/home/user/.claude/projects/foo/session.jsonl", - } - code = scanner.scan(json.dumps(data)) - # Claude output_result always returns 0 - assert code == 0 - - def test_scan_no_content_returns_allow(self): - """scan() with no content returns 0 (and doesn't call the API).""" - mock_scanner = _mock_scanner([]) - scanner = AIHookScanner(mock_scanner) - data = { - "hook_event_name": "PreToolUse", - "tool_name": "Read", - "tool_input": {"file_path": "doesn-t-exist"}, - "cursor_version": "1.2.3", - } - code = scanner.scan(json.dumps(data)) - assert code == 0 - mock_scanner.scan.assert_not_called() - - def test_scan_payloads_refuse_empty_list(self): - """scan() with empty list of payloads raises ValueError.""" - scanner = AIHookScanner(_mock_scanner([])) - with pytest.raises(ValueError): - scanner._scan_payloads([]) - - -class TestMessageFromSecrets: - """Unit tests for AIHookScanner._message_from_secrets with different payload types.""" - - def test_message_for_bash_tool(self): - """Message for BASH tool mentions environment variables.""" - payload = Payload( - event_type=EventType.PRE_TOOL_USE, - tool=Tool.BASH, - content="echo sk-xxx", - identifier="echo sk-xxx", - agent=Cursor(), - ) - message = AIHookScanner._message_from_secrets([_make_secret("sk-xxx")], payload) - assert "remove the secrets from the command" in message - assert "environment variables" in message - - def test_message_for_read_tool(self): - """Message for READ tool mentions file content.""" - payload = Payload( - event_type=EventType.PRE_TOOL_USE, - tool=Tool.READ, - content="file content with secret", - identifier="/path/to/file", - agent=Cursor(), - ) - message = AIHookScanner._message_from_secrets([_make_secret("sk-xxx")], payload) - assert "remove the secrets from" in message - - def test_message_for_other_tool(self): - """Message for OTHER tool uses generic message.""" - payload = Payload( - event_type=EventType.PRE_TOOL_USE, - tool=Tool.OTHER, - content="some content", - identifier="id", - agent=Cursor(), - ) - message = AIHookScanner._message_from_secrets([_make_secret("sk-xxx")], payload) - assert "remove the secrets from the tool input" in message - - def test_message_escapes_markdown(self): - """When escape_markdown=True, asterisks in matches are replaced with dots.""" - payload = Payload( - event_type=EventType.USER_PROMPT, - tool=None, - content="content", - identifier="id", - agent=Cursor(), - ) - message = AIHookScanner._message_from_secrets( - [_make_secret("sk-xxx")], payload, escape_markdown=True - ) - # The message itself should not contain raw asterisks from matches - # (the header uses ** for bold which is intentional) - assert "Detected" in message - - -class TestSendSecretNotification: - """Unit tests for AIHookScanner._send_secret_notification.""" - - @patch("ggshield.verticals.secret.ai_hook.Notify") - def test_notification_for_bash_tool(self, mock_notify_cls: MagicMock): - """Notification for BASH tool says 'running a command'.""" - AIHookScanner._send_secret_notification(1, Tool.BASH, "Claude Code") - instance = mock_notify_cls.return_value - assert "running a command" in instance.message - assert "Claude Code" in instance.message - instance.send.assert_called_once() - - @patch("ggshield.verticals.secret.ai_hook.Notify") - def test_notification_for_read_tool(self, mock_notify_cls: MagicMock): - """Notification for READ tool says 'reading a file'.""" - AIHookScanner._send_secret_notification(2, Tool.READ, "Cursor") - instance = mock_notify_cls.return_value - assert "reading a file" in instance.message - assert "2" in instance.message - instance.send.assert_called_once() - - @patch("ggshield.verticals.secret.ai_hook.Notify") - def test_notification_for_other_tool(self, mock_notify_cls: MagicMock): - """Notification for OTHER tool says 'using a tool'.""" - AIHookScanner._send_secret_notification(1, Tool.OTHER, "Copilot") - instance = mock_notify_cls.return_value - assert "using a tool" in instance.message - instance.send.assert_called_once()