diff --git a/strix/config/config.py b/strix/config/config.py index 782101ddb..cdfbc3f87 100644 --- a/strix/config/config.py +++ b/strix/config/config.py @@ -21,6 +21,9 @@ class Config: strix_reasoning_effort = "high" strix_llm_max_retries = "5" strix_memory_compressor_timeout = "30" + strix_max_context_tokens = None # Default: 100000 + strix_min_recent_messages = None # Default: 15 + strix_max_tool_output_chars = None # Default: 0 (no truncation) llm_timeout = "300" _LLM_CANONICAL_NAMES = ( "strix_llm", diff --git a/strix/llm/llm.py b/strix/llm/llm.py index 4f624956a..714666581 100644 --- a/strix/llm/llm.py +++ b/strix/llm/llm.py @@ -360,11 +360,18 @@ def _strip_images(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]: return result def _add_cache_control(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]: + """Add cache_control breakpoints to stable message segments. + + Caches the system prompt and the agent identity message since these + are identical across every iteration within an agent's lifetime. + Cache hits cost ~90% less than re-processing on Anthropic models. + """ if not messages or not supports_prompt_caching(self.config.canonical_model): return messages result = list(messages) + # Cache breakpoint 1: system prompt (unchanged across all iterations) if result[0].get("role") == "system": content = result[0]["content"] result[0] = { @@ -375,4 +382,21 @@ def _add_cache_control(self, messages: list[dict[str, Any]]) -> list[dict[str, A if isinstance(content, str) else content, } + + # Cache breakpoint 2: agent identity message (stable per-agent) + if len(result) > 1 and "" in str(result[1].get("content", "")): + content = result[1]["content"] + if isinstance(content, str): + result[1] = { + **result[1], + "content": [ + {"type": "text", "text": content, "cache_control": {"type": "ephemeral"}} + ], + } + elif isinstance(content, list) and content: + # Content is already a list — add cache_control to the last item + last = content[-1] + if isinstance(last, dict): + last["cache_control"] = {"type": "ephemeral"} + return result diff --git a/strix/llm/memory_compressor.py b/strix/llm/memory_compressor.py index 8cad51078..a19be5135 100644 --- a/strix/llm/memory_compressor.py +++ b/strix/llm/memory_compressor.py @@ -9,8 +9,15 @@ logger = logging.getLogger(__name__) -MAX_TOTAL_TOKENS = 100_000 -MIN_RECENT_MESSAGES = 15 +DEFAULT_MAX_TOTAL_TOKENS = 100_000 +DEFAULT_MIN_RECENT_MESSAGES = 15 +DEFAULT_MAX_TOOL_OUTPUT_CHARS = 0 # 0 = no truncation (backwards compatible) + +TOOL_TRUNCATION_NOTICE = ( + "\n\n[Output truncated: showing first {head_len} and last {tail_len} characters " + "of {original_len}-character output (limit: {max_len}). " + "The middle portion has been permanently removed.]" +) SUMMARY_PROMPT_TEMPLATE = """You are an agent performing context condensation for a security agent. Your job is to compress scan data while preserving @@ -131,6 +138,24 @@ def _summarize_messages( return messages[0] +def _truncate_tool_output(text: str, max_chars: int) -> str: + """Truncate large tool outputs while preserving the beginning and end. + + Keeps the first 60% and last 40% of the allowed length so that both + the command/header and the tail of the output (often containing summaries + or error messages) are preserved. + """ + if max_chars <= 0 or len(text) <= max_chars: + return text + + head_len = int(max_chars * 0.6) + tail_len = max_chars - head_len + notice = TOOL_TRUNCATION_NOTICE.format( + original_len=len(text), max_len=max_chars, head_len=head_len, tail_len=tail_len + ) + return text[:head_len] + notice + text[-tail_len:] + + def _handle_images(messages: list[dict[str, Any]], max_images: int) -> None: image_count = 0 for msg in reversed(messages): @@ -160,9 +185,66 @@ def __init__( self.model_name = model_name or Config.get("strix_llm") self.timeout = timeout or int(Config.get("strix_memory_compressor_timeout") or "120") + self.max_total_tokens = int( + Config.get("strix_max_context_tokens") or str(DEFAULT_MAX_TOTAL_TOKENS) + ) + self.min_recent_messages = int( + Config.get("strix_min_recent_messages") or str(DEFAULT_MIN_RECENT_MESSAGES) + ) + self.max_tool_output_chars = int( + Config.get("strix_max_tool_output_chars") or str(DEFAULT_MAX_TOOL_OUTPUT_CHARS) + ) + if not self.model_name: raise ValueError("STRIX_LLM environment variable must be set and not empty") + def truncate_tool_outputs(self, messages: list[dict[str, Any]]) -> None: + """Truncate large tool output messages in-place. + + This prevents oversized tool results (nmap scans, file contents, etc.) + from accumulating in the conversation history and being resent on every + subsequent LLM call. Applied at ingestion time before the history grows. + + Only truncates tool-role messages and tool_result content blocks to + avoid corrupting system prompts or user/assistant messages. + """ + if self.max_tool_output_chars <= 0: + return + + for msg in messages: + role = msg.get("role", "") + content = msg.get("content", "") + + # Direct tool-role messages (string content) + if role == "tool" and isinstance(content, str) and len(content) > self.max_tool_output_chars: + msg["content"] = _truncate_tool_output(content, self.max_tool_output_chars) + # Anthropic-style: tool_result blocks embedded in user messages + elif isinstance(content, list): + for item in content: + if not isinstance(item, dict): + continue + if ( + item.get("type") == "tool_result" + and isinstance(item.get("content"), str) + and len(item["content"]) > self.max_tool_output_chars + ): + item["content"] = _truncate_tool_output( + item["content"], self.max_tool_output_chars + ) + elif ( + item.get("type") == "tool_result" + and isinstance(item.get("content"), list) + ): + for sub in item["content"]: + if ( + isinstance(sub, dict) + and sub.get("type") == "text" + and len(sub.get("text", "")) > self.max_tool_output_chars + ): + sub["text"] = _truncate_tool_output( + sub["text"], self.max_tool_output_chars + ) + def compress_history( self, messages: list[dict[str, Any]], @@ -170,10 +252,11 @@ def compress_history( """Compress conversation history to stay within token limits. Strategy: - 1. Handle image limits first - 2. Keep all system messages - 3. Keep minimum recent messages - 4. Summarize older messages when total tokens exceed limit + 1. Truncate oversized tool outputs first + 2. Handle image limits + 3. Keep all system messages + 4. Keep minimum recent messages + 5. Summarize older messages when total tokens exceed limit The compression preserves: - All system messages unchanged @@ -185,6 +268,7 @@ def compress_history( if not messages: return messages + self.truncate_tool_outputs(messages) _handle_images(messages, self.max_images) system_msgs = [] @@ -195,8 +279,8 @@ def compress_history( else: regular_msgs.append(msg) - recent_msgs = regular_msgs[-MIN_RECENT_MESSAGES:] - old_msgs = regular_msgs[:-MIN_RECENT_MESSAGES] + recent_msgs = regular_msgs[-self.min_recent_messages:] + old_msgs = regular_msgs[:-self.min_recent_messages] # Type assertion since we ensure model_name is not None in __init__ model_name: str = self.model_name # type: ignore[assignment] @@ -205,7 +289,7 @@ def compress_history( _get_message_tokens(msg, model_name) for msg in system_msgs + regular_msgs ) - if total_tokens <= MAX_TOTAL_TOKENS * 0.9: + if total_tokens <= self.max_total_tokens * 0.9: return messages compressed = []