Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 82 additions & 0 deletions strix/llm/llm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import re
from collections.abc import AsyncIterator
from dataclasses import dataclass
from typing import Any
Expand Down Expand Up @@ -26,6 +27,12 @@
litellm.modify_params = True


_TOOL_RESULT_PATTERN = re.compile(
r"(<tool_result>\s*<tool_name>[^<]*</tool_name>\s*<result>)(.*?)(</result>\s*</tool_result>)",
re.DOTALL,
)


class LLMRequestFailedError(Exception):
def __init__(self, message: str, details: str | None = None):
super().__init__(message)
Expand Down Expand Up @@ -159,12 +166,33 @@ async def generate(
messages = self._prepare_messages(conversation_history)
max_retries = int(Config.get("strix_llm_max_retries") or "5")

bad_request_retried = False

for attempt in range(max_retries + 1):
try:
async for response in self._stream(messages):
yield response
return # noqa: TRY300
except Exception as e: # noqa: BLE001
if self._is_bad_request(e):
if not bad_request_retried:
bad_request_retried = True
if attempt >= max_retries:
self._raise_error(e)
await asyncio.sleep(2)
continue
truncate_enabled = Config.get("strix_truncate_on_oversize") or ""
if (
truncate_enabled.lower() in ("1", "true", "yes")
and self._truncate_large_tool_results(messages)
):
if attempt >= max_retries:
self._raise_error(e)
# Pace the provider — matches the 2s sleep on the bare-retry
# path so a throttled provider isn't hit back-to-back after the
# original 400.
await asyncio.sleep(2)
continue
if attempt >= max_retries or not self._should_retry(e):
self._raise_error(e)
wait = min(90, 2 * (2**attempt))
Expand Down Expand Up @@ -314,6 +342,60 @@ def _extract_cost(self, response: Any) -> float:
except Exception: # noqa: BLE001
return 0.0

@staticmethod
def _truncate_large_tool_results(
messages: list[dict[str, Any]],
threshold_chars: int = 2000,
truncate_to_chars: int = 1000,
) -> bool:
"""Truncate large tool_result XML blocks to recover from BadRequestError.

Scans all messages for tool_result blocks whose body exceeds threshold_chars
and shrinks them to truncate_to_chars. Called repeatedly on each 400 until it
returns False (nothing left to truncate).

threshold_chars and truncate_to_chars are independent: the threshold decides
which blocks qualify for truncation, and truncate_to_chars is the size of the
retained prefix. They are not the same value to allow aggressive shrinking of
blocks that are well over the threshold without re-processing blocks that are
already acceptable.
"""
truncated_any = False

def _truncate_match(m: re.Match) -> str:
nonlocal truncated_any
prefix, body, suffix = m.group(1), m.group(2), m.group(3)
if len(body) <= threshold_chars:
return m.group(0)
truncated_any = True
kept = body[:truncate_to_chars]
return (
f"{prefix}{kept}\n\n... [content truncated from {len(body)} to {len(kept)} chars "
f"due to request size limit — file requires manual review] ...{suffix}"
)

for msg in reversed(messages):
content = msg.get("content")

if isinstance(content, list):
for block in content:
if (
block.get("type") == "text"
and isinstance(block.get("text"), str)
and "<tool_result>" in block["text"]
):
block["text"] = _TOOL_RESULT_PATTERN.sub(_truncate_match, block["text"])
elif isinstance(content, str) and "<tool_result>" in content:
msg["content"] = _TOOL_RESULT_PATTERN.sub(_truncate_match, content)

return truncated_any

def _is_bad_request(self, e: Exception) -> bool:
code = getattr(e, "status_code", None) or getattr(
getattr(e, "response", None), "status_code", None
)
return code == 400

def _should_retry(self, e: Exception) -> bool:
code = getattr(e, "status_code", None) or getattr(
getattr(e, "response", None), "status_code", None
Expand Down