Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,26 @@ A better improvement would be using the embedding model as a service where users
- Enable better caching and optimization
- Improve scalability



### Shared embedding and vector services

The docs-agent now centralizes its embedding model and Milvus client:

- `server/vector_services/embedding.py` exposes a shared `SentenceTransformer`
instance with a small in-process cache for repeated texts.
- `server/vector_services/milvus_client.py` manages a shared Milvus connection
and collection handle, and provides a helper for vector search.

This design:

- Reduces cold-start latency and repeated model loads.
- Encourages reuse of a single “vector layer” across agent tools and KFP
components.
- Makes it easy to evolve toward running the embedding model as a separate
service (e.g., via KServe or MCP) by swapping only the implementation of
`vector_services/embedding.py`.

### API Server

Two API implementations are provided for different use cases:
Expand Down
37 changes: 36 additions & 1 deletion server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,44 @@
from websockets.exceptions import ConnectionClosedError
import logging
from typing import Dict, Any, List
from sentence_transformers import SentenceTransformer

from fastapi import FastAPI, HTTPException
from pymilvus import connections, Collection
from sentence_transformers import SentenceTransformer

from server.vector_services.embedding import embed_text
from server.vector_services.milvus_client import search_vectors

app = FastAPI()


@app.get("/search_docs")
async def search_docs(q: str, top_k: int = 5):
if not q:
raise HTTPException(status_code=400, detail="Query text 'q' is required")

# 1) embed query via shared SentenceTransformer
query_embedding = embed_text(q)

# 2) search Milvus via shared client
milvus_results = search_vectors(
query_vectors=[query_embedding],
top_k=top_k,
output_fields=["doc_id", "title", "path"],
)

hits = []
for hit in milvus_results[0]:
hits.append(
{
"doc_id": hit.entity.get("doc_id"),
"title": hit.entity.get("title"),
"path": hit.entity.get("path"),
"score": float(hit.distance),
}
)

return {"query": q, "results": hits}
# Config
KSERVE_URL = os.getenv("KSERVE_URL", "http://llama.docs-agent.svc.cluster.local/openai/v1/chat/completions")
MODEL = os.getenv("MODEL", "llama3.1-8B")
Expand Down
63 changes: 63 additions & 0 deletions server/vector_services/embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# server/vector_services/embedding.py

from functools import lru_cache
import os
from typing import List

from sentence_transformers import SentenceTransformer
import numpy as np


EMBEDDING_MODEL_ENV = "EMBEDDING_MODEL"
DEFAULT_EMBEDDING_MODEL = "sentence-transformers/all-mpnet-base-v2"


@lru_cache(maxsize=1)
def get_sentence_transformer() -> SentenceTransformer:
"""
Return a shared SentenceTransformer instance.

Uses an LRU cache to ensure a single model instance per process.
"""
model_name = os.getenv(EMBEDDING_MODEL_ENV, DEFAULT_EMBEDDING_MODEL)
model = SentenceTransformer(model_name)
return model


@lru_cache(maxsize=8192)
def _embed_text_cached_single(text: str) -> np.ndarray:
"""
Embed a single text with caching.

Internal helper; use embed_text or embed_texts from callers.
"""
model = get_sentence_transformer()
emb = model.encode([text], convert_to_numpy=True)[0]
return emb


def embed_text(text: str) -> List[float]:
"""
Public helper for embedding a single text as a Python list[float].
"""
emb = _embed_text_cached_single(text)
return emb.tolist()


def embed_texts(texts: List[str]) -> List[List[float]]:
"""
Embed a list of texts.

Uses the shared model instance; for repeated single texts, the
internal cache will avoid recomputation.
"""
if not texts:
return []

model = get_sentence_transformer()
embs = model.encode(texts, convert_to_numpy=True)
if isinstance(embs, np.ndarray):
return [row.tolist() for row in embs]

# Fallback if encode returns a list-like object
return [np.asarray(row).tolist() for row in embs]
84 changes: 84 additions & 0 deletions server/vector_services/milvus_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# server/vector_services/milvus_client.py

from functools import lru_cache
import os
from typing import List, Dict, Any

from pymilvus import (
connections,
Collection,
utility,
)


MILVUS_HOST_ENV = "MILVUS_HOST"
MILVUS_PORT_ENV = "MILVUS_PORT"
MILVUS_COLLECTION_ENV = "MILVUS_COLLECTION"

DEFAULT_MILVUS_HOST = "milvus"
DEFAULT_MILVUS_PORT = "19530"
DEFAULT_MILVUS_COLLECTION = "docs_rag"


@lru_cache(maxsize=1)
def _connect_default() -> None:
"""
Establish a shared connection to Milvus.

This is called implicitly by get_milvus_collection.
"""
host = os.getenv(MILVUS_HOST_ENV, DEFAULT_MILVUS_HOST)
port = os.getenv(MILVUS_PORT_ENV, DEFAULT_MILVUS_PORT)

connections.connect(
alias="default",
host=host,
port=port,
)


@lru_cache(maxsize=1)
def get_milvus_collection() -> Collection:
"""
Return a shared Collection handle.

Lazily connects to Milvus and opens the configured collection.
"""
_connect_default()
collection_name = os.getenv(MILVUS_COLLECTION_ENV, DEFAULT_MILVUS_COLLECTION)

if not utility.has_collection(collection_name):
raise RuntimeError(f"Milvus collection '{collection_name}' does not exist")

return Collection(collection_name)


def search_vectors(
query_vectors: List[List[float]],
top_k: int = 5,
search_params: Dict[str, Any] | None = None,
output_fields: List[str] | None = None,
):
"""
Convenience wrapper around Milvus collection.search.
"""
collection = get_milvus_collection()

if search_params is None:
search_params = {
"metric_type": "IP",
"params": {"nprobe": 10},
}

if output_fields is None:
output_fields = []

results = collection.search(
data=query_vectors,
anns_field="embedding",
param=search_params,
limit=top_k,
output_fields=output_fields,
)

return results