Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions strix/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,12 +159,28 @@ async def generate(
messages = self._prepare_messages(conversation_history)
max_retries = int(Config.get("strix_llm_max_retries") or "5")

bad_request_retried = False
bad_request_truncated = False

for attempt in range(max_retries + 1):
try:
async for response in self._stream(messages):
yield response
return # noqa: TRY300
except Exception as e: # noqa: BLE001
if self._is_bad_request(e):
if not bad_request_retried:
bad_request_retried = True
await asyncio.sleep(2)
continue
truncate_enabled = Config.get("strix_truncate_on_oversize") or ""
if (
not bad_request_truncated
and truncate_enabled.lower() in ("1", "true", "yes")
and self._truncate_large_tool_results(messages)
):
bad_request_truncated = True
continue
Comment thread
bearsyankees marked this conversation as resolved.
Comment thread
bearsyankees marked this conversation as resolved.
if attempt >= max_retries or not self._should_retry(e):
self._raise_error(e)
wait = min(90, 2 * (2**attempt))
Expand Down Expand Up @@ -314,6 +330,53 @@ def _extract_cost(self, response: Any) -> float:
except Exception: # noqa: BLE001
return 0.0

@staticmethod
def _truncate_large_tool_results(
messages: list[dict[str, Any]], max_chars: int = 2000
) -> bool:
"""Aggressively truncate large tool results in messages to recover from BadRequestError.

Scans messages in reverse for tool_result XML blocks that exceed max_chars and
replaces their content with a truncated version plus a skip notice. Returns True
if any truncation was performed (caller should retry the request).
"""
import re
Comment thread
bearsyankees marked this conversation as resolved.
Outdated

truncated_any = False
pattern = re.compile(
r"(<tool_result>\s*<tool_name>[^<]*</tool_name>\s*<result>)(.*?)(</result>\s*</tool_result>)",
re.DOTALL,
)

for msg in reversed(messages):
content = msg.get("content")
if not isinstance(content, str) or "<tool_result>" not in content:
continue

def _truncate_match(m: re.Match) -> str:
prefix, body, suffix = m.group(1), m.group(2), m.group(3)
if len(body) <= max_chars:
return m.group(0)
nonlocal truncated_any
truncated_any = True
kept = body[:1000]
return (
f"{prefix}{kept}\n\n... [content truncated from {len(body)} to {len(kept)} chars "
f"due to request size limit — file requires manual review] ...{suffix}"
)

msg["content"] = pattern.sub(_truncate_match, content)
if truncated_any:
break
Comment thread
bearsyankees marked this conversation as resolved.
Outdated

return truncated_any

def _is_bad_request(self, e: Exception) -> bool:
code = getattr(e, "status_code", None) or getattr(
getattr(e, "response", None), "status_code", None
)
return code == 400

def _should_retry(self, e: Exception) -> bool:
code = getattr(e, "status_code", None) or getattr(
getattr(e, "response", None), "status_code", None
Expand Down