Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions strix/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ class Config:
strix_reasoning_effort = "high"
strix_llm_max_retries = "5"
strix_memory_compressor_timeout = "30"
strix_max_context_tokens = None # Default: 100000
strix_min_recent_messages = None # Default: 15
strix_max_tool_output_chars = None # Default: 0 (no truncation)
llm_timeout = "300"
_LLM_CANONICAL_NAMES = (
"strix_llm",
Expand Down
24 changes: 24 additions & 0 deletions strix/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,11 +360,18 @@ def _strip_images(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
return result

def _add_cache_control(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Add cache_control breakpoints to stable message segments.

Caches the system prompt and the agent identity message since these
are identical across every iteration within an agent's lifetime.
Cache hits cost ~90% less than re-processing on Anthropic models.
"""
if not messages or not supports_prompt_caching(self.config.canonical_model):
return messages

result = list(messages)

# Cache breakpoint 1: system prompt (unchanged across all iterations)
if result[0].get("role") == "system":
content = result[0]["content"]
result[0] = {
Expand All @@ -375,4 +382,21 @@ def _add_cache_control(self, messages: list[dict[str, Any]]) -> list[dict[str, A
if isinstance(content, str)
else content,
}

# Cache breakpoint 2: agent identity message (stable per-agent)
if len(result) > 1 and "<agent_identity>" in str(result[1].get("content", "")):
content = result[1]["content"]
if isinstance(content, str):
result[1] = {
**result[1],
"content": [
{"type": "text", "text": content, "cache_control": {"type": "ephemeral"}}
],
}
elif isinstance(content, list) and content:
# Content is already a list — add cache_control to the last item
last = content[-1]
if isinstance(last, dict):
last["cache_control"] = {"type": "ephemeral"}

return result
102 changes: 93 additions & 9 deletions strix/llm/memory_compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,15 @@
logger = logging.getLogger(__name__)


MAX_TOTAL_TOKENS = 100_000
MIN_RECENT_MESSAGES = 15
DEFAULT_MAX_TOTAL_TOKENS = 100_000
DEFAULT_MIN_RECENT_MESSAGES = 15
DEFAULT_MAX_TOOL_OUTPUT_CHARS = 0 # 0 = no truncation (backwards compatible)

TOOL_TRUNCATION_NOTICE = (
"\n\n[Output truncated: showing first {head_len} and last {tail_len} characters "
"of {original_len}-character output (limit: {max_len}). "
"The middle portion has been permanently removed.]"
)

SUMMARY_PROMPT_TEMPLATE = """You are an agent performing context
condensation for a security agent. Your job is to compress scan data while preserving
Expand Down Expand Up @@ -131,6 +138,24 @@ def _summarize_messages(
return messages[0]


def _truncate_tool_output(text: str, max_chars: int) -> str:
"""Truncate large tool outputs while preserving the beginning and end.

Keeps the first 60% and last 40% of the allowed length so that both
the command/header and the tail of the output (often containing summaries
or error messages) are preserved.
"""
if max_chars <= 0 or len(text) <= max_chars:
return text

head_len = int(max_chars * 0.6)
tail_len = max_chars - head_len
notice = TOOL_TRUNCATION_NOTICE.format(
original_len=len(text), max_len=max_chars, head_len=head_len, tail_len=tail_len
)
return text[:head_len] + notice + text[-tail_len:]


def _handle_images(messages: list[dict[str, Any]], max_images: int) -> None:
image_count = 0
for msg in reversed(messages):
Expand Down Expand Up @@ -160,20 +185,78 @@ def __init__(
self.model_name = model_name or Config.get("strix_llm")
self.timeout = timeout or int(Config.get("strix_memory_compressor_timeout") or "120")

self.max_total_tokens = int(
Config.get("strix_max_context_tokens") or str(DEFAULT_MAX_TOTAL_TOKENS)
)
self.min_recent_messages = int(
Config.get("strix_min_recent_messages") or str(DEFAULT_MIN_RECENT_MESSAGES)
)
self.max_tool_output_chars = int(
Config.get("strix_max_tool_output_chars") or str(DEFAULT_MAX_TOOL_OUTPUT_CHARS)
)

if not self.model_name:
raise ValueError("STRIX_LLM environment variable must be set and not empty")

def truncate_tool_outputs(self, messages: list[dict[str, Any]]) -> None:
"""Truncate large tool output messages in-place.

This prevents oversized tool results (nmap scans, file contents, etc.)
from accumulating in the conversation history and being resent on every
subsequent LLM call. Applied at ingestion time before the history grows.

Only truncates tool-role messages and tool_result content blocks to
avoid corrupting system prompts or user/assistant messages.
"""
if self.max_tool_output_chars <= 0:
return

for msg in messages:
role = msg.get("role", "")
content = msg.get("content", "")

# Direct tool-role messages (string content)
if role == "tool" and isinstance(content, str) and len(content) > self.max_tool_output_chars:
msg["content"] = _truncate_tool_output(content, self.max_tool_output_chars)
# Anthropic-style: tool_result blocks embedded in user messages
elif isinstance(content, list):
for item in content:
if not isinstance(item, dict):
continue
if (
item.get("type") == "tool_result"
and isinstance(item.get("content"), str)
and len(item["content"]) > self.max_tool_output_chars
):
item["content"] = _truncate_tool_output(
item["content"], self.max_tool_output_chars
)
elif (
item.get("type") == "tool_result"
and isinstance(item.get("content"), list)
):
for sub in item["content"]:
if (
isinstance(sub, dict)
and sub.get("type") == "text"
and len(sub.get("text", "")) > self.max_tool_output_chars
):
sub["text"] = _truncate_tool_output(
sub["text"], self.max_tool_output_chars
)

def compress_history(
self,
messages: list[dict[str, Any]],
) -> list[dict[str, Any]]:
"""Compress conversation history to stay within token limits.

Strategy:
1. Handle image limits first
2. Keep all system messages
3. Keep minimum recent messages
4. Summarize older messages when total tokens exceed limit
1. Truncate oversized tool outputs first
2. Handle image limits
3. Keep all system messages
4. Keep minimum recent messages
5. Summarize older messages when total tokens exceed limit

The compression preserves:
- All system messages unchanged
Expand All @@ -185,6 +268,7 @@ def compress_history(
if not messages:
return messages

self.truncate_tool_outputs(messages)
_handle_images(messages, self.max_images)

system_msgs = []
Expand All @@ -195,8 +279,8 @@ def compress_history(
else:
regular_msgs.append(msg)

recent_msgs = regular_msgs[-MIN_RECENT_MESSAGES:]
old_msgs = regular_msgs[:-MIN_RECENT_MESSAGES]
recent_msgs = regular_msgs[-self.min_recent_messages:]
old_msgs = regular_msgs[:-self.min_recent_messages]

# Type assertion since we ensure model_name is not None in __init__
model_name: str = self.model_name # type: ignore[assignment]
Expand All @@ -205,7 +289,7 @@ def compress_history(
_get_message_tokens(msg, model_name) for msg in system_msgs + regular_msgs
)

if total_tokens <= MAX_TOTAL_TOKENS * 0.9:
if total_tokens <= self.max_total_tokens * 0.9:
return messages

compressed = []
Expand Down