Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backend/consts/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ class ProviderEnum(str, Enum):
MODELENGINE = "modelengine"
DASHSCOPE = "dashscope"
TOKENPONY = "tokenpony"
LITELLM = "litellm"


# Silicon Flow
Expand Down
4 changes: 4 additions & 0 deletions backend/services/model_provider_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from services.providers.base import AbstractModelProvider
from services.providers.silicon_provider import SiliconModelProvider
from services.providers.tokenpony_provider import TokenPonyModelProvider
from services.providers.litellm_provider import LiteLLMModelProvider
from services.providers.dashscope_provider import DashScopeModelProvider
from services.providers.modelengine_provider import ModelEngineProvider, get_model_engine_raw_url, MODEL_ENGINE_NORTH_PREFIX
from utils.model_name_utils import split_repo_name, add_repo_to_name
Expand Down Expand Up @@ -48,6 +49,9 @@ async def get_provider_models(model_data: dict) -> List[dict]:
elif model_data["provider"] == ProviderEnum.TOKENPONY.value:
provider = TokenPonyModelProvider()
model_list = await provider.get_models(model_data)
elif model_data["provider"] == ProviderEnum.LITELLM.value:
provider = LiteLLMModelProvider()
model_list = await provider.get_models(model_data)

return model_list

Expand Down
77 changes: 77 additions & 0 deletions backend/services/providers/litellm_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import logging
from typing import Dict, List

from consts.const import DEFAULT_LLM_MAX_TOKENS
from services.providers.base import AbstractModelProvider, _classify_provider_error

logger = logging.getLogger("model_provider")


class LiteLLMModelProvider(AbstractModelProvider):
"""Provider that discovers models via LiteLLM's /v1/models endpoint.

LiteLLM supports 100+ LLM providers (OpenAI, Anthropic, Google Gemini,
Azure, Bedrock, Ollama, etc.) through a unified interface. When pointed
at a LiteLLM proxy, this provider fetches the available model catalog.

For direct SDK usage (no proxy), users should add models manually with
the ``litellm`` provider and use LiteLLM model identifiers like
``anthropic/claude-sonnet-4-20250514`` or ``gemini/gemini-2.5-flash``.
"""

async def get_models(self, provider_config: Dict) -> List[Dict]:
"""
Fetch models from a LiteLLM-compatible /v1/models endpoint.

Args:
provider_config: Configuration dict containing model_type, api_key, and base_url

Returns:
List of models with canonical fields.
"""
import httpx

try:
model_type: str = provider_config.get("model_type", "llm")
api_key: str = provider_config.get("api_key", "")
base_url: str = provider_config.get("base_url", "").rstrip("/")

if not base_url:
return []

headers = {}
if api_key:
headers["Authorization"] = f"Bearer {api_key}"

models_url = f"{base_url}/models"

async with httpx.AsyncClient(verify=False, timeout=15.0) as client:

Check failure on line 48 in backend/services/providers/litellm_provider.py

View check run for this annotation

SonarQubeCloud / SonarCloud Code Analysis

Enable server certificate validation on this SSL/TLS connection.

See more on https://sonarcloud.io/project/issues?id=ModelEngine-Group_nexent&issues=AZ6JlMSAaTQuY9-sSTe1&open=AZ6JlMSAaTQuY9-sSTe1&pullRequest=3182
response = await client.get(models_url, headers=headers)
response.raise_for_status()
data = response.json().get("data", [])

model_list = []
for item in data:
model_id = item.get("id", "")
if not model_id:
continue

model_entry = {
"id": model_id,
"model_type": model_type,
"max_tokens": DEFAULT_LLM_MAX_TOKENS,
}

if model_type in ("llm", "vlm"):
model_entry["model_tag"] = "chat"
elif model_type in ("embedding", "multi_embedding"):
model_entry["model_tag"] = "embedding"
elif model_type == "rerank":
model_entry["model_tag"] = "rerank"

model_list.append(model_entry)

return model_list

