diff --git a/strix/llm/llm.py b/strix/llm/llm.py
index 4f624956a..5cabf227e 100644
--- a/strix/llm/llm.py
+++ b/strix/llm/llm.py
@@ -1,4 +1,5 @@
import asyncio
+import re
from collections.abc import AsyncIterator
from dataclasses import dataclass
from typing import Any
@@ -26,6 +27,12 @@
litellm.modify_params = True
+_TOOL_RESULT_PATTERN = re.compile(
+ r"(\s*[^<]*\s*)(.*?)(\s*)",
+ re.DOTALL,
+)
+
+
class LLMRequestFailedError(Exception):
def __init__(self, message: str, details: str | None = None):
super().__init__(message)
@@ -159,12 +166,33 @@ async def generate(
messages = self._prepare_messages(conversation_history)
max_retries = int(Config.get("strix_llm_max_retries") or "5")
+ bad_request_retried = False
+
for attempt in range(max_retries + 1):
try:
async for response in self._stream(messages):
yield response
return # noqa: TRY300
except Exception as e: # noqa: BLE001
+ if self._is_bad_request(e):
+ if not bad_request_retried:
+ bad_request_retried = True
+ if attempt >= max_retries:
+ self._raise_error(e)
+ await asyncio.sleep(2)
+ continue
+ truncate_enabled = Config.get("strix_truncate_on_oversize") or ""
+ if (
+ truncate_enabled.lower() in ("1", "true", "yes")
+ and self._truncate_large_tool_results(messages)
+ ):
+ if attempt >= max_retries:
+ self._raise_error(e)
+ # Pace the provider — matches the 2s sleep on the bare-retry
+ # path so a throttled provider isn't hit back-to-back after the
+ # original 400.
+ await asyncio.sleep(2)
+ continue
if attempt >= max_retries or not self._should_retry(e):
self._raise_error(e)
wait = min(90, 2 * (2**attempt))
@@ -314,6 +342,60 @@ def _extract_cost(self, response: Any) -> float:
except Exception: # noqa: BLE001
return 0.0
+ @staticmethod
+ def _truncate_large_tool_results(
+ messages: list[dict[str, Any]],
+ threshold_chars: int = 2000,
+ truncate_to_chars: int = 1000,
+ ) -> bool:
+ """Truncate large tool_result XML blocks to recover from BadRequestError.
+
+ Scans all messages for tool_result blocks whose body exceeds threshold_chars
+ and shrinks them to truncate_to_chars. Called repeatedly on each 400 until it
+ returns False (nothing left to truncate).
+
+ threshold_chars and truncate_to_chars are independent: the threshold decides
+ which blocks qualify for truncation, and truncate_to_chars is the size of the
+ retained prefix. They are not the same value to allow aggressive shrinking of
+ blocks that are well over the threshold without re-processing blocks that are
+ already acceptable.
+ """
+ truncated_any = False
+
+ def _truncate_match(m: re.Match) -> str:
+ nonlocal truncated_any
+ prefix, body, suffix = m.group(1), m.group(2), m.group(3)
+ if len(body) <= threshold_chars:
+ return m.group(0)
+ truncated_any = True
+ kept = body[:truncate_to_chars]
+ return (
+ f"{prefix}{kept}\n\n... [content truncated from {len(body)} to {len(kept)} chars "
+ f"due to request size limit — file requires manual review] ...{suffix}"
+ )
+
+ for msg in reversed(messages):
+ content = msg.get("content")
+
+ if isinstance(content, list):
+ for block in content:
+ if (
+ block.get("type") == "text"
+ and isinstance(block.get("text"), str)
+ and "" in block["text"]
+ ):
+ block["text"] = _TOOL_RESULT_PATTERN.sub(_truncate_match, block["text"])
+ elif isinstance(content, str) and "" in content:
+ msg["content"] = _TOOL_RESULT_PATTERN.sub(_truncate_match, content)
+
+ return truncated_any
+
+ def _is_bad_request(self, e: Exception) -> bool:
+ code = getattr(e, "status_code", None) or getattr(
+ getattr(e, "response", None), "status_code", None
+ )
+ return code == 400
+
def _should_retry(self, e: Exception) -> bool:
code = getattr(e, "status_code", None) or getattr(
getattr(e, "response", None), "status_code", None