diff --git a/server-https/app.py b/server-https/app.py index 694af8a..bca3ba8 100644 --- a/server-https/app.py +++ b/server-https/app.py @@ -22,6 +22,11 @@ MILVUS_VECTOR_FIELD = os.getenv("MILVUS_VECTOR_FIELD", "vector") EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "sentence-transformers/all-mpnet-base-v2") +# Initialize embedding model once at module level to avoid reloading on every request +print(f"Loading embedding model: {EMBEDDING_MODEL}...") +_embedding_model = SentenceTransformer(EMBEDDING_MODEL) +print("Embedding model loaded successfully.") + # System prompt (same as WebSocket version) SYSTEM_PROMPT = """ You are the Kubeflow Docs Assistant. @@ -119,9 +124,8 @@ def milvus_search(query: str, top_k: int = 5) -> Dict[str, Any]: collection = Collection(MILVUS_COLLECTION) collection.load() - # Encoder (same model as pipeline) - encoder = SentenceTransformer(EMBEDDING_MODEL) - query_vec = encoder.encode(query).tolist() + # Encode query using cached model + query_vec = _embedding_model.encode(query).tolist() search_params = {"metric_type": "COSINE", "params": {"nprobe": 32}} results = collection.search( diff --git a/server/app.py b/server/app.py index 96b277c..4d3fa56 100644 --- a/server/app.py +++ b/server/app.py @@ -22,6 +22,11 @@ MILVUS_VECTOR_FIELD = os.getenv("MILVUS_VECTOR_FIELD", "vector") EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "sentence-transformers/all-mpnet-base-v2") +# Initialize embedding model once at module level to avoid reloading on every request +print(f"Loading embedding model: {EMBEDDING_MODEL}...") +_embedding_model = SentenceTransformer(EMBEDDING_MODEL) +print("Embedding model loaded successfully.") + # System prompt SYSTEM_PROMPT = """ You are the Kubeflow Docs Assistant. @@ -73,9 +78,8 @@ def milvus_search(query: str, top_k: int = 5) -> Dict[str, Any]: collection = Collection(MILVUS_COLLECTION) collection.load() - # Encoder (same model as pipeline) - encoder = SentenceTransformer(EMBEDDING_MODEL) - query_vec = encoder.encode(query).tolist() + # Encode query using cached model + query_vec = _embedding_model.encode(query).tolist() search_params = {"metric_type": "COSINE", "params": {"nprobe": 32}} results = collection.search(