Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 56 additions & 1 deletion strix/llm/memory_compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@

MAX_TOTAL_TOKENS = 100_000
MIN_RECENT_MESSAGES = 15
FALLBACK_MESSAGE_PREVIEW_CHARS = 1_500
FALLBACK_SUMMARY_MAX_MESSAGES = 12

SUMMARY_PROMPT_TEMPLATE = """You are an agent performing context
condensation for a security agent. Your job is to compress scan data while preserving
Expand Down Expand Up @@ -83,6 +85,58 @@ def _extract_message_text(msg: dict[str, Any]) -> str:
return str(content)


def _truncate(text: str, limit: int) -> str:
text = text.strip()
if len(text) <= limit:
return text
return f"{text[:limit].rstrip()}\n[...truncated...]"


def _build_fallback_summary(messages: list[dict[str, Any]]) -> dict[str, Any]:
"""Create a local, extractive summary when the compressor LLM is unavailable.

The compressor runs while the main agent is already close to the context limit.
If the summarization LLM times out, retrying more model calls can trap long
scans in a failure loop. This fallback keeps the scan moving by retaining
role/order, the beginning of the chunk, and the most recent messages where
operational state is usually concentrated.
"""
if not messages:
return {
"role": "user",
"content": "<context_summary message_count='0'>No messages to summarize</context_summary>",
}

if len(messages) <= FALLBACK_SUMMARY_MAX_MESSAGES:
selected = list(enumerate(messages, start=1))
else:
head_count = FALLBACK_SUMMARY_MAX_MESSAGES // 2
tail_count = FALLBACK_SUMMARY_MAX_MESSAGES - head_count
selected = [
*list(enumerate(messages[:head_count], start=1)),
*list(enumerate(messages[-tail_count:], start=len(messages) - tail_count + 1)),
]

lines = [
"LLM summarization failed; using local extractive fallback.",
"Preserved message previews follow in original order.",
]
for idx, msg in selected:
role = msg.get("role", "unknown")
text = _truncate(_extract_message_text(msg), FALLBACK_MESSAGE_PREVIEW_CHARS)
lines.append(f"\n[{idx}/{len(messages)}] {role}:\n{text}")

skipped = len(messages) - len(selected)
if skipped > 0:
lines.insert(2, f"{skipped} middle message(s) omitted by fallback compression.")

summary_msg = "<context_summary message_count='{count}' fallback='true'>{text}</context_summary>"
return {
"role": "user",
"content": summary_msg.format(count=len(messages), text="\n".join(lines)),
}


def _summarize_messages(
messages: list[dict[str, Any]],
model: str,
Expand Down Expand Up @@ -111,6 +165,7 @@ def _summarize_messages(
"model": model,
"messages": [{"role": "user", "content": prompt}],
"timeout": timeout,
"num_retries": 0,
}
if api_key:
completion_args["api_key"] = api_key
Expand All @@ -128,7 +183,7 @@ def _summarize_messages(
}
except Exception:
logger.exception("Failed to summarize messages")
return messages[0]
return _build_fallback_summary(messages)


def _handle_images(messages: list[dict[str, Any]], max_images: int) -> None:
Expand Down
86 changes: 86 additions & 0 deletions tests/llm/test_memory_compressor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from typing import Any

import pytest

from strix.llm import memory_compressor
from strix.llm.memory_compressor import MemoryCompressor


def _message(index: int) -> dict[str, str]:
return {"role": "user", "content": f"message {index} " + ("x" * 200)}


def test_summarizer_disables_litellm_retries(monkeypatch: pytest.MonkeyPatch) -> None:
captured: dict[str, Any] = {}

class Message:
content = "summary"

class Choice:
message = Message()

class Response:
choices = [Choice()]

def fake_completion(**kwargs: Any) -> Response:
captured.update(kwargs)
return Response()

monkeypatch.setattr(memory_compressor.litellm, "completion", fake_completion)
monkeypatch.setattr(
memory_compressor,
"resolve_llm_config",
lambda: ("openai/gpt-5.4", "test-key", None),
)

summary = memory_compressor._summarize_messages([_message(1)], "openai/gpt-5.4", 5)

assert captured["num_retries"] == 0
assert captured["timeout"] == 5
assert "summary" in summary["content"]


def test_summarizer_uses_local_fallback_on_timeout(monkeypatch: pytest.MonkeyPatch) -> None:
def fake_completion(**_: Any) -> None:
raise TimeoutError("request timed out")

monkeypatch.setattr(memory_compressor.litellm, "completion", fake_completion)
monkeypatch.setattr(
memory_compressor,
"resolve_llm_config",
lambda: ("openai/gpt-5.4", "test-key", None),
)

summary = memory_compressor._summarize_messages([_message(i) for i in range(20)], "m", 1)

assert "fallback='true'" in summary["content"]
assert "LLM summarization failed" in summary["content"]
assert "message 0" in summary["content"]
assert "message 19" in summary["content"]
assert "middle message(s) omitted" in summary["content"]


def test_compress_history_falls_back_without_returning_raw_old_messages(
monkeypatch: pytest.MonkeyPatch,
) -> None:
monkeypatch.setattr(memory_compressor, "MAX_TOTAL_TOKENS", 10)
monkeypatch.setattr(memory_compressor, "MIN_RECENT_MESSAGES", 2)
monkeypatch.setattr(memory_compressor, "get_message_tokens", lambda *_: 100)
monkeypatch.setattr(
memory_compressor,
"_summarize_messages",
lambda messages, *_: {
"role": "user",
"content": (
f"<context_summary fallback='true'>compressed {len(messages)} messages"
"</context_summary>"
),
},
)
monkeypatch.setenv("STRIX_LLM", "openai/gpt-5.4")

messages = [_message(i) for i in range(8)]
compressed = MemoryCompressor(timeout=1).compress_history(messages)

assert any("compressed 6 messages" in msg["content"] for msg in compressed)
assert compressed[-2:] == messages[-2:]