Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,11 @@ name: strix-penetration-test
on:
pull_request:

permissions:
security-events: write
actions: read
contents: read

jobs:
security-scan:
runs-on: ubuntu-latest
Expand All @@ -214,13 +219,23 @@ jobs:
STRIX_LLM: ${{ secrets.STRIX_LLM }}
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}

run: strix -n -t ./ --scan-mode quick
run: strix -n -t ./ --scan-mode quick --sarif-output results.sarif

- name: Upload SARIF to GitHub code scanning
if: always()
uses: github/codeql-action/upload-sarif@v4
with:
sarif_file: results.sarif
```

> [!TIP]
> In CI pull request runs, Strix automatically scopes quick reviews to changed files.
> If diff-scope cannot resolve, ensure checkout uses full history (`fetch-depth: 0`) or pass
> `--diff-base` explicitly.
> `--sarif-output` writes GitHub-compatible SARIF before Strix exits with code `2` for
> confirmed vulnerabilities, so `if: always()` preserves code scanning upload on findings.
> Use `--sarif` to write `strix_runs/<run-name>/results.sarif`, or
> `--sarif-output <path>` to choose the upload path.

### Configuration

Expand Down
35 changes: 31 additions & 4 deletions strix/interface/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@
import os
import shutil
import sys
from importlib.metadata import version
from pathlib import Path
from typing import Any

import litellm
from docker.errors import DockerException
from docker.errors import DockerException # type: ignore[import-untyped]
from rich.console import Console
from rich.panel import Panel
from rich.text import Text
Expand Down Expand Up @@ -44,6 +45,7 @@
)
from strix.runtime.docker_runtime import HOST_GATEWAY_HOSTNAME # noqa: E402
from strix.telemetry import posthog # noqa: E402
from strix.telemetry.sarif import write_sarif_report # noqa: E402
from strix.telemetry.tracer import get_global_tracer # noqa: E402


Expand Down Expand Up @@ -257,8 +259,6 @@ async def warm_up_llm() -> None:

def get_version() -> str:
try:
from importlib.metadata import version

return version("strix-agent")
except Exception: # noqa: BLE001
return "unknown"
Expand Down Expand Up @@ -387,6 +387,21 @@ def parse_arguments() -> argparse.Namespace:
help="Path to a custom config file (JSON) to use instead of ~/.strix/cli-config.json",
)

parser.add_argument(
"--sarif",
action="store_true",
help="Write GitHub code scanning SARIF results after the scan completes.",
)

parser.add_argument(
"--sarif-output",
type=str,
help=(
"Path for SARIF output. Defaults to strix_runs/<run-name>/results.sarif "
"when --sarif is enabled."
),
)

args = parser.parse_args()

if args.instruction and args.instruction_file:
Expand Down Expand Up @@ -453,7 +468,7 @@ def display_completion_message(args: argparse.Namespace, results_path: Path) ->

stats_text = build_final_stats_text(tracer)

panel_parts = [completion_text, "\n\n", target_text]
panel_parts: list[str | Text] = [completion_text, "\n\n", target_text]

if stats_text.plain:
panel_parts.extend(["\n", stats_text])
Expand Down Expand Up @@ -484,6 +499,17 @@ def display_completion_message(args: argparse.Namespace, results_path: Path) ->
console.print()


def write_requested_sarif_output(args: argparse.Namespace, results_path: Path) -> Path | None:
if not args.sarif and not args.sarif_output:
return None

output_path = Path(args.sarif_output) if args.sarif_output else results_path / "results.sarif"
tracer = get_global_tracer()
vulnerability_reports = tracer.vulnerability_reports if tracer else []
write_sarif_report(output_path, vulnerability_reports, tool_version=get_version())
return output_path


def pull_docker_image() -> None:
console = Console()
client = check_docker_connection()
Expand Down Expand Up @@ -632,6 +658,7 @@ def main() -> None: # noqa: PLR0912, PLR0915
posthog.end(tracer, exit_reason=exit_reason)

results_path = Path("strix_runs") / args.run_name
write_requested_sarif_output(args, results_path)
display_completion_message(args, results_path)

if args.non_interactive:
Expand Down
256 changes: 256 additions & 0 deletions strix/telemetry/sarif.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,256 @@
from __future__ import annotations

import json
from pathlib import PurePosixPath
from typing import TYPE_CHECKING, Any, cast


if TYPE_CHECKING:
from pathlib import Path


SARIF_SCHEMA = "https://json.schemastore.org/sarif-2.1.0.json"
SARIF_VERSION = "2.1.0"
TOOL_NAME = "Strix"
TOOL_INFORMATION_URI = "https://strix.ai"


def build_sarif_report(
vulnerability_reports: list[dict[str, Any]],
*,
tool_version: str | None = None,
) -> dict[str, Any]:
rules_by_id: dict[str, dict[str, Any]] = {}
results: list[dict[str, Any]] = []
locationless_findings: list[dict[str, Any]] = []
dropped_unsafe_location_findings: list[dict[str, Any]] = []

for report in vulnerability_reports:
locations, dropped_location_count = _build_locations(report.get("code_locations"))
if dropped_location_count:
dropped_unsafe_location_findings.append(
_dropped_location_summary(report, dropped_location_count)
)
if not locations:
locationless_findings.append(_locationless_summary(report))
continue

rule_id = _rule_id(report)
rules_by_id.setdefault(rule_id, _build_rule(rule_id, report))
results.append(_build_result(rule_id, report, locations))

driver: dict[str, Any] = {
"name": TOOL_NAME,
"informationUri": TOOL_INFORMATION_URI,
"rules": list(rules_by_id.values()),
}
if tool_version:
driver["version"] = tool_version

run: dict[str, Any] = {
"tool": {"driver": driver},
"results": results,
}
if locationless_findings:
run["properties"] = {
"locationlessFindingCount": len(locationless_findings),
"locationlessFindings": locationless_findings,
}
if dropped_unsafe_location_findings:
properties = run.setdefault("properties", {})
properties["droppedUnsafeLocationCount"] = sum(
finding["droppedLocationCount"] for finding in dropped_unsafe_location_findings
)
properties["droppedUnsafeLocationFindings"] = dropped_unsafe_location_findings

return {
"version": SARIF_VERSION,
"$schema": SARIF_SCHEMA,
"runs": [run],
}


def write_sarif_report(
output_path: Path,
vulnerability_reports: list[dict[str, Any]],
*,
tool_version: str | None = None,
) -> None:
output_path.parent.mkdir(parents=True, exist_ok=True)
sarif = build_sarif_report(vulnerability_reports, tool_version=tool_version)
with output_path.open("w", encoding="utf-8") as sarif_file:
json.dump(sarif, sarif_file, ensure_ascii=False, indent=2)
sarif_file.write("\n")


def _build_rule(rule_id: str, report: dict[str, Any]) -> dict[str, Any]:
title = _string_value(report.get("title")) or rule_id
full_description = _string_value(report.get("description")) or title
rule: dict[str, Any] = {
"id": rule_id,
"name": title,
"shortDescription": {"text": title},
"fullDescription": {"text": full_description},
"help": {"text": _help_text(report, full_description)},
}

tags = [_string_value(report.get(key)) for key in ("cwe", "cve")]
rule_tags = [tag for tag in tags if tag]
if rule_tags:
rule["properties"] = {"tags": rule_tags}

return rule


def _build_result(
rule_id: str,
report: dict[str, Any],
locations: list[dict[str, Any]],
) -> dict[str, Any]:
title = _string_value(report.get("title")) or rule_id
return {
"ruleId": rule_id,
"level": _sarif_level(report.get("severity")),
"message": {"text": title},
"locations": locations,
"properties": _result_properties(report),
}


def _result_properties(report: dict[str, Any]) -> dict[str, Any]:
properties: dict[str, Any] = {}
for key in (
"id",
"severity",
"cvss",
"target",
"endpoint",
"method",
"cve",
"cwe",
"impact",
"remediation_steps",
):
value = report.get(key)
if value not in (None, ""):
properties[key] = value
return properties


def _build_locations(raw_locations: Any) -> tuple[list[dict[str, Any]], int]:
if not isinstance(raw_locations, list):
return [], 0

raw_locations_list = cast("list[Any]", raw_locations) # type: ignore[redundant-cast]
locations: list[dict[str, Any]] = []
dropped_location_count = 0
for raw_location in raw_locations_list:
if not isinstance(raw_location, dict):
dropped_location_count += 1
continue

location = cast("dict[str, Any]", raw_location)
file_path = _string_value(location.get("file"))
start_line = location.get("start_line")
end_line = location.get("end_line")
if not file_path or type(start_line) is not int or start_line < 1:
dropped_location_count += 1
continue
uri = _sarif_uri(file_path)
if uri is None:
dropped_location_count += 1
continue

region: dict[str, Any] = {"startLine": start_line}
if type(end_line) is not int or end_line < start_line:
dropped_location_count += 1
continue
region["endLine"] = end_line

snippet = _string_value(location.get("snippet"))
if snippet:
region["snippet"] = {"text": snippet}

locations.append(
{
"physicalLocation": {
"artifactLocation": {"uri": uri},
"region": region,
}
}
)

return locations, dropped_location_count


def _rule_id(report: dict[str, Any]) -> str:
for key in ("cwe", "cve", "id"):
value = _string_value(report.get(key))
if value:
return value

title = _string_value(report.get("title")) or "strix-finding"
return _slugify(title)


def _sarif_level(severity: Any) -> str:
normalized = (_string_value(severity) or "").lower()
if normalized in {"critical", "high"}:
return "error"
if normalized == "medium":
return "warning"
return "note"


def _sarif_uri(file_path: str) -> str | None:
uri = PurePosixPath(file_path.replace("\\", "/")).as_posix()
parts = PurePosixPath(uri).parts
if not uri or uri.startswith("/") or not parts:
return None
if ":" in parts[0] or any(part == ".." for part in parts):
return None
return uri


def _string_value(value: Any) -> str | None:
if isinstance(value, str):
stripped = value.strip()
return stripped or None
return None


def _slugify(value: str) -> str:
chars = [char.lower() if char.isalnum() else "-" for char in value]
slug = "-".join(part for part in "".join(chars).split("-") if part)
return slug or "strix-finding"


def _help_text(report: dict[str, Any], fallback: str) -> str:
sections = [
_string_value(report.get("description")),
_string_value(report.get("impact")),
_string_value(report.get("remediation_steps")),
]
help_text = "\n\n".join(section for section in sections if section)
return help_text or fallback


def _locationless_summary(report: dict[str, Any]) -> dict[str, Any]:
summary: dict[str, Any] = {}
for key in ("id", "title", "severity", "cwe", "cve", "target", "endpoint", "method"):
value = report.get(key)
if value not in (None, ""):
summary[key] = value
return summary


def _dropped_location_summary(
report: dict[str, Any],
dropped_location_count: int,
) -> dict[str, Any]:
summary: dict[str, Any] = {"droppedLocationCount": dropped_location_count}
for key in ("id", "title"):
value = report.get(key)
if value not in (None, ""):
summary[key] = value
return summary
Loading