diff --git a/README.md b/README.md index b387d19..b557bcf 100644 --- a/README.md +++ b/README.md @@ -311,6 +311,26 @@ A better improvement would be using the embedding model as a service where users - Enable better caching and optimization - Improve scalability + + +### Shared embedding and vector services + +The docs-agent now centralizes its embedding model and Milvus client: + +- `server/vector_services/embedding.py` exposes a shared `SentenceTransformer` + instance with a small in-process cache for repeated texts. +- `server/vector_services/milvus_client.py` manages a shared Milvus connection + and collection handle, and provides a helper for vector search. + +This design: + +- Reduces cold-start latency and repeated model loads. +- Encourages reuse of a single “vector layer” across agent tools and KFP + components. +- Makes it easy to evolve toward running the embedding model as a separate + service (e.g., via KServe or MCP) by swapping only the implementation of + `vector_services/embedding.py`. + ### API Server Two API implementations are provided for different use cases: diff --git a/server/app.py b/server/app.py index 96b277c..dac9a58 100644 --- a/server/app.py +++ b/server/app.py @@ -7,9 +7,44 @@ from websockets.exceptions import ConnectionClosedError import logging from typing import Dict, Any, List -from sentence_transformers import SentenceTransformer + +from fastapi import FastAPI, HTTPException from pymilvus import connections, Collection +from sentence_transformers import SentenceTransformer + +from server.vector_services.embedding import embed_text +from server.vector_services.milvus_client import search_vectors + +app = FastAPI() + + +@app.get("/search_docs") +async def search_docs(q: str, top_k: int = 5): + if not q: + raise HTTPException(status_code=400, detail="Query text 'q' is required") + + # 1) embed query via shared SentenceTransformer + query_embedding = embed_text(q) + + # 2) search Milvus via shared client + milvus_results = search_vectors( + query_vectors=[query_embedding], + top_k=top_k, + output_fields=["doc_id", "title", "path"], + ) + + hits = [] + for hit in milvus_results[0]: + hits.append( + { + "doc_id": hit.entity.get("doc_id"), + "title": hit.entity.get("title"), + "path": hit.entity.get("path"), + "score": float(hit.distance), + } + ) + return {"query": q, "results": hits} # Config KSERVE_URL = os.getenv("KSERVE_URL", "http://llama.docs-agent.svc.cluster.local/openai/v1/chat/completions") MODEL = os.getenv("MODEL", "llama3.1-8B") diff --git a/server/vector_services/embedding.py b/server/vector_services/embedding.py new file mode 100644 index 0000000..79d4d34 --- /dev/null +++ b/server/vector_services/embedding.py @@ -0,0 +1,63 @@ +# server/vector_services/embedding.py + +from functools import lru_cache +import os +from typing import List + +from sentence_transformers import SentenceTransformer +import numpy as np + + +EMBEDDING_MODEL_ENV = "EMBEDDING_MODEL" +DEFAULT_EMBEDDING_MODEL = "sentence-transformers/all-mpnet-base-v2" + + +@lru_cache(maxsize=1) +def get_sentence_transformer() -> SentenceTransformer: + """ + Return a shared SentenceTransformer instance. + + Uses an LRU cache to ensure a single model instance per process. + """ + model_name = os.getenv(EMBEDDING_MODEL_ENV, DEFAULT_EMBEDDING_MODEL) + model = SentenceTransformer(model_name) + return model + + +@lru_cache(maxsize=8192) +def _embed_text_cached_single(text: str) -> np.ndarray: + """ + Embed a single text with caching. + + Internal helper; use embed_text or embed_texts from callers. + """ + model = get_sentence_transformer() + emb = model.encode([text], convert_to_numpy=True)[0] + return emb + + +def embed_text(text: str) -> List[float]: + """ + Public helper for embedding a single text as a Python list[float]. + """ + emb = _embed_text_cached_single(text) + return emb.tolist() + + +def embed_texts(texts: List[str]) -> List[List[float]]: + """ + Embed a list of texts. + + Uses the shared model instance; for repeated single texts, the + internal cache will avoid recomputation. + """ + if not texts: + return [] + + model = get_sentence_transformer() + embs = model.encode(texts, convert_to_numpy=True) + if isinstance(embs, np.ndarray): + return [row.tolist() for row in embs] + + # Fallback if encode returns a list-like object + return [np.asarray(row).tolist() for row in embs] \ No newline at end of file diff --git a/server/vector_services/milvus_client.py b/server/vector_services/milvus_client.py new file mode 100644 index 0000000..5f43d8d --- /dev/null +++ b/server/vector_services/milvus_client.py @@ -0,0 +1,84 @@ +# server/vector_services/milvus_client.py + +from functools import lru_cache +import os +from typing import List, Dict, Any + +from pymilvus import ( + connections, + Collection, + utility, +) + + +MILVUS_HOST_ENV = "MILVUS_HOST" +MILVUS_PORT_ENV = "MILVUS_PORT" +MILVUS_COLLECTION_ENV = "MILVUS_COLLECTION" + +DEFAULT_MILVUS_HOST = "milvus" +DEFAULT_MILVUS_PORT = "19530" +DEFAULT_MILVUS_COLLECTION = "docs_rag" + + +@lru_cache(maxsize=1) +def _connect_default() -> None: + """ + Establish a shared connection to Milvus. + + This is called implicitly by get_milvus_collection. + """ + host = os.getenv(MILVUS_HOST_ENV, DEFAULT_MILVUS_HOST) + port = os.getenv(MILVUS_PORT_ENV, DEFAULT_MILVUS_PORT) + + connections.connect( + alias="default", + host=host, + port=port, + ) + + +@lru_cache(maxsize=1) +def get_milvus_collection() -> Collection: + """ + Return a shared Collection handle. + + Lazily connects to Milvus and opens the configured collection. + """ + _connect_default() + collection_name = os.getenv(MILVUS_COLLECTION_ENV, DEFAULT_MILVUS_COLLECTION) + + if not utility.has_collection(collection_name): + raise RuntimeError(f"Milvus collection '{collection_name}' does not exist") + + return Collection(collection_name) + + +def search_vectors( + query_vectors: List[List[float]], + top_k: int = 5, + search_params: Dict[str, Any] | None = None, + output_fields: List[str] | None = None, +): + """ + Convenience wrapper around Milvus collection.search. + """ + collection = get_milvus_collection() + + if search_params is None: + search_params = { + "metric_type": "IP", + "params": {"nprobe": 10}, + } + + if output_fields is None: + output_fields = [] + + results = collection.search( + data=query_vectors, + anns_field="embedding", + param=search_params, + limit=top_k, + output_fields=output_fields, + ) + + return results \ No newline at end of file