diff --git a/strix/agents/base_agent.py b/strix/agents/base_agent.py index c759f9a94..85da3dc6f 100644 --- a/strix/agents/base_agent.py +++ b/strix/agents/base_agent.py @@ -116,6 +116,17 @@ def __init__(self, config: dict[str, Any]): self._add_to_agents_graph() + self._conv_log = None + if self.state.parent_id is None and tracer: + from strix.telemetry.conversation_log import ConversationLog + + self._conv_log = ConversationLog(tracer.get_run_dir(), tracer.run_name or "") + is_resume = config.get("state") is not None + if not is_resume: + self._conv_log.write_session_start(tracer.scan_config or {}) + self.state.set_conversation_log(self._conv_log) + tracer._conversation_log = self._conv_log + def _add_to_agents_graph(self) -> None: from strix.tools.agents_graph import agents_graph_actions @@ -216,6 +227,11 @@ async def agent_loop(self, task: str) -> dict[str, Any]: # noqa: PLR0912, PLR09 should_finish = await iteration_task self._current_task = None + if self._conv_log is not None: + self._conv_log.append_iteration_end( + self.state.iteration, self.state.context, self.state.completed + ) + if should_finish is None and self.interactive: await self._enter_waiting_state(tracer, text_response=True) continue diff --git a/strix/agents/state.py b/strix/agents/state.py index da04ee7f9..7c66ee03f 100644 --- a/strix/agents/state.py +++ b/strix/agents/state.py @@ -1,8 +1,11 @@ import uuid from datetime import UTC, datetime -from typing import Any +from typing import TYPE_CHECKING, Any -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, PrivateAttr + +if TYPE_CHECKING: + from strix.telemetry.conversation_log import ConversationLog def _generate_agent_id() -> str: @@ -40,6 +43,11 @@ class AgentState(BaseModel): errors: list[str] = Field(default_factory=list) + _conversation_log: "ConversationLog | None" = PrivateAttr(default=None) + + def set_conversation_log(self, log: "ConversationLog") -> None: + self._conversation_log = log + def increment_iteration(self) -> None: self.iteration += 1 self.last_updated = datetime.now(UTC).isoformat() @@ -52,6 +60,13 @@ def add_message( message["thinking_blocks"] = thinking_blocks self.messages.append(message) self.last_updated = datetime.now(UTC).isoformat() + if self._conversation_log is not None: + self._conversation_log.append_message( + role=role, + content=content, + iteration=self.iteration, + thinking_blocks=thinking_blocks, + ) def add_action(self, action: dict[str, Any]) -> None: self.actions_taken.append( diff --git a/strix/interface/cli.py b/strix/interface/cli.py index ec853b3b8..6e95b9af6 100644 --- a/strix/interface/cli.py +++ b/strix/interface/cli.py @@ -61,9 +61,11 @@ async def run_cli(args: Any) -> None: # noqa: PLR0915 padding=(1, 2), ) - console.print("\n") - console.print(startup_panel) - console.print() + is_resume = getattr(args, "resumed_state", None) is not None + if not is_resume: + console.print("\n") + console.print(startup_panel) + console.print() scan_mode = getattr(args, "scan_mode", "deep") @@ -72,6 +74,7 @@ async def run_cli(args: Any) -> None: # noqa: PLR0915 "targets": args.targets_info, "user_instructions": args.instruction or "", "run_name": args.run_name, + "scan_mode": scan_mode, "diff_scope": getattr(args, "diff_scope", {"active": False}), } @@ -87,6 +90,11 @@ async def run_cli(args: Any) -> None: # noqa: PLR0915 if getattr(args, "local_sources", None): agent_config["local_sources"] = args.local_sources + if is_resume: + from strix.sessions import merge_into_agent_config + + merge_into_agent_config(agent_config, args.resume_bundle) + tracer = Tracer(args.run_name) tracer.set_scan_config(scan_config) diff --git a/strix/interface/main.py b/strix/interface/main.py index bc88da673..2cdc77060 100644 --- a/strix/interface/main.py +++ b/strix/interface/main.py @@ -310,11 +310,34 @@ def parse_arguments() -> argparse.Namespace: "-t", "--target", type=str, - required=True, + required=False, action="append", help="Target to test (URL, repository, local directory path, domain name, or IP address). " "Can be specified multiple times for multi-target scans.", ) + + parser.add_argument( + "--resume", + nargs="?", + const="__PICK__", + default=None, + metavar="RUN_NAME", + help="Resume a past scan session. Omit RUN_NAME to open an interactive picker.", + ) + + parser.add_argument( + "-c", + "--continue", + dest="continue_recent", + action="store_true", + help="Resume the most recent scan session.", + ) + + parser.add_argument( + "--list-sessions", + action="store_true", + help="List all past scan sessions and exit.", + ) parser.add_argument( "--instruction", type=str, @@ -389,6 +412,11 @@ def parse_arguments() -> argparse.Namespace: args = parser.parse_args() + # Resume flags make --target optional + _resume_flags = args.resume or args.continue_recent or args.list_sessions + if not args.target and not _resume_flags: + parser.error("the following arguments are required: -t/--target") + if args.instruction and args.instruction_file: parser.error( "Cannot specify both --instruction and --instruction-file. Use one or the other." @@ -405,23 +433,27 @@ def parse_arguments() -> argparse.Namespace: parser.error(f"Failed to read instruction file '{instruction_path}': {e}") args.targets_info = [] - for target in args.target: - try: - target_type, target_dict = infer_target_type(target) - - if target_type == "local_code": - display_target = target_dict.get("target_path", target) - else: - display_target = target + if args.target: + for target in args.target: + try: + target_type, target_dict = infer_target_type(target) + + if target_type == "local_code": + display_target = target_dict.get("target_path", target) + else: + display_target = target + + args.targets_info.append( + {"type": target_type, "details": target_dict, "original": display_target} + ) + except ValueError: + parser.error(f"Invalid target '{target}'") - args.targets_info.append( - {"type": target_type, "details": target_dict, "original": display_target} - ) - except ValueError: - parser.error(f"Invalid target '{target}'") + assign_workspace_subdirs(args.targets_info) + rewrite_localhost_targets(args.targets_info, HOST_GATEWAY_HOSTNAME) - assign_workspace_subdirs(args.targets_info) - rewrite_localhost_targets(args.targets_info, HOST_GATEWAY_HOSTNAME) + # Sentinel to distinguish an explicit --target from a resume-provided one + args._explicit_target = bool(args.target) return args @@ -544,6 +576,64 @@ def persist_config() -> None: save_current_config() +def _handle_resume_bootstrap(args: argparse.Namespace) -> None: + """Resolve resume/list-sessions flags early, before docker/env setup.""" + from rich.console import Console + + from strix.sessions import ResumeError, apply_resume_to_args, load_resume_bundle, most_recent + from strix.sessions.listing import list_sessions + + console = Console() + + # --list-sessions: print table and exit + if args.list_sessions: + from strix.interface.session_picker_cli import print_session_table + + rows = list_sessions() + print_session_table(rows, console) + sys.exit(0) + + bundle = None + + if args.continue_recent: + row = most_recent() + if row is None: + console.print("[red]No resumable sessions found.[/red]") + sys.exit(1) + try: + bundle = load_resume_bundle(row.run_name) + except ResumeError as exc: + console.print(f"[red]Resume failed:[/red] {exc}") + sys.exit(1) + + elif args.resume == "__PICK__": + if args.non_interactive: + console.print( + "[red]Interactive session picker unavailable in non-interactive mode.[/red]\n" + "Use [bold]--resume [/bold] or [bold]--continue[/bold] instead." + ) + sys.exit(1) + # TUI mode: defer to the TUI to push SessionPickerScreen + args.resume_pick = True + return + + elif args.resume: + try: + bundle = load_resume_bundle(args.resume) + except ResumeError as exc: + console.print(f"[red]Resume failed:[/red] {exc}") + sys.exit(1) + + if bundle is not None: + apply_resume_to_args(args, bundle) + mode_label = "Reopening" if bundle.mode == "reopen" else "Resuming" + console.print( + f"[green]{mode_label}[/green] session [bold cyan]{bundle.run_name}[/bold cyan] " + f"— iteration [bold]{bundle.agent_state.iteration}[/bold], " + f"mode [bold]{bundle.mode}[/bold]" + ) + + def main() -> None: # noqa: PLR0912, PLR0915 if sys.platform == "win32": asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) @@ -553,6 +643,8 @@ def main() -> None: # noqa: PLR0912, PLR0915 if args.config: apply_config_override(args.config) + _handle_resume_bootstrap(args) + check_docker_installed() pull_docker_image() @@ -561,7 +653,8 @@ def main() -> None: # noqa: PLR0912, PLR0915 persist_config() - args.run_name = generate_run_name(args.targets_info) + if not getattr(args, "run_name", None): + args.run_name = generate_run_name(args.targets_info) for target_info in args.targets_info: if target_info["type"] == "repository": diff --git a/strix/interface/session_picker_cli.py b/strix/interface/session_picker_cli.py new file mode 100644 index 000000000..8dfb3bb9a --- /dev/null +++ b/strix/interface/session_picker_cli.py @@ -0,0 +1,132 @@ +from __future__ import annotations + +import sys +from typing import Any + +from rich.console import Console +from rich.prompt import Prompt +from rich.table import Table +from rich.text import Text + +from strix.sessions.listing import SessionRow, list_sessions + + +def print_session_table( + rows: list[SessionRow], console: Console | None = None +) -> None: + con = console or Console() + if not rows: + con.print("[dim]No sessions found.[/dim]") + return + + table = Table(show_header=True, header_style="bold white", box=None, padding=(0, 1)) + table.add_column("#", style="dim", width=3, justify="right") + table.add_column("Run", style="bold cyan", min_width=20) + table.add_column("Status", width=10) + table.add_column("Mode", width=9) + table.add_column("Targets", min_width=24) + table.add_column("Iter", width=6, justify="right") + table.add_column("Vulns", width=5, justify="right") + table.add_column("Updated", width=16) + + for idx, row in enumerate(rows, 1): + meta = row.meta + status = meta.get("status", "unknown") + status_style = { + "running": "yellow", + "completed": "green", + "errored": "red", + }.get(status, "dim") + + targets_raw = meta.get("targets", []) + target_str = ", ".join( + t.get("original", str(t)) if isinstance(t, dict) else str(t) + for t in targets_raw[:2] + ) + if len(targets_raw) > 2: + target_str += f" +{len(targets_raw) - 2}" + + updated = _relative_time(row.last_updated_dt) + + table.add_row( + str(idx), + row.run_name, + Text(status, style=status_style), + meta.get("scan_mode", ""), + target_str or "[dim]unknown[/dim]", + str(meta.get("iteration_count", "?")), + str(meta.get("vulnerability_count", "?")), + updated, + ) + + con.print(table) + + +def pick_session_cli(query: str | None = None) -> SessionRow | None: + """Interactive CLI session picker. Returns None if user cancels.""" + if not sys.stdin.isatty(): + raise RuntimeError( + "Cannot run interactive session picker without a TTY. " + "Use --resume or --continue instead." + ) + + console = Console() + rows = [r for r in list_sessions(query=query) if r.has_conversation_log] + + if not rows: + console.print( + "\n[yellow]No resumable sessions found.[/yellow] " + "Run a scan first, then use --resume to continue it.\n" + ) + return None + + console.print() + print_session_table(rows, console) + console.print() + + while True: + choice = Prompt.ask( + "[bold]Select session[/bold] [dim](# or run name, q to quit)[/dim]", + console=console, + ).strip() + + if choice.lower() in ("q", "quit", ""): + return None + + # Numeric index + if choice.isdigit(): + idx = int(choice) - 1 + if 0 <= idx < len(rows): + return rows[idx] + console.print(f"[red]Invalid number. Enter 1–{len(rows)}.[/red]") + continue + + # Run name + match = next((r for r in rows if r.run_name == choice), None) + if match: + return match + + console.print("[red]Session not found. Try the number or exact run name.[/red]") + + +# ------------------------------------------------------------------ +# Helpers +# ------------------------------------------------------------------ + + +def _relative_time(dt: Any) -> str: + from datetime import UTC, datetime + + try: + now = datetime.now(UTC) + delta = now - dt + secs = int(delta.total_seconds()) + if secs < 60: + return "just now" + if secs < 3600: + return f"{secs // 60}m ago" + if secs < 86400: + return f"{secs // 3600}h ago" + return f"{secs // 86400}d ago" + except Exception: + return "" diff --git a/strix/interface/session_picker_tui.py b/strix/interface/session_picker_tui.py new file mode 100644 index 000000000..8192927dd --- /dev/null +++ b/strix/interface/session_picker_tui.py @@ -0,0 +1,171 @@ +from __future__ import annotations + +from textual.app import ComposeResult +from textual.binding import Binding +from textual.screen import ModalScreen +from textual.widgets import Button, DataTable, Input, Label, Static + +from strix.sessions.listing import SessionRow, list_sessions + + +class SessionPickerScreen(ModalScreen[SessionRow | None]): + """Interactive session picker modal for the TUI. + + Dismissed with the selected SessionRow, or None if cancelled. + """ + + BINDINGS = [ + Binding("escape", "cancel", "Cancel"), + Binding("enter", "resume_selected", "Resume", show=True), + Binding("/", "focus_search", "Search", show=True), + ] + + DEFAULT_CSS = """ + SessionPickerScreen { + align: center middle; + } + #picker-container { + width: 90; + height: auto; + max-height: 40; + background: $surface; + border: solid $primary; + padding: 1 2; + } + #picker-title { + text-style: bold; + color: $primary; + margin-bottom: 1; + } + #session-search { + margin-bottom: 1; + } + #session-table { + height: 20; + margin-bottom: 1; + } + #empty-state { + color: $text-muted; + margin: 2 0; + } + #picker-buttons { + layout: horizontal; + height: auto; + align: right middle; + } + #resume-btn { + margin-right: 1; + } + """ + + def __init__(self, initial_query: str | None = None) -> None: + super().__init__() + self._initial_query = initial_query or "" + self._rows: list[SessionRow] = [] + + def compose(self) -> ComposeResult: + with self.app.focused if False else self: # noqa: SIM210 + pass + from textual.containers import Container, Horizontal + + with Container(id="picker-container"): + yield Label("Resume a session", id="picker-title") + yield Input( + placeholder="Search sessions…", + value=self._initial_query, + id="session-search", + ) + yield DataTable(id="session-table", show_cursor=True, zebra_stripes=True) + yield Static("", id="empty-state") + with Horizontal(id="picker-buttons"): + yield Button("Resume", id="resume-btn", variant="primary") + yield Button("Cancel", id="cancel-btn", variant="default") + + def on_mount(self) -> None: + table = self.query_one("#session-table", DataTable) + table.add_columns("Run", "Status", "Targets", "Iter", "Vulns", "Updated") + self._refresh_table(self._initial_query) + + def on_input_changed(self, event: Input.Changed) -> None: + self._refresh_table(event.value) + + def on_data_table_row_selected(self, event: DataTable.RowSelected) -> None: + self._dismiss_selected(event.row_key.value) # type: ignore[arg-type] + + def on_button_pressed(self, event: Button.Pressed) -> None: + if event.button.id == "cancel-btn": + self.dismiss(None) + elif event.button.id == "resume-btn": + table = self.query_one("#session-table", DataTable) + if table.cursor_row is not None and self._rows: + row = self._rows[table.cursor_row] + self.dismiss(row) + + def action_cancel(self) -> None: + self.dismiss(None) + + def action_resume_selected(self) -> None: + table = self.query_one("#session-table", DataTable) + if self._rows: + row = self._rows[table.cursor_row or 0] + self.dismiss(row) + + def action_focus_search(self) -> None: + self.query_one("#session-search", Input).focus() + + # ------------------------------------------------------------------ + + def _refresh_table(self, query: str) -> None: + from datetime import UTC, datetime + + rows = [r for r in list_sessions(query=query or None) if r.has_conversation_log] + self._rows = rows + + table = self.query_one("#session-table", DataTable) + empty = self.query_one("#empty-state", Static) + table.clear() + + if not rows: + empty.update("No resumable sessions yet. Run a scan first.") + empty.display = True + table.display = False + return + + empty.display = False + table.display = True + + now = datetime.now(UTC) + for row in rows: + meta = row.meta + targets_raw = meta.get("targets", []) + target_str = ", ".join( + t.get("original", str(t)) if isinstance(t, dict) else str(t) + for t in targets_raw[:2] + ) + if len(targets_raw) > 2: + target_str += f" +{len(targets_raw) - 2}" + + delta = now - row.last_updated_dt + secs = int(delta.total_seconds()) + if secs < 60: + updated = "just now" + elif secs < 3600: + updated = f"{secs // 60}m ago" + elif secs < 86400: + updated = f"{secs // 3600}h ago" + else: + updated = f"{secs // 86400}d ago" + + table.add_row( + row.run_name, + meta.get("status", "?"), + target_str or "unknown", + str(meta.get("iteration_count", "?")), + str(meta.get("vulnerability_count", "?")), + updated, + key=row.run_name, + ) + + def _dismiss_selected(self, run_name: str) -> None: + match = next((r for r in self._rows if r.run_name == run_name), None) + self.dismiss(match) diff --git a/strix/interface/tui.py b/strix/interface/tui.py index 0cfd75411..c175f09f0 100644 --- a/strix/interface/tui.py +++ b/strix/interface/tui.py @@ -704,6 +704,11 @@ def __init__(self, args: argparse.Namespace): self.scan_config = self._build_scan_config(args) self.agent_config = self._build_agent_config(args) + if getattr(args, "resumed_state", None) is not None: + from strix.sessions import merge_into_agent_config + + merge_into_agent_config(self.agent_config, args.resume_bundle) + self.tracer = Tracer(self.scan_config["run_name"]) self.tracer.set_scan_config(self.scan_config) set_global_tracer(self.tracer) @@ -883,7 +888,41 @@ def _focus_agents_tree(self) -> None: def on_mount(self) -> None: self.title = "strix" - self.set_timer(4.5, self._hide_splash_screen) + if getattr(self.args, "resume_pick", False): + self.set_timer(0.1, self._open_session_picker) + elif getattr(self.args, "resumed_state", None) is not None: + # Already have a bundle — skip splash and start immediately + self.set_timer(0.1, self._hide_splash_screen) + else: + self.set_timer(4.5, self._hide_splash_screen) + + def _open_session_picker(self) -> None: + from strix.interface.session_picker_tui import SessionPickerScreen + + def _on_session_picked(row: "SessionRow | None") -> None: # type: ignore[name-defined] + if row is None: + self.exit() + return + from strix.sessions import ResumeError, apply_resume_to_args, load_resume_bundle + + try: + bundle = load_resume_bundle(row.run_name) + except ResumeError as exc: + from textual.widgets import Label + + self.notify(f"Resume failed: {exc}", severity="error") + self.exit() + return + + apply_resume_to_args(self.args, bundle) + from strix.sessions import merge_into_agent_config + + merge_into_agent_config(self.agent_config, bundle) + self.scan_config = self._build_scan_config(self.args) + self.tracer.set_run_name(bundle.run_name) + self._hide_splash_screen() + + self.push_screen(SessionPickerScreen(), _on_session_picked) def _hide_splash_screen(self) -> None: self.show_splash = False diff --git a/strix/sessions/__init__.py b/strix/sessions/__init__.py new file mode 100644 index 000000000..608f58491 --- /dev/null +++ b/strix/sessions/__init__.py @@ -0,0 +1,14 @@ +from strix.sessions.listing import SessionRow, get_session, list_sessions, most_recent +from strix.sessions.resume import ResumeBundle, ResumeError, apply_resume_to_args, load_resume_bundle, merge_into_agent_config + +__all__ = [ + "SessionRow", + "list_sessions", + "most_recent", + "get_session", + "ResumeBundle", + "ResumeError", + "load_resume_bundle", + "apply_resume_to_args", + "merge_into_agent_config", +] diff --git a/strix/sessions/listing.py b/strix/sessions/listing.py new file mode 100644 index 000000000..6a5fc07c9 --- /dev/null +++ b/strix/sessions/listing.py @@ -0,0 +1,143 @@ +from __future__ import annotations + +from dataclasses import dataclass +from datetime import UTC, datetime +from pathlib import Path +from typing import Any + + +@dataclass +class SessionRow: + run_name: str + run_dir: Path + meta: dict[str, Any] + has_conversation_log: bool + last_updated_dt: datetime + + +def list_sessions( + runs_root: Path | None = None, + *, + query: str | None = None, + limit: int | None = None, +) -> list[SessionRow]: + """Return scan sessions sorted by last_updated descending.""" + root = runs_root or (Path.cwd() / "strix_runs") + if not root.is_dir(): + return [] + + rows: list[SessionRow] = [] + for run_dir in root.iterdir(): + if not run_dir.is_dir(): + continue + row = _load_row(run_dir) + if row is None: + continue + rows.append(row) + + rows.sort(key=lambda r: r.last_updated_dt, reverse=True) + + if query: + q = query.lower() + rows = [r for r in rows if _matches(r, q)] + + if limit is not None: + rows = rows[:limit] + + return rows + + +def most_recent(runs_root: Path | None = None) -> SessionRow | None: + """Return the most recently updated session that has a conversation log.""" + for row in list_sessions(runs_root): + if row.has_conversation_log: + return row + return None + + +def get_session(run_name: str, runs_root: Path | None = None) -> SessionRow | None: + root = runs_root or (Path.cwd() / "strix_runs") + run_dir = root / run_name + if not run_dir.is_dir(): + return None + return _load_row(run_dir) + + +# ------------------------------------------------------------------ +# Internal helpers +# ------------------------------------------------------------------ + + +def _load_row(run_dir: Path) -> SessionRow | None: + from strix.telemetry.session_meta import read_session_meta + + meta = read_session_meta(run_dir) + if meta is None: + meta = _synthesize_meta_from_legacy(run_dir) + + has_conv_log = (run_dir / "conversation.jsonl").exists() + last_updated_dt = _parse_dt(meta.get("last_updated")) or _mtime(run_dir) + + return SessionRow( + run_name=run_dir.name, + run_dir=run_dir, + meta=meta, + has_conversation_log=has_conv_log, + last_updated_dt=last_updated_dt, + ) + + +def _synthesize_meta_from_legacy(run_dir: Path) -> dict[str, Any]: + """Best-effort metadata for runs created before session_meta.json existed.""" + vuln_count = 0 + vuln_dir = run_dir / "vulnerabilities" + if vuln_dir.is_dir(): + vuln_count = sum(1 for f in vuln_dir.iterdir() if f.suffix == ".md") + + mtime = _mtime(run_dir) + return { + "schema_version": 0, + "run_name": run_dir.name, + "created_at": mtime.isoformat(), + "last_updated": mtime.isoformat(), + "status": "unknown", + "targets": [], + "first_prompt_summary": "", + "vulnerability_count": vuln_count, + "has_conversation_log": (run_dir / "conversation.jsonl").exists(), + } + + +def _matches(row: SessionRow, q: str) -> bool: + haystack = " ".join( + [ + row.run_name, + row.meta.get("first_prompt_summary", ""), + row.meta.get("title") or "", + " ".join(row.meta.get("tags", [])), + " ".join( + t.get("original", "") if isinstance(t, dict) else str(t) + for t in row.meta.get("targets", []) + ), + ] + ).lower() + return q in haystack + + +def _parse_dt(value: str | None) -> datetime | None: + if not value: + return None + try: + dt = datetime.fromisoformat(value) + if dt.tzinfo is None: + dt = dt.replace(tzinfo=UTC) + return dt + except ValueError: + return None + + +def _mtime(path: Path) -> datetime: + try: + return datetime.fromtimestamp(path.stat().st_mtime, tz=UTC) + except OSError: + return datetime.now(UTC) diff --git a/strix/sessions/resume.py b/strix/sessions/resume.py new file mode 100644 index 000000000..ddc3d9a57 --- /dev/null +++ b/strix/sessions/resume.py @@ -0,0 +1,98 @@ +from __future__ import annotations + +import argparse +from dataclasses import dataclass +from pathlib import Path +from typing import TYPE_CHECKING, Any, Literal + +if TYPE_CHECKING: + from strix.agents.state import AgentState + + +class ResumeError(Exception): + pass + + +@dataclass +class ResumeBundle: + run_name: str + run_dir: Path + agent_state: "AgentState" + scan_config: dict[str, Any] + meta: dict[str, Any] + mode: Literal["continue", "reopen"] + + +def load_resume_bundle( + run_name: str, runs_root: Path | None = None +) -> ResumeBundle: + """Load a past session from conversation.jsonl and reconstruct AgentState.""" + from strix.agents.state import AgentState + from strix.sessions.listing import get_session + from strix.telemetry.conversation_log import ConversationLog, ReplayError + + row = get_session(run_name, runs_root) + if row is None: + raise ResumeError(f"Session '{run_name}' not found in strix_runs/") + if not row.has_conversation_log: + raise ResumeError( + f"Session '{run_name}' has no conversation log and cannot be resumed. " + "Only sessions created with Strix ≥ the resume feature can be resumed." + ) + + try: + result = ConversationLog.replay(row.run_dir) + except ReplayError as exc: + raise ResumeError(str(exc)) from exc + + mode: Literal["continue", "reopen"] = "reopen" if result.completed else "continue" + + state = AgentState( + messages=result.messages, + iteration=result.iteration, + context=result.context, + completed=False, + stop_requested=False, + ) + + if mode == "reopen": + state.add_message( + "user", + "The previous scan session has been reopened. Please summarize the key " + "findings so far and ask what to investigate or test next.", + ) + + return ResumeBundle( + run_name=row.run_name, + run_dir=row.run_dir, + agent_state=state, + scan_config=result.scan_config, + meta=row.meta, + mode=mode, + ) + + +def apply_resume_to_args(args: argparse.Namespace, bundle: ResumeBundle) -> None: + """Populate *args* from a ResumeBundle so main() can proceed normally.""" + args.run_name = bundle.run_name + args.resumed_state = bundle.agent_state + args.resume_mode = bundle.mode + args.resume_bundle = bundle + + # Only override targets/instruction if not explicitly set by the user + if not getattr(args, "target", None): + args.targets_info = bundle.scan_config.get("targets", []) + if not getattr(args, "instruction", None): + args.instruction = bundle.scan_config.get("user_instructions") or None + + scan_mode = bundle.scan_config.get("scan_mode") + if scan_mode and not getattr(args, "_scan_mode_explicit", False): + args.scan_mode = scan_mode + + +def merge_into_agent_config( + agent_config: dict[str, Any], bundle: ResumeBundle +) -> dict[str, Any]: + """Inject the restored AgentState into agent_config.""" + agent_config["state"] = bundle.agent_state + return agent_config diff --git a/strix/telemetry/conversation_log.py b/strix/telemetry/conversation_log.py new file mode 100644 index 000000000..76e901911 --- /dev/null +++ b/strix/telemetry/conversation_log.py @@ -0,0 +1,171 @@ +import json +import threading +from dataclasses import dataclass +from datetime import UTC, datetime +from pathlib import Path +from typing import Any + +SCHEMA_VERSION = 1 + + +@dataclass +class ReplayResult: + messages: list[dict[str, Any]] + scan_config: dict[str, Any] + iteration: int + context: dict[str, Any] + completed: bool + final_result: dict[str, Any] | None + schema_version: int + + +class ReplayError(Exception): + pass + + +class ConversationLog: + """Append-only JSONL log of every LLM message in a scan session. + + Writes each entry synchronously under a lock so any crash leaves + all previously-written messages intact. + """ + + def __init__(self, run_dir: Path, run_name: str) -> None: + self._path = run_dir / "conversation.jsonl" + self._run_name = run_name + self._lock = threading.Lock() + + # ------------------------------------------------------------------ + # Write helpers + # ------------------------------------------------------------------ + + def _append(self, record: dict[str, Any]) -> None: + line = json.dumps(record, ensure_ascii=False, default=str) + "\n" + with self._lock: + with self._path.open("a", encoding="utf-8") as fh: + fh.write(line) + + def write_session_start(self, scan_config: dict[str, Any]) -> None: + self._append( + { + "type": "session_start", + "schema_version": SCHEMA_VERSION, + "run_name": self._run_name, + "scan_config": scan_config, + "timestamp": datetime.now(UTC).isoformat(), + } + ) + + def append_message( + self, + role: str, + content: Any, + *, + iteration: int, + thinking_blocks: list[dict[str, Any]] | None = None, + ) -> None: + record: dict[str, Any] = { + "type": "message", + "role": role, + "content": content, + "iteration": iteration, + "timestamp": datetime.now(UTC).isoformat(), + } + if thinking_blocks: + record["thinking_blocks"] = thinking_blocks + self._append(record) + + def append_iteration_end( + self, iteration: int, context: dict[str, Any], completed: bool + ) -> None: + self._append( + { + "type": "iteration_end", + "iteration": iteration, + "context": context, + "completed": completed, + "timestamp": datetime.now(UTC).isoformat(), + } + ) + + def write_session_end( + self, completed: bool, final_result: dict[str, Any] | None = None + ) -> None: + self._append( + { + "type": "session_end", + "completed": completed, + "final_result": final_result, + "timestamp": datetime.now(UTC).isoformat(), + } + ) + + # ------------------------------------------------------------------ + # Replay + # ------------------------------------------------------------------ + + @classmethod + def replay(cls, run_dir: Path) -> ReplayResult: + """Reconstruct AgentState fields by replaying conversation.jsonl.""" + path = run_dir / "conversation.jsonl" + if not path.exists(): + raise ReplayError(f"No conversation log found at {path}") + + messages: list[dict[str, Any]] = [] + scan_config: dict[str, Any] = {} + iteration = 0 + context: dict[str, Any] = {} + completed = False + final_result: dict[str, Any] | None = None + schema_version = SCHEMA_VERSION + + try: + with path.open(encoding="utf-8") as fh: + for raw_line in fh: + raw_line = raw_line.strip() + if not raw_line: + continue + try: + entry = json.loads(raw_line) + except json.JSONDecodeError: + continue # skip corrupt line, keep partial state + + entry_type = entry.get("type") + + if entry_type == "session_start": + scan_config = entry.get("scan_config", {}) + schema_version = entry.get("schema_version", SCHEMA_VERSION) + + elif entry_type == "message": + msg: dict[str, Any] = { + "role": entry["role"], + "content": entry["content"], + } + if "thinking_blocks" in entry: + msg["thinking_blocks"] = entry["thinking_blocks"] + messages.append(msg) + + elif entry_type == "iteration_end": + iteration = entry.get("iteration", iteration) + context = entry.get("context", context) + completed = entry.get("completed", completed) + + elif entry_type == "session_end": + completed = entry.get("completed", completed) + final_result = entry.get("final_result") + + except OSError as e: + raise ReplayError(f"Failed to read conversation log: {e}") from e + + if not scan_config and not messages: + raise ReplayError("Conversation log is empty or unreadable") + + return ReplayResult( + messages=messages, + scan_config=scan_config, + iteration=iteration, + context=context, + completed=completed, + final_result=final_result, + schema_version=schema_version, + ) diff --git a/strix/telemetry/session_meta.py b/strix/telemetry/session_meta.py new file mode 100644 index 000000000..88068529e --- /dev/null +++ b/strix/telemetry/session_meta.py @@ -0,0 +1,43 @@ +import json +import os +from datetime import UTC, datetime +from pathlib import Path +from typing import Any + +SCHEMA_VERSION = 1 + + +def write_session_meta(run_dir: Path, meta: dict[str, Any]) -> None: + """Atomically merge *meta* into session_meta.json, preserving user fields.""" + path = run_dir / "session_meta.json" + existing = _read_raw(path) or {} + + merged = {**existing, **meta} + # Always preserve user-editable fields from existing file + merged["title"] = existing.get("title", merged.get("title")) + merged["tags"] = existing.get("tags", merged.get("tags", [])) + merged["last_updated"] = datetime.now(UTC).isoformat() + + tmp = path.with_suffix(".tmp") + tmp.write_text(json.dumps(merged, ensure_ascii=False, indent=2), encoding="utf-8") + os.replace(tmp, path) + + +def read_session_meta(run_dir: Path) -> dict[str, Any] | None: + return _read_raw(run_dir / "session_meta.json") + + +def update_status( + run_dir: Path, status: str, *, ended_at: str | None = None +) -> None: + update: dict[str, Any] = {"status": status} + if ended_at is not None: + update["ended_at"] = ended_at + write_session_meta(run_dir, update) + + +def _read_raw(path: Path) -> dict[str, Any] | None: + try: + return json.loads(path.read_text(encoding="utf-8")) + except (OSError, json.JSONDecodeError): + return None diff --git a/strix/telemetry/tracer.py b/strix/telemetry/tracer.py index 3f3ca6c69..8ace54d7e 100644 --- a/strix/telemetry/tracer.py +++ b/strix/telemetry/tracer.py @@ -79,6 +79,7 @@ def __init__(self, run_name: str | None = None): self._next_message_id = 1 self._saved_vuln_ids: set[str] = set() self._run_completed_emitted = False + self._conversation_log: "Any" = None # ConversationLog set by BaseAgent self._telemetry_enabled = is_otel_enabled() self._sanitizer = TelemetrySanitizer() @@ -610,6 +611,7 @@ def set_scan_config(self, config: dict[str, Any]) -> None: status="configured", source="strix.run", ) + self._write_initial_session_meta(config) def save_run_data(self, mark_complete: bool = False) -> None: try: @@ -856,5 +858,64 @@ def finalize_streaming_as_interrupted(self, agent_id: str) -> str | None: return self.interrupted_content.pop(agent_id, None) + def _write_initial_session_meta(self, config: dict[str, Any]) -> None: + try: + from strix.telemetry.session_meta import write_session_meta + + instructions = config.get("user_instructions", "") or "" + summary = (instructions[:160] + "…") if len(instructions) > 160 else instructions + if not summary: + targets = config.get("targets", []) + summary = ", ".join( + t.get("original", str(t)) if isinstance(t, dict) else str(t) + for t in targets + ) + + meta = { + "schema_version": 1, + "run_name": self.run_name or self.run_id, + "created_at": self.start_time, + "status": "running", + "scan_mode": config.get("scan_mode", "deep"), + "max_iterations": config.get("max_iterations", 300), + "targets": config.get("targets", []), + "first_prompt_summary": summary, + "has_conversation_log": True, + } + write_session_meta(self.get_run_dir(), meta) + except Exception: + pass # metadata is best-effort; never block a scan + + def _finalize_session_meta(self, completed: bool) -> None: + try: + from strix.telemetry.session_meta import write_session_meta + + write_session_meta( + self.get_run_dir(), + { + "status": "completed" if completed else "errored", + "ended_at": datetime.now(UTC).isoformat(), + "iteration_count": max( + (a.get("iteration", 0) for ag in self.agents.values() + for a in ag.get("tool_executions", [])), default=0 + ), + "vulnerability_count": len(self.vulnerability_reports), + "agent_count": len(self.agents), + }, + ) + except Exception: + pass + def cleanup(self) -> None: + completed = self.run_metadata.get("status") == "completed" or bool( + self.final_scan_result + ) + if self._conversation_log is not None: + try: + self._conversation_log.write_session_end( + completed=completed, final_result=None + ) + except Exception: + pass + self._finalize_session_meta(completed) self.save_run_data(mark_complete=True) diff --git a/tests/test_resume_feature.py b/tests/test_resume_feature.py new file mode 100644 index 000000000..d180c1287 --- /dev/null +++ b/tests/test_resume_feature.py @@ -0,0 +1,256 @@ +"""Unit tests for the resume session feature (no heavy deps required).""" + +import json +import os +import sys +import tempfile +from datetime import UTC, datetime +from pathlib import Path + +import pytest + +# --------------------------------------------------------------------------- +# conversation_log +# --------------------------------------------------------------------------- + + +def _import_conv_log(): + """Import ConversationLog without triggering strix.telemetry.__init__.""" + import importlib.util + + spec = importlib.util.spec_from_file_location( + "conversation_log", + Path(__file__).parents[1] / "strix/telemetry/conversation_log.py", + ) + mod = importlib.util.module_from_spec(spec) # type: ignore[attr-defined] + spec.loader.exec_module(mod) # type: ignore[union-attr] + return mod + + +def _import_session_meta(): + import importlib.util + + spec = importlib.util.spec_from_file_location( + "session_meta", + Path(__file__).parents[1] / "strix/telemetry/session_meta.py", + ) + mod = importlib.util.module_from_spec(spec) # type: ignore[attr-defined] + spec.loader.exec_module(mod) # type: ignore[union-attr] + return mod + + +class TestConversationLog: + def setup_method(self): + self.tmp = Path(tempfile.mkdtemp()) + mod = _import_conv_log() + self.ConversationLog = mod.ConversationLog + self.ReplayError = mod.ReplayError + self.SCHEMA_VERSION = mod.SCHEMA_VERSION + + def test_roundtrip(self): + log = self.ConversationLog(self.tmp, "test-run") + scan_config = {"targets": [{"original": "http://example.com"}], "scan_mode": "deep"} + log.write_session_start(scan_config) + log.append_message("user", "Hello, find vulns", iteration=1) + log.append_message("assistant", [{"type": "text", "text": "Starting scan"}], iteration=1) + log.append_iteration_end(1, {"found": True}, completed=False) + log.append_message("user", "any XSS?", iteration=2) + log.append_iteration_end(2, {"found": True}, completed=False) + log.write_session_end(completed=True, final_result={"vulns": 1}) + + result = self.ConversationLog.replay(self.tmp) + + assert len(result.messages) == 3 + assert result.messages[0] == {"role": "user", "content": "Hello, find vulns"} + assert result.messages[1]["role"] == "assistant" + assert result.messages[2] == {"role": "user", "content": "any XSS?"} + assert result.scan_config == scan_config + assert result.iteration == 2 + assert result.context == {"found": True} + assert result.completed is True + assert result.schema_version == self.SCHEMA_VERSION + + def test_thinking_blocks_preserved(self): + log = self.ConversationLog(self.tmp, "test-run") + log.write_session_start({}) + log.append_message( + "assistant", + "Result", + iteration=1, + thinking_blocks=[{"type": "thinking", "thinking": "deep thought"}], + ) + + result = self.ConversationLog.replay(self.tmp) + msg = result.messages[0] + assert msg["thinking_blocks"] == [{"type": "thinking", "thinking": "deep thought"}] + + def test_corrupt_lines_skipped(self): + log_path = self.tmp / "conversation.jsonl" + log_path.write_text( + json.dumps({"type": "session_start", "scan_config": {"targets": []}, "schema_version": 1}) + + "\n" + + "CORRUPTED LINE\n" + + json.dumps({"type": "message", "role": "user", "content": "hi", "iteration": 1}) + + "\n", + encoding="utf-8", + ) + result = self.ConversationLog.replay(self.tmp) + assert len(result.messages) == 1 + + def test_replay_missing_file(self): + empty_dir = self.tmp / "empty" + empty_dir.mkdir() + with pytest.raises(Exception): + self.ConversationLog.replay(empty_dir) + + def test_replay_empty_file(self): + (self.tmp / "conversation.jsonl").write_text("", encoding="utf-8") + with pytest.raises(Exception): + self.ConversationLog.replay(self.tmp) + + def test_crash_safe_partial_replay(self): + """Simulates a crash after a few messages — partial data still recoverable.""" + log = self.ConversationLog(self.tmp, "run") + log.write_session_start({"targets": []}) + log.append_message("user", "start", iteration=1) + log.append_message("assistant", "ok", iteration=1) + # No session_end written (simulated crash) + + result = self.ConversationLog.replay(self.tmp) + assert len(result.messages) == 2 + assert result.completed is False + + +class TestSessionMeta: + def setup_method(self): + self.tmp = Path(tempfile.mkdtemp()) + mod = _import_session_meta() + self.write = mod.write_session_meta + self.read = mod.read_session_meta + self.update_status = mod.update_status + + def test_write_and_read(self): + meta = {"run_name": "abc", "status": "running", "targets": []} + self.write(self.tmp, meta) + result = self.read(self.tmp) + assert result is not None + assert result["run_name"] == "abc" + assert result["status"] == "running" + assert "last_updated" in result + + def test_merge_preserves_user_fields(self): + self.write(self.tmp, {"title": "My custom title", "tags": ["pentest"]}) + self.write(self.tmp, {"status": "completed"}) # second write should preserve title + result = self.read(self.tmp) + assert result["title"] == "My custom title" + assert result["tags"] == ["pentest"] + assert result["status"] == "completed" + + def test_update_status(self): + self.write(self.tmp, {"status": "running"}) + self.update_status(self.tmp, "completed", ended_at="2026-04-26T00:00:00Z") + result = self.read(self.tmp) + assert result["status"] == "completed" + assert result["ended_at"] == "2026-04-26T00:00:00Z" + + def test_missing_file_returns_none(self): + empty = self.tmp / "noexist" + empty.mkdir() + assert self.read(empty) is None + + def test_corrupt_file_returns_none(self): + (self.tmp / "session_meta.json").write_text("NOT JSON", encoding="utf-8") + assert self.read(self.tmp) is None + + def test_atomic_write(self): + """Verify no tmp file left behind after write.""" + self.write(self.tmp, {"status": "running"}) + tmp_files = list(self.tmp.glob("*.tmp")) + assert len(tmp_files) == 0 + + +# --------------------------------------------------------------------------- +# listing (needs session_meta, no tracer) +# --------------------------------------------------------------------------- + + +def _make_run_dir(root: Path, name: str, **meta_kw) -> Path: + """Helper to create a fake run dir with session_meta.json.""" + mod = _import_session_meta() + run_dir = root / name + run_dir.mkdir() + mod.write_session_meta(run_dir, {"run_name": name, "status": "completed", **meta_kw}) + return run_dir + + +class TestListing: + def setup_method(self): + self.tmp = Path(tempfile.mkdtemp()) + + def _listing(self): + import importlib.util + import types + + root = Path(__file__).parents[1] + + def _load(mod_name: str, rel_path: str) -> types.ModuleType: + spec = importlib.util.spec_from_file_location(mod_name, root / rel_path) + mod = importlib.util.module_from_spec(spec) # type: ignore[attr-defined] + sys.modules[mod_name] = mod + spec.loader.exec_module(mod) # type: ignore[union-attr] + return mod + + # Register stub packages before loading real modules + for pkg in ("strix", "strix.sessions", "strix.telemetry"): + sys.modules.setdefault(pkg, types.ModuleType(pkg)) + + _load("strix.telemetry.session_meta", "strix/telemetry/session_meta.py") + return _load("strix.sessions.listing", "strix/sessions/listing.py") + + def test_empty_root(self): + mod = self._listing() + rows = mod.list_sessions(runs_root=self.tmp) + assert rows == [] + + def test_returns_session_rows(self): + _make_run_dir(self.tmp, "run-a") + _make_run_dir(self.tmp, "run-b") + mod = self._listing() + rows = mod.list_sessions(runs_root=self.tmp) + names = {r.run_name for r in rows} + assert "run-a" in names + assert "run-b" in names + + def test_most_recent_requires_conv_log(self): + run_dir = _make_run_dir(self.tmp, "run-no-log") + # no conversation.jsonl → has_conversation_log=False + mod = self._listing() + result = mod.most_recent(runs_root=self.tmp) + assert result is None + + def test_most_recent_with_conv_log(self): + run_dir = _make_run_dir(self.tmp, "run-with-log") + (run_dir / "conversation.jsonl").write_text("{}", encoding="utf-8") + mod = self._listing() + result = mod.most_recent(runs_root=self.tmp) + assert result is not None + assert result.run_name == "run-with-log" + + def test_query_filter(self): + _make_run_dir(self.tmp, "example-com-abc", first_prompt_summary="Focus on XSS") + _make_run_dir(self.tmp, "github-repo-xyz", first_prompt_summary="Check auth flow") + mod = self._listing() + rows = mod.list_sessions(runs_root=self.tmp, query="xss") + assert len(rows) == 1 + assert rows[0].run_name == "example-com-abc" + + def test_get_session_by_name(self): + _make_run_dir(self.tmp, "my-scan") + mod = self._listing() + row = mod.get_session("my-scan", runs_root=self.tmp) + assert row is not None + assert row.run_name == "my-scan" + + def test_get_session_missing(self): + mod = self._listing() + assert mod.get_session("nonexistent", runs_root=self.tmp) is None