except Exception as e:
return _classify_provider_error("LiteLLM", exception=e)
2 changes: 2 additions & 0 deletions sdk/nexent/core/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from .openai_llm import OpenAIModel
from .openai_vlm import OpenAIVLModel
from .openai_long_context_model import OpenAILongContextModel
from .litellm_llm import LiteLLMModel
from .stt_model import BaseSTTModel
from .ali_stt_model import AliSTTModel, AliSTTConfig
from .volc_stt_model import VolcSTTModel, VolcSTTConfig
__all__ = [
"OpenAIModel",
"OpenAIVLModel",
"OpenAILongContextModel",
"LiteLLMModel",
"BaseSTTModel",
"AliSTTModel",
"AliSTTConfig",
Expand Down
201 changes: 201 additions & 0 deletions sdk/nexent/core/models/litellm_llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
"""LiteLLM-backed LLM model for nexent.

Provides access to 100+ LLM providers (OpenAI, Anthropic, Google Gemini,
Azure, Bedrock, Ollama, etc.) through ``litellm.completion()`` as an SDK
dependency. Follows the same interface as ``OpenAIModel``.
"""

import logging
import threading
import time
from typing import Any, Dict, List, Optional

from smolagents import Tool
from smolagents.models import ChatMessage, MessageRole

from ..utils.observer import MessageObserver, ProcessType

logger = logging.getLogger("litellm_llm")


class LiteLLMModel:
"""LLM model backed by LiteLLM SDK.

Uses ``litellm.completion()`` directly, supporting any model identifier
that LiteLLM recognizes (e.g. ``anthropic/claude-sonnet-4-20250514``,
``gemini/gemini-2.5-flash``, ``azure/gpt-4o``).

See https://docs.litellm.ai/docs/providers for the full provider list.
"""

def __init__(
self,
model_id: str,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
temperature: float = 0.2,
top_p: float = 0.95,
observer: MessageObserver = MessageObserver,
display_name: Optional[str] = None,
**kwargs: Any,
):
self.model_id = model_id
self.api_key = api_key
self.api_base = api_base
self.temperature = temperature
self.top_p = top_p
self.observer = observer
self.display_name = display_name
self.stop_event = threading.Event()
self.last_input_token_count = 0
self.last_output_token_count = 0

def __call__(

Check failure on line 53 in sdk/nexent/core/models/litellm_llm.py

View check run for this annotation

SonarQubeCloud / SonarCloud Code Analysis

Refactor this function to reduce its Cognitive Complexity from 43 to the 15 allowed.

See more on https://sonarcloud.io/project/issues?id=ModelEngine-Group_nexent&issues=AZ6JlMhHaTQuY9-sSTe2&open=AZ6JlMhHaTQuY9-sSTe2&pullRequest=3182
self,
messages: List[Dict[str, Any]],
stop_sequences: Optional[List[str]] = None,
response_format: Optional[Dict[str, str]] = None,
tools_to_call_from: Optional[List[Tool]] = None,
**kwargs: Any,
) -> ChatMessage:
try:
import litellm
except ImportError as e:
raise ImportError(
"litellm is required for LiteLLMModel. "
"Install it with: pip install 'litellm>=1.80,<1.87'"
) from e

# Normalize messages to dicts
normalized: List[Dict[str, Any]] = []
for msg in messages or []:
if isinstance(msg, ChatMessage):
normalized.append({
"role": msg.role.value if hasattr(msg.role, "value") else str(msg.role),
"content": msg.content,
})
elif isinstance(msg, dict):
normalized.append(msg)
else:
raise TypeError("Messages must be ChatMessage or dict objects.")

completion_kwargs: Dict[str, Any] = {
"model": self.model_id,
"messages": normalized,
"temperature": self.temperature,
"stream": True,
"drop_params": True,
"stream_options": {"include_usage": True},
}

if self.api_key:
completion_kwargs["api_key"] = self.api_key
if self.api_base:
completion_kwargs["api_base"] = self.api_base
if stop_sequences:
completion_kwargs["stop"] = stop_sequences
if response_format:
completion_kwargs["response_format"] = response_format

# Handle tool calling
if tools_to_call_from:
tool_definitions = []
for tool in tools_to_call_from:
if hasattr(tool, "to_openai_tool"):
tool_definitions.append(tool.to_openai_tool())
if tool_definitions:
completion_kwargs["tools"] = tool_definitions

current_request = litellm.completion(**completion_kwargs)

# Process streaming response
chunk_list = []
token_join = []
role = None

self.observer.current_mode = ProcessType.MODEL_OUTPUT_THINKING

try:
for chunk in current_request:
if not hasattr(chunk, "choices") or not chunk.choices:
chunk_list.append(chunk)
continue

delta = chunk.choices[0].delta
new_token = getattr(delta, "content", None)
reasoning_content = getattr(delta, "reasoning_content", None)

if reasoning_content is not None:
self.observer.add_model_reasoning_content(reasoning_content)

if new_token is not None:
self.observer.add_model_new_token(new_token)
token_join.append(new_token)
role = getattr(delta, "role", None) or role

chunk_list.append(chunk)
if self.stop_event.is_set():
raise RuntimeError("Model is interrupted by stop event")

self.observer.flush_remaining_tokens()
model_output = "".join(token_join)

# Extract token usage from the last chunk
input_tokens = 0
output_tokens = 0
if chunk_list and hasattr(chunk_list[-1], "usage") and chunk_list[-1].usage is not None:
usage = chunk_list[-1].usage
input_tokens = getattr(usage, "prompt_tokens", 0) or 0
output_tokens = getattr(usage, "completion_tokens", 0) or 0

self.last_input_token_count = input_tokens
self.last_output_token_count = output_tokens

from openai.types.chat.chat_completion_message import ChatCompletionMessage

message = ChatMessage.from_dict(
ChatCompletionMessage(
role=role if role else "assistant",
content=model_output,
).model_dump(include={"role", "content", "tool_calls"})
)

from smolagents.monitoring import TokenUsage

if input_tokens > 0 or output_tokens > 0:
message.token_usage = TokenUsage(
input_tokens=input_tokens,
output_tokens=output_tokens,
)

message.raw = current_request
message.role = MessageRole.ASSISTANT
return message

except Exception as e:
if "context_length_exceeded" in str(e):
raise ValueError(f"Token limit exceeded: {str(e)}")
raise

async def check_connectivity(self) -> bool:
"""Test if the LLM provider connection works."""
try:
import litellm
import asyncio

kwargs: Dict[str, Any] = {
"model": self.model_id,
"messages": [{"role": "user", "content": "Hello"}],
"max_tokens": 5,
"drop_params": True,
}
if self.api_key:
kwargs["api_key"] = self.api_key
if self.api_base:
kwargs["api_base"] = self.api_base

await litellm.acompletion(**kwargs)
return True
except Exception as e:
logger.error(f"LiteLLM connectivity check failed: {e}")

Check failure on line 200 in sdk/nexent/core/models/litellm_llm.py

View check run for this annotation

SonarQubeCloud / SonarCloud Code Analysis

Use "logging.exception()" instead.

See more on https://sonarcloud.io/project/issues?id=ModelEngine-Group_nexent&issues=AZ6JlMhHaTQuY9-sSTe3&open=AZ6JlMhHaTQuY9-sSTe3&pullRequest=3182
return False
Loading