diff --git a/docs/advanced/configuration.mdx b/docs/advanced/configuration.mdx
index 4d51f3c62..71015b3f1 100644
--- a/docs/advanced/configuration.mdx
+++ b/docs/advanced/configuration.mdx
@@ -35,6 +35,11 @@ Configure Strix using environment variables or a config file.
Timeout in seconds for memory compression operations (context summarization).
+
+ Custom HTTP headers to include in every LiteLLM request. Accepts a JSON object,
+ e.g. `{"x-my-header": "value"}`.
+
+
## Optional Features
diff --git a/strix/config/config.py b/strix/config/config.py
index 255df7c66..5e8cc5782 100644
--- a/strix/config/config.py
+++ b/strix/config/config.py
@@ -18,6 +18,7 @@ class Config:
openai_api_base = None
litellm_base_url = None
ollama_api_base = None
+ llm_extra_headers = None
strix_reasoning_effort = "high"
strix_llm_max_retries = "5"
strix_memory_compressor_timeout = "30"
@@ -29,6 +30,7 @@ class Config:
"openai_api_base",
"litellm_base_url",
"ollama_api_base",
+ "llm_extra_headers",
"strix_reasoning_effort",
"strix_llm_max_retries",
"strix_memory_compressor_timeout",
@@ -196,18 +198,20 @@ def save_current_config() -> bool:
return Config.save_current()
-def resolve_llm_config() -> tuple[str | None, str | None, str | None]:
- """Resolve LLM model, api_key, and api_base based on STRIX_LLM prefix.
+def resolve_llm_config() -> tuple[str | None, str | None, str | None, dict[str, str] | None]:
+ """Resolve LLM model, api_key, api_base based on STRIX_LLM prefix
+ and extra_headers for LiteLLM calls.
Returns:
- tuple: (model_name, api_key, api_base)
+ tuple: (model_name, api_key, api_base, extra_headers)
- model_name: Original model name (strix/ prefix preserved for display)
- api_key: LLM API key
- api_base: API base URL (auto-set to STRIX_API_BASE for strix/ models)
+ - extra_headers : Custom headers
"""
model = Config.get("strix_llm")
if not model:
- return None, None, None
+ return None, None, None, None
api_key = Config.get("llm_api_key")
@@ -221,4 +225,16 @@ def resolve_llm_config() -> tuple[str | None, str | None, str | None]:
or Config.get("ollama_api_base")
)
- return model, api_key, api_base
+ extra_headers: dict[str, str] = {}
+ raw_headers = Config.get("llm_extra_headers") or ""
+ if raw_headers.strip():
+ try:
+ parsed = json.loads(raw_headers)
+ if isinstance(parsed, dict):
+ extra_headers = {str(k): str(v) for k, v in parsed.items() if v is not None}
+ else:
+ raise TypeError("LLM_EXTRA_HEADERS must be a JSON object")
+ except json.JSONDecodeError as e:
+ raise ValueError(f"Invalid LLM_EXTRA_HEADERS JSON: {e}") from e
+
+ return model, api_key, api_base, extra_headers
diff --git a/strix/interface/main.py b/strix/interface/main.py
index bc88da673..ce8060a4d 100644
--- a/strix/interface/main.py
+++ b/strix/interface/main.py
@@ -5,6 +5,7 @@
import argparse
import asyncio
+import json
import logging
import os
import shutil
@@ -82,6 +83,36 @@ def validate_environment() -> None: # noqa: PLR0912, PLR0915
if not Config.get("strix_reasoning_effort"):
missing_optional_vars.append("STRIX_REASONING_EFFORT")
+ raw_headers = Config.get("llm_extra_headers") or ""
+ if raw_headers.strip():
+ try:
+ parsed = json.loads(raw_headers)
+ if not isinstance(parsed, dict):
+ raise TypeError("LLM_EXTRA_HEADERS must be a JSON object, got a non-dict value")
+ except (json.JSONDecodeError, ValueError) as e:
+ error_text = Text()
+ error_text.append("INVALID LLM_EXTRA_HEADERS", style="bold red")
+ error_text.append("\n\n", style="white")
+ error_text.append("LLM_EXTRA_HEADERS must be a valid JSON object.\n", style="white")
+ error_text.append(f"Error: {e}\n", style="white")
+ error_text.append("\nExample:\n", style="white")
+ error_text.append(
+ 'export LLM_EXTRA_HEADERS={"x-my-header": "value"}\n',
+ style="dim white",
+ )
+
+ panel = Panel(
+ error_text,
+ title="[bold white]STRIX",
+ title_align="left",
+ border_style="red",
+ padding=(1, 2),
+ )
+ console.print("\n")
+ console.print(panel)
+ console.print()
+ sys.exit(1)
+
if missing_required_vars:
error_text = Text()
error_text.append("MISSING REQUIRED ENVIRONMENT VARIABLES", style="bold red")
@@ -208,7 +239,7 @@ async def warm_up_llm() -> None:
console = Console()
try:
- model_name, api_key, api_base = resolve_llm_config()
+ model_name, api_key, api_base, extra_headers = resolve_llm_config()
litellm_model, _ = resolve_strix_model(model_name)
litellm_model = litellm_model or model_name
@@ -228,6 +259,8 @@ async def warm_up_llm() -> None:
completion_kwargs["api_key"] = api_key
if api_base:
completion_kwargs["api_base"] = api_base
+ if extra_headers:
+ completion_kwargs["extra_headers"] = extra_headers
response = litellm.completion(**completion_kwargs)
diff --git a/strix/llm/config.py b/strix/llm/config.py
index 017c77662..d3bc1fe88 100644
--- a/strix/llm/config.py
+++ b/strix/llm/config.py
@@ -18,7 +18,7 @@ def __init__(
reasoning_effort: str | None = None,
system_prompt_context: dict[str, Any] | None = None,
):
- resolved_model, self.api_key, self.api_base = resolve_llm_config()
+ resolved_model, self.api_key, self.api_base, self.extra_headers = resolve_llm_config()
self.model_name = model_name or resolved_model
if not self.model_name:
diff --git a/strix/llm/dedupe.py b/strix/llm/dedupe.py
index 0ea608850..b191ac38a 100644
--- a/strix/llm/dedupe.py
+++ b/strix/llm/dedupe.py
@@ -156,7 +156,7 @@ def check_duplicate(
comparison_data = {"candidate": candidate_cleaned, "existing_reports": existing_cleaned}
- model_name, api_key, api_base = resolve_llm_config()
+ model_name, api_key, api_base, extra_headers = resolve_llm_config()
litellm_model, _ = resolve_strix_model(model_name)
litellm_model = litellm_model or model_name
@@ -181,6 +181,8 @@ def check_duplicate(
completion_kwargs["api_key"] = api_key
if api_base:
completion_kwargs["api_base"] = api_base
+ if extra_headers:
+ completion_kwargs["extra_headers"] = extra_headers
response = litellm.completion(**completion_kwargs)
diff --git a/strix/llm/llm.py b/strix/llm/llm.py
index 6fd727d5f..92290ab9c 100644
--- a/strix/llm/llm.py
+++ b/strix/llm/llm.py
@@ -286,6 +286,8 @@ def _build_completion_args(self, messages: list[dict[str, Any]]) -> dict[str, An
args["api_key"] = self.config.api_key
if self.config.api_base:
args["api_base"] = self.config.api_base
+ if self.config.extra_headers:
+ args["extra_headers"] = self.config.extra_headers
if self._supports_reasoning():
args["reasoning_effort"] = self._reasoning_effort
diff --git a/strix/llm/memory_compressor.py b/strix/llm/memory_compressor.py
index aea086c1c..27b7b82ec 100644
--- a/strix/llm/memory_compressor.py
+++ b/strix/llm/memory_compressor.py
@@ -104,7 +104,7 @@ def _summarize_messages(
conversation = "\n".join(formatted)
prompt = SUMMARY_PROMPT_TEMPLATE.format(conversation=conversation)
- _, api_key, api_base = resolve_llm_config()
+ _, api_key, api_base, extra_headers = resolve_llm_config()
try:
completion_args: dict[str, Any] = {
@@ -116,6 +116,8 @@ def _summarize_messages(
completion_args["api_key"] = api_key
if api_base:
completion_args["api_base"] = api_base
+ if extra_headers:
+ completion_args["extra_headers"] = extra_headers
response = litellm.completion(**completion_args)
summary = response.choices[0].message.content or ""