diff --git a/strix/llm/memory_compressor.py b/strix/llm/memory_compressor.py index aea086c1c..a3cb89451 100644 --- a/strix/llm/memory_compressor.py +++ b/strix/llm/memory_compressor.py @@ -11,6 +11,8 @@ MAX_TOTAL_TOKENS = 100_000 MIN_RECENT_MESSAGES = 15 +FALLBACK_MESSAGE_PREVIEW_CHARS = 1_500 +FALLBACK_SUMMARY_MAX_MESSAGES = 12 SUMMARY_PROMPT_TEMPLATE = """You are an agent performing context condensation for a security agent. Your job is to compress scan data while preserving @@ -83,6 +85,58 @@ def _extract_message_text(msg: dict[str, Any]) -> str: return str(content) +def _truncate(text: str, limit: int) -> str: + text = text.strip() + if len(text) <= limit: + return text + return f"{text[:limit].rstrip()}\n[...truncated...]" + + +def _build_fallback_summary(messages: list[dict[str, Any]]) -> dict[str, Any]: + """Create a local, extractive summary when the compressor LLM is unavailable. + + The compressor runs while the main agent is already close to the context limit. + If the summarization LLM times out, retrying more model calls can trap long + scans in a failure loop. This fallback keeps the scan moving by retaining + role/order, the beginning of the chunk, and the most recent messages where + operational state is usually concentrated. + """ + if not messages: + return { + "role": "user", + "content": "No messages to summarize", + } + + if len(messages) <= FALLBACK_SUMMARY_MAX_MESSAGES: + selected = list(enumerate(messages, start=1)) + else: + head_count = FALLBACK_SUMMARY_MAX_MESSAGES // 2 + tail_count = FALLBACK_SUMMARY_MAX_MESSAGES - head_count + selected = [ + *list(enumerate(messages[:head_count], start=1)), + *list(enumerate(messages[-tail_count:], start=len(messages) - tail_count + 1)), + ] + + lines = [ + "LLM summarization failed; using local extractive fallback.", + "Preserved message previews follow in original order.", + ] + for idx, msg in selected: + role = msg.get("role", "unknown") + text = _truncate(_extract_message_text(msg), FALLBACK_MESSAGE_PREVIEW_CHARS) + lines.append(f"\n[{idx}/{len(messages)}] {role}:\n{text}") + + skipped = len(messages) - len(selected) + if skipped > 0: + lines.insert(2, f"{skipped} middle message(s) omitted by fallback compression.") + + summary_msg = "{text}" + return { + "role": "user", + "content": summary_msg.format(count=len(messages), text="\n".join(lines)), + } + + def _summarize_messages( messages: list[dict[str, Any]], model: str, @@ -111,6 +165,7 @@ def _summarize_messages( "model": model, "messages": [{"role": "user", "content": prompt}], "timeout": timeout, + "num_retries": 0, } if api_key: completion_args["api_key"] = api_key @@ -128,7 +183,7 @@ def _summarize_messages( } except Exception: logger.exception("Failed to summarize messages") - return messages[0] + return _build_fallback_summary(messages) def _handle_images(messages: list[dict[str, Any]], max_images: int) -> None: diff --git a/tests/llm/test_memory_compressor.py b/tests/llm/test_memory_compressor.py new file mode 100644 index 000000000..0ab9d25fa --- /dev/null +++ b/tests/llm/test_memory_compressor.py @@ -0,0 +1,86 @@ +from typing import Any + +import pytest + +from strix.llm import memory_compressor +from strix.llm.memory_compressor import MemoryCompressor + + +def _message(index: int) -> dict[str, str]: + return {"role": "user", "content": f"message {index} " + ("x" * 200)} + + +def test_summarizer_disables_litellm_retries(monkeypatch: pytest.MonkeyPatch) -> None: + captured: dict[str, Any] = {} + + class Message: + content = "summary" + + class Choice: + message = Message() + + class Response: + choices = [Choice()] + + def fake_completion(**kwargs: Any) -> Response: + captured.update(kwargs) + return Response() + + monkeypatch.setattr(memory_compressor.litellm, "completion", fake_completion) + monkeypatch.setattr( + memory_compressor, + "resolve_llm_config", + lambda: ("openai/gpt-5.4", "test-key", None), + ) + + summary = memory_compressor._summarize_messages([_message(1)], "openai/gpt-5.4", 5) + + assert captured["num_retries"] == 0 + assert captured["timeout"] == 5 + assert "summary" in summary["content"] + + +def test_summarizer_uses_local_fallback_on_timeout(monkeypatch: pytest.MonkeyPatch) -> None: + def fake_completion(**_: Any) -> None: + raise TimeoutError("request timed out") + + monkeypatch.setattr(memory_compressor.litellm, "completion", fake_completion) + monkeypatch.setattr( + memory_compressor, + "resolve_llm_config", + lambda: ("openai/gpt-5.4", "test-key", None), + ) + + summary = memory_compressor._summarize_messages([_message(i) for i in range(20)], "m", 1) + + assert "fallback='true'" in summary["content"] + assert "LLM summarization failed" in summary["content"] + assert "message 0" in summary["content"] + assert "message 19" in summary["content"] + assert "middle message(s) omitted" in summary["content"] + + +def test_compress_history_falls_back_without_returning_raw_old_messages( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.setattr(memory_compressor, "MAX_TOTAL_TOKENS", 10) + monkeypatch.setattr(memory_compressor, "MIN_RECENT_MESSAGES", 2) + monkeypatch.setattr(memory_compressor, "get_message_tokens", lambda *_: 100) + monkeypatch.setattr( + memory_compressor, + "_summarize_messages", + lambda messages, *_: { + "role": "user", + "content": ( + f"compressed {len(messages)} messages" + "" + ), + }, + ) + monkeypatch.setenv("STRIX_LLM", "openai/gpt-5.4") + + messages = [_message(i) for i in range(8)] + compressed = MemoryCompressor(timeout=1).compress_history(messages) + + assert any("compressed 6 messages" in msg["content"] for msg in compressed) + assert compressed[-2:] == messages[-2:]