diff --git a/.importlinter b/.importlinter index a56c7d3c7e..9308e69e03 100644 --- a/.importlinter +++ b/.importlinter @@ -9,8 +9,8 @@ name = ggshield-layers type = layers layers = ggshield.__main__ - ggshield.cmd.auth | ggshield.cmd.config | ggshield.cmd.hmsl | ggshield.cmd.honeytoken | ggshield.cmd.install | ggshield.cmd.plugin | ggshield.cmd.quota | ggshield.cmd.secret | ggshield.cmd.status | ggshield.cmd.utils - ggshield.verticals.auth | ggshield.verticals.hmsl | ggshield.verticals.secret + ggshield.cmd.ai | ggshield.cmd.auth | ggshield.cmd.config | ggshield.cmd.hmsl | ggshield.cmd.honeytoken | ggshield.cmd.install | ggshield.cmd.plugin | ggshield.cmd.quota | ggshield.cmd.secret | ggshield.cmd.status | ggshield.cmd.utils + ggshield.verticals.ai | ggshield.verticals.auth | ggshield.verticals.hmsl | ggshield.verticals.secret ggshield.core click | ggshield.utils | pygitguardian ignore_imports = @@ -22,6 +22,7 @@ unmatched_ignore_imports_alerting = warn name = verticals-cmd-transversals type = forbidden source_modules = + ggshield.cmd.ai ggshield.cmd.auth ggshield.cmd.config ggshield.cmd.hmsl @@ -33,10 +34,13 @@ source_modules = ggshield.cmd.status ggshield.cmd.utils forbidden_modules = + ggshield.verticals.ai ggshield.verticals.auth ggshield.verticals.hmsl ggshield.verticals.secret ignore_imports = + ggshield.cmd.ai.** -> ggshield.verticals.ai + ggshield.cmd.ai.** -> ggshield.verticals.ai.** ggshield.cmd.auth.** -> ggshield.verticals.auth ggshield.cmd.auth.** -> ggshield.verticals.auth.** ggshield.cmd.auth.** -> ggshield.verticals.hmsl.** @@ -46,7 +50,7 @@ ignore_imports = ggshield.cmd.hmsl.** -> ggshield.verticals.hmsl.** ggshield.cmd.honeytoken.** -> ggshield.verticals.honeytoken ggshield.cmd.honeytoken.** -> ggshield.verticals.honeytoken.** - ggshield.cmd.install -> ggshield.verticals.secret.ai_hook + ggshield.cmd.install -> ggshield.verticals.ai.installation ggshield.cmd.install.** -> ggshield.verticals.install ggshield.cmd.install.** -> ggshield.verticals.install.** ggshield.cmd.plugin.** -> ggshield.core.plugin @@ -55,6 +59,7 @@ ignore_imports = ggshield.cmd.quota.** -> ggshield.verticals.quota.** ggshield.cmd.secret.** -> ggshield.verticals.secret ggshield.cmd.secret.** -> ggshield.verticals.secret.** + ggshield.cmd.secret.scan.ai_hook -> ggshield.verticals.ai.hooks ggshield.cmd.status.** -> ggshield.verticals.status ggshield.cmd.status.** -> ggshield.verticals.status.** ggshield.cmd.utils.** -> ggshield.verticals.utils diff --git a/ggshield/__main__.py b/ggshield/__main__.py index 0baf27b776..0744cbd72f 100644 --- a/ggshield/__main__.py +++ b/ggshield/__main__.py @@ -10,6 +10,7 @@ import click from ggshield import __version__ +from ggshield.cmd.ai import ai_group from ggshield.cmd.auth import auth_group from ggshield.cmd.config import config_group from ggshield.cmd.hmsl import hmsl_group @@ -88,6 +89,7 @@ def _load_plugins() -> PluginRegistry: @click.group( context_settings={"help_option_names": ["-h", "--help"]}, commands={ + "ai": ai_group, "auth": auth_group, "config": config_group, "plugin": plugin_group, diff --git a/ggshield/cmd/ai/__init__.py b/ggshield/cmd/ai/__init__.py new file mode 100644 index 0000000000..5e0ee14f45 --- /dev/null +++ b/ggshield/cmd/ai/__init__.py @@ -0,0 +1,12 @@ +from typing import Any + +import click + +from ggshield.cmd.ai.discover import discover_cmd +from ggshield.cmd.utils.common_options import add_common_options + + +@click.group(commands={"discover": discover_cmd}) +@add_common_options() +def ai_group(**kwargs: Any) -> None: + """Commands to work with MCP (Model Context Protocol) servers.""" diff --git a/ggshield/cmd/ai/discover.py b/ggshield/cmd/ai/discover.py new file mode 100644 index 0000000000..f5916085bd --- /dev/null +++ b/ggshield/cmd/ai/discover.py @@ -0,0 +1,70 @@ +""" +MCP Discover command - Discovers MCP servers and optionally probes them +for tools, resources, and prompts. +""" + +import json +from typing import Any + +import click +from rich import print + +from ggshield.cmd.utils.common_options import add_common_options +from ggshield.cmd.utils.context_obj import ContextObj +from ggshield.core import ui +from ggshield.core.client import create_client_from_config +from ggshield.core.errors import APIKeyCheckError, UnknownInstanceError +from ggshield.verticals.ai.discovery import ( + discover_ai_configuration, + save_discovery_cache, + submit_ai_discovery, +) + + +@click.command(name="discover") +@click.option( + "--json", + "use_json", + is_flag=True, + default=False, + help="Output as JSON", +) +@add_common_options() +@click.pass_context +def discover_cmd( + ctx: click.Context, + use_json: bool, + **kwargs: Any, +) -> None: + """ + Discover MCP servers and their configuration. + + Parses MCP configuration files from supported assistants + + Examples: + ggshield mcp discover + ggshield mcp discover --json + """ + + config = discover_ai_configuration() + + if use_json: + click.echo(json.dumps(config.model_dump(mode="json"), indent=2)) + else: + print(config) + + ctx_obj = ContextObj.get(ctx) + try: + client = create_client_from_config(ctx_obj.config) + except (APIKeyCheckError, UnknownInstanceError) as exc: + ui.display_warning( + f"Skipping upload of AI discovery to GitGuardian ({exc}). " + "Authenticate with `ggshield auth login` to enable upload." + ) + return + + try: + config = submit_ai_discovery(client, config) + save_discovery_cache(config) + except Exception as exc: + ui.display_warning(f"Could not upload AI discovery to GitGuardian: {exc}") diff --git a/ggshield/cmd/install.py b/ggshield/cmd/install.py index f835926d7f..9afee1469c 100644 --- a/ggshield/cmd/install.py +++ b/ggshield/cmd/install.py @@ -10,7 +10,7 @@ from ggshield.core.dirs import get_data_dir from ggshield.core.errors import UnexpectedError from ggshield.utils.git_shell import check_git_dir, git -from ggshield.verticals.secret.ai_hook import AI_FLAVORS, install_hooks +from ggshield.verticals.ai.installation import AGENTS, install_hooks # This snippet is used by the global hook to call the hook defined in the @@ -39,7 +39,7 @@ @click.option( "--hook-type", "-t", - type=click.Choice(["pre-commit", "pre-push"] + list(AI_FLAVORS.keys())), + type=click.Choice(["pre-commit", "pre-push"] + list(AGENTS.keys())), help="Type of hook to install.", default="pre-commit", ) @@ -61,7 +61,7 @@ def install_cmd( It can also install ggshield as a Cursor IDE or Claude Code agent hook. """ - if hook_type in AI_FLAVORS: + if hook_type in AGENTS: return install_hooks(name=hook_type, mode=mode, force=force) return_code = ( diff --git a/ggshield/cmd/secret/scan/ai_hook.py b/ggshield/cmd/secret/scan/ai_hook.py index c53ab6a716..3689b00131 100644 --- a/ggshield/cmd/secret/scan/ai_hook.py +++ b/ggshield/cmd/secret/scan/ai_hook.py @@ -10,9 +10,11 @@ from ggshield.core import ui from ggshield.core.client import create_client_from_config from ggshield.core.scan import ScanContext, ScanMode +from ggshield.verticals.ai.hooks import AIHookScanner from ggshield.verticals.secret import SecretScanner -from ggshield.verticals.secret.ai_hook import AIHookScanner -from ggshield.verticals.secret.ai_hook.models import MAX_READ_SIZE + + +MAX_READ_SIZE = 1024 * 1024 * 10 # We restrict stdin read to 10MB @click.command() diff --git a/ggshield/core/scan/__init__.py b/ggshield/core/scan/__init__.py index b37b6197f6..e2c370a110 100644 --- a/ggshield/core/scan/__init__.py +++ b/ggshield/core/scan/__init__.py @@ -3,6 +3,7 @@ from .scan_context import ScanContext from .scan_mode import ScanMode from .scannable import DecodeError, NonSeekableFileError, Scannable, StringScannable +from .scanner import ResultsProtocol, ScannerProtocol, SecretProtocol __all__ = [ @@ -11,8 +12,11 @@ "DecodeError", "File", "NonSeekableFileError", + "ResultsProtocol", "ScanContext", "ScanMode", "Scannable", + "ScannerProtocol", + "SecretProtocol", "StringScannable", ] diff --git a/ggshield/core/scan/scanner.py b/ggshield/core/scan/scanner.py new file mode 100644 index 0000000000..9f963669f6 --- /dev/null +++ b/ggshield/core/scan/scanner.py @@ -0,0 +1,50 @@ +""" +Protocols for SecretScanner and its results, +so that other verticals can use the scanner if they are provided one. +""" + +from collections.abc import Sequence +from typing import Iterable, Optional, Protocol + +from pygitguardian.models import Match + +from ggshield.core.scanner_ui import ScannerUI + +from . import Scannable + + +class SecretProtocol(Protocol): + """Abstract base class for secrets. + + We use getters instead of properties to have a . + """ + + @property + def detector_display_name(self) -> str: ... + + @property + def validity(self) -> str: ... + + @property + def matches(self) -> Sequence[Match]: ... + + +class ResultProtocol(Protocol): + @property + def secrets(self) -> Sequence[SecretProtocol]: ... + + +class ResultsProtocol(Protocol): + @property + def results(self) -> Sequence[ResultProtocol]: ... + + +class ScannerProtocol(Protocol): + """Protocol for scanners.""" + + def scan( + self, + files: Iterable[Scannable], + scanner_ui: ScannerUI, + scan_threads: Optional[int] = None, + ) -> ResultsProtocol: ... diff --git a/ggshield/verticals/ai/__init__.py b/ggshield/verticals/ai/__init__.py new file mode 100644 index 0000000000..8ca3b1590d --- /dev/null +++ b/ggshield/verticals/ai/__init__.py @@ -0,0 +1,20 @@ +from .agents import AGENTS +from .config import load_mcp_config +from .discovery import ( + discover_ai_configuration, + load_discovery_cache, + save_discovery_cache, +) +from .hooks import AIHookScanner +from .installation import install_hooks + + +__all__ = [ + "AGENTS", + "AIHookScanner", + "discover_ai_configuration", + "install_hooks", + "load_discovery_cache", + "load_mcp_config", + "save_discovery_cache", +] diff --git a/ggshield/verticals/ai/agents/__init__.py b/ggshield/verticals/ai/agents/__init__.py new file mode 100644 index 0000000000..6e354d473a --- /dev/null +++ b/ggshield/verticals/ai/agents/__init__.py @@ -0,0 +1,14 @@ +from typing import Dict + +from ..models import Agent +from .claude_code import Claude +from .copilot import Copilot +from .cursor import Cursor + + +AGENTS: Dict[str, Agent] = { + agent.name: agent for agent in [Cursor(), Claude(), Copilot()] +} + + +__all__ = ["AGENTS", "Claude", "Copilot", "Cursor"] diff --git a/ggshield/verticals/ai/agents/claude_code.py b/ggshield/verticals/ai/agents/claude_code.py new file mode 100644 index 0000000000..27e28c6b44 --- /dev/null +++ b/ggshield/verticals/ai/agents/claude_code.py @@ -0,0 +1,198 @@ +import json +import re +from pathlib import Path +from typing import Any, Dict, Iterator, List, Optional + +import click + +from ggshield.core.dirs import get_user_home_dir + +from ..models import ( + Agent, + AIDiscovery, + EventType, + HookPayload, + HookResult, + MCPActivityRequest, + MCPConfiguration, + Scope, +) + + +class Claude(Agent): + """Behavior specific to Claude Code.""" + + @property + def name(self) -> str: + return "claude-code" + + @property + def display_name(self) -> str: + return "Claude Code" + + @property + def config_folder(self) -> Path: + return get_user_home_dir() / ".claude" + + def output_result(self, result: HookResult) -> int: + response = {} + if result.block: + if result.payload.event_type in [ + EventType.USER_PROMPT, + EventType.POST_TOOL_USE, + ]: + response["decision"] = "block" + response["reason"] = result.message + response["additionalContext"] = result.message + elif result.payload.event_type == EventType.PRE_TOOL_USE: + response["hookSpecificOutput"] = { + "hookEventName": "PreToolUse", + "permissionDecision": "deny", + "permissionDecisionReason": result.message, + } + else: + # Should not happen, but just in case use Claude's "universal" fields. + response = { + "continue": False, + "stopReason": result.message, + } + else: + response["continue"] = True + + click.echo(json.dumps(response)) + # We don't use the return 2 convention to make sure our JSON output is read. + return 0 + + @property + def settings_path(self) -> Path: + return Path(".claude") / "settings.json" + + @property + def settings_template(self) -> Dict[str, Any]: + return { + "hooks": { + "PreToolUse": [ + { + "matcher": ".*", + "hooks": [ + { + "type": "command", + "command": "", + } + ], + } + ], + "PostToolUse": [ + { + "matcher": ".*", + "hooks": [ + { + "type": "command", + "command": "", + } + ], + } + ], + "UserPromptSubmit": [ + { + "matcher": ".*", + "hooks": [ + { + "type": "command", + "command": "", + } + ], + } + ], + } + } + + def settings_locate( + self, candidates: List[Dict[str, Any]], template: Dict[str, Any] + ) -> Optional[Dict[str, Any]]: + # We have two kind of lists: at the root of each hook (with a matcher) + # and in each hook (with a list of commands). + if "matcher" in template: + for obj in candidates: + if obj.get("matcher") == template["matcher"]: + return obj + return None + for obj in candidates: + command = obj.get("command", "") + if "ggshield" in command or "" in command: + return obj + return None + + def project_mcp_file(self, directory: Path) -> Path: + return directory / ".mcp.json" + + def _get_user_mcp_configurations(self) -> Iterator[MCPConfiguration]: + """Look into ~/.claude.json for both user-level and project-level MCP server entries.""" + # Load config file + filepath = get_user_home_dir() / ".claude.json" + if not (data := self._load_json_file(filepath)): + return + + # User-level mcpServers + yield from self._parse_servers_block(data, Scope.USER, None) + + # Per-project entries in projects dict + projects = data.get("projects", {}) + if not isinstance(projects, dict): + return + for project_key, project_data in projects.items(): + if not isinstance(project_data, dict): + continue + yield from self._parse_servers_block( + project_data, Scope.USER, Path(project_key) + ) + + def discover_project_directories(self) -> Iterator[Path]: + """Discover project directories by scraping config files.""" + history_file = self.config_folder / "history.jsonl" + projects = set() + for line in self._load_jsonl_file(history_file): + if "project" in line: + projects.add(Path(line["project"])) + for project in projects: + if project.is_dir(): + yield project.resolve() + + def parse_mcp_activity( + self, payload: HookPayload, ai_config: AIDiscovery + ) -> MCPActivityRequest: + """Parse the MCP activity from an MCP hook payload.""" + + # Claude Code's hook tool name is "mcp__{server}__{tool}" + raw_tool_name: str = payload.raw.get("tool_name", "") + parts = raw_tool_name.split("__") + # The server name can be anything, but we assume no MCP tool has a "__" in its name + tool = parts[-1] + server_cfg_name = "__".join(parts[1:-1]) + + # Lookup the server name based on its configuration name + # Fallback to the server name if not found + server_name = server_cfg_name + for server in ai_config.servers: + for configuration in server.configurations: + if _mangle_server_name(configuration.name) == server_cfg_name: + server_name = server.name + break + + return MCPActivityRequest( + user=ai_config.user, + tool=tool, + server=server_name, + agent=self.name, + model="", + cwd=Path(payload.raw.get("cwd", "")), + input=payload.raw.get("tool_input", {}), + ) + + +MANGLING_PATTERN = re.compile(r"[^A-Za-z0-9-]") + + +def _mangle_server_name(name: str) -> str: + """Mangle a server name in the same way Claude Code does.""" + return MANGLING_PATTERN.sub("_", name) diff --git a/ggshield/verticals/ai/agents/copilot.py b/ggshield/verticals/ai/agents/copilot.py new file mode 100644 index 0000000000..6e846b402a --- /dev/null +++ b/ggshield/verticals/ai/agents/copilot.py @@ -0,0 +1,94 @@ +import json +from pathlib import Path +from typing import Iterator + +import click + +from ggshield.core.dirs import get_user_home_dir + +from ..models import AIDiscovery, EventType, HookPayload, HookResult, MCPActivityRequest +from .claude_code import Claude + + +class Copilot(Claude): + """Behavior specific to Copilot Chat. + + Inherits most of its behavior from Claude Code. + """ + + @property + def name(self) -> str: + return "copilot" + + @property + def display_name(self) -> str: + return "Copilot Chat" + + @property + def config_folder(self) -> Path: + return get_user_home_dir() / ".config" / "Code" / "User" + + def output_result(self, result: HookResult) -> int: + response = {} + if result.block: + if result.payload.event_type == EventType.PRE_TOOL_USE: + response["hookSpecificOutput"] = { + "hookEventName": "PreToolUse", + "permissionDecision": "deny", + "permissionDecisionReason": result.message, + } + elif result.payload.event_type == EventType.POST_TOOL_USE: + response["decision"] = "block" + response["reason"] = result.message + else: + response["continue"] = False + response["stopReason"] = result.message + else: + response["continue"] = True + + click.echo(json.dumps(response)) + return 0 + + @property + def settings_path(self) -> Path: + return Path(".github") / "hooks" / "hooks.json" + + def project_mcp_file(self, directory: Path) -> Path: + return directory / ".vscode" / "mcp.json" + + def discover_project_directories(self) -> Iterator[Path]: + # Try to parse workspaces settings. + for file in self.config_folder.glob("workspaceStorage/*/workspace.json"): + if (data := self._load_json_file(file)) and "folder" in data: + path = Path(data["folder"].removeprefix("file://")) + if path.is_dir(): + yield path.resolve() + + def parse_mcp_activity( + self, payload: HookPayload, ai_config: AIDiscovery + ) -> MCPActivityRequest: + """Parse the MCP activity from an MCP hook payload.""" + + # Copilot's hook tool name is "mcp_{server}_{tool}" + # which is unfortunate because a lot of tools have a "_" in their name. + raw_tool_name: str = payload.raw.get("tool_name", "") + server_cfg_name, tool = raw_tool_name.split("_") + + # Lookup the server name based on its configuration name + # Fallback to the server name if not found + server_name = server_cfg_name + for server in ai_config.servers: + for configuration in server.configurations: + if configuration.name == server_cfg_name: + server_name = configuration.name + break + + return MCPActivityRequest( + user=ai_config.user, + tool=tool, + server=server_name, + agent=self.name, + model="", + cwd=Path(payload.raw.get("cwd", "")), + input=payload.raw.get("tool_input", {}), + ) diff --git a/ggshield/verticals/ai/agents/cursor.py b/ggshield/verticals/ai/agents/cursor.py new file mode 100644 index 0000000000..90484bc261 --- /dev/null +++ b/ggshield/verticals/ai/agents/cursor.py @@ -0,0 +1,235 @@ +import json +from pathlib import Path +from typing import Any, Dict, Iterator, List, Optional + +import click + +from ggshield.core.dirs import get_user_home_dir + +from ..models import ( + Agent, + AIDiscovery, + EventType, + HookPayload, + HookResult, + MCPActivityRequest, + MCPArgumentInfo, + MCPPromptInfo, + MCPResourceInfo, + MCPServer, + MCPToolInfo, +) + + +class Cursor(Agent): + """Behavior specific to Cursor.""" + + @property + def name(self) -> str: + return "cursor" + + @property + def display_name(self) -> str: + return "Cursor" + + @property + def config_folder(self) -> Path: + return get_user_home_dir() / ".cursor" + + def output_result(self, result: HookResult) -> int: + response = {} + if result.payload.event_type == EventType.USER_PROMPT: + response["continue"] = not result.block + response["user_message"] = result.message + elif result.payload.event_type == EventType.PRE_TOOL_USE: + response["permission"] = "deny" if result.block else "allow" + response["user_message"] = result.message + response["agent_message"] = result.message + elif result.payload.event_type == EventType.POST_TOOL_USE: + pass # Nothing to do here + else: + # Should not happen, but just in case + click.echo("{}") + return 2 if result.block else 0 + + click.echo(json.dumps(response)) + # We don't use the return 2 convention to make sure our JSON output is read. + return 0 + + @property + def settings_path(self) -> Path: + return Path(".cursor") / "hooks.json" + + @property + def settings_template(self) -> Dict[str, Any]: + return { + "version": 1, + "hooks": { + "beforeSubmitPrompt": [{"command": ""}], + "preToolUse": [{"command": ""}], + "postToolUse": [{"command": ""}], + }, + } + + def settings_locate( + self, candidates: List[Dict[str, Any]], template: Dict[str, Any] + ) -> Optional[Dict[str, Any]]: + # We only have one kind of lists: in each hook. Simply look for "ggshield" or "" in the command. + for obj in candidates: + command = obj.get("command", "") + if "ggshield" in command or "" in command: + return obj + return None + + def project_mcp_file(self, directory: Path) -> Path: + return directory / ".cursor" / "mcp.json" + + def discover_project_directories(self) -> Iterator[Path]: + # Because Cursor is based on VS Code, we can reuse the same logic than Copilot. + user_folder = get_user_home_dir() / ".config" / "Cursor" / "User" + for file in user_folder.glob("workspaceStorage/*/workspace.json"): + if (data := self._load_json_file(file)) and "folder" in data: + path = Path(data["folder"].removeprefix("file://")) + if path.is_dir(): + yield path.resolve() + + def discover_capabilities(self, server: MCPServer) -> bool: + # General Cursor strategy: + # For each project where Cursor was used, it created a folder with the project name + # in its configuration folder. Inside that folder, it stores metadata for every + # MCP server available in that project. + for configuration in server.configurations: + # Look for Cursor configurations + if configuration.agent != self.name: + continue + # We need a folder. Note: this also works for user-level configurations. + # as Cursor will have a `home-` "project". + if configuration.project is None: + continue + + # Lookup where Cursor stores the capabilities for the given project. + mangled = configuration.project.as_posix().replace("/", "-").lstrip("-") + folder = self.config_folder / "projects" / mangled / "mcps" + if not folder.exists(): + continue + # Look for a SERVER_METADATA.json file with the expected name. + # (each subfolder corresponds to a different MCP server) + for file in folder.glob("*/SERVER_METADATA.json"): + metadata = self._load_json_file(file) + if metadata and metadata.get("serverName") == configuration.name: + # Found it! Update the folder + folder = file.parent + break + else: + # We didn't find our MCP server's metadata. Try next configuration. + continue + + # If we reach this code, we found our MCP server's metadata folder. + # Hopefully it is connected. If not, Cursor creates a STATUS.md file. + if (folder / "STATUS.md").exists(): + # Don't go further, we may risk discovering only an "mcp_auth" tool + # whereas the MCP server may be properly connected in another project. + continue + + filled = False + # Tools + for file in folder.glob("tools/*.json"): + tool = self._load_json_file(file) + if not isinstance(tool, dict) or "name" not in tool: + continue + server.tools.append( + MCPToolInfo( + name=tool["name"], + description=tool.get("description", ""), + arguments=_parse_tool_arguments(tool.get("arguments")), + ) + ) + filled = True + # Resources + for file in folder.glob("resources/*.json"): + resource = self._load_json_file(file) + if not isinstance(resource, dict) or "uri" not in resource: + continue + server.resources.append( + MCPResourceInfo( + uri=resource["uri"], + name=resource.get("name", ""), + description=resource.get("description", ""), + mime_type=resource.get("mimeType", ""), + ) + ) + filled = True + # Prompts + for file in folder.glob("prompts/*.json"): + prompt = self._load_json_file(file) + if not isinstance(prompt, dict) or "name" not in prompt: + continue + server.prompts.append( + MCPPromptInfo( + name=prompt["name"], description=prompt.get("description", "") + ) + ) + filled = True + if filled: + # Discovery done. Early return. + return True + + return False + + def parse_mcp_activity( + self, payload: HookPayload, ai_config: AIDiscovery + ) -> MCPActivityRequest: + """Parse the MCP activity from an MCP hook payload.""" + + # Cursor only sends the MCP tool, not the server. + # Fortunately, we should have been able to discover the tools earlier. + + tools_to_server = {} + for server in ai_config.servers: + for tool in server.tools: + # Hopefully we won't have duplicates + tools_to_server[tool.name] = server.name + + raw_tool_name: str = payload.raw.get("tool_name", "") + tool_name = raw_tool_name.removeprefix("MCP:") + server_name = tools_to_server.get(tool_name, "") + + return MCPActivityRequest( + user=ai_config.user, + tool=tool_name, + server=server_name, + agent=self.name, + model=payload.raw.get("model", ""), + cwd=Path(payload.raw.get("workspace_roots", [""])[0]), + input=payload.raw.get("tool_input", {}), + ) + + +def _parse_tool_arguments( + schema: Optional[Dict[str, Any]], +) -> Optional[List[MCPArgumentInfo]]: + """Parse a JSON-Schema ``arguments`` object into a list of MCPArgumentInfo. + + The schema is expected to follow the standard MCP tool descriptor format:: + + {"type": "object", "properties": {...}, "required": [...]} + """ + if not isinstance(schema, dict): + return None + properties = schema.get("properties") + if not isinstance(properties, dict): + return None + required_set = set(schema.get("required", [])) + arguments: List[MCPArgumentInfo] = [] + for name, prop in properties.items(): + if not isinstance(prop, dict): + continue + arguments.append( + MCPArgumentInfo( + name=name, + type=prop.get("type", "string"), + description=prop.get("description"), + required=name in required_set, + ) + ) + return arguments or None diff --git a/ggshield/verticals/ai/config.py b/ggshield/verticals/ai/config.py new file mode 100644 index 0000000000..9ebfa96c66 --- /dev/null +++ b/ggshield/verticals/ai/config.py @@ -0,0 +1,75 @@ +""" +MCP configuration loading utilities. +""" + +import json +from pathlib import Path +from typing import Any, Dict, List, Optional, Union + +from ggshield.core.dirs import get_user_home_dir + + +def load_json_file(path: Path) -> Union[Dict[str, Any], List[Any]]: + try: + return json.loads(path.read_text()) + except (OSError, json.JSONDecodeError): + return {} + + +def save_json_file(path: Path, data: Union[Dict[str, Any], List[Any]]) -> None: + try: + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(json.dumps(data, indent=4)) + except OSError: + pass + + +def load_mcp_config(workspace_roots: List[str]) -> Dict[str, Any]: + for workspace in workspace_roots: + workspace_mcp = Path(workspace) / ".cursor" / "mcp.json" + if workspace_mcp.exists(): + config = load_json_file(workspace_mcp) + if isinstance(config, dict): + return config + + global_mcp = get_user_home_dir() / ".cursor" / "mcp.json" + if global_mcp.exists(): + config = load_json_file(global_mcp) + if isinstance(config, dict): + return config + + return {} + + +def extract_host_from_config(server_config: Optional[Dict[str, Any]]) -> Optional[str]: + if not server_config: + return None + + env_vars = server_config.get("env", {}) + + if "CLICKHOUSE_HOST" in env_vars: + return env_vars["CLICKHOUSE_HOST"] + + if "GITLAB_API_URL" in env_vars: + url = env_vars["GITLAB_API_URL"] + return url.replace("https://", "").replace("http://", "").split("/")[0] + + args = server_config.get("args", []) + for arg in args: + if isinstance(arg, str): + if arg.startswith("https://"): + return arg.replace("https://", "").split("/")[0] + if arg.startswith("http://"): + return arg.replace("http://", "").split("/")[0] + + return None + + +def get_mcp_remote_url(server_config: Dict[str, Any]) -> Optional[str]: + args = server_config.get("args", []) + for arg in args: + if isinstance(arg, str) and ( + arg.startswith("http://") or arg.startswith("https://") + ): + return arg + return None diff --git a/ggshield/verticals/ai/discovery.py b/ggshield/verticals/ai/discovery.py new file mode 100644 index 0000000000..ad7662a38e --- /dev/null +++ b/ggshield/verticals/ai/discovery.py @@ -0,0 +1,137 @@ +""" +MCP Discovery - Discovers MCP server configurations and manages probe result caches. +""" + +import json +from collections import defaultdict +from pathlib import Path +from time import monotonic +from typing import Dict, List, Optional + +from pygitguardian import GGClient + +from ggshield.core.dirs import get_cache_dir + +from .agents import AGENTS +from .config import save_json_file +from .models import AIDiscovery, MCPConfiguration, MCPServer +from .user import get_user_info + + +AI_DISCOVERY_CACHE_FILENAME = "ai_discovery.json" +AI_DISCOVERY_SUBMIT_ENDPOINT = "nhi/ai/discovery" + + +def refresh_and_maybe_submit_discovery(client: GGClient) -> AIDiscovery: + """Always run discovery, compare with cache, submit only if changed.""" + cached = load_discovery_cache() + # If we already have a machine id, reuse it. + machine_id = cached.user.machine_id if cached else None + discovery = discover_ai_configuration(machine_id=machine_id) + + # Nothing changed, + if cached and not discovery.has_changed_from(cached): + return cached + + try: + # Get the updated version of the discovery, filled with data from the API. + discovery = submit_ai_discovery(client, discovery) + save_discovery_cache(discovery) + except Exception: + pass # TODO(paul): handle error + + return discovery + + +def discover_ai_configuration(machine_id: Optional[str] = None) -> AIDiscovery: + """ + Discover configurations from all supported assistants. + + Args: + directories: additional project directories to scan. + """ + start_time = monotonic() + mcp_configurations: List[MCPConfiguration] = [] + + # Discovered project directories + projects = {Path.cwd().resolve()} + for agent in AGENTS.values(): + projects.update(agent.discover_project_directories()) + + # Discover MCP configurations + for agent in AGENTS.values(): + mcp_configurations.extend(agent.discover_mcp_configurations(projects)) + + # Merge MCP configurations into servers + servers = _merge_mcp_configurations(mcp_configurations) + + # Try to find the servers' capabilities + for server in servers: + for agent in AGENTS.values(): + if agent.discover_capabilities(server): + # Discovery succeeded for this server. Early return. + break + + # Add user information + user = get_user_info(machine_id=machine_id) + discovery_duration = monotonic() - start_time + return AIDiscovery( + user=user, + servers=servers, + discovery_duration=discovery_duration, + ) + + +def save_discovery_cache(config: AIDiscovery) -> None: + """ + Save probe results to cache. + """ + cache_path = get_cache_dir() / AI_DISCOVERY_CACHE_FILENAME + save_json_file(cache_path, config.model_dump(mode="json")) + + +def submit_ai_discovery(client: GGClient, discovery: AIDiscovery) -> AIDiscovery: + """ + Send discovery results to the GitGuardian API. + + Returns the updated discovery. Raises an exception if the request fails. + """ + response = client.post( + endpoint=AI_DISCOVERY_SUBMIT_ENDPOINT, + data=discovery.model_dump(mode="json"), + ) + assert response.status_code == 200, response.text + return AIDiscovery.model_validate_json(response.text) + + +def load_discovery_cache() -> Optional[AIDiscovery]: + """Load discovery cache if it exists. + + Returns None if the cache does not exist. + """ + cache_path = get_cache_dir() / AI_DISCOVERY_CACHE_FILENAME + if not cache_path.exists(): + return None + try: + return AIDiscovery.model_validate_json(cache_path.read_text()) + except (OSError, json.JSONDecodeError): + return None + + +def _merge_mcp_configurations( + mcp_configurations: List[MCPConfiguration], +) -> List[MCPServer]: + """Merge MCP configurations into servers. + + This is a first naive deduplication of MCP configurations based on their name. + Deduplicating is useful to avoid discovering capabilities for the same server multiple times. + We expect it to be improved by GIM later. + """ + servers: Dict[str, List[MCPConfiguration]] = defaultdict(list) + for configuration in mcp_configurations: + servers[configuration.name].append(configuration) + + return [ + MCPServer(name=name, configurations=configurations) + for name, configurations in servers.items() + ] diff --git a/ggshield/verticals/secret/ai_hook/scanner.py b/ggshield/verticals/ai/hooks.py similarity index 57% rename from ggshield/verticals/secret/ai_hook/scanner.py rename to ggshield/verticals/ai/hooks.py index 7eacfe5abe..7dc035dd01 100644 --- a/ggshield/verticals/secret/ai_hook/scanner.py +++ b/ggshield/verticals/ai/hooks.py @@ -6,15 +6,37 @@ from notifypy import Notify from ggshield.core.filter import censor_match +from ggshield.core.scan import ScannerProtocol +from ggshield.core.scan import SecretProtocol as Secret from ggshield.core.scanner_ui import create_message_only_scanner_ui from ggshield.core.text_utils import pluralize, translate_validity -from ggshield.verticals.secret import SecretScanner -from ggshield.verticals.secret.ai_hook.copilot import Copilot -from ggshield.verticals.secret.secret_scan_collection import Secret -from .claude_code import Claude -from .cursor import Cursor -from .models import EventType, Flavor, Payload, Result, Tool +from .agents import Claude, Copilot, Cursor +from .models import Agent, EventType, HookPayload, HookResult, Tool + + +HOOK_NAME_TO_EVENT_TYPE = { + "userpromptsubmit": EventType.USER_PROMPT, + "beforesubmitprompt": EventType.USER_PROMPT, + "pretooluse": EventType.PRE_TOOL_USE, + "posttooluse": EventType.POST_TOOL_USE, +} + +TOOL_NAME_TO_TOOL = { + "shell": Tool.BASH, # Cursor + "bash": Tool.BASH, # Claude Code + "run_in_terminal": Tool.BASH, # Copilot + "read": Tool.READ, # Claude/Cursor + "read_file": Tool.READ, # Copilot +} + + +def lookup(data: Dict[str, Any], keys: Sequence[str], default: Any = None) -> Any: + """Returns the value of the first key found in a dictionary.""" + for key in keys: + if key in data: + return data[key] + return default # Regex (and method) to look for any @file_path in the prompt. @@ -41,12 +63,129 @@ def find_filepaths(prompt: str) -> Set[str]: return paths +def parse_hook_input(raw_content: str) -> list[HookPayload]: + """Parse the input content. Raises a ValueError if the input is not valid. + + Returns: + A list of payloads. Most of the time the list will contain only one payload, + but in some cases files mentioned in the prompt will be read but the + PreToolUse event will not be called. So we need to handle this case ourselves. + """ + # Parse the content as JSON + if not raw_content.strip(): + raise ValueError("Error: No input received on stdin") + try: + data = json.loads(raw_content) + except json.JSONDecodeError as e: + raise ValueError(f"Error: Failed to parse JSON from stdin: {e}") from e + + payloads = [] + + # Try to guess which AI coding assistant is calling us + agent = _detect_agent(data) + + # Infer the event type + event_name = lookup(data, ["hook_event_name", "hookEventName"], None) + if event_name is None: + raise ValueError("Error: couldn't find event type") + event_type = HOOK_NAME_TO_EVENT_TYPE.get(event_name.lower(), EventType.OTHER) + + identifier = "" + content = "" + tool = None + + # Extract the identifier and content based on the event type + if event_type == EventType.USER_PROMPT: + content = data.get("prompt", "") + # Look for files mentioned in the prompt that could be read + # without triggering a PRE_TOOL_USE event. + payloads.extend(_parse_user_prompt(content, event_type, agent)) + + elif event_type == EventType.PRE_TOOL_USE: + tool = _parse_tool(data) + tool_input = data.get("tool_input", {}) + # Select the content based on the tool + if tool == Tool.BASH: + content = tool_input.get("command", "") + identifier = content + elif tool == Tool.READ: + # We only need to deal with the identifier, the content will be read by the Scannable + identifier = lookup(tool_input, ["file_path", "filePath"], "") + + elif event_type == EventType.POST_TOOL_USE: + tool = _parse_tool(data) + content = data.get("tool_output", "") or data.get("tool_response", {}) + # Claude Code returns a dict for the tool output + if isinstance(content, (dict, list)): + content = json.dumps(content) + + # If identifier was not set, hash the content + if not identifier: + identifier = hashlib.sha256((content or "").encode()).hexdigest() + + payloads.append( + HookPayload( + event_type=event_type, + tool=tool, + content=content, + identifier=identifier, + agent=agent, + raw=data, + ) + ) + return payloads + + +def _parse_tool(data: Dict[str, Any]) -> Tool: + """Parse the tool name.""" + tool_name = data.get("tool_name", "").lower() + if tool_name.startswith("mcp"): + return Tool.MCP + return TOOL_NAME_TO_TOOL.get(tool_name, Tool.OTHER) + + +def _detect_agent(data: Dict[str, Any]) -> Agent: + """Detect the AI code assistant.""" + if "cursor_version" in data: + return Cursor() + elif "github.copilot-chat" in data.get("transcript_path", "").lower(): + return Copilot() + # no .lower() here to reduce the risk of false positives (this is also why this check is last) + elif "session_id" in data and "claude" in data.get("transcript_path", ""): + return Claude() + # No other agent is supported yet + raise ValueError("Unsupported agent") + + +def _parse_user_prompt( + content: str, event_type: EventType, agent: Agent +) -> List[HookPayload]: + """Parse the user prompt for additional payloads that we may miss.""" + payloads = [] + # Scenario 1 (the only one we know about so far): + # Code assistants don't always trigger a PRE_TOOL_USE event when + # a file is mentioned in the prompt, especially with an "@" prefix. + matches = find_filepaths(content) + for match in matches: + payloads.append( + HookPayload( + event_type=event_type, + tool=Tool.READ, + content="", + identifier=match, + agent=agent, + raw={}, + ) + ) + return payloads + + class AIHookScanner: """AI hook scanner. It is called with the payload of a hook event. Note that instead of having a base class with common method and a subclass per supported AI tool, - we instead have a single class which detects which protocol to use (called "flavor"). + we instead have a single class which detects which protocol to use. This is because some tools sloppily support hooks from others. For instance, Cursor will call hooks defined in the Claude Code format, but send payload in its own format. So we can't assume which tool will call us based on the command line/hook configuration only. @@ -55,98 +194,27 @@ class AIHookScanner: ValueError: If the input is not valid. """ - def __init__(self, scanner: SecretScanner): + def __init__(self, scanner: ScannerProtocol): self.scanner = scanner def scan(self, content: str) -> int: """Scan the content, print the result and return the exit code.""" - payloads = self._parse_input(content) + payloads = parse_hook_input(content) result = self._scan_payloads(payloads) payload = result.payload # Special case: in post-tool use, the action is already done: at least notify the user if result.block and payload.event_type == EventType.POST_TOOL_USE: self._send_secret_notification( - result.nbr_secrets, payload.tool or Tool.OTHER, payload.flavor.name + result.nbr_secrets, + payload.tool or Tool.OTHER, + payload.agent.display_name, ) - return payload.flavor.output_result(result) - - def _parse_input(self, raw_content: str) -> list[Payload]: - """Parse the input content. Raises a ValueError if the input is not valid. - - Returns: - A list of payloads. Most of the time the list will contain only one payload, - but in some cases files mentioned in the prompt will be read but the - PreToolUse event will not be called. So we need to handle this case ourselves. - """ - # Parse the content as JSON - if not raw_content.strip(): - raise ValueError("Error: No input received on stdin") - try: - data = json.loads(raw_content) - except json.JSONDecodeError as e: - raise ValueError(f"Error: Failed to parse JSON from stdin: {e}") from e - - payloads = [] - - # Try to guess which AI coding assistant is calling us - flavor = self._detect_flavor(data) - - # Infer the event type - event_name = lookup(data, ["hook_event_name", "hookEventName"], None) - if event_name is None: - raise ValueError("Error: couldn't find event type") - event_type = HOOK_NAME_TO_EVENT_TYPE.get(event_name.lower(), EventType.OTHER) - - identifier = "" - content = "" - tool = None - - # Extract the identifier and content based on the event type - if event_type == EventType.USER_PROMPT: - content = data.get("prompt", "") - # Look for files mentioned in the prompt that could be read - # without triggering a PRE_TOOL_USE event. - payloads.extend(self._parse_user_prompt(content, event_type, flavor)) - - elif event_type == EventType.PRE_TOOL_USE: - tool_name = data.get("tool_name", "").lower() - tool = TOOL_NAME_TO_TOOL.get(tool_name, Tool.OTHER) - tool_input = data.get("tool_input", {}) - # Select the content based on the tool - if tool == Tool.BASH: - content = tool_input.get("command", "") - identifier = content - elif tool == Tool.READ: - # We only need to deal with the identifier, the content will be read by the Scannable - identifier = lookup(tool_input, ["file_path", "filePath"], "") - - elif event_type == EventType.POST_TOOL_USE: - tool_name = data.get("tool_name", "").lower() - tool = TOOL_NAME_TO_TOOL.get(tool_name, Tool.OTHER) - content = data.get("tool_output", "") or data.get("tool_response", {}) - # Claude Code returns a dict for the tool output - if isinstance(content, (dict, list)): - content = json.dumps(content) - - # If identifier was not set, hash the content - if not identifier: - identifier = hashlib.sha256((content or "").encode()).hexdigest() - - payloads.append( - Payload( - event_type=event_type, - tool=tool, - content=content, - identifier=identifier, - flavor=flavor, - ) - ) - return payloads + return payload.agent.output_result(result) - def _scan_payloads(self, payloads: List[Payload]) -> Result: + def _scan_payloads(self, payloads: List[HookPayload]) -> HookResult: """Scan payloads for secrets using the SecretScanner. Returns: @@ -159,16 +227,16 @@ def _scan_payloads(self, payloads: List[Payload]) -> Result: result = self._scan_content(payload) if result.block: return result - return Result.allow(payloads[0]) + return HookResult.allow(payloads[0]) def _scan_content( self, - payload: Payload, - ) -> Result: + payload: HookPayload, + ) -> HookResult: """Scan content for secrets using the SecretScanner.""" # Short path: if there is no content, no need to do an API call if payload.empty: - return Result.allow(payload) + return HookResult.allow(payload) with create_message_only_scanner_ui() as scanner_ui: results = self.scanner.scan([payload.scannable], scanner_ui=scanner_ui) @@ -178,58 +246,25 @@ def _scan_content( secrets.extend(result.secrets) if not secrets: - return Result.allow(payload) + return HookResult.allow(payload) message = self._message_from_secrets( secrets, payload, escape_markdown=True, ) - return Result( + return HookResult( block=True, message=message, nbr_secrets=len(secrets), payload=payload, ) - @staticmethod - def _detect_flavor(data: Dict[str, Any]) -> Flavor: - """Detect the AI code assistant.""" - if "cursor_version" in data: - return Cursor() - elif "github.copilot-chat" in data.get("transcript_path", "").lower(): - return Copilot() - # no .lower() here to reduce the risk of false positives (this is also why this check is last) - elif "session_id" in data and "claude" in data.get("transcript_path", ""): - return Claude() - else: - # Fallback that respect base conventions - return Flavor() - - def _parse_user_prompt( - self, content: str, event_type: EventType, flavor: Flavor - ) -> List[Payload]: - """Parse the user prompt for additional payloads that we may miss.""" - payloads = [] - # Scenario 1 (the only one we know about so far): - # Code assistants don't always trigger a PRE_TOOL_USE event when - # a file is mentioned in the prompt, especially with an "@" prefix. - matches = find_filepaths(content) - for match in matches: - payloads.append( - Payload( - event_type=event_type, - tool=Tool.READ, - content="", - identifier=match, - flavor=flavor, - ) - ) - return payloads - @staticmethod def _message_from_secrets( - secrets: List[Secret], payload: Payload, escape_markdown: bool = False + secrets: List[Secret], + payload: HookPayload, + escape_markdown: bool = False, ) -> str: """ Format detected secrets into a user-friendly message. @@ -308,27 +343,3 @@ def _send_secret_notification( # This is best effort, we don't want to propagate an error # if the notification fails. pass - - -HOOK_NAME_TO_EVENT_TYPE = { - "userpromptsubmit": EventType.USER_PROMPT, - "beforesubmitprompt": EventType.USER_PROMPT, - "pretooluse": EventType.PRE_TOOL_USE, - "posttooluse": EventType.POST_TOOL_USE, -} - -TOOL_NAME_TO_TOOL = { - "shell": Tool.BASH, # Cursor - "bash": Tool.BASH, # Claude Code - "run_in_terminal": Tool.BASH, # Copilot - "read": Tool.READ, # Claude/Cursor - "read_file": Tool.READ, # Copilot -} - - -def lookup(data: Dict[str, Any], keys: Sequence[str], default: Any = None) -> Any: - """Returns the value of the first key found in a dictionary.""" - for key in keys: - if key in data: - return data[key] - return default diff --git a/ggshield/verticals/secret/ai_hook/installation.py b/ggshield/verticals/ai/installation.py similarity index 87% rename from ggshield/verticals/secret/ai_hook/installation.py rename to ggshield/verticals/ai/installation.py index 3139c7b3d5..ca328b374f 100644 --- a/ggshield/verticals/secret/ai_hook/installation.py +++ b/ggshield/verticals/ai/installation.py @@ -9,16 +9,7 @@ from ggshield.core.dirs import get_user_home_dir from ggshield.core.errors import UnexpectedError -from .claude_code import Claude -from .copilot import Copilot -from .cursor import Cursor - - -AI_FLAVORS = { - "cursor": Cursor, - "claude-code": Claude, - "copilot": Copilot, -} +from .agents import AGENTS @dataclass @@ -41,12 +32,12 @@ def install_hooks( """ try: - flavor = AI_FLAVORS[name]() + agent = AGENTS[name] except KeyError: - raise ValueError(f"Unsupported tool name: {name}") + raise ValueError(f"Unsupported agent: {name}") base_dir = get_user_home_dir() if mode == "global" else Path(".") - settings_path = base_dir / flavor.settings_path + settings_path = base_dir / agent.settings_path command = "ggshield secret scan ai-hook" @@ -71,11 +62,11 @@ def install_hooks( stats = _fill_dict( config=existing_config, - template=flavor.settings_template, + template=agent.settings_template, command=command, overwrite=force, stats=stats, - locator=flavor.settings_locate, + locator=agent.settings_locate, ) # Ensure parent directory exists @@ -89,11 +80,11 @@ def install_hooks( # Report what happened styled_path = click.style(settings_path, fg="yellow", bold=True) if stats.added == 0 and stats.already_present > 0: - click.echo(f"{flavor.name} hooks already installed in {styled_path}") + click.echo(f"{agent.display_name} hooks already installed in {styled_path}") elif stats.added > 0 and stats.already_present > 0: - click.echo(f"{flavor.name} hooks updated in {styled_path}") + click.echo(f"{agent.display_name} hooks updated in {styled_path}") else: - click.echo(f"{flavor.name} hooks successfully added in {styled_path}") + click.echo(f"{agent.display_name} hooks successfully added in {styled_path}") return 0 diff --git a/ggshield/verticals/ai/mcp.py b/ggshield/verticals/ai/mcp.py new file mode 100644 index 0000000000..835d8ee57b --- /dev/null +++ b/ggshield/verticals/ai/mcp.py @@ -0,0 +1,46 @@ +from pydantic import ValidationError +from pygitguardian import GGClient +from requests.exceptions import ConnectionError as RequestsConnectionError + +from ggshield.verticals.ai.discovery import refresh_and_maybe_submit_discovery + +from .models import HookPayload, MCPActivityResponse + +# GitGuardian API (v1): POST JSON body matching MCPActivityRequest; response matches MCPActivityResponse. +MCP_ACTIVITY_SUBMIT_ENDPOINT = "nhi/ai/mcp-activity" + + +def _mcp_activity_fail_open() -> MCPActivityResponse: + return MCPActivityResponse(allowed=True, reason="") + + +def send_mcp_activity(client: GGClient, payload: HookPayload) -> MCPActivityResponse: + """Build the MCP activity request and send it to the GitGuardian API. + + Args: + client: GitGuardian API client (same instance as secret scans). + payload: Hook payload for the MCP pre-tool event. + + Returns: + Policy response from the API, or allow if the request fails (fail-open). + """ + + ai_config = refresh_and_maybe_submit_discovery(client) + request = payload.agent.parse_mcp_activity(payload, ai_config) + body = request.model_dump(mode="json") + + try: + response = client.post( + endpoint=MCP_ACTIVITY_SUBMIT_ENDPOINT, + data=body, + ) + except RequestsConnectionError: + return _mcp_activity_fail_open() + + if not response.ok: + return _mcp_activity_fail_open() + + try: + return MCPActivityResponse.model_validate_json(response.text) + except ValidationError: + return _mcp_activity_fail_open() diff --git a/ggshield/verticals/ai/models.py b/ggshield/verticals/ai/models.py new file mode 100644 index 0000000000..68392e5443 --- /dev/null +++ b/ggshield/verticals/ai/models.py @@ -0,0 +1,433 @@ +import json +from abc import ABC, abstractmethod +from collections.abc import Iterable, Iterator +from dataclasses import dataclass +from enum import Enum, auto +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + +from pydantic import BaseModel, Field + +from ggshield.core.scan import File, Scannable, StringScannable +from ggshield.utils.files import is_path_binary + + +class EventType(Enum): + """Event type constants for hook events.""" + + USER_PROMPT = auto() + PRE_TOOL_USE = auto() + POST_TOOL_USE = auto() + # We are not interested in other less generic events for now + # (most of the time, one of the three above will also be called anyway) + OTHER = auto() + + +class Tool(Enum): + """Tool constants for hook events.""" + + BASH = auto() + READ = auto() + MCP = auto() + # We are not interested in other tools for now + OTHER = auto() + + +@dataclass +class HookResult: + """Result of a scan: allow or not.""" + + block: bool + message: str + nbr_secrets: int + payload: "HookPayload" + + @classmethod + def allow(cls, payload: "HookPayload") -> "HookResult": + return cls(block=False, message="", nbr_secrets=0, payload=payload) + + +@dataclass +class HookPayload: + event_type: EventType + tool: Optional[Tool] + content: str + identifier: str + agent: "Agent" + raw: Dict[str, Any] + + @property + def scannable(self) -> Scannable: + """Return the appropriate Scannable for the payload.""" + if self.tool == Tool.READ: + path = Path(self.identifier) + if path.is_file() and not is_path_binary(path): + return File(path=self.identifier) + return StringScannable(url=self.identifier, content=self.content) + + @property + def empty(self) -> bool: + """Return True if the payload is empty.""" + return not self.scannable.is_longer_than(0) + + +class Transport(Enum): + STDIO = "stdio" + HTTP = "http" + SSE = "sse" + + +class Scope(Enum): + USER = "user" + PROJECT = "project" + + +class MCPArgumentInfo(BaseModel): + name: str + type: str + description: Optional[str] = None + required: bool = False + + +class MCPToolInfo(BaseModel): + name: str + description: Optional[str] = None + arguments: Optional[List[MCPArgumentInfo]] = None + + +class MCPResourceInfo(BaseModel): + uri: str + name: Optional[str] = None + description: Optional[str] = None + mime_type: Optional[str] = None + + +class MCPPromptInfo(BaseModel): + name: str + description: Optional[str] = None + + +ConfigurationKey = Tuple[str, Scope, Optional[Path], str] + + +class MCPConfiguration(BaseModel): + name: str + agent: str + scope: Scope + transport: Transport + project: Optional[Path] = None + # stdio fields + command: Optional[str] = None + args: List[str] = Field(default_factory=list) + env: Dict[str, str] = Field(default_factory=dict) + # remote fields + url: Optional[str] = None + headers: Dict[str, str] = Field(default_factory=dict) + + @property + def key(self) -> ConfigurationKey: + """Return a unique key for the configuration. + + We consider that their should be unicity over (agent, scope, project, name) + """ + return (self.agent, self.scope, self.project, self.name) + + +class MCPServer(BaseModel): + name: str + display_name: Optional[str] = None + tools: List[MCPToolInfo] = Field(default_factory=list) + resources: List[MCPResourceInfo] = Field(default_factory=list) + prompts: List[MCPPromptInfo] = Field(default_factory=list) + configurations: List[MCPConfiguration] = Field(default_factory=list) + + def has_capabilities_unknown_to(self, other: "MCPServer") -> bool: + """Check if the server has capabilities unknown to the other server.""" + # Note: we assume that if we have discovered a capability, + # then we have everything (name, description, etc.) + # So we simply check if we have names the other doesn't. + other_tools = {tool.name for tool in other.tools} + for tool in self.tools: + if tool.name not in other_tools: + return True + + other_resources = {resource.uri for resource in other.resources} + for resource in self.resources: + if resource.uri not in other_resources: + return True + + other_prompts = {prompt.name for prompt in other.prompts} + for prompt in self.prompts: + if prompt.name not in other_prompts: + return True + + return False + + +class UserInfo(BaseModel): + hostname: str + username: str + machine_id: str + user_email: Optional[str] = None + + +class AIDiscovery(BaseModel): + user: UserInfo + servers: List[MCPServer] = Field(default_factory=list) + # Metadata for analytics + discovery_duration: float # in s + + @property + def configurations(self) -> Iterator[MCPConfiguration]: + """Iterate over all MCP configurations.""" + for server in self.servers: + yield from server.configurations + + def has_changed_from(self, other: "AIDiscovery") -> bool: + """Check if the discovery has changed since a previous discovery.""" + # We compare : + # 1. user info exactly + if self.user != other.user: + return True + + # 2. MCP configurations should be the same (both in number and content) + other_configurations = {conf.key: conf for conf in other.configurations} + new_configurations = {conf.key: conf for conf in self.configurations} + if other_configurations != new_configurations: + return True + + # 3. Servers may have been overriden by GIM, but we still want to detect + # whether we discovered new capabilities unknown to GIM. + # First, build a map to find the server(s) to compare to + # (we know that the keys will be exactly our configurations, thanks to step 2) + other_servers: Dict[ConfigurationKey, MCPServer] = {} + for server in other.servers: + for configuration in server.configurations: + other_servers[configuration.key] = server + # Then, for each server we found, check if we have capabilities unknown to GIM + for server in self.servers: + # No data, no need to compare + if not server.tools and not server.resources and not server.prompts: + continue + # We don't know how our naive deduplication have been overriden by GIM, + # so we need to compare capabilities to each distinct destination. + candidates_by_name: Dict[str, MCPServer] = {} + for conf in server.configurations: + other_server = other_servers[conf.key] + candidates_by_name[other_server.name] = other_server + # If at least one has our capabilities, then no need to update. + # said otherwise, update if all candidates don't have our capabilities. + if all( + server.has_capabilities_unknown_to(candidate) + for candidate in candidates_by_name.values() + ): + return True + + return False + + +class MCPActivityResponse(BaseModel): + allowed: bool + reason: str + + +class MCPActivityRequest(BaseModel): + user: UserInfo + tool: str + server: str + agent: str + model: str + cwd: Path + input: Dict[str, Any] + + +class Agent(ABC): + """ + Class that can be derived to implement behavior specific to some AI code assistants. + """ + + # Properties + + @property + @abstractmethod + def display_name(self) -> str: + """A user-friendly name for the agent.""" + + @property + @abstractmethod + def name(self) -> str: + """The name of the agent.""" + + @property + @abstractmethod + def config_folder(self) -> Path: + """The folder where the assistant's config files are stored.""" + + # Hooks + + @abstractmethod + def output_result(self, result: HookResult) -> int: + """How to output the result of a scan. + + This method is expected to have side effects, like printing to stdout or stderr. + + Args: + result: the result of the scan. + + Returns: the exit code. + """ + + # Settings + + @property + @abstractmethod + def settings_path(self) -> Path: + """Path to the settings file for this AI coding tool.""" + + @property + @abstractmethod + def settings_template(self) -> Dict[str, Any]: + """ + Template for the settings file for this AI coding tool. + Use the sentinel "" for the places where the command should be inserted. + """ + + @abstractmethod + def settings_locate( + self, candidates: List[Dict[str, Any]], template: Dict[str, Any] + ) -> Optional[Dict[str, Any]]: + """Callback used to help locate the correct object to update in the settings. + + We don't want to overwrite other hooks nor create duplicates, so when the existing + hook configuration is traversed and we end up in a list, this callback is used to + locate the correct object to update. + + Args: + candidates: the list of objects at the level currently traversed. + template: the template of the expected object. + + Returns: the object to update, or None if no object was found. + """ + return None + + # Discovery + + @abstractmethod + def project_mcp_file(self, directory: Path) -> Path: + """The file where MCP servers are configured at the project level.""" + + @abstractmethod + def discover_project_directories(self) -> Iterator[Path]: + """Discover project directories by scraping config or history files.""" + + def _parse_servers_block( + self, + data: Dict[str, Dict[str, Any]], + scope: Scope, + project: Optional[Path], + ) -> Iterator[MCPConfiguration]: + """Utility function to parse a "mcpServer" block and return the MCP server entries. + + The format is standard across all assistants. + """ + # Lookup the two usual conventions + servers = data.get("mcpServers", data.get("servers", {})) + for name, entry in servers.items(): + if "url" in entry: + if entry.get("transport") == "sse": + transport = Transport.SSE + else: + transport = Transport.HTTP + else: + transport = Transport.STDIO + + yield MCPConfiguration( + name=name, + agent=self.name, + scope=scope, + transport=transport, + project=project, + command=entry.get("command"), + args=entry.get("args", []), + env=entry.get("env", {}), + url=entry.get("url"), + headers=entry.get("headers", {}), + ) + + def _get_user_mcp_configurations(self) -> Iterator[MCPConfiguration]: + """Return the MCP server entries for user-level (global) config files. + + Default implementation looks in the config folder for a file named "mcp.json". + """ + # Load config file + filepath = self.config_folder / "mcp.json" + if not (data := self._load_json_file(filepath)): + return + yield from self._parse_servers_block(data, Scope.USER, None) + + def _get_project_mcp_configurations( + self, directory: Path + ) -> Iterator[MCPConfiguration]: + """Return the MCP server entries for project-level config files.""" + if data := self._load_json_file(self.project_mcp_file(directory)): + yield from self._parse_servers_block(data, Scope.PROJECT, directory) + + def discover_mcp_configurations( + self, directories: Iterable[Path] + ) -> List[MCPConfiguration]: + """Discover MCP configurations from user and project config files. + + Iterates over user-level paths, then project-level paths for each + directory in *directories*. + """ + results: List[MCPConfiguration] = [] + + # User-level configs + results.extend(self._get_user_mcp_configurations()) + + # Project-level configs + for directory in directories: + results.extend(self._get_project_mcp_configurations(directory)) + + return results + + def discover_capabilities(self, server: MCPServer) -> bool: + """Discover capabilities for the given server. + + Returns whether the capabilities were discovered. + """ + return False + + @abstractmethod + def parse_mcp_activity( + self, payload: HookPayload, ai_config: AIDiscovery + ) -> MCPActivityRequest: + """Parse the MCP activity from an MCP hook payload.""" + + # Helper methods + + def _load_json_file(self, path: Path) -> Optional[Dict[str, Any]]: + """Load a JSON file and return the data, or None if the file doesn't exist (or is not a JSON object).""" + if not path.is_file(): + return None + try: + data = json.loads(path.read_text()) + except (OSError, json.JSONDecodeError): + return None + if not isinstance(data, dict): + return None + return data + + def _load_jsonl_file(self, path: Path) -> Iterator[Dict[str, Any]]: + """Load a JSONL file and return the data line by line, + or nothing if the file doesn't exist (or is not a JSON object).""" + if not path.is_file(): + yield from [] + try: + for line in open(path, "r"): + try: + yield json.loads(line) + except json.JSONDecodeError: + continue + except OSError: + yield from [] diff --git a/ggshield/verticals/ai/user.py b/ggshield/verticals/ai/user.py new file mode 100644 index 0000000000..c7db72580c --- /dev/null +++ b/ggshield/verticals/ai/user.py @@ -0,0 +1,202 @@ +""" +This module provides information about the current user and machine. + +It should be eventually merged with the logic used by local scanning. +""" + +import getpass +import logging +import os +import platform +import re +import socket +import subprocess +import sys +import uuid +from pathlib import Path +from typing import Optional + +from ggshield.core.dirs import get_user_home_dir +from ggshield.verticals.ai.models import UserInfo + + +logger = logging.getLogger(__name__) + +_MAC_IOREG_UUID_RE = re.compile(r'"IOPlatformUUID"\s*=\s*"([^"]+)"') + + +def get_user_info(machine_id: Optional[str] = None) -> UserInfo: + """Collect hostname, username, machine identifier, and best-effort email.""" + return UserInfo( + hostname=_get_hostname(), + username=_get_username(), + machine_id=machine_id or _get_machine_id(), + user_email=_get_user_email(), + ) + + +def _get_hostname() -> str: + if sys.platform == "win32": + name = (os.environ.get("COMPUTERNAME") or "").strip() + if name: + return name + try: + return socket.gethostname() or "unknown" + except OSError: + return "unknown" + + +def _get_username() -> str: + try: + return getpass.getuser() + except Exception: + pass + try: + return os.getlogin() + except Exception: + return "unknown" + + +def _get_user_email() -> Optional[str]: + """Best-effort user email; tries `git config user.email` first.""" + try: + result = subprocess.run( + ["git", "config", "user.email"], + capture_output=True, + text=True, + check=False, + ) + if result.returncode == 0: + email = (result.stdout or "").strip() + return email or None + except (OSError, subprocess.SubprocessError): + pass + return None + + +def _read_first_nonempty_line(path: Path) -> Optional[str]: + try: + text = path.read_text(encoding="utf-8", errors="replace").splitlines() + except OSError: + return None + for line in text: + stripped = line.strip() + if stripped: + return stripped + return None + + +def _get_linux_system_id() -> Optional[str]: + for candidate in ( + Path("/etc/machine-id"), + Path("/sys/class/dmi/id/product_uuid"), + Path("/var/lib/dbus/machine-id"), + ): + value = _read_first_nonempty_line(candidate) + if value: + return value + return None + + +def _get_macos_system_id() -> Optional[str]: + try: + result = subprocess.run( + ["ioreg", "-rd1", "-c", "IOPlatformExpertDevice"], + capture_output=True, + text=True, + check=False, + ) + if result.returncode != 0 or not result.stdout: + return None + match = _MAC_IOREG_UUID_RE.search(result.stdout) + if match: + return match.group(1).strip() + except (OSError, subprocess.SubprocessError): + pass + return None + + +def _parse_wmic_uuid(stdout: str) -> Optional[str]: + for line in stdout.splitlines(): + line = line.strip() + if not line or line.upper() == "UUID": + continue + try: + return str(uuid.UUID(line)) + except ValueError: + pass + return None + + +def _get_windows_system_id() -> Optional[str]: + try: + result = subprocess.run( + ["wmic", "csproduct", "get", "uuid"], + capture_output=True, + text=True, + check=False, + ) + if result.returncode == 0 and result.stdout: + parsed = _parse_wmic_uuid(result.stdout) + if parsed: + return parsed + except (OSError, subprocess.SubprocessError): + pass + + try: + result = subprocess.run( + [ + "powershell", + "-NoProfile", + "-Command", + "(Get-CimInstance Win32_ComputerSystemProduct).UUID", + ], + capture_output=True, + text=True, + check=False, + ) + if result.returncode == 0 and result.stdout: + try: + return str(uuid.UUID(result.stdout.strip())) + except ValueError: + pass + except (OSError, subprocess.SubprocessError): + pass + return None + + +def _get_machine_id() -> str: + + # In case Satori generated a machine id, use it. + path = get_user_home_dir() / ".satori" / "machine_id" + try: + if path.is_file(): + cached = _read_first_nonempty_line(path) + if cached: + return cached + except OSError: + pass + + system = platform.system().lower() + system_id = None + + if system == "darwin": + system_id = _get_macos_system_id() + elif system == "linux": + system_id = _get_linux_system_id() + elif sys.platform == "win32": + system_id = _get_windows_system_id() + + if system_id: + return system_id + + # If everything failed, use a random UUID. + # Store it so that satori can use it. + new_id = str(uuid.uuid4()) + try: + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(new_id + "\n") + except OSError: + pass + + return new_id diff --git a/ggshield/verticals/secret/ai_hook/__init__.py b/ggshield/verticals/secret/ai_hook/__init__.py deleted file mode 100644 index 68d4170d4a..0000000000 --- a/ggshield/verticals/secret/ai_hook/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from .installation import AI_FLAVORS, install_hooks -from .scanner import AIHookScanner - - -__all__ = ["AIHookScanner", "install_hooks", "AI_FLAVORS"] diff --git a/ggshield/verticals/secret/ai_hook/claude_code.py b/ggshield/verticals/secret/ai_hook/claude_code.py deleted file mode 100644 index 378490f904..0000000000 --- a/ggshield/verticals/secret/ai_hook/claude_code.py +++ /dev/null @@ -1,102 +0,0 @@ -import json -from pathlib import Path -from typing import Any, Dict, List, Optional - -import click - -from .models import EventType, Flavor, Result - - -class Claude(Flavor): - """Behavior specific to Claude Code.""" - - name = "Claude Code" - - def output_result(self, result: Result) -> int: - response = {} - if result.block: - if result.payload.event_type in [ - EventType.USER_PROMPT, - EventType.POST_TOOL_USE, - ]: - response["decision"] = "block" - response["reason"] = result.message - response["additionalContext"] = result.message - elif result.payload.event_type == EventType.PRE_TOOL_USE: - response["hookSpecificOutput"] = { - "hookEventName": "PreToolUse", - "permissionDecision": "deny", - "permissionDecisionReason": result.message, - } - else: - # Should not happen, but just in case use Claude's "universal" fields. - response = { - "continue": False, - "stopReason": result.message, - } - else: - response["continue"] = True - - click.echo(json.dumps(response)) - # We don't use the return 2 convention to make sure our JSON output is read. - return 0 - - @property - def settings_path(self) -> Path: - return Path(".claude") / "settings.json" - - @property - def settings_template(self) -> Dict[str, Any]: - return { - "hooks": { - "PreToolUse": [ - { - "matcher": ".*", - "hooks": [ - { - "type": "command", - "command": "", - } - ], - } - ], - "PostToolUse": [ - { - "matcher": ".*", - "hooks": [ - { - "type": "command", - "command": "", - } - ], - } - ], - "UserPromptSubmit": [ - { - "matcher": ".*", - "hooks": [ - { - "type": "command", - "command": "", - } - ], - } - ], - } - } - - def settings_locate( - self, candidates: List[Dict[str, Any]], template: Dict[str, Any] - ) -> Optional[Dict[str, Any]]: - # We have two kind of lists: at the root of each hook (with a matcher) - # and in each hook (with a list of commands). - if "matcher" in template: - for obj in candidates: - if obj.get("matcher") == template["matcher"]: - return obj - return None - for obj in candidates: - command = obj.get("command", "") - if "ggshield" in command or "" in command: - return obj - return None diff --git a/ggshield/verticals/secret/ai_hook/copilot.py b/ggshield/verticals/secret/ai_hook/copilot.py deleted file mode 100644 index b523daa70b..0000000000 --- a/ggshield/verticals/secret/ai_hook/copilot.py +++ /dev/null @@ -1,41 +0,0 @@ -import json -from pathlib import Path - -import click - -from .claude_code import Claude -from .models import EventType, Result - - -class Copilot(Claude): - """Behavior specific to Copilot Chat. - - Inherits most of its behavior from Claude Code. - """ - - name = "Copilot" - - def output_result(self, result: Result) -> int: - response = {} - if result.block: - if result.payload.event_type == EventType.PRE_TOOL_USE: - response["hookSpecificOutput"] = { - "hookEventName": "PreToolUse", - "permissionDecision": "deny", - "permissionDecisionReason": result.message, - } - elif result.payload.event_type == EventType.POST_TOOL_USE: - response["decision"] = "block" - response["reason"] = result.message - else: - response["continue"] = False - response["stopReason"] = result.message - else: - response["continue"] = True - - click.echo(json.dumps(response)) - return 0 - - @property - def settings_path(self) -> Path: - return Path(".github") / "hooks" / "hooks.json" diff --git a/ggshield/verticals/secret/ai_hook/cursor.py b/ggshield/verticals/secret/ai_hook/cursor.py deleted file mode 100644 index 3b20c46fec..0000000000 --- a/ggshield/verticals/secret/ai_hook/cursor.py +++ /dev/null @@ -1,58 +0,0 @@ -import json -from pathlib import Path -from typing import Any, Dict, List, Optional - -import click - -from .models import EventType, Flavor, Result - - -class Cursor(Flavor): - """Behavior specific to Cursor.""" - - name = "Cursor" - - def output_result(self, result: Result) -> int: - response = {} - if result.payload.event_type == EventType.USER_PROMPT: - response["continue"] = not result.block - response["user_message"] = result.message - elif result.payload.event_type == EventType.PRE_TOOL_USE: - response["permission"] = "deny" if result.block else "allow" - response["user_message"] = result.message - response["agent_message"] = result.message - elif result.payload.event_type == EventType.POST_TOOL_USE: - pass # Nothing to do here - else: - # Should not happen, but just in case - click.echo("{}") - return 2 if result.block else 0 - - click.echo(json.dumps(response)) - # We don't use the return 2 convention to make sure our JSON output is read. - return 0 - - @property - def settings_path(self) -> Path: - return Path(".cursor") / "hooks.json" - - @property - def settings_template(self) -> Dict[str, Any]: - return { - "version": 1, - "hooks": { - "beforeSubmitPrompt": [{"command": ""}], - "preToolUse": [{"command": ""}], - "postToolUse": [{"command": ""}], - }, - } - - def settings_locate( - self, candidates: List[Dict[str, Any]], template: Dict[str, Any] - ) -> Optional[Dict[str, Any]]: - # We only have one kind of lists: in each hook. Simply look for "ggshield" or "" in the command. - for obj in candidates: - command = obj.get("command", "") - if "ggshield" in command or "" in command: - return obj - return None diff --git a/ggshield/verticals/secret/ai_hook/models.py b/ggshield/verticals/secret/ai_hook/models.py deleted file mode 100644 index a031e2eea3..0000000000 --- a/ggshield/verticals/secret/ai_hook/models.py +++ /dev/null @@ -1,127 +0,0 @@ -from dataclasses import dataclass -from enum import Enum, auto -from pathlib import Path -from typing import Any, Dict, List, Optional - -import click - -from ggshield.core.scan import File, Scannable, StringScannable -from ggshield.utils.files import is_path_binary - - -MAX_READ_SIZE = 1024 * 1024 * 50 # We restrict payloads read to 50MB - - -class EventType(Enum): - """Event type constants for hook events.""" - - USER_PROMPT = auto() - PRE_TOOL_USE = auto() - POST_TOOL_USE = auto() - # We are not interested in other less generic events for now - # (most of the time, one of the three above will also be called anyway) - OTHER = auto() - - -class Tool(Enum): - """Tool constants for hook events.""" - - BASH = auto() - READ = auto() - # We are not interested in other tools for now - OTHER = auto() - - -@dataclass -class Result: - """Result of a scan: allow or not.""" - - block: bool - message: str - nbr_secrets: int - payload: "Payload" - - @classmethod - def allow(cls, payload: "Payload") -> "Result": - return cls(block=False, message="", nbr_secrets=0, payload=payload) - - -class Flavor: - """ - Class that can be derived to implement behavior specific to some AI code assistants. - """ - - name = "Your AI coding tool" - - def output_result(self, result: Result) -> int: - """How to output the result of a scan. - - This base implementation has sensible defaults (like returning 2 in case of a block, - and printing the output in stderr or stdout). - - This method is expected to have side effects, like printing to stdout or stderr. - - Args: - result: the result of the scan. - - Returns: the exit code. - """ - if result.block: - click.echo(result.message, err=True) - return 2 - else: - click.echo("No secrets found. Good to go.") - return 0 - - @property - def settings_path(self) -> Path: - """Path to the settings file for this AI coding tool.""" - return Path(".agents") / "hooks.json" - - @property - def settings_template(self) -> Dict[str, Any]: - """ - Template for the settings file for this AI coding tool. - Use the sentinel "" for the places where the command should be inserted. - """ - return {} - - def settings_locate( - self, candidates: List[Dict[str, Any]], template: Dict[str, Any] - ) -> Optional[Dict[str, Any]]: - """Callback used to help locate the correct object to update in the settings. - - We don't want to overwrite other hooks nor create duplicates, so when the existing - hook configuration is traversed and we end up in a list, this callback is used to - locate the correct object to update. - - Args: - candidates: the list of objects at the level currently traversed. - template: the template of the expected object. - - Returns: the object to update, or None if no object was found. - """ - return None - - -@dataclass -class Payload: - event_type: EventType - tool: Optional[Tool] - content: str - identifier: str - flavor: Flavor - - @property - def scannable(self) -> Scannable: - """Return the appropriate Scannable for the payload.""" - if self.tool == Tool.READ: - path = Path(self.identifier) - if path.is_file() and not is_path_binary(path): - return File(path=self.identifier) - return StringScannable(url=self.identifier, content=self.content) - - @property - def empty(self) -> bool: - """Return True if the payload is empty.""" - return not self.scannable.is_longer_than(0) diff --git a/pyproject.toml b/pyproject.toml index a3540af8fb..e345eab353 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,6 +49,7 @@ dependencies = [ "urllib3>=2.2.2,<3", "truststore>=0.10.1; python_version >= \"3.10\"", "notify-py>=0.3.43", + "pydantic>=2.11.10", ] [project.urls] @@ -61,6 +62,9 @@ ggshield = "ggshield.__main__:main" requires = ["hatchling"] build-backend = "hatchling.build" +[[tool.uv.index]] +url = "https://gitlab.gitguardian.ovh/api/v4/projects/435/packages/pypi/simple" + [tool.hatch.version] path = "ggshield/__init__.py" diff --git a/scripts/generate-import-linter-config.py b/scripts/generate-import-linter-config.py index bdb0823275..3bf75f2a76 100755 --- a/scripts/generate-import-linter-config.py +++ b/scripts/generate-import-linter-config.py @@ -64,8 +64,10 @@ class Contract(TypedDict): "ggshield.cmd.{}.** -> ggshield.verticals.{}.**", # FIXME: #521 - enforce boundaries between cmd.auth and verticals.hmsl "ggshield.cmd.auth.** -> ggshield.verticals.hmsl.**", - # Logic to install hooks for AI assistants - "ggshield.cmd.install -> ggshield.verticals.secret.ai_hook", + # Install command import logic to install AI hooks + "ggshield.cmd.install -> ggshield.verticals.ai.installation", + # AI hook command import logic to scan AI hook payloads + "ggshield.cmd.secret.scan.ai_hook -> ggshield.verticals.ai.hooks", ], "unmatched_ignore_imports_alerting": "none", }, diff --git a/tests/unit/verticals/ai/test_agents.py b/tests/unit/verticals/ai/test_agents.py new file mode 100644 index 0000000000..3361929529 --- /dev/null +++ b/tests/unit/verticals/ai/test_agents.py @@ -0,0 +1,486 @@ +import json +from pathlib import Path +from typing import Any, Dict, List, Optional +from unittest.mock import patch + +import pytest + +from ggshield.verticals.ai.agents import Agent +from ggshield.verticals.ai.agents.claude_code import Claude, _mangle_server_name +from ggshield.verticals.ai.agents.copilot import Copilot +from ggshield.verticals.ai.agents.cursor import Cursor, _parse_tool_arguments +from ggshield.verticals.ai.models import ( + AIDiscovery, + EventType, + HookPayload, + MCPConfiguration, + MCPServer, + MCPToolInfo, + Scope, + Tool, + Transport, + UserInfo, +) + + +def _user() -> UserInfo: + return UserInfo( + hostname="host", username="user", machine_id="mid", user_email="u@e.com" + ) + + +def _cfg( + name: str = "srv", + agent: str = "cursor", + scope: Scope = Scope.USER, + project: Optional[Path] = None, +) -> MCPConfiguration: + return MCPConfiguration( + name=name, agent=agent, scope=scope, transport=Transport.STDIO, project=project + ) + + +def _ai_discovery(servers: Optional[List[MCPServer]] = None) -> AIDiscovery: + return AIDiscovery(user=_user(), servers=servers or [], discovery_duration=0.1) + + +def _payload( + agent: Agent, raw: Optional[Dict[str, Any]] = None, tool: Tool = Tool.MCP +) -> HookPayload: + return HookPayload( + event_type=EventType.PRE_TOOL_USE, + tool=tool, + content="", + identifier="", + agent=agent, + raw=raw or {}, + ) + + +# =========================================================================== +# Cursor +# =========================================================================== + + +class TestCursorDiscoverCapabilities: + def _setup_mcps_folder( + self, tmp_path: Path, project_path: Path, server_name: str + ) -> Path: + """Build a Cursor-style mcps// folder layout and return the agent.""" + mangled = project_path.as_posix().replace("/", "-").lstrip("-") + mcps_root = tmp_path / ".cursor" / "projects" / mangled / "mcps" + server_dir = mcps_root / "my-server" + server_dir.mkdir(parents=True) + # SERVER_METADATA.json + (server_dir / "SERVER_METADATA.json").write_text( + json.dumps({"serverName": server_name}) + ) + return server_dir + + def test_populates_tools_resources_prompts(self, tmp_path: Path): + project = Path("/home/user/project") + server_dir = self._setup_mcps_folder(tmp_path, project, "my-mcp") + + (server_dir / "tools").mkdir() + (server_dir / "tools" / "t1.json").write_text( + json.dumps({"name": "do_thing", "description": "Does a thing"}) + ) + (server_dir / "resources").mkdir() + (server_dir / "resources" / "r1.json").write_text( + json.dumps({"uri": "file:///data", "name": "data"}) + ) + (server_dir / "prompts").mkdir() + (server_dir / "prompts" / "p1.json").write_text( + json.dumps({"name": "greeting", "description": "Says hi"}) + ) + + cursor = Cursor() + cfg = _cfg(name="my-mcp", agent="cursor", project=project) + server = MCPServer(name="my-mcp", configurations=[cfg]) + + with patch.object( + type(cursor), + "config_folder", + new_callable=lambda: property(lambda self: tmp_path / ".cursor"), + ): + result = cursor.discover_capabilities(server) + + assert result is True + assert len(server.tools) == 1 + assert server.tools[0].name == "do_thing" + assert len(server.resources) == 1 + assert server.resources[0].uri == "file:///data" + assert len(server.prompts) == 1 + assert server.prompts[0].name == "greeting" + + def test_status_md_present_returns_false(self, tmp_path: Path): + project = Path("/home/user/project") + server_dir = self._setup_mcps_folder(tmp_path, project, "my-mcp") + (server_dir / "STATUS.md").write_text("disconnected") + (server_dir / "tools").mkdir() + (server_dir / "tools" / "t1.json").write_text(json.dumps({"name": "t"})) + + cursor = Cursor() + cfg = _cfg(name="my-mcp", agent="cursor", project=project) + server = MCPServer(name="my-mcp", configurations=[cfg]) + + with patch.object( + type(cursor), + "config_folder", + new_callable=lambda: property(lambda self: tmp_path / ".cursor"), + ): + result = cursor.discover_capabilities(server) + + assert result is False + assert len(server.tools) == 0 + + def test_no_matching_metadata_returns_false(self, tmp_path: Path): + project = Path("/home/user/project") + self._setup_mcps_folder(tmp_path, project, "other-server") + + cursor = Cursor() + cfg = _cfg(name="my-mcp", agent="cursor", project=project) + server = MCPServer(name="my-mcp", configurations=[cfg]) + + with patch.object( + type(cursor), + "config_folder", + new_callable=lambda: property(lambda self: tmp_path / ".cursor"), + ): + result = cursor.discover_capabilities(server) + + assert result is False + + def test_non_cursor_configuration_skipped(self): + cursor = Cursor() + cfg = _cfg(name="srv", agent="claude-code", project=Path("/proj")) + server = MCPServer(name="srv", configurations=[cfg]) + assert cursor.discover_capabilities(server) is False + + +class TestCursorDiscoverProjectDirectories: + def test_valid_workspace_json_yields_path(self, tmp_path: Path): + project_dir = tmp_path / "myproject" + project_dir.mkdir() + ws_storage = ( + tmp_path / ".config" / "Cursor" / "User" / "workspaceStorage" / "abc" + ) + ws_storage.mkdir(parents=True) + (ws_storage / "workspace.json").write_text( + json.dumps({"folder": f"file://{project_dir}"}) + ) + + cursor = Cursor() + with patch.object( + type(cursor), + "config_folder", + new_callable=lambda: property( + lambda self: tmp_path / ".config" / "Cursor" / "User" + ), + ): + with patch( + "ggshield.verticals.ai.agents.cursor.get_user_home_dir", + return_value=tmp_path, + ): + dirs = list(cursor.discover_project_directories()) + + assert project_dir.resolve() in dirs + + def test_missing_folder_key_skipped(self, tmp_path: Path): + ws_storage = ( + tmp_path / ".config" / "Cursor" / "User" / "workspaceStorage" / "abc" + ) + ws_storage.mkdir(parents=True) + (ws_storage / "workspace.json").write_text(json.dumps({"other": "val"})) + + cursor = Cursor() + with patch.object( + type(cursor), + "config_folder", + new_callable=lambda: property( + lambda self: tmp_path / ".config" / "Cursor" / "User" + ), + ): + with patch( + "ggshield.verticals.ai.agents.cursor.get_user_home_dir", + return_value=tmp_path, + ): + dirs = list(cursor.discover_project_directories()) + + assert dirs == [] + + +class TestCursorParseMcpActivity: + def test_strips_mcp_prefix_and_maps_server(self): + cursor = Cursor() + tool_info = MCPToolInfo(name="run_query") + server = MCPServer( + name="my-db-server", + tools=[tool_info], + configurations=[_cfg(name="db", agent="cursor")], + ) + discovery = _ai_discovery(servers=[server]) + payload = _payload( + cursor, + raw={ + "tool_name": "MCP:run_query", + "model": "gpt-4", + "workspace_roots": ["/home/user/proj"], + "tool_input": {"sql": "SELECT 1"}, + }, + ) + + req = cursor.parse_mcp_activity(payload, discovery) + + assert req.tool == "run_query" + assert req.server == "my-db-server" + assert req.model == "gpt-4" + assert req.input == {"sql": "SELECT 1"} + + def test_unknown_tool_returns_empty_server(self): + cursor = Cursor() + discovery = _ai_discovery(servers=[]) + payload = _payload(cursor, raw={"tool_name": "MCP:unknown"}) + + req = cursor.parse_mcp_activity(payload, discovery) + + assert req.tool == "unknown" + assert req.server == "" + + +# =========================================================================== +# Claude Code +# =========================================================================== + + +class TestClaudeGetUserMcpConfigurations: + def test_user_level_and_project_level_parsed(self, tmp_path: Path): + project_dir = tmp_path / "myproject" + project_dir.mkdir() + claude_json = { + "mcpServers": {"global-srv": {"command": "npx", "args": ["-y", "mcp"]}}, + "projects": { + str(project_dir): { + "mcpServers": { + "project-srv": {"command": "node", "args": ["index.js"]} + } + } + }, + } + with patch( + "ggshield.verticals.ai.agents.claude_code.get_user_home_dir", + return_value=tmp_path, + ): + (tmp_path / ".claude.json").write_text(json.dumps(claude_json)) + claude = Claude() + configs = list(claude._get_user_mcp_configurations()) + + names = {c.name for c in configs} + assert "global-srv" in names + assert "project-srv" in names + + def test_missing_file_yields_nothing(self, tmp_path: Path): + with patch( + "ggshield.verticals.ai.agents.claude_code.get_user_home_dir", + return_value=tmp_path, + ): + claude = Claude() + configs = list(claude._get_user_mcp_configurations()) + assert configs == [] + + +class TestClaudeDiscoverProjectDirectories: + def test_yields_existing_directories(self, tmp_path: Path): + project = tmp_path / "proj" + project.mkdir() + history = tmp_path / ".claude" / "history.jsonl" + history.parent.mkdir(parents=True) + history.write_text(json.dumps({"project": str(project)}) + "\n") + + claude = Claude() + with patch.object( + type(claude), + "config_folder", + new_callable=lambda: property(lambda self: tmp_path / ".claude"), + ): + dirs = list(claude.discover_project_directories()) + + assert project.resolve() in dirs + + def test_skips_nonexistent_directories(self, tmp_path: Path): + history = tmp_path / ".claude" / "history.jsonl" + history.parent.mkdir(parents=True) + history.write_text( + json.dumps({"project": str(tmp_path / "nonexistent")}) + "\n" + ) + + claude = Claude() + with patch.object( + type(claude), + "config_folder", + new_callable=lambda: property(lambda self: tmp_path / ".claude"), + ): + dirs = list(claude.discover_project_directories()) + + assert dirs == [] + + +class TestClaudeParseMcpActivity: + def test_parses_mcp_double_underscore_format(self): + claude = Claude() + cfg = _cfg(name="my.server", agent="claude-code") + server = MCPServer( + name="my.server", configurations=[cfg], tools=[MCPToolInfo(name="run")] + ) + discovery = _ai_discovery(servers=[server]) + # Claude mangles "my.server" → "my_server" in the tool name + payload = _payload( + claude, + raw={"tool_name": "mcp__my_server__run", "cwd": "/tmp", "tool_input": {}}, + ) + + req = claude.parse_mcp_activity(payload, discovery) + + assert req.tool == "run" + assert req.server == "my.server" + + def test_server_with_double_underscore_handled(self): + claude = Claude() + discovery = _ai_discovery(servers=[]) + payload = _payload( + claude, + raw={ + "tool_name": "mcp__a__b__tool_name", + "cwd": "/tmp", + "tool_input": {}, + }, + ) + + req = claude.parse_mcp_activity(payload, discovery) + + assert req.tool == "tool_name" + assert req.server == "a__b" # falls back to mangled name + + def test_fallback_to_mangled_name(self): + claude = Claude() + discovery = _ai_discovery(servers=[]) + payload = _payload( + claude, + raw={"tool_name": "mcp__unknown__do_it", "cwd": "/tmp", "tool_input": {}}, + ) + + req = claude.parse_mcp_activity(payload, discovery) + + assert req.server == "unknown" + + +# --------------------------------------------------------------------------- +# _mangle_server_name +# --------------------------------------------------------------------------- + + +class TestMangleServerName: + @pytest.mark.parametrize( + "name, expected", + [ + pytest.param("my-seRver-123", "my-seRver-123", id="alphanumeric_dashes"), + pytest.param( + "my.server/v2 alpha", "my_server_v2_alpha", id="special_chars" + ), + pytest.param("simple", "simple", id="plain_alpha"), + pytest.param("a@b#c", "a_b_c", id="symbols"), + ], + ) + def test_mangle_server_name(self, name: str, expected: str): + assert _mangle_server_name(name) == expected + + +# =========================================================================== +# Copilot +# =========================================================================== + + +class TestCopilotParseMcpActivity: + def test_simple_server_tool_split(self): + copilot = Copilot() + cfg = _cfg(name="myserver", agent="copilot") + server = MCPServer(name="myserver", configurations=[cfg]) + discovery = _ai_discovery(servers=[server]) + payload = _payload( + copilot, + raw={"tool_name": "myserver_mytool", "cwd": "/tmp", "tool_input": {}}, + ) + + req = copilot.parse_mcp_activity(payload, discovery) + + assert req.tool == "mytool" + assert req.server == "myserver" + + def test_multiple_underscores_raises(self): + """Documents that `split("_")` with >2 parts raises ValueError.""" + copilot = Copilot() + discovery = _ai_discovery(servers=[]) + payload = _payload( + copilot, + raw={ + "tool_name": "server_tool_name_extra", + "cwd": "/tmp", + "tool_input": {}, + }, + ) + + with pytest.raises(ValueError): + copilot.parse_mcp_activity(payload, discovery) + + def test_unknown_server_falls_back_to_cfg_name(self): + copilot = Copilot() + discovery = _ai_discovery(servers=[]) + payload = _payload( + copilot, + raw={"tool_name": "unknown_tool", "cwd": "/tmp", "tool_input": {}}, + ) + + req = copilot.parse_mcp_activity(payload, discovery) + + assert req.server == "unknown" + assert req.tool == "tool" + + +# =========================================================================== +# _parse_tool_arguments (Cursor helper) +# =========================================================================== + + +class TestParseToolArguments: + def test_valid_schema(self): + schema = { + "type": "object", + "properties": { + "query": {"type": "string", "description": "SQL query"}, + "limit": {"type": "integer"}, + }, + "required": ["query"], + } + result = _parse_tool_arguments(schema) + assert result is not None + assert len(result) == 2 + q = next(a for a in result if a.name == "query") + assert q.required is True + assert q.description == "SQL query" + lim = next(a for a in result if a.name == "limit") + assert lim.required is False + + def test_empty_properties_returns_none(self): + schema = {"type": "object", "properties": {}} + assert _parse_tool_arguments(schema) is None + + @pytest.mark.parametrize( + "schema", + [ + pytest.param(None, id="none"), + pytest.param("string", id="string"), + pytest.param(42, id="integer"), + ], + ) + def test_non_dict_schema_returns_none(self, schema: Any): + assert _parse_tool_arguments(schema) is None diff --git a/tests/unit/verticals/ai/test_cmd_ai.py b/tests/unit/verticals/ai/test_cmd_ai.py new file mode 100644 index 0000000000..75b33932e1 --- /dev/null +++ b/tests/unit/verticals/ai/test_cmd_ai.py @@ -0,0 +1,154 @@ +import json +from unittest.mock import patch + +from click.testing import CliRunner + +from ggshield.__main__ import cli +from ggshield.core.errors import APIKeyCheckError +from ggshield.verticals.ai.models import AIDiscovery, UserInfo + + +def _user(): + return UserInfo( + hostname="host", username="user", machine_id="mid", user_email="u@e.com" + ) + + +def _discovery(): + return AIDiscovery(user=_user(), servers=[], discovery_duration=0.1) + + +# --------------------------------------------------------------------------- +# ggshield secret scan ai-hook +# --------------------------------------------------------------------------- + + +class TestAiHookCmd: + @patch("ggshield.cmd.secret.scan.ai_hook.AIHookScanner") + @patch("ggshield.cmd.secret.scan.ai_hook.SecretScanner") + @patch("ggshield.cmd.secret.scan.ai_hook.create_client_from_config") + def test_valid_json_stdin(self, mock_client, mock_scanner_cls, mock_hook_scanner): + instance = mock_hook_scanner.return_value + instance.scan.return_value = 0 + + runner = CliRunner() + result = runner.invoke( + cli, + ["secret", "scan", "ai-hook"], + input='{"event_type": "test"}', + ) + + assert result.exit_code == 0 + instance.scan.assert_called_once() + + @patch("ggshield.cmd.secret.scan.ai_hook.AIHookScanner") + @patch("ggshield.cmd.secret.scan.ai_hook.SecretScanner") + @patch("ggshield.cmd.secret.scan.ai_hook.create_client_from_config") + def test_empty_stdin_returns_error( + self, mock_client, mock_scanner_cls, mock_hook_scanner + ): + instance = mock_hook_scanner.return_value + instance.scan.side_effect = ValueError("Empty input") + + runner = CliRunner() + result = runner.invoke( + cli, + ["secret", "scan", "ai-hook"], + input="", + ) + + assert result.exit_code == 1 + + @patch("ggshield.cmd.secret.scan.ai_hook.AIHookScanner") + @patch("ggshield.cmd.secret.scan.ai_hook.SecretScanner") + @patch("ggshield.cmd.secret.scan.ai_hook.create_client_from_config") + def test_large_stdin_does_not_crash( + self, mock_client, mock_scanner_cls, mock_hook_scanner + ): + instance = mock_hook_scanner.return_value + instance.scan.return_value = 0 + + runner = CliRunner() + large_input = "x" * (1024 * 1024) # 1 MB + result = runner.invoke( + cli, + ["secret", "scan", "ai-hook"], + input=large_input, + ) + + assert result.exit_code == 0 + + +# --------------------------------------------------------------------------- +# ggshield ai discover +# --------------------------------------------------------------------------- + + +class TestDiscoverCmd: + @patch( + "ggshield.cmd.ai.discover.discover_ai_configuration", + return_value=_discovery(), + ) + @patch("ggshield.cmd.ai.discover.create_client_from_config") + @patch( + "ggshield.cmd.ai.discover.submit_ai_discovery", + return_value=_discovery(), + ) + @patch("ggshield.cmd.ai.discover.save_discovery_cache") + def test_default_output(self, mock_save, mock_submit, mock_client, mock_discover): + runner = CliRunner() + result = runner.invoke(cli, ["ai", "discover"]) + + assert result.exit_code == 0 + mock_discover.assert_called_once() + + @patch( + "ggshield.cmd.ai.discover.discover_ai_configuration", + return_value=_discovery(), + ) + @patch("ggshield.cmd.ai.discover.create_client_from_config") + @patch( + "ggshield.cmd.ai.discover.submit_ai_discovery", + return_value=_discovery(), + ) + @patch("ggshield.cmd.ai.discover.save_discovery_cache") + def test_json_flag(self, mock_save, mock_submit, mock_client, mock_discover): + runner = CliRunner() + result = runner.invoke(cli, ["ai", "discover", "--json"]) + + assert result.exit_code == 0 + parsed = json.loads(result.output) + assert "user" in parsed + + @patch( + "ggshield.cmd.ai.discover.discover_ai_configuration", + return_value=_discovery(), + ) + @patch( + "ggshield.cmd.ai.discover.create_client_from_config", + side_effect=APIKeyCheckError("https://api.gitguardian.com", "no key"), + ) + def test_auth_failure_shows_warning(self, mock_client, mock_discover): + runner = CliRunner() + result = runner.invoke(cli, ["ai", "discover"]) + + assert result.exit_code == 0 + assert "Skipping upload" in result.output or "warning" in result.output.lower() + + @patch( + "ggshield.cmd.ai.discover.discover_ai_configuration", + return_value=_discovery(), + ) + @patch("ggshield.cmd.ai.discover.create_client_from_config") + @patch( + "ggshield.cmd.ai.discover.submit_ai_discovery", + side_effect=RuntimeError("API error"), + ) + def test_api_submission_failure_shows_warning( + self, mock_submit, mock_client, mock_discover + ): + runner = CliRunner() + result = runner.invoke(cli, ["ai", "discover"]) + + assert result.exit_code == 0 + assert "Could not upload" in result.output or "warning" in result.output.lower() diff --git a/tests/unit/verticals/ai/test_config.py b/tests/unit/verticals/ai/test_config.py new file mode 100644 index 0000000000..b1919c70ca --- /dev/null +++ b/tests/unit/verticals/ai/test_config.py @@ -0,0 +1,177 @@ +import json +from pathlib import Path +from typing import Any, Dict, Optional +from unittest.mock import patch + +import pytest + +from ggshield.verticals.ai.config import ( + extract_host_from_config, + get_mcp_remote_url, + load_json_file, + load_mcp_config, + save_json_file, +) + +# --------------------------------------------------------------------------- +# load_json_file +# --------------------------------------------------------------------------- + + +class TestLoadJsonFile: + def test_valid_json_returns_dict(self, tmp_path: Path): + f = tmp_path / "data.json" + f.write_text('{"key": "value"}') + assert load_json_file(f) == {"key": "value"} + + def test_nonexistent_file_returns_empty_dict(self, tmp_path: Path): + assert load_json_file(tmp_path / "nope.json") == {} + + def test_malformed_json_returns_empty_dict(self, tmp_path: Path): + f = tmp_path / "bad.json" + f.write_text("{invalid json") + assert load_json_file(f) == {} + + def test_oserror_returns_empty_dict(self, tmp_path: Path): + d = tmp_path / "dir.json" + d.mkdir() # reading a directory triggers OSError + assert load_json_file(d) == {} + + +# --------------------------------------------------------------------------- +# save_json_file +# --------------------------------------------------------------------------- + + +class TestSaveJsonFile: + def test_writes_valid_json_and_creates_parents(self, tmp_path: Path): + f = tmp_path / "sub" / "dir" / "out.json" + save_json_file(f, {"a": 1}) + assert json.loads(f.read_text()) == {"a": 1} + + def test_oserror_silently_swallowed(self, tmp_path: Path): + f = tmp_path / "readonly" / "out.json" + (tmp_path / "readonly").mkdir() + (tmp_path / "readonly").chmod(0o444) + try: + save_json_file(f, {"a": 1}) # should not raise + finally: + (tmp_path / "readonly").chmod(0o755) + + +# --------------------------------------------------------------------------- +# load_mcp_config +# --------------------------------------------------------------------------- + + +class TestLoadMcpConfig: + def test_workspace_config_found(self, tmp_path: Path): + ws = tmp_path / "project" + (ws / ".cursor").mkdir(parents=True) + (ws / ".cursor" / "mcp.json").write_text('{"servers": {"s1": {}}}') + result = load_mcp_config([str(ws)]) + assert result == {"servers": {"s1": {}}} + + def test_first_workspace_wins(self, tmp_path: Path): + ws1 = tmp_path / "p1" + ws2 = tmp_path / "p2" + for ws, content in [(ws1, '{"first": true}'), (ws2, '{"second": true}')]: + (ws / ".cursor").mkdir(parents=True) + (ws / ".cursor" / "mcp.json").write_text(content) + result = load_mcp_config([str(ws1), str(ws2)]) + assert result == {"first": True} + + def test_falls_back_to_global(self, tmp_path: Path): + global_mcp = tmp_path / ".cursor" / "mcp.json" + global_mcp.parent.mkdir(parents=True) + global_mcp.write_text('{"global": true}') + with patch( + "ggshield.verticals.ai.config.get_user_home_dir", return_value=tmp_path + ): + result = load_mcp_config([str(tmp_path / "no_workspace")]) + assert result == {"global": True} + + def test_neither_exists_returns_empty(self, tmp_path: Path): + with patch( + "ggshield.verticals.ai.config.get_user_home_dir", return_value=tmp_path + ): + result = load_mcp_config([]) + assert result == {} + + def test_list_config_skipped(self, tmp_path: Path): + ws = tmp_path / "project" + (ws / ".cursor").mkdir(parents=True) + (ws / ".cursor" / "mcp.json").write_text("[1, 2, 3]") + with patch( + "ggshield.verticals.ai.config.get_user_home_dir", return_value=tmp_path + ): + result = load_mcp_config([str(ws)]) + assert result == {} + + +# --------------------------------------------------------------------------- +# extract_host_from_config +# --------------------------------------------------------------------------- + + +class TestExtractHostFromConfig: + @pytest.mark.parametrize( + "config, expected", + [ + pytest.param( + {"env": {"CLICKHOUSE_HOST": "ch.example.com"}}, + "ch.example.com", + id="clickhouse_host", + ), + pytest.param( + {"env": {"GITLAB_API_URL": "https://gitlab.corp.com/api/v4"}}, + "gitlab.corp.com", + id="gitlab_api_url", + ), + pytest.param( + {"args": ["--config", "https://api.example.com/v1/mcp"]}, + "api.example.com", + id="https_arg", + ), + pytest.param( + {"args": ["http://localhost:8080/path"]}, + "localhost:8080", + id="http_arg", + ), + pytest.param( + {"args": ["--flag", "value"], "env": {}}, + None, + id="no_match", + ), + pytest.param(None, None, id="none_input"), + ], + ) + def test_extract_host(self, config: Dict[str, Any], expected: Optional[str]): + assert extract_host_from_config(config) == expected + + +# --------------------------------------------------------------------------- +# get_mcp_remote_url +# --------------------------------------------------------------------------- + + +class TestGetMcpRemoteUrl: + @pytest.mark.parametrize( + "config, expected", + [ + pytest.param( + {"args": ["--flag", "https://remote.com/mcp", "http://other.com"]}, + "https://remote.com/mcp", + id="first_url_returned", + ), + pytest.param( + {"args": ["--flag", "not-a-url"]}, + None, + id="no_url", + ), + pytest.param({"args": []}, None, id="empty_args"), + pytest.param({}, None, id="no_args_key"), + ], + ) + def test_get_mcp_remote_url(self, config: Dict[str, Any], expected: Optional[str]): + assert get_mcp_remote_url(config) == expected diff --git a/tests/unit/verticals/ai/test_discovery.py b/tests/unit/verticals/ai/test_discovery.py new file mode 100644 index 0000000000..5e922ef595 --- /dev/null +++ b/tests/unit/verticals/ai/test_discovery.py @@ -0,0 +1,313 @@ +from pathlib import Path +from typing import Any +from unittest.mock import MagicMock, patch + +import pytest + +from ggshield.verticals.ai.discovery import ( + _merge_mcp_configurations, + discover_ai_configuration, + load_discovery_cache, + refresh_and_maybe_submit_discovery, + save_discovery_cache, + submit_ai_discovery, +) +from ggshield.verticals.ai.models import ( + AIDiscovery, + MCPConfiguration, + Scope, + Transport, + UserInfo, +) + + +def _user(**kwargs: Any) -> UserInfo: + defaults = dict( + hostname="host", username="user", machine_id="mid", user_email="u@e.com" + ) + return UserInfo.model_validate(defaults | kwargs) + + +def _cfg( + name: str = "srv", agent: str = "cursor", scope: Scope = Scope.USER +) -> MCPConfiguration: + return MCPConfiguration( + name=name, agent=agent, scope=scope, transport=Transport.STDIO + ) + + +def _discovery(**kwargs: Any) -> AIDiscovery: + defaults = dict(user=_user(), servers=[], discovery_duration=0.1) + return AIDiscovery.model_validate(defaults | kwargs) + + +# --------------------------------------------------------------------------- +# _merge_mcp_configurations +# --------------------------------------------------------------------------- + + +class TestMergeMcpConfigurations: + def test_different_names_produce_separate_servers(self): + configs = [_cfg(name="a"), _cfg(name="b")] + servers = _merge_mcp_configurations(configs) + assert len(servers) == 2 + names = {s.name for s in servers} + assert names == {"a", "b"} + + def test_same_name_merged_under_one_server(self): + configs = [_cfg(name="x", agent="cursor"), _cfg(name="x", agent="claude-code")] + servers = _merge_mcp_configurations(configs) + assert len(servers) == 1 + assert len(servers[0].configurations) == 2 + + def test_empty_list_returns_empty(self): + assert _merge_mcp_configurations([]) == [] + + +# --------------------------------------------------------------------------- +# discover_ai_configuration +# --------------------------------------------------------------------------- + + +class TestDiscoverAIConfiguration: + @patch("ggshield.verticals.ai.discovery.get_user_info", return_value=_user()) + @patch("ggshield.verticals.ai.discovery.AGENTS") + def test_aggregates_agents(self, mock_agents, mock_user_info, tmp_path: Path): + agent1 = MagicMock() + agent1.discover_project_directories.return_value = iter([tmp_path / "p1"]) + agent1.discover_mcp_configurations.return_value = [_cfg(name="s1")] + agent1.discover_capabilities.return_value = False + + agent2 = MagicMock() + agent2.discover_project_directories.return_value = iter([]) + agent2.discover_mcp_configurations.return_value = [_cfg(name="s2")] + agent2.discover_capabilities.return_value = False + + mock_agents.values.return_value = [agent1, agent2] + + result = discover_ai_configuration() + + assert result.user == _user() + assert len(result.servers) == 2 + assert result.discovery_duration > 0 + + @patch("ggshield.verticals.ai.discovery.get_user_info", return_value=_user()) + @patch("ggshield.verticals.ai.discovery.AGENTS") + def test_stops_capability_discovery_at_first_success( + self, mock_agents, mock_user_info + ): + agent1 = MagicMock() + agent1.discover_project_directories.return_value = iter([]) + agent1.discover_mcp_configurations.return_value = [_cfg(name="s")] + agent1.discover_capabilities.return_value = True + + agent2 = MagicMock() + agent2.discover_project_directories.return_value = iter([]) + agent2.discover_mcp_configurations.return_value = [] + agent2.discover_capabilities.return_value = False + + mock_agents.values.return_value = [agent1, agent2] + + discover_ai_configuration() + + agent1.discover_capabilities.assert_called_once() + agent2.discover_capabilities.assert_not_called() + + +# --------------------------------------------------------------------------- +# load / save discovery cache +# --------------------------------------------------------------------------- + + +class TestLoadSaveDiscoveryCache: + def test_round_trip(self, tmp_path: Path): + discovery = _discovery() + with patch( + "ggshield.verticals.ai.discovery.get_cache_dir", return_value=tmp_path + ): + save_discovery_cache(discovery) + loaded = load_discovery_cache() + assert loaded is not None + assert loaded.user == discovery.user + + def test_load_returns_none_when_missing(self, tmp_path: Path): + with patch( + "ggshield.verticals.ai.discovery.get_cache_dir", return_value=tmp_path + ): + assert load_discovery_cache() is None + + def test_load_returns_none_on_valid_json_bad_schema(self, tmp_path: Path): + """Valid JSON that doesn't match AIDiscovery schema.""" + cache_file = tmp_path / "ai_discovery.json" + cache_file.write_text('{"foo": "bar"}') + with patch( + "ggshield.verticals.ai.discovery.get_cache_dir", return_value=tmp_path + ): + from pydantic import ValidationError + + with pytest.raises(ValidationError): + load_discovery_cache() + + def test_load_returns_none_on_invalid_json(self, tmp_path: Path): + """Valid JSON that doesn't match AIDiscovery schema.""" + cache_file = tmp_path / "ai_discovery.json" + cache_file.write_text("not json") + with patch( + "ggshield.verticals.ai.discovery.get_cache_dir", return_value=tmp_path + ): + from pydantic import ValidationError + + with pytest.raises(ValidationError): + load_discovery_cache() + + def test_load_returns_none_on_oserror(self, tmp_path: Path): + with patch( + "ggshield.verticals.ai.discovery.get_cache_dir", return_value=tmp_path + ): + cache_file = tmp_path / "ai_discovery.json" + cache_file.mkdir() # directory instead of file triggers OSError + assert load_discovery_cache() is None + + +# --------------------------------------------------------------------------- +# submit_ai_discovery +# --------------------------------------------------------------------------- + + +class TestSubmitAIDiscovery: + def test_successful_response(self): + discovery = _discovery() + response = MagicMock() + response.status_code = 200 + response.text = discovery.model_dump_json() + client = MagicMock() + client.post.return_value = response + + result = submit_ai_discovery(client, discovery) + assert result.user == discovery.user + + def test_non_200_raises(self): + discovery = _discovery() + response = MagicMock() + response.status_code = 500 + response.text = "Internal Server Error" + client = MagicMock() + client.post.return_value = response + + with pytest.raises(AssertionError): + submit_ai_discovery(client, discovery) + + +# --------------------------------------------------------------------------- +# refresh_and_maybe_submit_discovery +# --------------------------------------------------------------------------- + + +class TestRefreshAndMaybeSubmitDiscovery: + def _patch_all(self): + return ( + patch( + "ggshield.verticals.ai.discovery.load_discovery_cache", + ), + patch( + "ggshield.verticals.ai.discovery.discover_ai_configuration", + ), + patch( + "ggshield.verticals.ai.discovery.submit_ai_discovery", + ), + patch( + "ggshield.verticals.ai.discovery.save_discovery_cache", + ), + ) + + def test_no_cache_submits_and_saves(self): + p_load, p_discover, p_submit, p_save = self._patch_all() + with ( + p_load as m_load, + p_discover as m_discover, + p_submit as m_submit, + p_save as m_save, + ): + m_load.return_value = None + new_disc = _discovery() + m_discover.return_value = new_disc + submitted = _discovery(discovery_duration=0.5) + m_submit.return_value = submitted + + result = refresh_and_maybe_submit_discovery(MagicMock()) + + m_submit.assert_called_once() + m_save.assert_called_once_with(submitted) + assert result == submitted + + def test_unchanged_returns_cache_without_submission(self): + cached = _discovery() + p_load, p_discover, p_submit, p_save = self._patch_all() + with ( + p_load as m_load, + p_discover as m_discover, + p_submit as m_submit, + p_save as m_save, + ): + m_load.return_value = cached + m_discover.return_value = cached # identical discovery + + result = refresh_and_maybe_submit_discovery(MagicMock()) + + m_submit.assert_not_called() + m_save.assert_not_called() + assert result == cached + + def test_changed_submits_and_saves(self): + cached = _discovery(user=_user(hostname="old")) + new_disc = _discovery(user=_user(hostname="new")) + p_load, p_discover, p_submit, p_save = self._patch_all() + with ( + p_load as m_load, + p_discover as m_discover, + p_submit as m_submit, + p_save as m_save, + ): + m_load.return_value = cached + m_discover.return_value = new_disc + m_submit.return_value = new_disc + + result = refresh_and_maybe_submit_discovery(MagicMock()) + + m_submit.assert_called_once() + m_save.assert_called_once() + + def test_api_error_swallowed(self): + p_load, p_discover, p_submit, p_save = self._patch_all() + with ( + p_load as m_load, + p_discover as m_discover, + p_submit as m_submit, + p_save as m_save, + ): + m_load.return_value = None + new_disc = _discovery() + m_discover.return_value = new_disc + m_submit.side_effect = RuntimeError("network") + + result = refresh_and_maybe_submit_discovery(MagicMock()) + + assert result == new_disc + m_save.assert_not_called() + + def test_reuses_machine_id_from_cache(self): + cached = _discovery(user=_user(machine_id="cached-id")) + p_load, p_discover, p_submit, p_save = self._patch_all() + with ( + p_load as m_load, + p_discover as m_discover, + p_submit as m_submit, + p_save as m_save, + ): + m_load.return_value = cached + m_discover.return_value = cached + + refresh_and_maybe_submit_discovery(MagicMock()) + + _, kwargs = m_discover.call_args + assert kwargs.get("machine_id") == "cached-id" diff --git a/tests/unit/verticals/secret/ai_hook/test_hooks.py b/tests/unit/verticals/ai/test_hooks.py similarity index 83% rename from tests/unit/verticals/secret/ai_hook/test_hooks.py rename to tests/unit/verticals/ai/test_hooks.py index f763cf76b4..021de1a128 100644 --- a/tests/unit/verticals/secret/ai_hook/test_hooks.py +++ b/tests/unit/verticals/ai/test_hooks.py @@ -7,18 +7,33 @@ import pytest from ggshield.utils.git_shell import Filemode +from ggshield.verticals.ai.agents import Claude, Copilot, Cursor +from ggshield.verticals.ai.hooks import AIHookScanner, find_filepaths, parse_hook_input +from ggshield.verticals.ai.models import EventType, HookPayload, HookResult, Tool from ggshield.verticals.secret import SecretScanner -from ggshield.verticals.secret.ai_hook.claude_code import Claude -from ggshield.verticals.secret.ai_hook.copilot import Copilot -from ggshield.verticals.secret.ai_hook.cursor import Cursor -from ggshield.verticals.secret.ai_hook.models import EventType, Flavor, Payload -from ggshield.verticals.secret.ai_hook.models import Result as HookResult -from ggshield.verticals.secret.ai_hook.models import Tool -from ggshield.verticals.secret.ai_hook.scanner import AIHookScanner, find_filepaths from ggshield.verticals.secret.secret_scan_collection import Result as ScanResult from ggshield.verticals.secret.secret_scan_collection import Results, Secret +def _dummy_payload(event_type: EventType = EventType.OTHER) -> HookPayload: + return HookPayload( + event_type=event_type, + tool=None, + content="", + identifier="", + agent=Cursor(), + raw={}, + ) + + +@pytest.fixture +def tmp_file(tmp_path: Path) -> Path: + """Create a temporary file with content.""" + file = tmp_path / "test.txt" + file.write_text("this is the content") + return file + + def _mock_scanner(matches: List[str]) -> MagicMock: """Create a mock SecretScanner that returns the given Results from scan().""" mock = MagicMock(spec=SecretScanner) @@ -62,26 +77,48 @@ def _make_secret(match_str: str = "***"): ) -def _dummy_payload(event_type: EventType = EventType.OTHER) -> Payload: - return Payload( - event_type=event_type, - tool=None, - content="", - identifier="", - flavor=Flavor(), - ) +class TestAIHookScannerScanContent: + """Unit tests for AIHookScanner._scan_content.""" + def test_no_secrets_returns_allow(self): + """When scanner returns no secrets, result has block=False and nbr_secrets=0.""" + hook_scanner = AIHookScanner(_mock_scanner([])) + payload = HookPayload( + event_type=EventType.USER_PROMPT, + tool=None, + content="safe content", + identifier="id", + agent=Cursor(), + raw={}, + ) + result = hook_scanner._scan_content(payload) + assert isinstance(result, HookResult) + assert result.block is False + assert result.nbr_secrets == 0 + assert result.message == "" -@pytest.fixture -def tmp_file(tmp_path: Path) -> Path: - """Create a temporary file with content.""" - file = tmp_path / "test.txt" - file.write_text("this is the content") - return file + def test_with_secrets_returns_block_and_message(self): + """When scanner returns secrets, result has block=True, nbr_secrets and message set.""" + hook_scanner = AIHookScanner(_mock_scanner(["sk-xxx"])) + payload = HookPayload( + event_type=EventType.USER_PROMPT, + tool=None, + content="content with sk-xxx", + identifier="id", + agent=Cursor(), + raw={}, + ) + result = hook_scanner._scan_content(payload) + assert isinstance(result, HookResult) + assert result.block is True + assert result.nbr_secrets == 1 + assert "dummy-detector" in result.message + assert "secret" in result.message.lower() + assert "remove the secrets from your prompt" in result.message -class TestAIHookScannerParseInput: - """Unit tests for AIHookScanner._parse_input.""" +class TestAIHookScannerScan: + """Unit tests for the AIHookScanner.scan() method.""" def test_empty_input_raises(self): """Empty or whitespace-only input raises ValueError.""" @@ -91,23 +128,182 @@ def test_empty_input_raises(self): with pytest.raises(ValueError, match="No input received on stdin"): scanner.scan(" \n ") + def test_scan_no_secrets_returns_zero(self): + """scan() with no secrets returns 0.""" + scanner = AIHookScanner(_mock_scanner([])) + data = { + "hook_event_name": "UserPromptSubmit", + "prompt": "hello world", + "transcript_path": "/home/user/.claude/projects/foo/session.jsonl", + "cursor_version": "1.2.3", + } + code = scanner.scan(json.dumps(data)) + assert code == 0 + + @patch("ggshield.verticals.ai.hooks.AIHookScanner._send_secret_notification") + def test_scan_post_tool_use_with_secrets_sends_notification( + self, mock_notify: MagicMock + ): + """scan() on POST_TOOL_USE with secrets sends a notification and returns 0 (no block).""" + scanner = AIHookScanner(_mock_scanner(["sk-xxx"])) + data = { + "hook_event_name": "PostToolUse", + "tool_name": "Bash", + "tool_input": {"command": "echo sk-xxx"}, + "tool_response": {"stdout": "sk-xxx\n"}, + "transcript_path": "/home/user/.claude/projects/foo/session.jsonl", + "session_id": "427ae0c5-0862-4e14-aa2c-12fad909c323", + } + code = scanner.scan(json.dumps(data)) + assert code == 0 + mock_notify.assert_called_once() + args = mock_notify.call_args[0] + assert args[0] == 1 # nbr_secrets + assert args[1] == Tool.BASH # tool + + def test_scan_pre_tool_use_with_secrets_blocks(self): + """scan() on PRE_TOOL_USE with secrets returns block result.""" + scanner = AIHookScanner(_mock_scanner(["sk-xxx"])) + data = { + "hook_event_name": "PreToolUse", + "tool_name": "Bash", + "tool_input": {"command": "echo sk-xxx"}, + "session_id": "427ae0c5-0862-4e14-aa2c-12fad909c323", + "transcript_path": "/home/user/.claude/projects/foo/session.jsonl", + } + code = scanner.scan(json.dumps(data)) + # Claude output_result always returns 0 + assert code == 0 + + def test_scan_no_content_returns_allow(self): + """scan() with no content returns 0 (and doesn't call the API).""" + mock_scanner = _mock_scanner([]) + scanner = AIHookScanner(mock_scanner) + data = { + "hook_event_name": "PreToolUse", + "tool_name": "Read", + "tool_input": {"file_path": "doesn-t-exist"}, + "cursor_version": "1.2.3", + } + code = scanner.scan(json.dumps(data)) + assert code == 0 + mock_scanner.scan.assert_not_called() + + def test_scan_payloads_refuse_empty_list(self): + """scan() with empty list of payloads raises ValueError.""" + scanner = AIHookScanner(_mock_scanner([])) + with pytest.raises(ValueError): + scanner._scan_payloads([]) + + +class TestMessageFromSecrets: + """Unit tests for AIHookScanner._message_from_secrets with different payload types.""" + + def test_message_for_bash_tool(self): + """Message for BASH tool mentions environment variables.""" + payload = HookPayload( + event_type=EventType.PRE_TOOL_USE, + tool=Tool.BASH, + content="echo sk-xxx", + identifier="echo sk-xxx", + agent=Cursor(), + raw={}, + ) + message = AIHookScanner._message_from_secrets([_make_secret("sk-xxx")], payload) + assert "remove the secrets from the command" in message + assert "environment variables" in message + + def test_message_for_read_tool(self): + """Message for READ tool mentions file content.""" + payload = HookPayload( + event_type=EventType.PRE_TOOL_USE, + tool=Tool.READ, + content="file content with secret", + identifier="/path/to/file", + agent=Cursor(), + raw={}, + ) + message = AIHookScanner._message_from_secrets([_make_secret("sk-xxx")], payload) + assert "remove the secrets from" in message + + def test_message_for_other_tool(self): + """Message for OTHER tool uses generic message.""" + payload = HookPayload( + event_type=EventType.PRE_TOOL_USE, + tool=Tool.OTHER, + content="some content", + identifier="id", + agent=Cursor(), + raw={}, + ) + message = AIHookScanner._message_from_secrets([_make_secret("sk-xxx")], payload) + assert "remove the secrets from the tool input" in message + + def test_message_escapes_markdown(self): + """When escape_markdown=True, asterisks in matches are replaced with dots.""" + payload = HookPayload( + event_type=EventType.USER_PROMPT, + tool=None, + content="content", + identifier="id", + agent=Cursor(), + raw={}, + ) + message = AIHookScanner._message_from_secrets( + [_make_secret("sk-xxx")], payload, escape_markdown=True + ) + # The message itself should not contain raw asterisks from matches + # (the header uses ** for bold which is intentional) + assert "Detected" in message + + +class TestSendSecretNotification: + """Unit tests for AIHookScanner._send_secret_notification.""" + + @patch("ggshield.verticals.ai.hooks.Notify") + def test_notification_for_bash_tool(self, mock_notify_cls: MagicMock): + """Notification for BASH tool says 'running a command'.""" + AIHookScanner._send_secret_notification(1, Tool.BASH, "Claude Code") + instance = mock_notify_cls.return_value + assert "running a command" in instance.message + assert "Claude Code" in instance.message + instance.send.assert_called_once() + + @patch("ggshield.verticals.ai.hooks.Notify") + def test_notification_for_read_tool(self, mock_notify_cls: MagicMock): + """Notification for READ tool says 'reading a file'.""" + AIHookScanner._send_secret_notification(2, Tool.READ, "Cursor") + instance = mock_notify_cls.return_value + assert "reading a file" in instance.message + assert "2" in instance.message + instance.send.assert_called_once() + + @patch("ggshield.verticals.ai.hooks.Notify") + def test_notification_for_other_tool(self, mock_notify_cls: MagicMock): + """Notification for OTHER tool says 'using a tool'.""" + AIHookScanner._send_secret_notification(1, Tool.OTHER, "Copilot") + instance = mock_notify_cls.return_value + assert "using a tool" in instance.message + instance.send.assert_called_once() + + +class TestAIHookScannerParseInput: + """Unit tests for AIHookparse_hook_input.""" + def test_invalid_json_raises(self): """Invalid JSON raises ValueError with parse error.""" - scanner = AIHookScanner(_mock_scanner([])) with pytest.raises(ValueError, match="Failed to parse JSON"): - scanner._parse_input("not json {") + parse_hook_input("not json {") with pytest.raises(ValueError, match="Failed to parse JSON"): - scanner._parse_input("{ missing brace ") + parse_hook_input("{ missing brace ") def test_missing_event_type_raises(self): """JSON without event type raises ValueError.""" - scanner = AIHookScanner(_mock_scanner([])) - with pytest.raises(ValueError, match="couldn't find event type"): - scanner._parse_input('{"prompt": "hello"}') + with pytest.raises(ValueError): + parse_hook_input('{"prompt": "hello"}') def test_cursor_user_prompt(self): """Test Cursor beforeSubmitPrompt (user prompt) parsing.""" - scanner = AIHookScanner(_mock_scanner([])) data = { "conversation_id": "75fed8a8-2078-4e49-80d2-776b20d441c3", "generation_id": "1501ede6-b8ac-43f4-9943-0e218610c5c6", @@ -120,16 +316,15 @@ def test_cursor_user_prompt(self): "user_email": "user@example.com", "transcript_path": "/home/user1/.cursor/projects/foo/agent-transcripts/75fed8a8/75fed8a8.jsonl", } - payload = scanner._parse_input(json.dumps(data))[0] + payload = parse_hook_input(json.dumps(data))[0] assert payload.event_type == EventType.USER_PROMPT assert payload.content == "hello world" assert payload.tool is None assert payload.identifier != "" - assert isinstance(payload.flavor, Cursor) + assert isinstance(payload.agent, Cursor) def test_cursor_pre_tool_use_shell(self): """Test Cursor preToolUse with Shell (bash) parsing.""" - scanner = AIHookScanner(_mock_scanner([])) data = { "conversation_id": "37a17cfc-322c-47ab-88c5-e810f23f4739", "generation_id": "049f5b26-326a-4081-82c1-e5c42a63d19e", @@ -146,16 +341,15 @@ def test_cursor_pre_tool_use_shell(self): "workspace_roots": ["/home/user1/foo"], "transcript_path": "/home/user1/.cursor/projects/foo/agent-transcripts/37a17cfc/37a17cfc.jsonl", } - payload = scanner._parse_input(json.dumps(data))[0] + payload = parse_hook_input(json.dumps(data))[0] assert payload.event_type == EventType.PRE_TOOL_USE assert payload.tool == Tool.BASH assert payload.content == "whoami" assert payload.identifier == "whoami" - assert isinstance(payload.flavor, Cursor) + assert isinstance(payload.agent, Cursor) def test_cursor_pre_tool_use_read(self, tmp_file: Path): """Test Cursor preToolUse with Read (file) parsing.""" - scanner = AIHookScanner(_mock_scanner([])) data = { "conversation_id": "75fed8a8-2078-4e49-80d2-776b20d441c3", "generation_id": "1501ede6-b8ac-43f4-9943-0e218610c5c6", @@ -168,17 +362,16 @@ def test_cursor_pre_tool_use_read(self, tmp_file: Path): "workspace_roots": ["/home/user1/foo"], "transcript_path": "/home/user1/.cursor/projects/foo/agent-transcripts/75fed8a8/75fed8a8.jsonl", } - payload = scanner._parse_input(json.dumps(data))[0] + payload = parse_hook_input(json.dumps(data))[0] assert payload.event_type == EventType.PRE_TOOL_USE assert payload.tool == Tool.READ assert payload.identifier == tmp_file.as_posix() assert payload.content == "" assert payload.scannable.content == "this is the content" - assert isinstance(payload.flavor, Cursor) + assert isinstance(payload.agent, Cursor) def test_cursor_post_tool_use_shell(self): """Test Cursor postToolUse with Shell (simulated cat command result).""" - scanner = AIHookScanner(_mock_scanner([])) data = { "conversation_id": "37a17cfc-322c-47ab-88c5-e810f23f4739", "generation_id": "049f5b26-326a-4081-82c1-e5c42a63d19e", @@ -193,15 +386,14 @@ def test_cursor_post_tool_use_shell(self): "workspace_roots": ["/home/user1/foo"], "transcript_path": "/home/user/.cursor/projects/foo/agent-transcripts/37a17cfc/37a17cfc.jsonl", } - payload = scanner._parse_input(json.dumps(data))[0] + payload = parse_hook_input(json.dumps(data))[0] assert payload.event_type == EventType.POST_TOOL_USE assert payload.tool == Tool.BASH assert "user1" in payload.content - assert isinstance(payload.flavor, Cursor) + assert isinstance(payload.agent, Cursor) def test_claude_user_prompt(self): """Test Claude Code UserPromptSubmit parsing.""" - scanner = AIHookScanner(_mock_scanner([])) data = { "session_id": "273ad859-3608-4799-9971-fa15ecb1a65c", "transcript_path": "/home/user1/.claude/projects/foo/273ad859-3608-4799-9971-fa15ecb1a65c.jsonl", @@ -210,15 +402,14 @@ def test_claude_user_prompt(self): "hook_event_name": "UserPromptSubmit", "prompt": "hello world", } - payload = scanner._parse_input(json.dumps(data))[0] + payload = parse_hook_input(json.dumps(data))[0] assert payload.event_type == EventType.USER_PROMPT assert payload.content == "hello world" assert payload.tool is None - assert isinstance(payload.flavor, Claude) + assert isinstance(payload.agent, Claude) def test_claude_pre_tool_use_bash(self): """Test Claude Code PreToolUse with Bash parsing.""" - scanner = AIHookScanner(_mock_scanner([])) data = { "session_id": "3b7ae0c5-0862-4e14-aa2c-12fad909c323", "transcript_path": "/home/user1/.claude/projects/foo/3b7ae0c5.jsonl", @@ -232,15 +423,14 @@ def test_claude_pre_tool_use_bash(self): }, "tool_use_id": "toolu_01BPMKeZAMCqBtn1xJRNfDJw", } - payload = scanner._parse_input(json.dumps(data))[0] + payload = parse_hook_input(json.dumps(data))[0] assert payload.event_type == EventType.PRE_TOOL_USE assert payload.tool == Tool.BASH assert "whoami" in payload.content - assert isinstance(payload.flavor, Claude) + assert isinstance(payload.agent, Claude) def test_claude_pre_tool_use_read(self, tmp_file: Path): """Test Claude Code PreToolUse with Read parsing.""" - scanner = AIHookScanner(_mock_scanner([])) # From raw_hooks_logs: Claude PreToolUse Read data = { "session_id": "3b7ae0c5-0862-4e14-aa2c-12fad909c323", @@ -252,17 +442,16 @@ def test_claude_pre_tool_use_read(self, tmp_file: Path): "tool_input": {"file_path": tmp_file.as_posix()}, "tool_use_id": "toolu_01WabtWJpzf1ZJ8GJ3JfQEmq", } - payload = scanner._parse_input(json.dumps(data))[0] + payload = parse_hook_input(json.dumps(data))[0] assert payload.event_type == EventType.PRE_TOOL_USE assert payload.tool == Tool.READ assert payload.identifier == tmp_file.as_posix() assert payload.content == "" assert payload.scannable.content == "this is the content" - assert isinstance(payload.flavor, Claude) + assert isinstance(payload.agent, Claude) def test_claude_post_tool_use_bash(self): """Test Claude Code PostToolUse with Bash (simulated cat command result).""" - scanner = AIHookScanner(_mock_scanner([])) # From raw_hooks_logs: Claude PostToolUse Bash - tool_response has stdout data = { "session_id": "3b7ae0c5-0862-4e14-aa2c-12fad909c323", @@ -284,16 +473,15 @@ def test_claude_post_tool_use_bash(self): }, "tool_use_id": "toolu_01BPMKeZAMCqBtn1xJRNfDJw", } - payload = scanner._parse_input(json.dumps(data))[0] + payload = parse_hook_input(json.dumps(data))[0] assert payload.event_type == EventType.POST_TOOL_USE assert payload.tool == Tool.BASH # Content is json.dumps(tool_response), so the stdout is inside the string assert "user1" in payload.content - assert isinstance(payload.flavor, Claude) + assert isinstance(payload.agent, Claude) def test_claude_parse_read_files_in_prompt(self): """Test parsing "@file_path" mentions from Claude Code prompt.""" - scanner = AIHookScanner(_mock_scanner([])) data = { "session_id": "273ad859-3608-4799-9971-fa15ecb1a65c", "transcript_path": "/home/user1/.claude/projects/foo/273ad859-3608-4799-9971-fa15ecb1a65c.jsonl", @@ -302,24 +490,23 @@ def test_claude_parse_read_files_in_prompt(self): "hook_event_name": "UserPromptSubmit", "prompt": "read @folder/file.txt and summarize the content.", } - payloads = scanner._parse_input(json.dumps(data)) + payloads = parse_hook_input(json.dumps(data)) assert len(payloads) == 2 payload = payloads[0] assert payload.event_type == EventType.USER_PROMPT assert payload.tool == Tool.READ assert payload.identifier == "folder/file.txt" assert payload.content == "" # empty because inexistent file - assert isinstance(payload.flavor, Claude) + assert isinstance(payload.agent, Claude) payload = payloads[1] assert payload.event_type == EventType.USER_PROMPT assert payload.content == "read @folder/file.txt and summarize the content." assert payload.tool is None - assert isinstance(payload.flavor, Claude) + assert isinstance(payload.agent, Claude) def test_copilot_user_prompt(self): """Test Copilot UserPromptSubmit parsing.""" - scanner = AIHookScanner(_mock_scanner([])) data = { "timestamp": "2026-02-26T11:28:53.112Z", "hookEventName": "UserPromptSubmit", @@ -331,15 +518,14 @@ def test_copilot_user_prompt(self): "prompt": "hello world", "cwd": "/home/user1/foo", } - payload = scanner._parse_input(json.dumps(data))[0] + payload = parse_hook_input(json.dumps(data))[0] assert payload.event_type == EventType.USER_PROMPT assert "hello world" in payload.content assert payload.tool is None - assert isinstance(payload.flavor, Copilot) + assert isinstance(payload.agent, Copilot) def test_copilot_pre_tool_use_run_in_terminal(self): """Test Copilot PreToolUse with run_in_terminal (shell) parsing.""" - scanner = AIHookScanner(_mock_scanner([])) # From raw_hooks_logs: Copilot PreToolUse run_in_terminal data = { "timestamp": "2026-02-26T11:29:05.821Z", @@ -360,15 +546,14 @@ def test_copilot_pre_tool_use_run_in_terminal(self): "tool_use_id": "call_ADJcoVxpnzPtpU6uf0h9wzLR__vscode-1772105116075", "cwd": "/home/user1/foo", } - payload = scanner._parse_input(json.dumps(data))[0] + payload = parse_hook_input(json.dumps(data))[0] assert payload.event_type == EventType.PRE_TOOL_USE assert payload.tool == Tool.BASH assert "whoami" in payload.content - assert isinstance(payload.flavor, Copilot) + assert isinstance(payload.agent, Copilot) def test_copilot_pre_tool_use_read_file(self, tmp_file: Path): """Test Copilot PreToolUse with read_file parsing.""" - scanner = AIHookScanner(_mock_scanner([])) # From raw_hooks_logs: Copilot PreToolUse read_file (nonexistent path for deterministic test) data = { "timestamp": "2026-02-26T11:53:49.593Z", @@ -387,17 +572,16 @@ def test_copilot_pre_tool_use_read_file(self, tmp_file: Path): "tool_use_id": "call_iMFuTGETQ2z23a3xYTqcHBXp__vscode-1772105116078", "cwd": "/home/user1/foo", } - payload = scanner._parse_input(json.dumps(data))[0] + payload = parse_hook_input(json.dumps(data))[0] assert payload.event_type == EventType.PRE_TOOL_USE assert payload.tool == Tool.READ assert payload.identifier == tmp_file.as_posix() assert payload.content == "" assert payload.scannable.content == "this is the content" - assert isinstance(payload.flavor, Copilot) + assert isinstance(payload.agent, Copilot) def test_copilot_post_tool_use_run_in_terminal(self): """Test Copilot PostToolUse with run_in_terminal (simulated cat result).""" - scanner = AIHookScanner(_mock_scanner([])) # From raw_hooks_logs: Copilot PostToolUse run_in_terminal - tool_response is string data = { "timestamp": "2026-02-26T11:53:47.392Z", @@ -419,23 +603,23 @@ def test_copilot_post_tool_use_run_in_terminal(self): "tool_use_id": "call_f96KUoNCGS8jENVKnlWnSz5Q__vscode-1772105116077", "cwd": "/home/user1/foo", } - payload = scanner._parse_input(json.dumps(data))[0] + payload = parse_hook_input(json.dumps(data))[0] assert payload.event_type == EventType.POST_TOOL_USE assert payload.tool == Tool.BASH assert "user1" in payload.content - assert isinstance(payload.flavor, Copilot) + assert isinstance(payload.agent, Copilot) def test_pre_tool_use_read_with_missing_file(self): """PRE_TOOL_USE with tool_name 'read' and non-existing file yields empty content.""" - scanner = AIHookScanner(_mock_scanner([])) content = json.dumps( { "hook_event_name": "pretooluse", "tool_name": "read", "tool_input": {"file_path": "/nonexistent/path"}, + "cursor_version": "1.2.3", } ) - payload = scanner._parse_input(content)[0] + payload = parse_hook_input(content)[0] assert payload.event_type == EventType.PRE_TOOL_USE assert payload.tool == Tool.READ assert payload.identifier == "/nonexistent/path" @@ -443,75 +627,37 @@ def test_pre_tool_use_read_with_missing_file(self): def test_pre_tool_use_other_tool(self): """PRE_TOOL_USE with unknown tool yields Tool.OTHER and empty content.""" - scanner = AIHookScanner(_mock_scanner([])) data = { "hook_event_name": "PreToolUse", "tool_name": "SomeUnknownTool", "tool_input": {"arg": "value"}, + "cursor_version": "1.2.3", } - payload = scanner._parse_input(json.dumps(data))[0] + payload = parse_hook_input(json.dumps(data))[0] assert payload.event_type == EventType.PRE_TOOL_USE assert payload.tool == Tool.OTHER assert payload.content == "" def test_other_event_type(self): """Unknown event type yields EventType.OTHER with empty content.""" - scanner = AIHookScanner(_mock_scanner([])) data = { "hook_event_name": "SomeOtherEvent", "prompt": "hello", + "cursor_version": "1.2.3", } - payload = scanner._parse_input(json.dumps(data))[0] + payload = parse_hook_input(json.dumps(data))[0] assert payload.event_type == EventType.OTHER assert payload.content == "" assert payload.tool is None -class TestAIHookScannerScanContent: - """Unit tests for AIHookScanner._scan_content.""" - - def test_no_secrets_returns_allow(self): - """When scanner returns no secrets, result has block=False and nbr_secrets=0.""" - hook_scanner = AIHookScanner(_mock_scanner([])) - payload = Payload( - event_type=EventType.USER_PROMPT, - tool=None, - content="safe content", - identifier="id", - flavor=Flavor(), - ) - result = hook_scanner._scan_content(payload) - assert isinstance(result, HookResult) - assert result.block is False - assert result.nbr_secrets == 0 - assert result.message == "" - - def test_with_secrets_returns_block_and_message(self): - """When scanner returns secrets, result has block=True, nbr_secrets and message set.""" - hook_scanner = AIHookScanner(_mock_scanner(["sk-xxx"])) - payload = Payload( - event_type=EventType.USER_PROMPT, - tool=None, - content="content with sk-xxx", - identifier="id", - flavor=Flavor(), - ) - result = hook_scanner._scan_content(payload) - assert isinstance(result, HookResult) - assert result.block is True - assert result.nbr_secrets == 1 - assert "dummy-detector" in result.message - assert "secret" in result.message.lower() - assert "remove the secrets from your prompt" in result.message - - class TestFlavorOutputResult: """Unit tests for Cursor, Claude, Copilot output_result with Result objects. Mocks click.echo to capture stdout/stderr and asserts both output and return code. """ - @patch("ggshield.verticals.secret.ai_hook.cursor.click.echo") + @patch("ggshield.verticals.ai.agents.cursor.click.echo") def test_cursor_output_result_user_prompt_allow(self, mock_echo: MagicMock): """Cursor USER_PROMPT with block=False: JSON to stdout, return 0.""" result = HookResult.allow(_dummy_payload(EventType.USER_PROMPT)) @@ -524,7 +670,7 @@ def test_cursor_output_result_user_prompt_allow(self, mock_echo: MagicMock): assert out["continue"] is True assert out["user_message"] == "" - @patch("ggshield.verticals.secret.ai_hook.cursor.click.echo") + @patch("ggshield.verticals.ai.agents.cursor.click.echo") def test_cursor_output_result_user_prompt_block(self, mock_echo: MagicMock): """Cursor USER_PROMPT with block=True: JSON to stdout, return 0.""" result = HookResult( @@ -542,7 +688,7 @@ def test_cursor_output_result_user_prompt_block(self, mock_echo: MagicMock): assert out["continue"] is False assert out["user_message"] == "Remove secrets from prompt" - @patch("ggshield.verticals.secret.ai_hook.cursor.click.echo") + @patch("ggshield.verticals.ai.agents.cursor.click.echo") def test_cursor_output_result_pre_tool_use_allow(self, mock_echo: MagicMock): """Cursor PRE_TOOL_USE with block=False: permission allow, return 0.""" result = HookResult.allow(_dummy_payload(EventType.PRE_TOOL_USE)) @@ -554,7 +700,7 @@ def test_cursor_output_result_pre_tool_use_allow(self, mock_echo: MagicMock): out = json.loads(args[0]) assert out["permission"] == "allow" - @patch("ggshield.verticals.secret.ai_hook.cursor.click.echo") + @patch("ggshield.verticals.ai.agents.cursor.click.echo") def test_cursor_output_result_pre_tool_use_block(self, mock_echo: MagicMock): """Cursor PRE_TOOL_USE with block=True: permission deny, return 0.""" result = HookResult( @@ -572,7 +718,7 @@ def test_cursor_output_result_pre_tool_use_block(self, mock_echo: MagicMock): assert out["permission"] == "deny" assert out["user_message"] == "Secrets detected in command" - @patch("ggshield.verticals.secret.ai_hook.cursor.click.echo") + @patch("ggshield.verticals.ai.agents.cursor.click.echo") def test_cursor_output_result_post_tool_use(self, mock_echo: MagicMock): """Cursor POST_TOOL_USE: empty JSON to stdout, return 0.""" result = HookResult( @@ -588,7 +734,7 @@ def test_cursor_output_result_post_tool_use(self, mock_echo: MagicMock): assert kwargs.get("err", False) is False # stdout (default) assert json.loads(args[0]) == {} - @patch("ggshield.verticals.secret.ai_hook.cursor.click.echo") + @patch("ggshield.verticals.ai.agents.cursor.click.echo") def test_cursor_output_result_other_block(self, mock_echo: MagicMock): """Cursor OTHER event with block: empty JSON, return 2.""" result = HookResult( @@ -601,7 +747,7 @@ def test_cursor_output_result_other_block(self, mock_echo: MagicMock): assert code == 2 mock_echo.assert_called_once_with("{}") - @patch("ggshield.verticals.secret.ai_hook.cursor.click.echo") + @patch("ggshield.verticals.ai.agents.cursor.click.echo") def test_cursor_output_result_other_allow(self, mock_echo: MagicMock): """Cursor OTHER event without block: empty JSON, return 0.""" result = HookResult.allow(_dummy_payload(EventType.OTHER)) @@ -609,7 +755,7 @@ def test_cursor_output_result_other_allow(self, mock_echo: MagicMock): assert code == 0 mock_echo.assert_called_once_with("{}") - @patch("ggshield.verticals.secret.ai_hook.claude_code.click.echo") + @patch("ggshield.verticals.ai.agents.claude_code.click.echo") def test_claude_output_result_allow(self, mock_echo: MagicMock): """Claude with block=False: JSON continue true to stdout, return 0.""" result = HookResult.allow(_dummy_payload(EventType.USER_PROMPT)) @@ -621,7 +767,7 @@ def test_claude_output_result_allow(self, mock_echo: MagicMock): out = json.loads(args[0]) assert out["continue"] is True - @patch("ggshield.verticals.secret.ai_hook.claude_code.click.echo") + @patch("ggshield.verticals.ai.agents.claude_code.click.echo") def test_claude_output_result_block(self, mock_echo: MagicMock): """Claude with block=True: JSON continue false and stopReason to stdout, return 0.""" result = HookResult( @@ -641,7 +787,7 @@ def test_claude_output_result_block(self, mock_echo: MagicMock): out["hookSpecificOutput"]["permissionDecisionReason"] == "Secrets in file" ) - @patch("ggshield.verticals.secret.ai_hook.claude_code.click.echo") + @patch("ggshield.verticals.ai.agents.claude_code.click.echo") def test_copilot_output_result_allow(self, mock_echo: MagicMock): """Copilot with block=False: same as Claude, JSON to stdout, return 0.""" result = HookResult.allow(_dummy_payload(EventType.USER_PROMPT)) @@ -654,7 +800,7 @@ def test_copilot_output_result_allow(self, mock_echo: MagicMock): assert out["continue"] is True assert "stopReason" not in out - @patch("ggshield.verticals.secret.ai_hook.claude_code.click.echo") + @patch("ggshield.verticals.ai.agents.claude_code.click.echo") def test_copilot_output_result_block(self, mock_echo: MagicMock): """Copilot with block=True: same as Claude, JSON to stdout, return 0.""" result = HookResult( @@ -672,7 +818,7 @@ def test_copilot_output_result_block(self, mock_echo: MagicMock): assert out["decision"] == "block" assert out["reason"] == "Secret in tool output" - @patch("ggshield.verticals.secret.ai_hook.claude_code.click.echo") + @patch("ggshield.verticals.ai.agents.claude_code.click.echo") def test_copilot_other_result_block(self, mock_echo: MagicMock): """Copilot with block=True, other type of event""" result = HookResult( @@ -688,201 +834,6 @@ def test_copilot_other_result_block(self, mock_echo: MagicMock): assert not out["continue"] -class TestBaseFlavor: - """Unit tests for the base Flavor class.""" - - @patch("ggshield.verticals.secret.ai_hook.models.click.echo") - def test_base_flavor_output_result_allow(self, mock_echo: MagicMock): - """Base Flavor with block=False: prints allow message, returns 0.""" - result = HookResult.allow(_dummy_payload(EventType.USER_PROMPT)) - code = Flavor().output_result(result) - assert code == 0 - mock_echo.assert_called_once_with("No secrets found. Good to go.") - - @patch("ggshield.verticals.secret.ai_hook.models.click.echo") - def test_base_flavor_output_result_block(self, mock_echo: MagicMock): - """Base Flavor with block=True: prints message to stderr, returns 2.""" - result = HookResult( - block=True, - message="Secrets found", - nbr_secrets=1, - payload=_dummy_payload(EventType.PRE_TOOL_USE), - ) - code = Flavor().output_result(result) - assert code == 2 - mock_echo.assert_called_once_with("Secrets found", err=True) - - def test_base_flavor_settings_path(self): - """Base Flavor settings_path returns default path.""" - assert Flavor().settings_path == Path(".agents") / "hooks.json" - - def test_base_flavor_settings_template(self): - """Base Flavor settings_template returns empty dict.""" - assert Flavor().settings_template == {} - - def test_base_flavor_settings_locate(self): - """Base Flavor settings_locate always returns None.""" - assert Flavor().settings_locate([{"a": 1}], {"a": 1}) is None - - -class TestAIHookScannerScan: - """Unit tests for the AIHookScanner.scan() method.""" - - def test_scan_no_secrets_returns_zero(self): - """scan() with no secrets returns 0.""" - scanner = AIHookScanner(_mock_scanner([])) - data = { - "hook_event_name": "UserPromptSubmit", - "prompt": "hello world", - "transcript_path": "/home/user/.claude/projects/foo/session.jsonl", - } - code = scanner.scan(json.dumps(data)) - assert code == 0 - - @patch( - "ggshield.verticals.secret.ai_hook.scanner.AIHookScanner._send_secret_notification" - ) - def test_scan_post_tool_use_with_secrets_sends_notification( - self, mock_notify: MagicMock - ): - """scan() on POST_TOOL_USE with secrets sends a notification and returns 0 (no block).""" - scanner = AIHookScanner(_mock_scanner(["sk-xxx"])) - data = { - "hook_event_name": "PostToolUse", - "tool_name": "Bash", - "tool_input": {"command": "echo sk-xxx"}, - "tool_response": {"stdout": "sk-xxx\n"}, - "transcript_path": "/home/user/.claude/projects/foo/session.jsonl", - "session_id": "427ae0c5-0862-4e14-aa2c-12fad909c323", - } - code = scanner.scan(json.dumps(data)) - assert code == 0 - mock_notify.assert_called_once() - args = mock_notify.call_args[0] - assert args[0] == 1 # nbr_secrets - assert args[1] == Tool.BASH # tool - - def test_scan_pre_tool_use_with_secrets_blocks(self): - """scan() on PRE_TOOL_USE with secrets returns block result.""" - scanner = AIHookScanner(_mock_scanner(["sk-xxx"])) - data = { - "hook_event_name": "PreToolUse", - "tool_name": "Bash", - "tool_input": {"command": "echo sk-xxx"}, - "session_id": "427ae0c5-0862-4e14-aa2c-12fad909c323", - "transcript_path": "/home/user/.claude/projects/foo/session.jsonl", - } - code = scanner.scan(json.dumps(data)) - # Claude output_result always returns 0 - assert code == 0 - - def test_scan_no_content_returns_allow(self): - """scan() with no content returns 0 (and doesn't call the API).""" - mock_scanner = _mock_scanner([]) - scanner = AIHookScanner(mock_scanner) - data = { - "hook_event_name": "PreToolUse", - "tool_name": "Read", - "tool_input": {"file_path": "doesn-t-exist"}, - } - code = scanner.scan(json.dumps(data)) - assert code == 0 - mock_scanner.scan.assert_not_called() - - def test_scan_payloads_refuse_empty_list(self): - """scan() with empty list of payloads raises ValueError.""" - scanner = AIHookScanner(_mock_scanner([])) - with pytest.raises(ValueError): - scanner._scan_payloads([]) - - -class TestMessageFromSecrets: - """Unit tests for AIHookScanner._message_from_secrets with different payload types.""" - - def test_message_for_bash_tool(self): - """Message for BASH tool mentions environment variables.""" - payload = Payload( - event_type=EventType.PRE_TOOL_USE, - tool=Tool.BASH, - content="echo sk-xxx", - identifier="echo sk-xxx", - flavor=Flavor(), - ) - message = AIHookScanner._message_from_secrets([_make_secret("sk-xxx")], payload) - assert "remove the secrets from the command" in message - assert "environment variables" in message - - def test_message_for_read_tool(self): - """Message for READ tool mentions file content.""" - payload = Payload( - event_type=EventType.PRE_TOOL_USE, - tool=Tool.READ, - content="file content with secret", - identifier="/path/to/file", - flavor=Flavor(), - ) - message = AIHookScanner._message_from_secrets([_make_secret("sk-xxx")], payload) - assert "remove the secrets from" in message - - def test_message_for_other_tool(self): - """Message for OTHER tool uses generic message.""" - payload = Payload( - event_type=EventType.PRE_TOOL_USE, - tool=Tool.OTHER, - content="some content", - identifier="id", - flavor=Flavor(), - ) - message = AIHookScanner._message_from_secrets([_make_secret("sk-xxx")], payload) - assert "remove the secrets from the tool input" in message - - def test_message_escapes_markdown(self): - """When escape_markdown=True, asterisks in matches are replaced with dots.""" - payload = Payload( - event_type=EventType.USER_PROMPT, - tool=None, - content="content", - identifier="id", - flavor=Flavor(), - ) - message = AIHookScanner._message_from_secrets( - [_make_secret("sk-xxx")], payload, escape_markdown=True - ) - # The message itself should not contain raw asterisks from matches - # (the header uses ** for bold which is intentional) - assert "Detected" in message - - -class TestSendSecretNotification: - """Unit tests for AIHookScanner._send_secret_notification.""" - - @patch("ggshield.verticals.secret.ai_hook.scanner.Notify") - def test_notification_for_bash_tool(self, mock_notify_cls: MagicMock): - """Notification for BASH tool says 'running a command'.""" - AIHookScanner._send_secret_notification(1, Tool.BASH, "Claude Code") - instance = mock_notify_cls.return_value - assert "running a command" in instance.message - assert "Claude Code" in instance.message - instance.send.assert_called_once() - - @patch("ggshield.verticals.secret.ai_hook.scanner.Notify") - def test_notification_for_read_tool(self, mock_notify_cls: MagicMock): - """Notification for READ tool says 'reading a file'.""" - AIHookScanner._send_secret_notification(2, Tool.READ, "Cursor") - instance = mock_notify_cls.return_value - assert "reading a file" in instance.message - assert "2" in instance.message - instance.send.assert_called_once() - - @patch("ggshield.verticals.secret.ai_hook.scanner.Notify") - def test_notification_for_other_tool(self, mock_notify_cls: MagicMock): - """Notification for OTHER tool says 'using a tool'.""" - AIHookScanner._send_secret_notification(1, Tool.OTHER, "Copilot") - instance = mock_notify_cls.return_value - assert "using a tool" in instance.message - instance.send.assert_called_once() - - @pytest.mark.parametrize( "prompt, filepaths", [ diff --git a/tests/unit/verticals/secret/ai_hook/test_installation.py b/tests/unit/verticals/ai/test_installation.py similarity index 90% rename from tests/unit/verticals/secret/ai_hook/test_installation.py rename to tests/unit/verticals/ai/test_installation.py index 5a98ef2f7b..20e2dea79d 100644 --- a/tests/unit/verticals/secret/ai_hook/test_installation.py +++ b/tests/unit/verticals/ai/test_installation.py @@ -6,10 +6,8 @@ import pytest from ggshield.core.errors import UnexpectedError -from ggshield.verticals.secret.ai_hook.claude_code import Claude -from ggshield.verticals.secret.ai_hook.copilot import Copilot -from ggshield.verticals.secret.ai_hook.cursor import Cursor -from ggshield.verticals.secret.ai_hook.installation import ( +from ggshield.verticals.ai.agents import Claude, Copilot, Cursor +from ggshield.verticals.ai.installation import ( InstallationStats, _fill_dict, install_hooks, @@ -287,16 +285,14 @@ def test_copilot_settings_path(self): class TestInstallHooks: """Unit tests for the install_hooks function.""" - @patch("ggshield.verticals.secret.ai_hook.installation.get_user_home_dir") + @patch("ggshield.verticals.ai.installation.get_user_home_dir") def test_install_cursor_local_fresh(self, mock_home: Any, tmp_path: Path): """Install Cursor hooks locally into a fresh directory (no existing config).""" mock_home.return_value = tmp_path settings_path = tmp_path / ".cursor" / "hooks.json" assert not settings_path.exists() - with patch( - "ggshield.verticals.secret.ai_hook.installation.Path" - ) as mock_path_cls: + with patch("ggshield.verticals.ai.installation.Path") as mock_path_cls: # Make Path(".") return tmp_path so local mode writes there mock_path_cls.side_effect = lambda *a: Path(*a) if a != (".",) else tmp_path code = install_hooks("cursor", mode="local") @@ -308,7 +304,7 @@ def test_install_cursor_local_fresh(self, mock_home: Any, tmp_path: Path): for key in ("beforeSubmitPrompt", "preToolUse", "postToolUse"): assert any("ggshield" in h["command"] for h in config["hooks"][key]) - @patch("ggshield.verticals.secret.ai_hook.installation.get_user_home_dir") + @patch("ggshield.verticals.ai.installation.get_user_home_dir") def test_install_claude_global(self, mock_home: Any, tmp_path: Path): """Install Claude Code hooks globally.""" mock_home.return_value = tmp_path @@ -322,7 +318,7 @@ def test_install_claude_global(self, mock_home: Any, tmp_path: Path): for key in ("PreToolUse", "PostToolUse", "UserPromptSubmit"): assert key in config["hooks"] - @patch("ggshield.verticals.secret.ai_hook.installation.get_user_home_dir") + @patch("ggshield.verticals.ai.installation.get_user_home_dir") def test_install_copilot_global(self, mock_home: Any, tmp_path: Path): """Install Copilot hooks globally.""" mock_home.return_value = tmp_path @@ -332,12 +328,12 @@ def test_install_copilot_global(self, mock_home: Any, tmp_path: Path): settings_path = tmp_path / ".github" / "hooks" / "hooks.json" assert settings_path.exists() - def test_install_unsupported_tool_raises(self): - """install_hooks raises ValueError for unsupported tool name.""" - with pytest.raises(ValueError, match="Unsupported tool name"): - install_hooks("unknown-tool", mode="local") + def test_install_unsupported_agent_raises(self): + """install_hooks raises ValueError for unsupported agent.""" + with pytest.raises(ValueError, match="Unsupported agent"): + install_hooks("unknown-agent", mode="local") - @patch("ggshield.verticals.secret.ai_hook.installation.get_user_home_dir") + @patch("ggshield.verticals.ai.installation.get_user_home_dir") def test_install_with_existing_config(self, mock_home: Any, tmp_path: Path): """Install hooks when a config file already exists (merges).""" mock_home.return_value = tmp_path @@ -345,9 +341,7 @@ def test_install_with_existing_config(self, mock_home: Any, tmp_path: Path): settings_path.parent.mkdir(parents=True) settings_path.write_text(json.dumps({"version": 1, "other_key": "keep_me"})) - with patch( - "ggshield.verticals.secret.ai_hook.installation.Path" - ) as mock_path_cls: + with patch("ggshield.verticals.ai.installation.Path") as mock_path_cls: mock_path_cls.side_effect = lambda *a: Path(*a) if a != (".",) else tmp_path code = install_hooks("cursor", mode="local") @@ -356,7 +350,7 @@ def test_install_with_existing_config(self, mock_home: Any, tmp_path: Path): assert config["other_key"] == "keep_me" assert "hooks" in config - @patch("ggshield.verticals.secret.ai_hook.installation.get_user_home_dir") + @patch("ggshield.verticals.ai.installation.get_user_home_dir") def test_install_with_corrupt_json_raises(self, mock_home: Any, tmp_path: Path): """install_hooks raises UnexpectedError when existing config is invalid JSON.""" mock_home.return_value = tmp_path @@ -364,21 +358,17 @@ def test_install_with_corrupt_json_raises(self, mock_home: Any, tmp_path: Path): settings_path.parent.mkdir(parents=True) settings_path.write_text("{ invalid json") - with patch( - "ggshield.verticals.secret.ai_hook.installation.Path" - ) as mock_path_cls: + with patch("ggshield.verticals.ai.installation.Path") as mock_path_cls: mock_path_cls.side_effect = lambda *a: Path(*a) if a != (".",) else tmp_path with pytest.raises(UnexpectedError, match="Failed to parse"): install_hooks("cursor", mode="local") - @patch("ggshield.verticals.secret.ai_hook.installation.get_user_home_dir") + @patch("ggshield.verticals.ai.installation.get_user_home_dir") def test_install_already_present(self, mock_home: Any, tmp_path: Path): """install_hooks when hooks are already installed reports 'already installed'.""" mock_home.return_value = tmp_path - with patch( - "ggshield.verticals.secret.ai_hook.installation.Path" - ) as mock_path_cls: + with patch("ggshield.verticals.ai.installation.Path") as mock_path_cls: mock_path_cls.side_effect = lambda *a: Path(*a) if a != (".",) else tmp_path # Install once install_hooks("cursor", mode="local") @@ -387,14 +377,12 @@ def test_install_already_present(self, mock_home: Any, tmp_path: Path): assert code == 0 - @patch("ggshield.verticals.secret.ai_hook.installation.get_user_home_dir") + @patch("ggshield.verticals.ai.installation.get_user_home_dir") def test_install_force_updates(self, mock_home: Any, tmp_path: Path): """install_hooks with force=True updates existing hooks.""" mock_home.return_value = tmp_path - with patch( - "ggshield.verticals.secret.ai_hook.installation.Path" - ) as mock_path_cls: + with patch("ggshield.verticals.ai.installation.Path") as mock_path_cls: mock_path_cls.side_effect = lambda *a: Path(*a) if a != (".",) else tmp_path install_hooks("cursor", mode="local") code = install_hooks("cursor", mode="local", force=True) diff --git a/tests/unit/verticals/ai/test_mcp_activity.py b/tests/unit/verticals/ai/test_mcp_activity.py new file mode 100644 index 0000000000..2e4d6d5afa --- /dev/null +++ b/tests/unit/verticals/ai/test_mcp_activity.py @@ -0,0 +1,122 @@ +from unittest.mock import MagicMock, patch + +import pytest +from pydantic import ValidationError +from requests.exceptions import ConnectionError as RequestsConnectionError + +from ggshield.verticals.ai.mcp import send_mcp_activity +from ggshield.verticals.ai.models import ( + AIDiscovery, + EventType, + HookPayload, + MCPActivityRequest, + MCPActivityResponse, + Tool, + UserInfo, +) + + +def _user() -> UserInfo: + return UserInfo( + hostname="host", username="user", machine_id="mid", user_email="u@e.com" + ) + + +def _ai_discovery() -> AIDiscovery: + return AIDiscovery(user=_user(), servers=[], discovery_duration=0.1) + + +def _mcp_activity_request() -> MCPActivityRequest: + from pathlib import Path + + return MCPActivityRequest( + user=_user(), + tool="my_tool", + server="my_server", + agent="cursor", + model="gpt-4", + cwd=Path("/tmp"), + input={"key": "value"}, + ) + + +def _payload() -> HookPayload: + agent = MagicMock() + return HookPayload( + event_type=EventType.PRE_TOOL_USE, + tool=Tool.MCP, + content="content", + identifier="id", + agent=agent, + raw={}, + ) + + +class TestSendMCPActivity: + @patch("ggshield.verticals.ai.mcp.refresh_and_maybe_submit_discovery") + def test_successful_response(self, mock_refresh): + mock_refresh.return_value = _ai_discovery() + payload = _payload() + payload.agent.parse_mcp_activity.return_value = _mcp_activity_request() + + response_data = MCPActivityResponse(allowed=False, reason="blocked by policy") + mock_response = MagicMock() + mock_response.ok = True + mock_response.text = response_data.model_dump_json() + + client = MagicMock() + client.post.return_value = mock_response + + result = send_mcp_activity(client, payload) + + assert result.allowed is False + assert result.reason == "blocked by policy" + + @pytest.mark.parametrize( + "setup_client", + [ + pytest.param("connection_error", id="connection_error"), + pytest.param("non_ok_response", id="non_ok_response"), + pytest.param("validation_error", id="validation_error"), + ], + ) + @patch("ggshield.verticals.ai.mcp.refresh_and_maybe_submit_discovery") + def test_fail_open_returns_allowed(self, mock_refresh, setup_client): + mock_refresh.return_value = _ai_discovery() + payload = _payload() + payload.agent.parse_mcp_activity.return_value = _mcp_activity_request() + + client = MagicMock() + if setup_client == "connection_error": + client.post.side_effect = RequestsConnectionError("network down") + elif setup_client == "non_ok_response": + mock_response = MagicMock() + mock_response.ok = False + client.post.return_value = mock_response + elif setup_client == "validation_error": + mock_response = MagicMock() + mock_response.ok = True + mock_response.text = '{"unexpected_field": true}' + client.post.return_value = mock_response + + result = send_mcp_activity(client, payload) + + assert result.allowed is True + + @patch("ggshield.verticals.ai.mcp.refresh_and_maybe_submit_discovery") + def test_refresh_called_before_submission(self, mock_refresh): + mock_refresh.return_value = _ai_discovery() + payload = _payload() + payload.agent.parse_mcp_activity.return_value = _mcp_activity_request() + + response_data = MCPActivityResponse(allowed=True, reason="") + mock_response = MagicMock() + mock_response.ok = True + mock_response.text = response_data.model_dump_json() + + client = MagicMock() + client.post.return_value = mock_response + + send_mcp_activity(client, payload) + + mock_refresh.assert_called_once_with(client) diff --git a/tests/unit/verticals/ai/test_models.py b/tests/unit/verticals/ai/test_models.py new file mode 100644 index 0000000000..82f73a4a88 --- /dev/null +++ b/tests/unit/verticals/ai/test_models.py @@ -0,0 +1,323 @@ +from pathlib import Path +from unittest.mock import patch + +import pytest + +from ggshield.core.scan import File, StringScannable +from ggshield.verticals.ai.agents import Cursor +from ggshield.verticals.ai.models import ( + AIDiscovery, + EventType, + HookPayload, + MCPConfiguration, + MCPPromptInfo, + MCPResourceInfo, + MCPServer, + MCPToolInfo, + Scope, + Tool, + Transport, + UserInfo, +) + + +def _user(**kwargs) -> UserInfo: + defaults = dict( + hostname="host", username="user", machine_id="mid", user_email="u@e.com" + ) + return UserInfo(**(defaults | kwargs)) + + +def _server( + name="srv", + tools=None, + resources=None, + prompts=None, + configurations=None, +) -> MCPServer: + return MCPServer( + name=name, + tools=tools or [], + resources=resources or [], + prompts=prompts or [], + configurations=configurations or [], + ) + + +def _cfg(name="srv", agent="cursor", scope=Scope.USER, project=None) -> MCPConfiguration: + return MCPConfiguration( + name=name, + agent=agent, + scope=scope, + transport=Transport.STDIO, + project=project, + ) + + +# --------------------------------------------------------------------------- +# MCPServer.has_capabilities_unknown_to +# --------------------------------------------------------------------------- + + +class TestMCPServerHasCapabilitiesUnknownTo: + @pytest.mark.parametrize( + "self_kwargs, other_kwargs, expected", + [ + pytest.param( + {"tools": [MCPToolInfo(name="t1")]}, + {"tools": [MCPToolInfo(name="t1")]}, + False, + id="same_tools", + ), + pytest.param( + {"tools": [MCPToolInfo(name="t1"), MCPToolInfo(name="t2")]}, + {"tools": [MCPToolInfo(name="t1")]}, + True, + id="extra_tool", + ), + pytest.param( + {"resources": [MCPResourceInfo(uri="r1"), MCPResourceInfo(uri="r2")]}, + {"resources": [MCPResourceInfo(uri="r1")]}, + True, + id="extra_resource", + ), + pytest.param( + {"prompts": [MCPPromptInfo(name="p1"), MCPPromptInfo(name="p2")]}, + {"prompts": [MCPPromptInfo(name="p1")]}, + True, + id="extra_prompt", + ), + pytest.param( + {"tools": [MCPToolInfo(name="t1")]}, + {"tools": [MCPToolInfo(name="t1"), MCPToolInfo(name="t2")]}, + False, + id="subset_of_other", + ), + pytest.param({}, {}, False, id="both_empty"), + ], + ) + def test_has_capabilities_unknown_to(self, self_kwargs, other_kwargs, expected): + assert _server(**self_kwargs).has_capabilities_unknown_to( + _server(**other_kwargs) + ) is expected + + +# --------------------------------------------------------------------------- +# AIDiscovery.has_changed_from +# --------------------------------------------------------------------------- + + +class TestAIDiscoveryHasChangedFrom: + def test_identical_returns_false(self): + cfg = _cfg() + a = AIDiscovery( + user=_user(), + servers=[_server(configurations=[cfg])], + discovery_duration=0.1, + ) + b = AIDiscovery( + user=_user(), + servers=[_server(configurations=[cfg])], + discovery_duration=0.2, + ) + assert a.has_changed_from(b) is False + + def test_different_user_returns_true(self): + cfg = _cfg() + a = AIDiscovery( + user=_user(hostname="a"), + servers=[_server(configurations=[cfg])], + discovery_duration=0.1, + ) + b = AIDiscovery( + user=_user(hostname="b"), + servers=[_server(configurations=[cfg])], + discovery_duration=0.1, + ) + assert a.has_changed_from(b) is True + + def test_different_configuration_keys_returns_true(self): + a = AIDiscovery( + user=_user(), + servers=[_server(configurations=[_cfg(name="x")])], + discovery_duration=0.1, + ) + b = AIDiscovery( + user=_user(), + servers=[_server(configurations=[_cfg(name="y")])], + discovery_duration=0.1, + ) + assert a.has_changed_from(b) is True + + def test_new_capabilities_unknown_to_all_candidates_returns_true(self): + cfg = _cfg() + a = AIDiscovery( + user=_user(), + servers=[ + _server( + configurations=[cfg], + tools=[MCPToolInfo(name="new_tool")], + ) + ], + discovery_duration=0.1, + ) + b = AIDiscovery( + user=_user(), + servers=[_server(configurations=[cfg])], + discovery_duration=0.1, + ) + assert a.has_changed_from(b) is True + + def test_capabilities_known_to_one_candidate_returns_false(self): + cfg = _cfg() + tool = MCPToolInfo(name="known") + a = AIDiscovery( + user=_user(), + servers=[_server(configurations=[cfg], tools=[tool])], + discovery_duration=0.1, + ) + b = AIDiscovery( + user=_user(), + servers=[_server(configurations=[cfg], tools=[tool])], + discovery_duration=0.1, + ) + assert a.has_changed_from(b) is False + + def test_empty_servers_returns_false(self): + a = AIDiscovery(user=_user(), servers=[], discovery_duration=0.1) + b = AIDiscovery(user=_user(), servers=[], discovery_duration=0.1) + assert a.has_changed_from(b) is False + + +# --------------------------------------------------------------------------- +# HookPayload.scannable +# --------------------------------------------------------------------------- + + +class TestHookPayloadScannable: + def test_read_tool_existing_text_file_returns_file(self, tmp_path: Path): + f = tmp_path / "code.py" + f.write_text("secret = 'abc'") + payload = HookPayload( + event_type=EventType.PRE_TOOL_USE, + tool=Tool.READ, + content="", + identifier=str(f), + agent=Cursor(), + raw={}, + ) + assert isinstance(payload.scannable, File) + + def test_read_tool_missing_file_returns_string_scannable(self): + payload = HookPayload( + event_type=EventType.PRE_TOOL_USE, + tool=Tool.READ, + content="some content", + identifier="/nonexistent/path.txt", + agent=Cursor(), + raw={}, + ) + assert isinstance(payload.scannable, StringScannable) + + def test_read_tool_binary_file_returns_string_scannable(self, tmp_path: Path): + f = tmp_path / "image.png" + f.write_bytes(b"\x89PNG\r\n\x1a\n" + b"\x00" * 100) + payload = HookPayload( + event_type=EventType.PRE_TOOL_USE, + tool=Tool.READ, + content="", + identifier=str(f), + agent=Cursor(), + raw={}, + ) + assert isinstance(payload.scannable, StringScannable) + + def test_non_read_tool_returns_string_scannable(self): + payload = HookPayload( + event_type=EventType.PRE_TOOL_USE, + tool=Tool.BASH, + content="echo hello", + identifier="cmd", + agent=Cursor(), + raw={}, + ) + assert isinstance(payload.scannable, StringScannable) + + +# --------------------------------------------------------------------------- +# HookPayload.empty +# --------------------------------------------------------------------------- + + +class TestHookPayloadEmpty: + @pytest.mark.parametrize( + "content, expected", + [ + pytest.param("non-empty", False, id="non_empty_content"), + pytest.param("", True, id="empty_content"), + ], + ) + def test_empty(self, content, expected): + payload = HookPayload( + event_type=EventType.USER_PROMPT, + tool=None, + content=content, + identifier="id", + agent=Cursor(), + raw={}, + ) + assert payload.empty is expected + + +# --------------------------------------------------------------------------- +# Agent._parse_servers_block +# --------------------------------------------------------------------------- + + +class TestParseServersBlock: + def _parse(self, data, scope=Scope.USER, project=None): + return list(Cursor()._parse_servers_block(data, scope, project)) + + def test_mcp_servers_key_stdio(self): + data = { + "mcpServers": { + "myserver": { + "command": "npx", + "args": ["-y", "mcp-server"], + "env": {"KEY": "val"}, + } + } + } + configs = self._parse(data) + assert len(configs) == 1 + cfg = configs[0] + assert cfg.name == "myserver" + assert cfg.transport == Transport.STDIO + assert cfg.command == "npx" + assert cfg.args == ["-y", "mcp-server"] + assert cfg.env == {"KEY": "val"} + + def test_servers_key(self): + data = {"servers": {"s1": {"command": "node"}}} + configs = self._parse(data) + assert len(configs) == 1 + assert configs[0].name == "s1" + + def test_url_entry_detected_as_http(self): + data = {"mcpServers": {"remote": {"url": "https://example.com/mcp"}}} + configs = self._parse(data) + assert configs[0].transport == Transport.HTTP + assert configs[0].url == "https://example.com/mcp" + + def test_url_entry_with_sse_transport(self): + data = { + "mcpServers": { + "remote": {"url": "https://example.com/sse", "transport": "sse"} + } + } + configs = self._parse(data) + assert configs[0].transport == Transport.SSE + + def test_empty_block_yields_nothing(self): + assert self._parse({}) == [] + assert self._parse({"mcpServers": {}}) == [] diff --git a/tests/unit/verticals/ai/test_user.py b/tests/unit/verticals/ai/test_user.py new file mode 100644 index 0000000000..8b828cf745 --- /dev/null +++ b/tests/unit/verticals/ai/test_user.py @@ -0,0 +1,222 @@ +import subprocess +import uuid +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest + +from ggshield.verticals.ai.user import ( + _get_hostname, + _get_linux_system_id, + _get_machine_id, + _get_macos_system_id, + _get_user_email, + _get_username, + _parse_wmic_uuid, + get_user_info, +) + + +# --------------------------------------------------------------------------- +# get_user_info +# --------------------------------------------------------------------------- + + +class TestGetUserInfo: + @patch("ggshield.verticals.ai.user._get_hostname", return_value="myhost") + @patch("ggshield.verticals.ai.user._get_username", return_value="myuser") + @patch("ggshield.verticals.ai.user._get_machine_id", return_value="abc-123") + @patch("ggshield.verticals.ai.user._get_user_email", return_value="me@test.com") + def test_populates_all_fields(self, *_mocks): + info = get_user_info() + assert info.hostname == "myhost" + assert info.username == "myuser" + assert info.machine_id == "abc-123" + assert info.user_email == "me@test.com" + + @patch("ggshield.verticals.ai.user._get_hostname", return_value="h") + @patch("ggshield.verticals.ai.user._get_username", return_value="u") + @patch("ggshield.verticals.ai.user._get_machine_id", return_value="generated") + @patch("ggshield.verticals.ai.user._get_user_email", return_value=None) + def test_reuses_provided_machine_id(self, *_mocks): + info = get_user_info(machine_id="provided-id") + assert info.machine_id == "provided-id" + + +# --------------------------------------------------------------------------- +# _get_hostname +# --------------------------------------------------------------------------- + + +class TestGetHostname: + @patch("ggshield.verticals.ai.user.sys") + @patch("ggshield.verticals.ai.user.socket.gethostname", return_value="linuxbox") + def test_linux_returns_gethostname(self, _mock_host, mock_sys): + mock_sys.platform = "linux" + assert _get_hostname() == "linuxbox" + + @patch("ggshield.verticals.ai.user.sys") + @patch("ggshield.verticals.ai.user.os.environ", {"COMPUTERNAME": "WINBOX"}) + def test_windows_prefers_computername(self, mock_sys): + mock_sys.platform = "win32" + assert _get_hostname() == "WINBOX" + + @patch("ggshield.verticals.ai.user.sys") + @patch("ggshield.verticals.ai.user.socket.gethostname", side_effect=OSError) + def test_oserror_returns_unknown(self, _mock_host, mock_sys): + mock_sys.platform = "linux" + assert _get_hostname() == "unknown" + + +# --------------------------------------------------------------------------- +# _get_username +# --------------------------------------------------------------------------- + + +class TestGetUsername: + @patch("ggshield.verticals.ai.user.getpass.getuser", return_value="alice") + def test_returns_getuser(self, _mock): + assert _get_username() == "alice" + + @patch("ggshield.verticals.ai.user.os.getlogin", return_value="bob") + @patch("ggshield.verticals.ai.user.getpass.getuser", side_effect=Exception) + def test_falls_back_to_getlogin(self, *_mocks): + assert _get_username() == "bob" + + @patch("ggshield.verticals.ai.user.os.getlogin", side_effect=Exception) + @patch("ggshield.verticals.ai.user.getpass.getuser", side_effect=Exception) + def test_returns_unknown_when_both_fail(self, *_mocks): + assert _get_username() == "unknown" + + +# --------------------------------------------------------------------------- +# _get_user_email +# --------------------------------------------------------------------------- + + +class TestGetUserEmail: + @pytest.mark.parametrize( + "run_return, expected", + [ + pytest.param( + MagicMock(returncode=0, stdout="me@example.com\n"), + "me@example.com", + id="valid_email", + ), + pytest.param( + MagicMock(returncode=1, stdout=""), + None, + id="git_failure", + ), + pytest.param( + MagicMock(returncode=0, stdout=" \n"), + None, + id="empty_output", + ), + ], + ) + @patch("ggshield.verticals.ai.user.subprocess.run") + def test_get_user_email(self, mock_run, run_return, expected): + mock_run.return_value = run_return + assert _get_user_email() == expected + + @patch( + "ggshield.verticals.ai.user.subprocess.run", + side_effect=OSError("git not found"), + ) + def test_returns_none_on_oserror(self, _mock): + assert _get_user_email() is None + + +# --------------------------------------------------------------------------- +# _get_machine_id +# --------------------------------------------------------------------------- + + +class TestGetMachineId: + def test_returns_satori_cached_id(self, tmp_path): + satori_dir = tmp_path / ".satori" + satori_dir.mkdir() + (satori_dir / "machine_id").write_text("cached-uuid\n") + with patch( + "ggshield.verticals.ai.user.get_user_home_dir", return_value=tmp_path + ): + assert _get_machine_id() == "cached-uuid" + + @patch("ggshield.verticals.ai.user.platform.system", return_value="Linux") + @patch( + "ggshield.verticals.ai.user._get_linux_system_id", + return_value="linux-machine-id", + ) + def test_linux_reads_system_id(self, _mock_linux, _mock_platform, tmp_path): + with patch( + "ggshield.verticals.ai.user.get_user_home_dir", return_value=tmp_path + ): + assert _get_machine_id() == "linux-machine-id" + + @patch("ggshield.verticals.ai.user.platform.system", return_value="Darwin") + @patch( + "ggshield.verticals.ai.user._get_macos_system_id", + return_value="mac-uuid-123", + ) + def test_macos_parses_ioreg(self, _mock_mac, _mock_platform, tmp_path): + with patch( + "ggshield.verticals.ai.user.get_user_home_dir", return_value=tmp_path + ): + assert _get_machine_id() == "mac-uuid-123" + + @patch("ggshield.verticals.ai.user.platform.system", return_value="Linux") + @patch("ggshield.verticals.ai.user._get_linux_system_id", return_value=None) + @patch("ggshield.verticals.ai.user.uuid.uuid4") + def test_generates_uuid_when_all_fail( + self, mock_uuid4, _mock_linux, _mock_platform, tmp_path + ): + fixed_uuid = uuid.UUID("12345678-1234-5678-1234-567812345678") + mock_uuid4.return_value = fixed_uuid + with patch( + "ggshield.verticals.ai.user.get_user_home_dir", return_value=tmp_path + ): + result = _get_machine_id() + assert result == str(fixed_uuid) + assert (tmp_path / ".satori" / "machine_id").read_text().strip() == str( + fixed_uuid + ) + + @patch("ggshield.verticals.ai.user.platform.system", return_value="Linux") + @patch("ggshield.verticals.ai.user._get_linux_system_id", return_value=None) + @patch("ggshield.verticals.ai.user.uuid.uuid4") + def test_persistence_failure_still_returns_uuid( + self, mock_uuid4, _mock_linux, _mock_platform, tmp_path + ): + fixed_uuid = uuid.UUID("aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee") + mock_uuid4.return_value = fixed_uuid + # Make .satori a file so mkdir fails + (tmp_path / ".satori").write_text("block") + with patch( + "ggshield.verticals.ai.user.get_user_home_dir", return_value=tmp_path + ): + result = _get_machine_id() + assert result == str(fixed_uuid) + + +# --------------------------------------------------------------------------- +# _parse_wmic_uuid +# --------------------------------------------------------------------------- + + +class TestParseWmicUuid: + @pytest.mark.parametrize( + "stdout, expected", + [ + pytest.param( + "UUID\n4C4C4544-0044-4810-8057-B5C04F4A5331\n", + "4c4c4544-0044-4810-8057-b5c04f4a5331", + id="valid_uuid", + ), + pytest.param("UUID\n", None, id="header_only"), + pytest.param("UUID\nnot-a-uuid\n", None, id="invalid_line"), + pytest.param("", None, id="empty_string"), + ], + ) + def test_parse_wmic_uuid(self, stdout, expected): + assert _parse_wmic_uuid(stdout) == expected diff --git a/uv.lock b/uv.lock index 4e3307ee66..1a0607d1f6 100644 --- a/uv.lock +++ b/uv.lock @@ -811,6 +811,7 @@ dependencies = [ { name = "oauthlib" }, { name = "packaging" }, { name = "platformdirs" }, + { name = "pydantic" }, { name = "pygitguardian" }, { name = "pyjwt" }, { name = "python-dotenv" }, @@ -877,6 +878,7 @@ requires-dist = [ { name = "oauthlib", specifier = "~=3.2.1" }, { name = "packaging", specifier = ">=22.0" }, { name = "platformdirs", specifier = "~=3.0.0" }, + { name = "pydantic", specifier = ">=2.11.10" }, { name = "pygitguardian", specifier = "~=1.28.0" }, { name = "pyjwt", specifier = "~=2.6.0" }, { name = "python-dotenv", specifier = "~=0.21.0" },