Skip to content
Merged
6 changes: 4 additions & 2 deletions astrbot/core/config/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -1463,6 +1463,7 @@ class ChatProviderTemplate(TypedDict):
"type": "openai_embedding",
"provider": "openai",
"provider_type": "embedding",
"hint": "provider_group.provider.openai_embedding.hint",
"enable": True,
"embedding_api_key": "",
"embedding_api_base": "",
Expand All @@ -1476,6 +1477,7 @@ class ChatProviderTemplate(TypedDict):
"type": "gemini_embedding",
"provider": "google",
"provider_type": "embedding",
"hint": "provider_group.provider.gemini_embedding.hint",
"enable": True,
"embedding_api_key": "",
"embedding_api_base": "",
Expand Down Expand Up @@ -2192,9 +2194,9 @@ class ChatProviderTemplate(TypedDict):
"type": "string",
},
"proxy": {
"description": "代理地址",
"description": "provider_group.provider.proxy.description",
"type": "string",
"hint": "HTTP/HTTPS 代理地址,格式如 http://127.0.0.1:7890。仅对该提供商的 API 请求生效,不影响 Docker 内网通信。",
"hint": "provider_group.provider.proxy.hint",
},
"model": {
"description": "模型 ID",
Expand Down
4 changes: 4 additions & 0 deletions astrbot/core/provider/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,10 @@ def get_dim(self) -> int:
"""获取向量的维度"""
...

async def detect_dim(self) -> int:
"""探测模型原生向量维度(默认实现)"""
return len(await self.get_embedding("astrbot"))

async def test(self) -> None:
await self.get_embedding("astrbot")

Expand Down
13 changes: 13 additions & 0 deletions astrbot/core/provider/sources/gemini_embedding_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,19 @@ async def get_embeddings(self, text: list[str]) -> list[list[float]]:
except APIError as e:
raise Exception(f"Gemini Embedding API批量请求失败: {e.message}")

async def detect_dim(self) -> int:
"""探测模型原生向量维度(不传 output_dimensionality)"""
try:
result = await self.client.models.embed_content(
model=self.model,
contents="echo",
)
assert result.embeddings is not None
assert result.embeddings[0].values is not None
return len(result.embeddings[0].values)
except APIError as e:
raise Exception(f"Gemini Embedding 维度探测失败: {e.message}")

def get_dim(self) -> int:
"""获取向量的维度"""
return int(self.provider_config.get("embedding_dimensions", 768))
Expand Down
70 changes: 66 additions & 4 deletions astrbot/core/provider/sources/openai_embedding_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,16 @@ def __init__(self, provider_config: dict, provider_settings: dict) -> None:
if proxy:
logger.info(f"[OpenAI Embedding] 使用代理: {proxy}")
http_client = httpx.AsyncClient(proxy=proxy)
api_base = provider_config.get("embedding_api_base", "").strip()
if not api_base:
api_base = "https://api.openai.com/v1"
else:
api_base = api_base.removesuffix("/")
if not api_base.endswith("/v1"):
api_base = f"{api_base}/v1"
Comment on lines +26 to +32
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-high high

The embedding_api_base is taken directly from the user-provided configuration and used as the base URL for the OpenAI client without any validation. This allows an attacker to perform Server-Side Request Forgery (SSRF) by providing internal IP addresses or malicious domains, which the server will then attempt to connect to during the dimension detection process.

Recommendation: Validate the api_base URL to ensure it does not point to internal or reserved IP addresses.

self.client = AsyncOpenAI(
api_key=provider_config.get("embedding_api_key"),
base_url=provider_config.get(
"embedding_api_base",
"https://api.openai.com/v1",
),
base_url=api_base,
timeout=int(provider_config.get("timeout", 20)),
http_client=http_client,
)
Expand All @@ -52,6 +56,64 @@ async def get_embeddings(self, text: list[str]) -> list[list[float]]:
)
return [item.embedding for item in embeddings.data]

async def detect_dim(self) -> int:
"""探测模型可用的最大向量维度"""

async def _request_dim(dimensions: int | None) -> int:
kwargs = {
"input": "echo",
"model": self.model,
}
if dimensions is not None:
kwargs["dimensions"] = dimensions
embedding = await self.client.embeddings.create(**kwargs)
return len(embedding.data[0].embedding)

