From 7474d2b708e4fa09407af27ccf4ec41e8aeba928 Mon Sep 17 00:00:00 2001 From: Paul PETIT Date: Thu, 26 Mar 2026 10:58:47 +0100 Subject: [PATCH 1/3] feat(ai_hook): refactor into new "ai" vertical --- .importlinter | 1 - ggshield/cmd/install.py | 6 +- ggshield/cmd/secret/scan/ai_hook.py | 4 +- ggshield/verticals/ai/agents/__init__.py | 9 + .../ai_hook => ai/agents}/claude_code.py | 14 +- .../{secret/ai_hook => ai/agents}/copilot.py | 12 +- .../{secret/ai_hook => ai/agents}/cursor.py | 14 +- ggshield/verticals/ai/hooks.py | 165 +++++++ .../{secret/ai_hook => ai}/installation.py | 27 +- .../{secret/ai_hook => ai}/models.py | 93 ++-- .../secret/{ai_hook/scanner.py => ai_hook.py} | 177 +------- ggshield/verticals/secret/ai_hook/__init__.py | 5 - scripts/generate-import-linter-config.py | 2 - .../{secret/ai_hook => ai}/test_hooks.py | 422 +++--------------- .../ai_hook => ai}/test_installation.py | 48 +- tests/unit/verticals/secret/test_ai_hooks.py | 263 +++++++++++ 16 files changed, 613 insertions(+), 649 deletions(-) create mode 100644 ggshield/verticals/ai/agents/__init__.py rename ggshield/verticals/{secret/ai_hook => ai/agents}/claude_code.py (92%) rename ggshield/verticals/{secret/ai_hook => ai/agents}/copilot.py (81%) rename ggshield/verticals/{secret/ai_hook => ai/agents}/cursor.py (87%) create mode 100644 ggshield/verticals/ai/hooks.py rename ggshield/verticals/{secret/ai_hook => ai}/installation.py (87%) rename ggshield/verticals/{secret/ai_hook => ai}/models.py (80%) rename ggshield/verticals/secret/{ai_hook/scanner.py => ai_hook.py} (50%) delete mode 100644 ggshield/verticals/secret/ai_hook/__init__.py rename tests/unit/verticals/{secret/ai_hook => ai}/test_hooks.py (60%) rename tests/unit/verticals/{secret/ai_hook => ai}/test_installation.py (90%) create mode 100644 tests/unit/verticals/secret/test_ai_hooks.py diff --git a/.importlinter b/.importlinter index a56c7d3c7e..866d4d06db 100644 --- a/.importlinter +++ b/.importlinter @@ -46,7 +46,6 @@ ignore_imports = ggshield.cmd.hmsl.** -> ggshield.verticals.hmsl.** ggshield.cmd.honeytoken.** -> ggshield.verticals.honeytoken ggshield.cmd.honeytoken.** -> ggshield.verticals.honeytoken.** - ggshield.cmd.install -> ggshield.verticals.secret.ai_hook ggshield.cmd.install.** -> ggshield.verticals.install ggshield.cmd.install.** -> ggshield.verticals.install.** ggshield.cmd.plugin.** -> ggshield.core.plugin diff --git a/ggshield/cmd/install.py b/ggshield/cmd/install.py index f835926d7f..9afee1469c 100644 --- a/ggshield/cmd/install.py +++ b/ggshield/cmd/install.py @@ -10,7 +10,7 @@ from ggshield.core.dirs import get_data_dir from ggshield.core.errors import UnexpectedError from ggshield.utils.git_shell import check_git_dir, git -from ggshield.verticals.secret.ai_hook import AI_FLAVORS, install_hooks +from ggshield.verticals.ai.installation import AGENTS, install_hooks # This snippet is used by the global hook to call the hook defined in the @@ -39,7 +39,7 @@ @click.option( "--hook-type", "-t", - type=click.Choice(["pre-commit", "pre-push"] + list(AI_FLAVORS.keys())), + type=click.Choice(["pre-commit", "pre-push"] + list(AGENTS.keys())), help="Type of hook to install.", default="pre-commit", ) @@ -61,7 +61,7 @@ def install_cmd( It can also install ggshield as a Cursor IDE or Claude Code agent hook. """ - if hook_type in AI_FLAVORS: + if hook_type in AGENTS: return install_hooks(name=hook_type, mode=mode, force=force) return_code = ( diff --git a/ggshield/cmd/secret/scan/ai_hook.py b/ggshield/cmd/secret/scan/ai_hook.py index c53ab6a716..64e835c757 100644 --- a/ggshield/cmd/secret/scan/ai_hook.py +++ b/ggshield/cmd/secret/scan/ai_hook.py @@ -12,7 +12,9 @@ from ggshield.core.scan import ScanContext, ScanMode from ggshield.verticals.secret import SecretScanner from ggshield.verticals.secret.ai_hook import AIHookScanner -from ggshield.verticals.secret.ai_hook.models import MAX_READ_SIZE + + +MAX_READ_SIZE = 1024 * 1024 * 10 # We restrict stdin read to 10MB @click.command() diff --git a/ggshield/verticals/ai/agents/__init__.py b/ggshield/verticals/ai/agents/__init__.py new file mode 100644 index 0000000000..7289463cc7 --- /dev/null +++ b/ggshield/verticals/ai/agents/__init__.py @@ -0,0 +1,9 @@ +from .claude_code import Claude +from .copilot import Copilot +from .cursor import Cursor + + +AGENTS = {agent.name: agent for agent in [Cursor(), Claude(), Copilot()]} + + +__all__ = ["AGENTS", "Claude", "Copilot", "Cursor"] diff --git a/ggshield/verticals/secret/ai_hook/claude_code.py b/ggshield/verticals/ai/agents/claude_code.py similarity index 92% rename from ggshield/verticals/secret/ai_hook/claude_code.py rename to ggshield/verticals/ai/agents/claude_code.py index 378490f904..b9ea7670f1 100644 --- a/ggshield/verticals/secret/ai_hook/claude_code.py +++ b/ggshield/verticals/ai/agents/claude_code.py @@ -4,15 +4,21 @@ import click -from .models import EventType, Flavor, Result +from ..models import Agent, EventType, HookResult -class Claude(Flavor): +class Claude(Agent): """Behavior specific to Claude Code.""" - name = "Claude Code" + @property + def name(self) -> str: + return "claude-code" + + @property + def display_name(self) -> str: + return "Claude Code" - def output_result(self, result: Result) -> int: + def output_result(self, result: HookResult) -> int: response = {} if result.block: if result.payload.event_type in [ diff --git a/ggshield/verticals/secret/ai_hook/copilot.py b/ggshield/verticals/ai/agents/copilot.py similarity index 81% rename from ggshield/verticals/secret/ai_hook/copilot.py rename to ggshield/verticals/ai/agents/copilot.py index b523daa70b..300a8ce132 100644 --- a/ggshield/verticals/secret/ai_hook/copilot.py +++ b/ggshield/verticals/ai/agents/copilot.py @@ -3,8 +3,8 @@ import click +from ..models import EventType, HookResult from .claude_code import Claude -from .models import EventType, Result class Copilot(Claude): @@ -13,9 +13,15 @@ class Copilot(Claude): Inherits most of its behavior from Claude Code. """ - name = "Copilot" + @property + def name(self) -> str: + return "copilot" + + @property + def display_name(self) -> str: + return "Copilot Chat" - def output_result(self, result: Result) -> int: + def output_result(self, result: HookResult) -> int: response = {} if result.block: if result.payload.event_type == EventType.PRE_TOOL_USE: diff --git a/ggshield/verticals/secret/ai_hook/cursor.py b/ggshield/verticals/ai/agents/cursor.py similarity index 87% rename from ggshield/verticals/secret/ai_hook/cursor.py rename to ggshield/verticals/ai/agents/cursor.py index 3b20c46fec..c75e807380 100644 --- a/ggshield/verticals/secret/ai_hook/cursor.py +++ b/ggshield/verticals/ai/agents/cursor.py @@ -4,15 +4,21 @@ import click -from .models import EventType, Flavor, Result +from ..models import Agent, EventType, HookResult -class Cursor(Flavor): +class Cursor(Agent): """Behavior specific to Cursor.""" - name = "Cursor" + @property + def name(self) -> str: + return "cursor" + + @property + def display_name(self) -> str: + return "Cursor" - def output_result(self, result: Result) -> int: + def output_result(self, result: HookResult) -> int: response = {} if result.payload.event_type == EventType.USER_PROMPT: response["continue"] = not result.block diff --git a/ggshield/verticals/ai/hooks.py b/ggshield/verticals/ai/hooks.py new file mode 100644 index 0000000000..75adabcd36 --- /dev/null +++ b/ggshield/verticals/ai/hooks.py @@ -0,0 +1,165 @@ +import hashlib +import json +import re +from typing import Any, Dict, List, Sequence, Set + +from ggshield.verticals.ai.agents import Claude, Copilot, Cursor + +from .models import Agent, EventType, HookPayload, Tool + + +HOOK_NAME_TO_EVENT_TYPE = { + "userpromptsubmit": EventType.USER_PROMPT, + "beforesubmitprompt": EventType.USER_PROMPT, + "pretooluse": EventType.PRE_TOOL_USE, + "posttooluse": EventType.POST_TOOL_USE, +} + +TOOL_NAME_TO_TOOL = { + "shell": Tool.BASH, # Cursor + "bash": Tool.BASH, # Claude Code + "run_in_terminal": Tool.BASH, # Copilot + "read": Tool.READ, # Claude/Cursor + "read_file": Tool.READ, # Copilot +} + + +def lookup(data: Dict[str, Any], keys: Sequence[str], default: Any = None) -> Any: + """Returns the value of the first key found in a dictionary.""" + for key in keys: + if key in data: + return data[key] + return default + + +# Regex (and method) to look for any @file_path in the prompt. +# A list of test cases can be found in test_hooks.py. +_FILE_PATH_REGEX = re.compile( + r'@"((?:[^"\\]|\\.)*)"' # quoted: @"..." + r"|" + r"(?:\W|^)@([\w/\\.-]+)", # unquoted: @path + re.MULTILINE, +) + + +def find_filepaths(prompt: str) -> Set[str]: + """Find all file paths in the prompt.""" + paths = set() + for m in _FILE_PATH_REGEX.finditer(prompt): + path = m.group(1) or m.group(2) or "" + path = path.strip() + # Don't include trailing dots in the path + if path.endswith("."): + path = path[:-1] + if path: + paths.add(path) + return paths + + +def parse_hook_input(raw_content: str) -> list[HookPayload]: + """Parse the input content. Raises a ValueError if the input is not valid. + + Returns: + A list of payloads. Most of the time the list will contain only one payload, + but in some cases files mentioned in the prompt will be read but the + PreToolUse event will not be called. So we need to handle this case ourselves. + """ + # Parse the content as JSON + if not raw_content.strip(): + raise ValueError("Error: No input received on stdin") + try: + data = json.loads(raw_content) + except json.JSONDecodeError as e: + raise ValueError(f"Error: Failed to parse JSON from stdin: {e}") from e + + payloads = [] + + # Try to guess which AI coding assistant is calling us + agent = _detect_agent(data) + + # Infer the event type + event_name = lookup(data, ["hook_event_name", "hookEventName"], None) + if event_name is None: + raise ValueError("Error: couldn't find event type") + event_type = HOOK_NAME_TO_EVENT_TYPE.get(event_name.lower(), EventType.OTHER) + + identifier = "" + content = "" + tool = None + + # Extract the identifier and content based on the event type + if event_type == EventType.USER_PROMPT: + content = data.get("prompt", "") + # Look for files mentioned in the prompt that could be read + # without triggering a PRE_TOOL_USE event. + payloads.extend(_parse_user_prompt(content, event_type, agent)) + + elif event_type == EventType.PRE_TOOL_USE: + tool_name = data.get("tool_name", "").lower() + tool = TOOL_NAME_TO_TOOL.get(tool_name, Tool.OTHER) + tool_input = data.get("tool_input", {}) + # Select the content based on the tool + if tool == Tool.BASH: + content = tool_input.get("command", "") + identifier = content + elif tool == Tool.READ: + # We only need to deal with the identifier, the content will be read by the Scannable + identifier = lookup(tool_input, ["file_path", "filePath"], "") + + elif event_type == EventType.POST_TOOL_USE: + tool_name = data.get("tool_name", "").lower() + tool = TOOL_NAME_TO_TOOL.get(tool_name, Tool.OTHER) + content = data.get("tool_output", "") or data.get("tool_response", {}) + # Claude Code returns a dict for the tool output + if isinstance(content, (dict, list)): + content = json.dumps(content) + + # If identifier was not set, hash the content + if not identifier: + identifier = hashlib.sha256((content or "").encode()).hexdigest() + + payloads.append( + HookPayload( + event_type=event_type, + tool=tool, + content=content, + identifier=identifier, + agent=agent, + ) + ) + return payloads + + +def _detect_agent(data: Dict[str, Any]) -> Agent: + """Detect the AI code assistant.""" + if "cursor_version" in data: + return Cursor() + elif "github.copilot-chat" in data.get("transcript_path", "").lower(): + return Copilot() + # no .lower() here to reduce the risk of false positives (this is also why this check is last) + elif "session_id" in data and "claude" in data.get("transcript_path", ""): + return Claude() + # No other agent is supported yet + raise ValueError("Unsupported agent") + + +def _parse_user_prompt( + content: str, event_type: EventType, agent: Agent +) -> List[HookPayload]: + """Parse the user prompt for additional payloads that we may miss.""" + payloads = [] + # Scenario 1 (the only one we know about so far): + # Code assistants don't always trigger a PRE_TOOL_USE event when + # a file is mentioned in the prompt, especially with an "@" prefix. + matches = find_filepaths(content) + for match in matches: + payloads.append( + HookPayload( + event_type=event_type, + tool=Tool.READ, + content="", + identifier=match, + agent=agent, + ) + ) + return payloads diff --git a/ggshield/verticals/secret/ai_hook/installation.py b/ggshield/verticals/ai/installation.py similarity index 87% rename from ggshield/verticals/secret/ai_hook/installation.py rename to ggshield/verticals/ai/installation.py index 3139c7b3d5..f624c1fc15 100644 --- a/ggshield/verticals/secret/ai_hook/installation.py +++ b/ggshield/verticals/ai/installation.py @@ -9,16 +9,7 @@ from ggshield.core.dirs import get_user_home_dir from ggshield.core.errors import UnexpectedError -from .claude_code import Claude -from .copilot import Copilot -from .cursor import Cursor - - -AI_FLAVORS = { - "cursor": Cursor, - "claude-code": Claude, - "copilot": Copilot, -} +from .agents import AGENTS @dataclass @@ -41,12 +32,12 @@ def install_hooks( """ try: - flavor = AI_FLAVORS[name]() + agent = AGENTS[name] except KeyError: - raise ValueError(f"Unsupported tool name: {name}") + raise ValueError(f"Unsupported agent: {name}") base_dir = get_user_home_dir() if mode == "global" else Path(".") - settings_path = base_dir / flavor.settings_path + settings_path = base_dir / agent.settings_path command = "ggshield secret scan ai-hook" @@ -71,11 +62,11 @@ def install_hooks( stats = _fill_dict( config=existing_config, - template=flavor.settings_template, + template=agent.settings_template, command=command, overwrite=force, stats=stats, - locator=flavor.settings_locate, + locator=agent.settings_locate, ) # Ensure parent directory exists @@ -89,11 +80,11 @@ def install_hooks( # Report what happened styled_path = click.style(settings_path, fg="yellow", bold=True) if stats.added == 0 and stats.already_present > 0: - click.echo(f"{flavor.name} hooks already installed in {styled_path}") + click.echo(f"{agent.name} hooks already installed in {styled_path}") elif stats.added > 0 and stats.already_present > 0: - click.echo(f"{flavor.name} hooks updated in {styled_path}") + click.echo(f"{agent.name} hooks updated in {styled_path}") else: - click.echo(f"{flavor.name} hooks successfully added in {styled_path}") + click.echo(f"{agent.name} hooks successfully added in {styled_path}") return 0 diff --git a/ggshield/verticals/secret/ai_hook/models.py b/ggshield/verticals/ai/models.py similarity index 80% rename from ggshield/verticals/secret/ai_hook/models.py rename to ggshield/verticals/ai/models.py index a031e2eea3..fecf9ddd27 100644 --- a/ggshield/verticals/secret/ai_hook/models.py +++ b/ggshield/verticals/ai/models.py @@ -1,17 +1,13 @@ +from abc import ABC, abstractmethod from dataclasses import dataclass from enum import Enum, auto from pathlib import Path from typing import Any, Dict, List, Optional -import click - from ggshield.core.scan import File, Scannable, StringScannable from ggshield.utils.files import is_path_binary -MAX_READ_SIZE = 1024 * 1024 * 50 # We restrict payloads read to 50MB - - class EventType(Enum): """Event type constants for hook events.""" @@ -33,31 +29,64 @@ class Tool(Enum): @dataclass -class Result: +class HookResult: """Result of a scan: allow or not.""" block: bool message: str nbr_secrets: int - payload: "Payload" + payload: "HookPayload" @classmethod - def allow(cls, payload: "Payload") -> "Result": + def allow(cls, payload: "HookPayload") -> "HookResult": return cls(block=False, message="", nbr_secrets=0, payload=payload) -class Flavor: +@dataclass +class HookPayload: + event_type: EventType + tool: Optional[Tool] + content: str + identifier: str + agent: "Agent" + + @property + def scannable(self) -> Scannable: + """Return the appropriate Scannable for the payload.""" + if self.tool == Tool.READ: + path = Path(self.identifier) + if path.is_file() and not is_path_binary(path): + return File(path=self.identifier) + return StringScannable(url=self.identifier, content=self.content) + + @property + def empty(self) -> bool: + """Return True if the payload is empty.""" + return not self.scannable.is_longer_than(0) + + +class Agent(ABC): """ Class that can be derived to implement behavior specific to some AI code assistants. """ - name = "Your AI coding tool" + # Metadata - def output_result(self, result: Result) -> int: - """How to output the result of a scan. + @property + @abstractmethod + def display_name(self) -> str: + """A user-friendly name for the agent.""" - This base implementation has sensible defaults (like returning 2 in case of a block, - and printing the output in stderr or stdout). + @property + @abstractmethod + def name(self) -> str: + """The name of the agent.""" + + # Hooks + + @abstractmethod + def output_result(self, result: HookResult) -> int: + """How to output the result of a scan. This method is expected to have side effects, like printing to stdout or stderr. @@ -66,26 +95,23 @@ def output_result(self, result: Result) -> int: Returns: the exit code. """ - if result.block: - click.echo(result.message, err=True) - return 2 - else: - click.echo("No secrets found. Good to go.") - return 0 + + # Settings @property + @abstractmethod def settings_path(self) -> Path: """Path to the settings file for this AI coding tool.""" - return Path(".agents") / "hooks.json" @property + @abstractmethod def settings_template(self) -> Dict[str, Any]: """ Template for the settings file for this AI coding tool. Use the sentinel "" for the places where the command should be inserted. """ - return {} + @abstractmethod def settings_locate( self, candidates: List[Dict[str, Any]], template: Dict[str, Any] ) -> Optional[Dict[str, Any]]: @@ -102,26 +128,3 @@ def settings_locate( Returns: the object to update, or None if no object was found. """ return None - - -@dataclass -class Payload: - event_type: EventType - tool: Optional[Tool] - content: str - identifier: str - flavor: Flavor - - @property - def scannable(self) -> Scannable: - """Return the appropriate Scannable for the payload.""" - if self.tool == Tool.READ: - path = Path(self.identifier) - if path.is_file() and not is_path_binary(path): - return File(path=self.identifier) - return StringScannable(url=self.identifier, content=self.content) - - @property - def empty(self) -> bool: - """Return True if the payload is empty.""" - return not self.scannable.is_longer_than(0) diff --git a/ggshield/verticals/secret/ai_hook/scanner.py b/ggshield/verticals/secret/ai_hook.py similarity index 50% rename from ggshield/verticals/secret/ai_hook/scanner.py rename to ggshield/verticals/secret/ai_hook.py index 7eacfe5abe..332262a3bf 100644 --- a/ggshield/verticals/secret/ai_hook/scanner.py +++ b/ggshield/verticals/secret/ai_hook.py @@ -1,45 +1,18 @@ -import hashlib -import json -import re -from typing import Any, Dict, List, Sequence, Set +from typing import List from notifypy import Notify from ggshield.core.filter import censor_match from ggshield.core.scanner_ui import create_message_only_scanner_ui from ggshield.core.text_utils import pluralize, translate_validity +from ggshield.verticals.ai.hooks import parse_hook_input +from ggshield.verticals.ai.models import EventType +from ggshield.verticals.ai.models import HookPayload as Payload +from ggshield.verticals.ai.models import HookResult as Result +from ggshield.verticals.ai.models import Tool from ggshield.verticals.secret import SecretScanner -from ggshield.verticals.secret.ai_hook.copilot import Copilot from ggshield.verticals.secret.secret_scan_collection import Secret -from .claude_code import Claude -from .cursor import Cursor -from .models import EventType, Flavor, Payload, Result, Tool - - -# Regex (and method) to look for any @file_path in the prompt. -# A list of test cases can be found in test_hooks.py. -_FILE_PATH_REGEX = re.compile( - r'@"((?:[^"\\]|\\.)*)"' # quoted: @"..." - r"|" - r"(?:\W|^)@([\w/\\.-]+)", # unquoted: @path - re.MULTILINE, -) - - -def find_filepaths(prompt: str) -> Set[str]: - """Find all file paths in the prompt.""" - paths = set() - for m in _FILE_PATH_REGEX.finditer(prompt): - path = m.group(1) or m.group(2) or "" - path = path.strip() - # Don't include trailing dots in the path - if path.endswith("."): - path = path[:-1] - if path: - paths.add(path) - return paths - class AIHookScanner: """AI hook scanner. @@ -61,90 +34,17 @@ def __init__(self, scanner: SecretScanner): def scan(self, content: str) -> int: """Scan the content, print the result and return the exit code.""" - payloads = self._parse_input(content) + payloads = parse_hook_input(content) result = self._scan_payloads(payloads) payload = result.payload # Special case: in post-tool use, the action is already done: at least notify the user if result.block and payload.event_type == EventType.POST_TOOL_USE: self._send_secret_notification( - result.nbr_secrets, payload.tool or Tool.OTHER, payload.flavor.name + result.nbr_secrets, payload.tool or Tool.OTHER, payload.agent.name ) - return payload.flavor.output_result(result) - - def _parse_input(self, raw_content: str) -> list[Payload]: - """Parse the input content. Raises a ValueError if the input is not valid. - - Returns: - A list of payloads. Most of the time the list will contain only one payload, - but in some cases files mentioned in the prompt will be read but the - PreToolUse event will not be called. So we need to handle this case ourselves. - """ - # Parse the content as JSON - if not raw_content.strip(): - raise ValueError("Error: No input received on stdin") - try: - data = json.loads(raw_content) - except json.JSONDecodeError as e: - raise ValueError(f"Error: Failed to parse JSON from stdin: {e}") from e - - payloads = [] - - # Try to guess which AI coding assistant is calling us - flavor = self._detect_flavor(data) - - # Infer the event type - event_name = lookup(data, ["hook_event_name", "hookEventName"], None) - if event_name is None: - raise ValueError("Error: couldn't find event type") - event_type = HOOK_NAME_TO_EVENT_TYPE.get(event_name.lower(), EventType.OTHER) - - identifier = "" - content = "" - tool = None - - # Extract the identifier and content based on the event type - if event_type == EventType.USER_PROMPT: - content = data.get("prompt", "") - # Look for files mentioned in the prompt that could be read - # without triggering a PRE_TOOL_USE event. - payloads.extend(self._parse_user_prompt(content, event_type, flavor)) - - elif event_type == EventType.PRE_TOOL_USE: - tool_name = data.get("tool_name", "").lower() - tool = TOOL_NAME_TO_TOOL.get(tool_name, Tool.OTHER) - tool_input = data.get("tool_input", {}) - # Select the content based on the tool - if tool == Tool.BASH: - content = tool_input.get("command", "") - identifier = content - elif tool == Tool.READ: - # We only need to deal with the identifier, the content will be read by the Scannable - identifier = lookup(tool_input, ["file_path", "filePath"], "") - - elif event_type == EventType.POST_TOOL_USE: - tool_name = data.get("tool_name", "").lower() - tool = TOOL_NAME_TO_TOOL.get(tool_name, Tool.OTHER) - content = data.get("tool_output", "") or data.get("tool_response", {}) - # Claude Code returns a dict for the tool output - if isinstance(content, (dict, list)): - content = json.dumps(content) - - # If identifier was not set, hash the content - if not identifier: - identifier = hashlib.sha256((content or "").encode()).hexdigest() - - payloads.append( - Payload( - event_type=event_type, - tool=tool, - content=content, - identifier=identifier, - flavor=flavor, - ) - ) - return payloads + return payload.agent.output_result(result) def _scan_payloads(self, payloads: List[Payload]) -> Result: """Scan payloads for secrets using the SecretScanner. @@ -192,41 +92,6 @@ def _scan_content( payload=payload, ) - @staticmethod - def _detect_flavor(data: Dict[str, Any]) -> Flavor: - """Detect the AI code assistant.""" - if "cursor_version" in data: - return Cursor() - elif "github.copilot-chat" in data.get("transcript_path", "").lower(): - return Copilot() - # no .lower() here to reduce the risk of false positives (this is also why this check is last) - elif "session_id" in data and "claude" in data.get("transcript_path", ""): - return Claude() - else: - # Fallback that respect base conventions - return Flavor() - - def _parse_user_prompt( - self, content: str, event_type: EventType, flavor: Flavor - ) -> List[Payload]: - """Parse the user prompt for additional payloads that we may miss.""" - payloads = [] - # Scenario 1 (the only one we know about so far): - # Code assistants don't always trigger a PRE_TOOL_USE event when - # a file is mentioned in the prompt, especially with an "@" prefix. - matches = find_filepaths(content) - for match in matches: - payloads.append( - Payload( - event_type=event_type, - tool=Tool.READ, - content="", - identifier=match, - flavor=flavor, - ) - ) - return payloads - @staticmethod def _message_from_secrets( secrets: List[Secret], payload: Payload, escape_markdown: bool = False @@ -308,27 +173,3 @@ def _send_secret_notification( # This is best effort, we don't want to propagate an error # if the notification fails. pass - - -HOOK_NAME_TO_EVENT_TYPE = { - "userpromptsubmit": EventType.USER_PROMPT, - "beforesubmitprompt": EventType.USER_PROMPT, - "pretooluse": EventType.PRE_TOOL_USE, - "posttooluse": EventType.POST_TOOL_USE, -} - -TOOL_NAME_TO_TOOL = { - "shell": Tool.BASH, # Cursor - "bash": Tool.BASH, # Claude Code - "run_in_terminal": Tool.BASH, # Copilot - "read": Tool.READ, # Claude/Cursor - "read_file": Tool.READ, # Copilot -} - - -def lookup(data: Dict[str, Any], keys: Sequence[str], default: Any = None) -> Any: - """Returns the value of the first key found in a dictionary.""" - for key in keys: - if key in data: - return data[key] - return default diff --git a/ggshield/verticals/secret/ai_hook/__init__.py b/ggshield/verticals/secret/ai_hook/__init__.py deleted file mode 100644 index 68d4170d4a..0000000000 --- a/ggshield/verticals/secret/ai_hook/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from .installation import AI_FLAVORS, install_hooks -from .scanner import AIHookScanner - - -__all__ = ["AIHookScanner", "install_hooks", "AI_FLAVORS"] diff --git a/scripts/generate-import-linter-config.py b/scripts/generate-import-linter-config.py index bdb0823275..8bc41c5678 100755 --- a/scripts/generate-import-linter-config.py +++ b/scripts/generate-import-linter-config.py @@ -64,8 +64,6 @@ class Contract(TypedDict): "ggshield.cmd.{}.** -> ggshield.verticals.{}.**", # FIXME: #521 - enforce boundaries between cmd.auth and verticals.hmsl "ggshield.cmd.auth.** -> ggshield.verticals.hmsl.**", - # Logic to install hooks for AI assistants - "ggshield.cmd.install -> ggshield.verticals.secret.ai_hook", ], "unmatched_ignore_imports_alerting": "none", }, diff --git a/tests/unit/verticals/secret/ai_hook/test_hooks.py b/tests/unit/verticals/ai/test_hooks.py similarity index 60% rename from tests/unit/verticals/secret/ai_hook/test_hooks.py rename to tests/unit/verticals/ai/test_hooks.py index f763cf76b4..e63d6f0361 100644 --- a/tests/unit/verticals/secret/ai_hook/test_hooks.py +++ b/tests/unit/verticals/ai/test_hooks.py @@ -1,74 +1,22 @@ import json -from collections import Counter from pathlib import Path -from typing import List, Set +from typing import Set from unittest.mock import MagicMock, patch import pytest -from ggshield.utils.git_shell import Filemode -from ggshield.verticals.secret import SecretScanner -from ggshield.verticals.secret.ai_hook.claude_code import Claude -from ggshield.verticals.secret.ai_hook.copilot import Copilot -from ggshield.verticals.secret.ai_hook.cursor import Cursor -from ggshield.verticals.secret.ai_hook.models import EventType, Flavor, Payload -from ggshield.verticals.secret.ai_hook.models import Result as HookResult -from ggshield.verticals.secret.ai_hook.models import Tool -from ggshield.verticals.secret.ai_hook.scanner import AIHookScanner, find_filepaths -from ggshield.verticals.secret.secret_scan_collection import Result as ScanResult -from ggshield.verticals.secret.secret_scan_collection import Results, Secret - - -def _mock_scanner(matches: List[str]) -> MagicMock: - """Create a mock SecretScanner that returns the given Results from scan().""" - mock = MagicMock(spec=SecretScanner) - scan_result = Results( - results=[ - ScanResult( - filename="url", - filemode=Filemode.FILE, - path=Path("."), - url="url", - secrets=[_make_secret(match) for match in matches], - ignored_secrets_count_by_kind=Counter(), - ) - ], - errors=[], - ) - mock.scan.return_value = scan_result - return mock - - -def _make_secret(match_str: str = "***"): - """Minimal Secret for tests; _message_from_secrets only uses detector_display_name, validity, matches[].match.""" - mock_match = MagicMock() - mock_match.match = match_str - return Secret( - detector_display_name="dummy-detector", - detector_name="dummy-detector", - detector_group_name=None, - documentation_url=None, - validity="valid", - known_secret=False, - incident_url=None, - matches=[mock_match], - ignore_reason=None, - diff_kind=None, - is_vaulted=False, - vault_type=None, - vault_name=None, - vault_path=None, - vault_path_count=None, - ) +from ggshield.verticals.ai.agents import Claude, Copilot, Cursor +from ggshield.verticals.ai.hooks import find_filepaths, parse_hook_input +from ggshield.verticals.ai.models import EventType, HookPayload, HookResult, Tool -def _dummy_payload(event_type: EventType = EventType.OTHER) -> Payload: - return Payload( +def _dummy_payload(event_type: EventType = EventType.OTHER) -> HookPayload: + return HookPayload( event_type=event_type, tool=None, content="", identifier="", - flavor=Flavor(), + agent=Cursor(), ) @@ -81,33 +29,22 @@ def tmp_file(tmp_path: Path) -> Path: class TestAIHookScannerParseInput: - """Unit tests for AIHookScanner._parse_input.""" - - def test_empty_input_raises(self): - """Empty or whitespace-only input raises ValueError.""" - scanner = AIHookScanner(_mock_scanner([])) - with pytest.raises(ValueError, match="No input received on stdin"): - scanner.scan("") - with pytest.raises(ValueError, match="No input received on stdin"): - scanner.scan(" \n ") + """Unit tests for AIHookparse_hook_input.""" def test_invalid_json_raises(self): """Invalid JSON raises ValueError with parse error.""" - scanner = AIHookScanner(_mock_scanner([])) with pytest.raises(ValueError, match="Failed to parse JSON"): - scanner._parse_input("not json {") + parse_hook_input("not json {") with pytest.raises(ValueError, match="Failed to parse JSON"): - scanner._parse_input("{ missing brace ") + parse_hook_input("{ missing brace ") def test_missing_event_type_raises(self): """JSON without event type raises ValueError.""" - scanner = AIHookScanner(_mock_scanner([])) - with pytest.raises(ValueError, match="couldn't find event type"): - scanner._parse_input('{"prompt": "hello"}') + with pytest.raises(ValueError): + parse_hook_input('{"prompt": "hello"}') def test_cursor_user_prompt(self): """Test Cursor beforeSubmitPrompt (user prompt) parsing.""" - scanner = AIHookScanner(_mock_scanner([])) data = { "conversation_id": "75fed8a8-2078-4e49-80d2-776b20d441c3", "generation_id": "1501ede6-b8ac-43f4-9943-0e218610c5c6", @@ -120,16 +57,15 @@ def test_cursor_user_prompt(self): "user_email": "user@example.com", "transcript_path": "/home/user1/.cursor/projects/foo/agent-transcripts/75fed8a8/75fed8a8.jsonl", } - payload = scanner._parse_input(json.dumps(data))[0] + payload = parse_hook_input(json.dumps(data))[0] assert payload.event_type == EventType.USER_PROMPT assert payload.content == "hello world" assert payload.tool is None assert payload.identifier != "" - assert isinstance(payload.flavor, Cursor) + assert isinstance(payload.agent, Cursor) def test_cursor_pre_tool_use_shell(self): """Test Cursor preToolUse with Shell (bash) parsing.""" - scanner = AIHookScanner(_mock_scanner([])) data = { "conversation_id": "37a17cfc-322c-47ab-88c5-e810f23f4739", "generation_id": "049f5b26-326a-4081-82c1-e5c42a63d19e", @@ -146,16 +82,15 @@ def test_cursor_pre_tool_use_shell(self): "workspace_roots": ["/home/user1/foo"], "transcript_path": "/home/user1/.cursor/projects/foo/agent-transcripts/37a17cfc/37a17cfc.jsonl", } - payload = scanner._parse_input(json.dumps(data))[0] + payload = parse_hook_input(json.dumps(data))[0] assert payload.event_type == EventType.PRE_TOOL_USE assert payload.tool == Tool.BASH assert payload.content == "whoami" assert payload.identifier == "whoami" - assert isinstance(payload.flavor, Cursor) + assert isinstance(payload.agent, Cursor) def test_cursor_pre_tool_use_read(self, tmp_file: Path): """Test Cursor preToolUse with Read (file) parsing.""" - scanner = AIHookScanner(_mock_scanner([])) data = { "conversation_id": "75fed8a8-2078-4e49-80d2-776b20d441c3", "generation_id": "1501ede6-b8ac-43f4-9943-0e218610c5c6", @@ -168,17 +103,16 @@ def test_cursor_pre_tool_use_read(self, tmp_file: Path): "workspace_roots": ["/home/user1/foo"], "transcript_path": "/home/user1/.cursor/projects/foo/agent-transcripts/75fed8a8/75fed8a8.jsonl", } - payload = scanner._parse_input(json.dumps(data))[0] + payload = parse_hook_input(json.dumps(data))[0] assert payload.event_type == EventType.PRE_TOOL_USE assert payload.tool == Tool.READ assert payload.identifier == tmp_file.as_posix() assert payload.content == "" assert payload.scannable.content == "this is the content" - assert isinstance(payload.flavor, Cursor) + assert isinstance(payload.agent, Cursor) def test_cursor_post_tool_use_shell(self): """Test Cursor postToolUse with Shell (simulated cat command result).""" - scanner = AIHookScanner(_mock_scanner([])) data = { "conversation_id": "37a17cfc-322c-47ab-88c5-e810f23f4739", "generation_id": "049f5b26-326a-4081-82c1-e5c42a63d19e", @@ -193,15 +127,14 @@ def test_cursor_post_tool_use_shell(self): "workspace_roots": ["/home/user1/foo"], "transcript_path": "/home/user/.cursor/projects/foo/agent-transcripts/37a17cfc/37a17cfc.jsonl", } - payload = scanner._parse_input(json.dumps(data))[0] + payload = parse_hook_input(json.dumps(data))[0] assert payload.event_type == EventType.POST_TOOL_USE assert payload.tool == Tool.BASH assert "user1" in payload.content - assert isinstance(payload.flavor, Cursor) + assert isinstance(payload.agent, Cursor) def test_claude_user_prompt(self): """Test Claude Code UserPromptSubmit parsing.""" - scanner = AIHookScanner(_mock_scanner([])) data = { "session_id": "273ad859-3608-4799-9971-fa15ecb1a65c", "transcript_path": "/home/user1/.claude/projects/foo/273ad859-3608-4799-9971-fa15ecb1a65c.jsonl", @@ -210,15 +143,14 @@ def test_claude_user_prompt(self): "hook_event_name": "UserPromptSubmit", "prompt": "hello world", } - payload = scanner._parse_input(json.dumps(data))[0] + payload = parse_hook_input(json.dumps(data))[0] assert payload.event_type == EventType.USER_PROMPT assert payload.content == "hello world" assert payload.tool is None - assert isinstance(payload.flavor, Claude) + assert isinstance(payload.agent, Claude) def test_claude_pre_tool_use_bash(self): """Test Claude Code PreToolUse with Bash parsing.""" - scanner = AIHookScanner(_mock_scanner([])) data = { "session_id": "3b7ae0c5-0862-4e14-aa2c-12fad909c323", "transcript_path": "/home/user1/.claude/projects/foo/3b7ae0c5.jsonl", @@ -232,15 +164,14 @@ def test_claude_pre_tool_use_bash(self): }, "tool_use_id": "toolu_01BPMKeZAMCqBtn1xJRNfDJw", } - payload = scanner._parse_input(json.dumps(data))[0] + payload = parse_hook_input(json.dumps(data))[0] assert payload.event_type == EventType.PRE_TOOL_USE assert payload.tool == Tool.BASH assert "whoami" in payload.content - assert isinstance(payload.flavor, Claude) + assert isinstance(payload.agent, Claude) def test_claude_pre_tool_use_read(self, tmp_file: Path): """Test Claude Code PreToolUse with Read parsing.""" - scanner = AIHookScanner(_mock_scanner([])) # From raw_hooks_logs: Claude PreToolUse Read data = { "session_id": "3b7ae0c5-0862-4e14-aa2c-12fad909c323", @@ -252,17 +183,16 @@ def test_claude_pre_tool_use_read(self, tmp_file: Path): "tool_input": {"file_path": tmp_file.as_posix()}, "tool_use_id": "toolu_01WabtWJpzf1ZJ8GJ3JfQEmq", } - payload = scanner._parse_input(json.dumps(data))[0] + payload = parse_hook_input(json.dumps(data))[0] assert payload.event_type == EventType.PRE_TOOL_USE assert payload.tool == Tool.READ assert payload.identifier == tmp_file.as_posix() assert payload.content == "" assert payload.scannable.content == "this is the content" - assert isinstance(payload.flavor, Claude) + assert isinstance(payload.agent, Claude) def test_claude_post_tool_use_bash(self): """Test Claude Code PostToolUse with Bash (simulated cat command result).""" - scanner = AIHookScanner(_mock_scanner([])) # From raw_hooks_logs: Claude PostToolUse Bash - tool_response has stdout data = { "session_id": "3b7ae0c5-0862-4e14-aa2c-12fad909c323", @@ -284,16 +214,15 @@ def test_claude_post_tool_use_bash(self): }, "tool_use_id": "toolu_01BPMKeZAMCqBtn1xJRNfDJw", } - payload = scanner._parse_input(json.dumps(data))[0] + payload = parse_hook_input(json.dumps(data))[0] assert payload.event_type == EventType.POST_TOOL_USE assert payload.tool == Tool.BASH # Content is json.dumps(tool_response), so the stdout is inside the string assert "user1" in payload.content - assert isinstance(payload.flavor, Claude) + assert isinstance(payload.agent, Claude) def test_claude_parse_read_files_in_prompt(self): """Test parsing "@file_path" mentions from Claude Code prompt.""" - scanner = AIHookScanner(_mock_scanner([])) data = { "session_id": "273ad859-3608-4799-9971-fa15ecb1a65c", "transcript_path": "/home/user1/.claude/projects/foo/273ad859-3608-4799-9971-fa15ecb1a65c.jsonl", @@ -302,24 +231,23 @@ def test_claude_parse_read_files_in_prompt(self): "hook_event_name": "UserPromptSubmit", "prompt": "read @folder/file.txt and summarize the content.", } - payloads = scanner._parse_input(json.dumps(data)) + payloads = parse_hook_input(json.dumps(data)) assert len(payloads) == 2 payload = payloads[0] assert payload.event_type == EventType.USER_PROMPT assert payload.tool == Tool.READ assert payload.identifier == "folder/file.txt" assert payload.content == "" # empty because inexistent file - assert isinstance(payload.flavor, Claude) + assert isinstance(payload.agent, Claude) payload = payloads[1] assert payload.event_type == EventType.USER_PROMPT assert payload.content == "read @folder/file.txt and summarize the content." assert payload.tool is None - assert isinstance(payload.flavor, Claude) + assert isinstance(payload.agent, Claude) def test_copilot_user_prompt(self): """Test Copilot UserPromptSubmit parsing.""" - scanner = AIHookScanner(_mock_scanner([])) data = { "timestamp": "2026-02-26T11:28:53.112Z", "hookEventName": "UserPromptSubmit", @@ -331,15 +259,14 @@ def test_copilot_user_prompt(self): "prompt": "hello world", "cwd": "/home/user1/foo", } - payload = scanner._parse_input(json.dumps(data))[0] + payload = parse_hook_input(json.dumps(data))[0] assert payload.event_type == EventType.USER_PROMPT assert "hello world" in payload.content assert payload.tool is None - assert isinstance(payload.flavor, Copilot) + assert isinstance(payload.agent, Copilot) def test_copilot_pre_tool_use_run_in_terminal(self): """Test Copilot PreToolUse with run_in_terminal (shell) parsing.""" - scanner = AIHookScanner(_mock_scanner([])) # From raw_hooks_logs: Copilot PreToolUse run_in_terminal data = { "timestamp": "2026-02-26T11:29:05.821Z", @@ -360,15 +287,14 @@ def test_copilot_pre_tool_use_run_in_terminal(self): "tool_use_id": "call_ADJcoVxpnzPtpU6uf0h9wzLR__vscode-1772105116075", "cwd": "/home/user1/foo", } - payload = scanner._parse_input(json.dumps(data))[0] + payload = parse_hook_input(json.dumps(data))[0] assert payload.event_type == EventType.PRE_TOOL_USE assert payload.tool == Tool.BASH assert "whoami" in payload.content - assert isinstance(payload.flavor, Copilot) + assert isinstance(payload.agent, Copilot) def test_copilot_pre_tool_use_read_file(self, tmp_file: Path): """Test Copilot PreToolUse with read_file parsing.""" - scanner = AIHookScanner(_mock_scanner([])) # From raw_hooks_logs: Copilot PreToolUse read_file (nonexistent path for deterministic test) data = { "timestamp": "2026-02-26T11:53:49.593Z", @@ -387,17 +313,16 @@ def test_copilot_pre_tool_use_read_file(self, tmp_file: Path): "tool_use_id": "call_iMFuTGETQ2z23a3xYTqcHBXp__vscode-1772105116078", "cwd": "/home/user1/foo", } - payload = scanner._parse_input(json.dumps(data))[0] + payload = parse_hook_input(json.dumps(data))[0] assert payload.event_type == EventType.PRE_TOOL_USE assert payload.tool == Tool.READ assert payload.identifier == tmp_file.as_posix() assert payload.content == "" assert payload.scannable.content == "this is the content" - assert isinstance(payload.flavor, Copilot) + assert isinstance(payload.agent, Copilot) def test_copilot_post_tool_use_run_in_terminal(self): """Test Copilot PostToolUse with run_in_terminal (simulated cat result).""" - scanner = AIHookScanner(_mock_scanner([])) # From raw_hooks_logs: Copilot PostToolUse run_in_terminal - tool_response is string data = { "timestamp": "2026-02-26T11:53:47.392Z", @@ -419,23 +344,23 @@ def test_copilot_post_tool_use_run_in_terminal(self): "tool_use_id": "call_f96KUoNCGS8jENVKnlWnSz5Q__vscode-1772105116077", "cwd": "/home/user1/foo", } - payload = scanner._parse_input(json.dumps(data))[0] + payload = parse_hook_input(json.dumps(data))[0] assert payload.event_type == EventType.POST_TOOL_USE assert payload.tool == Tool.BASH assert "user1" in payload.content - assert isinstance(payload.flavor, Copilot) + assert isinstance(payload.agent, Copilot) def test_pre_tool_use_read_with_missing_file(self): """PRE_TOOL_USE with tool_name 'read' and non-existing file yields empty content.""" - scanner = AIHookScanner(_mock_scanner([])) content = json.dumps( { "hook_event_name": "pretooluse", "tool_name": "read", "tool_input": {"file_path": "/nonexistent/path"}, + "cursor_version": "1.2.3", } ) - payload = scanner._parse_input(content)[0] + payload = parse_hook_input(content)[0] assert payload.event_type == EventType.PRE_TOOL_USE assert payload.tool == Tool.READ assert payload.identifier == "/nonexistent/path" @@ -443,75 +368,37 @@ def test_pre_tool_use_read_with_missing_file(self): def test_pre_tool_use_other_tool(self): """PRE_TOOL_USE with unknown tool yields Tool.OTHER and empty content.""" - scanner = AIHookScanner(_mock_scanner([])) data = { "hook_event_name": "PreToolUse", "tool_name": "SomeUnknownTool", "tool_input": {"arg": "value"}, + "cursor_version": "1.2.3", } - payload = scanner._parse_input(json.dumps(data))[0] + payload = parse_hook_input(json.dumps(data))[0] assert payload.event_type == EventType.PRE_TOOL_USE assert payload.tool == Tool.OTHER assert payload.content == "" def test_other_event_type(self): """Unknown event type yields EventType.OTHER with empty content.""" - scanner = AIHookScanner(_mock_scanner([])) data = { "hook_event_name": "SomeOtherEvent", "prompt": "hello", + "cursor_version": "1.2.3", } - payload = scanner._parse_input(json.dumps(data))[0] + payload = parse_hook_input(json.dumps(data))[0] assert payload.event_type == EventType.OTHER assert payload.content == "" assert payload.tool is None -class TestAIHookScannerScanContent: - """Unit tests for AIHookScanner._scan_content.""" - - def test_no_secrets_returns_allow(self): - """When scanner returns no secrets, result has block=False and nbr_secrets=0.""" - hook_scanner = AIHookScanner(_mock_scanner([])) - payload = Payload( - event_type=EventType.USER_PROMPT, - tool=None, - content="safe content", - identifier="id", - flavor=Flavor(), - ) - result = hook_scanner._scan_content(payload) - assert isinstance(result, HookResult) - assert result.block is False - assert result.nbr_secrets == 0 - assert result.message == "" - - def test_with_secrets_returns_block_and_message(self): - """When scanner returns secrets, result has block=True, nbr_secrets and message set.""" - hook_scanner = AIHookScanner(_mock_scanner(["sk-xxx"])) - payload = Payload( - event_type=EventType.USER_PROMPT, - tool=None, - content="content with sk-xxx", - identifier="id", - flavor=Flavor(), - ) - result = hook_scanner._scan_content(payload) - assert isinstance(result, HookResult) - assert result.block is True - assert result.nbr_secrets == 1 - assert "dummy-detector" in result.message - assert "secret" in result.message.lower() - assert "remove the secrets from your prompt" in result.message - - class TestFlavorOutputResult: """Unit tests for Cursor, Claude, Copilot output_result with Result objects. Mocks click.echo to capture stdout/stderr and asserts both output and return code. """ - @patch("ggshield.verticals.secret.ai_hook.cursor.click.echo") + @patch("ggshield.verticals.ai.agents.cursor.click.echo") def test_cursor_output_result_user_prompt_allow(self, mock_echo: MagicMock): """Cursor USER_PROMPT with block=False: JSON to stdout, return 0.""" result = HookResult.allow(_dummy_payload(EventType.USER_PROMPT)) @@ -524,7 +411,7 @@ def test_cursor_output_result_user_prompt_allow(self, mock_echo: MagicMock): assert out["continue"] is True assert out["user_message"] == "" - @patch("ggshield.verticals.secret.ai_hook.cursor.click.echo") + @patch("ggshield.verticals.ai.agents.cursor.click.echo") def test_cursor_output_result_user_prompt_block(self, mock_echo: MagicMock): """Cursor USER_PROMPT with block=True: JSON to stdout, return 0.""" result = HookResult( @@ -542,7 +429,7 @@ def test_cursor_output_result_user_prompt_block(self, mock_echo: MagicMock): assert out["continue"] is False assert out["user_message"] == "Remove secrets from prompt" - @patch("ggshield.verticals.secret.ai_hook.cursor.click.echo") + @patch("ggshield.verticals.ai.agents.cursor.click.echo") def test_cursor_output_result_pre_tool_use_allow(self, mock_echo: MagicMock): """Cursor PRE_TOOL_USE with block=False: permission allow, return 0.""" result = HookResult.allow(_dummy_payload(EventType.PRE_TOOL_USE)) @@ -554,7 +441,7 @@ def test_cursor_output_result_pre_tool_use_allow(self, mock_echo: MagicMock): out = json.loads(args[0]) assert out["permission"] == "allow" - @patch("ggshield.verticals.secret.ai_hook.cursor.click.echo") + @patch("ggshield.verticals.ai.agents.cursor.click.echo") def test_cursor_output_result_pre_tool_use_block(self, mock_echo: MagicMock): """Cursor PRE_TOOL_USE with block=True: permission deny, return 0.""" result = HookResult( @@ -572,7 +459,7 @@ def test_cursor_output_result_pre_tool_use_block(self, mock_echo: MagicMock): assert out["permission"] == "deny" assert out["user_message"] == "Secrets detected in command" - @patch("ggshield.verticals.secret.ai_hook.cursor.click.echo") + @patch("ggshield.verticals.ai.agents.cursor.click.echo") def test_cursor_output_result_post_tool_use(self, mock_echo: MagicMock): """Cursor POST_TOOL_USE: empty JSON to stdout, return 0.""" result = HookResult( @@ -588,7 +475,7 @@ def test_cursor_output_result_post_tool_use(self, mock_echo: MagicMock): assert kwargs.get("err", False) is False # stdout (default) assert json.loads(args[0]) == {} - @patch("ggshield.verticals.secret.ai_hook.cursor.click.echo") + @patch("ggshield.verticals.ai.agents.cursor.click.echo") def test_cursor_output_result_other_block(self, mock_echo: MagicMock): """Cursor OTHER event with block: empty JSON, return 2.""" result = HookResult( @@ -601,7 +488,7 @@ def test_cursor_output_result_other_block(self, mock_echo: MagicMock): assert code == 2 mock_echo.assert_called_once_with("{}") - @patch("ggshield.verticals.secret.ai_hook.cursor.click.echo") + @patch("ggshield.verticals.ai.agents.cursor.click.echo") def test_cursor_output_result_other_allow(self, mock_echo: MagicMock): """Cursor OTHER event without block: empty JSON, return 0.""" result = HookResult.allow(_dummy_payload(EventType.OTHER)) @@ -609,7 +496,7 @@ def test_cursor_output_result_other_allow(self, mock_echo: MagicMock): assert code == 0 mock_echo.assert_called_once_with("{}") - @patch("ggshield.verticals.secret.ai_hook.claude_code.click.echo") + @patch("ggshield.verticals.ai.agents.claude_code.click.echo") def test_claude_output_result_allow(self, mock_echo: MagicMock): """Claude with block=False: JSON continue true to stdout, return 0.""" result = HookResult.allow(_dummy_payload(EventType.USER_PROMPT)) @@ -621,7 +508,7 @@ def test_claude_output_result_allow(self, mock_echo: MagicMock): out = json.loads(args[0]) assert out["continue"] is True - @patch("ggshield.verticals.secret.ai_hook.claude_code.click.echo") + @patch("ggshield.verticals.ai.agents.claude_code.click.echo") def test_claude_output_result_block(self, mock_echo: MagicMock): """Claude with block=True: JSON continue false and stopReason to stdout, return 0.""" result = HookResult( @@ -641,7 +528,7 @@ def test_claude_output_result_block(self, mock_echo: MagicMock): out["hookSpecificOutput"]["permissionDecisionReason"] == "Secrets in file" ) - @patch("ggshield.verticals.secret.ai_hook.claude_code.click.echo") + @patch("ggshield.verticals.ai.agents.claude_code.click.echo") def test_copilot_output_result_allow(self, mock_echo: MagicMock): """Copilot with block=False: same as Claude, JSON to stdout, return 0.""" result = HookResult.allow(_dummy_payload(EventType.USER_PROMPT)) @@ -654,7 +541,7 @@ def test_copilot_output_result_allow(self, mock_echo: MagicMock): assert out["continue"] is True assert "stopReason" not in out - @patch("ggshield.verticals.secret.ai_hook.claude_code.click.echo") + @patch("ggshield.verticals.ai.agents.claude_code.click.echo") def test_copilot_output_result_block(self, mock_echo: MagicMock): """Copilot with block=True: same as Claude, JSON to stdout, return 0.""" result = HookResult( @@ -672,7 +559,7 @@ def test_copilot_output_result_block(self, mock_echo: MagicMock): assert out["decision"] == "block" assert out["reason"] == "Secret in tool output" - @patch("ggshield.verticals.secret.ai_hook.claude_code.click.echo") + @patch("ggshield.verticals.ai.agents.claude_code.click.echo") def test_copilot_other_result_block(self, mock_echo: MagicMock): """Copilot with block=True, other type of event""" result = HookResult( @@ -688,201 +575,6 @@ def test_copilot_other_result_block(self, mock_echo: MagicMock): assert not out["continue"] -class TestBaseFlavor: - """Unit tests for the base Flavor class.""" - - @patch("ggshield.verticals.secret.ai_hook.models.click.echo") - def test_base_flavor_output_result_allow(self, mock_echo: MagicMock): - """Base Flavor with block=False: prints allow message, returns 0.""" - result = HookResult.allow(_dummy_payload(EventType.USER_PROMPT)) - code = Flavor().output_result(result) - assert code == 0 - mock_echo.assert_called_once_with("No secrets found. Good to go.") - - @patch("ggshield.verticals.secret.ai_hook.models.click.echo") - def test_base_flavor_output_result_block(self, mock_echo: MagicMock): - """Base Flavor with block=True: prints message to stderr, returns 2.""" - result = HookResult( - block=True, - message="Secrets found", - nbr_secrets=1, - payload=_dummy_payload(EventType.PRE_TOOL_USE), - ) - code = Flavor().output_result(result) - assert code == 2 - mock_echo.assert_called_once_with("Secrets found", err=True) - - def test_base_flavor_settings_path(self): - """Base Flavor settings_path returns default path.""" - assert Flavor().settings_path == Path(".agents") / "hooks.json" - - def test_base_flavor_settings_template(self): - """Base Flavor settings_template returns empty dict.""" - assert Flavor().settings_template == {} - - def test_base_flavor_settings_locate(self): - """Base Flavor settings_locate always returns None.""" - assert Flavor().settings_locate([{"a": 1}], {"a": 1}) is None - - -class TestAIHookScannerScan: - """Unit tests for the AIHookScanner.scan() method.""" - - def test_scan_no_secrets_returns_zero(self): - """scan() with no secrets returns 0.""" - scanner = AIHookScanner(_mock_scanner([])) - data = { - "hook_event_name": "UserPromptSubmit", - "prompt": "hello world", - "transcript_path": "/home/user/.claude/projects/foo/session.jsonl", - } - code = scanner.scan(json.dumps(data)) - assert code == 0 - - @patch( - "ggshield.verticals.secret.ai_hook.scanner.AIHookScanner._send_secret_notification" - ) - def test_scan_post_tool_use_with_secrets_sends_notification( - self, mock_notify: MagicMock - ): - """scan() on POST_TOOL_USE with secrets sends a notification and returns 0 (no block).""" - scanner = AIHookScanner(_mock_scanner(["sk-xxx"])) - data = { - "hook_event_name": "PostToolUse", - "tool_name": "Bash", - "tool_input": {"command": "echo sk-xxx"}, - "tool_response": {"stdout": "sk-xxx\n"}, - "transcript_path": "/home/user/.claude/projects/foo/session.jsonl", - "session_id": "427ae0c5-0862-4e14-aa2c-12fad909c323", - } - code = scanner.scan(json.dumps(data)) - assert code == 0 - mock_notify.assert_called_once() - args = mock_notify.call_args[0] - assert args[0] == 1 # nbr_secrets - assert args[1] == Tool.BASH # tool - - def test_scan_pre_tool_use_with_secrets_blocks(self): - """scan() on PRE_TOOL_USE with secrets returns block result.""" - scanner = AIHookScanner(_mock_scanner(["sk-xxx"])) - data = { - "hook_event_name": "PreToolUse", - "tool_name": "Bash", - "tool_input": {"command": "echo sk-xxx"}, - "session_id": "427ae0c5-0862-4e14-aa2c-12fad909c323", - "transcript_path": "/home/user/.claude/projects/foo/session.jsonl", - } - code = scanner.scan(json.dumps(data)) - # Claude output_result always returns 0 - assert code == 0 - - def test_scan_no_content_returns_allow(self): - """scan() with no content returns 0 (and doesn't call the API).""" - mock_scanner = _mock_scanner([]) - scanner = AIHookScanner(mock_scanner) - data = { - "hook_event_name": "PreToolUse", - "tool_name": "Read", - "tool_input": {"file_path": "doesn-t-exist"}, - } - code = scanner.scan(json.dumps(data)) - assert code == 0 - mock_scanner.scan.assert_not_called() - - def test_scan_payloads_refuse_empty_list(self): - """scan() with empty list of payloads raises ValueError.""" - scanner = AIHookScanner(_mock_scanner([])) - with pytest.raises(ValueError): - scanner._scan_payloads([]) - - -class TestMessageFromSecrets: - """Unit tests for AIHookScanner._message_from_secrets with different payload types.""" - - def test_message_for_bash_tool(self): - """Message for BASH tool mentions environment variables.""" - payload = Payload( - event_type=EventType.PRE_TOOL_USE, - tool=Tool.BASH, - content="echo sk-xxx", - identifier="echo sk-xxx", - flavor=Flavor(), - ) - message = AIHookScanner._message_from_secrets([_make_secret("sk-xxx")], payload) - assert "remove the secrets from the command" in message - assert "environment variables" in message - - def test_message_for_read_tool(self): - """Message for READ tool mentions file content.""" - payload = Payload( - event_type=EventType.PRE_TOOL_USE, - tool=Tool.READ, - content="file content with secret", - identifier="/path/to/file", - flavor=Flavor(), - ) - message = AIHookScanner._message_from_secrets([_make_secret("sk-xxx")], payload) - assert "remove the secrets from" in message - - def test_message_for_other_tool(self): - """Message for OTHER tool uses generic message.""" - payload = Payload( - event_type=EventType.PRE_TOOL_USE, - tool=Tool.OTHER, - content="some content", - identifier="id", - flavor=Flavor(), - ) - message = AIHookScanner._message_from_secrets([_make_secret("sk-xxx")], payload) - assert "remove the secrets from the tool input" in message - - def test_message_escapes_markdown(self): - """When escape_markdown=True, asterisks in matches are replaced with dots.""" - payload = Payload( - event_type=EventType.USER_PROMPT, - tool=None, - content="content", - identifier="id", - flavor=Flavor(), - ) - message = AIHookScanner._message_from_secrets( - [_make_secret("sk-xxx")], payload, escape_markdown=True - ) - # The message itself should not contain raw asterisks from matches - # (the header uses ** for bold which is intentional) - assert "Detected" in message - - -class TestSendSecretNotification: - """Unit tests for AIHookScanner._send_secret_notification.""" - - @patch("ggshield.verticals.secret.ai_hook.scanner.Notify") - def test_notification_for_bash_tool(self, mock_notify_cls: MagicMock): - """Notification for BASH tool says 'running a command'.""" - AIHookScanner._send_secret_notification(1, Tool.BASH, "Claude Code") - instance = mock_notify_cls.return_value - assert "running a command" in instance.message - assert "Claude Code" in instance.message - instance.send.assert_called_once() - - @patch("ggshield.verticals.secret.ai_hook.scanner.Notify") - def test_notification_for_read_tool(self, mock_notify_cls: MagicMock): - """Notification for READ tool says 'reading a file'.""" - AIHookScanner._send_secret_notification(2, Tool.READ, "Cursor") - instance = mock_notify_cls.return_value - assert "reading a file" in instance.message - assert "2" in instance.message - instance.send.assert_called_once() - - @patch("ggshield.verticals.secret.ai_hook.scanner.Notify") - def test_notification_for_other_tool(self, mock_notify_cls: MagicMock): - """Notification for OTHER tool says 'using a tool'.""" - AIHookScanner._send_secret_notification(1, Tool.OTHER, "Copilot") - instance = mock_notify_cls.return_value - assert "using a tool" in instance.message - instance.send.assert_called_once() - - @pytest.mark.parametrize( "prompt, filepaths", [ diff --git a/tests/unit/verticals/secret/ai_hook/test_installation.py b/tests/unit/verticals/ai/test_installation.py similarity index 90% rename from tests/unit/verticals/secret/ai_hook/test_installation.py rename to tests/unit/verticals/ai/test_installation.py index 5a98ef2f7b..20e2dea79d 100644 --- a/tests/unit/verticals/secret/ai_hook/test_installation.py +++ b/tests/unit/verticals/ai/test_installation.py @@ -6,10 +6,8 @@ import pytest from ggshield.core.errors import UnexpectedError -from ggshield.verticals.secret.ai_hook.claude_code import Claude -from ggshield.verticals.secret.ai_hook.copilot import Copilot -from ggshield.verticals.secret.ai_hook.cursor import Cursor -from ggshield.verticals.secret.ai_hook.installation import ( +from ggshield.verticals.ai.agents import Claude, Copilot, Cursor +from ggshield.verticals.ai.installation import ( InstallationStats, _fill_dict, install_hooks, @@ -287,16 +285,14 @@ def test_copilot_settings_path(self): class TestInstallHooks: """Unit tests for the install_hooks function.""" - @patch("ggshield.verticals.secret.ai_hook.installation.get_user_home_dir") + @patch("ggshield.verticals.ai.installation.get_user_home_dir") def test_install_cursor_local_fresh(self, mock_home: Any, tmp_path: Path): """Install Cursor hooks locally into a fresh directory (no existing config).""" mock_home.return_value = tmp_path settings_path = tmp_path / ".cursor" / "hooks.json" assert not settings_path.exists() - with patch( - "ggshield.verticals.secret.ai_hook.installation.Path" - ) as mock_path_cls: + with patch("ggshield.verticals.ai.installation.Path") as mock_path_cls: # Make Path(".") return tmp_path so local mode writes there mock_path_cls.side_effect = lambda *a: Path(*a) if a != (".",) else tmp_path code = install_hooks("cursor", mode="local") @@ -308,7 +304,7 @@ def test_install_cursor_local_fresh(self, mock_home: Any, tmp_path: Path): for key in ("beforeSubmitPrompt", "preToolUse", "postToolUse"): assert any("ggshield" in h["command"] for h in config["hooks"][key]) - @patch("ggshield.verticals.secret.ai_hook.installation.get_user_home_dir") + @patch("ggshield.verticals.ai.installation.get_user_home_dir") def test_install_claude_global(self, mock_home: Any, tmp_path: Path): """Install Claude Code hooks globally.""" mock_home.return_value = tmp_path @@ -322,7 +318,7 @@ def test_install_claude_global(self, mock_home: Any, tmp_path: Path): for key in ("PreToolUse", "PostToolUse", "UserPromptSubmit"): assert key in config["hooks"] - @patch("ggshield.verticals.secret.ai_hook.installation.get_user_home_dir") + @patch("ggshield.verticals.ai.installation.get_user_home_dir") def test_install_copilot_global(self, mock_home: Any, tmp_path: Path): """Install Copilot hooks globally.""" mock_home.return_value = tmp_path @@ -332,12 +328,12 @@ def test_install_copilot_global(self, mock_home: Any, tmp_path: Path): settings_path = tmp_path / ".github" / "hooks" / "hooks.json" assert settings_path.exists() - def test_install_unsupported_tool_raises(self): - """install_hooks raises ValueError for unsupported tool name.""" - with pytest.raises(ValueError, match="Unsupported tool name"): - install_hooks("unknown-tool", mode="local") + def test_install_unsupported_agent_raises(self): + """install_hooks raises ValueError for unsupported agent.""" + with pytest.raises(ValueError, match="Unsupported agent"): + install_hooks("unknown-agent", mode="local") - @patch("ggshield.verticals.secret.ai_hook.installation.get_user_home_dir") + @patch("ggshield.verticals.ai.installation.get_user_home_dir") def test_install_with_existing_config(self, mock_home: Any, tmp_path: Path): """Install hooks when a config file already exists (merges).""" mock_home.return_value = tmp_path @@ -345,9 +341,7 @@ def test_install_with_existing_config(self, mock_home: Any, tmp_path: Path): settings_path.parent.mkdir(parents=True) settings_path.write_text(json.dumps({"version": 1, "other_key": "keep_me"})) - with patch( - "ggshield.verticals.secret.ai_hook.installation.Path" - ) as mock_path_cls: + with patch("ggshield.verticals.ai.installation.Path") as mock_path_cls: mock_path_cls.side_effect = lambda *a: Path(*a) if a != (".",) else tmp_path code = install_hooks("cursor", mode="local") @@ -356,7 +350,7 @@ def test_install_with_existing_config(self, mock_home: Any, tmp_path: Path): assert config["other_key"] == "keep_me" assert "hooks" in config - @patch("ggshield.verticals.secret.ai_hook.installation.get_user_home_dir") + @patch("ggshield.verticals.ai.installation.get_user_home_dir") def test_install_with_corrupt_json_raises(self, mock_home: Any, tmp_path: Path): """install_hooks raises UnexpectedError when existing config is invalid JSON.""" mock_home.return_value = tmp_path @@ -364,21 +358,17 @@ def test_install_with_corrupt_json_raises(self, mock_home: Any, tmp_path: Path): settings_path.parent.mkdir(parents=True) settings_path.write_text("{ invalid json") - with patch( - "ggshield.verticals.secret.ai_hook.installation.Path" - ) as mock_path_cls: + with patch("ggshield.verticals.ai.installation.Path") as mock_path_cls: mock_path_cls.side_effect = lambda *a: Path(*a) if a != (".",) else tmp_path with pytest.raises(UnexpectedError, match="Failed to parse"): install_hooks("cursor", mode="local") - @patch("ggshield.verticals.secret.ai_hook.installation.get_user_home_dir") + @patch("ggshield.verticals.ai.installation.get_user_home_dir") def test_install_already_present(self, mock_home: Any, tmp_path: Path): """install_hooks when hooks are already installed reports 'already installed'.""" mock_home.return_value = tmp_path - with patch( - "ggshield.verticals.secret.ai_hook.installation.Path" - ) as mock_path_cls: + with patch("ggshield.verticals.ai.installation.Path") as mock_path_cls: mock_path_cls.side_effect = lambda *a: Path(*a) if a != (".",) else tmp_path # Install once install_hooks("cursor", mode="local") @@ -387,14 +377,12 @@ def test_install_already_present(self, mock_home: Any, tmp_path: Path): assert code == 0 - @patch("ggshield.verticals.secret.ai_hook.installation.get_user_home_dir") + @patch("ggshield.verticals.ai.installation.get_user_home_dir") def test_install_force_updates(self, mock_home: Any, tmp_path: Path): """install_hooks with force=True updates existing hooks.""" mock_home.return_value = tmp_path - with patch( - "ggshield.verticals.secret.ai_hook.installation.Path" - ) as mock_path_cls: + with patch("ggshield.verticals.ai.installation.Path") as mock_path_cls: mock_path_cls.side_effect = lambda *a: Path(*a) if a != (".",) else tmp_path install_hooks("cursor", mode="local") code = install_hooks("cursor", mode="local", force=True) diff --git a/tests/unit/verticals/secret/test_ai_hooks.py b/tests/unit/verticals/secret/test_ai_hooks.py new file mode 100644 index 0000000000..5bfeca82cd --- /dev/null +++ b/tests/unit/verticals/secret/test_ai_hooks.py @@ -0,0 +1,263 @@ +import json +from collections import Counter +from pathlib import Path +from typing import List +from unittest.mock import MagicMock, patch + +import pytest + +from ggshield.utils.git_shell import Filemode +from ggshield.verticals.ai.agents import Cursor +from ggshield.verticals.secret import SecretScanner +from ggshield.verticals.secret.ai_hook import AIHookScanner, EventType, Payload +from ggshield.verticals.secret.ai_hook import Result as HookResult +from ggshield.verticals.secret.ai_hook import Tool +from ggshield.verticals.secret.secret_scan_collection import Result as ScanResult +from ggshield.verticals.secret.secret_scan_collection import Results, Secret + + +def _mock_scanner(matches: List[str]) -> MagicMock: + """Create a mock SecretScanner that returns the given Results from scan().""" + mock = MagicMock(spec=SecretScanner) + scan_result = Results( + results=[ + ScanResult( + filename="url", + filemode=Filemode.FILE, + path=Path("."), + url="url", + secrets=[_make_secret(match) for match in matches], + ignored_secrets_count_by_kind=Counter(), + ) + ], + errors=[], + ) + mock.scan.return_value = scan_result + return mock + + +def _make_secret(match_str: str = "***"): + """Minimal Secret for tests; _message_from_secrets only uses detector_display_name, validity, matches[].match.""" + mock_match = MagicMock() + mock_match.match = match_str + return Secret( + detector_display_name="dummy-detector", + detector_name="dummy-detector", + detector_group_name=None, + documentation_url=None, + validity="valid", + known_secret=False, + incident_url=None, + matches=[mock_match], + ignore_reason=None, + diff_kind=None, + is_vaulted=False, + vault_type=None, + vault_name=None, + vault_path=None, + vault_path_count=None, + ) + + +class TestAIHookScannerScanContent: + """Unit tests for AIHookScanner._scan_content.""" + + def test_no_secrets_returns_allow(self): + """When scanner returns no secrets, result has block=False and nbr_secrets=0.""" + hook_scanner = AIHookScanner(_mock_scanner([])) + payload = Payload( + event_type=EventType.USER_PROMPT, + tool=None, + content="safe content", + identifier="id", + agent=Cursor(), + ) + result = hook_scanner._scan_content(payload) + assert isinstance(result, HookResult) + assert result.block is False + assert result.nbr_secrets == 0 + assert result.message == "" + + def test_with_secrets_returns_block_and_message(self): + """When scanner returns secrets, result has block=True, nbr_secrets and message set.""" + hook_scanner = AIHookScanner(_mock_scanner(["sk-xxx"])) + payload = Payload( + event_type=EventType.USER_PROMPT, + tool=None, + content="content with sk-xxx", + identifier="id", + agent=Cursor(), + ) + result = hook_scanner._scan_content(payload) + assert isinstance(result, HookResult) + assert result.block is True + assert result.nbr_secrets == 1 + assert "dummy-detector" in result.message + assert "secret" in result.message.lower() + assert "remove the secrets from your prompt" in result.message + + +class TestAIHookScannerScan: + """Unit tests for the AIHookScanner.scan() method.""" + + def test_empty_input_raises(self): + """Empty or whitespace-only input raises ValueError.""" + scanner = AIHookScanner(_mock_scanner([])) + with pytest.raises(ValueError, match="No input received on stdin"): + scanner.scan("") + with pytest.raises(ValueError, match="No input received on stdin"): + scanner.scan(" \n ") + + def test_scan_no_secrets_returns_zero(self): + """scan() with no secrets returns 0.""" + scanner = AIHookScanner(_mock_scanner([])) + data = { + "hook_event_name": "UserPromptSubmit", + "prompt": "hello world", + "transcript_path": "/home/user/.claude/projects/foo/session.jsonl", + "cursor_version": "1.2.3", + } + code = scanner.scan(json.dumps(data)) + assert code == 0 + + @patch("ggshield.verticals.secret.ai_hook.AIHookScanner._send_secret_notification") + def test_scan_post_tool_use_with_secrets_sends_notification( + self, mock_notify: MagicMock + ): + """scan() on POST_TOOL_USE with secrets sends a notification and returns 0 (no block).""" + scanner = AIHookScanner(_mock_scanner(["sk-xxx"])) + data = { + "hook_event_name": "PostToolUse", + "tool_name": "Bash", + "tool_input": {"command": "echo sk-xxx"}, + "tool_response": {"stdout": "sk-xxx\n"}, + "transcript_path": "/home/user/.claude/projects/foo/session.jsonl", + "session_id": "427ae0c5-0862-4e14-aa2c-12fad909c323", + } + code = scanner.scan(json.dumps(data)) + assert code == 0 + mock_notify.assert_called_once() + args = mock_notify.call_args[0] + assert args[0] == 1 # nbr_secrets + assert args[1] == Tool.BASH # tool + + def test_scan_pre_tool_use_with_secrets_blocks(self): + """scan() on PRE_TOOL_USE with secrets returns block result.""" + scanner = AIHookScanner(_mock_scanner(["sk-xxx"])) + data = { + "hook_event_name": "PreToolUse", + "tool_name": "Bash", + "tool_input": {"command": "echo sk-xxx"}, + "session_id": "427ae0c5-0862-4e14-aa2c-12fad909c323", + "transcript_path": "/home/user/.claude/projects/foo/session.jsonl", + } + code = scanner.scan(json.dumps(data)) + # Claude output_result always returns 0 + assert code == 0 + + def test_scan_no_content_returns_allow(self): + """scan() with no content returns 0 (and doesn't call the API).""" + mock_scanner = _mock_scanner([]) + scanner = AIHookScanner(mock_scanner) + data = { + "hook_event_name": "PreToolUse", + "tool_name": "Read", + "tool_input": {"file_path": "doesn-t-exist"}, + "cursor_version": "1.2.3", + } + code = scanner.scan(json.dumps(data)) + assert code == 0 + mock_scanner.scan.assert_not_called() + + def test_scan_payloads_refuse_empty_list(self): + """scan() with empty list of payloads raises ValueError.""" + scanner = AIHookScanner(_mock_scanner([])) + with pytest.raises(ValueError): + scanner._scan_payloads([]) + + +class TestMessageFromSecrets: + """Unit tests for AIHookScanner._message_from_secrets with different payload types.""" + + def test_message_for_bash_tool(self): + """Message for BASH tool mentions environment variables.""" + payload = Payload( + event_type=EventType.PRE_TOOL_USE, + tool=Tool.BASH, + content="echo sk-xxx", + identifier="echo sk-xxx", + agent=Cursor(), + ) + message = AIHookScanner._message_from_secrets([_make_secret("sk-xxx")], payload) + assert "remove the secrets from the command" in message + assert "environment variables" in message + + def test_message_for_read_tool(self): + """Message for READ tool mentions file content.""" + payload = Payload( + event_type=EventType.PRE_TOOL_USE, + tool=Tool.READ, + content="file content with secret", + identifier="/path/to/file", + agent=Cursor(), + ) + message = AIHookScanner._message_from_secrets([_make_secret("sk-xxx")], payload) + assert "remove the secrets from" in message + + def test_message_for_other_tool(self): + """Message for OTHER tool uses generic message.""" + payload = Payload( + event_type=EventType.PRE_TOOL_USE, + tool=Tool.OTHER, + content="some content", + identifier="id", + agent=Cursor(), + ) + message = AIHookScanner._message_from_secrets([_make_secret("sk-xxx")], payload) + assert "remove the secrets from the tool input" in message + + def test_message_escapes_markdown(self): + """When escape_markdown=True, asterisks in matches are replaced with dots.""" + payload = Payload( + event_type=EventType.USER_PROMPT, + tool=None, + content="content", + identifier="id", + agent=Cursor(), + ) + message = AIHookScanner._message_from_secrets( + [_make_secret("sk-xxx")], payload, escape_markdown=True + ) + # The message itself should not contain raw asterisks from matches + # (the header uses ** for bold which is intentional) + assert "Detected" in message + + +class TestSendSecretNotification: + """Unit tests for AIHookScanner._send_secret_notification.""" + + @patch("ggshield.verticals.secret.ai_hook.Notify") + def test_notification_for_bash_tool(self, mock_notify_cls: MagicMock): + """Notification for BASH tool says 'running a command'.""" + AIHookScanner._send_secret_notification(1, Tool.BASH, "Claude Code") + instance = mock_notify_cls.return_value + assert "running a command" in instance.message + assert "Claude Code" in instance.message + instance.send.assert_called_once() + + @patch("ggshield.verticals.secret.ai_hook.Notify") + def test_notification_for_read_tool(self, mock_notify_cls: MagicMock): + """Notification for READ tool says 'reading a file'.""" + AIHookScanner._send_secret_notification(2, Tool.READ, "Cursor") + instance = mock_notify_cls.return_value + assert "reading a file" in instance.message + assert "2" in instance.message + instance.send.assert_called_once() + + @patch("ggshield.verticals.secret.ai_hook.Notify") + def test_notification_for_other_tool(self, mock_notify_cls: MagicMock): + """Notification for OTHER tool says 'using a tool'.""" + AIHookScanner._send_secret_notification(1, Tool.OTHER, "Copilot") + instance = mock_notify_cls.return_value + assert "using a tool" in instance.message + instance.send.assert_called_once() From 17f9363b18730565e53ae6570e15575616e1d7b6 Mon Sep 17 00:00:00 2001 From: Paul PETIT Date: Tue, 7 Apr 2026 09:04:26 +0200 Subject: [PATCH 2/3] feat(ai_hook): reduce coupling between verticals --- .importlinter | 5 +- ggshield/cmd/secret/scan/ai_hook.py | 2 +- ggshield/core/scan/__init__.py | 4 + ggshield/core/scan/scanner.py | 50 ++++ ggshield/verticals/ai/__init__.py | 10 + ggshield/verticals/ai/hooks.py | 176 ++++++++++++- ggshield/verticals/ai/installation.py | 6 +- ggshield/verticals/secret/ai_hook.py | 175 ------------ scripts/generate-import-linter-config.py | 4 + tests/unit/verticals/ai/test_hooks.py | 256 +++++++++++++++++- tests/unit/verticals/secret/test_ai_hooks.py | 263 ------------------- 11 files changed, 504 insertions(+), 447 deletions(-) create mode 100644 ggshield/core/scan/scanner.py create mode 100644 ggshield/verticals/ai/__init__.py delete mode 100644 ggshield/verticals/secret/ai_hook.py delete mode 100644 tests/unit/verticals/secret/test_ai_hooks.py diff --git a/.importlinter b/.importlinter index 866d4d06db..3e5669e1a3 100644 --- a/.importlinter +++ b/.importlinter @@ -10,7 +10,7 @@ type = layers layers = ggshield.__main__ ggshield.cmd.auth | ggshield.cmd.config | ggshield.cmd.hmsl | ggshield.cmd.honeytoken | ggshield.cmd.install | ggshield.cmd.plugin | ggshield.cmd.quota | ggshield.cmd.secret | ggshield.cmd.status | ggshield.cmd.utils - ggshield.verticals.auth | ggshield.verticals.hmsl | ggshield.verticals.secret + ggshield.verticals.ai | ggshield.verticals.auth | ggshield.verticals.hmsl | ggshield.verticals.secret ggshield.core click | ggshield.utils | pygitguardian ignore_imports = @@ -33,6 +33,7 @@ source_modules = ggshield.cmd.status ggshield.cmd.utils forbidden_modules = + ggshield.verticals.ai ggshield.verticals.auth ggshield.verticals.hmsl ggshield.verticals.secret @@ -46,6 +47,7 @@ ignore_imports = ggshield.cmd.hmsl.** -> ggshield.verticals.hmsl.** ggshield.cmd.honeytoken.** -> ggshield.verticals.honeytoken ggshield.cmd.honeytoken.** -> ggshield.verticals.honeytoken.** + ggshield.cmd.install -> ggshield.verticals.ai.installation ggshield.cmd.install.** -> ggshield.verticals.install ggshield.cmd.install.** -> ggshield.verticals.install.** ggshield.cmd.plugin.** -> ggshield.core.plugin @@ -54,6 +56,7 @@ ignore_imports = ggshield.cmd.quota.** -> ggshield.verticals.quota.** ggshield.cmd.secret.** -> ggshield.verticals.secret ggshield.cmd.secret.** -> ggshield.verticals.secret.** + ggshield.cmd.secret.scan.ai_hook -> ggshield.verticals.ai.hooks ggshield.cmd.status.** -> ggshield.verticals.status ggshield.cmd.status.** -> ggshield.verticals.status.** ggshield.cmd.utils.** -> ggshield.verticals.utils diff --git a/ggshield/cmd/secret/scan/ai_hook.py b/ggshield/cmd/secret/scan/ai_hook.py index 64e835c757..3689b00131 100644 --- a/ggshield/cmd/secret/scan/ai_hook.py +++ b/ggshield/cmd/secret/scan/ai_hook.py @@ -10,8 +10,8 @@ from ggshield.core import ui from ggshield.core.client import create_client_from_config from ggshield.core.scan import ScanContext, ScanMode +from ggshield.verticals.ai.hooks import AIHookScanner from ggshield.verticals.secret import SecretScanner -from ggshield.verticals.secret.ai_hook import AIHookScanner MAX_READ_SIZE = 1024 * 1024 * 10 # We restrict stdin read to 10MB diff --git a/ggshield/core/scan/__init__.py b/ggshield/core/scan/__init__.py index b37b6197f6..e2c370a110 100644 --- a/ggshield/core/scan/__init__.py +++ b/ggshield/core/scan/__init__.py @@ -3,6 +3,7 @@ from .scan_context import ScanContext from .scan_mode import ScanMode from .scannable import DecodeError, NonSeekableFileError, Scannable, StringScannable +from .scanner import ResultsProtocol, ScannerProtocol, SecretProtocol __all__ = [ @@ -11,8 +12,11 @@ "DecodeError", "File", "NonSeekableFileError", + "ResultsProtocol", "ScanContext", "ScanMode", "Scannable", + "ScannerProtocol", + "SecretProtocol", "StringScannable", ] diff --git a/ggshield/core/scan/scanner.py b/ggshield/core/scan/scanner.py new file mode 100644 index 0000000000..9f963669f6 --- /dev/null +++ b/ggshield/core/scan/scanner.py @@ -0,0 +1,50 @@ +""" +Protocols for SecretScanner and its results, +so that other verticals can use the scanner if they are provided one. +""" + +from collections.abc import Sequence +from typing import Iterable, Optional, Protocol + +from pygitguardian.models import Match + +from ggshield.core.scanner_ui import ScannerUI + +from . import Scannable + + +class SecretProtocol(Protocol): + """Abstract base class for secrets. + + We use getters instead of properties to have a . + """ + + @property + def detector_display_name(self) -> str: ... + + @property + def validity(self) -> str: ... + + @property + def matches(self) -> Sequence[Match]: ... + + +class ResultProtocol(Protocol): + @property + def secrets(self) -> Sequence[SecretProtocol]: ... + + +class ResultsProtocol(Protocol): + @property + def results(self) -> Sequence[ResultProtocol]: ... + + +class ScannerProtocol(Protocol): + """Protocol for scanners.""" + + def scan( + self, + files: Iterable[Scannable], + scanner_ui: ScannerUI, + scan_threads: Optional[int] = None, + ) -> ResultsProtocol: ... diff --git a/ggshield/verticals/ai/__init__.py b/ggshield/verticals/ai/__init__.py new file mode 100644 index 0000000000..cbd7e4db0c --- /dev/null +++ b/ggshield/verticals/ai/__init__.py @@ -0,0 +1,10 @@ +from .agents import AGENTS +from .hooks import AIHookScanner +from .installation import install_hooks + + +__all__ = [ + "AGENTS", + "AIHookScanner", + "install_hooks", +] diff --git a/ggshield/verticals/ai/hooks.py b/ggshield/verticals/ai/hooks.py index 75adabcd36..4307398c79 100644 --- a/ggshield/verticals/ai/hooks.py +++ b/ggshield/verticals/ai/hooks.py @@ -3,9 +3,16 @@ import re from typing import Any, Dict, List, Sequence, Set -from ggshield.verticals.ai.agents import Claude, Copilot, Cursor +from notifypy import Notify -from .models import Agent, EventType, HookPayload, Tool +from ggshield.core.filter import censor_match +from ggshield.core.scan import ScannerProtocol +from ggshield.core.scan import SecretProtocol as Secret +from ggshield.core.scanner_ui import create_message_only_scanner_ui +from ggshield.core.text_utils import pluralize, translate_validity + +from .agents import Claude, Copilot, Cursor +from .models import Agent, EventType, HookPayload, HookResult, Tool HOOK_NAME_TO_EVENT_TYPE = { @@ -163,3 +170,168 @@ def _parse_user_prompt( ) ) return payloads + + +class AIHookScanner: + """AI hook scanner. + + It is called with the payload of a hook event. + Note that instead of having a base class with common method and a subclass per supported AI tool, + we instead have a single class which detects which protocol to use. + This is because some tools sloppily support hooks from others. For instance, + Cursor will call hooks defined in the Claude Code format, but send payload in its own format. + So we can't assume which tool will call us based on the command line/hook configuration only. + + Raises: + ValueError: If the input is not valid. + """ + + def __init__(self, scanner: ScannerProtocol): + self.scanner = scanner + + def scan(self, content: str) -> int: + """Scan the content, print the result and return the exit code.""" + + payloads = parse_hook_input(content) + result = self._scan_payloads(payloads) + payload = result.payload + + # Special case: in post-tool use, the action is already done: at least notify the user + if result.block and payload.event_type == EventType.POST_TOOL_USE: + self._send_secret_notification( + result.nbr_secrets, + payload.tool or Tool.OTHER, + payload.agent.display_name, + ) + + return payload.agent.output_result(result) + + def _scan_payloads(self, payloads: List[HookPayload]) -> HookResult: + """Scan payloads for secrets using the SecretScanner. + + Returns: + The result of the first blocking payload, or a non-blocking result. + Raises a ValueError if the list is empty (we must have at least one to emit a result). + """ + if not payloads: + raise ValueError("Error: no payloads to scan") + for payload in payloads: + result = self._scan_content(payload) + if result.block: + return result + return HookResult.allow(payloads[0]) + + def _scan_content( + self, + payload: HookPayload, + ) -> HookResult: + """Scan content for secrets using the SecretScanner.""" + # Short path: if there is no content, no need to do an API call + if payload.empty: + return HookResult.allow(payload) + + with create_message_only_scanner_ui() as scanner_ui: + results = self.scanner.scan([payload.scannable], scanner_ui=scanner_ui) + # Collect all secrets from results + secrets: List[Secret] = [] + for result in results.results: + secrets.extend(result.secrets) + + if not secrets: + return HookResult.allow(payload) + + message = self._message_from_secrets( + secrets, + payload, + escape_markdown=True, + ) + return HookResult( + block=True, + message=message, + nbr_secrets=len(secrets), + payload=payload, + ) + + @staticmethod + def _message_from_secrets( + secrets: List[Secret], + payload: HookPayload, + escape_markdown: bool = False, + ) -> str: + """ + Format detected secrets into a user-friendly message. + + Args: + secrets: List of detected secrets + payload: Text to display after the secrets output + escape_markdown: If True, escape asterisks to prevent markdown interpretation + + Returns: + Formatted message describing the detected secrets + """ + count = len(secrets) + header = f"**🚨 Detected {count} {pluralize('secret', count)} 🚨**" + + secret_lines = [] + for secret in secrets: + validity = translate_validity(secret.validity).lower() + if validity == "valid": + validity = f"**{validity}**" + match_str = ", ".join(censor_match(m) for m in secret.matches) + if escape_markdown: + match_str = match_str.replace("*", "•") + secret_lines.append( + f" - {secret.detector_display_name} ({validity}): {match_str}" + ) + + if payload.tool == Tool.BASH: + if payload.event_type == EventType.POST_TOOL_USE: + message = "Secrets detected in the command output." + else: + message = ( + "Please remove the secrets from the command before executing it. " + "Consider using environment variables or a secrets manager instead." + ) + elif payload.tool == Tool.READ: + message = f"Please remove the secrets from {payload.identifier} before reading it." + elif payload.event_type == EventType.USER_PROMPT: + message = "Please remove the secrets from your prompt before submitting." + else: + message = ( + "Please remove the secrets from the tool input before executing. " + "Consider using environment variables or a secrets manager instead." + ) + + secrets_block = "\n".join(secret_lines) + return f"{header}\n{secrets_block}\n\n{message}" + + @staticmethod + def _send_secret_notification( + nbr_secrets: int, tool: Tool, agent_name: str + ) -> None: + """ + Send desktop notification when secrets are detected. + + Args: + nbr_secrets: Number of detected secrets + tool: Tool used to detect the secrets + agent_name: Name of the agent that detected the secrets + """ + source = "using a tool" + if tool == Tool.READ: + source = "reading a file" + elif tool == Tool.BASH: + source = "running a command" + notification = Notify() + notification.title = "ggshield - Secrets Detected" + notification.message = ( + f"{agent_name} got access to {nbr_secrets}" + f" {pluralize('secret', nbr_secrets)} by {source}" + ) + notification.application_name = "ggshield" + try: + notification.send() + except Exception: + # This is best effort, we don't want to propagate an error + # if the notification fails. + pass diff --git a/ggshield/verticals/ai/installation.py b/ggshield/verticals/ai/installation.py index f624c1fc15..ca328b374f 100644 --- a/ggshield/verticals/ai/installation.py +++ b/ggshield/verticals/ai/installation.py @@ -80,11 +80,11 @@ def install_hooks( # Report what happened styled_path = click.style(settings_path, fg="yellow", bold=True) if stats.added == 0 and stats.already_present > 0: - click.echo(f"{agent.name} hooks already installed in {styled_path}") + click.echo(f"{agent.display_name} hooks already installed in {styled_path}") elif stats.added > 0 and stats.already_present > 0: - click.echo(f"{agent.name} hooks updated in {styled_path}") + click.echo(f"{agent.display_name} hooks updated in {styled_path}") else: - click.echo(f"{agent.name} hooks successfully added in {styled_path}") + click.echo(f"{agent.display_name} hooks successfully added in {styled_path}") return 0 diff --git a/ggshield/verticals/secret/ai_hook.py b/ggshield/verticals/secret/ai_hook.py deleted file mode 100644 index 332262a3bf..0000000000 --- a/ggshield/verticals/secret/ai_hook.py +++ /dev/null @@ -1,175 +0,0 @@ -from typing import List - -from notifypy import Notify - -from ggshield.core.filter import censor_match -from ggshield.core.scanner_ui import create_message_only_scanner_ui -from ggshield.core.text_utils import pluralize, translate_validity -from ggshield.verticals.ai.hooks import parse_hook_input -from ggshield.verticals.ai.models import EventType -from ggshield.verticals.ai.models import HookPayload as Payload -from ggshield.verticals.ai.models import HookResult as Result -from ggshield.verticals.ai.models import Tool -from ggshield.verticals.secret import SecretScanner -from ggshield.verticals.secret.secret_scan_collection import Secret - - -class AIHookScanner: - """AI hook scanner. - - It is called with the payload of a hook event. - Note that instead of having a base class with common method and a subclass per supported AI tool, - we instead have a single class which detects which protocol to use (called "flavor"). - This is because some tools sloppily support hooks from others. For instance, - Cursor will call hooks defined in the Claude Code format, but send payload in its own format. - So we can't assume which tool will call us based on the command line/hook configuration only. - - Raises: - ValueError: If the input is not valid. - """ - - def __init__(self, scanner: SecretScanner): - self.scanner = scanner - - def scan(self, content: str) -> int: - """Scan the content, print the result and return the exit code.""" - - payloads = parse_hook_input(content) - result = self._scan_payloads(payloads) - payload = result.payload - - # Special case: in post-tool use, the action is already done: at least notify the user - if result.block and payload.event_type == EventType.POST_TOOL_USE: - self._send_secret_notification( - result.nbr_secrets, payload.tool or Tool.OTHER, payload.agent.name - ) - - return payload.agent.output_result(result) - - def _scan_payloads(self, payloads: List[Payload]) -> Result: - """Scan payloads for secrets using the SecretScanner. - - Returns: - The result of the first blocking payload, or a non-blocking result. - Raises a ValueError if the list is empty (we must have at least one to emit a result). - """ - if not payloads: - raise ValueError("Error: no payloads to scan") - for payload in payloads: - result = self._scan_content(payload) - if result.block: - return result - return Result.allow(payloads[0]) - - def _scan_content( - self, - payload: Payload, - ) -> Result: - """Scan content for secrets using the SecretScanner.""" - # Short path: if there is no content, no need to do an API call - if payload.empty: - return Result.allow(payload) - - with create_message_only_scanner_ui() as scanner_ui: - results = self.scanner.scan([payload.scannable], scanner_ui=scanner_ui) - # Collect all secrets from results - secrets: List[Secret] = [] - for result in results.results: - secrets.extend(result.secrets) - - if not secrets: - return Result.allow(payload) - - message = self._message_from_secrets( - secrets, - payload, - escape_markdown=True, - ) - return Result( - block=True, - message=message, - nbr_secrets=len(secrets), - payload=payload, - ) - - @staticmethod - def _message_from_secrets( - secrets: List[Secret], payload: Payload, escape_markdown: bool = False - ) -> str: - """ - Format detected secrets into a user-friendly message. - - Args: - secrets: List of detected secrets - payload: Text to display after the secrets output - escape_markdown: If True, escape asterisks to prevent markdown interpretation - - Returns: - Formatted message describing the detected secrets - """ - count = len(secrets) - header = f"**🚨 Detected {count} {pluralize('secret', count)} 🚨**" - - secret_lines = [] - for secret in secrets: - validity = translate_validity(secret.validity).lower() - if validity == "valid": - validity = f"**{validity}**" - match_str = ", ".join(censor_match(m) for m in secret.matches) - if escape_markdown: - match_str = match_str.replace("*", "•") - secret_lines.append( - f" - {secret.detector_display_name} ({validity}): {match_str}" - ) - - if payload.tool == Tool.BASH: - if payload.event_type == EventType.POST_TOOL_USE: - message = "Secrets detected in the command output." - else: - message = ( - "Please remove the secrets from the command before executing it. " - "Consider using environment variables or a secrets manager instead." - ) - elif payload.tool == Tool.READ: - message = f"Please remove the secrets from {payload.identifier} before reading it." - elif payload.event_type == EventType.USER_PROMPT: - message = "Please remove the secrets from your prompt before submitting." - else: - message = ( - "Please remove the secrets from the tool input before executing. " - "Consider using environment variables or a secrets manager instead." - ) - - secrets_block = "\n".join(secret_lines) - return f"{header}\n{secrets_block}\n\n{message}" - - @staticmethod - def _send_secret_notification( - nbr_secrets: int, tool: Tool, agent_name: str - ) -> None: - """ - Send desktop notification when secrets are detected. - - Args: - nbr_secrets: Number of detected secrets - tool: Tool used to detect the secrets - agent_name: Name of the agent that detected the secrets - """ - source = "using a tool" - if tool == Tool.READ: - source = "reading a file" - elif tool == Tool.BASH: - source = "running a command" - notification = Notify() - notification.title = "ggshield - Secrets Detected" - notification.message = ( - f"{agent_name} got access to {nbr_secrets}" - f" {pluralize('secret', nbr_secrets)} by {source}" - ) - notification.application_name = "ggshield" - try: - notification.send() - except Exception: - # This is best effort, we don't want to propagate an error - # if the notification fails. - pass diff --git a/scripts/generate-import-linter-config.py b/scripts/generate-import-linter-config.py index 8bc41c5678..3bf75f2a76 100755 --- a/scripts/generate-import-linter-config.py +++ b/scripts/generate-import-linter-config.py @@ -64,6 +64,10 @@ class Contract(TypedDict): "ggshield.cmd.{}.** -> ggshield.verticals.{}.**", # FIXME: #521 - enforce boundaries between cmd.auth and verticals.hmsl "ggshield.cmd.auth.** -> ggshield.verticals.hmsl.**", + # Install command import logic to install AI hooks + "ggshield.cmd.install -> ggshield.verticals.ai.installation", + # AI hook command import logic to scan AI hook payloads + "ggshield.cmd.secret.scan.ai_hook -> ggshield.verticals.ai.hooks", ], "unmatched_ignore_imports_alerting": "none", }, diff --git a/tests/unit/verticals/ai/test_hooks.py b/tests/unit/verticals/ai/test_hooks.py index e63d6f0361..adf355774f 100644 --- a/tests/unit/verticals/ai/test_hooks.py +++ b/tests/unit/verticals/ai/test_hooks.py @@ -1,13 +1,18 @@ import json +from collections import Counter from pathlib import Path -from typing import Set +from typing import List, Set from unittest.mock import MagicMock, patch import pytest +from ggshield.utils.git_shell import Filemode from ggshield.verticals.ai.agents import Claude, Copilot, Cursor -from ggshield.verticals.ai.hooks import find_filepaths, parse_hook_input +from ggshield.verticals.ai.hooks import AIHookScanner, find_filepaths, parse_hook_input from ggshield.verticals.ai.models import EventType, HookPayload, HookResult, Tool +from ggshield.verticals.secret import SecretScanner +from ggshield.verticals.secret.secret_scan_collection import Result as ScanResult +from ggshield.verticals.secret.secret_scan_collection import Results, Secret def _dummy_payload(event_type: EventType = EventType.OTHER) -> HookPayload: @@ -28,6 +33,253 @@ def tmp_file(tmp_path: Path) -> Path: return file +def _mock_scanner(matches: List[str]) -> MagicMock: + """Create a mock SecretScanner that returns the given Results from scan().""" + mock = MagicMock(spec=SecretScanner) + scan_result = Results( + results=[ + ScanResult( + filename="url", + filemode=Filemode.FILE, + path=Path("."), + url="url", + secrets=[_make_secret(match) for match in matches], + ignored_secrets_count_by_kind=Counter(), + ) + ], + errors=[], + ) + mock.scan.return_value = scan_result + return mock + + +def _make_secret(match_str: str = "***"): + """Minimal Secret for tests; _message_from_secrets only uses detector_display_name, validity, matches[].match.""" + mock_match = MagicMock() + mock_match.match = match_str + return Secret( + detector_display_name="dummy-detector", + detector_name="dummy-detector", + detector_group_name=None, + documentation_url=None, + validity="valid", + known_secret=False, + incident_url=None, + matches=[mock_match], + ignore_reason=None, + diff_kind=None, + is_vaulted=False, + vault_type=None, + vault_name=None, + vault_path=None, + vault_path_count=None, + ) + + +class TestAIHookScannerScanContent: + """Unit tests for AIHookScanner._scan_content.""" + + def test_no_secrets_returns_allow(self): + """When scanner returns no secrets, result has block=False and nbr_secrets=0.""" + hook_scanner = AIHookScanner(_mock_scanner([])) + payload = HookPayload( + event_type=EventType.USER_PROMPT, + tool=None, + content="safe content", + identifier="id", + agent=Cursor(), + ) + result = hook_scanner._scan_content(payload) + assert isinstance(result, HookResult) + assert result.block is False + assert result.nbr_secrets == 0 + assert result.message == "" + + def test_with_secrets_returns_block_and_message(self): + """When scanner returns secrets, result has block=True, nbr_secrets and message set.""" + hook_scanner = AIHookScanner(_mock_scanner(["sk-xxx"])) + payload = HookPayload( + event_type=EventType.USER_PROMPT, + tool=None, + content="content with sk-xxx", + identifier="id", + agent=Cursor(), + ) + result = hook_scanner._scan_content(payload) + assert isinstance(result, HookResult) + assert result.block is True + assert result.nbr_secrets == 1 + assert "dummy-detector" in result.message + assert "secret" in result.message.lower() + assert "remove the secrets from your prompt" in result.message + + +class TestAIHookScannerScan: + """Unit tests for the AIHookScanner.scan() method.""" + + def test_empty_input_raises(self): + """Empty or whitespace-only input raises ValueError.""" + scanner = AIHookScanner(_mock_scanner([])) + with pytest.raises(ValueError, match="No input received on stdin"): + scanner.scan("") + with pytest.raises(ValueError, match="No input received on stdin"): + scanner.scan(" \n ") + + def test_scan_no_secrets_returns_zero(self): + """scan() with no secrets returns 0.""" + scanner = AIHookScanner(_mock_scanner([])) + data = { + "hook_event_name": "UserPromptSubmit", + "prompt": "hello world", + "transcript_path": "/home/user/.claude/projects/foo/session.jsonl", + "cursor_version": "1.2.3", + } + code = scanner.scan(json.dumps(data)) + assert code == 0 + + @patch("ggshield.verticals.ai.hooks.AIHookScanner._send_secret_notification") + def test_scan_post_tool_use_with_secrets_sends_notification( + self, mock_notify: MagicMock + ): + """scan() on POST_TOOL_USE with secrets sends a notification and returns 0 (no block).""" + scanner = AIHookScanner(_mock_scanner(["sk-xxx"])) + data = { + "hook_event_name": "PostToolUse", + "tool_name": "Bash", + "tool_input": {"command": "echo sk-xxx"}, + "tool_response": {"stdout": "sk-xxx\n"}, + "transcript_path": "/home/user/.claude/projects/foo/session.jsonl", + "session_id": "427ae0c5-0862-4e14-aa2c-12fad909c323", + } + code = scanner.scan(json.dumps(data)) + assert code == 0 + mock_notify.assert_called_once() + args = mock_notify.call_args[0] + assert args[0] == 1 # nbr_secrets + assert args[1] == Tool.BASH # tool + + def test_scan_pre_tool_use_with_secrets_blocks(self): + """scan() on PRE_TOOL_USE with secrets returns block result.""" + scanner = AIHookScanner(_mock_scanner(["sk-xxx"])) + data = { + "hook_event_name": "PreToolUse", + "tool_name": "Bash", + "tool_input": {"command": "echo sk-xxx"}, + "session_id": "427ae0c5-0862-4e14-aa2c-12fad909c323", + "transcript_path": "/home/user/.claude/projects/foo/session.jsonl", + } + code = scanner.scan(json.dumps(data)) + # Claude output_result always returns 0 + assert code == 0 + + def test_scan_no_content_returns_allow(self): + """scan() with no content returns 0 (and doesn't call the API).""" + mock_scanner = _mock_scanner([]) + scanner = AIHookScanner(mock_scanner) + data = { + "hook_event_name": "PreToolUse", + "tool_name": "Read", + "tool_input": {"file_path": "doesn-t-exist"}, + "cursor_version": "1.2.3", + } + code = scanner.scan(json.dumps(data)) + assert code == 0 + mock_scanner.scan.assert_not_called() + + def test_scan_payloads_refuse_empty_list(self): + """scan() with empty list of payloads raises ValueError.""" + scanner = AIHookScanner(_mock_scanner([])) + with pytest.raises(ValueError): + scanner._scan_payloads([]) + + +class TestMessageFromSecrets: + """Unit tests for AIHookScanner._message_from_secrets with different payload types.""" + + def test_message_for_bash_tool(self): + """Message for BASH tool mentions environment variables.""" + payload = HookPayload( + event_type=EventType.PRE_TOOL_USE, + tool=Tool.BASH, + content="echo sk-xxx", + identifier="echo sk-xxx", + agent=Cursor(), + ) + message = AIHookScanner._message_from_secrets([_make_secret("sk-xxx")], payload) + assert "remove the secrets from the command" in message + assert "environment variables" in message + + def test_message_for_read_tool(self): + """Message for READ tool mentions file content.""" + payload = HookPayload( + event_type=EventType.PRE_TOOL_USE, + tool=Tool.READ, + content="file content with secret", + identifier="/path/to/file", + agent=Cursor(), + ) + message = AIHookScanner._message_from_secrets([_make_secret("sk-xxx")], payload) + assert "remove the secrets from" in message + + def test_message_for_other_tool(self): + """Message for OTHER tool uses generic message.""" + payload = HookPayload( + event_type=EventType.PRE_TOOL_USE, + tool=Tool.OTHER, + content="some content", + identifier="id", + agent=Cursor(), + ) + message = AIHookScanner._message_from_secrets([_make_secret("sk-xxx")], payload) + assert "remove the secrets from the tool input" in message + + def test_message_escapes_markdown(self): + """When escape_markdown=True, asterisks in matches are replaced with dots.""" + payload = HookPayload( + event_type=EventType.USER_PROMPT, + tool=None, + content="content", + identifier="id", + agent=Cursor(), + ) + message = AIHookScanner._message_from_secrets( + [_make_secret("sk-xxx")], payload, escape_markdown=True + ) + # The message itself should not contain raw asterisks from matches + # (the header uses ** for bold which is intentional) + assert "Detected" in message + + +class TestSendSecretNotification: + """Unit tests for AIHookScanner._send_secret_notification.""" + + @patch("ggshield.verticals.ai.hooks.Notify") + def test_notification_for_bash_tool(self, mock_notify_cls: MagicMock): + """Notification for BASH tool says 'running a command'.""" + AIHookScanner._send_secret_notification(1, Tool.BASH, "Claude Code") + instance = mock_notify_cls.return_value + assert "running a command" in instance.message + assert "Claude Code" in instance.message + instance.send.assert_called_once() + + @patch("ggshield.verticals.ai.hooks.Notify") + def test_notification_for_read_tool(self, mock_notify_cls: MagicMock): + """Notification for READ tool says 'reading a file'.""" + AIHookScanner._send_secret_notification(2, Tool.READ, "Cursor") + instance = mock_notify_cls.return_value + assert "reading a file" in instance.message + assert "2" in instance.message + instance.send.assert_called_once() + + @patch("ggshield.verticals.ai.hooks.Notify") + def test_notification_for_other_tool(self, mock_notify_cls: MagicMock): + """Notification for OTHER tool says 'using a tool'.""" + AIHookScanner._send_secret_notification(1, Tool.OTHER, "Copilot") + instance = mock_notify_cls.return_value + assert "using a tool" in instance.message + instance.send.assert_called_once() + + class TestAIHookScannerParseInput: """Unit tests for AIHookparse_hook_input.""" diff --git a/tests/unit/verticals/secret/test_ai_hooks.py b/tests/unit/verticals/secret/test_ai_hooks.py deleted file mode 100644 index 5bfeca82cd..0000000000 --- a/tests/unit/verticals/secret/test_ai_hooks.py +++ /dev/null @@ -1,263 +0,0 @@ -import json -from collections import Counter -from pathlib import Path -from typing import List -from unittest.mock import MagicMock, patch - -import pytest - -from ggshield.utils.git_shell import Filemode -from ggshield.verticals.ai.agents import Cursor -from ggshield.verticals.secret import SecretScanner -from ggshield.verticals.secret.ai_hook import AIHookScanner, EventType, Payload -from ggshield.verticals.secret.ai_hook import Result as HookResult -from ggshield.verticals.secret.ai_hook import Tool -from ggshield.verticals.secret.secret_scan_collection import Result as ScanResult -from ggshield.verticals.secret.secret_scan_collection import Results, Secret - - -def _mock_scanner(matches: List[str]) -> MagicMock: - """Create a mock SecretScanner that returns the given Results from scan().""" - mock = MagicMock(spec=SecretScanner) - scan_result = Results( - results=[ - ScanResult( - filename="url", - filemode=Filemode.FILE, - path=Path("."), - url="url", - secrets=[_make_secret(match) for match in matches], - ignored_secrets_count_by_kind=Counter(), - ) - ], - errors=[], - ) - mock.scan.return_value = scan_result - return mock - - -def _make_secret(match_str: str = "***"): - """Minimal Secret for tests; _message_from_secrets only uses detector_display_name, validity, matches[].match.""" - mock_match = MagicMock() - mock_match.match = match_str - return Secret( - detector_display_name="dummy-detector", - detector_name="dummy-detector", - detector_group_name=None, - documentation_url=None, - validity="valid", - known_secret=False, - incident_url=None, - matches=[mock_match], - ignore_reason=None, - diff_kind=None, - is_vaulted=False, - vault_type=None, - vault_name=None, - vault_path=None, - vault_path_count=None, - ) - - -class TestAIHookScannerScanContent: - """Unit tests for AIHookScanner._scan_content.""" - - def test_no_secrets_returns_allow(self): - """When scanner returns no secrets, result has block=False and nbr_secrets=0.""" - hook_scanner = AIHookScanner(_mock_scanner([])) - payload = Payload( - event_type=EventType.USER_PROMPT, - tool=None, - content="safe content", - identifier="id", - agent=Cursor(), - ) - result = hook_scanner._scan_content(payload) - assert isinstance(result, HookResult) - assert result.block is False - assert result.nbr_secrets == 0 - assert result.message == "" - - def test_with_secrets_returns_block_and_message(self): - """When scanner returns secrets, result has block=True, nbr_secrets and message set.""" - hook_scanner = AIHookScanner(_mock_scanner(["sk-xxx"])) - payload = Payload( - event_type=EventType.USER_PROMPT, - tool=None, - content="content with sk-xxx", - identifier="id", - agent=Cursor(), - ) - result = hook_scanner._scan_content(payload) - assert isinstance(result, HookResult) - assert result.block is True - assert result.nbr_secrets == 1 - assert "dummy-detector" in result.message - assert "secret" in result.message.lower() - assert "remove the secrets from your prompt" in result.message - - -class TestAIHookScannerScan: - """Unit tests for the AIHookScanner.scan() method.""" - - def test_empty_input_raises(self): - """Empty or whitespace-only input raises ValueError.""" - scanner = AIHookScanner(_mock_scanner([])) - with pytest.raises(ValueError, match="No input received on stdin"): - scanner.scan("") - with pytest.raises(ValueError, match="No input received on stdin"): - scanner.scan(" \n ") - - def test_scan_no_secrets_returns_zero(self): - """scan() with no secrets returns 0.""" - scanner = AIHookScanner(_mock_scanner([])) - data = { - "hook_event_name": "UserPromptSubmit", - "prompt": "hello world", - "transcript_path": "/home/user/.claude/projects/foo/session.jsonl", - "cursor_version": "1.2.3", - } - code = scanner.scan(json.dumps(data)) - assert code == 0 - - @patch("ggshield.verticals.secret.ai_hook.AIHookScanner._send_secret_notification") - def test_scan_post_tool_use_with_secrets_sends_notification( - self, mock_notify: MagicMock - ): - """scan() on POST_TOOL_USE with secrets sends a notification and returns 0 (no block).""" - scanner = AIHookScanner(_mock_scanner(["sk-xxx"])) - data = { - "hook_event_name": "PostToolUse", - "tool_name": "Bash", - "tool_input": {"command": "echo sk-xxx"}, - "tool_response": {"stdout": "sk-xxx\n"}, - "transcript_path": "/home/user/.claude/projects/foo/session.jsonl", - "session_id": "427ae0c5-0862-4e14-aa2c-12fad909c323", - } - code = scanner.scan(json.dumps(data)) - assert code == 0 - mock_notify.assert_called_once() - args = mock_notify.call_args[0] - assert args[0] == 1 # nbr_secrets - assert args[1] == Tool.BASH # tool - - def test_scan_pre_tool_use_with_secrets_blocks(self): - """scan() on PRE_TOOL_USE with secrets returns block result.""" - scanner = AIHookScanner(_mock_scanner(["sk-xxx"])) - data = { - "hook_event_name": "PreToolUse", - "tool_name": "Bash", - "tool_input": {"command": "echo sk-xxx"}, - "session_id": "427ae0c5-0862-4e14-aa2c-12fad909c323", - "transcript_path": "/home/user/.claude/projects/foo/session.jsonl", - } - code = scanner.scan(json.dumps(data)) - # Claude output_result always returns 0 - assert code == 0 - - def test_scan_no_content_returns_allow(self): - """scan() with no content returns 0 (and doesn't call the API).""" - mock_scanner = _mock_scanner([]) - scanner = AIHookScanner(mock_scanner) - data = { - "hook_event_name": "PreToolUse", - "tool_name": "Read", - "tool_input": {"file_path": "doesn-t-exist"}, - "cursor_version": "1.2.3", - } - code = scanner.scan(json.dumps(data)) - assert code == 0 - mock_scanner.scan.assert_not_called() - - def test_scan_payloads_refuse_empty_list(self): - """scan() with empty list of payloads raises ValueError.""" - scanner = AIHookScanner(_mock_scanner([])) - with pytest.raises(ValueError): - scanner._scan_payloads([]) - - -class TestMessageFromSecrets: - """Unit tests for AIHookScanner._message_from_secrets with different payload types.""" - - def test_message_for_bash_tool(self): - """Message for BASH tool mentions environment variables.""" - payload = Payload( - event_type=EventType.PRE_TOOL_USE, - tool=Tool.BASH, - content="echo sk-xxx", - identifier="echo sk-xxx", - agent=Cursor(), - ) - message = AIHookScanner._message_from_secrets([_make_secret("sk-xxx")], payload) - assert "remove the secrets from the command" in message - assert "environment variables" in message - - def test_message_for_read_tool(self): - """Message for READ tool mentions file content.""" - payload = Payload( - event_type=EventType.PRE_TOOL_USE, - tool=Tool.READ, - content="file content with secret", - identifier="/path/to/file", - agent=Cursor(), - ) - message = AIHookScanner._message_from_secrets([_make_secret("sk-xxx")], payload) - assert "remove the secrets from" in message - - def test_message_for_other_tool(self): - """Message for OTHER tool uses generic message.""" - payload = Payload( - event_type=EventType.PRE_TOOL_USE, - tool=Tool.OTHER, - content="some content", - identifier="id", - agent=Cursor(), - ) - message = AIHookScanner._message_from_secrets([_make_secret("sk-xxx")], payload) - assert "remove the secrets from the tool input" in message - - def test_message_escapes_markdown(self): - """When escape_markdown=True, asterisks in matches are replaced with dots.""" - payload = Payload( - event_type=EventType.USER_PROMPT, - tool=None, - content="content", - identifier="id", - agent=Cursor(), - ) - message = AIHookScanner._message_from_secrets( - [_make_secret("sk-xxx")], payload, escape_markdown=True - ) - # The message itself should not contain raw asterisks from matches - # (the header uses ** for bold which is intentional) - assert "Detected" in message - - -class TestSendSecretNotification: - """Unit tests for AIHookScanner._send_secret_notification.""" - - @patch("ggshield.verticals.secret.ai_hook.Notify") - def test_notification_for_bash_tool(self, mock_notify_cls: MagicMock): - """Notification for BASH tool says 'running a command'.""" - AIHookScanner._send_secret_notification(1, Tool.BASH, "Claude Code") - instance = mock_notify_cls.return_value - assert "running a command" in instance.message - assert "Claude Code" in instance.message - instance.send.assert_called_once() - - @patch("ggshield.verticals.secret.ai_hook.Notify") - def test_notification_for_read_tool(self, mock_notify_cls: MagicMock): - """Notification for READ tool says 'reading a file'.""" - AIHookScanner._send_secret_notification(2, Tool.READ, "Cursor") - instance = mock_notify_cls.return_value - assert "reading a file" in instance.message - assert "2" in instance.message - instance.send.assert_called_once() - - @patch("ggshield.verticals.secret.ai_hook.Notify") - def test_notification_for_other_tool(self, mock_notify_cls: MagicMock): - """Notification for OTHER tool says 'using a tool'.""" - AIHookScanner._send_secret_notification(1, Tool.OTHER, "Copilot") - instance = mock_notify_cls.return_value - assert "using a tool" in instance.message - instance.send.assert_called_once() From 62a67916c04f1bccdd63ad9648f085334141a66c Mon Sep 17 00:00:00 2001 From: Paul PETIT Date: Thu, 19 Mar 2026 15:14:54 +0100 Subject: [PATCH 3/3] feat(mcp): discovery get machine information monitor mcp activity monitor discovery time call ai discovery endpoint --- .importlinter | 5 +- ggshield/__main__.py | 2 + ggshield/cmd/ai/__init__.py | 12 + ggshield/cmd/ai/discover.py | 70 +++ ggshield/verticals/ai/__init__.py | 10 + ggshield/verticals/ai/agents/__init__.py | 7 +- ggshield/verticals/ai/agents/claude_code.py | 94 +++- ggshield/verticals/ai/agents/copilot.py | 49 +- ggshield/verticals/ai/agents/cursor.py | 175 ++++++- ggshield/verticals/ai/config.py | 75 +++ ggshield/verticals/ai/discovery.py | 137 ++++++ ggshield/verticals/ai/hooks.py | 16 +- ggshield/verticals/ai/mcp.py | 46 ++ ggshield/verticals/ai/models.py | 307 +++++++++++- ggshield/verticals/ai/user.py | 202 ++++++++ pyproject.toml | 4 + tests/unit/verticals/ai/test_agents.py | 486 +++++++++++++++++++ tests/unit/verticals/ai/test_cmd_ai.py | 154 ++++++ tests/unit/verticals/ai/test_config.py | 177 +++++++ tests/unit/verticals/ai/test_discovery.py | 313 ++++++++++++ tests/unit/verticals/ai/test_hooks.py | 7 + tests/unit/verticals/ai/test_mcp_activity.py | 122 +++++ tests/unit/verticals/ai/test_models.py | 323 ++++++++++++ tests/unit/verticals/ai/test_user.py | 222 +++++++++ uv.lock | 2 + 25 files changed, 3004 insertions(+), 13 deletions(-) create mode 100644 ggshield/cmd/ai/__init__.py create mode 100644 ggshield/cmd/ai/discover.py create mode 100644 ggshield/verticals/ai/config.py create mode 100644 ggshield/verticals/ai/discovery.py create mode 100644 ggshield/verticals/ai/mcp.py create mode 100644 ggshield/verticals/ai/user.py create mode 100644 tests/unit/verticals/ai/test_agents.py create mode 100644 tests/unit/verticals/ai/test_cmd_ai.py create mode 100644 tests/unit/verticals/ai/test_config.py create mode 100644 tests/unit/verticals/ai/test_discovery.py create mode 100644 tests/unit/verticals/ai/test_mcp_activity.py create mode 100644 tests/unit/verticals/ai/test_models.py create mode 100644 tests/unit/verticals/ai/test_user.py diff --git a/.importlinter b/.importlinter index 3e5669e1a3..9308e69e03 100644 --- a/.importlinter +++ b/.importlinter @@ -9,7 +9,7 @@ name = ggshield-layers type = layers layers = ggshield.__main__ - ggshield.cmd.auth | ggshield.cmd.config | ggshield.cmd.hmsl | ggshield.cmd.honeytoken | ggshield.cmd.install | ggshield.cmd.plugin | ggshield.cmd.quota | ggshield.cmd.secret | ggshield.cmd.status | ggshield.cmd.utils + ggshield.cmd.ai | ggshield.cmd.auth | ggshield.cmd.config | ggshield.cmd.hmsl | ggshield.cmd.honeytoken | ggshield.cmd.install | ggshield.cmd.plugin | ggshield.cmd.quota | ggshield.cmd.secret | ggshield.cmd.status | ggshield.cmd.utils ggshield.verticals.ai | ggshield.verticals.auth | ggshield.verticals.hmsl | ggshield.verticals.secret ggshield.core click | ggshield.utils | pygitguardian @@ -22,6 +22,7 @@ unmatched_ignore_imports_alerting = warn name = verticals-cmd-transversals type = forbidden source_modules = + ggshield.cmd.ai ggshield.cmd.auth ggshield.cmd.config ggshield.cmd.hmsl @@ -38,6 +39,8 @@ forbidden_modules = ggshield.verticals.hmsl ggshield.verticals.secret ignore_imports = + ggshield.cmd.ai.** -> ggshield.verticals.ai + ggshield.cmd.ai.** -> ggshield.verticals.ai.** ggshield.cmd.auth.** -> ggshield.verticals.auth ggshield.cmd.auth.** -> ggshield.verticals.auth.** ggshield.cmd.auth.** -> ggshield.verticals.hmsl.** diff --git a/ggshield/__main__.py b/ggshield/__main__.py index 0baf27b776..0744cbd72f 100644 --- a/ggshield/__main__.py +++ b/ggshield/__main__.py @@ -10,6 +10,7 @@ import click from ggshield import __version__ +from ggshield.cmd.ai import ai_group from ggshield.cmd.auth import auth_group from ggshield.cmd.config import config_group from ggshield.cmd.hmsl import hmsl_group @@ -88,6 +89,7 @@ def _load_plugins() -> PluginRegistry: @click.group( context_settings={"help_option_names": ["-h", "--help"]}, commands={ + "ai": ai_group, "auth": auth_group, "config": config_group, "plugin": plugin_group, diff --git a/ggshield/cmd/ai/__init__.py b/ggshield/cmd/ai/__init__.py new file mode 100644 index 0000000000..5e0ee14f45 --- /dev/null +++ b/ggshield/cmd/ai/__init__.py @@ -0,0 +1,12 @@ +from typing import Any + +import click + +from ggshield.cmd.ai.discover import discover_cmd +from ggshield.cmd.utils.common_options import add_common_options + + +@click.group(commands={"discover": discover_cmd}) +@add_common_options() +def ai_group(**kwargs: Any) -> None: + """Commands to work with MCP (Model Context Protocol) servers.""" diff --git a/ggshield/cmd/ai/discover.py b/ggshield/cmd/ai/discover.py new file mode 100644 index 0000000000..f5916085bd --- /dev/null +++ b/ggshield/cmd/ai/discover.py @@ -0,0 +1,70 @@ +""" +MCP Discover command - Discovers MCP servers and optionally probes them +for tools, resources, and prompts. +""" + +import json +from typing import Any + +import click +from rich import print + +from ggshield.cmd.utils.common_options import add_common_options +from ggshield.cmd.utils.context_obj import ContextObj +from ggshield.core import ui +from ggshield.core.client import create_client_from_config +from ggshield.core.errors import APIKeyCheckError, UnknownInstanceError +from ggshield.verticals.ai.discovery import ( + discover_ai_configuration, + save_discovery_cache, + submit_ai_discovery, +) + + +@click.command(name="discover") +@click.option( + "--json", + "use_json", + is_flag=True, + default=False, + help="Output as JSON", +) +@add_common_options() +@click.pass_context +def discover_cmd( + ctx: click.Context, + use_json: bool, + **kwargs: Any, +) -> None: + """ + Discover MCP servers and their configuration. + + Parses MCP configuration files from supported assistants + + Examples: + ggshield mcp discover + ggshield mcp discover --json + """ + + config = discover_ai_configuration() + + if use_json: + click.echo(json.dumps(config.model_dump(mode="json"), indent=2)) + else: + print(config) + + ctx_obj = ContextObj.get(ctx) + try: + client = create_client_from_config(ctx_obj.config) + except (APIKeyCheckError, UnknownInstanceError) as exc: + ui.display_warning( + f"Skipping upload of AI discovery to GitGuardian ({exc}). " + "Authenticate with `ggshield auth login` to enable upload." + ) + return + + try: + config = submit_ai_discovery(client, config) + save_discovery_cache(config) + except Exception as exc: + ui.display_warning(f"Could not upload AI discovery to GitGuardian: {exc}") diff --git a/ggshield/verticals/ai/__init__.py b/ggshield/verticals/ai/__init__.py index cbd7e4db0c..8ca3b1590d 100644 --- a/ggshield/verticals/ai/__init__.py +++ b/ggshield/verticals/ai/__init__.py @@ -1,4 +1,10 @@ from .agents import AGENTS +from .config import load_mcp_config +from .discovery import ( + discover_ai_configuration, + load_discovery_cache, + save_discovery_cache, +) from .hooks import AIHookScanner from .installation import install_hooks @@ -6,5 +12,9 @@ __all__ = [ "AGENTS", "AIHookScanner", + "discover_ai_configuration", "install_hooks", + "load_discovery_cache", + "load_mcp_config", + "save_discovery_cache", ] diff --git a/ggshield/verticals/ai/agents/__init__.py b/ggshield/verticals/ai/agents/__init__.py index 7289463cc7..6e354d473a 100644 --- a/ggshield/verticals/ai/agents/__init__.py +++ b/ggshield/verticals/ai/agents/__init__.py @@ -1,9 +1,14 @@ +from typing import Dict + +from ..models import Agent from .claude_code import Claude from .copilot import Copilot from .cursor import Cursor -AGENTS = {agent.name: agent for agent in [Cursor(), Claude(), Copilot()]} +AGENTS: Dict[str, Agent] = { + agent.name: agent for agent in [Cursor(), Claude(), Copilot()] +} __all__ = ["AGENTS", "Claude", "Copilot", "Cursor"] diff --git a/ggshield/verticals/ai/agents/claude_code.py b/ggshield/verticals/ai/agents/claude_code.py index b9ea7670f1..27e28c6b44 100644 --- a/ggshield/verticals/ai/agents/claude_code.py +++ b/ggshield/verticals/ai/agents/claude_code.py @@ -1,10 +1,22 @@ import json +import re from pathlib import Path -from typing import Any, Dict, List, Optional +from typing import Any, Dict, Iterator, List, Optional import click -from ..models import Agent, EventType, HookResult +from ggshield.core.dirs import get_user_home_dir + +from ..models import ( + Agent, + AIDiscovery, + EventType, + HookPayload, + HookResult, + MCPActivityRequest, + MCPConfiguration, + Scope, +) class Claude(Agent): @@ -18,6 +30,10 @@ def name(self) -> str: def display_name(self) -> str: return "Claude Code" + @property + def config_folder(self) -> Path: + return get_user_home_dir() / ".claude" + def output_result(self, result: HookResult) -> int: response = {} if result.block: @@ -106,3 +122,77 @@ def settings_locate( if "ggshield" in command or "" in command: return obj return None + + def project_mcp_file(self, directory: Path) -> Path: + return directory / ".mcp.json" + + def _get_user_mcp_configurations(self) -> Iterator[MCPConfiguration]: + """Look into ~/.claude.json for both user-level and project-level MCP server entries.""" + # Load config file + filepath = get_user_home_dir() / ".claude.json" + if not (data := self._load_json_file(filepath)): + return + + # User-level mcpServers + yield from self._parse_servers_block(data, Scope.USER, None) + + # Per-project entries in projects dict + projects = data.get("projects", {}) + if not isinstance(projects, dict): + return + for project_key, project_data in projects.items(): + if not isinstance(project_data, dict): + continue + yield from self._parse_servers_block( + project_data, Scope.USER, Path(project_key) + ) + + def discover_project_directories(self) -> Iterator[Path]: + """Discover project directories by scraping config files.""" + history_file = self.config_folder / "history.jsonl" + projects = set() + for line in self._load_jsonl_file(history_file): + if "project" in line: + projects.add(Path(line["project"])) + for project in projects: + if project.is_dir(): + yield project.resolve() + + def parse_mcp_activity( + self, payload: HookPayload, ai_config: AIDiscovery + ) -> MCPActivityRequest: + """Parse the MCP activity from an MCP hook payload.""" + + # Claude Code's hook tool name is "mcp__{server}__{tool}" + raw_tool_name: str = payload.raw.get("tool_name", "") + parts = raw_tool_name.split("__") + # The server name can be anything, but we assume no MCP tool has a "__" in its name + tool = parts[-1] + server_cfg_name = "__".join(parts[1:-1]) + + # Lookup the server name based on its configuration name + # Fallback to the server name if not found + server_name = server_cfg_name + for server in ai_config.servers: + for configuration in server.configurations: + if _mangle_server_name(configuration.name) == server_cfg_name: + server_name = server.name + break + + return MCPActivityRequest( + user=ai_config.user, + tool=tool, + server=server_name, + agent=self.name, + model="", + cwd=Path(payload.raw.get("cwd", "")), + input=payload.raw.get("tool_input", {}), + ) + + +MANGLING_PATTERN = re.compile(r"[^A-Za-z0-9-]") + + +def _mangle_server_name(name: str) -> str: + """Mangle a server name in the same way Claude Code does.""" + return MANGLING_PATTERN.sub("_", name) diff --git a/ggshield/verticals/ai/agents/copilot.py b/ggshield/verticals/ai/agents/copilot.py index 300a8ce132..6e846b402a 100644 --- a/ggshield/verticals/ai/agents/copilot.py +++ b/ggshield/verticals/ai/agents/copilot.py @@ -1,9 +1,12 @@ import json from pathlib import Path +from typing import Iterator import click -from ..models import EventType, HookResult +from ggshield.core.dirs import get_user_home_dir + +from ..models import AIDiscovery, EventType, HookPayload, HookResult, MCPActivityRequest from .claude_code import Claude @@ -21,6 +24,10 @@ def name(self) -> str: def display_name(self) -> str: return "Copilot Chat" + @property + def config_folder(self) -> Path: + return get_user_home_dir() / ".config" / "Code" / "User" + def output_result(self, result: HookResult) -> int: response = {} if result.block: @@ -45,3 +52,43 @@ def output_result(self, result: HookResult) -> int: @property def settings_path(self) -> Path: return Path(".github") / "hooks" / "hooks.json" + + def project_mcp_file(self, directory: Path) -> Path: + return directory / ".vscode" / "mcp.json" + + def discover_project_directories(self) -> Iterator[Path]: + # Try to parse workspaces settings. + for file in self.config_folder.glob("workspaceStorage/*/workspace.json"): + if (data := self._load_json_file(file)) and "folder" in data: + path = Path(data["folder"].removeprefix("file://")) + if path.is_dir(): + yield path.resolve() + + def parse_mcp_activity( + self, payload: HookPayload, ai_config: AIDiscovery + ) -> MCPActivityRequest: + """Parse the MCP activity from an MCP hook payload.""" + + # Copilot's hook tool name is "mcp_{server}_{tool}" + # which is unfortunate because a lot of tools have a "_" in their name. + raw_tool_name: str = payload.raw.get("tool_name", "") + server_cfg_name, tool = raw_tool_name.split("_") + + # Lookup the server name based on its configuration name + # Fallback to the server name if not found + server_name = server_cfg_name + for server in ai_config.servers: + for configuration in server.configurations: + if configuration.name == server_cfg_name: + server_name = configuration.name + break + + return MCPActivityRequest( + user=ai_config.user, + tool=tool, + server=server_name, + agent=self.name, + model="", + cwd=Path(payload.raw.get("cwd", "")), + input=payload.raw.get("tool_input", {}), + ) diff --git a/ggshield/verticals/ai/agents/cursor.py b/ggshield/verticals/ai/agents/cursor.py index c75e807380..90484bc261 100644 --- a/ggshield/verticals/ai/agents/cursor.py +++ b/ggshield/verticals/ai/agents/cursor.py @@ -1,10 +1,24 @@ import json from pathlib import Path -from typing import Any, Dict, List, Optional +from typing import Any, Dict, Iterator, List, Optional import click -from ..models import Agent, EventType, HookResult +from ggshield.core.dirs import get_user_home_dir + +from ..models import ( + Agent, + AIDiscovery, + EventType, + HookPayload, + HookResult, + MCPActivityRequest, + MCPArgumentInfo, + MCPPromptInfo, + MCPResourceInfo, + MCPServer, + MCPToolInfo, +) class Cursor(Agent): @@ -18,6 +32,10 @@ def name(self) -> str: def display_name(self) -> str: return "Cursor" + @property + def config_folder(self) -> Path: + return get_user_home_dir() / ".cursor" + def output_result(self, result: HookResult) -> int: response = {} if result.payload.event_type == EventType.USER_PROMPT: @@ -62,3 +80,156 @@ def settings_locate( if "ggshield" in command or "" in command: return obj return None + + def project_mcp_file(self, directory: Path) -> Path: + return directory / ".cursor" / "mcp.json" + + def discover_project_directories(self) -> Iterator[Path]: + # Because Cursor is based on VS Code, we can reuse the same logic than Copilot. + user_folder = get_user_home_dir() / ".config" / "Cursor" / "User" + for file in user_folder.glob("workspaceStorage/*/workspace.json"): + if (data := self._load_json_file(file)) and "folder" in data: + path = Path(data["folder"].removeprefix("file://")) + if path.is_dir(): + yield path.resolve() + + def discover_capabilities(self, server: MCPServer) -> bool: + # General Cursor strategy: + # For each project where Cursor was used, it created a folder with the project name + # in its configuration folder. Inside that folder, it stores metadata for every + # MCP server available in that project. + for configuration in server.configurations: + # Look for Cursor configurations + if configuration.agent != self.name: + continue + # We need a folder. Note: this also works for user-level configurations. + # as Cursor will have a `home-` "project". + if configuration.project is None: + continue + + # Lookup where Cursor stores the capabilities for the given project. + mangled = configuration.project.as_posix().replace("/", "-").lstrip("-") + folder = self.config_folder / "projects" / mangled / "mcps" + if not folder.exists(): + continue + # Look for a SERVER_METADATA.json file with the expected name. + # (each subfolder corresponds to a different MCP server) + for file in folder.glob("*/SERVER_METADATA.json"): + metadata = self._load_json_file(file) + if metadata and metadata.get("serverName") == configuration.name: + # Found it! Update the folder + folder = file.parent + break + else: + # We didn't find our MCP server's metadata. Try next configuration. + continue + + # If we reach this code, we found our MCP server's metadata folder. + # Hopefully it is connected. If not, Cursor creates a STATUS.md file. + if (folder / "STATUS.md").exists(): + # Don't go further, we may risk discovering only an "mcp_auth" tool + # whereas the MCP server may be properly connected in another project. + continue + + filled = False + # Tools + for file in folder.glob("tools/*.json"): + tool = self._load_json_file(file) + if not isinstance(tool, dict) or "name" not in tool: + continue + server.tools.append( + MCPToolInfo( + name=tool["name"], + description=tool.get("description", ""), + arguments=_parse_tool_arguments(tool.get("arguments")), + ) + ) + filled = True + # Resources + for file in folder.glob("resources/*.json"): + resource = self._load_json_file(file) + if not isinstance(resource, dict) or "uri" not in resource: + continue + server.resources.append( + MCPResourceInfo( + uri=resource["uri"], + name=resource.get("name", ""), + description=resource.get("description", ""), + mime_type=resource.get("mimeType", ""), + ) + ) + filled = True + # Prompts + for file in folder.glob("prompts/*.json"): + prompt = self._load_json_file(file) + if not isinstance(prompt, dict) or "name" not in prompt: + continue + server.prompts.append( + MCPPromptInfo( + name=prompt["name"], description=prompt.get("description", "") + ) + ) + filled = True + if filled: + # Discovery done. Early return. + return True + + return False + + def parse_mcp_activity( + self, payload: HookPayload, ai_config: AIDiscovery + ) -> MCPActivityRequest: + """Parse the MCP activity from an MCP hook payload.""" + + # Cursor only sends the MCP tool, not the server. + # Fortunately, we should have been able to discover the tools earlier. + + tools_to_server = {} + for server in ai_config.servers: + for tool in server.tools: + # Hopefully we won't have duplicates + tools_to_server[tool.name] = server.name + + raw_tool_name: str = payload.raw.get("tool_name", "") + tool_name = raw_tool_name.removeprefix("MCP:") + server_name = tools_to_server.get(tool_name, "") + + return MCPActivityRequest( + user=ai_config.user, + tool=tool_name, + server=server_name, + agent=self.name, + model=payload.raw.get("model", ""), + cwd=Path(payload.raw.get("workspace_roots", [""])[0]), + input=payload.raw.get("tool_input", {}), + ) + + +def _parse_tool_arguments( + schema: Optional[Dict[str, Any]], +) -> Optional[List[MCPArgumentInfo]]: + """Parse a JSON-Schema ``arguments`` object into a list of MCPArgumentInfo. + + The schema is expected to follow the standard MCP tool descriptor format:: + + {"type": "object", "properties": {...}, "required": [...]} + """ + if not isinstance(schema, dict): + return None + properties = schema.get("properties") + if not isinstance(properties, dict): + return None + required_set = set(schema.get("required", [])) + arguments: List[MCPArgumentInfo] = [] + for name, prop in properties.items(): + if not isinstance(prop, dict): + continue + arguments.append( + MCPArgumentInfo( + name=name, + type=prop.get("type", "string"), + description=prop.get("description"), + required=name in required_set, + ) + ) + return arguments or None diff --git a/ggshield/verticals/ai/config.py b/ggshield/verticals/ai/config.py new file mode 100644 index 0000000000..9ebfa96c66 --- /dev/null +++ b/ggshield/verticals/ai/config.py @@ -0,0 +1,75 @@ +""" +MCP configuration loading utilities. +""" + +import json +from pathlib import Path +from typing import Any, Dict, List, Optional, Union + +from ggshield.core.dirs import get_user_home_dir + + +def load_json_file(path: Path) -> Union[Dict[str, Any], List[Any]]: + try: + return json.loads(path.read_text()) + except (OSError, json.JSONDecodeError): + return {} + + +def save_json_file(path: Path, data: Union[Dict[str, Any], List[Any]]) -> None: + try: + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(json.dumps(data, indent=4)) + except OSError: + pass + + +def load_mcp_config(workspace_roots: List[str]) -> Dict[str, Any]: + for workspace in workspace_roots: + workspace_mcp = Path(workspace) / ".cursor" / "mcp.json" + if workspace_mcp.exists(): + config = load_json_file(workspace_mcp) + if isinstance(config, dict): + return config + + global_mcp = get_user_home_dir() / ".cursor" / "mcp.json" + if global_mcp.exists(): + config = load_json_file(global_mcp) + if isinstance(config, dict): + return config + + return {} + + +def extract_host_from_config(server_config: Optional[Dict[str, Any]]) -> Optional[str]: + if not server_config: + return None + + env_vars = server_config.get("env", {}) + + if "CLICKHOUSE_HOST" in env_vars: + return env_vars["CLICKHOUSE_HOST"] + + if "GITLAB_API_URL" in env_vars: + url = env_vars["GITLAB_API_URL"] + return url.replace("https://", "").replace("http://", "").split("/")[0] + + args = server_config.get("args", []) + for arg in args: + if isinstance(arg, str): + if arg.startswith("https://"): + return arg.replace("https://", "").split("/")[0] + if arg.startswith("http://"): + return arg.replace("http://", "").split("/")[0] + + return None + + +def get_mcp_remote_url(server_config: Dict[str, Any]) -> Optional[str]: + args = server_config.get("args", []) + for arg in args: + if isinstance(arg, str) and ( + arg.startswith("http://") or arg.startswith("https://") + ): + return arg + return None diff --git a/ggshield/verticals/ai/discovery.py b/ggshield/verticals/ai/discovery.py new file mode 100644 index 0000000000..ad7662a38e --- /dev/null +++ b/ggshield/verticals/ai/discovery.py @@ -0,0 +1,137 @@ +""" +MCP Discovery - Discovers MCP server configurations and manages probe result caches. +""" + +import json +from collections import defaultdict +from pathlib import Path +from time import monotonic +from typing import Dict, List, Optional + +from pygitguardian import GGClient + +from ggshield.core.dirs import get_cache_dir + +from .agents import AGENTS +from .config import save_json_file +from .models import AIDiscovery, MCPConfiguration, MCPServer +from .user import get_user_info + + +AI_DISCOVERY_CACHE_FILENAME = "ai_discovery.json" +AI_DISCOVERY_SUBMIT_ENDPOINT = "nhi/ai/discovery" + + +def refresh_and_maybe_submit_discovery(client: GGClient) -> AIDiscovery: + """Always run discovery, compare with cache, submit only if changed.""" + cached = load_discovery_cache() + # If we already have a machine id, reuse it. + machine_id = cached.user.machine_id if cached else None + discovery = discover_ai_configuration(machine_id=machine_id) + + # Nothing changed, + if cached and not discovery.has_changed_from(cached): + return cached + + try: + # Get the updated version of the discovery, filled with data from the API. + discovery = submit_ai_discovery(client, discovery) + save_discovery_cache(discovery) + except Exception: + pass # TODO(paul): handle error + + return discovery + + +def discover_ai_configuration(machine_id: Optional[str] = None) -> AIDiscovery: + """ + Discover configurations from all supported assistants. + + Args: + directories: additional project directories to scan. + """ + start_time = monotonic() + mcp_configurations: List[MCPConfiguration] = [] + + # Discovered project directories + projects = {Path.cwd().resolve()} + for agent in AGENTS.values(): + projects.update(agent.discover_project_directories()) + + # Discover MCP configurations + for agent in AGENTS.values(): + mcp_configurations.extend(agent.discover_mcp_configurations(projects)) + + # Merge MCP configurations into servers + servers = _merge_mcp_configurations(mcp_configurations) + + # Try to find the servers' capabilities + for server in servers: + for agent in AGENTS.values(): + if agent.discover_capabilities(server): + # Discovery succeeded for this server. Early return. + break + + # Add user information + user = get_user_info(machine_id=machine_id) + discovery_duration = monotonic() - start_time + return AIDiscovery( + user=user, + servers=servers, + discovery_duration=discovery_duration, + ) + + +def save_discovery_cache(config: AIDiscovery) -> None: + """ + Save probe results to cache. + """ + cache_path = get_cache_dir() / AI_DISCOVERY_CACHE_FILENAME + save_json_file(cache_path, config.model_dump(mode="json")) + + +def submit_ai_discovery(client: GGClient, discovery: AIDiscovery) -> AIDiscovery: + """ + Send discovery results to the GitGuardian API. + + Returns the updated discovery. Raises an exception if the request fails. + """ + response = client.post( + endpoint=AI_DISCOVERY_SUBMIT_ENDPOINT, + data=discovery.model_dump(mode="json"), + ) + assert response.status_code == 200, response.text + return AIDiscovery.model_validate_json(response.text) + + +def load_discovery_cache() -> Optional[AIDiscovery]: + """Load discovery cache if it exists. + + Returns None if the cache does not exist. + """ + cache_path = get_cache_dir() / AI_DISCOVERY_CACHE_FILENAME + if not cache_path.exists(): + return None + try: + return AIDiscovery.model_validate_json(cache_path.read_text()) + except (OSError, json.JSONDecodeError): + return None + + +def _merge_mcp_configurations( + mcp_configurations: List[MCPConfiguration], +) -> List[MCPServer]: + """Merge MCP configurations into servers. + + This is a first naive deduplication of MCP configurations based on their name. + Deduplicating is useful to avoid discovering capabilities for the same server multiple times. + We expect it to be improved by GIM later. + """ + servers: Dict[str, List[MCPConfiguration]] = defaultdict(list) + for configuration in mcp_configurations: + servers[configuration.name].append(configuration) + + return [ + MCPServer(name=name, configurations=configurations) + for name, configurations in servers.items() + ] diff --git a/ggshield/verticals/ai/hooks.py b/ggshield/verticals/ai/hooks.py index 4307398c79..7dc035dd01 100644 --- a/ggshield/verticals/ai/hooks.py +++ b/ggshield/verticals/ai/hooks.py @@ -102,8 +102,7 @@ def parse_hook_input(raw_content: str) -> list[HookPayload]: payloads.extend(_parse_user_prompt(content, event_type, agent)) elif event_type == EventType.PRE_TOOL_USE: - tool_name = data.get("tool_name", "").lower() - tool = TOOL_NAME_TO_TOOL.get(tool_name, Tool.OTHER) + tool = _parse_tool(data) tool_input = data.get("tool_input", {}) # Select the content based on the tool if tool == Tool.BASH: @@ -114,8 +113,7 @@ def parse_hook_input(raw_content: str) -> list[HookPayload]: identifier = lookup(tool_input, ["file_path", "filePath"], "") elif event_type == EventType.POST_TOOL_USE: - tool_name = data.get("tool_name", "").lower() - tool = TOOL_NAME_TO_TOOL.get(tool_name, Tool.OTHER) + tool = _parse_tool(data) content = data.get("tool_output", "") or data.get("tool_response", {}) # Claude Code returns a dict for the tool output if isinstance(content, (dict, list)): @@ -132,11 +130,20 @@ def parse_hook_input(raw_content: str) -> list[HookPayload]: content=content, identifier=identifier, agent=agent, + raw=data, ) ) return payloads +def _parse_tool(data: Dict[str, Any]) -> Tool: + """Parse the tool name.""" + tool_name = data.get("tool_name", "").lower() + if tool_name.startswith("mcp"): + return Tool.MCP + return TOOL_NAME_TO_TOOL.get(tool_name, Tool.OTHER) + + def _detect_agent(data: Dict[str, Any]) -> Agent: """Detect the AI code assistant.""" if "cursor_version" in data: @@ -167,6 +174,7 @@ def _parse_user_prompt( content="", identifier=match, agent=agent, + raw={}, ) ) return payloads diff --git a/ggshield/verticals/ai/mcp.py b/ggshield/verticals/ai/mcp.py new file mode 100644 index 0000000000..835d8ee57b --- /dev/null +++ b/ggshield/verticals/ai/mcp.py @@ -0,0 +1,46 @@ +from pydantic import ValidationError +from pygitguardian import GGClient +from requests.exceptions import ConnectionError as RequestsConnectionError + +from ggshield.verticals.ai.discovery import refresh_and_maybe_submit_discovery + +from .models import HookPayload, MCPActivityResponse + +# GitGuardian API (v1): POST JSON body matching MCPActivityRequest; response matches MCPActivityResponse. +MCP_ACTIVITY_SUBMIT_ENDPOINT = "nhi/ai/mcp-activity" + + +def _mcp_activity_fail_open() -> MCPActivityResponse: + return MCPActivityResponse(allowed=True, reason="") + + +def send_mcp_activity(client: GGClient, payload: HookPayload) -> MCPActivityResponse: + """Build the MCP activity request and send it to the GitGuardian API. + + Args: + client: GitGuardian API client (same instance as secret scans). + payload: Hook payload for the MCP pre-tool event. + + Returns: + Policy response from the API, or allow if the request fails (fail-open). + """ + + ai_config = refresh_and_maybe_submit_discovery(client) + request = payload.agent.parse_mcp_activity(payload, ai_config) + body = request.model_dump(mode="json") + + try: + response = client.post( + endpoint=MCP_ACTIVITY_SUBMIT_ENDPOINT, + data=body, + ) + except RequestsConnectionError: + return _mcp_activity_fail_open() + + if not response.ok: + return _mcp_activity_fail_open() + + try: + return MCPActivityResponse.model_validate_json(response.text) + except ValidationError: + return _mcp_activity_fail_open() diff --git a/ggshield/verticals/ai/models.py b/ggshield/verticals/ai/models.py index fecf9ddd27..68392e5443 100644 --- a/ggshield/verticals/ai/models.py +++ b/ggshield/verticals/ai/models.py @@ -1,8 +1,12 @@ +import json from abc import ABC, abstractmethod +from collections.abc import Iterable, Iterator from dataclasses import dataclass from enum import Enum, auto from pathlib import Path -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Tuple + +from pydantic import BaseModel, Field from ggshield.core.scan import File, Scannable, StringScannable from ggshield.utils.files import is_path_binary @@ -24,6 +28,7 @@ class Tool(Enum): BASH = auto() READ = auto() + MCP = auto() # We are not interested in other tools for now OTHER = auto() @@ -49,6 +54,7 @@ class HookPayload: content: str identifier: str agent: "Agent" + raw: Dict[str, Any] @property def scannable(self) -> Scannable: @@ -65,12 +71,182 @@ def empty(self) -> bool: return not self.scannable.is_longer_than(0) +class Transport(Enum): + STDIO = "stdio" + HTTP = "http" + SSE = "sse" + + +class Scope(Enum): + USER = "user" + PROJECT = "project" + + +class MCPArgumentInfo(BaseModel): + name: str + type: str + description: Optional[str] = None + required: bool = False + + +class MCPToolInfo(BaseModel): + name: str + description: Optional[str] = None + arguments: Optional[List[MCPArgumentInfo]] = None + + +class MCPResourceInfo(BaseModel): + uri: str + name: Optional[str] = None + description: Optional[str] = None + mime_type: Optional[str] = None + + +class MCPPromptInfo(BaseModel): + name: str + description: Optional[str] = None + + +ConfigurationKey = Tuple[str, Scope, Optional[Path], str] + + +class MCPConfiguration(BaseModel): + name: str + agent: str + scope: Scope + transport: Transport + project: Optional[Path] = None + # stdio fields + command: Optional[str] = None + args: List[str] = Field(default_factory=list) + env: Dict[str, str] = Field(default_factory=dict) + # remote fields + url: Optional[str] = None + headers: Dict[str, str] = Field(default_factory=dict) + + @property + def key(self) -> ConfigurationKey: + """Return a unique key for the configuration. + + We consider that their should be unicity over (agent, scope, project, name) + """ + return (self.agent, self.scope, self.project, self.name) + + +class MCPServer(BaseModel): + name: str + display_name: Optional[str] = None + tools: List[MCPToolInfo] = Field(default_factory=list) + resources: List[MCPResourceInfo] = Field(default_factory=list) + prompts: List[MCPPromptInfo] = Field(default_factory=list) + configurations: List[MCPConfiguration] = Field(default_factory=list) + + def has_capabilities_unknown_to(self, other: "MCPServer") -> bool: + """Check if the server has capabilities unknown to the other server.""" + # Note: we assume that if we have discovered a capability, + # then we have everything (name, description, etc.) + # So we simply check if we have names the other doesn't. + other_tools = {tool.name for tool in other.tools} + for tool in self.tools: + if tool.name not in other_tools: + return True + + other_resources = {resource.uri for resource in other.resources} + for resource in self.resources: + if resource.uri not in other_resources: + return True + + other_prompts = {prompt.name for prompt in other.prompts} + for prompt in self.prompts: + if prompt.name not in other_prompts: + return True + + return False + + +class UserInfo(BaseModel): + hostname: str + username: str + machine_id: str + user_email: Optional[str] = None + + +class AIDiscovery(BaseModel): + user: UserInfo + servers: List[MCPServer] = Field(default_factory=list) + # Metadata for analytics + discovery_duration: float # in s + + @property + def configurations(self) -> Iterator[MCPConfiguration]: + """Iterate over all MCP configurations.""" + for server in self.servers: + yield from server.configurations + + def has_changed_from(self, other: "AIDiscovery") -> bool: + """Check if the discovery has changed since a previous discovery.""" + # We compare : + # 1. user info exactly + if self.user != other.user: + return True + + # 2. MCP configurations should be the same (both in number and content) + other_configurations = {conf.key: conf for conf in other.configurations} + new_configurations = {conf.key: conf for conf in self.configurations} + if other_configurations != new_configurations: + return True + + # 3. Servers may have been overriden by GIM, but we still want to detect + # whether we discovered new capabilities unknown to GIM. + # First, build a map to find the server(s) to compare to + # (we know that the keys will be exactly our configurations, thanks to step 2) + other_servers: Dict[ConfigurationKey, MCPServer] = {} + for server in other.servers: + for configuration in server.configurations: + other_servers[configuration.key] = server + # Then, for each server we found, check if we have capabilities unknown to GIM + for server in self.servers: + # No data, no need to compare + if not server.tools and not server.resources and not server.prompts: + continue + # We don't know how our naive deduplication have been overriden by GIM, + # so we need to compare capabilities to each distinct destination. + candidates_by_name: Dict[str, MCPServer] = {} + for conf in server.configurations: + other_server = other_servers[conf.key] + candidates_by_name[other_server.name] = other_server + # If at least one has our capabilities, then no need to update. + # said otherwise, update if all candidates don't have our capabilities. + if all( + server.has_capabilities_unknown_to(candidate) + for candidate in candidates_by_name.values() + ): + return True + + return False + + +class MCPActivityResponse(BaseModel): + allowed: bool + reason: str + + +class MCPActivityRequest(BaseModel): + user: UserInfo + tool: str + server: str + agent: str + model: str + cwd: Path + input: Dict[str, Any] + + class Agent(ABC): """ Class that can be derived to implement behavior specific to some AI code assistants. """ - # Metadata + # Properties @property @abstractmethod @@ -82,6 +258,11 @@ def display_name(self) -> str: def name(self) -> str: """The name of the agent.""" + @property + @abstractmethod + def config_folder(self) -> Path: + """The folder where the assistant's config files are stored.""" + # Hooks @abstractmethod @@ -128,3 +309,125 @@ def settings_locate( Returns: the object to update, or None if no object was found. """ return None + + # Discovery + + @abstractmethod + def project_mcp_file(self, directory: Path) -> Path: + """The file where MCP servers are configured at the project level.""" + + @abstractmethod + def discover_project_directories(self) -> Iterator[Path]: + """Discover project directories by scraping config or history files.""" + + def _parse_servers_block( + self, + data: Dict[str, Dict[str, Any]], + scope: Scope, + project: Optional[Path], + ) -> Iterator[MCPConfiguration]: + """Utility function to parse a "mcpServer" block and return the MCP server entries. + + The format is standard across all assistants. + """ + # Lookup the two usual conventions + servers = data.get("mcpServers", data.get("servers", {})) + for name, entry in servers.items(): + if "url" in entry: + if entry.get("transport") == "sse": + transport = Transport.SSE + else: + transport = Transport.HTTP + else: + transport = Transport.STDIO + + yield MCPConfiguration( + name=name, + agent=self.name, + scope=scope, + transport=transport, + project=project, + command=entry.get("command"), + args=entry.get("args", []), + env=entry.get("env", {}), + url=entry.get("url"), + headers=entry.get("headers", {}), + ) + + def _get_user_mcp_configurations(self) -> Iterator[MCPConfiguration]: + """Return the MCP server entries for user-level (global) config files. + + Default implementation looks in the config folder for a file named "mcp.json". + """ + # Load config file + filepath = self.config_folder / "mcp.json" + if not (data := self._load_json_file(filepath)): + return + yield from self._parse_servers_block(data, Scope.USER, None) + + def _get_project_mcp_configurations( + self, directory: Path + ) -> Iterator[MCPConfiguration]: + """Return the MCP server entries for project-level config files.""" + if data := self._load_json_file(self.project_mcp_file(directory)): + yield from self._parse_servers_block(data, Scope.PROJECT, directory) + + def discover_mcp_configurations( + self, directories: Iterable[Path] + ) -> List[MCPConfiguration]: + """Discover MCP configurations from user and project config files. + + Iterates over user-level paths, then project-level paths for each + directory in *directories*. + """ + results: List[MCPConfiguration] = [] + + # User-level configs + results.extend(self._get_user_mcp_configurations()) + + # Project-level configs + for directory in directories: + results.extend(self._get_project_mcp_configurations(directory)) + + return results + + def discover_capabilities(self, server: MCPServer) -> bool: + """Discover capabilities for the given server. + + Returns whether the capabilities were discovered. + """ + return False + + @abstractmethod + def parse_mcp_activity( + self, payload: HookPayload, ai_config: AIDiscovery + ) -> MCPActivityRequest: + """Parse the MCP activity from an MCP hook payload.""" + + # Helper methods + + def _load_json_file(self, path: Path) -> Optional[Dict[str, Any]]: + """Load a JSON file and return the data, or None if the file doesn't exist (or is not a JSON object).""" + if not path.is_file(): + return None + try: + data = json.loads(path.read_text()) + except (OSError, json.JSONDecodeError): + return None + if not isinstance(data, dict): + return None + return data + + def _load_jsonl_file(self, path: Path) -> Iterator[Dict[str, Any]]: + """Load a JSONL file and return the data line by line, + or nothing if the file doesn't exist (or is not a JSON object).""" + if not path.is_file(): + yield from [] + try: + for line in open(path, "r"): + try: + yield json.loads(line) + except json.JSONDecodeError: + continue + except OSError: + yield from [] diff --git a/ggshield/verticals/ai/user.py b/ggshield/verticals/ai/user.py new file mode 100644 index 0000000000..c7db72580c --- /dev/null +++ b/ggshield/verticals/ai/user.py @@ -0,0 +1,202 @@ +""" +This module provides information about the current user and machine. + +It should be eventually merged with the logic used by local scanning. +""" + +import getpass +import logging +import os +import platform +import re +import socket +import subprocess +import sys +import uuid +from pathlib import Path +from typing import Optional + +from ggshield.core.dirs import get_user_home_dir +from ggshield.verticals.ai.models import UserInfo + + +logger = logging.getLogger(__name__) + +_MAC_IOREG_UUID_RE = re.compile(r'"IOPlatformUUID"\s*=\s*"([^"]+)"') + + +def get_user_info(machine_id: Optional[str] = None) -> UserInfo: + """Collect hostname, username, machine identifier, and best-effort email.""" + return UserInfo( + hostname=_get_hostname(), + username=_get_username(), + machine_id=machine_id or _get_machine_id(), + user_email=_get_user_email(), + ) + + +def _get_hostname() -> str: + if sys.platform == "win32": + name = (os.environ.get("COMPUTERNAME") or "").strip() + if name: + return name + try: + return socket.gethostname() or "unknown" + except OSError: + return "unknown" + + +def _get_username() -> str: + try: + return getpass.getuser() + except Exception: + pass + try: + return os.getlogin() + except Exception: + return "unknown" + + +def _get_user_email() -> Optional[str]: + """Best-effort user email; tries `git config user.email` first.""" + try: + result = subprocess.run( + ["git", "config", "user.email"], + capture_output=True, + text=True, + check=False, + ) + if result.returncode == 0: + email = (result.stdout or "").strip() + return email or None + except (OSError, subprocess.SubprocessError): + pass + return None + + +def _read_first_nonempty_line(path: Path) -> Optional[str]: + try: + text = path.read_text(encoding="utf-8", errors="replace").splitlines() + except OSError: + return None + for line in text: + stripped = line.strip() + if stripped: + return stripped + return None + + +def _get_linux_system_id() -> Optional[str]: + for candidate in ( + Path("/etc/machine-id"), + Path("/sys/class/dmi/id/product_uuid"), + Path("/var/lib/dbus/machine-id"), + ): + value = _read_first_nonempty_line(candidate) + if value: + return value + return None + + +def _get_macos_system_id() -> Optional[str]: + try: + result = subprocess.run( + ["ioreg", "-rd1", "-c", "IOPlatformExpertDevice"], + capture_output=True, + text=True, + check=False, + ) + if result.returncode != 0 or not result.stdout: + return None + match = _MAC_IOREG_UUID_RE.search(result.stdout) + if match: + return match.group(1).strip() + except (OSError, subprocess.SubprocessError): + pass + return None + + +def _parse_wmic_uuid(stdout: str) -> Optional[str]: + for line in stdout.splitlines(): + line = line.strip() + if not line or line.upper() == "UUID": + continue + try: + return str(uuid.UUID(line)) + except ValueError: + pass + return None + + +def _get_windows_system_id() -> Optional[str]: + try: + result = subprocess.run( + ["wmic", "csproduct", "get", "uuid"], + capture_output=True, + text=True, + check=False, + ) + if result.returncode == 0 and result.stdout: + parsed = _parse_wmic_uuid(result.stdout) + if parsed: + return parsed + except (OSError, subprocess.SubprocessError): + pass + + try: + result = subprocess.run( + [ + "powershell", + "-NoProfile", + "-Command", + "(Get-CimInstance Win32_ComputerSystemProduct).UUID", + ], + capture_output=True, + text=True, + check=False, + ) + if result.returncode == 0 and result.stdout: + try: + return str(uuid.UUID(result.stdout.strip())) + except ValueError: + pass + except (OSError, subprocess.SubprocessError): + pass + return None + + +def _get_machine_id() -> str: + + # In case Satori generated a machine id, use it. + path = get_user_home_dir() / ".satori" / "machine_id" + try: + if path.is_file(): + cached = _read_first_nonempty_line(path) + if cached: + return cached + except OSError: + pass + + system = platform.system().lower() + system_id = None + + if system == "darwin": + system_id = _get_macos_system_id() + elif system == "linux": + system_id = _get_linux_system_id() + elif sys.platform == "win32": + system_id = _get_windows_system_id() + + if system_id: + return system_id + + # If everything failed, use a random UUID. + # Store it so that satori can use it. + new_id = str(uuid.uuid4()) + try: + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(new_id + "\n") + except OSError: + pass + + return new_id diff --git a/pyproject.toml b/pyproject.toml index a3540af8fb..e345eab353 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,6 +49,7 @@ dependencies = [ "urllib3>=2.2.2,<3", "truststore>=0.10.1; python_version >= \"3.10\"", "notify-py>=0.3.43", + "pydantic>=2.11.10", ] [project.urls] @@ -61,6 +62,9 @@ ggshield = "ggshield.__main__:main" requires = ["hatchling"] build-backend = "hatchling.build" +[[tool.uv.index]] +url = "https://gitlab.gitguardian.ovh/api/v4/projects/435/packages/pypi/simple" + [tool.hatch.version] path = "ggshield/__init__.py" diff --git a/tests/unit/verticals/ai/test_agents.py b/tests/unit/verticals/ai/test_agents.py new file mode 100644 index 0000000000..3361929529 --- /dev/null +++ b/tests/unit/verticals/ai/test_agents.py @@ -0,0 +1,486 @@ +import json +from pathlib import Path +from typing import Any, Dict, List, Optional +from unittest.mock import patch + +import pytest + +from ggshield.verticals.ai.agents import Agent +from ggshield.verticals.ai.agents.claude_code import Claude, _mangle_server_name +from ggshield.verticals.ai.agents.copilot import Copilot +from ggshield.verticals.ai.agents.cursor import Cursor, _parse_tool_arguments +from ggshield.verticals.ai.models import ( + AIDiscovery, + EventType, + HookPayload, + MCPConfiguration, + MCPServer, + MCPToolInfo, + Scope, + Tool, + Transport, + UserInfo, +) + + +def _user() -> UserInfo: + return UserInfo( + hostname="host", username="user", machine_id="mid", user_email="u@e.com" + ) + + +def _cfg( + name: str = "srv", + agent: str = "cursor", + scope: Scope = Scope.USER, + project: Optional[Path] = None, +) -> MCPConfiguration: + return MCPConfiguration( + name=name, agent=agent, scope=scope, transport=Transport.STDIO, project=project + ) + + +def _ai_discovery(servers: Optional[List[MCPServer]] = None) -> AIDiscovery: + return AIDiscovery(user=_user(), servers=servers or [], discovery_duration=0.1) + + +def _payload( + agent: Agent, raw: Optional[Dict[str, Any]] = None, tool: Tool = Tool.MCP +) -> HookPayload: + return HookPayload( + event_type=EventType.PRE_TOOL_USE, + tool=tool, + content="", + identifier="", + agent=agent, + raw=raw or {}, + ) + + +# =========================================================================== +# Cursor +# =========================================================================== + + +class TestCursorDiscoverCapabilities: + def _setup_mcps_folder( + self, tmp_path: Path, project_path: Path, server_name: str + ) -> Path: + """Build a Cursor-style mcps// folder layout and return the agent.""" + mangled = project_path.as_posix().replace("/", "-").lstrip("-") + mcps_root = tmp_path / ".cursor" / "projects" / mangled / "mcps" + server_dir = mcps_root / "my-server" + server_dir.mkdir(parents=True) + # SERVER_METADATA.json + (server_dir / "SERVER_METADATA.json").write_text( + json.dumps({"serverName": server_name}) + ) + return server_dir + + def test_populates_tools_resources_prompts(self, tmp_path: Path): + project = Path("/home/user/project") + server_dir = self._setup_mcps_folder(tmp_path, project, "my-mcp") + + (server_dir / "tools").mkdir() + (server_dir / "tools" / "t1.json").write_text( + json.dumps({"name": "do_thing", "description": "Does a thing"}) + ) + (server_dir / "resources").mkdir() + (server_dir / "resources" / "r1.json").write_text( + json.dumps({"uri": "file:///data", "name": "data"}) + ) + (server_dir / "prompts").mkdir() + (server_dir / "prompts" / "p1.json").write_text( + json.dumps({"name": "greeting", "description": "Says hi"}) + ) + + cursor = Cursor() + cfg = _cfg(name="my-mcp", agent="cursor", project=project) + server = MCPServer(name="my-mcp", configurations=[cfg]) + + with patch.object( + type(cursor), + "config_folder", + new_callable=lambda: property(lambda self: tmp_path / ".cursor"), + ): + result = cursor.discover_capabilities(server) + + assert result is True + assert len(server.tools) == 1 + assert server.tools[0].name == "do_thing" + assert len(server.resources) == 1 + assert server.resources[0].uri == "file:///data" + assert len(server.prompts) == 1 + assert server.prompts[0].name == "greeting" + + def test_status_md_present_returns_false(self, tmp_path: Path): + project = Path("/home/user/project") + server_dir = self._setup_mcps_folder(tmp_path, project, "my-mcp") + (server_dir / "STATUS.md").write_text("disconnected") + (server_dir / "tools").mkdir() + (server_dir / "tools" / "t1.json").write_text(json.dumps({"name": "t"})) + + cursor = Cursor() + cfg = _cfg(name="my-mcp", agent="cursor", project=project) + server = MCPServer(name="my-mcp", configurations=[cfg]) + + with patch.object( + type(cursor), + "config_folder", + new_callable=lambda: property(lambda self: tmp_path / ".cursor"), + ): + result = cursor.discover_capabilities(server) + + assert result is False + assert len(server.tools) == 0 + + def test_no_matching_metadata_returns_false(self, tmp_path: Path): + project = Path("/home/user/project") + self._setup_mcps_folder(tmp_path, project, "other-server") + + cursor = Cursor() + cfg = _cfg(name="my-mcp", agent="cursor", project=project) + server = MCPServer(name="my-mcp", configurations=[cfg]) + + with patch.object( + type(cursor), + "config_folder", + new_callable=lambda: property(lambda self: tmp_path / ".cursor"), + ): + result = cursor.discover_capabilities(server) + + assert result is False + + def test_non_cursor_configuration_skipped(self): + cursor = Cursor() + cfg = _cfg(name="srv", agent="claude-code", project=Path("/proj")) + server = MCPServer(name="srv", configurations=[cfg]) + assert cursor.discover_capabilities(server) is False + + +class TestCursorDiscoverProjectDirectories: + def test_valid_workspace_json_yields_path(self, tmp_path: Path): + project_dir = tmp_path / "myproject" + project_dir.mkdir() + ws_storage = ( + tmp_path / ".config" / "Cursor" / "User" / "workspaceStorage" / "abc" + ) + ws_storage.mkdir(parents=True) + (ws_storage / "workspace.json").write_text( + json.dumps({"folder": f"file://{project_dir}"}) + ) + + cursor = Cursor() + with patch.object( + type(cursor), + "config_folder", + new_callable=lambda: property( + lambda self: tmp_path / ".config" / "Cursor" / "User" + ), + ): + with patch( + "ggshield.verticals.ai.agents.cursor.get_user_home_dir", + return_value=tmp_path, + ): + dirs = list(cursor.discover_project_directories()) + + assert project_dir.resolve() in dirs + + def test_missing_folder_key_skipped(self, tmp_path: Path): + ws_storage = ( + tmp_path / ".config" / "Cursor" / "User" / "workspaceStorage" / "abc" + ) + ws_storage.mkdir(parents=True) + (ws_storage / "workspace.json").write_text(json.dumps({"other": "val"})) + + cursor = Cursor() + with patch.object( + type(cursor), + "config_folder", + new_callable=lambda: property( + lambda self: tmp_path / ".config" / "Cursor" / "User" + ), + ): + with patch( + "ggshield.verticals.ai.agents.cursor.get_user_home_dir", + return_value=tmp_path, + ): + dirs = list(cursor.discover_project_directories()) + + assert dirs == [] + + +class TestCursorParseMcpActivity: + def test_strips_mcp_prefix_and_maps_server(self): + cursor = Cursor() + tool_info = MCPToolInfo(name="run_query") + server = MCPServer( + name="my-db-server", + tools=[tool_info], + configurations=[_cfg(name="db", agent="cursor")], + ) + discovery = _ai_discovery(servers=[server]) + payload = _payload( + cursor, + raw={ + "tool_name": "MCP:run_query", + "model": "gpt-4", + "workspace_roots": ["/home/user/proj"], + "tool_input": {"sql": "SELECT 1"}, + }, + ) + + req = cursor.parse_mcp_activity(payload, discovery) + + assert req.tool == "run_query" + assert req.server == "my-db-server" + assert req.model == "gpt-4" + assert req.input == {"sql": "SELECT 1"} + + def test_unknown_tool_returns_empty_server(self): + cursor = Cursor() + discovery = _ai_discovery(servers=[]) + payload = _payload(cursor, raw={"tool_name": "MCP:unknown"}) + + req = cursor.parse_mcp_activity(payload, discovery) + + assert req.tool == "unknown" + assert req.server == "" + + +# =========================================================================== +# Claude Code +# =========================================================================== + + +class TestClaudeGetUserMcpConfigurations: + def test_user_level_and_project_level_parsed(self, tmp_path: Path): + project_dir = tmp_path / "myproject" + project_dir.mkdir() + claude_json = { + "mcpServers": {"global-srv": {"command": "npx", "args": ["-y", "mcp"]}}, + "projects": { + str(project_dir): { + "mcpServers": { + "project-srv": {"command": "node", "args": ["index.js"]} + } + } + }, + } + with patch( + "ggshield.verticals.ai.agents.claude_code.get_user_home_dir", + return_value=tmp_path, + ): + (tmp_path / ".claude.json").write_text(json.dumps(claude_json)) + claude = Claude() + configs = list(claude._get_user_mcp_configurations()) + + names = {c.name for c in configs} + assert "global-srv" in names + assert "project-srv" in names + + def test_missing_file_yields_nothing(self, tmp_path: Path): + with patch( + "ggshield.verticals.ai.agents.claude_code.get_user_home_dir", + return_value=tmp_path, + ): + claude = Claude() + configs = list(claude._get_user_mcp_configurations()) + assert configs == [] + + +class TestClaudeDiscoverProjectDirectories: + def test_yields_existing_directories(self, tmp_path: Path): + project = tmp_path / "proj" + project.mkdir() + history = tmp_path / ".claude" / "history.jsonl" + history.parent.mkdir(parents=True) + history.write_text(json.dumps({"project": str(project)}) + "\n") + + claude = Claude() + with patch.object( + type(claude), + "config_folder", + new_callable=lambda: property(lambda self: tmp_path / ".claude"), + ): + dirs = list(claude.discover_project_directories()) + + assert project.resolve() in dirs + + def test_skips_nonexistent_directories(self, tmp_path: Path): + history = tmp_path / ".claude" / "history.jsonl" + history.parent.mkdir(parents=True) + history.write_text( + json.dumps({"project": str(tmp_path / "nonexistent")}) + "\n" + ) + + claude = Claude() + with patch.object( + type(claude), + "config_folder", + new_callable=lambda: property(lambda self: tmp_path / ".claude"), + ): + dirs = list(claude.discover_project_directories()) + + assert dirs == [] + + +class TestClaudeParseMcpActivity: + def test_parses_mcp_double_underscore_format(self): + claude = Claude() + cfg = _cfg(name="my.server", agent="claude-code") + server = MCPServer( + name="my.server", configurations=[cfg], tools=[MCPToolInfo(name="run")] + ) + discovery = _ai_discovery(servers=[server]) + # Claude mangles "my.server" → "my_server" in the tool name + payload = _payload( + claude, + raw={"tool_name": "mcp__my_server__run", "cwd": "/tmp", "tool_input": {}}, + ) + + req = claude.parse_mcp_activity(payload, discovery) + + assert req.tool == "run" + assert req.server == "my.server" + + def test_server_with_double_underscore_handled(self): + claude = Claude() + discovery = _ai_discovery(servers=[]) + payload = _payload( + claude, + raw={ + "tool_name": "mcp__a__b__tool_name", + "cwd": "/tmp", + "tool_input": {}, + }, + ) + + req = claude.parse_mcp_activity(payload, discovery) + + assert req.tool == "tool_name" + assert req.server == "a__b" # falls back to mangled name + + def test_fallback_to_mangled_name(self): + claude = Claude() + discovery = _ai_discovery(servers=[]) + payload = _payload( + claude, + raw={"tool_name": "mcp__unknown__do_it", "cwd": "/tmp", "tool_input": {}}, + ) + + req = claude.parse_mcp_activity(payload, discovery) + + assert req.server == "unknown" + + +# --------------------------------------------------------------------------- +# _mangle_server_name +# --------------------------------------------------------------------------- + + +class TestMangleServerName: + @pytest.mark.parametrize( + "name, expected", + [ + pytest.param("my-seRver-123", "my-seRver-123", id="alphanumeric_dashes"), + pytest.param( + "my.server/v2 alpha", "my_server_v2_alpha", id="special_chars" + ), + pytest.param("simple", "simple", id="plain_alpha"), + pytest.param("a@b#c", "a_b_c", id="symbols"), + ], + ) + def test_mangle_server_name(self, name: str, expected: str): + assert _mangle_server_name(name) == expected + + +# =========================================================================== +# Copilot +# =========================================================================== + + +class TestCopilotParseMcpActivity: + def test_simple_server_tool_split(self): + copilot = Copilot() + cfg = _cfg(name="myserver", agent="copilot") + server = MCPServer(name="myserver", configurations=[cfg]) + discovery = _ai_discovery(servers=[server]) + payload = _payload( + copilot, + raw={"tool_name": "myserver_mytool", "cwd": "/tmp", "tool_input": {}}, + ) + + req = copilot.parse_mcp_activity(payload, discovery) + + assert req.tool == "mytool" + assert req.server == "myserver" + + def test_multiple_underscores_raises(self): + """Documents that `split("_")` with >2 parts raises ValueError.""" + copilot = Copilot() + discovery = _ai_discovery(servers=[]) + payload = _payload( + copilot, + raw={ + "tool_name": "server_tool_name_extra", + "cwd": "/tmp", + "tool_input": {}, + }, + ) + + with pytest.raises(ValueError): + copilot.parse_mcp_activity(payload, discovery) + + def test_unknown_server_falls_back_to_cfg_name(self): + copilot = Copilot() + discovery = _ai_discovery(servers=[]) + payload = _payload( + copilot, + raw={"tool_name": "unknown_tool", "cwd": "/tmp", "tool_input": {}}, + ) + + req = copilot.parse_mcp_activity(payload, discovery) + + assert req.server == "unknown" + assert req.tool == "tool" + + +# =========================================================================== +# _parse_tool_arguments (Cursor helper) +# =========================================================================== + + +class TestParseToolArguments: + def test_valid_schema(self): + schema = { + "type": "object", + "properties": { + "query": {"type": "string", "description": "SQL query"}, + "limit": {"type": "integer"}, + }, + "required": ["query"], + } + result = _parse_tool_arguments(schema) + assert result is not None + assert len(result) == 2 + q = next(a for a in result if a.name == "query") + assert q.required is True + assert q.description == "SQL query" + lim = next(a for a in result if a.name == "limit") + assert lim.required is False + + def test_empty_properties_returns_none(self): + schema = {"type": "object", "properties": {}} + assert _parse_tool_arguments(schema) is None + + @pytest.mark.parametrize( + "schema", + [ + pytest.param(None, id="none"), + pytest.param("string", id="string"), + pytest.param(42, id="integer"), + ], + ) + def test_non_dict_schema_returns_none(self, schema: Any): + assert _parse_tool_arguments(schema) is None diff --git a/tests/unit/verticals/ai/test_cmd_ai.py b/tests/unit/verticals/ai/test_cmd_ai.py new file mode 100644 index 0000000000..75b33932e1 --- /dev/null +++ b/tests/unit/verticals/ai/test_cmd_ai.py @@ -0,0 +1,154 @@ +import json +from unittest.mock import patch + +from click.testing import CliRunner + +from ggshield.__main__ import cli +from ggshield.core.errors import APIKeyCheckError +from ggshield.verticals.ai.models import AIDiscovery, UserInfo + + +def _user(): + return UserInfo( + hostname="host", username="user", machine_id="mid", user_email="u@e.com" + ) + + +def _discovery(): + return AIDiscovery(user=_user(), servers=[], discovery_duration=0.1) + + +# --------------------------------------------------------------------------- +# ggshield secret scan ai-hook +# --------------------------------------------------------------------------- + + +class TestAiHookCmd: + @patch("ggshield.cmd.secret.scan.ai_hook.AIHookScanner") + @patch("ggshield.cmd.secret.scan.ai_hook.SecretScanner") + @patch("ggshield.cmd.secret.scan.ai_hook.create_client_from_config") + def test_valid_json_stdin(self, mock_client, mock_scanner_cls, mock_hook_scanner): + instance = mock_hook_scanner.return_value + instance.scan.return_value = 0 + + runner = CliRunner() + result = runner.invoke( + cli, + ["secret", "scan", "ai-hook"], + input='{"event_type": "test"}', + ) + + assert result.exit_code == 0 + instance.scan.assert_called_once() + + @patch("ggshield.cmd.secret.scan.ai_hook.AIHookScanner") + @patch("ggshield.cmd.secret.scan.ai_hook.SecretScanner") + @patch("ggshield.cmd.secret.scan.ai_hook.create_client_from_config") + def test_empty_stdin_returns_error( + self, mock_client, mock_scanner_cls, mock_hook_scanner + ): + instance = mock_hook_scanner.return_value + instance.scan.side_effect = ValueError("Empty input") + + runner = CliRunner() + result = runner.invoke( + cli, + ["secret", "scan", "ai-hook"], + input="", + ) + + assert result.exit_code == 1 + + @patch("ggshield.cmd.secret.scan.ai_hook.AIHookScanner") + @patch("ggshield.cmd.secret.scan.ai_hook.SecretScanner") + @patch("ggshield.cmd.secret.scan.ai_hook.create_client_from_config") + def test_large_stdin_does_not_crash( + self, mock_client, mock_scanner_cls, mock_hook_scanner + ): + instance = mock_hook_scanner.return_value + instance.scan.return_value = 0 + + runner = CliRunner() + large_input = "x" * (1024 * 1024) # 1 MB + result = runner.invoke( + cli, + ["secret", "scan", "ai-hook"], + input=large_input, + ) + + assert result.exit_code == 0 + + +# --------------------------------------------------------------------------- +# ggshield ai discover +# --------------------------------------------------------------------------- + + +class TestDiscoverCmd: + @patch( + "ggshield.cmd.ai.discover.discover_ai_configuration", + return_value=_discovery(), + ) + @patch("ggshield.cmd.ai.discover.create_client_from_config") + @patch( + "ggshield.cmd.ai.discover.submit_ai_discovery", + return_value=_discovery(), + ) + @patch("ggshield.cmd.ai.discover.save_discovery_cache") + def test_default_output(self, mock_save, mock_submit, mock_client, mock_discover): + runner = CliRunner() + result = runner.invoke(cli, ["ai", "discover"]) + + assert result.exit_code == 0 + mock_discover.assert_called_once() + + @patch( + "ggshield.cmd.ai.discover.discover_ai_configuration", + return_value=_discovery(), + ) + @patch("ggshield.cmd.ai.discover.create_client_from_config") + @patch( + "ggshield.cmd.ai.discover.submit_ai_discovery", + return_value=_discovery(), + ) + @patch("ggshield.cmd.ai.discover.save_discovery_cache") + def test_json_flag(self, mock_save, mock_submit, mock_client, mock_discover): + runner = CliRunner() + result = runner.invoke(cli, ["ai", "discover", "--json"]) + + assert result.exit_code == 0 + parsed = json.loads(result.output) + assert "user" in parsed + + @patch( + "ggshield.cmd.ai.discover.discover_ai_configuration", + return_value=_discovery(), + ) + @patch( + "ggshield.cmd.ai.discover.create_client_from_config", + side_effect=APIKeyCheckError("https://api.gitguardian.com", "no key"), + ) + def test_auth_failure_shows_warning(self, mock_client, mock_discover): + runner = CliRunner() + result = runner.invoke(cli, ["ai", "discover"]) + + assert result.exit_code == 0 + assert "Skipping upload" in result.output or "warning" in result.output.lower() + + @patch( + "ggshield.cmd.ai.discover.discover_ai_configuration", + return_value=_discovery(), + ) + @patch("ggshield.cmd.ai.discover.create_client_from_config") + @patch( + "ggshield.cmd.ai.discover.submit_ai_discovery", + side_effect=RuntimeError("API error"), + ) + def test_api_submission_failure_shows_warning( + self, mock_submit, mock_client, mock_discover + ): + runner = CliRunner() + result = runner.invoke(cli, ["ai", "discover"]) + + assert result.exit_code == 0 + assert "Could not upload" in result.output or "warning" in result.output.lower() diff --git a/tests/unit/verticals/ai/test_config.py b/tests/unit/verticals/ai/test_config.py new file mode 100644 index 0000000000..b1919c70ca --- /dev/null +++ b/tests/unit/verticals/ai/test_config.py @@ -0,0 +1,177 @@ +import json +from pathlib import Path +from typing import Any, Dict, Optional +from unittest.mock import patch + +import pytest + +from ggshield.verticals.ai.config import ( + extract_host_from_config, + get_mcp_remote_url, + load_json_file, + load_mcp_config, + save_json_file, +) + +# --------------------------------------------------------------------------- +# load_json_file +# --------------------------------------------------------------------------- + + +class TestLoadJsonFile: + def test_valid_json_returns_dict(self, tmp_path: Path): + f = tmp_path / "data.json" + f.write_text('{"key": "value"}') + assert load_json_file(f) == {"key": "value"} + + def test_nonexistent_file_returns_empty_dict(self, tmp_path: Path): + assert load_json_file(tmp_path / "nope.json") == {} + + def test_malformed_json_returns_empty_dict(self, tmp_path: Path): + f = tmp_path / "bad.json" + f.write_text("{invalid json") + assert load_json_file(f) == {} + + def test_oserror_returns_empty_dict(self, tmp_path: Path): + d = tmp_path / "dir.json" + d.mkdir() # reading a directory triggers OSError + assert load_json_file(d) == {} + + +# --------------------------------------------------------------------------- +# save_json_file +# --------------------------------------------------------------------------- + + +class TestSaveJsonFile: + def test_writes_valid_json_and_creates_parents(self, tmp_path: Path): + f = tmp_path / "sub" / "dir" / "out.json" + save_json_file(f, {"a": 1}) + assert json.loads(f.read_text()) == {"a": 1} + + def test_oserror_silently_swallowed(self, tmp_path: Path): + f = tmp_path / "readonly" / "out.json" + (tmp_path / "readonly").mkdir() + (tmp_path / "readonly").chmod(0o444) + try: + save_json_file(f, {"a": 1}) # should not raise + finally: + (tmp_path / "readonly").chmod(0o755) + + +# --------------------------------------------------------------------------- +# load_mcp_config +# --------------------------------------------------------------------------- + + +class TestLoadMcpConfig: + def test_workspace_config_found(self, tmp_path: Path): + ws = tmp_path / "project" + (ws / ".cursor").mkdir(parents=True) + (ws / ".cursor" / "mcp.json").write_text('{"servers": {"s1": {}}}') + result = load_mcp_config([str(ws)]) + assert result == {"servers": {"s1": {}}} + + def test_first_workspace_wins(self, tmp_path: Path): + ws1 = tmp_path / "p1" + ws2 = tmp_path / "p2" + for ws, content in [(ws1, '{"first": true}'), (ws2, '{"second": true}')]: + (ws / ".cursor").mkdir(parents=True) + (ws / ".cursor" / "mcp.json").write_text(content) + result = load_mcp_config([str(ws1), str(ws2)]) + assert result == {"first": True} + + def test_falls_back_to_global(self, tmp_path: Path): + global_mcp = tmp_path / ".cursor" / "mcp.json" + global_mcp.parent.mkdir(parents=True) + global_mcp.write_text('{"global": true}') + with patch( + "ggshield.verticals.ai.config.get_user_home_dir", return_value=tmp_path + ): + result = load_mcp_config([str(tmp_path / "no_workspace")]) + assert result == {"global": True} + + def test_neither_exists_returns_empty(self, tmp_path: Path): + with patch( + "ggshield.verticals.ai.config.get_user_home_dir", return_value=tmp_path + ): + result = load_mcp_config([]) + assert result == {} + + def test_list_config_skipped(self, tmp_path: Path): + ws = tmp_path / "project" + (ws / ".cursor").mkdir(parents=True) + (ws / ".cursor" / "mcp.json").write_text("[1, 2, 3]") + with patch( + "ggshield.verticals.ai.config.get_user_home_dir", return_value=tmp_path + ): + result = load_mcp_config([str(ws)]) + assert result == {} + + +# --------------------------------------------------------------------------- +# extract_host_from_config +# --------------------------------------------------------------------------- + + +class TestExtractHostFromConfig: + @pytest.mark.parametrize( + "config, expected", + [ + pytest.param( + {"env": {"CLICKHOUSE_HOST": "ch.example.com"}}, + "ch.example.com", + id="clickhouse_host", + ), + pytest.param( + {"env": {"GITLAB_API_URL": "https://gitlab.corp.com/api/v4"}}, + "gitlab.corp.com", + id="gitlab_api_url", + ), + pytest.param( + {"args": ["--config", "https://api.example.com/v1/mcp"]}, + "api.example.com", + id="https_arg", + ), + pytest.param( + {"args": ["http://localhost:8080/path"]}, + "localhost:8080", + id="http_arg", + ), + pytest.param( + {"args": ["--flag", "value"], "env": {}}, + None, + id="no_match", + ), + pytest.param(None, None, id="none_input"), + ], + ) + def test_extract_host(self, config: Dict[str, Any], expected: Optional[str]): + assert extract_host_from_config(config) == expected + + +# --------------------------------------------------------------------------- +# get_mcp_remote_url +# --------------------------------------------------------------------------- + + +class TestGetMcpRemoteUrl: + @pytest.mark.parametrize( + "config, expected", + [ + pytest.param( + {"args": ["--flag", "https://remote.com/mcp", "http://other.com"]}, + "https://remote.com/mcp", + id="first_url_returned", + ), + pytest.param( + {"args": ["--flag", "not-a-url"]}, + None, + id="no_url", + ), + pytest.param({"args": []}, None, id="empty_args"), + pytest.param({}, None, id="no_args_key"), + ], + ) + def test_get_mcp_remote_url(self, config: Dict[str, Any], expected: Optional[str]): + assert get_mcp_remote_url(config) == expected diff --git a/tests/unit/verticals/ai/test_discovery.py b/tests/unit/verticals/ai/test_discovery.py new file mode 100644 index 0000000000..5e922ef595 --- /dev/null +++ b/tests/unit/verticals/ai/test_discovery.py @@ -0,0 +1,313 @@ +from pathlib import Path +from typing import Any +from unittest.mock import MagicMock, patch + +import pytest + +from ggshield.verticals.ai.discovery import ( + _merge_mcp_configurations, + discover_ai_configuration, + load_discovery_cache, + refresh_and_maybe_submit_discovery, + save_discovery_cache, + submit_ai_discovery, +) +from ggshield.verticals.ai.models import ( + AIDiscovery, + MCPConfiguration, + Scope, + Transport, + UserInfo, +) + + +def _user(**kwargs: Any) -> UserInfo: + defaults = dict( + hostname="host", username="user", machine_id="mid", user_email="u@e.com" + ) + return UserInfo.model_validate(defaults | kwargs) + + +def _cfg( + name: str = "srv", agent: str = "cursor", scope: Scope = Scope.USER +) -> MCPConfiguration: + return MCPConfiguration( + name=name, agent=agent, scope=scope, transport=Transport.STDIO + ) + + +def _discovery(**kwargs: Any) -> AIDiscovery: + defaults = dict(user=_user(), servers=[], discovery_duration=0.1) + return AIDiscovery.model_validate(defaults | kwargs) + + +# --------------------------------------------------------------------------- +# _merge_mcp_configurations +# --------------------------------------------------------------------------- + + +class TestMergeMcpConfigurations: + def test_different_names_produce_separate_servers(self): + configs = [_cfg(name="a"), _cfg(name="b")] + servers = _merge_mcp_configurations(configs) + assert len(servers) == 2 + names = {s.name for s in servers} + assert names == {"a", "b"} + + def test_same_name_merged_under_one_server(self): + configs = [_cfg(name="x", agent="cursor"), _cfg(name="x", agent="claude-code")] + servers = _merge_mcp_configurations(configs) + assert len(servers) == 1 + assert len(servers[0].configurations) == 2 + + def test_empty_list_returns_empty(self): + assert _merge_mcp_configurations([]) == [] + + +# --------------------------------------------------------------------------- +# discover_ai_configuration +# --------------------------------------------------------------------------- + + +class TestDiscoverAIConfiguration: + @patch("ggshield.verticals.ai.discovery.get_user_info", return_value=_user()) + @patch("ggshield.verticals.ai.discovery.AGENTS") + def test_aggregates_agents(self, mock_agents, mock_user_info, tmp_path: Path): + agent1 = MagicMock() + agent1.discover_project_directories.return_value = iter([tmp_path / "p1"]) + agent1.discover_mcp_configurations.return_value = [_cfg(name="s1")] + agent1.discover_capabilities.return_value = False + + agent2 = MagicMock() + agent2.discover_project_directories.return_value = iter([]) + agent2.discover_mcp_configurations.return_value = [_cfg(name="s2")] + agent2.discover_capabilities.return_value = False + + mock_agents.values.return_value = [agent1, agent2] + + result = discover_ai_configuration() + + assert result.user == _user() + assert len(result.servers) == 2 + assert result.discovery_duration > 0 + + @patch("ggshield.verticals.ai.discovery.get_user_info", return_value=_user()) + @patch("ggshield.verticals.ai.discovery.AGENTS") + def test_stops_capability_discovery_at_first_success( + self, mock_agents, mock_user_info + ): + agent1 = MagicMock() + agent1.discover_project_directories.return_value = iter([]) + agent1.discover_mcp_configurations.return_value = [_cfg(name="s")] + agent1.discover_capabilities.return_value = True + + agent2 = MagicMock() + agent2.discover_project_directories.return_value = iter([]) + agent2.discover_mcp_configurations.return_value = [] + agent2.discover_capabilities.return_value = False + + mock_agents.values.return_value = [agent1, agent2] + + discover_ai_configuration() + + agent1.discover_capabilities.assert_called_once() + agent2.discover_capabilities.assert_not_called() + + +# --------------------------------------------------------------------------- +# load / save discovery cache +# --------------------------------------------------------------------------- + + +class TestLoadSaveDiscoveryCache: + def test_round_trip(self, tmp_path: Path): + discovery = _discovery() + with patch( + "ggshield.verticals.ai.discovery.get_cache_dir", return_value=tmp_path + ): + save_discovery_cache(discovery) + loaded = load_discovery_cache() + assert loaded is not None + assert loaded.user == discovery.user + + def test_load_returns_none_when_missing(self, tmp_path: Path): + with patch( + "ggshield.verticals.ai.discovery.get_cache_dir", return_value=tmp_path + ): + assert load_discovery_cache() is None + + def test_load_returns_none_on_valid_json_bad_schema(self, tmp_path: Path): + """Valid JSON that doesn't match AIDiscovery schema.""" + cache_file = tmp_path / "ai_discovery.json" + cache_file.write_text('{"foo": "bar"}') + with patch( + "ggshield.verticals.ai.discovery.get_cache_dir", return_value=tmp_path + ): + from pydantic import ValidationError + + with pytest.raises(ValidationError): + load_discovery_cache() + + def test_load_returns_none_on_invalid_json(self, tmp_path: Path): + """Valid JSON that doesn't match AIDiscovery schema.""" + cache_file = tmp_path / "ai_discovery.json" + cache_file.write_text("not json") + with patch( + "ggshield.verticals.ai.discovery.get_cache_dir", return_value=tmp_path + ): + from pydantic import ValidationError + + with pytest.raises(ValidationError): + load_discovery_cache() + + def test_load_returns_none_on_oserror(self, tmp_path: Path): + with patch( + "ggshield.verticals.ai.discovery.get_cache_dir", return_value=tmp_path + ): + cache_file = tmp_path / "ai_discovery.json" + cache_file.mkdir() # directory instead of file triggers OSError + assert load_discovery_cache() is None + + +# --------------------------------------------------------------------------- +# submit_ai_discovery +# --------------------------------------------------------------------------- + + +class TestSubmitAIDiscovery: + def test_successful_response(self): + discovery = _discovery() + response = MagicMock() + response.status_code = 200 + response.text = discovery.model_dump_json() + client = MagicMock() + client.post.return_value = response + + result = submit_ai_discovery(client, discovery) + assert result.user == discovery.user + + def test_non_200_raises(self): + discovery = _discovery() + response = MagicMock() + response.status_code = 500 + response.text = "Internal Server Error" + client = MagicMock() + client.post.return_value = response + + with pytest.raises(AssertionError): + submit_ai_discovery(client, discovery) + + +# --------------------------------------------------------------------------- +# refresh_and_maybe_submit_discovery +# --------------------------------------------------------------------------- + + +class TestRefreshAndMaybeSubmitDiscovery: + def _patch_all(self): + return ( + patch( + "ggshield.verticals.ai.discovery.load_discovery_cache", + ), + patch( + "ggshield.verticals.ai.discovery.discover_ai_configuration", + ), + patch( + "ggshield.verticals.ai.discovery.submit_ai_discovery", + ), + patch( + "ggshield.verticals.ai.discovery.save_discovery_cache", + ), + ) + + def test_no_cache_submits_and_saves(self): + p_load, p_discover, p_submit, p_save = self._patch_all() + with ( + p_load as m_load, + p_discover as m_discover, + p_submit as m_submit, + p_save as m_save, + ): + m_load.return_value = None + new_disc = _discovery() + m_discover.return_value = new_disc + submitted = _discovery(discovery_duration=0.5) + m_submit.return_value = submitted + + result = refresh_and_maybe_submit_discovery(MagicMock()) + + m_submit.assert_called_once() + m_save.assert_called_once_with(submitted) + assert result == submitted + + def test_unchanged_returns_cache_without_submission(self): + cached = _discovery() + p_load, p_discover, p_submit, p_save = self._patch_all() + with ( + p_load as m_load, + p_discover as m_discover, + p_submit as m_submit, + p_save as m_save, + ): + m_load.return_value = cached + m_discover.return_value = cached # identical discovery + + result = refresh_and_maybe_submit_discovery(MagicMock()) + + m_submit.assert_not_called() + m_save.assert_not_called() + assert result == cached + + def test_changed_submits_and_saves(self): + cached = _discovery(user=_user(hostname="old")) + new_disc = _discovery(user=_user(hostname="new")) + p_load, p_discover, p_submit, p_save = self._patch_all() + with ( + p_load as m_load, + p_discover as m_discover, + p_submit as m_submit, + p_save as m_save, + ): + m_load.return_value = cached + m_discover.return_value = new_disc + m_submit.return_value = new_disc + + result = refresh_and_maybe_submit_discovery(MagicMock()) + + m_submit.assert_called_once() + m_save.assert_called_once() + + def test_api_error_swallowed(self): + p_load, p_discover, p_submit, p_save = self._patch_all() + with ( + p_load as m_load, + p_discover as m_discover, + p_submit as m_submit, + p_save as m_save, + ): + m_load.return_value = None + new_disc = _discovery() + m_discover.return_value = new_disc + m_submit.side_effect = RuntimeError("network") + + result = refresh_and_maybe_submit_discovery(MagicMock()) + + assert result == new_disc + m_save.assert_not_called() + + def test_reuses_machine_id_from_cache(self): + cached = _discovery(user=_user(machine_id="cached-id")) + p_load, p_discover, p_submit, p_save = self._patch_all() + with ( + p_load as m_load, + p_discover as m_discover, + p_submit as m_submit, + p_save as m_save, + ): + m_load.return_value = cached + m_discover.return_value = cached + + refresh_and_maybe_submit_discovery(MagicMock()) + + _, kwargs = m_discover.call_args + assert kwargs.get("machine_id") == "cached-id" diff --git a/tests/unit/verticals/ai/test_hooks.py b/tests/unit/verticals/ai/test_hooks.py index adf355774f..021de1a128 100644 --- a/tests/unit/verticals/ai/test_hooks.py +++ b/tests/unit/verticals/ai/test_hooks.py @@ -22,6 +22,7 @@ def _dummy_payload(event_type: EventType = EventType.OTHER) -> HookPayload: content="", identifier="", agent=Cursor(), + raw={}, ) @@ -88,6 +89,7 @@ def test_no_secrets_returns_allow(self): content="safe content", identifier="id", agent=Cursor(), + raw={}, ) result = hook_scanner._scan_content(payload) assert isinstance(result, HookResult) @@ -104,6 +106,7 @@ def test_with_secrets_returns_block_and_message(self): content="content with sk-xxx", identifier="id", agent=Cursor(), + raw={}, ) result = hook_scanner._scan_content(payload) assert isinstance(result, HookResult) @@ -204,6 +207,7 @@ def test_message_for_bash_tool(self): content="echo sk-xxx", identifier="echo sk-xxx", agent=Cursor(), + raw={}, ) message = AIHookScanner._message_from_secrets([_make_secret("sk-xxx")], payload) assert "remove the secrets from the command" in message @@ -217,6 +221,7 @@ def test_message_for_read_tool(self): content="file content with secret", identifier="/path/to/file", agent=Cursor(), + raw={}, ) message = AIHookScanner._message_from_secrets([_make_secret("sk-xxx")], payload) assert "remove the secrets from" in message @@ -229,6 +234,7 @@ def test_message_for_other_tool(self): content="some content", identifier="id", agent=Cursor(), + raw={}, ) message = AIHookScanner._message_from_secrets([_make_secret("sk-xxx")], payload) assert "remove the secrets from the tool input" in message @@ -241,6 +247,7 @@ def test_message_escapes_markdown(self): content="content", identifier="id", agent=Cursor(), + raw={}, ) message = AIHookScanner._message_from_secrets( [_make_secret("sk-xxx")], payload, escape_markdown=True diff --git a/tests/unit/verticals/ai/test_mcp_activity.py b/tests/unit/verticals/ai/test_mcp_activity.py new file mode 100644 index 0000000000..2e4d6d5afa --- /dev/null +++ b/tests/unit/verticals/ai/test_mcp_activity.py @@ -0,0 +1,122 @@ +from unittest.mock import MagicMock, patch + +import pytest +from pydantic import ValidationError +from requests.exceptions import ConnectionError as RequestsConnectionError + +from ggshield.verticals.ai.mcp import send_mcp_activity +from ggshield.verticals.ai.models import ( + AIDiscovery, + EventType, + HookPayload, + MCPActivityRequest, + MCPActivityResponse, + Tool, + UserInfo, +) + + +def _user() -> UserInfo: + return UserInfo( + hostname="host", username="user", machine_id="mid", user_email="u@e.com" + ) + + +def _ai_discovery() -> AIDiscovery: + return AIDiscovery(user=_user(), servers=[], discovery_duration=0.1) + + +def _mcp_activity_request() -> MCPActivityRequest: + from pathlib import Path + + return MCPActivityRequest( + user=_user(), + tool="my_tool", + server="my_server", + agent="cursor", + model="gpt-4", + cwd=Path("/tmp"), + input={"key": "value"}, + ) + + +def _payload() -> HookPayload: + agent = MagicMock() + return HookPayload( + event_type=EventType.PRE_TOOL_USE, + tool=Tool.MCP, + content="content", + identifier="id", + agent=agent, + raw={}, + ) + + +class TestSendMCPActivity: + @patch("ggshield.verticals.ai.mcp.refresh_and_maybe_submit_discovery") + def test_successful_response(self, mock_refresh): + mock_refresh.return_value = _ai_discovery() + payload = _payload() + payload.agent.parse_mcp_activity.return_value = _mcp_activity_request() + + response_data = MCPActivityResponse(allowed=False, reason="blocked by policy") + mock_response = MagicMock() + mock_response.ok = True + mock_response.text = response_data.model_dump_json() + + client = MagicMock() + client.post.return_value = mock_response + + result = send_mcp_activity(client, payload) + + assert result.allowed is False + assert result.reason == "blocked by policy" + + @pytest.mark.parametrize( + "setup_client", + [ + pytest.param("connection_error", id="connection_error"), + pytest.param("non_ok_response", id="non_ok_response"), + pytest.param("validation_error", id="validation_error"), + ], + ) + @patch("ggshield.verticals.ai.mcp.refresh_and_maybe_submit_discovery") + def test_fail_open_returns_allowed(self, mock_refresh, setup_client): + mock_refresh.return_value = _ai_discovery() + payload = _payload() + payload.agent.parse_mcp_activity.return_value = _mcp_activity_request() + + client = MagicMock() + if setup_client == "connection_error": + client.post.side_effect = RequestsConnectionError("network down") + elif setup_client == "non_ok_response": + mock_response = MagicMock() + mock_response.ok = False + client.post.return_value = mock_response + elif setup_client == "validation_error": + mock_response = MagicMock() + mock_response.ok = True + mock_response.text = '{"unexpected_field": true}' + client.post.return_value = mock_response + + result = send_mcp_activity(client, payload) + + assert result.allowed is True + + @patch("ggshield.verticals.ai.mcp.refresh_and_maybe_submit_discovery") + def test_refresh_called_before_submission(self, mock_refresh): + mock_refresh.return_value = _ai_discovery() + payload = _payload() + payload.agent.parse_mcp_activity.return_value = _mcp_activity_request() + + response_data = MCPActivityResponse(allowed=True, reason="") + mock_response = MagicMock() + mock_response.ok = True + mock_response.text = response_data.model_dump_json() + + client = MagicMock() + client.post.return_value = mock_response + + send_mcp_activity(client, payload) + + mock_refresh.assert_called_once_with(client) diff --git a/tests/unit/verticals/ai/test_models.py b/tests/unit/verticals/ai/test_models.py new file mode 100644 index 0000000000..82f73a4a88 --- /dev/null +++ b/tests/unit/verticals/ai/test_models.py @@ -0,0 +1,323 @@ +from pathlib import Path +from unittest.mock import patch + +import pytest + +from ggshield.core.scan import File, StringScannable +from ggshield.verticals.ai.agents import Cursor +from ggshield.verticals.ai.models import ( + AIDiscovery, + EventType, + HookPayload, + MCPConfiguration, + MCPPromptInfo, + MCPResourceInfo, + MCPServer, + MCPToolInfo, + Scope, + Tool, + Transport, + UserInfo, +) + + +def _user(**kwargs) -> UserInfo: + defaults = dict( + hostname="host", username="user", machine_id="mid", user_email="u@e.com" + ) + return UserInfo(**(defaults | kwargs)) + + +def _server( + name="srv", + tools=None, + resources=None, + prompts=None, + configurations=None, +) -> MCPServer: + return MCPServer( + name=name, + tools=tools or [], + resources=resources or [], + prompts=prompts or [], + configurations=configurations or [], + ) + + +def _cfg(name="srv", agent="cursor", scope=Scope.USER, project=None) -> MCPConfiguration: + return MCPConfiguration( + name=name, + agent=agent, + scope=scope, + transport=Transport.STDIO, + project=project, + ) + + +# --------------------------------------------------------------------------- +# MCPServer.has_capabilities_unknown_to +# --------------------------------------------------------------------------- + + +class TestMCPServerHasCapabilitiesUnknownTo: + @pytest.mark.parametrize( + "self_kwargs, other_kwargs, expected", + [ + pytest.param( + {"tools": [MCPToolInfo(name="t1")]}, + {"tools": [MCPToolInfo(name="t1")]}, + False, + id="same_tools", + ), + pytest.param( + {"tools": [MCPToolInfo(name="t1"), MCPToolInfo(name="t2")]}, + {"tools": [MCPToolInfo(name="t1")]}, + True, + id="extra_tool", + ), + pytest.param( + {"resources": [MCPResourceInfo(uri="r1"), MCPResourceInfo(uri="r2")]}, + {"resources": [MCPResourceInfo(uri="r1")]}, + True, + id="extra_resource", + ), + pytest.param( + {"prompts": [MCPPromptInfo(name="p1"), MCPPromptInfo(name="p2")]}, + {"prompts": [MCPPromptInfo(name="p1")]}, + True, + id="extra_prompt", + ), + pytest.param( + {"tools": [MCPToolInfo(name="t1")]}, + {"tools": [MCPToolInfo(name="t1"), MCPToolInfo(name="t2")]}, + False, + id="subset_of_other", + ), + pytest.param({}, {}, False, id="both_empty"), + ], + ) + def test_has_capabilities_unknown_to(self, self_kwargs, other_kwargs, expected): + assert _server(**self_kwargs).has_capabilities_unknown_to( + _server(**other_kwargs) + ) is expected + + +# --------------------------------------------------------------------------- +# AIDiscovery.has_changed_from +# --------------------------------------------------------------------------- + + +class TestAIDiscoveryHasChangedFrom: + def test_identical_returns_false(self): + cfg = _cfg() + a = AIDiscovery( + user=_user(), + servers=[_server(configurations=[cfg])], + discovery_duration=0.1, + ) + b = AIDiscovery( + user=_user(), + servers=[_server(configurations=[cfg])], + discovery_duration=0.2, + ) + assert a.has_changed_from(b) is False + + def test_different_user_returns_true(self): + cfg = _cfg() + a = AIDiscovery( + user=_user(hostname="a"), + servers=[_server(configurations=[cfg])], + discovery_duration=0.1, + ) + b = AIDiscovery( + user=_user(hostname="b"), + servers=[_server(configurations=[cfg])], + discovery_duration=0.1, + ) + assert a.has_changed_from(b) is True + + def test_different_configuration_keys_returns_true(self): + a = AIDiscovery( + user=_user(), + servers=[_server(configurations=[_cfg(name="x")])], + discovery_duration=0.1, + ) + b = AIDiscovery( + user=_user(), + servers=[_server(configurations=[_cfg(name="y")])], + discovery_duration=0.1, + ) + assert a.has_changed_from(b) is True + + def test_new_capabilities_unknown_to_all_candidates_returns_true(self): + cfg = _cfg() + a = AIDiscovery( + user=_user(), + servers=[ + _server( + configurations=[cfg], + tools=[MCPToolInfo(name="new_tool")], + ) + ], + discovery_duration=0.1, + ) + b = AIDiscovery( + user=_user(), + servers=[_server(configurations=[cfg])], + discovery_duration=0.1, + ) + assert a.has_changed_from(b) is True + + def test_capabilities_known_to_one_candidate_returns_false(self): + cfg = _cfg() + tool = MCPToolInfo(name="known") + a = AIDiscovery( + user=_user(), + servers=[_server(configurations=[cfg], tools=[tool])], + discovery_duration=0.1, + ) + b = AIDiscovery( + user=_user(), + servers=[_server(configurations=[cfg], tools=[tool])], + discovery_duration=0.1, + ) + assert a.has_changed_from(b) is False + + def test_empty_servers_returns_false(self): + a = AIDiscovery(user=_user(), servers=[], discovery_duration=0.1) + b = AIDiscovery(user=_user(), servers=[], discovery_duration=0.1) + assert a.has_changed_from(b) is False + + +# --------------------------------------------------------------------------- +# HookPayload.scannable +# --------------------------------------------------------------------------- + + +class TestHookPayloadScannable: + def test_read_tool_existing_text_file_returns_file(self, tmp_path: Path): + f = tmp_path / "code.py" + f.write_text("secret = 'abc'") + payload = HookPayload( + event_type=EventType.PRE_TOOL_USE, + tool=Tool.READ, + content="", + identifier=str(f), + agent=Cursor(), + raw={}, + ) + assert isinstance(payload.scannable, File) + + def test_read_tool_missing_file_returns_string_scannable(self): + payload = HookPayload( + event_type=EventType.PRE_TOOL_USE, + tool=Tool.READ, + content="some content", + identifier="/nonexistent/path.txt", + agent=Cursor(), + raw={}, + ) + assert isinstance(payload.scannable, StringScannable) + + def test_read_tool_binary_file_returns_string_scannable(self, tmp_path: Path): + f = tmp_path / "image.png" + f.write_bytes(b"\x89PNG\r\n\x1a\n" + b"\x00" * 100) + payload = HookPayload( + event_type=EventType.PRE_TOOL_USE, + tool=Tool.READ, + content="", + identifier=str(f), + agent=Cursor(), + raw={}, + ) + assert isinstance(payload.scannable, StringScannable) + + def test_non_read_tool_returns_string_scannable(self): + payload = HookPayload( + event_type=EventType.PRE_TOOL_USE, + tool=Tool.BASH, + content="echo hello", + identifier="cmd", + agent=Cursor(), + raw={}, + ) + assert isinstance(payload.scannable, StringScannable) + + +# --------------------------------------------------------------------------- +# HookPayload.empty +# --------------------------------------------------------------------------- + + +class TestHookPayloadEmpty: + @pytest.mark.parametrize( + "content, expected", + [ + pytest.param("non-empty", False, id="non_empty_content"), + pytest.param("", True, id="empty_content"), + ], + ) + def test_empty(self, content, expected): + payload = HookPayload( + event_type=EventType.USER_PROMPT, + tool=None, + content=content, + identifier="id", + agent=Cursor(), + raw={}, + ) + assert payload.empty is expected + + +# --------------------------------------------------------------------------- +# Agent._parse_servers_block +# --------------------------------------------------------------------------- + + +class TestParseServersBlock: + def _parse(self, data, scope=Scope.USER, project=None): + return list(Cursor()._parse_servers_block(data, scope, project)) + + def test_mcp_servers_key_stdio(self): + data = { + "mcpServers": { + "myserver": { + "command": "npx", + "args": ["-y", "mcp-server"], + "env": {"KEY": "val"}, + } + } + } + configs = self._parse(data) + assert len(configs) == 1 + cfg = configs[0] + assert cfg.name == "myserver" + assert cfg.transport == Transport.STDIO + assert cfg.command == "npx" + assert cfg.args == ["-y", "mcp-server"] + assert cfg.env == {"KEY": "val"} + + def test_servers_key(self): + data = {"servers": {"s1": {"command": "node"}}} + configs = self._parse(data) + assert len(configs) == 1 + assert configs[0].name == "s1" + + def test_url_entry_detected_as_http(self): + data = {"mcpServers": {"remote": {"url": "https://example.com/mcp"}}} + configs = self._parse(data) + assert configs[0].transport == Transport.HTTP + assert configs[0].url == "https://example.com/mcp" + + def test_url_entry_with_sse_transport(self): + data = { + "mcpServers": { + "remote": {"url": "https://example.com/sse", "transport": "sse"} + } + } + configs = self._parse(data) + assert configs[0].transport == Transport.SSE + + def test_empty_block_yields_nothing(self): + assert self._parse({}) == [] + assert self._parse({"mcpServers": {}}) == [] diff --git a/tests/unit/verticals/ai/test_user.py b/tests/unit/verticals/ai/test_user.py new file mode 100644 index 0000000000..8b828cf745 --- /dev/null +++ b/tests/unit/verticals/ai/test_user.py @@ -0,0 +1,222 @@ +import subprocess +import uuid +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest + +from ggshield.verticals.ai.user import ( + _get_hostname, + _get_linux_system_id, + _get_machine_id, + _get_macos_system_id, + _get_user_email, + _get_username, + _parse_wmic_uuid, + get_user_info, +) + + +# --------------------------------------------------------------------------- +# get_user_info +# --------------------------------------------------------------------------- + + +class TestGetUserInfo: + @patch("ggshield.verticals.ai.user._get_hostname", return_value="myhost") + @patch("ggshield.verticals.ai.user._get_username", return_value="myuser") + @patch("ggshield.verticals.ai.user._get_machine_id", return_value="abc-123") + @patch("ggshield.verticals.ai.user._get_user_email", return_value="me@test.com") + def test_populates_all_fields(self, *_mocks): + info = get_user_info() + assert info.hostname == "myhost" + assert info.username == "myuser" + assert info.machine_id == "abc-123" + assert info.user_email == "me@test.com" + + @patch("ggshield.verticals.ai.user._get_hostname", return_value="h") + @patch("ggshield.verticals.ai.user._get_username", return_value="u") + @patch("ggshield.verticals.ai.user._get_machine_id", return_value="generated") + @patch("ggshield.verticals.ai.user._get_user_email", return_value=None) + def test_reuses_provided_machine_id(self, *_mocks): + info = get_user_info(machine_id="provided-id") + assert info.machine_id == "provided-id" + + +# --------------------------------------------------------------------------- +# _get_hostname +# --------------------------------------------------------------------------- + + +class TestGetHostname: + @patch("ggshield.verticals.ai.user.sys") + @patch("ggshield.verticals.ai.user.socket.gethostname", return_value="linuxbox") + def test_linux_returns_gethostname(self, _mock_host, mock_sys): + mock_sys.platform = "linux" + assert _get_hostname() == "linuxbox" + + @patch("ggshield.verticals.ai.user.sys") + @patch("ggshield.verticals.ai.user.os.environ", {"COMPUTERNAME": "WINBOX"}) + def test_windows_prefers_computername(self, mock_sys): + mock_sys.platform = "win32" + assert _get_hostname() == "WINBOX" + + @patch("ggshield.verticals.ai.user.sys") + @patch("ggshield.verticals.ai.user.socket.gethostname", side_effect=OSError) + def test_oserror_returns_unknown(self, _mock_host, mock_sys): + mock_sys.platform = "linux" + assert _get_hostname() == "unknown" + + +# --------------------------------------------------------------------------- +# _get_username +# --------------------------------------------------------------------------- + + +class TestGetUsername: + @patch("ggshield.verticals.ai.user.getpass.getuser", return_value="alice") + def test_returns_getuser(self, _mock): + assert _get_username() == "alice" + + @patch("ggshield.verticals.ai.user.os.getlogin", return_value="bob") + @patch("ggshield.verticals.ai.user.getpass.getuser", side_effect=Exception) + def test_falls_back_to_getlogin(self, *_mocks): + assert _get_username() == "bob" + + @patch("ggshield.verticals.ai.user.os.getlogin", side_effect=Exception) + @patch("ggshield.verticals.ai.user.getpass.getuser", side_effect=Exception) + def test_returns_unknown_when_both_fail(self, *_mocks): + assert _get_username() == "unknown" + + +# --------------------------------------------------------------------------- +# _get_user_email +# --------------------------------------------------------------------------- + + +class TestGetUserEmail: + @pytest.mark.parametrize( + "run_return, expected", + [ + pytest.param( + MagicMock(returncode=0, stdout="me@example.com\n"), + "me@example.com", + id="valid_email", + ), + pytest.param( + MagicMock(returncode=1, stdout=""), + None, + id="git_failure", + ), + pytest.param( + MagicMock(returncode=0, stdout=" \n"), + None, + id="empty_output", + ), + ], + ) + @patch("ggshield.verticals.ai.user.subprocess.run") + def test_get_user_email(self, mock_run, run_return, expected): + mock_run.return_value = run_return + assert _get_user_email() == expected + + @patch( + "ggshield.verticals.ai.user.subprocess.run", + side_effect=OSError("git not found"), + ) + def test_returns_none_on_oserror(self, _mock): + assert _get_user_email() is None + + +# --------------------------------------------------------------------------- +# _get_machine_id +# --------------------------------------------------------------------------- + + +class TestGetMachineId: + def test_returns_satori_cached_id(self, tmp_path): + satori_dir = tmp_path / ".satori" + satori_dir.mkdir() + (satori_dir / "machine_id").write_text("cached-uuid\n") + with patch( + "ggshield.verticals.ai.user.get_user_home_dir", return_value=tmp_path + ): + assert _get_machine_id() == "cached-uuid" + + @patch("ggshield.verticals.ai.user.platform.system", return_value="Linux") + @patch( + "ggshield.verticals.ai.user._get_linux_system_id", + return_value="linux-machine-id", + ) + def test_linux_reads_system_id(self, _mock_linux, _mock_platform, tmp_path): + with patch( + "ggshield.verticals.ai.user.get_user_home_dir", return_value=tmp_path + ): + assert _get_machine_id() == "linux-machine-id" + + @patch("ggshield.verticals.ai.user.platform.system", return_value="Darwin") + @patch( + "ggshield.verticals.ai.user._get_macos_system_id", + return_value="mac-uuid-123", + ) + def test_macos_parses_ioreg(self, _mock_mac, _mock_platform, tmp_path): + with patch( + "ggshield.verticals.ai.user.get_user_home_dir", return_value=tmp_path + ): + assert _get_machine_id() == "mac-uuid-123" + + @patch("ggshield.verticals.ai.user.platform.system", return_value="Linux") + @patch("ggshield.verticals.ai.user._get_linux_system_id", return_value=None) + @patch("ggshield.verticals.ai.user.uuid.uuid4") + def test_generates_uuid_when_all_fail( + self, mock_uuid4, _mock_linux, _mock_platform, tmp_path + ): + fixed_uuid = uuid.UUID("12345678-1234-5678-1234-567812345678") + mock_uuid4.return_value = fixed_uuid + with patch( + "ggshield.verticals.ai.user.get_user_home_dir", return_value=tmp_path + ): + result = _get_machine_id() + assert result == str(fixed_uuid) + assert (tmp_path / ".satori" / "machine_id").read_text().strip() == str( + fixed_uuid + ) + + @patch("ggshield.verticals.ai.user.platform.system", return_value="Linux") + @patch("ggshield.verticals.ai.user._get_linux_system_id", return_value=None) + @patch("ggshield.verticals.ai.user.uuid.uuid4") + def test_persistence_failure_still_returns_uuid( + self, mock_uuid4, _mock_linux, _mock_platform, tmp_path + ): + fixed_uuid = uuid.UUID("aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee") + mock_uuid4.return_value = fixed_uuid + # Make .satori a file so mkdir fails + (tmp_path / ".satori").write_text("block") + with patch( + "ggshield.verticals.ai.user.get_user_home_dir", return_value=tmp_path + ): + result = _get_machine_id() + assert result == str(fixed_uuid) + + +# --------------------------------------------------------------------------- +# _parse_wmic_uuid +# --------------------------------------------------------------------------- + + +class TestParseWmicUuid: + @pytest.mark.parametrize( + "stdout, expected", + [ + pytest.param( + "UUID\n4C4C4544-0044-4810-8057-B5C04F4A5331\n", + "4c4c4544-0044-4810-8057-b5c04f4a5331", + id="valid_uuid", + ), + pytest.param("UUID\n", None, id="header_only"), + pytest.param("UUID\nnot-a-uuid\n", None, id="invalid_line"), + pytest.param("", None, id="empty_string"), + ], + ) + def test_parse_wmic_uuid(self, stdout, expected): + assert _parse_wmic_uuid(stdout) == expected diff --git a/uv.lock b/uv.lock index 4e3307ee66..1a0607d1f6 100644 --- a/uv.lock +++ b/uv.lock @@ -811,6 +811,7 @@ dependencies = [ { name = "oauthlib" }, { name = "packaging" }, { name = "platformdirs" }, + { name = "pydantic" }, { name = "pygitguardian" }, { name = "pyjwt" }, { name = "python-dotenv" }, @@ -877,6 +878,7 @@ requires-dist = [ { name = "oauthlib", specifier = "~=3.2.1" }, { name = "packaging", specifier = ">=22.0" }, { name = "platformdirs", specifier = "~=3.0.0" }, + { name = "pydantic", specifier = ">=2.11.10" }, { name = "pygitguardian", specifier = "~=1.28.0" }, { name = "pyjwt", specifier = "~=2.6.0" }, { name = "python-dotenv", specifier = "~=0.21.0" },