-
-
Notifications
You must be signed in to change notification settings - Fork 2.3k
fix: cannot automatically get embedding dim when create embedding provider #5442
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 6 commits
b13bf36
64ae120
5b31476
fd3c337
0e5946a
4763cb5
e28a75f
6620f2b
227eb1a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -23,12 +23,16 @@ def __init__(self, provider_config: dict, provider_settings: dict) -> None: | |
| if proxy: | ||
| logger.info(f"[OpenAI Embedding] 使用代理: {proxy}") | ||
| http_client = httpx.AsyncClient(proxy=proxy) | ||
| api_base = provider_config.get("embedding_api_base", "").strip() | ||
| if not api_base: | ||
| api_base = "https://api.openai.com/v1" | ||
| else: | ||
| api_base = api_base.removesuffix("/") | ||
| if not api_base.endswith("/v1"): | ||
| api_base = f"{api_base}/v1" | ||
| self.client = AsyncOpenAI( | ||
| api_key=provider_config.get("embedding_api_key"), | ||
| base_url=provider_config.get( | ||
| "embedding_api_base", | ||
| "https://api.openai.com/v1", | ||
| ), | ||
| base_url=api_base, | ||
| timeout=int(provider_config.get("timeout", 20)), | ||
| http_client=http_client, | ||
| ) | ||
|
|
@@ -52,6 +56,64 @@ async def get_embeddings(self, text: list[str]) -> list[list[float]]: | |
| ) | ||
| return [item.embedding for item in embeddings.data] | ||
|
|
||
| async def detect_dim(self) -> int: | ||
| """探测模型可用的最大向量维度""" | ||
|
|
||
| async def _request_dim(dimensions: int | None) -> int: | ||
| kwargs = { | ||
| "input": "echo", | ||
| "model": self.model, | ||
| } | ||
| if dimensions is not None: | ||
| kwargs["dimensions"] = dimensions | ||
| embedding = await self.client.embeddings.create(**kwargs) | ||
| return len(embedding.data[0].embedding) | ||
|
|
||
| # 1) 默认调用,获取当前默认维度 | ||
| base_dim = await _request_dim(None) | ||
|
|
||
| # 2) 先判断 dimensions 参数是否可调 | ||
| probe_dim = base_dim + 1 | ||
| try: | ||
| probe_result = await _request_dim(probe_dim) | ||
| if probe_result != probe_dim: | ||
| return base_dim | ||
| except Exception: | ||
| return base_dim | ||
|
|
||
| # 3) 可调时探测上界:指数扩张 + 二分 | ||
| max_cap = 32768 | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| low = probe_dim | ||
| high = max(base_dim * 2, probe_dim + 1) | ||
| if high > max_cap: | ||
| high = max_cap | ||
|
|
||
| while high < max_cap: | ||
| try: | ||
| result_dim = await _request_dim(high) | ||
| if result_dim != high: | ||
| break | ||
| low = high | ||
| high = min(high * 2, max_cap) | ||
| except Exception: | ||
| break | ||
|
|
||
| left = low + 1 | ||
| right = high - 1 | ||
| while left <= right: | ||
| mid = (left + right) // 2 | ||
| try: | ||
| result_dim = await _request_dim(mid) | ||
| if result_dim == mid: | ||
| low = mid | ||
| left = mid + 1 | ||
| else: | ||
| right = mid - 1 | ||
| except Exception: | ||
| right = mid - 1 | ||
|
|
||
| return low | ||
|
|
||
| def get_dim(self) -> int: | ||
| """获取向量的维度""" | ||
| return int(self.provider_config.get("embedding_dimensions", 1024)) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -754,6 +754,16 @@ async def get_embedding_dim(self): | |
| if not provider_type: | ||
| return Response().error("provider_config 缺少 type 字段").__dict__ | ||
|
|
||
| # 首次添加某类提供商时,provider_cls_map 可能尚未注册该适配器 | ||
| if provider_type not in provider_cls_map: | ||
| try: | ||
| self.core_lifecycle.provider_manager.dynamic_import_provider( | ||
| provider_type, | ||
| ) | ||
|
Comment on lines
+760
to
+762
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The Recommendation: Validate |
||
| except ImportError as e: | ||
| logger.error(traceback.format_exc()) | ||
| return Response().error(f"动态导入提供商适配器失败: {e!s}").__dict__ | ||
|
|
||
| # 获取对应的 provider 类 | ||
| if provider_type not in provider_cls_map: | ||
| return ( | ||
|
|
@@ -779,9 +789,8 @@ async def get_embedding_dim(self): | |
| if inspect.iscoroutinefunction(init_fn): | ||
| await init_fn() | ||
|
|
||
| # 获取嵌入向量维度 | ||
| vec = await inst.get_embedding("echo") | ||
| dim = len(vec) | ||
| # 探测嵌入向量维度(优先使用 provider 的原生探测逻辑) | ||
| dim = await inst.detect_dim() | ||
|
|
||
| logger.info( | ||
| f"检测到 {provider_config.get('id', 'unknown')} 的嵌入向量维度为 {dim}", | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
embedding_api_baseis taken directly from the user-provided configuration and used as the base URL for the OpenAI client without any validation. This allows an attacker to perform Server-Side Request Forgery (SSRF) by providing internal IP addresses or malicious domains, which the server will then attempt to connect to during the dimension detection process.Recommendation: Validate the
api_baseURL to ensure it does not point to internal or reserved IP addresses.