diff --git a/strix/llm/llm.py b/strix/llm/llm.py index 4f624956a..5cabf227e 100644 --- a/strix/llm/llm.py +++ b/strix/llm/llm.py @@ -1,4 +1,5 @@ import asyncio +import re from collections.abc import AsyncIterator from dataclasses import dataclass from typing import Any @@ -26,6 +27,12 @@ litellm.modify_params = True +_TOOL_RESULT_PATTERN = re.compile( + r"(\s*[^<]*\s*)(.*?)(\s*)", + re.DOTALL, +) + + class LLMRequestFailedError(Exception): def __init__(self, message: str, details: str | None = None): super().__init__(message) @@ -159,12 +166,33 @@ async def generate( messages = self._prepare_messages(conversation_history) max_retries = int(Config.get("strix_llm_max_retries") or "5") + bad_request_retried = False + for attempt in range(max_retries + 1): try: async for response in self._stream(messages): yield response return # noqa: TRY300 except Exception as e: # noqa: BLE001 + if self._is_bad_request(e): + if not bad_request_retried: + bad_request_retried = True + if attempt >= max_retries: + self._raise_error(e) + await asyncio.sleep(2) + continue + truncate_enabled = Config.get("strix_truncate_on_oversize") or "" + if ( + truncate_enabled.lower() in ("1", "true", "yes") + and self._truncate_large_tool_results(messages) + ): + if attempt >= max_retries: + self._raise_error(e) + # Pace the provider — matches the 2s sleep on the bare-retry + # path so a throttled provider isn't hit back-to-back after the + # original 400. + await asyncio.sleep(2) + continue if attempt >= max_retries or not self._should_retry(e): self._raise_error(e) wait = min(90, 2 * (2**attempt)) @@ -314,6 +342,60 @@ def _extract_cost(self, response: Any) -> float: except Exception: # noqa: BLE001 return 0.0 + @staticmethod + def _truncate_large_tool_results( + messages: list[dict[str, Any]], + threshold_chars: int = 2000, + truncate_to_chars: int = 1000, + ) -> bool: + """Truncate large tool_result XML blocks to recover from BadRequestError. + + Scans all messages for tool_result blocks whose body exceeds threshold_chars + and shrinks them to truncate_to_chars. Called repeatedly on each 400 until it + returns False (nothing left to truncate). + + threshold_chars and truncate_to_chars are independent: the threshold decides + which blocks qualify for truncation, and truncate_to_chars is the size of the + retained prefix. They are not the same value to allow aggressive shrinking of + blocks that are well over the threshold without re-processing blocks that are + already acceptable. + """ + truncated_any = False + + def _truncate_match(m: re.Match) -> str: + nonlocal truncated_any + prefix, body, suffix = m.group(1), m.group(2), m.group(3) + if len(body) <= threshold_chars: + return m.group(0) + truncated_any = True + kept = body[:truncate_to_chars] + return ( + f"{prefix}{kept}\n\n... [content truncated from {len(body)} to {len(kept)} chars " + f"due to request size limit — file requires manual review] ...{suffix}" + ) + + for msg in reversed(messages): + content = msg.get("content") + + if isinstance(content, list): + for block in content: + if ( + block.get("type") == "text" + and isinstance(block.get("text"), str) + and "" in block["text"] + ): + block["text"] = _TOOL_RESULT_PATTERN.sub(_truncate_match, block["text"]) + elif isinstance(content, str) and "" in content: + msg["content"] = _TOOL_RESULT_PATTERN.sub(_truncate_match, content) + + return truncated_any + + def _is_bad_request(self, e: Exception) -> bool: + code = getattr(e, "status_code", None) or getattr( + getattr(e, "response", None), "status_code", None + ) + return code == 400 + def _should_retry(self, e: Exception) -> bool: code = getattr(e, "status_code", None) or getattr( getattr(e, "response", None), "status_code", None