# 1) 默认调用,获取当前默认维度
base_dim = await _request_dim(None)

# 2) 先判断 dimensions 参数是否可调
probe_dim = base_dim + 1
try:
probe_result = await _request_dim(probe_dim)
if probe_result != probe_dim:
return base_dim
except Exception:
return base_dim

# 3) 可调时探测上界:指数扩张 + 二分
max_cap = 32768
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

这里的 max_cap = 32768 是一个魔术数字。为了提高代码的可读性和可维护性,建议将其定义为一个具名常量,例如 MAX_DIMENSION_CAP

Suggested change
max_cap = 32768
MAX_DIMENSION_CAP = 32768

low = probe_dim
high = max(base_dim * 2, probe_dim + 1)
if high > max_cap:
high = max_cap

while high < max_cap:
try:
result_dim = await _request_dim(high)
if result_dim != high:
break
low = high
high = min(high * 2, max_cap)
except Exception:
break

left = low + 1
right = high - 1
while left <= right:
mid = (left + right) // 2
try:
result_dim = await _request_dim(mid)
if result_dim == mid:
low = mid
left = mid + 1
else:
right = mid - 1
except Exception:
right = mid - 1

return low

def get_dim(self) -> int:
"""获取向量的维度"""
return int(self.provider_config.get("embedding_dimensions", 1024))
Expand Down
15 changes: 12 additions & 3 deletions astrbot/dashboard/routes/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -754,6 +754,16 @@ async def get_embedding_dim(self):
if not provider_type:
return Response().error("provider_config 缺少 type 字段").__dict__

# 首次添加某类提供商时,provider_cls_map 可能尚未注册该适配器
if provider_type not in provider_cls_map:
try:
self.core_lifecycle.provider_manager.dynamic_import_provider(
provider_type,
)
Comment on lines +760 to +762
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-medium medium

The provider_type is taken from the user-provided configuration and passed directly to dynamic_import_provider without validation. This could allow an attacker to trigger the loading of arbitrary modules if the provider_type is not properly sanitized or checked against an allow-list.

Recommendation: Validate provider_type against a list of known, safe provider types before attempting to import it.

except ImportError as e:
logger.error(traceback.format_exc())
return Response().error(f"动态导入提供商适配器失败: {e!s}").__dict__

