Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/advanced/configuration.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@ Configure Strix using environment variables or a config file.
Timeout in seconds for memory compression operations (context summarization).
</ParamField>

<ParamField path="LLM_EXTRA_HEADERS" type="string">
Custom HTTP headers to include in every LiteLLM request. Accepts a JSON object,
e.g. `{"x-my-header": "value"}`.
</ParamField>

## Optional Features

<ParamField path="PERPLEXITY_API_KEY" type="string">
Expand Down
26 changes: 21 additions & 5 deletions strix/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class Config:
openai_api_base = None
litellm_base_url = None
ollama_api_base = None
llm_extra_headers = None
strix_reasoning_effort = "high"
strix_llm_max_retries = "5"
strix_memory_compressor_timeout = "30"
Expand All @@ -29,6 +30,7 @@ class Config:
"openai_api_base",
"litellm_base_url",
"ollama_api_base",
"llm_extra_headers",
"strix_reasoning_effort",
"strix_llm_max_retries",
"strix_memory_compressor_timeout",
Expand Down Expand Up @@ -196,18 +198,20 @@ def save_current_config() -> bool:
return Config.save_current()


def resolve_llm_config() -> tuple[str | None, str | None, str | None]:
"""Resolve LLM model, api_key, and api_base based on STRIX_LLM prefix.
def resolve_llm_config() -> tuple[str | None, str | None, str | None, dict[str, str] | None]:
"""Resolve LLM model, api_key, api_base based on STRIX_LLM prefix
and extra_headers for LiteLLM calls.

Returns:
tuple: (model_name, api_key, api_base)
tuple: (model_name, api_key, api_base, extra_headers)
- model_name: Original model name (strix/ prefix preserved for display)
- api_key: LLM API key
- api_base: API base URL (auto-set to STRIX_API_BASE for strix/ models)
- extra_headers : Custom headers
"""
model = Config.get("strix_llm")
if not model:
return None, None, None
return None, None, None, None

api_key = Config.get("llm_api_key")

Expand All @@ -221,4 +225,16 @@ def resolve_llm_config() -> tuple[str | None, str | None, str | None]:
or Config.get("ollama_api_base")
)

return model, api_key, api_base
extra_headers: dict[str, str] = {}
raw_headers = Config.get("llm_extra_headers") or ""
if raw_headers.strip():
try:
parsed = json.loads(raw_headers)
if isinstance(parsed, dict):
extra_headers = {str(k): str(v) for k, v in parsed.items() if v is not None}
else:
raise TypeError("LLM_EXTRA_HEADERS must be a JSON object")
except json.JSONDecodeError as e:
raise ValueError(f"Invalid LLM_EXTRA_HEADERS JSON: {e}") from e

return model, api_key, api_base, extra_headers
35 changes: 34 additions & 1 deletion strix/interface/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import argparse
import asyncio
import json
import logging
import os
import shutil
Expand Down Expand Up @@ -82,6 +83,36 @@ def validate_environment() -> None: # noqa: PLR0912, PLR0915
if not Config.get("strix_reasoning_effort"):
missing_optional_vars.append("STRIX_REASONING_EFFORT")

raw_headers = Config.get("llm_extra_headers") or ""
if raw_headers.strip():
try:
parsed = json.loads(raw_headers)
if not isinstance(parsed, dict):
raise TypeError("LLM_EXTRA_HEADERS must be a JSON object, got a non-dict value")
except (json.JSONDecodeError, ValueError) as e:
error_text = Text()
error_text.append("INVALID LLM_EXTRA_HEADERS", style="bold red")
error_text.append("\n\n", style="white")
error_text.append("LLM_EXTRA_HEADERS must be a valid JSON object.\n", style="white")
error_text.append(f"Error: {e}\n", style="white")
error_text.append("\nExample:\n", style="white")
error_text.append(
'export LLM_EXTRA_HEADERS={"x-my-header": "value"}\n',
style="dim white",
)

panel = Panel(
error_text,
title="[bold white]STRIX",
title_align="left",
border_style="red",
padding=(1, 2),
)
console.print("\n")
console.print(panel)
console.print()
sys.exit(1)

if missing_required_vars:
error_text = Text()
error_text.append("MISSING REQUIRED ENVIRONMENT VARIABLES", style="bold red")
Expand Down Expand Up @@ -208,7 +239,7 @@ async def warm_up_llm() -> None:
console = Console()

try:
model_name, api_key, api_base = resolve_llm_config()
model_name, api_key, api_base, extra_headers = resolve_llm_config()
litellm_model, _ = resolve_strix_model(model_name)
litellm_model = litellm_model or model_name

Expand All @@ -228,6 +259,8 @@ async def warm_up_llm() -> None:
completion_kwargs["api_key"] = api_key
if api_base:
completion_kwargs["api_base"] = api_base
if extra_headers:
completion_kwargs["extra_headers"] = extra_headers

response = litellm.completion(**completion_kwargs)

Expand Down
2 changes: 1 addition & 1 deletion strix/llm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def __init__(
reasoning_effort: str | None = None,
system_prompt_context: dict[str, Any] | None = None,
):
resolved_model, self.api_key, self.api_base = resolve_llm_config()
resolved_model, self.api_key, self.api_base, self.extra_headers = resolve_llm_config()
self.model_name = model_name or resolved_model

if not self.model_name:
Expand Down
4 changes: 3 additions & 1 deletion strix/llm/dedupe.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def check_duplicate(

comparison_data = {"candidate": candidate_cleaned, "existing_reports": existing_cleaned}

model_name, api_key, api_base = resolve_llm_config()
model_name, api_key, api_base, extra_headers = resolve_llm_config()
litellm_model, _ = resolve_strix_model(model_name)
litellm_model = litellm_model or model_name

Expand All @@ -181,6 +181,8 @@ def check_duplicate(
completion_kwargs["api_key"] = api_key
if api_base:
completion_kwargs["api_base"] = api_base
if extra_headers:
completion_kwargs["extra_headers"] = extra_headers

response = litellm.completion(**completion_kwargs)

Expand Down
2 changes: 2 additions & 0 deletions strix/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,8 @@ def _build_completion_args(self, messages: list[dict[str, Any]]) -> dict[str, An
args["api_key"] = self.config.api_key
if self.config.api_base:
args["api_base"] = self.config.api_base
if self.config.extra_headers:
args["extra_headers"] = self.config.extra_headers
if self._supports_reasoning():
args["reasoning_effort"] = self._reasoning_effort

Expand Down
4 changes: 3 additions & 1 deletion strix/llm/memory_compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def _summarize_messages(
conversation = "\n".join(formatted)
prompt = SUMMARY_PROMPT_TEMPLATE.format(conversation=conversation)

_, api_key, api_base = resolve_llm_config()
_, api_key, api_base, extra_headers = resolve_llm_config()

try:
completion_args: dict[str, Any] = {
Expand All @@ -116,6 +116,8 @@ def _summarize_messages(
completion_args["api_key"] = api_key
if api_base:
completion_args["api_base"] = api_base
if extra_headers:
completion_args["extra_headers"] = extra_headers

response = litellm.completion(**completion_args)
summary = response.choices[0].message.content or ""
Expand Down
Loading