From 90979ac71f112802b3881c9c602d2305eccdbce1 Mon Sep 17 00:00:00 2001 From: RheagalFire Date: Tue, 2 Jun 2026 23:51:38 +0530 Subject: [PATCH 1/2] feat: add LiteLLM as unified LLM provider --- backend/consts/provider.py | 1 + backend/services/model_provider_service.py | 4 + .../services/providers/litellm_provider.py | 77 +++++++ sdk/nexent/core/models/__init__.py | 2 + sdk/nexent/core/models/litellm_llm.py | 201 ++++++++++++++++++ 5 files changed, 285 insertions(+) create mode 100644 backend/services/providers/litellm_provider.py create mode 100644 sdk/nexent/core/models/litellm_llm.py diff --git a/backend/consts/provider.py b/backend/consts/provider.py index 38bbc4027..6562ad2d4 100644 --- a/backend/consts/provider.py +++ b/backend/consts/provider.py @@ -8,6 +8,7 @@ class ProviderEnum(str, Enum): MODELENGINE = "modelengine" DASHSCOPE = "dashscope" TOKENPONY = "tokenpony" + LITELLM = "litellm" # Silicon Flow diff --git a/backend/services/model_provider_service.py b/backend/services/model_provider_service.py index dbff17082..8ac274d6e 100644 --- a/backend/services/model_provider_service.py +++ b/backend/services/model_provider_service.py @@ -12,6 +12,7 @@ from services.providers.base import AbstractModelProvider from services.providers.silicon_provider import SiliconModelProvider from services.providers.tokenpony_provider import TokenPonyModelProvider +from services.providers.litellm_provider import LiteLLMModelProvider from services.providers.dashscope_provider import DashScopeModelProvider from services.providers.modelengine_provider import ModelEngineProvider, get_model_engine_raw_url, MODEL_ENGINE_NORTH_PREFIX from utils.model_name_utils import split_repo_name, add_repo_to_name @@ -48,6 +49,9 @@ async def get_provider_models(model_data: dict) -> List[dict]: elif model_data["provider"] == ProviderEnum.TOKENPONY.value: provider = TokenPonyModelProvider() model_list = await provider.get_models(model_data) + elif model_data["provider"] == ProviderEnum.LITELLM.value: + provider = LiteLLMModelProvider() + model_list = await provider.get_models(model_data) return model_list diff --git a/backend/services/providers/litellm_provider.py b/backend/services/providers/litellm_provider.py new file mode 100644 index 000000000..580b755ab --- /dev/null +++ b/backend/services/providers/litellm_provider.py @@ -0,0 +1,77 @@ +import logging +from typing import Dict, List + +from consts.const import DEFAULT_LLM_MAX_TOKENS +from services.providers.base import AbstractModelProvider, _classify_provider_error + +logger = logging.getLogger("model_provider") + + +class LiteLLMModelProvider(AbstractModelProvider): + """Provider that discovers models via LiteLLM's /v1/models endpoint. + + LiteLLM supports 100+ LLM providers (OpenAI, Anthropic, Google Gemini, + Azure, Bedrock, Ollama, etc.) through a unified interface. When pointed + at a LiteLLM proxy, this provider fetches the available model catalog. + + For direct SDK usage (no proxy), users should add models manually with + the ``litellm`` provider and use LiteLLM model identifiers like + ``anthropic/claude-sonnet-4-20250514`` or ``gemini/gemini-2.5-flash``. + """ + + async def get_models(self, provider_config: Dict) -> List[Dict]: + """ + Fetch models from a LiteLLM-compatible /v1/models endpoint. + + Args: + provider_config: Configuration dict containing model_type, api_key, and base_url + + Returns: + List of models with canonical fields. + """ + import httpx + + try: + model_type: str = provider_config.get("model_type", "llm") + api_key: str = provider_config.get("api_key", "") + base_url: str = provider_config.get("base_url", "").rstrip("/") + + if not base_url: + return [] + + headers = {} + if api_key: + headers["Authorization"] = f"Bearer {api_key}" + + models_url = f"{base_url}/models" + + async with httpx.AsyncClient(verify=False, timeout=15.0) as client: + response = await client.get(models_url, headers=headers) + response.raise_for_status() + data = response.json().get("data", []) + + model_list = [] + for item in data: + model_id = item.get("id", "") + if not model_id: + continue + + model_entry = { + "id": model_id, + "model_type": model_type, + "max_tokens": DEFAULT_LLM_MAX_TOKENS, + } + + if model_type in ("llm", "vlm"): + model_entry["model_tag"] = "chat" + elif model_type in ("embedding", "multi_embedding"): + model_entry["model_tag"] = "embedding" + elif model_type == "rerank": + model_entry["model_tag"] = "rerank" + + model_list.append(model_entry) + + return model_list + + except Exception as e: + return _classify_provider_error("LiteLLM", exception=e) diff --git a/sdk/nexent/core/models/__init__.py b/sdk/nexent/core/models/__init__.py index fa15fb3d4..97ebf2282 100644 --- a/sdk/nexent/core/models/__init__.py +++ b/sdk/nexent/core/models/__init__.py @@ -1,6 +1,7 @@ from .openai_llm import OpenAIModel from .openai_vlm import OpenAIVLModel from .openai_long_context_model import OpenAILongContextModel +from .litellm_llm import LiteLLMModel from .stt_model import BaseSTTModel from .ali_stt_model import AliSTTModel, AliSTTConfig from .volc_stt_model import VolcSTTModel, VolcSTTConfig @@ -8,6 +9,7 @@ "OpenAIModel", "OpenAIVLModel", "OpenAILongContextModel", + "LiteLLMModel", "BaseSTTModel", "AliSTTModel", "AliSTTConfig", diff --git a/sdk/nexent/core/models/litellm_llm.py b/sdk/nexent/core/models/litellm_llm.py new file mode 100644 index 000000000..d56f7d411 --- /dev/null +++ b/sdk/nexent/core/models/litellm_llm.py @@ -0,0 +1,201 @@ +"""LiteLLM-backed LLM model for nexent. + +Provides access to 100+ LLM providers (OpenAI, Anthropic, Google Gemini, +Azure, Bedrock, Ollama, etc.) through ``litellm.completion()`` as an SDK +dependency. Follows the same interface as ``OpenAIModel``. +""" + +import logging +import threading +import time +from typing import Any, Dict, List, Optional + +from smolagents import Tool +from smolagents.models import ChatMessage, MessageRole + +from ..utils.observer import MessageObserver, ProcessType + +logger = logging.getLogger("litellm_llm") + + +class LiteLLMModel: + """LLM model backed by LiteLLM SDK. + + Uses ``litellm.completion()`` directly, supporting any model identifier + that LiteLLM recognizes (e.g. ``anthropic/claude-sonnet-4-20250514``, + ``gemini/gemini-2.5-flash``, ``azure/gpt-4o``). + + See https://docs.litellm.ai/docs/providers for the full provider list. + """ + + def __init__( + self, + model_id: str, + api_key: Optional[str] = None, + api_base: Optional[str] = None, + temperature: float = 0.2, + top_p: float = 0.95, + observer: MessageObserver = MessageObserver, + display_name: Optional[str] = None, + **kwargs: Any, + ): + self.model_id = model_id + self.api_key = api_key + self.api_base = api_base + self.temperature = temperature + self.top_p = top_p + self.observer = observer + self.display_name = display_name + self.stop_event = threading.Event() + self.last_input_token_count = 0 + self.last_output_token_count = 0 + + def __call__( + self, + messages: List[Dict[str, Any]], + stop_sequences: Optional[List[str]] = None, + response_format: Optional[Dict[str, str]] = None, + tools_to_call_from: Optional[List[Tool]] = None, + **kwargs: Any, + ) -> ChatMessage: + try: + import litellm + except ImportError as e: + raise ImportError( + "litellm is required for LiteLLMModel. " + "Install it with: pip install 'litellm>=1.80,<1.87'" + ) from e + + # Normalize messages to dicts + normalized: List[Dict[str, Any]] = [] + for msg in messages or []: + if isinstance(msg, ChatMessage): + normalized.append({ + "role": msg.role.value if hasattr(msg.role, "value") else str(msg.role), + "content": msg.content, + }) + elif isinstance(msg, dict): + normalized.append(msg) + else: + raise TypeError("Messages must be ChatMessage or dict objects.") + + completion_kwargs: Dict[str, Any] = { + "model": self.model_id, + "messages": normalized, + "temperature": self.temperature, + "stream": True, + "drop_params": True, + "stream_options": {"include_usage": True}, + } + + if self.api_key: + completion_kwargs["api_key"] = self.api_key + if self.api_base: + completion_kwargs["api_base"] = self.api_base + if stop_sequences: + completion_kwargs["stop"] = stop_sequences + if response_format: + completion_kwargs["response_format"] = response_format + + # Handle tool calling + if tools_to_call_from: + tool_definitions = [] + for tool in tools_to_call_from: + if hasattr(tool, "to_openai_tool"): + tool_definitions.append(tool.to_openai_tool()) + if tool_definitions: + completion_kwargs["tools"] = tool_definitions + + current_request = litellm.completion(**completion_kwargs) + + # Process streaming response + chunk_list = [] + token_join = [] + role = None + + self.observer.current_mode = ProcessType.MODEL_OUTPUT_THINKING + + try: + for chunk in current_request: + if not hasattr(chunk, "choices") or not chunk.choices: + chunk_list.append(chunk) + continue + + delta = chunk.choices[0].delta + new_token = getattr(delta, "content", None) + reasoning_content = getattr(delta, "reasoning_content", None) + + if reasoning_content is not None: + self.observer.add_model_reasoning_content(reasoning_content) + + if new_token is not None: + self.observer.add_model_new_token(new_token) + token_join.append(new_token) + role = getattr(delta, "role", None) or role + + chunk_list.append(chunk) + if self.stop_event.is_set(): + raise RuntimeError("Model is interrupted by stop event") + + self.observer.flush_remaining_tokens() + model_output = "".join(token_join) + + # Extract token usage from the last chunk + input_tokens = 0 + output_tokens = 0 + if chunk_list and hasattr(chunk_list[-1], "usage") and chunk_list[-1].usage is not None: + usage = chunk_list[-1].usage + input_tokens = getattr(usage, "prompt_tokens", 0) or 0 + output_tokens = getattr(usage, "completion_tokens", 0) or 0 + + self.last_input_token_count = input_tokens + self.last_output_token_count = output_tokens + + from openai.types.chat.chat_completion_message import ChatCompletionMessage + + message = ChatMessage.from_dict( + ChatCompletionMessage( + role=role if role else "assistant", + content=model_output, + ).model_dump(include={"role", "content", "tool_calls"}) + ) + + from smolagents.monitoring import TokenUsage + + if input_tokens > 0 or output_tokens > 0: + message.token_usage = TokenUsage( + input_tokens=input_tokens, + output_tokens=output_tokens, + ) + + message.raw = current_request + message.role = MessageRole.ASSISTANT + return message + + except Exception as e: + if "context_length_exceeded" in str(e): + raise ValueError(f"Token limit exceeded: {str(e)}") + raise + + async def check_connectivity(self) -> bool: + """Test if the LLM provider connection works.""" + try: + import litellm + import asyncio + + kwargs: Dict[str, Any] = { + "model": self.model_id, + "messages": [{"role": "user", "content": "Hello"}], + "max_tokens": 5, + "drop_params": True, + } + if self.api_key: + kwargs["api_key"] = self.api_key + if self.api_base: + kwargs["api_base"] = self.api_base + + await litellm.acompletion(**kwargs) + return True + except Exception as e: + logger.error(f"LiteLLM connectivity check failed: {e}") + return False From 776f886adfc40bf5116c0b329105290e72352d26 Mon Sep 17 00:00:00 2001 From: RheagalFire Date: Tue, 2 Jun 2026 23:58:19 +0530 Subject: [PATCH 2/2] test: add comprehensive tests for LiteLLMModel --- test/sdk/test_litellm_model.py | 213 +++++++++++++++++++++++++++++++++ 1 file changed, 213 insertions(+) create mode 100644 test/sdk/test_litellm_model.py diff --git a/test/sdk/test_litellm_model.py b/test/sdk/test_litellm_model.py new file mode 100644 index 000000000..8770ca623 --- /dev/null +++ b/test/sdk/test_litellm_model.py @@ -0,0 +1,213 @@ +"""Tests for LiteLLMModel SDK integration.""" + +import os +import sys +import types as _types +from types import SimpleNamespace +from unittest.mock import patch + +import pytest + +# Stub the heavy nexent import chain before importing LiteLLMModel +sdk_path = os.path.join(os.path.dirname(__file__), '..', '..', 'sdk') +sys.path.insert(0, sdk_path) + + +class _FakeObserver: + current_mode = None + @staticmethod + def add_model_new_token(t): pass + @staticmethod + def add_model_reasoning_content(t): pass + @staticmethod + def flush_remaining_tokens(): pass + + +class _FakeProcessType: + MODEL_OUTPUT_THINKING = "thinking" + + +# Stub observer module +_observer_mod = _types.ModuleType("nexent.core.utils.observer") +_observer_mod.MessageObserver = _FakeObserver +_observer_mod.ProcessType = _FakeProcessType + +# Build nexent package hierarchy as stubs +_models_dir = os.path.join(sdk_path, "nexent", "core", "models") +for name in [ + "nexent", "nexent.core", "nexent.core.utils", + "nexent.core.utils.observer", "nexent.core.models", + "nexent.memory", "nexent.container", "nexent.monitor", + "nexent.monitor.monitoring", + "nexent.core.utils.token_estimation", +]: + if name not in sys.modules: + mod = _types.ModuleType(name) + if name == "nexent.core.utils.observer": + mod = _observer_mod + # Mark packages so submodule imports work + if not name.endswith((".observer", ".monitoring", ".token_estimation")): + mod.__path__ = [os.path.join(sdk_path, *name.split("."))] + sys.modules[name] = mod + +# Now we can import the module (relative imports resolve against stubs) +from nexent.core.models.litellm_llm import LiteLLMModel + + +def _make_stream_chunks(content="hello", include_usage=True): + """Create fake streaming chunks.""" + chunks = [] + for token in list(content): + chunks.append(SimpleNamespace( + choices=[SimpleNamespace( + delta=SimpleNamespace(content=token, role="assistant", reasoning_content=None), + finish_reason=None, + )], + usage=None, + )) + usage = SimpleNamespace(prompt_tokens=10, completion_tokens=len(content)) if include_usage else None + chunks.append(SimpleNamespace( + choices=[SimpleNamespace( + delta=SimpleNamespace(content=None, role=None, reasoning_content=None), + finish_reason="stop", + )], + usage=usage, + )) + return chunks + + +class TestInit: + def test_basic(self): + m = LiteLLMModel(model_id="anthropic/claude-sonnet-4-20250514") + assert m.model_id == "anthropic/claude-sonnet-4-20250514" + + def test_with_credentials(self): + m = LiteLLMModel(model_id="azure/gpt-4o", api_key="sk-test", api_base="https://x.com") + assert m.api_key == "sk-test" + assert m.api_base == "https://x.com" + + +class TestCall: + def test_streaming_output(self): + m = LiteLLMModel(model_id="gpt-4o-mini") + with patch("litellm.completion", return_value=iter(_make_stream_chunks("OK"))): + r = m([{"role": "user", "content": "hi"}]) + assert r.content == "OK" + + def test_api_key_forwarded(self): + m = LiteLLMModel(model_id="gpt-4o-mini", api_key="sk-test") + with patch("litellm.completion", return_value=iter(_make_stream_chunks("x"))) as mock: + m([{"role": "user", "content": "hi"}]) + assert mock.call_args.kwargs["api_key"] == "sk-test" + + def test_api_key_omitted_when_none(self): + m = LiteLLMModel(model_id="gpt-4o-mini") + with patch("litellm.completion", return_value=iter(_make_stream_chunks("x"))) as mock: + m([{"role": "user", "content": "hi"}]) + assert "api_key" not in mock.call_args.kwargs + + def test_api_base_forwarded(self): + m = LiteLLMModel(model_id="azure/gpt-4o", api_base="https://x.com") + with patch("litellm.completion", return_value=iter(_make_stream_chunks("x"))) as mock: + m([{"role": "user", "content": "hi"}]) + assert mock.call_args.kwargs["api_base"] == "https://x.com" + + def test_drop_params_set(self): + m = LiteLLMModel(model_id="gpt-4o-mini") + with patch("litellm.completion", return_value=iter(_make_stream_chunks("x"))) as mock: + m([{"role": "user", "content": "hi"}]) + assert mock.call_args.kwargs["drop_params"] is True + + def test_response_format(self): + m = LiteLLMModel(model_id="gpt-4o-mini") + with patch("litellm.completion", return_value=iter(_make_stream_chunks("{}"))) as mock: + m([{"role": "user", "content": "json"}], response_format={"type": "json_object"}) + assert mock.call_args.kwargs["response_format"] == {"type": "json_object"} + + def test_stop_sequences(self): + m = LiteLLMModel(model_id="gpt-4o-mini") + with patch("litellm.completion", return_value=iter(_make_stream_chunks("x"))) as mock: + m([{"role": "user", "content": "hi"}], stop_sequences=["END"]) + assert mock.call_args.kwargs["stop"] == ["END"] + + def test_token_tracking(self): + m = LiteLLMModel(model_id="gpt-4o-mini") + with patch("litellm.completion", return_value=iter(_make_stream_chunks("hello"))): + m([{"role": "user", "content": "hi"}]) + assert m.last_input_token_count == 10 + assert m.last_output_token_count == 5 + + +class TestEdgeCases: + def test_empty_stream(self): + m = LiteLLMModel(model_id="gpt-4o-mini") + chunks = [SimpleNamespace( + choices=[SimpleNamespace( + delta=SimpleNamespace(content=None, role="assistant", reasoning_content=None), + finish_reason="stop", + )], usage=SimpleNamespace(prompt_tokens=5, completion_tokens=0), + )] + with patch("litellm.completion", return_value=iter(chunks)): + r = m([{"role": "user", "content": "hi"}]) + assert r.content == "" + + def test_chunk_without_choices(self): + m = LiteLLMModel(model_id="gpt-4o-mini") + chunks = [SimpleNamespace(usage=None), *_make_stream_chunks("OK")] + with patch("litellm.completion", return_value=iter(chunks)): + r = m([{"role": "user", "content": "hi"}]) + assert r.content == "OK" + + def test_context_length_exceeded(self): + """context_length_exceeded during streaming is converted to ValueError.""" + m = LiteLLMModel(model_id="gpt-4o-mini") + + def _exploding_stream(**kwargs): + yield SimpleNamespace( + choices=[SimpleNamespace( + delta=SimpleNamespace(content="x", role="assistant", reasoning_content=None), + finish_reason=None, + )], usage=None, + ) + raise Exception("context_length_exceeded") + + with patch("litellm.completion", side_effect=lambda **kw: _exploding_stream(**kw)): + with pytest.raises(ValueError, match="Token limit exceeded"): + m([{"role": "user", "content": "hi"}]) + + def test_auth_error_propagates(self): + m = LiteLLMModel(model_id="gpt-4o-mini") + with patch("litellm.completion", side_effect=ValueError("invalid api key")): + with pytest.raises(ValueError, match="invalid api key"): + m([{"role": "user", "content": "hi"}]) + + def test_import_error(self): + m = LiteLLMModel(model_id="gpt-4o-mini") + with patch.dict("sys.modules", {"litellm": None}): + with pytest.raises(ImportError, match="litellm is required"): + m([{"role": "user", "content": "hi"}]) + + def test_stop_event_interrupts(self): + m = LiteLLMModel(model_id="gpt-4o-mini") + m.stop_event.set() + with patch("litellm.completion", return_value=iter(_make_stream_chunks("long text"))): + with pytest.raises(RuntimeError, match="interrupted"): + m([{"role": "user", "content": "hi"}]) + + +@pytest.mark.skipif( + "ANTHROPIC_FOUNDRY_API_KEY" not in os.environ, + reason="Live E2E requires ANTHROPIC_FOUNDRY_API_KEY", +) +class TestLiveE2E: + def test_live_streaming(self): + m = LiteLLMModel( + model_id="anthropic/" + os.environ.get("ANTHROPIC_DEFAULT_SONNET_MODEL", "claude-sonnet-4-20250514"), + api_key=os.environ["ANTHROPIC_FOUNDRY_API_KEY"], + api_base=os.environ.get("ANTHROPIC_FOUNDRY_BASE_URL"), + temperature=0.7, + ) + r = m([{"role": "user", "content": "Say OK and nothing else."}]) + assert isinstance(r.content, str) + assert len(r.content) > 0 + print(f"Live E2E response: {r.content!r}")