# 获取对应的 provider 类
if provider_type not in provider_cls_map:
return (
Expand All @@ -779,9 +789,8 @@ async def get_embedding_dim(self):
if inspect.iscoroutinefunction(init_fn):
await init_fn()

# 获取嵌入向量维度
vec = await inst.get_embedding("echo")
dim = len(vec)
# 探测嵌入向量维度(优先使用 provider 的原生探测逻辑)
dim = await inst.detect_dim()

logger.info(
f"检测到 {provider_config.get('id', 'unknown')} 的嵌入向量维度为 {dim}",
Expand Down
42 changes: 38 additions & 4 deletions dashboard/src/components/shared/AstrBotConfig.vue
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,40 @@ const filteredIterable = computed(() => {
return rest
})

const providerHint = computed(() => {
const hint = props.iterable?.hint
if (typeof hint !== 'string' || !hint) return ''

if (
hint === 'provider_group.provider.openai_embedding.hint'
|| hint === 'provider_group.provider.gemini_embedding.hint'
) {
return ''
}

return hint
})

const getItemHint = (itemKey, itemMeta) => {
if (itemMeta?.hint) return itemMeta.hint

if (itemKey !== 'embedding_api_base') return ''

const providerType = props.iterable?.type
if (providerType === 'openai_embedding') {
return getRaw('provider_group.provider.openai_embedding.hint')
? 'provider_group.provider.openai_embedding.hint'
: ''
}
if (providerType === 'gemini_embedding') {
return getRaw('provider_group.provider.gemini_embedding.hint')
? 'provider_group.provider.gemini_embedding.hint'
: ''
}

return ''
}

const dialog = ref(false)
const currentEditingKey = ref('')
const currentEditingLanguage = ref('json')
Expand Down Expand Up @@ -153,14 +187,14 @@ function hasVisibleItemsAfter(items, currentIndex) {
<div v-if="metadata[metadataKey]?.type === 'object' || metadata[metadataKey]?.config_template" class="object-config">
<!-- Provider-level hint -->
<v-alert
v-if="iterable.hint && !isEditing"
v-if="providerHint"
type="info"
variant="tonal"
class="mb-4"
border="start"
density="compact"
>
{{ iterable.hint }}
{{ translateIfKey(providerHint) }}
</v-alert>

<div v-for="(val, key, index) in filteredIterable" :key="key" class="config-item">
Expand Down Expand Up @@ -218,9 +252,9 @@ function hasVisibleItemsAfter(items, currentIndex) {
</v-list-item-title>

<v-list-item-subtitle class="property-hint">
<span v-if="metadata[metadataKey].items[key]?.obvious_hint && metadata[metadataKey].items[key]?.hint"
<span v-if="metadata[metadataKey].items[key]?.obvious_hint && getItemHint(key, metadata[metadataKey].items[key])"
class="important-hint">‼️</span>
{{ translateIfKey(metadata[metadataKey].items[key]?.hint) }}
{{ translateIfKey(getItemHint(key, metadata[metadataKey].items[key])) }}
</v-list-item-subtitle>
</v-list-item>
</v-col>
Expand Down
10 changes: 10 additions & 0 deletions dashboard/src/i18n/locales/en-US/features/config-metadata.json
Original file line number Diff line number Diff line change
Expand Up @@ -1086,6 +1086,12 @@
"embedding_api_base": {
"description": "API Base URL"
},
"openai_embedding": {
"hint": "OpenAI Embedding automatically appends /v1 at request time."
},
"gemini_embedding": {
"hint": "Gemini Embedding does not require manually adding /v1beta."
},
"volcengine_cluster": {
"description": "Volcengine cluster",
"hint": "For voice cloning models, choose volcano_icl or volcano_icl_concurr; default is volcano_tts."
Expand Down Expand Up @@ -1313,6 +1319,10 @@
"api_base": {
"description": "API Base URL"
},
"proxy": {
"description": "Proxy address",
"hint": "HTTP/HTTPS proxy URL, e.g. http://127.0.0.1:7890. Applies only to this provider's API requests and does not affect Docker internal networking."
},
"model": {
"description": "Model ID",
"hint": "Model name, e.g., gpt-4o-mini, deepseek-chat."
Expand Down
10 changes: 10 additions & 0 deletions dashboard/src/i18n/locales/zh-CN/features/config-metadata.json
Original file line number Diff line number Diff line change
Expand Up @@ -1089,6 +1089,12 @@
"embedding_api_base": {
"description": "API Base URL"
},
"openai_embedding": {
"hint": "OpenAI Embedding 会在请求时自动补上 /v1。"
},
"gemini_embedding": {
"hint": "Gemini Embedding 无需手动添加 /v1beta。"
},
"volcengine_cluster": {
"description": "火山引擎集群",
"hint": "若使用语音复刻大模型,可选volcano_icl或volcano_icl_concurr,默认使用volcano_tts"
Expand Down Expand Up @@ -1316,6 +1322,10 @@
"api_base": {
"description": "API Base URL"
},
"proxy": {
"description": "代理地址",
"hint": "HTTP/HTTPS 代理地址,格式如 http://127.0.0.1:7890。仅对该提供商的 API 请求生效,不影响 Docker 内网通信。"
},
"model": {
"description": "模型 ID",
"hint": "模型名称,如 gpt-4o-mini, deepseek-chat。"
Expand Down
6 changes: 4 additions & 2 deletions dashboard/src/views/Settings.vue
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ const loadApiKeys = async () => {
const tryExecCommandCopy = (text) => {
let textArea = null;
try {
if (typeof document === 'undefined') return false;
if (typeof document === 'undefined' || !document.body) return false;
textArea = document.createElement('textarea');
textArea.value = text;
textArea.setAttribute('readonly', '');
Expand All @@ -353,7 +353,9 @@ const tryExecCommandCopy = (text) => {
return false;
} finally {
try {
textArea?.remove?.();
if (textArea?.parentNode) {
textArea.parentNode.removeChild(textArea);
}
} catch (_) {
// ignore cleanup errors
}
Expand Down