diff --git a/.gitignore b/.gitignore index 94adf2890..93a323c54 100644 --- a/.gitignore +++ b/.gitignore @@ -97,3 +97,8 @@ Thumbs.db schema.graphql .opencode/ + +# Local project config and scratch artifacts +/config.json +/IMPLEMENTATION_PLAN.md +/rendered_root_agent_system_prompt.txt diff --git a/README.md b/README.md index 8f5997c6d..14501ca35 100644 --- a/README.md +++ b/README.md @@ -1,264 +1,419 @@ -

- - Strix Banner - -

- -
- # Strix -### Open-source AI hackers to find and fix your app’s vulnerabilities. - -
- - -Docs -Website -[![](https://dcbadge.limes.pink/api/server/strix-ai)](https://discord.gg/strix-ai) - -Ask DeepWiki -GitHub Stars -License -PyPI Version +开源 AI 安全代理,用于对 Web 应用、代码仓库和本地项目进行自动化安全评估、漏洞验证和结果归档。 +## 项目定位 -Join Discord -Follow on X +Strix 不是传统的静态扫描器。它会像真实安全研究员一样运行目标、调用浏览器和终端、编写与执行 PoC,并把发现结果整理成结构化事件、报告和漏洞产物。适合以下场景: +- 应用安全测试 +- 灰盒或白盒渗透测试 +- 漏洞赏金研究 +- CI/CD 安全门禁 +- 需要流式跟踪过程的自动化评估平台 -usestrix/strix | Trendshift +## 核心能力 -
+- 多代理协作,支持任务拆分、验证和汇总 +- 同时覆盖代码仓库、本地目录、在线应用等多种目标 +- 浏览器、HTTP、终端、Python runtime 等工具链开箱即用 +- 通过 PoC 验证结果,尽量减少“只报不证”的误报 +- CLI、TUI、Web API、内置 Web Demo 共用同一套扫描执行链 +- 任务产物、事件流、最终报告统一落盘,便于二次集成 +## 快速开始 +### 前置要求 -> [!TIP] -> **New!** Strix integrates seamlessly with GitHub Actions and CI/CD pipelines. Automatically scan for vulnerabilities on every pull request and block insecure code before it reaches production! +- Python 3.12+ +- Docker 已安装并处于运行状态 +- 可用的 LLM 提供商凭据 ---- +### 安装 +生产环境或快速体验: -## Strix Overview +```bash +curl -sSL https://strix.ai/install | bash +``` -Strix are autonomous AI agents that act just like real hackers - they run your code dynamically, find vulnerabilities, and validate them through actual proof-of-concepts. Built for developers and security teams who need fast, accurate security testing without the overhead of manual pentesting or the false positives of static analysis tools. +本地开发环境: -**Key Capabilities:** +```bash +git clone https://github.com/usestrix/strix.git +cd strix/strix_api +uv pip install -e . +``` -- **Full hacker toolkit** out of the box -- **Teams of agents** that collaborate and scale -- **Real validation** with PoCs, not false positives -- **Developer‑first** CLI with actionable reports -- **Auto‑fix & reporting** to accelerate remediation +如果你已经使用 Poetry,也可以执行: +```bash +poetry install +``` -
+### 配置 +Strix 运行时配置统一从 JSON 配置文件读取,默认路径为 `~/.strix/config.json`。 -
- - Strix Demo - -
+最小可用配置: +```json +{ + "llm": { + "model": "openai/gpt-5.4", + "api_key": "your-api-key" + } +} +``` -## Use Cases +更完整的示例可以直接参考 [config.example.json](/Users/tao/Documents/docker/strix/strix/strix_api/config.example.json)。 -- **Application Security Testing** - Detect and validate critical vulnerabilities in your applications -- **Rapid Penetration Testing** - Get penetration tests done in hours, not weeks, with compliance reports -- **Bug Bounty Automation** - Automate bug bounty research and generate PoCs for faster reporting -- **CI/CD Integration** - Run tests in CI/CD to block vulnerabilities before reaching production +常用配置项: -## 🚀 Quick Start +- `llm.model`:LiteLLM 模型标识,例如 `openai/gpt-5.4` +- `llm.api_key`:LLM 提供商 API Key +- `llm.api_base`:自定义网关或本地模型地址 +- `llm.openai_compatible_provider`:显式声明 OpenAI 兼容 provider 名称 +- `llm.reasoning_effort`:推理强度 +- `features.perplexity_api_key`:联网研究能力所需的可选 Key +- `runtime.image`:沙箱镜像 +- `api.host` / `api.port` / `api.auth_token`:Web API 服务配置 -**Prerequisites:** -- Docker (running) -- An LLM API key from any [supported provider](https://docs.strix.ai/llm-providers/overview) (OpenAI, Anthropic, Google, etc.) +> Strix 不再依赖环境变量作为正常运行时配置来源。CLI、TUI 和 Web API 都优先读取配置文件。 -### Installation & First Scan +### 第一次扫描 ```bash -# Install Strix -curl -sSL https://strix.ai/install | bash - -# Configure your AI provider -export STRIX_LLM="openai/gpt-5.4" -export LLM_API_KEY="your-api-key" - -# Run your first security assessment strix --target ./app-directory ``` -> [!NOTE] -> First run automatically pulls the sandbox Docker image. Results are saved to `strix_runs/` +常用示例: ---- +```bash +# 扫描本地目录 +strix --target ./app-directory -## ☁️ Strix Platform +# 扫描在线应用 +strix --target https://example.com -Try the Strix full-stack security platform at **[app.strix.ai](https://app.strix.ai)** — sign up for free, connect your repos and domains, and launch a pentest in minutes. +# 白盒 + 黑盒联合测试 +strix --target https://github.com/org/repo --target https://staging.example.com -- **Validated findings with PoCs** and reproduction steps -- **One-click autofix** as ready-to-merge pull requests -- **Continuous monitoring** across code, cloud, and infrastructure -- **Integrations** with GitHub, Slack, Jira, Linear, and CI/CD pipelines -- **Continuous learning** that builds on past findings and remediations +# 指定指令 +strix --target https://example.com --instruction "重点看认证、IDOR 和业务逻辑" -[**Start your first pentest →**](https://app.strix.ai) +# 从文件读取详细指令 +strix --target https://example.com --instruction-file ./instruction.md ---- +# 非交互模式,适合自动化 +strix -n --target https://example.com --scan-mode quick -## ✨ Features +# 指定运行目录名称 +strix --target ./app-directory --run-name audit-20260325 +``` -### Agentic Security Tools +## Web API -Strix agents come equipped with a comprehensive security testing toolkit: +启动 API 服务: -- **Full HTTP Proxy** - Full request/response manipulation and analysis -- **Browser Automation** - Multi-tab browser for testing of XSS, CSRF, auth flows -- **Terminal Environments** - Interactive shells for command execution and testing -- **Python Runtime** - Custom exploit development and validation -- **Reconnaissance** - Automated OSINT and attack surface mapping -- **Code Analysis** - Static and dynamic analysis capabilities -- **Knowledge Management** - Structured findings and attack documentation +```bash +strix-api +``` -### Comprehensive Vulnerability Detection +也可以覆盖配置文件和监听地址: -Strix can identify and validate a wide range of security vulnerabilities: +```bash +strix-api --config ~/.strix/config.json --host 0.0.0.0 --port 8787 +``` -- **Access Control** - IDOR, privilege escalation, auth bypass -- **Injection Attacks** - SQL, NoSQL, command injection -- **Server-Side** - SSRF, XXE, deserialization flaws -- **Client-Side** - XSS, prototype pollution, DOM vulnerabilities -- **Business Logic** - Race conditions, workflow manipulation -- **Authentication** - JWT vulnerabilities, session management -- **Infrastructure** - Misconfigurations, exposed services +默认地址为 `http://127.0.0.1:8787`,任务接口统一挂在 `/api/v1` 下。 -### Graph of Agents +### 认证 -Advanced multi-agent orchestration for comprehensive security testing: +如果配置了 `api.auth_token`,所有 `/api/v1/tasks*` 接口都需要携带 Bearer Token: -- **Distributed Workflows** - Specialized agents for different attacks and assets -- **Scalable Testing** - Parallel execution for fast comprehensive coverage -- **Dynamic Coordination** - Agents collaborate and share discoveries +```text +Authorization: Bearer +``` ---- +如果未配置 `api.auth_token`,则不校验认证。 -## Usage Examples +### 文档入口 -### Basic Usage +当 `api.enable_docs != false` 时,FastAPI 会暴露: -```bash -# Scan a local codebase -strix --target ./app-directory +- `GET /docs` +- `GET /redoc` +- `GET /openapi.json` -# Security review of a GitHub repository -strix --target https://github.com/org/repo +### 创建任务请求 -# Black-box web application assessment -strix --target https://your-app.com -``` +创建任务接口使用 JSON 请求体: -### Advanced Testing Scenarios +```json +{ + "targets": ["https://example.com"], + "instruction": "重点看认证和 IDOR", + "scan_mode": "deep", + "task_id": "example-deep-scan" +} +``` -```bash -# Grey-box authenticated testing -strix --target https://your-app.com --instruction "Perform authenticated testing using credentials: user:pass" +字段说明: + +- `targets`:必填,至少 1 项,支持本地目录、仓库 URL、在线应用 URL +- `instruction`:可选,直接传入扫描指令 +- `instruction_file`:可选,服务端本地指令文件路径 +- `scan_mode`:可选,`quick`、`standard`、`deep`,默认 `deep` +- `task_id`:可选,任务唯一标识 +- `run_name`:可选,运行名称 +- `config_path`:可选,该任务单独使用的配置文件路径 + +约束: + +- `instruction` 与 `instruction_file` 不能同时传 +- 不允许传入未定义字段 +- 超过并发限制、任务重名、目标校验失败或配置文件非法时会返回 `400` + +### 任务状态 + +任务状态枚举: + +- `queued` +- `running` +- `cancelling` +- `completed` +- `failed` +- `cancelled` + +### 接口速查 + +| 方法 | 路径 | 说明 | 响应要点 | +| --- | --- | --- | --- | +| `GET` | `/health` | 健康检查 | `{"status":"ok"}` | +| `POST` | `/api/v1/tasks` | 创建扫描任务 | 返回 `{"task": ...}` | +| `GET` | `/api/v1/tasks` | 列出任务 | 返回 `{"tasks": [...]}` | +| `GET` | `/api/v1/tasks/{task_id}` | 获取任务详情 | 实际返回完整结果对象 `{task, scan_state, artifacts}` | +| `GET` | `/api/v1/tasks/{task_id}/result` | 获取结构化结果 | 返回 `{task, scan_state, artifacts}` | +| `GET` | `/api/v1/tasks/{task_id}/results` | `/result` 别名 | 与 `/result` 相同 | +| `POST` | `/api/v1/tasks/{task_id}/cancel` | 取消任务 | 返回 `{"task": ...}` | +| `GET` | `/api/v1/tasks/{task_id}/events` | 获取历史事件 | 支持 `limit=1..5000` | +| `GET` | `/api/v1/tasks/{task_id}/stream` | SSE 流式订阅事件 | `text/event-stream` | +| `GET` | `/api/v1/tasks/{task_id}/artifacts` | 获取产物列表 | 返回文件路径数组 | +| `GET` | `/api/v1/tasks/{task_id}/report` | 获取最终报告 | 返回纯文本 Markdown | + +### 响应结构 + +任务对象常见字段: + +```json +{ + "task_id": "example-deep-scan", + "run_name": "example-deep-scan", + "status": "queued", + "created_at": "2026-03-25T00:00:00+00:00", + "updated_at": "2026-03-25T00:00:00+00:00", + "started_at": null, + "finished_at": null, + "completed_at": null, + "pid": 12345, + "exit_code": null, + "error": null, + "request": {}, + "run_dir": "/abs/path/strix_runs/example-deep-scan", + "worker_log_path": "/abs/path/strix_runs/example-deep-scan/worker.log", + "scan_state_path": "/abs/path/strix_runs/example-deep-scan/scan_state.json", + "events_path": "/abs/path/strix_runs/example-deep-scan/events.jsonl" +} +``` -# Multi-target testing (source code + deployed app) -strix -t https://github.com/org/app -t https://your-app.com +结果对象结构: + +```json +{ + "task": {}, + "scan_state": { + "run_metadata": {}, + "scan_config": {}, + "scan_results": {}, + "final_scan_result": "markdown text", + "vulnerability_reports": [], + "agents": {}, + "tool_executions": {}, + "chat_messages": [] + }, + "artifacts": [ + "/abs/path/strix_runs/example-deep-scan/events.jsonl", + "/abs/path/strix_runs/example-deep-scan/scan_state.json" + ] +} +``` -# Focused testing with custom instructions -strix --target api.your-app.com --instruction "Focus on business logic flaws and IDOR vulnerabilities" +说明: -# Provide detailed instructions through file (e.g., rules of engagement, scope, exclusions) -strix --target api.your-app.com --instruction-file ./instruction.md -``` +- `artifacts` 当前返回的是本地文件路径,不是下载 URL +- `/api/v1/tasks/{task_id}` 与 `/api/v1/tasks/{task_id}/result` 当前返回结构一致 +- `report` 接口返回纯文本响应,而不是 JSON -### Headless Mode +### 流式事件 -Run Strix programmatically without interactive UI using the `-n/--non-interactive` flag—perfect for servers and automated jobs. The CLI prints real-time vulnerability findings, and the final report before exiting. Exits with non-zero code when vulnerabilities are found. +事件流接口: ```bash -strix -n --target https://your-app.com +curl -N http://127.0.0.1:8787/api/v1/tasks//stream \ + -H 'Authorization: Bearer optional-api-token' ``` -### CI/CD (GitHub Actions) +支持查询参数: + +- `follow=true`:默认持续跟随新事件 +- `from_offset=0`:从 `events.jsonl` 的字节偏移处继续读取 -Strix can be added to your pipeline to run a security test on pull requests with a lightweight GitHub Actions workflow: +SSE 行为: -```yaml -name: strix-penetration-test +1. 建立连接后先发送 `stream.connected` +2. 随后把 `events.jsonl` 中的事件按 `event_type` 转成 SSE 事件名 +3. 任务运行中会发送 `: keep-alive` +4. 任务结束后额外发送 `task.finished` -on: - pull_request: +典型事件类型: -jobs: - security-scan: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v6 +- `run.started` +- `run.configured` +- `agent.created` +- `agent.status.updated` +- `chat.streaming` +- `chat.message` +- `tool.execution.started` +- `tool.execution.updated` +- `finding.created` +- `finding.reviewed` +- `run.completed` +- `task.finished` - - name: Install Strix - run: curl -sSL https://strix.ai/install | bash +SSE 示例: - - name: Run Strix - env: - STRIX_LLM: ${{ secrets.STRIX_LLM }} - LLM_API_KEY: ${{ secrets.LLM_API_KEY }} +```text +event: stream.connected +data: {"task_id":"example-deep-scan","offset":0} - run: strix -n -t ./ --scan-mode quick +event: chat.message +data: {"offset":1234,"payload":{"event_type":"chat.message","payload":{"content":"hello"}}} + +event: task.finished +data: {"task_id":"example-deep-scan","status":"completed"} ``` -### Configuration +### 常见状态码 -```bash -export STRIX_LLM="openai/gpt-5.4" -export LLM_API_KEY="your-api-key" +- `200 OK`:请求成功 +- `201 Created`:任务创建成功 +- `400 Bad Request`:业务校验失败 +- `401 Unauthorized`:Token 缺失或错误 +- `404 Not Found`:任务或报告不存在 +- `422 Unprocessable Entity`:请求体或查询参数校验失败 + +### 常用调用示例 -# Optional -export LLM_API_BASE="your-api-base-url" # if using a local model, e.g. Ollama, LMStudio -export PERPLEXITY_API_KEY="your-api-key" # for search capabilities -export STRIX_REASONING_EFFORT="high" # control thinking effort (default: high, quick scan: medium) +```bash +# 创建任务 +curl -X POST http://127.0.0.1:8787/api/v1/tasks \ + -H 'Authorization: Bearer optional-api-token' \ + -H 'Content-Type: application/json' \ + -d '{ + "targets": ["https://example.com"], + "instruction": "重点看认证和 IDOR", + "scan_mode": "deep", + "task_id": "example-deep-scan" + }' + +# 查看任务列表 +curl http://127.0.0.1:8787/api/v1/tasks \ + -H 'Authorization: Bearer optional-api-token' + +# 查看结构化结果 +curl http://127.0.0.1:8787/api/v1/tasks//result \ + -H 'Authorization: Bearer optional-api-token' + +# 拉取历史事件 +curl http://127.0.0.1:8787/api/v1/tasks//events?limit=200 \ + -H 'Authorization: Bearer optional-api-token' + +# 获取产物列表 +curl http://127.0.0.1:8787/api/v1/tasks//artifacts \ + -H 'Authorization: Bearer optional-api-token' + +# 获取最终 Markdown 报告 +curl http://127.0.0.1:8787/api/v1/tasks//report \ + -H 'Authorization: Bearer optional-api-token' + +# 取消任务 +curl -X POST http://127.0.0.1:8787/api/v1/tasks//cancel \ + -H 'Authorization: Bearer optional-api-token' ``` -> [!NOTE] -> Strix automatically saves your configuration to `~/.strix/cli-config.json`, so you don't have to re-enter it on every run. +更完整的接口说明请看: -**Recommended models for best results:** +- [docs/api/web-api.mdx](/Users/tao/Documents/docker/strix/strix/strix_api/docs/api/web-api.mdx) -- [OpenAI GPT-5.4](https://openai.com/api/) — `openai/gpt-5.4` -- [Anthropic Claude Sonnet 4.6](https://claude.com/platform/api) — `anthropic/claude-sonnet-4-6` -- [Google Gemini 3 Pro Preview](https://cloud.google.com/vertex-ai) — `vertex_ai/gemini-3-pro-preview` +## Web Demo -See the [LLM Providers documentation](https://docs.strix.ai/llm-providers/overview) for all supported providers including Vertex AI, Bedrock, Azure, and local models. +内置 Demo 页面用于展示任务管理、事件流和结果查看能力: -## Enterprise +1. 启动 `strix-api` +2. 打开 `http://127.0.0.1:8787/demo` +3. 输入 API 地址和 Bearer Token +4. 在页面中创建任务、查看结果、回放事件、订阅流式输出 -Get the same Strix experience with [enterprise-grade](https://strix.ai/demo) controls: SSO (SAML/OIDC), custom compliance reports, dedicated support & SLA, custom deployment options (VPC/self-hosted), BYOK model support, and tailored agents optimized for your environment. [Learn more](https://strix.ai/demo). +Demo 当前支持: -## Documentation +- 创建任务 +- 查看任务列表与详情 +- 获取 `/result` 和 `/results` +- 查看 `/events`、`/artifacts`、`/report` +- 取消任务 +- 通过 SSE 流式查看执行过程 -Full documentation is available at **[docs.strix.ai](https://docs.strix.ai)** — including detailed guides for usage, CI/CD integrations, skills, and advanced configuration. +## 输出目录 -## Contributing +默认情况下,每次运行都会写入: -We welcome contributions of code, docs, and new skills - check out our [Contributing Guide](https://docs.strix.ai/contributing) to get started or open a [pull request](https://github.com/usestrix/strix/pulls)/[issue](https://github.com/usestrix/strix/issues). +```text +strix_runs// +``` -## Join Our Community +常见文件: -Have questions? Found a bug? Want to contribute? **[Join our Discord!](https://discord.gg/strix-ai)** +- `task_state.json`:任务生命周期状态 +- `events.jsonl`:事件流历史 +- `scan_state.json`:结构化扫描状态与汇总结果 +- `penetration_test_report.md`:最终 Markdown 报告 +- `vulnerabilities/`:漏洞明细目录 +- `vulnerabilities.csv`:漏洞索引 +- `worker.log`:worker 标准输出与错误输出 -## Support the Project +## 文档索引 -**Love Strix?** Give us a ⭐ on GitHub! +- [docs/api/web-api.mdx](/Users/tao/Documents/docker/strix/strix/strix_api/docs/api/web-api.mdx):Web API 中文文档 +- [docs/advanced/configuration.mdx](/Users/tao/Documents/docker/strix/strix/strix_api/docs/advanced/configuration.mdx):配置说明 +- [docs/usage/cli.mdx](/Users/tao/Documents/docker/strix/strix/strix_api/docs/usage/cli.mdx):CLI 参数说明 +- [docs/integrations/github-actions.mdx](/Users/tao/Documents/docker/strix/strix/strix_api/docs/integrations/github-actions.mdx):GitHub Actions 集成 -## Acknowledgements +## 开发与测试 -Strix builds on the incredible work of open-source projects like [LiteLLM](https://github.com/BerriAI/litellm), [Caido](https://github.com/caido/caido), [Nuclei](https://github.com/projectdiscovery/nuclei), [Playwright](https://github.com/microsoft/playwright), and [Textual](https://github.com/Textualize/textual). Huge thanks to their maintainers! +安装开发依赖后,可以执行: +```bash +make check-all +``` + +或按需运行: + +```bash +uv run pytest -o addopts='' tests/api/test_server.py tests/api/test_task_store.py +python3 -m compileall strix tests +``` -> [!WARNING] -> Only test apps you own or have permission to test. You are responsible for using Strix ethically and legally. +## 合规与免责声明 - +请仅测试你拥有或已获得明确授权的系统。使用者需自行确保测试行为符合当地法律、合同约束和组织安全规范。 diff --git a/config.example.json b/config.example.json new file mode 100644 index 000000000..807816f80 --- /dev/null +++ b/config.example.json @@ -0,0 +1,38 @@ +{ + "llm": { + "model": "openai/gpt-5.4", + "api_key": "your-api-key-here", + "api_base": null, + "openai_compatible_provider": null, + "reasoning_effort": "high", + "timeout": 300 + }, + "features": { + "perplexity_api_key": null, + "disable_browser": false + }, + "runtime": { + "backend": "docker", + "image": "ghcr.io/usestrix/strix-sandbox:0.1.13", + "sandbox_execution_timeout": 120, + "sandbox_connect_timeout": 10, + "docker_host": null, + "caido_api_token": null + }, + "telemetry": { + "enabled": true, + "otel_enabled": null, + "posthog_enabled": null, + "traceloop_base_url": null, + "traceloop_api_key": null, + "traceloop_headers": null + }, + "api": { + "host": "127.0.0.1", + "port": 8787, + "auth_token": null, + "max_concurrent_tasks": 1, + "enable_docs": true, + "stream_poll_interval_ms": 500 + } +} diff --git a/containers/docker-entrypoint.sh b/containers/docker-entrypoint.sh index cbef471ef..b82e21663 100644 --- a/containers/docker-entrypoint.sh +++ b/containers/docker-entrypoint.sh @@ -154,17 +154,30 @@ echo "✅ CA added to browser trust store" echo "Starting tool server..." cd /app export PYTHONPATH=/app -export STRIX_SANDBOX_MODE=true export POETRY_VIRTUALENVS_CREATE=false export TOOL_SERVER_TIMEOUT="${STRIX_SANDBOX_EXECUTION_TIMEOUT:-120}" TOOL_SERVER_LOG="/tmp/tool_server.log" +RUNTIME_CONFIG="/tmp/strix-runtime-config.json" + +jq -n \ + --arg caido_api_token "$CAIDO_API_TOKEN" \ + --argjson sandbox_execution_timeout "$TOOL_SERVER_TIMEOUT" \ + '{ + runtime: { + sandbox_mode: true, + caido_api_token: $caido_api_token, + sandbox_execution_timeout: $sandbox_execution_timeout + } + }' > "$RUNTIME_CONFIG" sudo -E -u pentester \ poetry run python -m strix.runtime.tool_server \ --token="$TOOL_SERVER_TOKEN" \ --host=0.0.0.0 \ --port="$TOOL_SERVER_PORT" \ - --timeout="$TOOL_SERVER_TIMEOUT" > "$TOOL_SERVER_LOG" 2>&1 & + --timeout="$TOOL_SERVER_TIMEOUT" \ + --config="$RUNTIME_CONFIG" \ + --sandbox-mode > "$TOOL_SERVER_LOG" 2>&1 & for i in {1..10}; do if curl -s "http://127.0.0.1:$TOOL_SERVER_PORT/health" | grep -q '"status":"healthy"'; then diff --git a/docs/README.md b/docs/README.md index cc1989533..a7362e423 100644 --- a/docs/README.md +++ b/docs/README.md @@ -1,8 +1,8 @@ -# Strix Documentation +# Strix 文档目录 -Documentation source files for Strix, powered by [Mintlify](https://mintlify.com). +这里存放的是 Strix 的文档源文件,使用 [Mintlify](https://mintlify.com) 渲染。 -## Local Preview +## 本地预览 ```bash npm i -g mintlify diff --git a/docs/advanced/configuration.mdx b/docs/advanced/configuration.mdx index 4d51f3c62..e5155eb86 100644 --- a/docs/advanced/configuration.mdx +++ b/docs/advanced/configuration.mdx @@ -1,138 +1,147 @@ --- title: "Configuration" -description: "Environment variables for Strix" +description: "Configure Strix with JSON files" --- -Configure Strix using environment variables or a config file. +Strix runtime configuration is file-based. The default path is: -## LLM Configuration - - - Model name in LiteLLM format (e.g., `openai/gpt-5.4`, `anthropic/claude-sonnet-4-6`). - - - - API key for your LLM provider. Not required for local models or cloud provider auth (Vertex AI, AWS Bedrock). - - - - Custom API base URL. Also accepts `OPENAI_API_BASE`, `LITELLM_BASE_URL`, or `OLLAMA_API_BASE`. - - - - Request timeout in seconds for LLM calls. - - - - Maximum number of retries for LLM API calls on transient failures. - - - - Control thinking effort for reasoning models. Valid values: `none`, `minimal`, `low`, `medium`, `high`, `xhigh`. Defaults to `medium` for quick scan mode. - - - - Timeout in seconds for memory compression operations (context summarization). - - -## Optional Features - - - API key for Perplexity AI. Enables real-time web search during scans for OSINT and vulnerability research. - - - - Disable browser automation tools. - - - - Global telemetry default toggle. Set to `0`, `false`, `no`, or `off` to disable both PostHog and OTEL unless overridden by per-channel flags below. - - - - Enable/disable OpenTelemetry run observability independently. When unset, falls back to `STRIX_TELEMETRY`. - - - - Enable/disable PostHog product telemetry independently. When unset, falls back to `STRIX_TELEMETRY`. - - - - OTLP/Traceloop base URL for remote OpenTelemetry export. If unset, Strix keeps traces local only. - - - - API key used for remote trace export. Remote export is enabled only when both `TRACELOOP_BASE_URL` and `TRACELOOP_API_KEY` are set. - - - - Optional custom OTEL headers (JSON object or `key=value,key2=value2`). Useful for Langfuse or custom/self-hosted OTLP gateways. - - -When remote OTEL vars are not set, Strix still writes complete run telemetry locally to: - -```bash -strix_runs//events.jsonl +```text +~/.strix/config.json ``` -When remote vars are set, Strix dual-writes telemetry to both local JSONL and the remote OTEL endpoint. - -## Docker Configuration - - - Docker image to use for the sandbox container. - - - - Docker daemon socket path. Use for remote Docker hosts or custom configurations. - - - - Runtime backend for the sandbox environment. - - -## Sandbox Configuration - - - Maximum execution time in seconds for sandbox operations. - - - - Timeout in seconds for connecting to the sandbox container. - - -## Config File - -Strix stores configuration in `~/.strix/cli-config.json`. You can also specify a custom config file: +You can also point Strix or `strix-api` at a custom file: ```bash strix --target ./app --config /path/to/config.json +strix-api --config /path/to/config.json ``` -**Config file format:** +## Minimal Example ```json { - "env": { - "STRIX_LLM": "openai/gpt-5.4", - "LLM_API_KEY": "sk-...", - "STRIX_REASONING_EFFORT": "high" + "llm": { + "model": "openai/gpt-5.4", + "api_key": "your-api-key" } } ``` -## Example Setup +## Full Example -```bash -# Required -export STRIX_LLM="openai/gpt-5.4" -export LLM_API_KEY="sk-..." - -# Optional: Enable web search -export PERPLEXITY_API_KEY="pplx-..." +```json +{ + "llm": { + "model": "openai/gpt-5.4", + "api_key": "your-api-key", + "api_base": null, + "openai_compatible_provider": null, + "reasoning_effort": "high", + "max_retries": 5, + "memory_compressor_timeout": 30, + "timeout": 300 + }, + "features": { + "perplexity_api_key": null, + "disable_browser": false + }, + "runtime": { + "backend": "docker", + "image": "ghcr.io/usestrix/strix-sandbox:0.1.13", + "sandbox_execution_timeout": 120, + "sandbox_connect_timeout": 10, + "docker_host": null, + "caido_api_token": null + }, + "telemetry": { + "enabled": true, + "otel_enabled": null, + "posthog_enabled": null, + "traceloop_base_url": null, + "traceloop_api_key": null, + "traceloop_headers": null + }, + "api": { + "host": "127.0.0.1", + "port": 8787, + "auth_token": null, + "max_concurrent_tasks": 1, + "enable_docs": true, + "stream_poll_interval_ms": 500 + } +} +``` -# Optional: Custom timeouts -export LLM_TIMEOUT="600" -export STRIX_SANDBOX_EXECUTION_TIMEOUT="300" +## `llm` + +| Key | Type | Description | +| --- | --- | --- | +| `model` | `string` | LiteLLM model identifier such as `openai/gpt-5.4` | +| `api_key` | `string \| null` | Provider API key | +| `api_base` | `string \| null` | Custom gateway or self-hosted endpoint | +| `openai_compatible_provider` | `string \| null` | Explicit provider name for OpenAI-compatible gateways not built into LiteLLM | +| `reasoning_effort` | `string` | Reasoning level, typically `none`, `minimal`, `low`, `medium`, `high`, or `xhigh` | +| `max_retries` | `integer` | Retry count for transient LLM failures | +| `memory_compressor_timeout` | `integer` | Timeout for memory compression operations | +| `timeout` | `integer` | Request timeout in seconds | + +## `features` + +| Key | Type | Description | +| --- | --- | --- | +| `perplexity_api_key` | `string \| null` | Optional key for live web research | +| `disable_browser` | `boolean` | Disable browser automation tools | + +## `runtime` + +| Key | Type | Description | +| --- | --- | --- | +| `backend` | `string` | Sandbox backend, default `docker` | +| `image` | `string` | Sandbox image name | +| `sandbox_execution_timeout` | `integer` | Sandbox execution timeout in seconds | +| `sandbox_connect_timeout` | `integer` | Sandbox connection timeout in seconds | +| `docker_host` | `string \| null` | Custom Docker daemon endpoint | +| `caido_api_token` | `string \| null` | Optional Caido token | + +## `telemetry` + +| Key | Type | Description | +| --- | --- | --- | +| `enabled` | `boolean` | Global telemetry switch | +| `otel_enabled` | `boolean \| null` | Override OTEL telemetry separately | +| `posthog_enabled` | `boolean \| null` | Override PostHog telemetry separately | +| `traceloop_base_url` | `string \| null` | Remote OTEL endpoint | +| `traceloop_api_key` | `string \| null` | Remote OTEL API key | +| `traceloop_headers` | `string \| null` | Custom OTEL headers | + +## `api` + +| Key | Type | Description | +| --- | --- | --- | +| `host` | `string` | API bind host | +| `port` | `integer` | API bind port | +| `auth_token` | `string \| null` | Optional Bearer token for `/api/v1/tasks*` | +| `max_concurrent_tasks` | `integer` | Maximum number of active tasks | +| `enable_docs` | `boolean` | Enable `/docs`, `/redoc`, and `/openapi.json` | +| `stream_poll_interval_ms` | `integer` | SSE polling interval in milliseconds | + +## OpenAI-Compatible Providers + +If you are using an OpenAI-compatible gateway that LiteLLM does not recognize out of the box, configure both `api_base` and `openai_compatible_provider`: +```json +{ + "llm": { + "model": "astron-code-latest", + "api_key": "your-api-key", + "api_base": "https://maas-coding-api.example.com/v2", + "openai_compatible_provider": "AstronCodingPlan" + } +} ``` + +## Notes + +- Normal Strix runtime configuration no longer depends on environment variables. +- If you use cloud providers such as Vertex AI or Bedrock, their SDK-level credentials may still follow the provider's own authentication mechanism. +- Older `cli-config.json` or `env`-style configuration layouts are legacy compatibility paths and are not recommended for new setups. diff --git a/docs/api/web-api.mdx b/docs/api/web-api.mdx new file mode 100644 index 000000000..4d5111468 --- /dev/null +++ b/docs/api/web-api.mdx @@ -0,0 +1,692 @@ +--- +title: "Web API" +description: "Strix Web API 中文参考文档" +--- + +## 概览 + +Strix Web API 用于把扫描能力以服务形式暴露给外部系统。它支持: + +- 创建扫描任务 +- 查询任务状态与结果 +- 获取事件历史 +- 通过 SSE 流式订阅任务执行过程 +- 获取最终 Markdown 报告和全部产物文件路径 +- 取消仍在运行中的任务 + +服务入口默认是: + +```text +http://127.0.0.1:8787 +``` + +公开接口: + +- `GET /health` +- `GET /demo` + +任务接口统一位于: + +```text +/api/v1/tasks +``` + +## 启动方式 + +```bash +strix-api +``` + +覆盖配置文件或监听地址: + +```bash +strix-api --config ~/.strix/config.json --host 0.0.0.0 --port 8787 +``` + +## 鉴权 + +如果配置文件中设置了 `api.auth_token`,所有 `/api/v1/tasks*` 接口都需要 Bearer Token: + +```text +Authorization: Bearer +``` + +未配置 `api.auth_token` 时,这些接口默认不鉴权。 + +典型未授权响应: + +```json +{ + "detail": "Invalid or missing API token" +} +``` + +## 任务状态 + +| 状态值 | 含义 | +| --- | --- | +| `queued` | 任务已创建,worker 已启动或等待执行 | +| `running` | worker 已进入实际扫描阶段 | +| `cancelling` | 已收到取消请求,正在结束执行 | +| `completed` | 已正常完成 | +| `failed` | 执行失败 | +| `cancelled` | 已取消 | + +## 核心数据结构 + +### 1. 创建任务请求体 `ScanTaskRequest` + +| 字段 | 类型 | 必填 | 说明 | 约束 | +| --- | --- | --- | --- | --- | +| `targets` | `string[]` | 是 | 扫描目标列表。支持本地目录、仓库 URL、在线应用 URL、域名或 IP | 至少 1 项 | +| `instruction` | `string \| null` | 否 | 直接内联传入的附加扫描说明 | 不能与 `instruction_file` 同时传入 | +| `instruction_file` | `string \| null` | 否 | 服务端本地文件路径,读取文件内容作为扫描说明 | 不能与 `instruction` 同时传入;文件必须存在且非空 | +| `scan_mode` | `quick \| standard \| deep` | 否 | 扫描深度 | 默认 `deep` | +| `task_id` | `string \| null` | 否 | 外部系统自定义任务 ID | 必须唯一;若重复会返回 `400` | +| `run_name` | `string \| null` | 否 | 运行目录名称 | 未传时默认等于任务 ID | +| `config_path` | `string \| null` | 否 | 当前任务单独使用的配置文件路径 | 文件必须可读且通过校验 | + +请求示例: + +```json +{ + "targets": [ + "https://example.com", + "https://github.com/org/repo" + ], + "instruction": "重点验证认证、IDOR、业务逻辑和高危注入问题", + "scan_mode": "deep", + "task_id": "example-deep-scan", + "run_name": "example-20260325" +} +``` + +### 2. 任务记录对象 `ScanTaskRecord` + +`POST /api/v1/tasks`、`GET /api/v1/tasks` 和取消接口都会返回这个对象或包含这个对象。 + +| 字段 | 类型 | 说明 | +| --- | --- | --- | +| `task_id` | `string` | 任务唯一标识 | +| `run_name` | `string \| null` | 运行目录名,默认与 `task_id` 相同 | +| `status` | `string` | 任务状态,取值见上文 | +| `created_at` | `string` | 任务创建时间,ISO 8601 | +| `updated_at` | `string` | 最近更新时间,ISO 8601 | +| `started_at` | `string \| null` | worker 进入执行阶段的时间 | +| `finished_at` | `string \| null` | 任务结束时间 | +| `completed_at` | `string \| null` | 与 `finished_at` 兼容的结束时间字段 | +| `pid` | `integer \| null` | worker 进程 ID | +| `exit_code` | `integer \| null` | worker 退出码 | +| `error` | `string \| null` | 失败原因或取消说明 | +| `request` | `object` | 原始请求体,包含标准化后的参数 | +| `run_dir` | `string` | 任务输出目录绝对路径 | +| `worker_log_path` | `string` | worker 日志文件绝对路径 | +| `scan_state_path` | `string` | `scan_state.json` 绝对路径 | +| `events_path` | `string` | `events.jsonl` 绝对路径 | + +### 3. 结构化结果对象 `ScanTaskResult` + +| 字段 | 类型 | 说明 | +| --- | --- | --- | +| `task` | `ScanTaskRecord` | 当前任务记录 | +| `scan_state` | `object \| null` | 结构化扫描状态;未生成时为 `null` | +| `artifacts` | `string[]` | 运行目录下所有文件的绝对路径列表 | + +### 4. `scan_state` 常见字段 + +`scan_state` 的具体内容会随扫描结果变化,但通常包含: + +| 字段 | 类型 | 说明 | +| --- | --- | --- | +| `run_metadata` | `object` | 任务运行元信息,例如运行名、开始/结束时间、目标、状态 | +| `scan_config` | `object` | 本次实际使用的扫描配置 | +| `scan_results` | `object` | 汇总分析结果,例如摘要、方法论、建议 | +| `final_scan_result` | `string` | 最终 Markdown 报告正文 | +| `vulnerability_reports` | `object[]` | 漏洞明细列表 | +| `agents` | `object` | 代理执行信息 | +| `tool_executions` | `object[]` | 工具调用历史 | +| `chat_messages` | `object[]` | 聊天消息历史 | + +### 5. 事件对象 + +`/events` 和 `/stream` 都围绕同一套事件结构工作。单条事件常见字段如下: + +| 字段 | 类型 | 说明 | +| --- | --- | --- | +| `timestamp` | `string` | 事件时间 | +| `event_type` | `string` | 事件类型 | +| `run_id` | `string` | 当前任务 ID | +| `trace_id` | `string \| null` | 链路追踪 ID | +| `span_id` | `string \| null` | 当前 span ID | +| `parent_span_id` | `string \| null` | 父 span ID | +| `actor` | `object \| null` | 产生事件的 agent、tool 或角色信息 | +| `payload` | `object \| null` | 事件主体内容 | +| `status` | `string \| null` | 事件状态 | +| `error` | `object \| string \| null` | 错误信息 | +| `source` | `string \| null` | 事件来源 | + +常见 `event_type` 示例: + +- `run.started` +- `run.configured` +- `agent.created` +- `chat.streaming` +- `chat.message` +- `tool.execution.started` +- `tool.execution.updated` +- `finding.created` +- `run.completed` +- `task.finished` + +## 接口明细 + +### GET `/health` + +用途:健康检查。 + +鉴权:否。 + +响应: + +```json +{ + "status": "ok" +} +``` + +状态码: + +- `200`:服务可用 + +### GET `/demo` + +用途:返回内置 Web Demo 页面。 + +鉴权:否。 + +响应类型:`text/html` + +状态码: + +- `200`:返回页面内容 + +### POST `/api/v1/tasks` + +用途:创建扫描任务并异步拉起 worker。 + +鉴权:是,除非未配置 `api.auth_token`。 + +请求头: + +- `Content-Type: application/json` +- `Authorization: Bearer `(如果启用了鉴权) + +请求体:见上文 `ScanTaskRequest`。 + +成功响应: + +```json +{ + "task": { + "task_id": "example-deep-scan", + "run_name": "example-20260325", + "status": "queued", + "created_at": "2026-03-25T06:00:00+00:00", + "updated_at": "2026-03-25T06:00:00+00:00", + "started_at": null, + "finished_at": null, + "completed_at": null, + "pid": 12345, + "exit_code": null, + "error": null, + "request": { + "targets": ["https://example.com"], + "instruction": "重点看认证和 IDOR", + "instruction_file": null, + "scan_mode": "deep", + "task_id": "example-deep-scan", + "run_name": "example-20260325", + "config_path": "/Users/example/.strix/config.json" + }, + "run_dir": "/path/to/strix_runs/example-deep-scan", + "worker_log_path": "/path/to/strix_runs/example-deep-scan/worker.log", + "scan_state_path": "/path/to/strix_runs/example-deep-scan/scan_state.json", + "events_path": "/path/to/strix_runs/example-deep-scan/events.jsonl" + } +} +``` + +状态码: + +- `201`:创建成功 +- `400`:业务校验失败 +- `401`:鉴权失败 +- `422`:请求体验证失败 + +典型 `400` 场景: + +- 超过 `api.max_concurrent_tasks` +- `task_id` 已存在 +- `instruction_file` 不存在、不可读或为空 +- `targets` 无法通过目标校验 +- `config_path` 指向的配置文件无效 + +### GET `/api/v1/tasks` + +用途:列出所有任务。 + +鉴权:是。 + +请求参数:无。 + +响应: + +```json +{ + "tasks": [ + { + "task_id": "example-deep-scan", + "run_name": "example-20260325", + "status": "running", + "created_at": "2026-03-25T06:00:00+00:00", + "updated_at": "2026-03-25T06:01:00+00:00", + "started_at": "2026-03-25T06:00:10+00:00", + "finished_at": null, + "completed_at": null, + "pid": 12345, + "exit_code": null, + "error": null, + "request": { + "targets": ["https://example.com"], + "instruction": null, + "instruction_file": null, + "scan_mode": "deep", + "task_id": "example-deep-scan", + "run_name": "example-20260325", + "config_path": "/Users/example/.strix/config.json" + }, + "run_dir": "/path/to/strix_runs/example-deep-scan", + "worker_log_path": "/path/to/strix_runs/example-deep-scan/worker.log", + "scan_state_path": "/path/to/strix_runs/example-deep-scan/scan_state.json", + "events_path": "/path/to/strix_runs/example-deep-scan/events.jsonl" + } + ] +} +``` + +补充说明: + +- 返回顺序按 `created_at` 倒序排列 +- 服务端在读取时会刷新任务状态 + +状态码: + +- `200` +- `401` + +### GET `/api/v1/tasks/{task_id}` + +用途:获取任务详情。 + +鉴权:是。 + +路径参数: + +| 参数 | 类型 | 说明 | +| --- | --- | --- | +| `task_id` | `string` | 任务 ID | + +响应结构: + +```json +{ + "task": {}, + "scan_state": {}, + "artifacts": [] +} +``` + +注意:这个接口当前返回的是完整结果对象,而不只是 `task` 元数据。 + +状态码: + +- `200` +- `401` +- `404` + +### GET `/api/v1/tasks/{task_id}/result` + +用途:获取结构化结果。 + +鉴权:是。 + +路径参数: + +| 参数 | 类型 | 说明 | +| --- | --- | --- | +| `task_id` | `string` | 任务 ID | + +响应: + +```json +{ + "task": {}, + "scan_state": {}, + "artifacts": [] +} +``` + +状态码: + +- `200` +- `401` +- `404` + +### GET `/api/v1/tasks/{task_id}/results` + +用途:`/result` 的别名。 + +鉴权:是。 + +路径参数: + +| 参数 | 类型 | 说明 | +| --- | --- | --- | +| `task_id` | `string` | 任务 ID | + +响应与状态码:同 `/result`。 + +### POST `/api/v1/tasks/{task_id}/cancel` + +用途:请求取消任务。 + +鉴权:是。 + +路径参数: + +| 参数 | 类型 | 说明 | +| --- | --- | --- | +| `task_id` | `string` | 任务 ID | + +请求体:无。 + +响应: + +```json +{ + "task": { + "task_id": "example-deep-scan", + "status": "cancelling" + } +} +``` + +补充说明: + +- 如果任务仍在执行,通常会先进入 `cancelling` +- 最终状态会由 worker 或存储刷新为 `cancelled` +- 如果任务本身已经结束,这个接口会直接返回当前最终状态,具备幂等特征 + +状态码: + +- `200` +- `401` +- `404` + +### GET `/api/v1/tasks/{task_id}/events` + +用途:读取事件历史。 + +鉴权:是。 + +路径参数: + +| 参数 | 类型 | 说明 | +| --- | --- | --- | +| `task_id` | `string` | 任务 ID | + +查询参数: + +| 参数 | 类型 | 默认值 | 说明 | 约束 | +| --- | --- | --- | --- | --- | +| `limit` | `integer` | `200` | 返回最近多少条事件 | 范围 `1..5000` | + +响应: + +```json +{ + "task_id": "example-deep-scan", + "events": [ + { + "timestamp": "2026-03-25T06:00:12+00:00", + "event_type": "run.started", + "run_id": "example-deep-scan", + "payload": { + "run_name": "example-20260325" + }, + "status": "running", + "source": "strix.tracer" + } + ] +} +``` + +状态码: + +- `200` +- `401` +- `404` +- `422` + +### GET `/api/v1/tasks/{task_id}/stream` + +用途:以 SSE 方式流式订阅任务事件。 + +鉴权:是。 + +路径参数: + +| 参数 | 类型 | 说明 | +| --- | --- | --- | +| `task_id` | `string` | 任务 ID | + +查询参数: + +| 参数 | 类型 | 默认值 | 说明 | +| --- | --- | --- | --- | +| `follow` | `boolean` | `true` | 是否持续跟随任务直到结束 | +| `from_offset` | `integer` | `0` | 从 `events.jsonl` 的哪个字节偏移继续读取 | + +响应类型: + +```text +text/event-stream +``` + +流式行为: + +1. 首先发送 `stream.connected` +2. 然后把 `events.jsonl` 中新增事件逐条转发为 SSE +3. 任务运行中会定时发送 `: keep-alive` 注释行保持连接 +4. 任务进入终态后,服务端额外补发 `task.finished` + +首条事件示例: + +```text +event: stream.connected +data: {"task_id":"example-deep-scan","offset":0} +``` + +普通事件示例: + +```text +event: chat.message +data: {"offset":1234,"payload":{"timestamp":"2026-03-25T06:00:15+00:00","event_type":"chat.message","payload":{"message_id":1,"content":"开始执行 reconnaissance"}}} +``` + +结束事件示例: + +```text +event: task.finished +data: {"task_id":"example-deep-scan","status":"completed", ...} +``` + +补充说明: + +- `from_offset` 是文件字节偏移,不是事件序号 +- `follow=false` 更适合“回放模式” +- 如果在连接建立后任务很快结束,仍会收到 `task.finished` + +状态码: + +- `200` +- `401` +- `404` +- `422` + +### GET `/api/v1/tasks/{task_id}/artifacts` + +用途:获取运行目录中的全部文件路径。 + +鉴权:是。 + +路径参数: + +| 参数 | 类型 | 说明 | +| --- | --- | --- | +| `task_id` | `string` | 任务 ID | + +响应: + +```json +{ + "task_id": "example-deep-scan", + "artifacts": [ + "/path/to/strix_runs/example-deep-scan/task_state.json", + "/path/to/strix_runs/example-deep-scan/events.jsonl", + "/path/to/strix_runs/example-deep-scan/scan_state.json", + "/path/to/strix_runs/example-deep-scan/penetration_test_report.md" + ] +} +``` + +补充说明: + +- 返回的是服务器文件绝对路径,不是下载 URL +- 常见文件包括 `task_state.json`、`events.jsonl`、`scan_state.json`、`worker.log`、`penetration_test_report.md`、`vulnerabilities/*.md`、`vulnerabilities.csv` + +状态码: + +- `200` +- `401` +- `404` + +### GET `/api/v1/tasks/{task_id}/report` + +用途:获取最终 Markdown 报告正文。 + +鉴权:是。 + +路径参数: + +| 参数 | 类型 | 说明 | +| --- | --- | --- | +| `task_id` | `string` | 任务 ID | + +响应类型: + +```text +text/plain +``` + +返回内容是 Markdown 文本,不是 JSON。 + +状态码: + +- `200` +- `401` +- `404` + +`404` 有两种常见情况: + +- 任务不存在 +- 任务存在,但报告文件尚未生成 + +## 错误与约束 + +### 401 未授权 + +仅当配置了 `api.auth_token` 时会出现。 + +```json +{ + "detail": "Invalid or missing API token" +} +``` + +### 404 资源不存在 + +典型返回: + +```json +{ + "detail": "Task 'example-deep-scan' not found" +} +``` + +报告接口的典型返回: + +```json +{ + "detail": "Task 'example-deep-scan' report not found" +} +``` + +### 422 请求参数校验失败 + +典型场景: + +- `targets` 为空 +- 出现额外未定义字段 +- `instruction` 与 `instruction_file` 同时传入 +- `scan_mode` 不在允许枚举中 +- `limit` 或 `from_offset` 越界 + +### 400 业务校验失败 + +目前主要出现在 `POST /api/v1/tasks`: + +- 任务并发数超过 `api.max_concurrent_tasks` +- 任务 ID 已存在 +- 指令文件读取失败 +- 指令文件为空 +- 配置文件无效 +- 目标校验失败 + +## 运行目录与文件说明 + +每个任务默认写入: + +```text +strix_runs// +``` + +目录下常见文件: + +| 文件 | 说明 | +| --- | --- | +| `task_state.json` | API 任务元数据 | +| `events.jsonl` | 事件历史,SSE 也基于它进行转发 | +| `scan_state.json` | 结构化扫描状态和最终结果 | +| `penetration_test_report.md` | 最终报告 | +| `worker.log` | worker 进程日志 | +| `vulnerabilities/` | 单个漏洞报告目录 | +| `vulnerabilities.csv` | 漏洞索引 | + +## 调用建议 + +- 如果你只需要“是否完成”和最终结果,用 `/result` 即可 +- 如果你要做任务面板,建议同时使用 `/tasks`、`/events` 和 `/stream` +- 如果你要做断点续传式消费,请记录 SSE 返回的 `offset` +- 如果你要接入企业网关,请优先配置 `api.auth_token` +- 如果你要把任务结果保存到外部系统,请同时归档 `scan_state.json` 和 `penetration_test_report.md` diff --git a/docs/contributing.mdx b/docs/contributing.mdx index 50964ccad..af3a66ee9 100644 --- a/docs/contributing.mdx +++ b/docs/contributing.mdx @@ -30,10 +30,17 @@ description: "Contribute to Strix development" poetry run pre-commit install ``` - + ```bash - export STRIX_LLM="openai/gpt-5.4" - export LLM_API_KEY="your-api-key" + mkdir -p ~/.strix + cat > ~/.strix/config.json <<'EOF' + { + "llm": { + "model": "openai/gpt-5.4", + "api_key": "your-api-key" + } + } + EOF ``` diff --git a/docs/docs.json b/docs/docs.json index e15b496d5..9ee8c2a1b 100644 --- a/docs/docs.json +++ b/docs/docs.json @@ -24,6 +24,7 @@ "group": "Usage", "pages": [ "usage/cli", + "api/web-api", "usage/scan-modes", "usage/instructions" ] diff --git a/docs/index.mdx b/docs/index.mdx index 2d4014893..67ced878a 100644 --- a/docs/index.mdx +++ b/docs/index.mdx @@ -74,14 +74,16 @@ Strix uses a graph of specialized agents for comprehensive security testing: ## Quick Example ```bash -# Install curl -sSL https://strix.ai/install | bash - -# Configure -export STRIX_LLM="openai/gpt-5.4" -export LLM_API_KEY="your-api-key" - -# Scan +mkdir -p ~/.strix +cat > ~/.strix/config.json <<'EOF' +{ + "llm": { + "model": "openai/gpt-5.4", + "api_key": "your-api-key" + } +} +EOF strix --target ./your-app ``` diff --git a/docs/integrations/ci-cd.mdx b/docs/integrations/ci-cd.mdx index 48213e7b8..e0d31bbe5 100644 --- a/docs/integrations/ci-cd.mdx +++ b/docs/integrations/ci-cd.mdx @@ -28,11 +28,18 @@ security-scan: image: docker:latest services: - docker:dind - variables: - STRIX_LLM: $STRIX_LLM - LLM_API_KEY: $LLM_API_KEY script: - curl -sSL https://strix.ai/install | bash + - mkdir -p ~/.strix + - | + cat > ~/.strix/config.json < ~/.strix/config.json < ~/.strix/config.json < -All CI platforms require Docker access. Ensure your runner has Docker available. +All CI platforms require Docker access. Ensure your runner has Docker available. Secrets can be injected by the CI platform, but Strix itself should still be configured through `~/.strix/config.json`. diff --git a/docs/integrations/github-actions.mdx b/docs/integrations/github-actions.mdx index 63991449c..445106eeb 100644 --- a/docs/integrations/github-actions.mdx +++ b/docs/integrations/github-actions.mdx @@ -22,10 +22,19 @@ jobs: - name: Install Strix run: curl -sSL https://strix.ai/install | bash + - name: Write Strix Config + run: | + mkdir -p ~/.strix + cat > ~/.strix/config.json <<'EOF' + { + "llm": { + "model": "${{ secrets.STRIX_MODEL }}", + "api_key": "${{ secrets.LLM_API_KEY }}" + } + } + EOF + - name: Run Security Scan - env: - STRIX_LLM: ${{ secrets.STRIX_LLM }} - LLM_API_KEY: ${{ secrets.LLM_API_KEY }} run: strix -n -t ./ --scan-mode quick ``` @@ -35,9 +44,13 @@ Add these secrets to your repository: | Secret | Description | |--------|-------------| -| `STRIX_LLM` | Model name (e.g., `openai/gpt-5.4`) | +| `STRIX_MODEL` | Model name (e.g., `openai/gpt-5.4`) | | `LLM_API_KEY` | API key for your LLM provider | + +The CI platform can still use secrets or environment variables to template the file, but Strix itself reads `~/.strix/config.json` at runtime. + + ## Exit Codes The workflow fails when vulnerabilities are found: diff --git a/docs/llm-providers/anthropic.mdx b/docs/llm-providers/anthropic.mdx index da32b13ab..7cf2b039d 100644 --- a/docs/llm-providers/anthropic.mdx +++ b/docs/llm-providers/anthropic.mdx @@ -5,9 +5,13 @@ description: "Configure Strix with Claude models" ## Setup -```bash -export STRIX_LLM="anthropic/claude-sonnet-4-6" -export LLM_API_KEY="sk-ant-..." +```json +{ + "llm": { + "model": "anthropic/claude-sonnet-4-6", + "api_key": "sk-ant-..." + } +} ``` ## Available Models diff --git a/docs/llm-providers/azure.mdx b/docs/llm-providers/azure.mdx index 1a9be0084..265401e4b 100644 --- a/docs/llm-providers/azure.mdx +++ b/docs/llm-providers/azure.mdx @@ -5,31 +5,40 @@ description: "Configure Strix with OpenAI models via Azure" ## Setup -```bash -export STRIX_LLM="azure/your-gpt5-deployment" -export AZURE_API_KEY="your-azure-api-key" -export AZURE_API_BASE="https://your-resource.openai.azure.com" -export AZURE_API_VERSION="2025-11-01-preview" +```json +{ + "llm": { + "model": "azure/your-gpt5-deployment", + "api_key": "your-azure-api-key", + "api_base": "https://your-resource.openai.azure.com" + } +} ``` ## Configuration -| Variable | Description | +| Key | Description | |----------|-------------| -| `STRIX_LLM` | `azure/` | -| `AZURE_API_KEY` | Your Azure OpenAI API key | -| `AZURE_API_BASE` | Your Azure OpenAI endpoint URL | -| `AZURE_API_VERSION` | API version (e.g., `2025-11-01-preview`) | +| `llm.model` | `azure/` | +| `llm.api_key` | Your Azure OpenAI API key | +| `llm.api_base` | Your Azure OpenAI endpoint URL | ## Example -```bash -export STRIX_LLM="azure/gpt-5.4-deployment" -export AZURE_API_KEY="abc123..." -export AZURE_API_BASE="https://mycompany.openai.azure.com" -export AZURE_API_VERSION="2025-11-01-preview" +```json +{ + "llm": { + "model": "azure/gpt-5.4-deployment", + "api_key": "abc123...", + "api_base": "https://mycompany.openai.azure.com" + } +} ``` + +If your Azure setup requires additional provider-specific settings such as an explicit API version, follow the Azure and LiteLLM guidance used by your deployment environment or gateway. + + ## Prerequisites 1. Create an Azure OpenAI resource diff --git a/docs/llm-providers/bedrock.mdx b/docs/llm-providers/bedrock.mdx index 2189e9876..06e2dbcd6 100644 --- a/docs/llm-providers/bedrock.mdx +++ b/docs/llm-providers/bedrock.mdx @@ -5,11 +5,15 @@ description: "Configure Strix with models via AWS Bedrock" ## Setup -```bash -export STRIX_LLM="bedrock/anthropic.claude-4-5-sonnet-20251022-v1:0" +```json +{ + "llm": { + "model": "bedrock/anthropic.claude-4-5-sonnet-20251022-v1:0" + } +} ``` -No API key required—uses AWS credentials from environment. +No Strix API key is required here. Model selection lives in the config file, while authentication follows the standard AWS credential chain. ## Authentication diff --git a/docs/llm-providers/local.mdx b/docs/llm-providers/local.mdx index 8a899a5d2..56b7bc187 100644 --- a/docs/llm-providers/local.mdx +++ b/docs/llm-providers/local.mdx @@ -32,9 +32,13 @@ For critical assessments, we strongly recommend using state-of-the-art cloud mod ollama pull qwen3-vl ``` 3. Configure Strix: - ```bash - export STRIX_LLM="ollama/qwen3-vl" - export LLM_API_BASE="http://localhost:11434" + ```json + { + "llm": { + "model": "ollama/qwen3-vl", + "api_base": "http://localhost:11434" + } + } ``` ### Recommended Models @@ -50,7 +54,11 @@ We recommend these models for the best balance of reasoning and tool use: If you use LM Studio, vLLM, or other runners: -```bash -export STRIX_LLM="openai/local-model" -export LLM_API_BASE="http://localhost:1234/v1" # Adjust port as needed +```json +{ + "llm": { + "model": "openai/local-model", + "api_base": "http://localhost:1234/v1" + } +} ``` diff --git a/docs/llm-providers/openai.mdx b/docs/llm-providers/openai.mdx index c8a486778..719e4a3ad 100644 --- a/docs/llm-providers/openai.mdx +++ b/docs/llm-providers/openai.mdx @@ -5,9 +5,13 @@ description: "Configure Strix with OpenAI models" ## Setup -```bash -export STRIX_LLM="openai/gpt-5.4" -export LLM_API_KEY="sk-..." +```json +{ + "llm": { + "model": "openai/gpt-5.4", + "api_key": "sk-..." + } +} ``` ## Available Models @@ -24,8 +28,12 @@ See [OpenAI Models Documentation](https://platform.openai.com/docs/models) for t For OpenAI-compatible APIs: -```bash -export STRIX_LLM="openai/gpt-5.4" -export LLM_API_KEY="your-key" -export LLM_API_BASE="https://your-proxy.com/v1" +```json +{ + "llm": { + "model": "openai/gpt-5.4", + "api_key": "your-key", + "api_base": "https://your-proxy.com/v1" + } +} ``` diff --git a/docs/llm-providers/openrouter.mdx b/docs/llm-providers/openrouter.mdx index 2b816e90d..1bd599754 100644 --- a/docs/llm-providers/openrouter.mdx +++ b/docs/llm-providers/openrouter.mdx @@ -7,9 +7,13 @@ description: "Configure Strix with models via OpenRouter" ## Setup -```bash -export STRIX_LLM="openrouter/openai/gpt-5.4" -export LLM_API_KEY="sk-or-..." +```json +{ + "llm": { + "model": "openrouter/openai/gpt-5.4", + "api_key": "sk-or-..." + } +} ``` ## Available Models diff --git a/docs/llm-providers/overview.mdx b/docs/llm-providers/overview.mdx index 8c0d5002e..0c765fcba 100644 --- a/docs/llm-providers/overview.mdx +++ b/docs/llm-providers/overview.mdx @@ -7,7 +7,7 @@ Strix uses [LiteLLM](https://docs.litellm.ai/docs/providers) for model compatibi ## Configuration -Set your model and API key: +Set your model and API key in `~/.strix/config.json`: | Model | Provider | Configuration | | ----------------- | ------------- | -------------------------------- | @@ -15,18 +15,26 @@ Set your model and API key: | Claude Sonnet 4.6 | Anthropic | `anthropic/claude-sonnet-4-6` | | Gemini 3 Pro | Google Vertex | `vertex_ai/gemini-3-pro-preview` | -```bash -export STRIX_LLM="openai/gpt-5.4" -export LLM_API_KEY="your-api-key" +```json +{ + "llm": { + "model": "openai/gpt-5.4", + "api_key": "your-api-key" + } +} ``` ## Local Models Run models locally with [Ollama](https://ollama.com), [LM Studio](https://lmstudio.ai), or any OpenAI-compatible server: -```bash -export STRIX_LLM="ollama/llama4" -export LLM_API_BASE="http://localhost:11434" +```json +{ + "llm": { + "model": "ollama/llama4", + "api_base": "http://localhost:11434" + } +} ``` See the [Local Models guide](/llm-providers/local) for setup instructions and recommended models. diff --git a/docs/llm-providers/vertex.mdx b/docs/llm-providers/vertex.mdx index d7ed9710e..1bfadc306 100644 --- a/docs/llm-providers/vertex.mdx +++ b/docs/llm-providers/vertex.mdx @@ -13,11 +13,15 @@ pipx install "strix-agent[vertex]" ## Setup -```bash -export STRIX_LLM="vertex_ai/gemini-3-pro-preview" +```json +{ + "llm": { + "model": "vertex_ai/gemini-3-pro-preview" + } +} ``` -No API key required—uses Google Cloud Application Default Credentials. +No Strix API key is required here. Model selection lives in the config file, while authentication follows Google Cloud Application Default Credentials. ## Authentication diff --git a/docs/quickstart.mdx b/docs/quickstart.mdx index 681bf02de..127e6a0ca 100644 --- a/docs/quickstart.mdx +++ b/docs/quickstart.mdx @@ -25,11 +25,15 @@ description: "Install Strix and run your first security scan" ## Configuration -Set your LLM provider: - -```bash -export STRIX_LLM="openai/gpt-5.4" -export LLM_API_KEY="your-api-key" +Create a config file at `~/.strix/config.json`: + +```json +{ + "llm": { + "model": "openai/gpt-5.4", + "api_key": "your-api-key" + } +} ``` diff --git a/docs/usage/cli.mdx b/docs/usage/cli.mdx index bfb4e1523..e8f4996bd 100644 --- a/docs/usage/cli.mdx +++ b/docs/usage/cli.mdx @@ -20,7 +20,7 @@ strix --target [options] - Path to a file containing detailed instructions. + Path to a file containing detailed instructions. Mutually exclusive with `--instruction`. @@ -32,7 +32,11 @@ strix --target [options] - Path to a custom config file (JSON) to use instead of `~/.strix/cli-config.json`. + Path to a custom config file (JSON) to use instead of `~/.strix/config.json`. + + + + Override the output directory name under `strix_runs/`. ## Examples @@ -52,6 +56,9 @@ strix -n --target ./ --scan-mode quick # Multi-target white-box testing strix -t https://github.com/org/app -t https://staging.example.com + +# Custom output directory name +strix --target https://example.com --run-name example-20260325 ``` ## Exit Codes diff --git a/pyproject.toml b/pyproject.toml index 2c974f30f..faf8f6471 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,6 +36,7 @@ include = [ "LICENSE", "README.md", "strix/agents/**/*.jinja", + "strix/**/*.html", "strix/skills/**/*.md", "strix/**/*.xml", "strix/**/*.tcss" @@ -43,6 +44,7 @@ include = [ [tool.poetry.scripts] strix = "strix.interface.main:main" +strix-api = "strix.api.server:main" [tool.poetry.dependencies] python = "^3.12" @@ -59,13 +61,13 @@ cvss = "^3.2" traceloop-sdk = "^0.53.0" opentelemetry-exporter-otlp-proto-http = "^1.40.0" scrubadub = "^2.0.1" +fastapi = "*" +uvicorn = "*" # Optional LLM provider dependencies google-cloud-aiplatform = { version = ">=1.38", optional = true } # Sandbox-only dependencies (only needed inside Docker container) -fastapi = { version = "*", optional = true } -uvicorn = { version = "*", optional = true } ipython = { version = "^9.3.0", optional = true } openhands-aci = { version = "^0.3.0", optional = true } playwright = { version = "^1.48.0", optional = true } @@ -77,7 +79,7 @@ defusedxml = "^0.7.1" [tool.poetry.extras] vertex = ["google-cloud-aiplatform"] -sandbox = ["fastapi", "uvicorn", "ipython", "openhands-aci", "playwright", "gql", "pyte", "libtmux", "numpydoc"] +sandbox = ["ipython", "openhands-aci", "playwright", "gql", "pyte", "libtmux", "numpydoc"] [tool.poetry.group.dev.dependencies] # Type checking and static analysis diff --git a/strix/agents/StrixAgent/system_prompt.jinja b/strix/agents/StrixAgent/system_prompt.jinja index 2dd146662..91c2d7a81 100644 --- a/strix/agents/StrixAgent/system_prompt.jinja +++ b/strix/agents/StrixAgent/system_prompt.jinja @@ -15,6 +15,14 @@ CLI OUTPUT: - Use line breaks and indentation for structure - NEVER use "Strix" or any identifiable names/markers in HTTP requests, payloads, user-agents, or any inputs +LANGUAGE PREFERENCE: +- Prefer Simplified Chinese for all user-facing natural language output whenever doing so does not break tooling or machine-readable formats +- This includes status updates, explanations, findings summaries, remediation guidance, final conclusions, and other narrative text intended for users or operators +- When technical precision matters, prefer Chinese narration with original technical terms preserved in backticks +- Keep tool/function XML tags, tool names, parameter names, required enum values, booleans, identifiers, schema-defined literals, URLs, paths, payloads, code, CVE IDs, HTTP methods, and any machine-readable tokens exactly as required +- Do NOT translate any value that must match a tool contract, parser expectation, protocol requirement, or downstream validation rule +- If you are unsure whether a field is machine-consumed, preserve the original literal format instead of translating it + INTER-AGENT MESSAGES: - NEVER echo inter_agent_message or agent_completion_report blocks that are sent to you in your output. - Process these internally without displaying them diff --git a/strix/agents/base_agent.py b/strix/agents/base_agent.py index 74fe21ef5..3e5bee992 100644 --- a/strix/agents/base_agent.py +++ b/strix/agents/base_agent.py @@ -15,6 +15,7 @@ from strix.llm import LLM, LLMConfig, LLMRequestFailedError from strix.llm.utils import clean_content +from strix.runtime.context import is_sandbox_mode from strix.runtime import SandboxInitializationError from strix.tools import process_tool_invocations from strix.utils.resource_paths import get_strix_resource_path @@ -328,9 +329,7 @@ async def _enter_waiting_state( ) async def _initialize_sandbox_and_state(self, task: str) -> None: - import os - - sandbox_mode = os.getenv("STRIX_SANDBOX_MODE", "false").lower() == "true" + sandbox_mode = is_sandbox_mode() if not sandbox_mode and self.state.sandbox_id is None: from strix.runtime import get_runtime diff --git a/strix/api/__init__.py b/strix/api/__init__.py new file mode 100644 index 000000000..4ef6ab2db --- /dev/null +++ b/strix/api/__init__.py @@ -0,0 +1,4 @@ +from .server import create_app + + +__all__ = ["create_app"] diff --git a/strix/api/common.py b/strix/api/common.py new file mode 100644 index 000000000..31ae6492c --- /dev/null +++ b/strix/api/common.py @@ -0,0 +1,16 @@ +from __future__ import annotations + +from strix.scan import PreparedScan, ScanRequest, build_targets_info, generate_scan_id, prepare_scan + + +def generate_task_id(raw_targets: list[str]) -> str: + return generate_scan_id(raw_targets) + + +__all__ = [ + "PreparedScan", + "ScanRequest", + "build_targets_info", + "generate_task_id", + "prepare_scan", +] diff --git a/strix/api/demo/index.html b/strix/api/demo/index.html new file mode 100644 index 000000000..ba50aa484 --- /dev/null +++ b/strix/api/demo/index.html @@ -0,0 +1,1801 @@ + + + + + + Strix API Demo + + + +
+
+
+
+
Strix API Demo
+

Operate scans, artifacts, results, and live telemetry from one page.

+

+ This demo targets every public API capability: create tasks, inspect queue state, + fetch results, pull reports and artifacts, replay event history, and stream + execution output in real time. +

+
+
Health unknown
+
+ +
+ + + + Open docs + OpenAPI +
+
+ +
+ + +
+
+
+
+

Launch a task

+

+ Use one or more targets. Enter either inline instruction text or a server-side + instruction file path. +

+
+
+
+ + + + + + + +
+
+ + +
+
+ +
+
+
+

Task detail

+

+ Select a task to inspect its lifecycle and outputs. +

+
+
+
+ + + + + + + +
+
+
+
Selection
+
Choose a task from the left rail.
+
+
+
+
+
+

Artifacts

+

Files emitted into the run directory.

+
+
+
    +
  • Artifacts will appear here.
  • +
+
+
+
+ + +
+
+ + + + diff --git a/strix/api/main.py b/strix/api/main.py new file mode 100644 index 000000000..f8278316a --- /dev/null +++ b/strix/api/main.py @@ -0,0 +1,4 @@ +from strix.api.server import main + + +__all__ = ["main"] diff --git a/strix/api/models.py b/strix/api/models.py new file mode 100644 index 000000000..7b8023437 --- /dev/null +++ b/strix/api/models.py @@ -0,0 +1,93 @@ +from datetime import UTC, datetime +from enum import Enum +from typing import Any, Literal + +from pydantic import BaseModel, ConfigDict, Field, model_validator + + +def utc_now_iso() -> str: + return datetime.now(UTC).isoformat() + + +class TaskStatus(str, Enum): + QUEUED = "queued" + RUNNING = "running" + CANCELLING = "cancelling" + COMPLETED = "completed" + FAILED = "failed" + CANCELLED = "cancelled" + + +class ScanTaskRequest(BaseModel): + model_config = ConfigDict(extra="forbid") + + targets: list[str] = Field(min_length=1) + instruction: str | None = None + instruction_file: str | None = None + scan_mode: Literal["quick", "standard", "deep"] = "deep" + task_id: str | None = None + run_name: str | None = None + config_path: str | None = None + + @model_validator(mode="after") + def validate_instruction_inputs(self) -> "ScanTaskRequest": + if self.instruction and self.instruction_file: + raise ValueError("instruction and instruction_file cannot be used together") + return self + + +class ScanTaskRecord(BaseModel): + model_config = ConfigDict(extra="ignore") + + task_id: str + run_name: str | None = None + status: TaskStatus = TaskStatus.QUEUED + created_at: str = Field(default_factory=utc_now_iso) + updated_at: str = Field(default_factory=utc_now_iso) + started_at: str | None = None + finished_at: str | None = None + completed_at: str | None = None + pid: int | None = None + exit_code: int | None = None + error: str | None = None + request: ScanTaskRequest + run_dir: str + worker_log_path: str + scan_state_path: str + events_path: str + + @model_validator(mode="after") + def normalize_record(self) -> "ScanTaskRecord": + if self.finished_at and not self.completed_at: + self.completed_at = self.finished_at + if self.completed_at and not self.finished_at: + self.finished_at = self.completed_at + if not self.run_name: + self.run_name = self.request.run_name or self.task_id + return self + + +class ScanTaskResult(BaseModel): + model_config = ConfigDict(extra="ignore") + + task: ScanTaskRecord + scan_state: dict[str, Any] | None = None + artifacts: list[str] = Field(default_factory=list) + + +class TaskCollectionResponse(BaseModel): + tasks: list[ScanTaskRecord] + + +class TaskEventsResponse(BaseModel): + task_id: str + events: list[dict[str, Any]] + + +class TaskArtifactsResponse(BaseModel): + task_id: str + artifacts: list[Any] + + +CreateTaskRequest = ScanTaskRequest +TaskRecord = ScanTaskRecord diff --git a/strix/api/server.py b/strix/api/server.py new file mode 100644 index 000000000..0c8ff1dd4 --- /dev/null +++ b/strix/api/server.py @@ -0,0 +1,266 @@ +from __future__ import annotations + +import argparse +import asyncio +import json +from pathlib import Path +from typing import AsyncIterator + +import uvicorn +from fastapi import Depends, FastAPI, HTTPException, Query +from fastapi.responses import HTMLResponse, PlainTextResponse, StreamingResponse +from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer + +from strix.api.models import ScanTaskRecord, ScanTaskRequest, ScanTaskResult +from strix.api.task_manager import TaskManager +from strix.config import Config + + +security = HTTPBearer(auto_error=False) +DEMO_PAGE_PATH = Path(__file__).resolve().parent / "demo" / "index.html" + + +def create_app(task_manager: TaskManager | None = None) -> FastAPI: + manager = task_manager or TaskManager() + auth_token = Config.get_str("api_auth_token") + enable_docs = Config.get_bool("api_enable_docs") + poll_interval_ms = Config.get_int("api_stream_poll_interval_ms") or 500 + + app = FastAPI( + title="Strix API", + version="1.0.0", + docs_url="/docs" if enable_docs is not False else None, + redoc_url="/redoc" if enable_docs is not False else None, + openapi_url="/openapi.json" if enable_docs is not False else None, + ) + app.state.task_manager = manager + app.state.auth_token = auth_token + app.state.poll_interval_ms = poll_interval_ms + + def _get_task_or_404(task_id: str) -> ScanTaskRecord: + try: + task = app.state.task_manager.get_task(task_id) + except KeyError as exc: + raise HTTPException(status_code=404, detail=f"Task '{task_id}' not found") from exc + if task is None: + raise HTTPException(status_code=404, detail=f"Task '{task_id}' not found") + return task + + def _get_result_or_404(task_id: str) -> ScanTaskResult: + try: + result = app.state.task_manager.get_result(task_id) + except KeyError as exc: + raise HTTPException(status_code=404, detail=f"Task '{task_id}' not found") from exc + if result is None: + raise HTTPException(status_code=404, detail=f"Task '{task_id}' not found") + return result + + async def require_auth( + credentials: HTTPAuthorizationCredentials | None = Depends(security), + ) -> None: + expected = app.state.auth_token + if not expected: + return + if ( + credentials is None + or credentials.scheme.lower() != "bearer" + or credentials.credentials != expected + ): + raise HTTPException(status_code=401, detail="Invalid or missing API token") + + @app.get("/health") + async def health() -> dict[str, str]: + return {"status": "ok"} + + @app.get("/demo", include_in_schema=False) + async def demo() -> HTMLResponse: + return HTMLResponse(DEMO_PAGE_PATH.read_text(encoding="utf-8")) + + @app.post("/api/v1/tasks", status_code=201, dependencies=[Depends(require_auth)]) + async def create_task(request: ScanTaskRequest) -> dict[str, object]: + try: + record = app.state.task_manager.create_task(request) + except ValueError as exc: + raise HTTPException(status_code=400, detail=str(exc)) from exc + return {"task": record.model_dump(mode="python")} + + @app.get("/api/v1/tasks", dependencies=[Depends(require_auth)]) + async def list_tasks() -> dict[str, object]: + tasks = [record.model_dump(mode="python") for record in app.state.task_manager.list_tasks()] + return {"tasks": tasks} + + @app.get("/api/v1/tasks/{task_id}", dependencies=[Depends(require_auth)]) + async def get_task(task_id: str) -> dict[str, object]: + result = _get_result_or_404(task_id) + return result.model_dump(mode="python") + + @app.get("/api/v1/tasks/{task_id}/result", dependencies=[Depends(require_auth)]) + async def get_task_result(task_id: str) -> dict[str, object]: + result = _get_result_or_404(task_id) + return { + "task": result.task.model_dump(mode="python"), + "scan_state": result.scan_state, + "artifacts": result.artifacts, + } + + @app.get("/api/v1/tasks/{task_id}/results", dependencies=[Depends(require_auth)]) + async def get_task_results_alias(task_id: str) -> dict[str, object]: + return await get_task_result(task_id) + + @app.post("/api/v1/tasks/{task_id}/cancel", dependencies=[Depends(require_auth)]) + async def cancel_task(task_id: str) -> dict[str, object]: + try: + record = app.state.task_manager.cancel_task(task_id) + except KeyError as exc: + raise HTTPException(status_code=404, detail=f"Task '{task_id}' not found") from exc + if record is None: + raise HTTPException(status_code=404, detail=f"Task '{task_id}' not found") + return {"task": record.model_dump(mode="python")} + + @app.get("/api/v1/tasks/{task_id}/events", dependencies=[Depends(require_auth)]) + async def get_task_events( + task_id: str, + limit: int = Query(default=200, ge=1, le=5000), + ) -> dict[str, object]: + try: + events = app.state.task_manager.get_events(task_id, limit=limit) + except KeyError as exc: + raise HTTPException(status_code=404, detail=f"Task '{task_id}' not found") from exc + if events is None: + raise HTTPException(status_code=404, detail=f"Task '{task_id}' not found") + return {"task_id": task_id, "events": events} + + @app.get("/api/v1/tasks/{task_id}/stream", dependencies=[Depends(require_auth)]) + async def stream_task_events( + task_id: str, + follow: bool = Query(default=True), + from_offset: int = Query(default=0, ge=0), + ) -> StreamingResponse: + _get_task_or_404(task_id) + + return StreamingResponse( + _stream_events( + app.state.task_manager, + task_id, + follow=follow, + from_offset=from_offset, + poll_interval_ms=app.state.poll_interval_ms, + ), + media_type="text/event-stream", + ) + + @app.get("/api/v1/tasks/{task_id}/artifacts", dependencies=[Depends(require_auth)]) + async def get_task_artifacts(task_id: str) -> dict[str, object]: + try: + artifacts = app.state.task_manager.get_artifacts(task_id) + except KeyError as exc: + raise HTTPException(status_code=404, detail=f"Task '{task_id}' not found") from exc + if artifacts is None: + raise HTTPException(status_code=404, detail=f"Task '{task_id}' not found") + return {"task_id": task_id, "artifacts": artifacts} + + @app.get("/api/v1/tasks/{task_id}/report", dependencies=[Depends(require_auth)]) + async def get_task_report(task_id: str) -> PlainTextResponse: + try: + report = app.state.task_manager.get_report_text(task_id) + except KeyError as exc: + raise HTTPException(status_code=404, detail=f"Task '{task_id}' report not found") from exc + if report is None: + raise HTTPException(status_code=404, detail=f"Task '{task_id}' report not found") + return PlainTextResponse(report) + + return app + + +async def _stream_events( + task_manager: TaskManager, + task_id: str, + *, + follow: bool, + from_offset: int, + poll_interval_ms: int, +) -> AsyncIterator[str]: + events_path = task_manager.store.events_file(task_id) + offset = from_offset + sent_terminal = False + yield ( + "event: stream.connected\n" + f"data: {json.dumps({'task_id': task_id, 'offset': offset}, ensure_ascii=False)}\n\n" + ) + + while True: + if events_path.exists(): + with events_path.open("r", encoding="utf-8") as file_obj: + file_obj.seek(offset) + while True: + line = file_obj.readline() + if not line: + break + offset = file_obj.tell() + stripped = line.strip() + if not stripped: + continue + try: + payload = json.loads(stripped) + except json.JSONDecodeError: + payload = {"raw": stripped} + yield ( + f"event: {payload.get('event_type', 'message')}\n" + f"data: {json.dumps({'offset': offset, 'payload': payload}, ensure_ascii=False)}\n\n" + ) + + try: + task = task_manager.get_task(task_id) + except KeyError: + break + if task is None: + break + try: + status = task.status.value + except AttributeError: + status = str(task.status) + if status in {"completed", "failed", "cancelled"}: + if not sent_terminal: + payload = task.model_dump(mode="python") if hasattr(task, "model_dump") else task + yield "event: task.finished\n" f"data: {json.dumps(payload, ensure_ascii=False)}\n\n" + sent_terminal = True + if not follow: + break + await asyncio.sleep(poll_interval_ms / 1000) + if events_path.exists() and offset < events_path.stat().st_size: + continue + break + + if not follow: + break + + yield ": keep-alive\n\n" + await asyncio.sleep(poll_interval_ms / 1000) + + +def build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description="Start the Strix API server") + parser.add_argument("--config", type=str, help="Path to a Strix config JSON file") + parser.add_argument("--host", type=str, help="Override server host") + parser.add_argument("--port", type=int, help="Override server port") + return parser + + +def main() -> None: + args = build_parser().parse_args() + + if args.config: + config_path = Path(args.config).expanduser().resolve() + Config.validate_file(config_path) + Config.set_config_file(config_path) + else: + Config.reload() + + host = args.host or Config.get_str("api_host") or "127.0.0.1" + port = args.port or Config.get_int("api_port") or 8787 + + uvicorn.run(create_app(), host=host, port=port, log_level="info") + + +if __name__ == "__main__": + main() diff --git a/strix/api/task_manager.py b/strix/api/task_manager.py new file mode 100644 index 000000000..13e02faf4 --- /dev/null +++ b/strix/api/task_manager.py @@ -0,0 +1,171 @@ +from __future__ import annotations + +import json +import os +import signal +import subprocess +import sys +from pathlib import Path +from typing import Any + +from strix.api.common import build_targets_info, generate_task_id +from strix.api.models import ScanTaskRecord, ScanTaskRequest, ScanTaskResult, TaskStatus +from strix.api.task_store import TERMINAL_TASK_STATUSES, TaskStore +from strix.config import Config + + +class TaskManager: + def __init__(self, store: TaskStore | None = None): + self.store = store or TaskStore() + + def list_tasks(self) -> list[ScanTaskRecord]: + return [self.store.refresh(record) for record in self.store.list()] + + def get_task(self, task_id: str) -> ScanTaskRecord: + record = self.store.load(task_id) + if record is None: + raise KeyError(task_id) + return self.store.refresh(record) + + def get_result(self, task_id: str) -> ScanTaskResult: + record = self.get_task(task_id) + + result = self.store.result(task_id) + if result is None: + return ScanTaskResult( + task=record, + scan_state=self.store.load_scan_state(task_id), + artifacts=self.get_artifacts(task_id), + ) + + result.task = record + return result + + def create_task(self, request: ScanTaskRequest) -> ScanTaskRecord: + build_targets_info(request.targets) + + max_concurrent_tasks = Config.get_int("api_max_concurrent_tasks") or 1 + active_tasks = [ + task + for task in self.list_tasks() + if task.status in {TaskStatus.QUEUED, TaskStatus.RUNNING, TaskStatus.CANCELLING} + ] + if len(active_tasks) >= max_concurrent_tasks: + raise ValueError( + "Maximum concurrent task limit reached. " + f"Current limit: {max_concurrent_tasks}" + ) + + task_id = request.task_id or request.run_name or generate_task_id(request.targets) + if self.store.load(task_id) is not None: + raise ValueError(f"Task '{task_id}' already exists") + + instruction = request.instruction + if request.instruction_file: + instruction_path = Path(request.instruction_file).expanduser().resolve() + try: + instruction = instruction_path.read_text(encoding="utf-8").strip() + except OSError as exc: + raise ValueError( + f"Failed to read instruction file '{instruction_path}': {exc}" + ) from exc + if not instruction: + raise ValueError(f"Instruction file '{instruction_path}' is empty") + + config_path = ( + Path(request.config_path).expanduser().resolve() + if request.config_path + else Config.active_config_path().resolve() + ) + Config.validate_file(config_path) + + request_data = request.model_copy( + update={ + "task_id": task_id, + "run_name": request.run_name or task_id, + "config_path": str(config_path), + "instruction": instruction, + } + ) + record = self.store.create_record(task_id, request_data) + self.store.save(record) + + worker_log_path = Path(record.worker_log_path) + worker_log_path.parent.mkdir(parents=True, exist_ok=True) + with worker_log_path.open("ab") as log_file: + process = subprocess.Popen( # noqa: S603 + [ + sys.executable, + "-m", + "strix.api.worker", + "--task-id", + task_id, + "--config", + str(config_path), + ], + cwd=Path.cwd(), + stdout=log_file, + stderr=subprocess.STDOUT, + start_new_session=True, + ) + + record.pid = process.pid + record.status = TaskStatus.QUEUED + return self.store.save(record) + + def cancel_task(self, task_id: str) -> ScanTaskRecord: + record = self.get_task(task_id) + + if record.status in TERMINAL_TASK_STATUSES: + return record + + if record.pid: + try: + if os.name == "posix": + os.killpg(record.pid, signal.SIGTERM) + else: + os.kill(record.pid, signal.SIGTERM) + except OSError: + pass + + record.status = TaskStatus.CANCELLING + record.error = record.error or "Task cancellation requested" + return self.store.save(record) + + def get_scan_state(self, task_id: str) -> dict[str, Any]: + self.get_task(task_id) + return self.store.load_scan_state(task_id) or {} + + def get_events(self, task_id: str, limit: int = 200) -> list[dict[str, Any]]: + record = self.get_task(task_id) + + events_path = Path(record.events_path) + if not events_path.exists(): + return [] + + events: list[dict[str, Any]] = [] + with events_path.open("r", encoding="utf-8") as file_obj: + for line in file_obj: + stripped = line.strip() + if not stripped: + continue + try: + events.append(json.loads(stripped)) + except ValueError: + continue + + return events[-limit:] + + def get_report_text(self, task_id: str) -> str | None: + record = self.get_task(task_id) + report_path = Path(record.run_dir) / "penetration_test_report.md" + if not report_path.exists(): + return None + return report_path.read_text(encoding="utf-8") + + def get_artifacts(self, task_id: str) -> list[str]: + record = self.get_task(task_id) + return sorted(str(path) for path in Path(record.run_dir).glob("**/*") if path.is_file()) + + +ScanTaskManager = TaskManager diff --git a/strix/api/task_store.py b/strix/api/task_store.py new file mode 100644 index 000000000..bf140ce2d --- /dev/null +++ b/strix/api/task_store.py @@ -0,0 +1,175 @@ +import json +import os +import subprocess +from pathlib import Path +from typing import Any + +from strix.api.models import ScanTaskRecord, ScanTaskResult, ScanTaskRequest, TaskStatus, utc_now_iso + + +ACTIVE_TASK_STATUSES = {TaskStatus.QUEUED, TaskStatus.RUNNING, TaskStatus.CANCELLING} +TERMINAL_TASK_STATUSES = {TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED} + + +class TaskStore: + def __init__(self, base_dir: Path | None = None): + self.base_dir = base_dir or (Path.cwd() / "strix_runs") + self.base_dir.mkdir(parents=True, exist_ok=True) + + def run_dir(self, task_id: str) -> Path: + return self.base_dir / task_id + + def task_file(self, task_id: str) -> Path: + return self.run_dir(task_id) / "task_state.json" + + def events_file(self, task_id: str) -> Path: + return self.run_dir(task_id) / "events.jsonl" + + def scan_state_file(self, task_id: str) -> Path: + return self.run_dir(task_id) / "scan_state.json" + + def worker_log_file(self, task_id: str) -> Path: + return self.run_dir(task_id) / "worker.log" + + def create_record(self, task_id: str, request: ScanTaskRequest) -> ScanTaskRecord: + run_dir = self.run_dir(task_id) + run_dir.mkdir(parents=True, exist_ok=True) + return ScanTaskRecord( + task_id=task_id, + request=request, + run_dir=str(run_dir), + worker_log_path=str(self.worker_log_file(task_id)), + scan_state_path=str(self.scan_state_file(task_id)), + events_path=str(self.events_file(task_id)), + ) + + def save(self, record: ScanTaskRecord) -> ScanTaskRecord: + record.updated_at = utc_now_iso() + task_path = self.task_file(record.task_id) + task_path.parent.mkdir(parents=True, exist_ok=True) + with task_path.open("w", encoding="utf-8") as file_obj: + json.dump(record.model_dump(mode="python"), file_obj, indent=2, ensure_ascii=False) + return record + + def load(self, task_id: str) -> ScanTaskRecord | None: + task_path = self.task_file(task_id) + if not task_path.exists(): + return None + with task_path.open("r", encoding="utf-8") as file_obj: + data = json.load(file_obj) + return ScanTaskRecord.model_validate(data) + + def list(self) -> list[ScanTaskRecord]: + tasks: list[ScanTaskRecord] = [] + for task_file in sorted(self.base_dir.glob("*/task_state.json")): + try: + with task_file.open("r", encoding="utf-8") as file_obj: + tasks.append(ScanTaskRecord.model_validate(json.load(file_obj))) + except (OSError, ValueError, json.JSONDecodeError): + continue + + tasks.sort(key=lambda item: item.created_at, reverse=True) + return tasks + + def load_scan_state(self, task_id: str) -> dict[str, Any] | None: + scan_state_path = self.scan_state_file(task_id) + if not scan_state_path.exists(): + return None + with scan_state_path.open("r", encoding="utf-8") as file_obj: + return json.load(file_obj) + + def result(self, task_id: str) -> ScanTaskResult | None: + record = self.load(task_id) + if record is None: + return None + + artifacts = sorted(str(path) for path in self.run_dir(task_id).glob("**/*") if path.is_file()) + return ScanTaskResult( + task=record, + scan_state=self.load_scan_state(task_id), + artifacts=artifacts, + ) + + def refresh(self, record: ScanTaskRecord) -> ScanTaskRecord: + if record.status in TERMINAL_TASK_STATUSES: + return record + + scan_state = self.load_scan_state(record.task_id) + if scan_state and (scan_state.get("run_metadata") or {}).get("status") == "completed": + record.status = TaskStatus.COMPLETED + record.finished_at = (scan_state.get("run_metadata") or {}).get("end_time") + return self.save(record) + + if record.pid: + exit_code = _poll_process_exit_code(record.pid) + else: + exit_code = None + + if exit_code is not None: + record.finished_at = record.finished_at or utc_now_iso() + record.exit_code = exit_code + if record.status == TaskStatus.CANCELLING: + record.status = TaskStatus.CANCELLED + record.error = record.error or "Task cancelled" + else: + record.status = TaskStatus.FAILED + if exit_code == 0: + record.error = record.error or "Worker exited without producing scan output" + else: + record.error = record.error or f"Worker exited with code {exit_code}" + return self.save(record) + + return record + + +def _process_exists(pid: int) -> bool: + try: + os.kill(pid, 0) + except OSError: + return False + return True + + +def _poll_process_exit_code(pid: int) -> int | None: + if os.name == "posix": + try: + waited_pid, status = os.waitpid(pid, os.WNOHANG) + except ChildProcessError: + waited_pid, status = 0, 0 + except OSError: + return 1 + else: + if waited_pid == pid: + if os.WIFEXITED(status): + return os.WEXITSTATUS(status) + if os.WIFSIGNALED(status): + return 128 + os.WTERMSIG(status) + return 1 + + if _is_zombie_process(pid): + return 1 + + if not _process_exists(pid): + return 1 + + return None + + +def _is_zombie_process(pid: int) -> bool: + if os.name != "posix": + return False + + try: + result = subprocess.run( # noqa: S603 + ["ps", "-o", "stat=", "-p", str(pid)], + capture_output=True, + check=False, + text=True, + ) + except OSError: + return False + + if result.returncode != 0: + return False + + return "Z" in result.stdout.strip().upper() diff --git a/strix/api/worker.py b/strix/api/worker.py new file mode 100644 index 000000000..11e08bb77 --- /dev/null +++ b/strix/api/worker.py @@ -0,0 +1,167 @@ +from __future__ import annotations + +import argparse +import asyncio +import signal +import traceback +from pathlib import Path +from typing import Any + +from strix.agents.StrixAgent import StrixAgent +from strix.api.common import ScanRequest, prepare_scan +from strix.api.models import TaskStatus, utc_now_iso +from strix.api.task_store import TaskStore +from strix.config import Config +from strix.interface.main import ( + check_docker_installed, + pull_docker_image, + validate_environment, + warm_up_llm, +) +from strix.runtime import cleanup_runtime +from strix.runtime.context import configure_runtime_context +from strix.telemetry import posthog +from strix.telemetry.tracer import Tracer, get_global_tracer, set_global_tracer +from strix.tools.agents_graph.agents_graph_actions import reset_agent_graph_state + + +_CURRENT_AGENT: StrixAgent | None = None +_CANCEL_REQUESTED = False + + +def _handle_termination(_signum: int, _frame: Any) -> None: + global _CANCEL_REQUESTED + _CANCEL_REQUESTED = True + if _CURRENT_AGENT is not None: + _CURRENT_AGENT.cancel_current_execution() + + +def build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description="Run a Strix scan worker") + parser.add_argument("--task-id", required=True) + parser.add_argument("--config", required=True) + return parser + + +async def _run_worker(task_id: str, config_path: Path) -> int: + global _CANCEL_REQUESTED, _CURRENT_AGENT + + _CANCEL_REQUESTED = False + + store = TaskStore() + record = store.load(task_id) + if record is None: + raise ValueError(f"Task '{task_id}' not found") + + Config.set_config_file(config_path) + configure_runtime_context( + sandbox_mode=False, + caido_api_token=Config.get_str("caido_api_token"), + ) + + check_docker_installed() + pull_docker_image() + validate_environment() + await warm_up_llm() + + prepared = prepare_scan( + ScanRequest( + targets=record.request.targets, + instruction=record.request.instruction or "", + scan_mode=record.request.scan_mode, + run_name=task_id, + ) + ) + + record.status = TaskStatus.RUNNING + record.started_at = record.started_at or utc_now_iso() + store.save(record) + + posthog.start( + model=Config.get_str("strix_llm"), + scan_mode=prepared.request.scan_mode, + is_whitebox=bool(prepared.local_sources), + interactive=False, + has_instructions=bool(prepared.request.instruction), + ) + + exit_code = 1 + exit_reason = "finished" + try: + reset_agent_graph_state() + tracer = Tracer(task_id) + set_global_tracer(tracer) + tracer.set_scan_config(prepared.build_scan_config()) + + _CURRENT_AGENT = StrixAgent(prepared.build_agent_config(interactive=False)) + result = await _CURRENT_AGENT.execute_scan(prepared.build_scan_config()) + + if _CANCEL_REQUESTED: + record.status = TaskStatus.CANCELLED + record.error = "Task cancelled" + exit_reason = "cancelled" + exit_code = 130 + return exit_code + + if isinstance(result, dict) and not result.get("success", True): + record.status = TaskStatus.FAILED + record.error = result.get("error", "Scan failed") + exit_reason = "failed" + exit_code = 1 + return exit_code + + record.status = TaskStatus.COMPLETED + record.error = None + exit_code = 0 + return exit_code + except asyncio.CancelledError: + record.status = TaskStatus.CANCELLED + record.error = "Task cancelled" + exit_reason = "cancelled" + exit_code = 130 + return exit_code + except SystemExit as exc: + record.status = TaskStatus.FAILED + record.error = f"Worker bootstrap exited with code {exc.code}" + exit_reason = "failed" + exit_code = int(exc.code) if isinstance(exc.code, int) else 1 + return exit_code + except Exception as exc: # noqa: BLE001 + record.status = TaskStatus.FAILED + record.error = "".join(traceback.format_exception_only(type(exc), exc)).strip() + exit_reason = "failed" + exit_code = 1 + return exit_code + finally: + _CURRENT_AGENT = None + if record.status in {TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED}: + record.finished_at = utc_now_iso() + record.exit_code = exit_code + store.save(record) + + tracer = get_global_tracer() + if tracer: + if record.finished_at: + tracer.end_time = record.finished_at + tracer.run_metadata["end_time"] = record.finished_at + tracer.run_metadata["status"] = record.status.value + tracer.save_run_data(mark_complete=record.status == TaskStatus.COMPLETED) + posthog.end(tracer, exit_reason=exit_reason) + set_global_tracer(None) + reset_agent_graph_state() + cleanup_runtime() + + +def main() -> None: + args = build_parser().parse_args() + config_path = Path(args.config).expanduser().resolve() + + signal.signal(signal.SIGTERM, _handle_termination) + signal.signal(signal.SIGINT, _handle_termination) + + exit_code = asyncio.run(_run_worker(args.task_id, config_path)) + raise SystemExit(exit_code) + + +if __name__ == "__main__": + main() diff --git a/strix/config/__init__.py b/strix/config/__init__.py index 328c13898..825b78ae4 100644 --- a/strix/config/__init__.py +++ b/strix/config/__init__.py @@ -1,4 +1,5 @@ from strix.config.config import ( + AppConfig, Config, apply_saved_config, save_current_config, @@ -6,6 +7,7 @@ __all__ = [ + "AppConfig", "Config", "apply_saved_config", "save_current_config", diff --git a/strix/config/config.py b/strix/config/config.py index 782101ddb..f842f56a1 100644 --- a/strix/config/config.py +++ b/strix/config/config.py @@ -1,182 +1,419 @@ import contextlib import json -import os +from copy import deepcopy from pathlib import Path from typing import Any +from pydantic import BaseModel, ConfigDict, Field, ValidationError + STRIX_API_BASE = "https://models.strix.ai/api/v1" +def _parse_bool(value: Any) -> bool: + if isinstance(value, bool): + return value + if value is None: + return False + return str(value).strip().lower() in {"1", "true", "yes", "on"} + + +def _get_nested(config: dict[str, Any], path: str) -> Any: + current: Any = config + for part in path.split("."): + if not isinstance(current, dict): + return None + current = current.get(part) + return current + + +def _set_nested(config: dict[str, Any], path: str, value: Any) -> None: + parts = path.split(".") + current = config + for part in parts[:-1]: + next_value = current.get(part) + if not isinstance(next_value, dict): + next_value = {} + current[part] = next_value + current = next_value + current[parts[-1]] = value + + +def _deep_merge(base: dict[str, Any], override: dict[str, Any]) -> dict[str, Any]: + result = deepcopy(base) + for key, value in override.items(): + if isinstance(value, dict) and isinstance(result.get(key), dict): + result[key] = _deep_merge(result[key], value) + else: + result[key] = value + return result + + +class LLMSettings(BaseModel): + model_config = ConfigDict(extra="ignore") + + model: str | None = None + api_key: str | None = None + api_base: str | None = None + openai_compatible_provider: str | None = None + openai_api_base: str | None = None + litellm_base_url: str | None = None + ollama_api_base: str | None = None + reasoning_effort: str = "high" + max_retries: int = 5 + memory_compressor_timeout: int = 30 + timeout: int = 300 + + +class FeatureSettings(BaseModel): + model_config = ConfigDict(extra="ignore") + + perplexity_api_key: str | None = None + disable_browser: bool = False + + +class RuntimeSettings(BaseModel): + model_config = ConfigDict(extra="ignore") + + image: str = "ghcr.io/usestrix/strix-sandbox:0.1.13" + backend: str = "docker" + sandbox_execution_timeout: int = 120 + sandbox_connect_timeout: int = 10 + sandbox_mode: bool = False + docker_host: str | None = None + caido_api_token: str | None = None + + +class TelemetrySettings(BaseModel): + model_config = ConfigDict(extra="ignore") + + enabled: bool = True + otel_enabled: bool | None = None + posthog_enabled: bool | None = None + traceloop_base_url: str | None = None + traceloop_api_key: str | None = None + traceloop_headers: str | None = None + + +class APISettings(BaseModel): + model_config = ConfigDict(extra="ignore") + + host: str = "127.0.0.1" + port: int = 8787 + auth_token: str | None = None + max_concurrent_tasks: int = 1 + enable_docs: bool = True + stream_poll_interval_ms: int = 500 + + +class AppConfig(BaseModel): + model_config = ConfigDict(extra="ignore") + + llm: LLMSettings = Field(default_factory=LLMSettings) + features: FeatureSettings = Field(default_factory=FeatureSettings) + runtime: RuntimeSettings = Field(default_factory=RuntimeSettings) + telemetry: TelemetrySettings = Field(default_factory=TelemetrySettings) + api: APISettings = Field(default_factory=APISettings) + + class Config: - """Configuration Manager for Strix.""" - - # LLM Configuration - strix_llm = None - llm_api_key = None - llm_api_base = None - openai_api_base = None - litellm_base_url = None - ollama_api_base = None - strix_reasoning_effort = "high" - strix_llm_max_retries = "5" - strix_memory_compressor_timeout = "30" - llm_timeout = "300" - _LLM_CANONICAL_NAMES = ( - "strix_llm", - "llm_api_key", - "llm_api_base", - "openai_api_base", - "litellm_base_url", - "ollama_api_base", - "strix_reasoning_effort", - "strix_llm_max_retries", - "strix_memory_compressor_timeout", - "llm_timeout", - ) - - # Tool & Feature Configuration - perplexity_api_key = None - strix_disable_browser = "false" - - # Runtime Configuration - strix_image = "ghcr.io/usestrix/strix-sandbox:0.1.13" - strix_runtime_backend = "docker" - strix_sandbox_execution_timeout = "120" - strix_sandbox_connect_timeout = "10" - - # Telemetry - strix_telemetry = "1" - strix_otel_telemetry = None - strix_posthog_telemetry = None - traceloop_base_url = None - traceloop_api_key = None - traceloop_headers = None - - # Config file override (set via --config CLI arg) + """Structured configuration manager backed by JSON files only.""" + _config_file_override: Path | None = None + _cached_config: AppConfig | None = None + + _LEGACY_ENV_TO_PATH = { + "STRIX_LLM": "llm.model", + "LLM_API_KEY": "llm.api_key", + "LLM_API_BASE": "llm.api_base", + "STRIX_OPENAI_COMPATIBLE_PROVIDER": "llm.openai_compatible_provider", + "OPENAI_API_BASE": "llm.openai_api_base", + "LITELLM_BASE_URL": "llm.litellm_base_url", + "OLLAMA_API_BASE": "llm.ollama_api_base", + "STRIX_REASONING_EFFORT": "llm.reasoning_effort", + "STRIX_LLM_MAX_RETRIES": "llm.max_retries", + "STRIX_MEMORY_COMPRESSOR_TIMEOUT": "llm.memory_compressor_timeout", + "LLM_TIMEOUT": "llm.timeout", + "PERPLEXITY_API_KEY": "features.perplexity_api_key", + "STRIX_DISABLE_BROWSER": "features.disable_browser", + "STRIX_IMAGE": "runtime.image", + "STRIX_RUNTIME_BACKEND": "runtime.backend", + "STRIX_SANDBOX_EXECUTION_TIMEOUT": "runtime.sandbox_execution_timeout", + "STRIX_SANDBOX_CONNECT_TIMEOUT": "runtime.sandbox_connect_timeout", + "STRIX_SANDBOX_MODE": "runtime.sandbox_mode", + "DOCKER_HOST": "runtime.docker_host", + "CAIDO_API_TOKEN": "runtime.caido_api_token", + "STRIX_TELEMETRY": "telemetry.enabled", + "STRIX_OTEL_TELEMETRY": "telemetry.otel_enabled", + "STRIX_POSTHOG_TELEMETRY": "telemetry.posthog_enabled", + "TRACELOOP_BASE_URL": "telemetry.traceloop_base_url", + "TRACELOOP_API_KEY": "telemetry.traceloop_api_key", + "TRACELOOP_HEADERS": "telemetry.traceloop_headers", + "STRIX_API_HOST": "api.host", + "STRIX_API_PORT": "api.port", + "STRIX_API_AUTH_TOKEN": "api.auth_token", + "STRIX_API_MAX_CONCURRENT_TASKS": "api.max_concurrent_tasks", + "STRIX_API_ENABLE_DOCS": "api.enable_docs", + "STRIX_API_STREAM_POLL_INTERVAL_MS": "api.stream_poll_interval_ms", + } + _CONFIG_KEY_PATHS = { + "strix_llm": "llm.model", + "llm_api_key": "llm.api_key", + "llm_api_base": "llm.api_base", + "llm_openai_compatible_provider": "llm.openai_compatible_provider", + "openai_api_base": "llm.openai_api_base", + "litellm_base_url": "llm.litellm_base_url", + "ollama_api_base": "llm.ollama_api_base", + "strix_reasoning_effort": "llm.reasoning_effort", + "strix_llm_max_retries": "llm.max_retries", + "strix_memory_compressor_timeout": "llm.memory_compressor_timeout", + "llm_timeout": "llm.timeout", + "perplexity_api_key": "features.perplexity_api_key", + "strix_disable_browser": "features.disable_browser", + "strix_image": "runtime.image", + "strix_runtime_backend": "runtime.backend", + "strix_sandbox_execution_timeout": "runtime.sandbox_execution_timeout", + "strix_sandbox_connect_timeout": "runtime.sandbox_connect_timeout", + "strix_sandbox_mode": "runtime.sandbox_mode", + "docker_host": "runtime.docker_host", + "caido_api_token": "runtime.caido_api_token", + "strix_telemetry": "telemetry.enabled", + "strix_otel_telemetry": "telemetry.otel_enabled", + "strix_posthog_telemetry": "telemetry.posthog_enabled", + "traceloop_base_url": "telemetry.traceloop_base_url", + "traceloop_api_key": "telemetry.traceloop_api_key", + "traceloop_headers": "telemetry.traceloop_headers", + "api_host": "api.host", + "api_port": "api.port", + "api_auth_token": "api.auth_token", + "api_max_concurrent_tasks": "api.max_concurrent_tasks", + "api_enable_docs": "api.enable_docs", + "api_stream_poll_interval_ms": "api.stream_poll_interval_ms", + } + _BOOL_PATHS = { + "features.disable_browser", + "runtime.sandbox_mode", + "telemetry.enabled", + "telemetry.otel_enabled", + "telemetry.posthog_enabled", + "api.enable_docs", + } + _INT_PATHS = { + "llm.max_retries", + "llm.memory_compressor_timeout", + "llm.timeout", + "runtime.sandbox_execution_timeout", + "runtime.sandbox_connect_timeout", + "api.port", + "api.max_concurrent_tasks", + "api.stream_poll_interval_ms", + } @classmethod - def _tracked_names(cls) -> list[str]: - return [ - k - for k, v in vars(cls).items() - if not k.startswith("_") and k[0].islower() and (v is None or isinstance(v, str)) - ] + def tracked_vars(cls) -> list[str]: + return sorted(cls._LEGACY_ENV_TO_PATH.keys()) @classmethod - def tracked_vars(cls) -> list[str]: - return [name.upper() for name in cls._tracked_names()] + def config_dir(cls) -> Path: + return Path.home() / ".strix" @classmethod - def _llm_env_vars(cls) -> set[str]: - return {name.upper() for name in cls._LLM_CANONICAL_NAMES} + def legacy_config_file(cls) -> Path: + return cls.config_dir() / "cli-config.json" @classmethod - def _llm_env_changed(cls, saved_env: dict[str, Any]) -> bool: - for var_name in cls._llm_env_vars(): - current = os.getenv(var_name) - if current is None: - continue - if saved_env.get(var_name) != current: - return True - return False + def config_file(cls) -> Path: + if cls._config_file_override is not None: + return cls._config_file_override + return cls.config_dir() / "config.json" @classmethod - def get(cls, name: str) -> str | None: - env_name = name.upper() - default = getattr(cls, name, None) - return os.getenv(env_name, default) + def active_config_path(cls) -> Path: + if cls._config_file_override is not None: + return cls._config_file_override + + primary = cls.config_file() + if primary.exists(): + return primary + + legacy = cls.legacy_config_file() + if legacy.exists(): + return legacy + + return primary @classmethod - def config_dir(cls) -> Path: - return Path.home() / ".strix" + def set_config_file(cls, path: Path) -> None: + cls._config_file_override = path + cls.reload() @classmethod - def config_file(cls) -> Path: - if cls._config_file_override is not None: - return cls._config_file_override - return cls.config_dir() / "cli-config.json" + def reload(cls) -> AppConfig: + cls._cached_config = None + return cls.load_model() @classmethod - def load(cls) -> dict[str, Any]: - path = cls.config_file() + def _read_json_file(cls, path: Path) -> dict[str, Any]: if not path.exists(): return {} + with path.open("r", encoding="utf-8") as file_obj: + data = json.load(file_obj) + if not isinstance(data, dict): + raise ValueError("Config file must contain a JSON object") + return data + + @classmethod + def _coerce_legacy_value(cls, path: str, value: Any) -> Any: + if value in ("", None): + return None + if path in cls._BOOL_PATHS: + return _parse_bool(value) + if path in cls._INT_PATHS: + return int(value) + return value + + @classmethod + def _normalize_legacy_env(cls, env_data: dict[str, Any]) -> dict[str, Any]: + normalized: dict[str, Any] = {} + for raw_key, raw_value in env_data.items(): + if not isinstance(raw_key, str): + continue + path = cls._LEGACY_ENV_TO_PATH.get(raw_key.upper()) + if not path: + continue + value = cls._coerce_legacy_value(path, raw_value) + if value is None: + continue + _set_nested(normalized, path, value) + return normalized + + @classmethod + def _normalize_dict(cls, data: dict[str, Any]) -> dict[str, Any]: + normalized = AppConfig().model_dump(mode="python") + body = dict(data) + + legacy_env = body.pop("env", None) + if isinstance(legacy_env, dict): + normalized = _deep_merge(normalized, cls._normalize_legacy_env(legacy_env)) + + normalized = _deep_merge(normalized, body) + return normalized + + @classmethod + def _load_from_file(cls, path: Path) -> AppConfig: try: - with path.open("r", encoding="utf-8") as f: - data: dict[str, Any] = json.load(f) - return data - except (json.JSONDecodeError, OSError): - return {} + raw_data = cls._read_json_file(path) + except FileNotFoundError: + return AppConfig() + except json.JSONDecodeError as exc: + raise ValueError(f"Invalid JSON in config file: {exc}") from exc + + normalized = cls._normalize_dict(raw_data) + try: + return AppConfig.model_validate(normalized) + except ValidationError as exc: + raise ValueError(f"Invalid config structure: {exc}") from exc + + @classmethod + def validate_file(cls, path: Path) -> Path: + cls._load_from_file(path) + return path + + @classmethod + def load_model(cls) -> AppConfig: + if cls._cached_config is not None: + return cls._cached_config + + primary = cls.config_file() + if primary.exists(): + cls._cached_config = cls._load_from_file(primary) + return cls._cached_config + + legacy = cls.legacy_config_file() + if cls._config_file_override is None and legacy.exists(): + cls._cached_config = cls._load_from_file(legacy) + return cls._cached_config + + cls._cached_config = AppConfig() + return cls._cached_config + + @classmethod + def load(cls) -> dict[str, Any]: + return cls.load_model().model_dump(mode="python") + + @classmethod + def get(cls, name: str) -> Any: + path = cls._CONFIG_KEY_PATHS.get(name, name) + return _get_nested(cls.load(), path) + + @classmethod + def get_str(cls, name: str) -> str | None: + value = cls.get(name) + if value is None: + return None + if isinstance(value, bool): + return "true" if value else "false" + return str(value) + + @classmethod + def get_int(cls, name: str) -> int | None: + value = cls.get(name) + if value is None: + return None + return int(value) + + @classmethod + def get_bool(cls, name: str) -> bool | None: + value = cls.get(name) + if value is None: + return None + return _parse_bool(value) + + @classmethod + def _legacy_snapshot(cls) -> dict[str, str]: + config = cls.load() + snapshot: dict[str, str] = {} + for env_key, path in cls._LEGACY_ENV_TO_PATH.items(): + value = _get_nested(config, path) + if value is None: + continue + if isinstance(value, bool): + snapshot[env_key] = "true" if value else "false" + else: + snapshot[env_key] = str(value) + return snapshot @classmethod def save(cls, config: dict[str, Any]) -> bool: try: - cls.config_dir().mkdir(parents=True, exist_ok=True) - config_path = cls.config_dir() / "cli-config.json" - with config_path.open("w", encoding="utf-8") as f: - json.dump(config, f, indent=2) - except OSError: + normalized = cls._normalize_dict(config) + validated = AppConfig.model_validate(normalized) + path = cls.config_file() + path.parent.mkdir(parents=True, exist_ok=True) + with path.open("w", encoding="utf-8") as file_obj: + json.dump(validated.model_dump(mode="python"), file_obj, indent=2) + except (OSError, ValidationError, ValueError): return False + with contextlib.suppress(OSError): - config_path.chmod(0o600) # may fail on Windows - return True + path.chmod(0o600) - @classmethod - def apply_saved(cls, force: bool = False) -> dict[str, str]: - saved = cls.load() - env_vars = saved.get("env", {}) - if not isinstance(env_vars, dict): - env_vars = {} - cleared_vars = { - var_name - for var_name in cls.tracked_vars() - if var_name in os.environ and os.environ.get(var_name) == "" - } - if cleared_vars: - for var_name in cleared_vars: - env_vars.pop(var_name, None) - if cls._config_file_override is None: - cls.save({"env": env_vars}) - if cls._llm_env_changed(env_vars): - for var_name in cls._llm_env_vars(): - env_vars.pop(var_name, None) - if cls._config_file_override is None: - cls.save({"env": env_vars}) - applied = {} - - for var_name, var_value in env_vars.items(): - if var_name in cls.tracked_vars() and (force or var_name not in os.environ): - os.environ[var_name] = var_value - applied[var_name] = var_value - - return applied + cls._cached_config = validated + return True @classmethod def capture_current(cls) -> dict[str, Any]: - env_vars = {} - for var_name in cls.tracked_vars(): - value = os.getenv(var_name) - if value: - env_vars[var_name] = value - return {"env": env_vars} + return cls.load() @classmethod def save_current(cls) -> bool: - existing = cls.load().get("env", {}) - merged = dict(existing) + return cls.save(cls.load()) - for var_name in cls.tracked_vars(): - value = os.getenv(var_name) - if value is None: - pass - elif value == "": - merged.pop(var_name, None) - else: - merged[var_name] = value - - return cls.save({"env": merged}) + @classmethod + def apply_saved(cls, force: bool = False) -> dict[str, str]: + del force + cls.reload() + return cls._legacy_snapshot() def apply_saved_config(force: bool = False) -> dict[str, str]: @@ -187,29 +424,22 @@ def save_current_config() -> bool: return Config.save_current() -def resolve_llm_config() -> tuple[str | None, str | None, str | None]: - """Resolve LLM model, api_key, and api_base based on STRIX_LLM prefix. - - Returns: - tuple: (model_name, api_key, api_base) - - model_name: Original model name (strix/ prefix preserved for display) - - api_key: LLM API key - - api_base: API base URL (auto-set to STRIX_API_BASE for strix/ models) - """ - model = Config.get("strix_llm") +def resolve_llm_config() -> tuple[str | None, str | None, str | None, str | None]: + model = Config.get_str("strix_llm") if not model: - return None, None, None + return None, None, None, None - api_key = Config.get("llm_api_key") + api_key = Config.get_str("llm_api_key") + openai_compatible_provider = Config.get_str("llm_openai_compatible_provider") if model.startswith("strix/"): api_base: str | None = STRIX_API_BASE else: api_base = ( - Config.get("llm_api_base") - or Config.get("openai_api_base") - or Config.get("litellm_base_url") - or Config.get("ollama_api_base") + Config.get_str("llm_api_base") + or Config.get_str("openai_api_base") + or Config.get_str("litellm_base_url") + or Config.get_str("ollama_api_base") ) - return model, api_key, api_base + return model, api_key, api_base, openai_compatible_provider diff --git a/strix/interface/__init__.py b/strix/interface/__init__.py index b0f97407c..3427b731a 100644 --- a/strix/interface/__init__.py +++ b/strix/interface/__init__.py @@ -1,4 +1,10 @@ -from .main import main +from typing import Any + + +def main(*args: Any, **kwargs: Any) -> Any: + from .main import main as interface_main + + return interface_main(*args, **kwargs) __all__ = ["main"] diff --git a/strix/interface/cli.py b/strix/interface/cli.py index 430eebcf3..ded0cc595 100644 --- a/strix/interface/cli.py +++ b/strix/interface/cli.py @@ -11,7 +11,7 @@ from rich.text import Text from strix.agents.StrixAgent import StrixAgent -from strix.llm.config import LLMConfig +from strix.scan import PreparedScan, ScanRequest, build_agent_config, build_scan_config from strix.telemetry.tracer import Tracer, set_global_tracer from .utils import ( @@ -66,22 +66,19 @@ async def run_cli(args: Any) -> None: # noqa: PLR0915 console.print() scan_mode = getattr(args, "scan_mode", "deep") - - scan_config = { - "scan_id": args.run_name, - "targets": args.targets_info, - "user_instructions": args.instruction or "", - "run_name": args.run_name, - } - - llm_config = LLMConfig(scan_mode=scan_mode) - agent_config = { - "llm_config": llm_config, - "max_iterations": 300, - } - - if getattr(args, "local_sources", None): - agent_config["local_sources"] = args.local_sources + prepared_scan = PreparedScan( + request=ScanRequest( + targets=[], + instruction=args.instruction or "", + scan_mode=scan_mode, + run_name=args.run_name, + ), + run_name=args.run_name, + targets_info=args.targets_info, + local_sources=getattr(args, "local_sources", None) or [], + ) + scan_config = build_scan_config(prepared_scan) + agent_config = build_agent_config(prepared_scan, interactive=False) tracer = Tracer(args.run_name) tracer.set_scan_config(scan_config) diff --git a/strix/interface/main.py b/strix/interface/main.py index 56873f1fd..0316a1786 100644 --- a/strix/interface/main.py +++ b/strix/interface/main.py @@ -17,7 +17,7 @@ from rich.panel import Panel from rich.text import Text -from strix.config import Config, apply_saved_config, save_current_config +from strix.config import Config, apply_saved_config from strix.config.config import resolve_llm_config from strix.llm.utils import resolve_strix_model @@ -40,6 +40,7 @@ validate_config_file, validate_llm_response, ) +from strix.runtime.context import configure_runtime_context # noqa: E402 from strix.runtime.docker_runtime import HOST_GATEWAY_HOSTNAME # noqa: E402 from strix.telemetry import posthog # noqa: E402 from strix.telemetry.tracer import get_global_tracer # noqa: E402 @@ -50,121 +51,106 @@ def validate_environment() -> None: # noqa: PLR0912, PLR0915 console = Console() - missing_required_vars = [] - missing_optional_vars = [] + missing_required_fields: list[tuple[str, str]] = [] + missing_optional_fields: list[tuple[str, str]] = [] + config_path = Config.config_file() - strix_llm = Config.get("strix_llm") + strix_llm = Config.get_str("strix_llm") uses_strix_models = strix_llm and strix_llm.startswith("strix/") if not strix_llm: - missing_required_vars.append("STRIX_LLM") + missing_required_fields.append( + ("llm.model", "Model name to use with LiteLLM (for example `openai/gpt-5.4`)"), + ) has_base_url = uses_strix_models or any( [ - Config.get("llm_api_base"), - Config.get("openai_api_base"), - Config.get("litellm_base_url"), - Config.get("ollama_api_base"), + Config.get_str("llm_api_base"), + Config.get_str("openai_api_base"), + Config.get_str("litellm_base_url"), + Config.get_str("ollama_api_base"), ] ) - if not Config.get("llm_api_key"): - missing_optional_vars.append("LLM_API_KEY") + if not Config.get_str("llm_api_key"): + missing_optional_fields.append( + ( + "llm.api_key", + "API key for the LLM provider (not needed for some local or cloud providers)", + ), + ) if not has_base_url: - missing_optional_vars.append("LLM_API_BASE") + missing_optional_fields.append( + ( + "llm.api_base", + "Custom API base when using local or self-hosted providers such as Ollama", + ), + ) - if not Config.get("perplexity_api_key"): - missing_optional_vars.append("PERPLEXITY_API_KEY") + if not Config.get_str("perplexity_api_key"): + missing_optional_fields.append( + ( + "features.perplexity_api_key", + "Perplexity API key for live web research", + ), + ) - if not Config.get("strix_reasoning_effort"): - missing_optional_vars.append("STRIX_REASONING_EFFORT") + if not Config.get_str("strix_reasoning_effort"): + missing_optional_fields.append( + ( + "llm.reasoning_effort", + "Reasoning effort level: none, minimal, low, medium, high, xhigh", + ), + ) - if missing_required_vars: + if missing_required_fields: error_text = Text() - error_text.append("MISSING REQUIRED ENVIRONMENT VARIABLES", style="bold red") + error_text.append("MISSING REQUIRED CONFIGURATION", style="bold red") + error_text.append("\n\n", style="white") + error_text.append("Config file", style="dim") + error_text.append(" ") + error_text.append(str(config_path), style="bold white") error_text.append("\n\n", style="white") - for var in missing_required_vars: - error_text.append(f"• {var}", style="bold yellow") - error_text.append(" is not set\n", style="white") + for field_name, _ in missing_required_fields: + error_text.append(f"• {field_name}", style="bold yellow") + error_text.append(" is missing\n", style="white") - if missing_optional_vars: - error_text.append("\nOptional environment variables:\n", style="dim white") - for var in missing_optional_vars: - error_text.append(f"• {var}", style="dim yellow") + if missing_optional_fields: + error_text.append("\nOptional config fields:\n", style="dim white") + for field_name, _ in missing_optional_fields: + error_text.append(f"• {field_name}", style="dim yellow") error_text.append(" is not set\n", style="dim white") - error_text.append("\nRequired environment variables:\n", style="white") - for var in missing_required_vars: - if var == "STRIX_LLM": + error_text.append("\nRequired config fields:\n", style="white") + for field_name, description in missing_required_fields: + error_text.append("• ", style="white") + error_text.append(field_name, style="bold cyan") + error_text.append(f" - {description}\n", style="white") + + if missing_optional_fields: + error_text.append("\nOptional config fields:\n", style="white") + for field_name, description in missing_optional_fields: error_text.append("• ", style="white") - error_text.append("STRIX_LLM", style="bold cyan") - error_text.append( - " - Model name to use with litellm (e.g., 'openai/gpt-5.4')\n", - style="white", - ) - - if missing_optional_vars: - error_text.append("\nOptional environment variables:\n", style="white") - for var in missing_optional_vars: - if var == "LLM_API_KEY": - error_text.append("• ", style="white") - error_text.append("LLM_API_KEY", style="bold cyan") - error_text.append( - " - API key for the LLM provider " - "(not needed for local models, Vertex AI, AWS, etc.)\n", - style="white", - ) - elif var == "LLM_API_BASE": - error_text.append("• ", style="white") - error_text.append("LLM_API_BASE", style="bold cyan") - error_text.append( - " - Custom API base URL if using local models (e.g., Ollama, LMStudio)\n", - style="white", - ) - elif var == "PERPLEXITY_API_KEY": - error_text.append("• ", style="white") - error_text.append("PERPLEXITY_API_KEY", style="bold cyan") - error_text.append( - " - API key for Perplexity AI web search (enables real-time research)\n", - style="white", - ) - elif var == "STRIX_REASONING_EFFORT": - error_text.append("• ", style="white") - error_text.append("STRIX_REASONING_EFFORT", style="bold cyan") - error_text.append( - " - Reasoning effort level: none, minimal, low, medium, high, xhigh " - "(default: high)\n", - style="white", - ) + error_text.append(field_name, style="bold cyan") + error_text.append(f" - {description}\n", style="white") error_text.append("\nExample setup:\n", style="white") - error_text.append("export STRIX_LLM='openai/gpt-5.4'\n", style="dim white") - - if missing_optional_vars: - for var in missing_optional_vars: - if var == "LLM_API_KEY": - error_text.append( - "export LLM_API_KEY='your-api-key-here' " - "# not needed for local models, Vertex AI, AWS, etc.\n", - style="dim white", - ) - elif var == "LLM_API_BASE": - error_text.append( - "export LLM_API_BASE='http://localhost:11434' " - "# needed for local models only\n", - style="dim white", - ) - elif var == "PERPLEXITY_API_KEY": - error_text.append( - "export PERPLEXITY_API_KEY='your-perplexity-key-here'\n", style="dim white" - ) - elif var == "STRIX_REASONING_EFFORT": - error_text.append( - "export STRIX_REASONING_EFFORT='high'\n", - style="dim white", - ) + error_text.append( + '{\n' + ' "llm": {\n' + ' "model": "openai/gpt-5.4",\n' + ' "api_key": "your-api-key-here",\n' + ' "api_base": "http://localhost:11434",\n' + ' "reasoning_effort": "high"\n' + ' },\n' + ' "features": {\n' + ' "perplexity_api_key": "your-perplexity-key-here"\n' + " }\n" + "}\n", + style="dim white", + ) panel = Panel( error_text, @@ -206,8 +192,12 @@ async def warm_up_llm() -> None: console = Console() try: - model_name, api_key, api_base = resolve_llm_config() - litellm_model, _ = resolve_strix_model(model_name) + model_name, api_key, api_base, openai_compatible_provider = resolve_llm_config() + litellm_model, _ = resolve_strix_model( + model_name, + api_base=api_base, + openai_compatible_provider=openai_compatible_provider, + ) litellm_model = litellm_model or model_name test_messages = [ @@ -215,7 +205,7 @@ async def warm_up_llm() -> None: {"role": "user", "content": "Reply with just 'OK'."}, ] - llm_timeout = int(Config.get("llm_timeout") or "300") + llm_timeout = Config.get_int("llm_timeout") or 300 completion_kwargs: dict[str, Any] = { "model": litellm_model, @@ -360,7 +350,13 @@ def parse_arguments() -> argparse.Namespace: parser.add_argument( "--config", type=str, - help="Path to a custom config file (JSON) to use instead of ~/.strix/cli-config.json", + help="Path to a custom config file (JSON) to use instead of ~/.strix/config.json", + ) + + parser.add_argument( + "--run-name", + type=str, + help="Override the generated run name. Useful for API-triggered or externally tracked runs.", ) args = parser.parse_args() @@ -463,12 +459,13 @@ def display_completion_message(args: argparse.Namespace, results_path: Path) -> def pull_docker_image() -> None: console = Console() client = check_docker_connection() + image_name = Config.get_str("strix_image") - if image_exists(client, Config.get("strix_image")): # type: ignore[arg-type] + if image_exists(client, image_name): # type: ignore[arg-type] return console.print() - console.print(f"[dim]Pulling image[/] {Config.get('strix_image')}") + console.print(f"[dim]Pulling image[/] {image_name}") console.print("[dim yellow]This only happens on first run and may take a few minutes...[/]") console.print() @@ -477,7 +474,7 @@ def pull_docker_image() -> None: layers_info: dict[str, str] = {} last_update = "" - for line in client.api.pull(Config.get("strix_image"), stream=True, decode=True): + for line in client.api.pull(image_name, stream=True, decode=True): last_update = process_pull_line(line, layers_info, status, last_update) except DockerException as e: @@ -485,7 +482,7 @@ def pull_docker_image() -> None: error_text = Text() error_text.append("FAILED TO PULL IMAGE", style="bold red") error_text.append("\n\n", style="white") - error_text.append(f"Could not download: {Config.get('strix_image')}\n", style="white") + error_text.append(f"Could not download: {image_name}\n", style="white") error_text.append(str(e), style="dim red") panel = Panel( @@ -505,13 +502,7 @@ def pull_docker_image() -> None: def apply_config_override(config_path: str) -> None: - Config._config_file_override = validate_config_file(config_path) - apply_saved_config(force=True) - - -def persist_config() -> None: - if Config._config_file_override is None: - save_current_config() + Config.set_config_file(validate_config_file(config_path)) def main() -> None: @@ -522,6 +513,13 @@ def main() -> None: if args.config: apply_config_override(args.config) + else: + Config.reload() + + configure_runtime_context( + sandbox_mode=False, + caido_api_token=Config.get_str("caido_api_token"), + ) check_docker_installed() pull_docker_image() @@ -529,9 +527,7 @@ def main() -> None: validate_environment() asyncio.run(warm_up_llm()) - persist_config() - - args.run_name = generate_run_name(args.targets_info) + args.run_name = args.run_name or generate_run_name(args.targets_info) for target_info in args.targets_info: if target_info["type"] == "repository": @@ -545,7 +541,7 @@ def main() -> None: is_whitebox = bool(args.local_sources) posthog.start( - model=Config.get("strix_llm"), + model=Config.get_str("strix_llm"), scan_mode=args.scan_mode, is_whitebox=is_whitebox, interactive=not args.non_interactive, diff --git a/strix/interface/tui.py b/strix/interface/tui.py index f3665621d..15483a35f 100644 --- a/strix/interface/tui.py +++ b/strix/interface/tui.py @@ -34,7 +34,7 @@ from strix.interface.tool_components.registry import get_tool_renderer from strix.interface.tool_components.user_message_renderer import UserMessageRenderer from strix.interface.utils import build_tui_stats_text -from strix.llm.config import LLMConfig +from strix.scan import PreparedScan, ScanRequest, build_agent_config, build_scan_config from strix.telemetry.tracer import Tracer, set_global_tracer @@ -737,26 +737,32 @@ def __init__(self, args: argparse.Namespace): self._setup_cleanup_handlers() def _build_scan_config(self, args: argparse.Namespace) -> dict[str, Any]: - return { - "scan_id": args.run_name, - "targets": args.targets_info, - "user_instructions": args.instruction or "", - "run_name": args.run_name, - } + prepared_scan = PreparedScan( + request=ScanRequest( + targets=[], + instruction=args.instruction or "", + scan_mode=getattr(args, "scan_mode", "deep"), + run_name=args.run_name, + ), + run_name=args.run_name, + targets_info=args.targets_info, + local_sources=getattr(args, "local_sources", None) or [], + ) + return build_scan_config(prepared_scan) def _build_agent_config(self, args: argparse.Namespace) -> dict[str, Any]: - scan_mode = getattr(args, "scan_mode", "deep") - llm_config = LLMConfig(scan_mode=scan_mode, interactive=True) - - config = { - "llm_config": llm_config, - "max_iterations": 300, - } - - if getattr(args, "local_sources", None): - config["local_sources"] = args.local_sources - - return config + prepared_scan = PreparedScan( + request=ScanRequest( + targets=[], + instruction=args.instruction or "", + scan_mode=getattr(args, "scan_mode", "deep"), + run_name=args.run_name, + ), + run_name=args.run_name, + targets_info=args.targets_info, + local_sources=getattr(args, "local_sources", None) or [], + ) + return build_agent_config(prepared_scan, interactive=True) def _setup_cleanup_handlers(self) -> None: def cleanup_on_exit() -> None: diff --git a/strix/interface/utils.py b/strix/interface/utils.py index 5b9e52b4e..591367eb3 100644 --- a/strix/interface/utils.py +++ b/strix/interface/utils.py @@ -18,6 +18,9 @@ from rich.panel import Panel from rich.text import Text +from strix.config import Config +from strix.runtime.docker_client import create_docker_client + # Token formatting utilities def format_token_count(count: float) -> str: @@ -742,7 +745,9 @@ def clone_repository(repo_url: str, run_name: str, dest_name: str | None = None) # Docker utilities def check_docker_connection() -> Any: try: - return docker.from_env() + client = create_docker_client(timeout=60) + client.ping() + return client except DockerException: console = Console() error_text = Text() @@ -834,18 +839,12 @@ def validate_config_file(config_path: str) -> Path: sys.exit(1) try: - with path.open("r", encoding="utf-8") as f: - data = json.load(f) + Config.validate_file(path) except json.JSONDecodeError as e: console.print(f"[bold red]Error:[/] Invalid JSON in config file: {e}") sys.exit(1) - - if not isinstance(data, dict): - console.print("[bold red]Error:[/] Config file must contain a JSON object") - sys.exit(1) - - if "env" not in data or not isinstance(data.get("env"), dict): - console.print("[bold red]Error:[/] Config file must have an 'env' object") + except ValueError as e: + console.print(f"[bold red]Error:[/] {e}") sys.exit(1) return path diff --git a/strix/llm/config.py b/strix/llm/config.py index 9c4757a1e..ef6490893 100644 --- a/strix/llm/config.py +++ b/strix/llm/config.py @@ -17,20 +17,29 @@ def __init__( reasoning_effort: str | None = None, system_prompt_context: dict[str, Any] | None = None, ): - resolved_model, self.api_key, self.api_base = resolve_llm_config() + ( + resolved_model, + self.api_key, + self.api_base, + openai_compatible_provider, + ) = resolve_llm_config() self.model_name = model_name or resolved_model if not self.model_name: - raise ValueError("STRIX_LLM environment variable must be set and not empty") + raise ValueError("LLM model must be configured in the Strix config file") - api_model, canonical = resolve_strix_model(self.model_name) + api_model, canonical = resolve_strix_model( + self.model_name, + api_base=self.api_base, + openai_compatible_provider=openai_compatible_provider, + ) self.litellm_model: str = api_model or self.model_name self.canonical_model: str = canonical or self.model_name self.enable_prompt_caching = enable_prompt_caching self.skills = skills or [] - self.timeout = timeout or int(Config.get("llm_timeout") or "300") + self.timeout = timeout or (Config.get_int("llm_timeout") or 300) self.scan_mode = scan_mode if scan_mode in ["quick", "standard", "deep"] else "deep" diff --git a/strix/llm/dedupe.py b/strix/llm/dedupe.py index 0ea608850..d0c420241 100644 --- a/strix/llm/dedupe.py +++ b/strix/llm/dedupe.py @@ -156,8 +156,12 @@ def check_duplicate( comparison_data = {"candidate": candidate_cleaned, "existing_reports": existing_cleaned} - model_name, api_key, api_base = resolve_llm_config() - litellm_model, _ = resolve_strix_model(model_name) + model_name, api_key, api_base, openai_compatible_provider = resolve_llm_config() + litellm_model, _ = resolve_strix_model( + model_name, + api_base=api_base, + openai_compatible_provider=openai_compatible_provider, + ) litellm_model = litellm_model or model_name messages = [ diff --git a/strix/llm/llm.py b/strix/llm/llm.py index c8827e3d9..533c9dced 100644 --- a/strix/llm/llm.py +++ b/strix/llm/llm.py @@ -71,7 +71,7 @@ def __init__(self, config: LLMConfig, agent_name: str | None = None): self.memory_compressor = MemoryCompressor(model_name=config.litellm_model) self.system_prompt = self._load_system_prompt(agent_name) - reasoning = Config.get("strix_reasoning_effort") + reasoning = Config.get_str("strix_reasoning_effort") if reasoning: self._reasoning_effort = reasoning elif config.reasoning_effort: @@ -154,7 +154,7 @@ async def generate( self, conversation_history: list[dict[str, Any]] ) -> AsyncIterator[LLMResponse]: messages = self._prepare_messages(conversation_history) - max_retries = int(Config.get("strix_llm_max_retries") or "5") + max_retries = Config.get_int("strix_llm_max_retries") or 5 for attempt in range(max_retries + 1): try: diff --git a/strix/llm/memory_compressor.py b/strix/llm/memory_compressor.py index 8cad51078..6ebf745a4 100644 --- a/strix/llm/memory_compressor.py +++ b/strix/llm/memory_compressor.py @@ -104,7 +104,7 @@ def _summarize_messages( conversation = "\n".join(formatted) prompt = SUMMARY_PROMPT_TEMPLATE.format(conversation=conversation) - _, api_key, api_base = resolve_llm_config() + _, api_key, api_base, _ = resolve_llm_config() try: completion_args: dict[str, Any] = { @@ -157,11 +157,11 @@ def __init__( timeout: int | None = None, ): self.max_images = max_images - self.model_name = model_name or Config.get("strix_llm") - self.timeout = timeout or int(Config.get("strix_memory_compressor_timeout") or "120") + self.model_name = model_name or Config.get_str("strix_llm") + self.timeout = timeout or (Config.get_int("strix_memory_compressor_timeout") or 120) if not self.model_name: - raise ValueError("STRIX_LLM environment variable must be set and not empty") + raise ValueError("LLM model must be configured in the Strix config file") def compress_history( self, diff --git a/strix/llm/utils.py b/strix/llm/utils.py index 9771854f7..b4c5d16f9 100644 --- a/strix/llm/utils.py +++ b/strix/llm/utils.py @@ -43,8 +43,40 @@ def normalize_tool_format(content: str) -> str: "glm-4.7": "openrouter/z-ai/glm-4.7", } +KNOWN_PROVIDER_PREFIXES: set[str] = { + "anthropic", + "azure", + "azure_ai", + "bedrock", + "cerebras", + "claude", + "cohere", + "deepseek", + "fireworks_ai", + "gemini", + "github", + "google", + "groq", + "huggingface", + "mistral", + "ollama", + "openai", + "openrouter", + "perplexity", + "replicate", + "sambanova", + "vertex_ai", + "voyage", + "watsonx", + "xai", +} + -def resolve_strix_model(model_name: str | None) -> tuple[str | None, str | None]: +def resolve_strix_model( + model_name: str | None, + api_base: str | None = None, + openai_compatible_provider: str | None = None, +) -> tuple[str | None, str | None]: """Resolve a strix/ model into names for API calls and capability lookups. Returns (api_model, canonical_model): @@ -52,7 +84,24 @@ def resolve_strix_model(model_name: str | None) -> tuple[str | None, str | None] - canonical_model: actual provider model name for litellm capability lookups Non-strix models return the same name for both. """ - if not model_name or not model_name.startswith("strix/"): + if not model_name: + return None, None + + if not model_name.startswith("strix/"): + if api_base and openai_compatible_provider: + provider_model = _apply_openai_compatible_provider( + model_name, + openai_compatible_provider, + ) + if _register_openai_compatible_provider(provider_model, api_base): + return provider_model, provider_model + return f"openai/{provider_model}", provider_model + + if api_base and _looks_like_openai_compatible_model(model_name): + inferred_provider_model = model_name + if _register_openai_compatible_provider(inferred_provider_model, api_base): + return inferred_provider_model, inferred_provider_model + return f"openai/{model_name}", model_name return model_name, model_name base_model = model_name[6:] @@ -61,6 +110,67 @@ def resolve_strix_model(model_name: str | None) -> tuple[str | None, str | None] return api_model, canonical_model +def _looks_like_openai_compatible_model(model_name: str) -> bool: + if "/" not in model_name or model_name.startswith("openai/"): + return False + + provider_prefix = model_name.split("/", 1)[0].lower() + return provider_prefix not in KNOWN_PROVIDER_PREFIXES + + +def _apply_openai_compatible_provider(model_name: str, provider_name: str) -> str: + normalized_provider = provider_name.strip() + normalized_model = model_name.strip() + if not normalized_provider or not normalized_model: + return model_name + + if "/" not in normalized_model: + return f"{normalized_provider}/{normalized_model}" + + existing_provider, provider_model = normalized_model.split("/", 1) + if existing_provider.lower() == normalized_provider.lower(): + return f"{normalized_provider}/{provider_model}" + + if existing_provider.lower() in KNOWN_PROVIDER_PREFIXES: + return normalized_model + + return f"{normalized_provider}/{normalized_model}" + + +def _register_openai_compatible_provider(model_name: str, api_base: str) -> bool: + provider_slug = model_name.split("/", 1)[0].strip() + normalized_base = api_base.strip() + if not provider_slug or not normalized_base: + return False + + try: + from litellm.llms.openai_like.json_loader import JSONProviderRegistry, SimpleProviderConfig + except Exception: # noqa: BLE001 + return False + + provider_data = { + "base_url": normalized_base, + # LiteLLM requires this field for JSON providers, but Strix passes api_key directly. + "api_key_env": _provider_api_key_env(provider_slug), + "base_class": "openai_gpt", + } + + aliases = {provider_slug, provider_slug.lower()} + for alias in aliases: + existing = JSONProviderRegistry.get(alias) + if existing and existing.base_url == normalized_base: + continue + JSONProviderRegistry._providers[alias] = SimpleProviderConfig(alias, provider_data) + + JSONProviderRegistry._loaded = True + return True + + +def _provider_api_key_env(provider_slug: str) -> str: + sanitized = re.sub(r"[^A-Za-z0-9]+", "_", provider_slug).strip("_") + return f"{sanitized.upper()}_API_KEY" if sanitized else "STRIX_OPENAI_COMPATIBLE_API_KEY" + + def _truncate_to_first_function(content: str) -> str: if not content: return content diff --git a/strix/runtime/__init__.py b/strix/runtime/__init__.py index 5d0cbda45..37bf3c9f3 100644 --- a/strix/runtime/__init__.py +++ b/strix/runtime/__init__.py @@ -18,7 +18,7 @@ def __init__(self, message: str, details: str | None = None): def get_runtime() -> AbstractRuntime: global _global_runtime # noqa: PLW0603 - runtime_backend = Config.get("strix_runtime_backend") + runtime_backend = Config.get_str("strix_runtime_backend") if runtime_backend == "docker": from .docker_runtime import DockerRuntime diff --git a/strix/runtime/context.py b/strix/runtime/context.py new file mode 100644 index 000000000..625132276 --- /dev/null +++ b/strix/runtime/context.py @@ -0,0 +1,31 @@ +from dataclasses import dataclass + +_UNSET = object() + + +@dataclass +class RuntimeContext: + sandbox_mode: bool = False + caido_api_token: str | None = None + + +_runtime_context = RuntimeContext() + + +def configure_runtime_context( + *, + sandbox_mode: bool | None = None, + caido_api_token: str | None | object = _UNSET, +) -> None: + if sandbox_mode is not None: + _runtime_context.sandbox_mode = sandbox_mode + if caido_api_token is not _UNSET: + _runtime_context.caido_api_token = caido_api_token + + +def is_sandbox_mode() -> bool: + return _runtime_context.sandbox_mode + + +def get_caido_api_token() -> str | None: + return _runtime_context.caido_api_token diff --git a/strix/runtime/docker_client.py b/strix/runtime/docker_client.py new file mode 100644 index 000000000..a3bfccfa5 --- /dev/null +++ b/strix/runtime/docker_client.py @@ -0,0 +1,20 @@ +import sys + +import docker + +from strix.config import Config + + +def resolve_docker_base_url() -> str: + configured = Config.get_str("docker_host") + if configured: + return configured + + if sys.platform == "win32": + return "npipe:////./pipe/docker_engine" + + return "unix:///var/run/docker.sock" + + +def create_docker_client(timeout: int) -> docker.DockerClient: + return docker.DockerClient(base_url=resolve_docker_base_url(), timeout=timeout) diff --git a/strix/runtime/docker_runtime.py b/strix/runtime/docker_runtime.py index d57d35827..7b4d23413 100644 --- a/strix/runtime/docker_runtime.py +++ b/strix/runtime/docker_runtime.py @@ -1,5 +1,4 @@ import contextlib -import os import secrets import socket import time @@ -16,6 +15,7 @@ from strix.config import Config from . import SandboxInitializationError +from .docker_client import create_docker_client from .runtime import AbstractRuntime, SandboxInfo @@ -28,7 +28,7 @@ class DockerRuntime(AbstractRuntime): def __init__(self) -> None: try: - self.client = docker.from_env(timeout=DOCKER_TIMEOUT) + self.client = create_docker_client(timeout=DOCKER_TIMEOUT) except (DockerException, RequestsConnectionError, RequestsTimeout) as e: raise SandboxInitializationError( "Docker is not available", @@ -110,9 +110,9 @@ def _wait_for_tool_server(self, max_retries: int = 30, timeout: int = 5) -> None def _create_container(self, scan_id: str, max_retries: int = 2) -> Container: container_name = f"strix-scan-{scan_id}" - image_name = Config.get("strix_image") + image_name = Config.get_str("strix_image") if not image_name: - raise ValueError("STRIX_IMAGE must be configured") + raise ValueError("runtime.image must be configured in the config file") self._verify_image_available(image_name) @@ -129,7 +129,9 @@ def _create_container(self, scan_id: str, max_retries: int = 2) -> Container: self._tool_server_port = self._find_available_port() self._caido_port = self._find_available_port() self._tool_server_token = secrets.token_urlsafe(32) - execution_timeout = Config.get("strix_sandbox_execution_timeout") or "120" + execution_timeout = str( + Config.get_int("strix_sandbox_execution_timeout") or 120 + ) container = self.client.containers.run( image_name, @@ -310,7 +312,7 @@ async def get_sandbox_url(self, container_id: str, port: int) -> str: raise ValueError(f"Container {container_id} not found.") from None def _resolve_docker_host(self) -> str: - docker_host = os.getenv("DOCKER_HOST", "") + docker_host = Config.get_str("docker_host") or "" if docker_host: from urllib.parse import urlparse diff --git a/strix/runtime/tool_server.py b/strix/runtime/tool_server.py index ee5fb49a8..f92264c89 100644 --- a/strix/runtime/tool_server.py +++ b/strix/runtime/tool_server.py @@ -2,44 +2,66 @@ import argparse import asyncio -import os -import signal -import sys +from pathlib import Path from typing import Any import uvicorn -from fastapi import Depends, FastAPI, HTTPException, status +from fastapi import Depends, FastAPI, HTTPException, Request, status from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from pydantic import BaseModel, ValidationError +from strix.config import Config +from strix.runtime.context import configure_runtime_context -SANDBOX_MODE = os.getenv("STRIX_SANDBOX_MODE", "false").lower() == "true" -if not SANDBOX_MODE: - raise RuntimeError("Tool server should only run in sandbox mode (STRIX_SANDBOX_MODE=true)") -parser = argparse.ArgumentParser(description="Start Strix tool server") -parser.add_argument("--token", required=True, help="Authentication token") -parser.add_argument("--host", default="0.0.0.0", help="Host to bind to") # nosec -parser.add_argument("--port", type=int, required=True, help="Port to bind to") -parser.add_argument( - "--timeout", - type=int, - default=120, - help="Hard timeout in seconds for each request execution (default: 120)", -) - -args = parser.parse_args() -EXPECTED_TOKEN = args.token -REQUEST_TIMEOUT = args.timeout - -app = FastAPI() security = HTTPBearer() security_dependency = Depends(security) -agent_tasks: dict[str, asyncio.Task[Any]] = {} + +class ToolExecutionRequest(BaseModel): + agent_id: str + tool_name: str + kwargs: dict[str, Any] -def verify_token(credentials: HTTPAuthorizationCredentials) -> str: +class ToolExecutionResponse(BaseModel): + result: Any | None = None + error: str | None = None + + +def build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description="Start Strix tool server") + parser.add_argument("--token", required=True, help="Authentication token") + parser.add_argument("--host", default="0.0.0.0", help="Host to bind to") # nosec + parser.add_argument("--port", type=int, required=True, help="Port to bind to") + parser.add_argument( + "--timeout", + type=int, + default=120, + help="Hard timeout in seconds for each request execution (default: 120)", + ) + parser.add_argument( + "--config", + type=str, + help="Path to the runtime config file used inside the sandbox", + ) + parser.add_argument( + "--sandbox-mode", + action="store_true", + help="Mark this tool server as running inside sandbox mode", + ) + parser.add_argument( + "--caido-api-token", + type=str, + help="Internal Caido API token for proxy tooling inside the sandbox", + ) + return parser + + +def verify_token( + credentials: HTTPAuthorizationCredentials, + expected_token: str, +) -> str: if not credentials or credentials.scheme != "Bearer": raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -47,7 +69,7 @@ def verify_token(credentials: HTTPAuthorizationCredentials) -> str: headers={"WWW-Authenticate": "Bearer"}, ) - if credentials.credentials != EXPECTED_TOKEN: + if credentials.credentials != expected_token: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid authentication token", @@ -57,17 +79,6 @@ def verify_token(credentials: HTTPAuthorizationCredentials) -> str: return credentials.credentials -class ToolExecutionRequest(BaseModel): - agent_id: str - tool_name: str - kwargs: dict[str, Any] - - -class ToolExecutionResponse(BaseModel): - result: Any | None = None - error: str | None = None - - async def _run_tool(agent_id: str, tool_name: str, kwargs: dict[str, Any]) -> Any: from strix.tools.argument_parser import convert_arguments from strix.tools.context import set_current_agent_id @@ -83,83 +94,113 @@ async def _run_tool(agent_id: str, tool_name: str, kwargs: dict[str, Any]) -> An return await asyncio.to_thread(tool_func, **converted_kwargs) -@app.post("/execute", response_model=ToolExecutionResponse) -async def execute_tool( - request: ToolExecutionRequest, credentials: HTTPAuthorizationCredentials = security_dependency -) -> ToolExecutionResponse: - verify_token(credentials) - - agent_id = request.agent_id - - if agent_id in agent_tasks: - old_task = agent_tasks[agent_id] - if not old_task.done(): - old_task.cancel() - - task = asyncio.create_task( - asyncio.wait_for( - _run_tool(agent_id, request.tool_name, request.kwargs), timeout=REQUEST_TIMEOUT +def create_app( + expected_token: str, + request_timeout: int, + sandbox_mode: bool, +) -> FastAPI: + app = FastAPI() + app.state.expected_token = expected_token + app.state.request_timeout = request_timeout + app.state.sandbox_mode = sandbox_mode + app.state.agent_tasks = {} + + @app.on_event("shutdown") + async def shutdown_event() -> None: + for task in list(app.state.agent_tasks.values()): + task.cancel() + + @app.post("/execute", response_model=ToolExecutionResponse) + async def execute_tool( + request: ToolExecutionRequest, + http_request: Request, + credentials: HTTPAuthorizationCredentials = security_dependency, + ) -> ToolExecutionResponse: + verify_token(credentials, http_request.app.state.expected_token) + + agent_id = request.agent_id + agent_tasks: dict[str, asyncio.Task[Any]] = http_request.app.state.agent_tasks + + if agent_id in agent_tasks: + old_task = agent_tasks[agent_id] + if not old_task.done(): + old_task.cancel() + + task = asyncio.create_task( + asyncio.wait_for( + _run_tool(agent_id, request.tool_name, request.kwargs), + timeout=http_request.app.state.request_timeout, + ) ) + agent_tasks[agent_id] = task + + try: + result = await task + return ToolExecutionResponse(result=result) + except asyncio.CancelledError: + return ToolExecutionResponse(error="Cancelled by newer request") + except TimeoutError: + return ToolExecutionResponse( + error=f"Tool timed out after {http_request.app.state.request_timeout}s" + ) + except ValidationError as exc: + return ToolExecutionResponse(error=f"Invalid arguments: {exc}") + except (ValueError, RuntimeError, ImportError) as exc: + return ToolExecutionResponse(error=f"Tool execution error: {exc}") + except Exception as exc: # noqa: BLE001 + return ToolExecutionResponse(error=f"Unexpected error: {exc}") + finally: + if agent_tasks.get(agent_id) is task: + del agent_tasks[agent_id] + + @app.post("/register_agent") + async def register_agent( + agent_id: str, + http_request: Request, + credentials: HTTPAuthorizationCredentials = security_dependency, + ) -> dict[str, str]: + verify_token(credentials, http_request.app.state.expected_token) + return {"status": "registered", "agent_id": agent_id} + + @app.get("/health") + async def health_check(http_request: Request) -> dict[str, Any]: + agent_tasks: dict[str, asyncio.Task[Any]] = http_request.app.state.agent_tasks + return { + "status": "healthy", + "sandbox_mode": str(http_request.app.state.sandbox_mode).lower(), + "environment": "sandbox" if http_request.app.state.sandbox_mode else "main", + "auth_configured": "true" if http_request.app.state.expected_token else "false", + "active_agents": len(agent_tasks), + "agents": list(agent_tasks.keys()), + } + + return app + + +def main() -> None: + args = build_parser().parse_args() + + if args.config: + Config.set_config_file(Path(args.config)) + else: + Config.reload() + + sandbox_mode = args.sandbox_mode + if not sandbox_mode: + raise RuntimeError("Tool server should only run in sandbox mode") + + configure_runtime_context( + sandbox_mode=sandbox_mode, + caido_api_token=args.caido_api_token or Config.get_str("caido_api_token"), ) - agent_tasks[agent_id] = task - - try: - result = await task - return ToolExecutionResponse(result=result) - - except asyncio.CancelledError: - return ToolExecutionResponse(error="Cancelled by newer request") - - except TimeoutError: - return ToolExecutionResponse(error=f"Tool timed out after {REQUEST_TIMEOUT}s") - - except ValidationError as e: - return ToolExecutionResponse(error=f"Invalid arguments: {e}") - - except (ValueError, RuntimeError, ImportError) as e: - return ToolExecutionResponse(error=f"Tool execution error: {e}") - except Exception as e: # noqa: BLE001 - return ToolExecutionResponse(error=f"Unexpected error: {e}") - - finally: - if agent_tasks.get(agent_id) is task: - del agent_tasks[agent_id] - - -@app.post("/register_agent") -async def register_agent( - agent_id: str, credentials: HTTPAuthorizationCredentials = security_dependency -) -> dict[str, str]: - verify_token(credentials) - return {"status": "registered", "agent_id": agent_id} - - -@app.get("/health") -async def health_check() -> dict[str, Any]: - return { - "status": "healthy", - "sandbox_mode": str(SANDBOX_MODE), - "environment": "sandbox" if SANDBOX_MODE else "main", - "auth_configured": "true" if EXPECTED_TOKEN else "false", - "active_agents": len(agent_tasks), - "agents": list(agent_tasks.keys()), - } - - -def signal_handler(_signum: int, _frame: Any) -> None: - if hasattr(signal, "SIGPIPE"): - signal.signal(signal.SIGPIPE, signal.SIG_IGN) - for task in agent_tasks.values(): - task.cancel() - sys.exit(0) - - -if hasattr(signal, "SIGPIPE"): - signal.signal(signal.SIGPIPE, signal.SIG_IGN) + app = create_app( + expected_token=args.token, + request_timeout=args.timeout, + sandbox_mode=sandbox_mode, + ) + uvicorn.run(app, host=args.host, port=args.port, log_level="info") -signal.signal(signal.SIGTERM, signal_handler) -signal.signal(signal.SIGINT, signal_handler) if __name__ == "__main__": - uvicorn.run(app, host=args.host, port=args.port, log_level="info") + main() diff --git a/strix/scan/__init__.py b/strix/scan/__init__.py new file mode 100644 index 000000000..73b7129e8 --- /dev/null +++ b/strix/scan/__init__.py @@ -0,0 +1,24 @@ +from strix.scan.service import ( + PreparedScan, + ScanExecutionResult, + ScanRequest, + build_agent_config, + build_scan_config, + build_targets_info, + execute_prepared_scan, + generate_scan_id, + prepare_scan, +) + + +__all__ = [ + "PreparedScan", + "ScanExecutionResult", + "ScanRequest", + "build_agent_config", + "build_scan_config", + "build_targets_info", + "execute_prepared_scan", + "generate_scan_id", + "prepare_scan", +] diff --git a/strix/scan/service.py b/strix/scan/service.py new file mode 100644 index 000000000..92767ff03 --- /dev/null +++ b/strix/scan/service.py @@ -0,0 +1,151 @@ +from __future__ import annotations + +import asyncio +from dataclasses import dataclass +from typing import Any + +from strix.agents.StrixAgent import StrixAgent +from strix.llm.config import LLMConfig +from strix.runtime import cleanup_runtime +from strix.runtime.docker_runtime import HOST_GATEWAY_HOSTNAME +from strix.telemetry.tracer import Tracer, set_global_tracer +from strix.tools.agents_graph.agents_graph_actions import reset_agent_graph_state +from strix.interface.utils import ( + assign_workspace_subdirs, + clone_repository, + collect_local_sources, + generate_run_name, + infer_target_type, + rewrite_localhost_targets, +) + + +@dataclass +class ScanRequest: + targets: list[str] + instruction: str = "" + scan_mode: str = "deep" + run_name: str | None = None + + +@dataclass +class PreparedScan: + request: ScanRequest + run_name: str + targets_info: list[dict[str, Any]] + local_sources: list[dict[str, str]] + + def build_scan_config(self) -> dict[str, Any]: + return build_scan_config(self) + + def build_agent_config(self, *, interactive: bool = False) -> dict[str, Any]: + return build_agent_config(self, interactive=interactive) + + +@dataclass +class ScanExecutionResult: + prepared_scan: PreparedScan + tracer: Tracer + result: dict[str, Any] + + +def build_targets_info(raw_targets: list[str]) -> list[dict[str, Any]]: + targets_info: list[dict[str, Any]] = [] + for target in raw_targets: + target_type, target_dict = infer_target_type(target) + display_target = target_dict.get("target_path", target) if target_type == "local_code" else target + targets_info.append( + { + "type": target_type, + "details": target_dict, + "original": display_target, + } + ) + + assign_workspace_subdirs(targets_info) + rewrite_localhost_targets(targets_info, HOST_GATEWAY_HOSTNAME) + return targets_info + + +def generate_scan_id(raw_targets: list[str]) -> str: + return generate_run_name(build_targets_info(raw_targets)) + + +def prepare_scan(request: ScanRequest) -> PreparedScan: + targets_info = build_targets_info(request.targets) + + run_name = request.run_name or generate_run_name(targets_info) + + for target_info in targets_info: + if target_info["type"] != "repository": + continue + repo_url = target_info["details"]["target_repo"] + dest_name = target_info["details"].get("workspace_subdir") + cloned_path = clone_repository(repo_url, run_name, dest_name) + target_info["details"]["cloned_repo_path"] = cloned_path + + local_sources = collect_local_sources(targets_info) + return PreparedScan( + request=request, + run_name=run_name, + targets_info=targets_info, + local_sources=local_sources, + ) + + +def build_scan_config(prepared_scan: PreparedScan) -> dict[str, Any]: + return { + "scan_id": prepared_scan.run_name, + "targets": prepared_scan.targets_info, + "user_instructions": prepared_scan.request.instruction, + "run_name": prepared_scan.run_name, + } + + +def build_agent_config( + prepared_scan: PreparedScan, + *, + interactive: bool = False, +) -> dict[str, Any]: + agent_config: dict[str, Any] = { + "llm_config": LLMConfig( + scan_mode=prepared_scan.request.scan_mode, + interactive=interactive, + ), + "max_iterations": 300, + } + if prepared_scan.local_sources: + agent_config["local_sources"] = prepared_scan.local_sources + return agent_config + + +async def execute_prepared_scan( + prepared_scan: PreparedScan, + *, + interactive: bool = False, + cleanup_after_run: bool = True, +) -> ScanExecutionResult: + tracer = Tracer(prepared_scan.run_name) + scan_config = build_scan_config(prepared_scan) + agent_config = build_agent_config(prepared_scan, interactive=interactive) + + reset_agent_graph_state() + tracer.set_scan_config(scan_config) + set_global_tracer(tracer) + + try: + agent = StrixAgent(agent_config) + result = await agent.execute_scan(scan_config) + return ScanExecutionResult( + prepared_scan=prepared_scan, + tracer=tracer, + result=result, + ) + except asyncio.CancelledError: + raise + finally: + if cleanup_after_run: + tracer.cleanup() + cleanup_runtime() + set_global_tracer(None) + reset_agent_graph_state() diff --git a/strix/telemetry/flags.py b/strix/telemetry/flags.py index bae942724..4c2138da1 100644 --- a/strix/telemetry/flags.py +++ b/strix/telemetry/flags.py @@ -4,20 +4,20 @@ _DISABLED_VALUES = {"0", "false", "no", "off"} -def _is_enabled(raw_value: str | None, default: str = "1") -> bool: - value = (raw_value if raw_value is not None else default).strip().lower() +def _is_enabled(raw_value: bool | str | None, default: str = "1") -> bool: + value = str(raw_value if raw_value is not None else default).strip().lower() return value not in _DISABLED_VALUES def is_otel_enabled() -> bool: - explicit = Config.get("strix_otel_telemetry") + explicit = Config.get_bool("strix_otel_telemetry") if explicit is not None: return _is_enabled(explicit) - return _is_enabled(Config.get("strix_telemetry"), default="1") + return _is_enabled(Config.get_bool("strix_telemetry"), default="1") def is_posthog_enabled() -> bool: - explicit = Config.get("strix_posthog_telemetry") + explicit = Config.get_bool("strix_posthog_telemetry") if explicit is not None: return _is_enabled(explicit) - return _is_enabled(Config.get("strix_telemetry"), default="1") + return _is_enabled(Config.get_bool("strix_telemetry"), default="1") diff --git a/strix/telemetry/tracer.py b/strix/telemetry/tracer.py index bde9750ab..849c37f16 100644 --- a/strix/telemetry/tracer.py +++ b/strix/telemetry/tracer.py @@ -40,7 +40,7 @@ def get_global_tracer() -> Optional["Tracer"]: return _global_tracer -def set_global_tracer(tracer: "Tracer") -> None: +def set_global_tracer(tracer: Optional["Tracer"]) -> None: global _global_tracer # noqa: PLW0603 _global_tracer = tracer @@ -57,6 +57,7 @@ def __init__(self, run_name: str | None = None): self.chat_messages: list[dict[str, Any]] = [] self.streaming_content: dict[str, str] = {} self.interrupted_content: dict[str, str] = {} + self._last_emitted_streaming_content: dict[str, str] = {} self.vulnerability_reports: list[dict[str, Any]] = [] self.final_scan_result: str | None = None @@ -114,16 +115,17 @@ def _active_run_metadata(self) -> dict[str, Any]: def _setup_telemetry(self) -> None: global _OTEL_BOOTSTRAPPED, _OTEL_REMOTE_ENABLED + run_dir = self.get_run_dir() + self._events_file_path = run_dir / "events.jsonl" + if not self._telemetry_enabled: self._otel_tracer = None self._remote_export_enabled = False return - run_dir = self.get_run_dir() - self._events_file_path = run_dir / "events.jsonl" - base_url = (Config.get("traceloop_base_url") or "").strip() - api_key = (Config.get("traceloop_api_key") or "").strip() - headers_raw = Config.get("traceloop_headers") or "" + base_url = (Config.get_str("traceloop_base_url") or "").strip() + api_key = (Config.get_str("traceloop_api_key") or "").strip() + headers_raw = Config.get_str("traceloop_headers") or "" ( self._otel_tracer, @@ -192,9 +194,6 @@ def _emit_event( source: str = "strix.tracer", include_run_metadata: bool = False, ) -> None: - if not self._telemetry_enabled: - return - enriched_actor = self._enrich_actor(actor) sanitized_actor = self._sanitize_data(enriched_actor) if enriched_actor else None sanitized_payload = self._sanitize_data(payload) if payload is not None else None @@ -208,7 +207,7 @@ def _emit_event( if isinstance(current_context, SpanContext) and current_context.is_valid: parent_span_id = format_span_id(current_context.span_id) - if self._otel_tracer is not None: + if self._telemetry_enabled and self._otel_tracer is not None: try: with self._otel_tracer.start_as_current_span( f"strix.{event_type}", @@ -277,9 +276,6 @@ def set_run_name(self, run_name: str) -> None: self._emit_run_started_event() def _emit_run_started_event(self) -> None: - if not self._telemetry_enabled: - return - self._emit_event( "run.started", payload={ @@ -618,6 +614,24 @@ def save_run_data(self, mark_complete: bool = False) -> None: self.run_metadata["end_time"] = self.end_time self.run_metadata["status"] = "completed" + scan_state_file = run_dir / "scan_state.json" + with scan_state_file.open("w", encoding="utf-8") as f: + json.dump( + { + "run_metadata": self.run_metadata, + "scan_config": self.scan_config, + "scan_results": self.scan_results, + "final_scan_result": self.final_scan_result, + "vulnerability_reports": self.vulnerability_reports, + "agents": self.agents, + "tool_executions": self.tool_executions, + "chat_messages": self.chat_messages, + }, + f, + indent=2, + ensure_ascii=False, + ) + if self.final_scan_result: penetration_test_report_file = run_dir / "penetration_test_report.md" with penetration_test_report_file.open("w", encoding="utf-8") as f: @@ -825,9 +839,24 @@ def get_total_llm_stats(self) -> dict[str, Any]: def update_streaming_content(self, agent_id: str, content: str) -> None: self.streaming_content[agent_id] = content + if not content: + return + + if self._last_emitted_streaming_content.get(agent_id) == content: + return + + self._last_emitted_streaming_content[agent_id] = content + self._emit_event( + "chat.streaming", + actor={"agent_id": agent_id, "role": "assistant"}, + payload={"content": content}, + status="streaming", + source="strix.chat", + ) def clear_streaming_content(self, agent_id: str) -> None: self.streaming_content.pop(agent_id, None) + self._last_emitted_streaming_content.pop(agent_id, None) def get_streaming_content(self, agent_id: str) -> str | None: return self.streaming_content.get(agent_id) diff --git a/strix/tools/agents_graph/agents_graph_actions.py b/strix/tools/agents_graph/agents_graph_actions.py index d4425b734..6530616c7 100644 --- a/strix/tools/agents_graph/agents_graph_actions.py +++ b/strix/tools/agents_graph/agents_graph_actions.py @@ -21,6 +21,18 @@ _agent_states: dict[str, Any] = {} +def reset_agent_graph_state() -> None: + global _root_agent_id # noqa: PLW0603 + + _agent_graph["nodes"].clear() + _agent_graph["edges"].clear() + _agent_messages.clear() + _running_agents.clear() + _agent_instances.clear() + _agent_states.clear() + _root_agent_id = None + + def _run_agent_in_thread( agent: Any, state: Any, inherited_messages: list[dict[str, Any]] ) -> dict[str, Any]: diff --git a/strix/tools/executor.py b/strix/tools/executor.py index 1c2408777..321705259 100644 --- a/strix/tools/executor.py +++ b/strix/tools/executor.py @@ -1,14 +1,14 @@ import inspect -import os from typing import Any import httpx from strix.config import Config +from strix.runtime.context import is_sandbox_mode from strix.telemetry import posthog -if os.getenv("STRIX_SANDBOX_MODE", "false").lower() == "false": +if not is_sandbox_mode(): from strix.runtime import get_runtime from .argument_parser import convert_arguments @@ -21,14 +21,14 @@ ) -_SERVER_TIMEOUT = float(Config.get("strix_sandbox_execution_timeout") or "120") +_SERVER_TIMEOUT = float(Config.get_int("strix_sandbox_execution_timeout") or 120) SANDBOX_EXECUTION_TIMEOUT = _SERVER_TIMEOUT + 30 -SANDBOX_CONNECT_TIMEOUT = float(Config.get("strix_sandbox_connect_timeout") or "10") +SANDBOX_CONNECT_TIMEOUT = float(Config.get_int("strix_sandbox_connect_timeout") or 10) async def execute_tool(tool_name: str, agent_state: Any | None = None, **kwargs: Any) -> Any: execute_in_sandbox = should_execute_in_sandbox(tool_name) - sandbox_mode = os.getenv("STRIX_SANDBOX_MODE", "false").lower() == "true" + sandbox_mode = is_sandbox_mode() if execute_in_sandbox and not sandbox_mode: return await _execute_tool_in_sandbox(tool_name, agent_state, **kwargs) diff --git a/strix/tools/proxy/proxy_manager.py b/strix/tools/proxy/proxy_manager.py index c028e6c6c..ca69c10f8 100644 --- a/strix/tools/proxy/proxy_manager.py +++ b/strix/tools/proxy/proxy_manager.py @@ -1,5 +1,4 @@ import base64 -import os import re import time from typing import TYPE_CHECKING, Any @@ -11,6 +10,9 @@ from gql.transport.requests import RequestsHTTPTransport from requests.exceptions import ProxyError, RequestException, Timeout +from strix.config import Config +from strix.runtime.context import get_caido_api_token + if TYPE_CHECKING: from collections.abc import Callable @@ -27,7 +29,7 @@ def __init__(self, auth_token: str | None = None): "http": f"http://{host}:{CAIDO_PORT}", "https": f"http://{host}:{CAIDO_PORT}", } - self.auth_token = auth_token or os.getenv("CAIDO_API_TOKEN") + self.auth_token = auth_token or get_caido_api_token() or Config.get_str("caido_api_token") def _get_client(self) -> Client: transport = RequestsHTTPTransport( diff --git a/strix/tools/registry.py b/strix/tools/registry.py index 614197aae..852a3a36d 100644 --- a/strix/tools/registry.py +++ b/strix/tools/registry.py @@ -1,6 +1,5 @@ import inspect import logging -import os from collections.abc import Callable from functools import wraps from inspect import signature @@ -9,6 +8,8 @@ import defusedxml.ElementTree as DefusedET +from strix.config import Config +from strix.runtime.context import is_sandbox_mode from strix.utils.resource_paths import get_strix_resource_path @@ -150,26 +151,15 @@ def _get_schema_path(func: Callable[..., Any]) -> Path | None: def _is_sandbox_mode() -> bool: - return os.getenv("STRIX_SANDBOX_MODE", "false").lower() == "true" + return is_sandbox_mode() def _is_browser_disabled() -> bool: - if os.getenv("STRIX_DISABLE_BROWSER", "").lower() == "true": - return True - - from strix.config import Config - - val: str = Config.load().get("env", {}).get("STRIX_DISABLE_BROWSER", "") - return str(val).lower() == "true" + return Config.get_bool("strix_disable_browser") is True def _has_perplexity_api() -> bool: - if os.getenv("PERPLEXITY_API_KEY"): - return True - - from strix.config import Config - - return bool(Config.load().get("env", {}).get("PERPLEXITY_API_KEY")) + return bool(Config.get_str("perplexity_api_key")) def _should_register_tool( diff --git a/strix/tools/web_search/web_search_actions.py b/strix/tools/web_search/web_search_actions.py index e88eba707..4ac38d56b 100644 --- a/strix/tools/web_search/web_search_actions.py +++ b/strix/tools/web_search/web_search_actions.py @@ -1,8 +1,8 @@ -import os from typing import Any import requests +from strix.config import Config from strix.tools.registry import register_tool @@ -34,11 +34,11 @@ @register_tool(sandbox_execution=False, requires_web_search_mode=True) def web_search(query: str) -> dict[str, Any]: try: - api_key = os.getenv("PERPLEXITY_API_KEY") + api_key = Config.get_str("perplexity_api_key") if not api_key: return { "success": False, - "message": "PERPLEXITY_API_KEY environment variable not set", + "message": "Perplexity API key is not configured in the config file", "results": [], } diff --git a/tests/api/test_server.py b/tests/api/test_server.py new file mode 100644 index 000000000..60f476202 --- /dev/null +++ b/tests/api/test_server.py @@ -0,0 +1,138 @@ +import json + +from fastapi.testclient import TestClient + +from strix.api.models import ScanTaskRecord, ScanTaskRequest, ScanTaskResult, TaskStatus +from strix.api.server import create_app + + +class FakeStore: + def __init__(self, events_path: str): + self._events_path = events_path + + def events_file(self, _task_id: str) -> str: + from pathlib import Path + + return Path(self._events_path) + + +class FakeTaskManager: + def __init__( + self, + record: ScanTaskRecord, + result: ScanTaskResult, + events: list[dict[str, object]], + ) -> None: + self.record = record + self.result = result + self.events = events + self.store = FakeStore(record.events_path) + + def create_task(self, payload: ScanTaskRequest) -> ScanTaskRecord: + self.record.request = payload + return self.record + + def list_tasks(self) -> list[ScanTaskRecord]: + return [self.record] + + def get_task(self, task_id: str) -> ScanTaskRecord | None: + return self.record if task_id == self.record.task_id else None + + def get_result(self, task_id: str) -> ScanTaskResult | None: + return self.result if task_id == self.record.task_id else None + + def cancel_task(self, task_id: str) -> ScanTaskRecord | None: + if task_id != self.record.task_id: + return None + self.record.status = TaskStatus.CANCELLED + return self.record + + def get_events(self, task_id: str, limit: int = 200) -> list[dict[str, object]] | None: + return self.events[-limit:] if task_id == self.record.task_id else None + + def get_artifacts(self, task_id: str) -> list[str] | None: + return [self.record.events_path] if task_id == self.record.task_id else None + + def get_report_text(self, task_id: str) -> str | None: + return "# report" if task_id == self.record.task_id else None + + +def test_task_endpoints_and_sse_stream(tmp_path, write_config) -> None: + write_config({}) + + run_dir = tmp_path / "strix_runs" / "task_1234" + run_dir.mkdir(parents=True) + events_path = run_dir / "events.jsonl" + events = [ + { + "event_type": "chat.message", + "payload": {"content": "hello"}, + "run_id": "task_1234", + } + ] + events_path.write_text("\n".join(json.dumps(event) for event in events), encoding="utf-8") + + request = ScanTaskRequest(targets=["https://example.com"]) + record = ScanTaskRecord( + task_id="task_1234", + run_name="task_1234", + status=TaskStatus.COMPLETED, + created_at="2026-03-25T00:00:00+00:00", + finished_at="2026-03-25T00:01:00+00:00", + request=request, + run_dir=str(run_dir), + worker_log_path=str(run_dir / "worker.log"), + scan_state_path=str(run_dir / "scan_state.json"), + events_path=str(events_path), + ) + result = ScanTaskResult( + task=record, + scan_state={"final_scan_result": "done"}, + artifacts=[str(events_path)], + ) + fake_manager = FakeTaskManager(record=record, result=result, events=events) + + client = TestClient(create_app(task_manager=fake_manager)) + + demo_response = client.get("/demo") + assert demo_response.status_code == 200 + assert "Strix API Demo" in demo_response.text + + create_response = client.post( + "/api/v1/tasks", + json={"targets": ["https://example.com"], "scan_mode": "deep"}, + ) + assert create_response.status_code == 201 + assert create_response.json()["task"]["task_id"] == "task_1234" + + list_response = client.get("/api/v1/tasks") + assert list_response.status_code == 200 + assert len(list_response.json()["tasks"]) == 1 + + task_response = client.get("/api/v1/tasks/task_1234") + assert task_response.status_code == 200 + assert task_response.json()["task"]["task_id"] == "task_1234" + + result_response = client.get("/api/v1/tasks/task_1234/results") + assert result_response.status_code == 200 + assert result_response.json()["scan_state"]["final_scan_result"] == "done" + + events_response = client.get("/api/v1/tasks/task_1234/events") + assert events_response.status_code == 200 + assert events_response.json()["events"][0]["event_type"] == "chat.message" + + artifacts_response = client.get("/api/v1/tasks/task_1234/artifacts") + assert artifacts_response.status_code == 200 + assert artifacts_response.json()["artifacts"][0].endswith("events.jsonl") + + report_response = client.get("/api/v1/tasks/task_1234/report") + assert report_response.status_code == 200 + assert "# report" in report_response.text + + with client.stream("GET", "/api/v1/tasks/task_1234/stream") as response: + body = "".join(response.iter_text()) + + assert response.status_code == 200 + assert "event: stream.connected" in body + assert "event: chat.message" in body + assert "event: task.finished" in body diff --git a/tests/api/test_task_store.py b/tests/api/test_task_store.py new file mode 100644 index 000000000..c16018acd --- /dev/null +++ b/tests/api/test_task_store.py @@ -0,0 +1,27 @@ +from pathlib import Path + +from strix.api.models import ScanTaskRequest, TaskStatus +from strix.api.task_store import TaskStore + + +def test_refresh_marks_exited_worker_failed( + tmp_path: Path, + monkeypatch, +) -> None: + store = TaskStore(base_dir=tmp_path / "strix_runs") + record = store.create_record( + "task_1234", + ScanTaskRequest(targets=["https://example.com"]), + ) + record.pid = 4321 + record.status = TaskStatus.QUEUED + store.save(record) + + monkeypatch.setattr("strix.api.task_store._poll_process_exit_code", lambda _pid: 1) + + refreshed = store.refresh(record) + + assert refreshed.status == TaskStatus.FAILED + assert refreshed.exit_code == 1 + assert refreshed.finished_at is not None + assert refreshed.error == "Worker exited with code 1" diff --git a/tests/config/test_config_telemetry.py b/tests/config/test_config_telemetry.py index 89af42f95..6f9a3c1c2 100644 --- a/tests/config/test_config_telemetry.py +++ b/tests/config/test_config_telemetry.py @@ -1,6 +1,6 @@ import json -from strix.config.config import Config +from strix.config.config import Config, resolve_llm_config def test_traceloop_vars_are_tracked() -> None: @@ -13,7 +13,7 @@ def test_traceloop_vars_are_tracked() -> None: assert "TRACELOOP_HEADERS" in tracked -def test_apply_saved_uses_saved_traceloop_vars(monkeypatch, tmp_path) -> None: +def test_apply_saved_uses_legacy_env_style_config(monkeypatch, tmp_path) -> None: config_path = tmp_path / "cli-config.json" config_path.write_text( json.dumps( @@ -29,27 +29,60 @@ def test_apply_saved_uses_saved_traceloop_vars(monkeypatch, tmp_path) -> None: ) monkeypatch.setattr(Config, "_config_file_override", config_path) - monkeypatch.delenv("TRACELOOP_BASE_URL", raising=False) - monkeypatch.delenv("TRACELOOP_API_KEY", raising=False) - monkeypatch.delenv("TRACELOOP_HEADERS", raising=False) + monkeypatch.setattr(Config, "_cached_config", None) applied = Config.apply_saved() assert applied["TRACELOOP_BASE_URL"] == "https://otel.example.com" assert applied["TRACELOOP_API_KEY"] == "api-key" assert applied["TRACELOOP_HEADERS"] == "x-test=value" + assert Config.get_str("traceloop_base_url") == "https://otel.example.com" -def test_apply_saved_respects_existing_env_traceloop_vars(monkeypatch, tmp_path) -> None: - config_path = tmp_path / "cli-config.json" +def test_config_values_ignore_process_environment(monkeypatch, tmp_path) -> None: + config_path = tmp_path / "config.json" config_path.write_text( - json.dumps({"env": {"TRACELOOP_BASE_URL": "https://otel.example.com"}}), + json.dumps( + { + "telemetry": { + "traceloop_base_url": "https://file.example.com", + } + } + ), encoding="utf-8", ) monkeypatch.setattr(Config, "_config_file_override", config_path) + monkeypatch.setattr(Config, "_cached_config", None) monkeypatch.setenv("TRACELOOP_BASE_URL", "https://env.example.com") - applied = Config.apply_saved(force=False) + Config.reload() + + assert Config.get_str("traceloop_base_url") == "https://file.example.com" + + +def test_resolve_llm_config_reads_openai_compatible_provider(monkeypatch, tmp_path) -> None: + config_path = tmp_path / "config.json" + config_path.write_text( + json.dumps( + { + "llm": { + "model": "astron-code-latest", + "api_key": "test-key", + "api_base": "https://maas-coding-api.cn-huabei-1.xf-yun.com/v2", + "openai_compatible_provider": "AstronCodingPlan", + } + } + ), + encoding="utf-8", + ) + + monkeypatch.setattr(Config, "_config_file_override", config_path) + monkeypatch.setattr(Config, "_cached_config", None) + + model, api_key, api_base, provider = resolve_llm_config() - assert "TRACELOOP_BASE_URL" not in applied + assert model == "astron-code-latest" + assert api_key == "test-key" + assert api_base == "https://maas-coding-api.cn-huabei-1.xf-yun.com/v2" + assert provider == "AstronCodingPlan" diff --git a/tests/conftest.py b/tests/conftest.py index e698403a5..303e662e6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1 +1,31 @@ """Pytest configuration and shared fixtures for Strix tests.""" + +from __future__ import annotations + +import json +from collections.abc import Callable, Iterator +from pathlib import Path +from typing import Any + +import pytest + +from strix.config import Config + + +@pytest.fixture +def write_config( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> Iterator[Callable[[dict[str, Any]], Path]]: + config_path = tmp_path / "config.json" + + def _write(data: dict[str, Any]) -> Path: + config_path.write_text(json.dumps(data), encoding="utf-8") + monkeypatch.setattr(Config, "_config_file_override", config_path) + Config.reload() + return config_path + + _write({}) + yield _write + monkeypatch.setattr(Config, "_config_file_override", None) + Config.reload() diff --git a/tests/llm/test_llm_otel.py b/tests/llm/test_llm_otel.py index a11ffa5ad..b125af732 100644 --- a/tests/llm/test_llm_otel.py +++ b/tests/llm/test_llm_otel.py @@ -5,9 +5,11 @@ from strix.llm.llm import LLM -def test_llm_does_not_modify_litellm_callbacks(monkeypatch: pytest.MonkeyPatch) -> None: - monkeypatch.setenv("STRIX_TELEMETRY", "1") - monkeypatch.setenv("STRIX_OTEL_TELEMETRY", "1") +def test_llm_does_not_modify_litellm_callbacks( + monkeypatch: pytest.MonkeyPatch, + write_config, +) -> None: + write_config({"telemetry": {"enabled": True, "otel_enabled": True}}) monkeypatch.setattr(litellm, "callbacks", ["custom-callback"]) llm = LLM(LLMConfig(model_name="openai/gpt-5.4"), agent_name=None) diff --git a/tests/llm/test_model_resolution.py b/tests/llm/test_model_resolution.py new file mode 100644 index 000000000..802a5d8d0 --- /dev/null +++ b/tests/llm/test_model_resolution.py @@ -0,0 +1,59 @@ +from litellm import get_llm_provider +from litellm.llms.openai_like.json_loader import JSONProviderRegistry + +from strix.llm.utils import resolve_strix_model + + +def test_resolve_custom_openai_compatible_model_with_api_base() -> None: + api_model, canonical_model = resolve_strix_model( + "AstronCodingPlan/astron-code-latest", + api_base="https://maas-coding-api.cn-huabei-1.xf-yun.com/v2", + ) + + assert api_model == "AstronCodingPlan/astron-code-latest" + assert canonical_model == "AstronCodingPlan/astron-code-latest" + assert JSONProviderRegistry.exists("AstronCodingPlan") + + resolved_model, provider, dynamic_api_key, api_base = get_llm_provider( + model=api_model, + api_base="https://maas-coding-api.cn-huabei-1.xf-yun.com/v2", + api_key="test-key", + ) + + assert resolved_model == "astron-code-latest" + assert provider == "AstronCodingPlan" + assert dynamic_api_key == "test-key" + assert api_base == "https://maas-coding-api.cn-huabei-1.xf-yun.com/v2" + + +def test_resolve_explicit_openai_compatible_provider_from_config() -> None: + api_model, canonical_model = resolve_strix_model( + "astron-code-latest", + api_base="https://maas-coding-api.cn-huabei-1.xf-yun.com/v2", + openai_compatible_provider="AstronCodingPlan", + ) + + assert api_model == "AstronCodingPlan/astron-code-latest" + assert canonical_model == "AstronCodingPlan/astron-code-latest" + assert JSONProviderRegistry.exists("AstronCodingPlan") + + resolved_model, provider, dynamic_api_key, api_base = get_llm_provider( + model=api_model, + api_base="https://maas-coding-api.cn-huabei-1.xf-yun.com/v2", + api_key="test-key", + ) + + assert resolved_model == "astron-code-latest" + assert provider == "AstronCodingPlan" + assert dynamic_api_key == "test-key" + assert api_base == "https://maas-coding-api.cn-huabei-1.xf-yun.com/v2" + + +def test_resolve_known_provider_model_with_api_base_keeps_provider() -> None: + api_model, canonical_model = resolve_strix_model( + "anthropic/claude-sonnet-4-6", + api_base="https://example.com/v1", + ) + + assert api_model == "anthropic/claude-sonnet-4-6" + assert canonical_model == "anthropic/claude-sonnet-4-6" diff --git a/tests/telemetry/test_flags.py b/tests/telemetry/test_flags.py index a7f8e4350..1d25c95b3 100644 --- a/tests/telemetry/test_flags.py +++ b/tests/telemetry/test_flags.py @@ -1,28 +1,22 @@ from strix.telemetry.flags import is_otel_enabled, is_posthog_enabled -def test_flags_fallback_to_strix_telemetry(monkeypatch) -> None: - monkeypatch.delenv("STRIX_OTEL_TELEMETRY", raising=False) - monkeypatch.delenv("STRIX_POSTHOG_TELEMETRY", raising=False) - monkeypatch.setenv("STRIX_TELEMETRY", "0") +def test_flags_fallback_to_strix_telemetry(write_config) -> None: + write_config({"telemetry": {"enabled": False}}) assert is_otel_enabled() is False assert is_posthog_enabled() is False -def test_otel_flag_overrides_global_telemetry(monkeypatch) -> None: - monkeypatch.setenv("STRIX_TELEMETRY", "0") - monkeypatch.setenv("STRIX_OTEL_TELEMETRY", "1") - monkeypatch.delenv("STRIX_POSTHOG_TELEMETRY", raising=False) +def test_otel_flag_overrides_global_telemetry(write_config) -> None: + write_config({"telemetry": {"enabled": False, "otel_enabled": True}}) assert is_otel_enabled() is True assert is_posthog_enabled() is False -def test_posthog_flag_overrides_global_telemetry(monkeypatch) -> None: - monkeypatch.setenv("STRIX_TELEMETRY", "0") - monkeypatch.setenv("STRIX_POSTHOG_TELEMETRY", "1") - monkeypatch.delenv("STRIX_OTEL_TELEMETRY", raising=False) +def test_posthog_flag_overrides_global_telemetry(write_config) -> None: + write_config({"telemetry": {"enabled": False, "posthog_enabled": True}}) assert is_otel_enabled() is False assert is_posthog_enabled() is True diff --git a/tests/telemetry/test_tracer.py b/tests/telemetry/test_tracer.py index 10f887e9a..d06d7978b 100644 --- a/tests/telemetry/test_tracer.py +++ b/tests/telemetry/test_tracer.py @@ -18,17 +18,12 @@ def _load_events(events_path: Path) -> list[dict[str, Any]]: @pytest.fixture(autouse=True) -def _reset_tracer_globals(monkeypatch) -> None: +def _reset_tracer_globals(monkeypatch, write_config) -> None: monkeypatch.setattr(tracer_module, "_global_tracer", None) monkeypatch.setattr(tracer_module, "_OTEL_BOOTSTRAPPED", False) monkeypatch.setattr(tracer_module, "_OTEL_REMOTE_ENABLED", False) telemetry_utils.reset_events_write_locks() - monkeypatch.delenv("STRIX_TELEMETRY", raising=False) - monkeypatch.delenv("STRIX_OTEL_TELEMETRY", raising=False) - monkeypatch.delenv("STRIX_POSTHOG_TELEMETRY", raising=False) - monkeypatch.delenv("TRACELOOP_BASE_URL", raising=False) - monkeypatch.delenv("TRACELOOP_API_KEY", raising=False) - monkeypatch.delenv("TRACELOOP_HEADERS", raising=False) + write_config({}) def test_tracer_local_mode_writes_jsonl_with_correlation(monkeypatch, tmp_path) -> None: @@ -88,7 +83,7 @@ def test_tracer_redacts_sensitive_payloads(monkeypatch, tmp_path) -> None: assert "[REDACTED]" in serialized -def test_tracer_remote_mode_configures_traceloop_export(monkeypatch, tmp_path) -> None: +def test_tracer_remote_mode_configures_traceloop_export(monkeypatch, tmp_path, write_config) -> None: monkeypatch.chdir(tmp_path) class FakeTraceloop: @@ -103,9 +98,15 @@ def set_association_properties(properties: dict[str, Any]) -> None: # noqa: ARG return None monkeypatch.setattr(tracer_module, "Traceloop", FakeTraceloop) - monkeypatch.setenv("TRACELOOP_BASE_URL", "https://otel.example.com") - monkeypatch.setenv("TRACELOOP_API_KEY", "test-api-key") - monkeypatch.setenv("TRACELOOP_HEADERS", '{"x-custom":"header"}') + write_config( + { + "telemetry": { + "traceloop_base_url": "https://otel.example.com", + "traceloop_api_key": "test-api-key", + "traceloop_headers": '{"x-custom":"header"}', + } + } + ) tracer = Tracer("remote-observability") set_global_tracer(tracer) @@ -156,12 +157,18 @@ def set_association_properties(properties: dict[str, Any]) -> None: # noqa: ARG assert tracer._remote_export_enabled is False -def test_otlp_fallback_includes_auth_and_custom_headers(monkeypatch, tmp_path) -> None: +def test_otlp_fallback_includes_auth_and_custom_headers(monkeypatch, tmp_path, write_config) -> None: monkeypatch.chdir(tmp_path) monkeypatch.setattr(tracer_module, "Traceloop", None) - monkeypatch.setenv("TRACELOOP_BASE_URL", "https://otel.example.com") - monkeypatch.setenv("TRACELOOP_API_KEY", "test-api-key") - monkeypatch.setenv("TRACELOOP_HEADERS", '{"x-custom":"header"}') + write_config( + { + "telemetry": { + "traceloop_base_url": "https://otel.example.com", + "traceloop_api_key": "test-api-key", + "traceloop_headers": '{"x-custom":"header"}', + } + } + ) captured: dict[str, Any] = {} @@ -337,9 +344,9 @@ def set_association_properties(properties: dict[str, Any]) -> None: assert FakeTraceloop.associations[-1]["run_name"] == "renamed-run" -def test_events_write_locks_are_scoped_by_events_file(monkeypatch, tmp_path) -> None: +def test_events_write_locks_are_scoped_by_events_file(monkeypatch, tmp_path, write_config) -> None: monkeypatch.chdir(tmp_path) - monkeypatch.setenv("STRIX_TELEMETRY", "0") + write_config({"telemetry": {"enabled": False}}) tracer_one = Tracer("lock-run-a") tracer_two = Tracer("lock-run-b") @@ -352,9 +359,9 @@ def test_events_write_locks_are_scoped_by_events_file(monkeypatch, tmp_path) -> assert lock_a_from_one is not lock_b -def test_tracer_skips_jsonl_when_telemetry_disabled(monkeypatch, tmp_path) -> None: +def test_tracer_skips_jsonl_when_telemetry_disabled(monkeypatch, tmp_path, write_config) -> None: monkeypatch.chdir(tmp_path) - monkeypatch.setenv("STRIX_TELEMETRY", "0") + write_config({"telemetry": {"enabled": False}}) tracer = Tracer("telemetry-disabled") set_global_tracer(tracer) @@ -365,10 +372,9 @@ def test_tracer_skips_jsonl_when_telemetry_disabled(monkeypatch, tmp_path) -> No assert not events_path.exists() -def test_tracer_otel_flag_overrides_global_telemetry(monkeypatch, tmp_path) -> None: +def test_tracer_otel_flag_overrides_global_telemetry(monkeypatch, tmp_path, write_config) -> None: monkeypatch.chdir(tmp_path) - monkeypatch.setenv("STRIX_TELEMETRY", "0") - monkeypatch.setenv("STRIX_OTEL_TELEMETRY", "1") + write_config({"telemetry": {"enabled": False, "otel_enabled": True}}) tracer = Tracer("otel-enabled") set_global_tracer(tracer) diff --git a/tests/tools/test_tool_registration_modes.py b/tests/tools/test_tool_registration_modes.py index b50d26725..aca245a28 100644 --- a/tests/tools/test_tool_registration_modes.py +++ b/tests/tools/test_tool_registration_modes.py @@ -4,11 +4,16 @@ from typing import Any from strix.config import Config +from strix.runtime.context import configure_runtime_context from strix.tools.registry import clear_registry -def _empty_config_load(_cls: type[Config]) -> dict[str, dict[str, str]]: - return {"env": {}} +def _config_without_web_search(_cls: type[Config]) -> dict[str, Any]: + return { + "features": { + "disable_browser": True, + } + } def _reload_tools_module() -> ModuleType: @@ -24,10 +29,8 @@ def _reload_tools_module() -> ModuleType: def test_non_sandbox_registers_agents_graph_but_not_browser_or_web_search_when_disabled( monkeypatch: Any, ) -> None: - monkeypatch.setenv("STRIX_SANDBOX_MODE", "false") - monkeypatch.setenv("STRIX_DISABLE_BROWSER", "true") - monkeypatch.delenv("PERPLEXITY_API_KEY", raising=False) - monkeypatch.setattr(Config, "load", classmethod(_empty_config_load)) + configure_runtime_context(sandbox_mode=False, caido_api_token=None) + monkeypatch.setattr(Config, "load", classmethod(_config_without_web_search)) tools = _reload_tools_module() names = set(tools.get_tool_names()) @@ -40,10 +43,8 @@ def test_non_sandbox_registers_agents_graph_but_not_browser_or_web_search_when_d def test_sandbox_registers_sandbox_tools_but_not_non_sandbox_tools( monkeypatch: Any, ) -> None: - monkeypatch.setenv("STRIX_SANDBOX_MODE", "true") - monkeypatch.setenv("STRIX_DISABLE_BROWSER", "true") - monkeypatch.delenv("PERPLEXITY_API_KEY", raising=False) - monkeypatch.setattr(Config, "load", classmethod(_empty_config_load)) + configure_runtime_context(sandbox_mode=True, caido_api_token=None) + monkeypatch.setattr(Config, "load", classmethod(_config_without_web_search)) tools = _reload_tools_module() names = set(tools.get_tool_names()) @@ -61,10 +62,8 @@ def test_sandbox_registers_sandbox_tools_but_not_non_sandbox_tools( def test_load_skill_import_does_not_register_create_agent_in_sandbox( monkeypatch: Any, ) -> None: - monkeypatch.setenv("STRIX_SANDBOX_MODE", "true") - monkeypatch.setenv("STRIX_DISABLE_BROWSER", "true") - monkeypatch.delenv("PERPLEXITY_API_KEY", raising=False) - monkeypatch.setattr(Config, "load", classmethod(_empty_config_load)) + configure_runtime_context(sandbox_mode=True, caido_api_token=None) + monkeypatch.setattr(Config, "load", classmethod(_config_without_web_search)) clear_registry() for name in list(sys.modules):