Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions server-https/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@
MILVUS_VECTOR_FIELD = os.getenv("MILVUS_VECTOR_FIELD", "vector")
EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "sentence-transformers/all-mpnet-base-v2")

# Initialize embedding model once at module level to avoid reloading on every request
print(f"Loading embedding model: {EMBEDDING_MODEL}...")
_embedding_model = SentenceTransformer(EMBEDDING_MODEL)
print("Embedding model loaded successfully.")

# System prompt (same as WebSocket version)
SYSTEM_PROMPT = """
You are the Kubeflow Docs Assistant.
Expand Down Expand Up @@ -119,9 +124,8 @@ def milvus_search(query: str, top_k: int = 5) -> Dict[str, Any]:
collection = Collection(MILVUS_COLLECTION)
collection.load()

# Encoder (same model as pipeline)
encoder = SentenceTransformer(EMBEDDING_MODEL)
query_vec = encoder.encode(query).tolist()
# Encode query using cached model
query_vec = _embedding_model.encode(query).tolist()

search_params = {"metric_type": "COSINE", "params": {"nprobe": 32}}
results = collection.search(
Expand Down
10 changes: 7 additions & 3 deletions server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@
MILVUS_VECTOR_FIELD = os.getenv("MILVUS_VECTOR_FIELD", "vector")
EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "sentence-transformers/all-mpnet-base-v2")

# Initialize embedding model once at module level to avoid reloading on every request
print(f"Loading embedding model: {EMBEDDING_MODEL}...")
_embedding_model = SentenceTransformer(EMBEDDING_MODEL)
print("Embedding model loaded successfully.")

# System prompt
SYSTEM_PROMPT = """
You are the Kubeflow Docs Assistant.
Expand Down Expand Up @@ -73,9 +78,8 @@ def milvus_search(query: str, top_k: int = 5) -> Dict[str, Any]:
collection = Collection(MILVUS_COLLECTION)
collection.load()

# Encoder (same model as pipeline)
encoder = SentenceTransformer(EMBEDDING_MODEL)
query_vec = encoder.encode(query).tolist()
# Encode query using cached model
query_vec = _embedding_model.encode(query).tolist()

search_params = {"metric_type": "COSINE", "params": {"nprobe": 32}}
results = collection.search(
Expand Down