diff --git a/docs/advanced/configuration.mdx b/docs/advanced/configuration.mdx index 4d51f3c62..71015b3f1 100644 --- a/docs/advanced/configuration.mdx +++ b/docs/advanced/configuration.mdx @@ -35,6 +35,11 @@ Configure Strix using environment variables or a config file. Timeout in seconds for memory compression operations (context summarization). + + Custom HTTP headers to include in every LiteLLM request. Accepts a JSON object, + e.g. `{"x-my-header": "value"}`. + + ## Optional Features diff --git a/strix/config/config.py b/strix/config/config.py index 255df7c66..5e8cc5782 100644 --- a/strix/config/config.py +++ b/strix/config/config.py @@ -18,6 +18,7 @@ class Config: openai_api_base = None litellm_base_url = None ollama_api_base = None + llm_extra_headers = None strix_reasoning_effort = "high" strix_llm_max_retries = "5" strix_memory_compressor_timeout = "30" @@ -29,6 +30,7 @@ class Config: "openai_api_base", "litellm_base_url", "ollama_api_base", + "llm_extra_headers", "strix_reasoning_effort", "strix_llm_max_retries", "strix_memory_compressor_timeout", @@ -196,18 +198,20 @@ def save_current_config() -> bool: return Config.save_current() -def resolve_llm_config() -> tuple[str | None, str | None, str | None]: - """Resolve LLM model, api_key, and api_base based on STRIX_LLM prefix. +def resolve_llm_config() -> tuple[str | None, str | None, str | None, dict[str, str] | None]: + """Resolve LLM model, api_key, api_base based on STRIX_LLM prefix + and extra_headers for LiteLLM calls. Returns: - tuple: (model_name, api_key, api_base) + tuple: (model_name, api_key, api_base, extra_headers) - model_name: Original model name (strix/ prefix preserved for display) - api_key: LLM API key - api_base: API base URL (auto-set to STRIX_API_BASE for strix/ models) + - extra_headers : Custom headers """ model = Config.get("strix_llm") if not model: - return None, None, None + return None, None, None, None api_key = Config.get("llm_api_key") @@ -221,4 +225,16 @@ def resolve_llm_config() -> tuple[str | None, str | None, str | None]: or Config.get("ollama_api_base") ) - return model, api_key, api_base + extra_headers: dict[str, str] = {} + raw_headers = Config.get("llm_extra_headers") or "" + if raw_headers.strip(): + try: + parsed = json.loads(raw_headers) + if isinstance(parsed, dict): + extra_headers = {str(k): str(v) for k, v in parsed.items() if v is not None} + else: + raise TypeError("LLM_EXTRA_HEADERS must be a JSON object") + except json.JSONDecodeError as e: + raise ValueError(f"Invalid LLM_EXTRA_HEADERS JSON: {e}") from e + + return model, api_key, api_base, extra_headers diff --git a/strix/interface/main.py b/strix/interface/main.py index bc88da673..ce8060a4d 100644 --- a/strix/interface/main.py +++ b/strix/interface/main.py @@ -5,6 +5,7 @@ import argparse import asyncio +import json import logging import os import shutil @@ -82,6 +83,36 @@ def validate_environment() -> None: # noqa: PLR0912, PLR0915 if not Config.get("strix_reasoning_effort"): missing_optional_vars.append("STRIX_REASONING_EFFORT") + raw_headers = Config.get("llm_extra_headers") or "" + if raw_headers.strip(): + try: + parsed = json.loads(raw_headers) + if not isinstance(parsed, dict): + raise TypeError("LLM_EXTRA_HEADERS must be a JSON object, got a non-dict value") + except (json.JSONDecodeError, ValueError) as e: + error_text = Text() + error_text.append("INVALID LLM_EXTRA_HEADERS", style="bold red") + error_text.append("\n\n", style="white") + error_text.append("LLM_EXTRA_HEADERS must be a valid JSON object.\n", style="white") + error_text.append(f"Error: {e}\n", style="white") + error_text.append("\nExample:\n", style="white") + error_text.append( + 'export LLM_EXTRA_HEADERS={"x-my-header": "value"}\n', + style="dim white", + ) + + panel = Panel( + error_text, + title="[bold white]STRIX", + title_align="left", + border_style="red", + padding=(1, 2), + ) + console.print("\n") + console.print(panel) + console.print() + sys.exit(1) + if missing_required_vars: error_text = Text() error_text.append("MISSING REQUIRED ENVIRONMENT VARIABLES", style="bold red") @@ -208,7 +239,7 @@ async def warm_up_llm() -> None: console = Console() try: - model_name, api_key, api_base = resolve_llm_config() + model_name, api_key, api_base, extra_headers = resolve_llm_config() litellm_model, _ = resolve_strix_model(model_name) litellm_model = litellm_model or model_name @@ -228,6 +259,8 @@ async def warm_up_llm() -> None: completion_kwargs["api_key"] = api_key if api_base: completion_kwargs["api_base"] = api_base + if extra_headers: + completion_kwargs["extra_headers"] = extra_headers response = litellm.completion(**completion_kwargs) diff --git a/strix/llm/config.py b/strix/llm/config.py index 017c77662..d3bc1fe88 100644 --- a/strix/llm/config.py +++ b/strix/llm/config.py @@ -18,7 +18,7 @@ def __init__( reasoning_effort: str | None = None, system_prompt_context: dict[str, Any] | None = None, ): - resolved_model, self.api_key, self.api_base = resolve_llm_config() + resolved_model, self.api_key, self.api_base, self.extra_headers = resolve_llm_config() self.model_name = model_name or resolved_model if not self.model_name: diff --git a/strix/llm/dedupe.py b/strix/llm/dedupe.py index 0ea608850..b191ac38a 100644 --- a/strix/llm/dedupe.py +++ b/strix/llm/dedupe.py @@ -156,7 +156,7 @@ def check_duplicate( comparison_data = {"candidate": candidate_cleaned, "existing_reports": existing_cleaned} - model_name, api_key, api_base = resolve_llm_config() + model_name, api_key, api_base, extra_headers = resolve_llm_config() litellm_model, _ = resolve_strix_model(model_name) litellm_model = litellm_model or model_name @@ -181,6 +181,8 @@ def check_duplicate( completion_kwargs["api_key"] = api_key if api_base: completion_kwargs["api_base"] = api_base + if extra_headers: + completion_kwargs["extra_headers"] = extra_headers response = litellm.completion(**completion_kwargs) diff --git a/strix/llm/llm.py b/strix/llm/llm.py index 6fd727d5f..92290ab9c 100644 --- a/strix/llm/llm.py +++ b/strix/llm/llm.py @@ -286,6 +286,8 @@ def _build_completion_args(self, messages: list[dict[str, Any]]) -> dict[str, An args["api_key"] = self.config.api_key if self.config.api_base: args["api_base"] = self.config.api_base + if self.config.extra_headers: + args["extra_headers"] = self.config.extra_headers if self._supports_reasoning(): args["reasoning_effort"] = self._reasoning_effort diff --git a/strix/llm/memory_compressor.py b/strix/llm/memory_compressor.py index aea086c1c..27b7b82ec 100644 --- a/strix/llm/memory_compressor.py +++ b/strix/llm/memory_compressor.py @@ -104,7 +104,7 @@ def _summarize_messages( conversation = "\n".join(formatted) prompt = SUMMARY_PROMPT_TEMPLATE.format(conversation=conversation) - _, api_key, api_base = resolve_llm_config() + _, api_key, api_base, extra_headers = resolve_llm_config() try: completion_args: dict[str, Any] = { @@ -116,6 +116,8 @@ def _summarize_messages( completion_args["api_key"] = api_key if api_base: completion_args["api_base"] = api_base + if extra_headers: + completion_args["extra_headers"] = extra_headers response = litellm.completion(**completion_args) summary = response.choices[0].message.content or ""