diff --git a/README.md b/README.md index b387d19..6713c75 100644 --- a/README.md +++ b/README.md @@ -584,6 +584,51 @@ if data.get('citations'): sentence-transformers/all-mpnet-base-v2 Embedding model + +RERANK_ENABLED +true +Enable lightweight post-retrieval reranking + + +RERANK_CANDIDATE_MULTIPLIER +3 +Multiplier for initial candidate pool before reranking + + +RERANK_SIMILARITY_WEIGHT +0.7 +Weight for vector similarity score in final rerank score + + +RERANK_KEYWORD_WEIGHT +0.2 +Weight for query term overlap with chunk content + + +RERANK_METADATA_WEIGHT +0.1 +Weight for query term overlap with file path and citation URL + + +RERANK_MAX_CANDIDATES +50 +Upper bound for candidate pool size fetched before reranking + + +RERANK_MIN_TOKEN_LEN +3 +Minimum token length used for query/content term overlap scoring + + +RERANK_DEBUG_LOG +false +Enable before/after reranking logs with component scores + + +RERANK_LOG_TOP_N +5 +Number of documents to show in reranking debug logs + diff --git a/eval_retrieval.py b/eval_retrieval.py new file mode 100644 index 0000000..fc2d85e --- /dev/null +++ b/eval_retrieval.py @@ -0,0 +1,140 @@ +import argparse +import os +from typing import Any, Dict, List + +from pymilvus import Collection, connections +from sentence_transformers import SentenceTransformer + +from shared.reranking import candidate_pool_limit, load_rerank_config_from_env, rerank_documents + + +DEFAULT_QUERIES = [ + "How do I create a Kubeflow Pipeline?", + "How to deploy an InferenceService in KServe?", + "Kubeflow Notebook setup requirements", +] + + +def build_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Evaluate retrieval results before/after reranking.") + parser.add_argument("--queries", nargs="*", default=DEFAULT_QUERIES, help="Queries to evaluate") + parser.add_argument("--top-k", type=int, default=5, help="Final top-k results to keep") + parser.add_argument( + "--show-content-chars", + type=int, + default=180, + help="Number of content characters to print per result", + ) + return parser.parse_args() + + +def _print_docs(title: str, docs: List[Dict[str, Any]], show_content_chars: int) -> None: + print(f"\n{title}") + if not docs: + print(" (no results)") + return + + for idx, doc in enumerate(docs, start=1): + content = (doc.get("content_text") or "").replace("\n", " ").strip() + if len(content) > show_content_chars: + content = content[:show_content_chars] + "..." + + print( + f" {idx}. score={doc.get('rerank_score', doc.get('similarity', 0.0)):.4f} " + f"sim={doc.get('similarity', 0.0):.4f} " + f"keyword={doc.get('keyword_score', 0.0):.4f} " + f"metadata={doc.get('metadata_score', 0.0):.4f}" + ) + print(f" file={doc.get('file_path', '')}") + print(f" url={doc.get('citation_url', '')}") + print(f" text={content}") + + +def retrieve_candidates( + query: str, + model: SentenceTransformer, + collection: Collection, + top_k: int, + candidate_limit: int, + vector_field: str, +) -> List[Dict[str, Any]]: + query_vec = model.encode(query).tolist() + search_params = {"metric_type": "COSINE", "params": {"nprobe": 32}} + + results = collection.search( + data=[query_vec], + anns_field=vector_field, + param=search_params, + limit=candidate_limit, + output_fields=["file_path", "content_text", "citation_url"], + ) + + docs: List[Dict[str, Any]] = [] + for hit in results[0]: + entity = hit.entity + docs.append( + { + "similarity": 1.0 - float(hit.distance), + "file_path": entity.get("file_path"), + "citation_url": entity.get("citation_url"), + "content_text": entity.get("content_text") or "", + } + ) + + return docs + + +def main() -> None: + args = build_args() + + milvus_host = os.getenv("MILVUS_HOST", "my-release-milvus.docs-agent.svc.cluster.local") + milvus_port = os.getenv("MILVUS_PORT", "19530") + milvus_collection = os.getenv("MILVUS_COLLECTION", "docs_rag") + milvus_vector_field = os.getenv("MILVUS_VECTOR_FIELD", "vector") + embedding_model_name = os.getenv("EMBEDDING_MODEL", "sentence-transformers/all-mpnet-base-v2") + + rerank_config = load_rerank_config_from_env() + requested_top_k = max(1, int(args.top_k)) + candidate_limit = candidate_pool_limit(requested_top_k, rerank_config) + + print("Retrieval evaluation configuration") + print(f"- collection: {milvus_collection}") + print(f"- top_k: {requested_top_k}") + print(f"- candidate_limit: {candidate_limit}") + print(f"- rerank_enabled: {rerank_config.enabled}") + + connections.connect(alias="default", host=milvus_host, port=milvus_port) + try: + collection = Collection(milvus_collection) + collection.load() + model = SentenceTransformer(embedding_model_name) + + for query in args.queries: + print("\n" + "=" * 100) + print(f"Query: {query}") + + candidates = retrieve_candidates( + query=query, + model=model, + collection=collection, + top_k=requested_top_k, + candidate_limit=candidate_limit, + vector_field=milvus_vector_field, + ) + + before_docs = candidates[:requested_top_k] + after_docs = rerank_documents( + query=query, + docs=candidates, + config=rerank_config, + top_k=requested_top_k, + ) + + _print_docs("Before reranking", before_docs, args.show_content_chars) + _print_docs("After reranking", after_docs, args.show_content_chars) + finally: + connections.disconnect(alias="default") + + +if __name__ == "__main__": + main() diff --git a/git_logs_temp.txt b/git_logs_temp.txt new file mode 100644 index 0000000..b2fd484 Binary files /dev/null and b/git_logs_temp.txt differ diff --git a/git_msg.txt b/git_msg.txt new file mode 100644 index 0000000..c6bf362 --- /dev/null +++ b/git_msg.txt @@ -0,0 +1,12 @@ +bd1725fdd3d161db68f327df783abe7f8d464cf7 +Ayush-kathil +Add thread-safe model init and Milvus search +Introduce thread-safe lazy initialization for SentenceTransformer in server-https/app.py and server/app.py using a lock and double-checked locking, with timing/info logs to avoid repeated model loading. Add threading and time imports. In server/app.py add milvus_search(query, top_k) to perform Milvus connection, load collection, encode query with the cached model, execute search, format results (similarity, file_path, citation_url, truncated content_text), handle errors, and ensure disconnect. These changes reduce initialization overhead and centralize Milvus search logic. + +--- +f05614a123776a4009d7104125931289d23bcdcc +Ayush Gupta +fix(server-https): preserve multi-hop citations in stream_llm_response +Signed-off-by: Ayush-kathil + +--- \ No newline at end of file diff --git a/kagent-feast-mcp/mcp-server/server.py b/kagent-feast-mcp/mcp-server/server.py index 44a9fcb..f286c7c 100644 --- a/kagent-feast-mcp/mcp-server/server.py +++ b/kagent-feast-mcp/mcp-server/server.py @@ -1,26 +1,67 @@ from fastmcp import FastMCP +import logging +import os +import sys +import threading +from pathlib import Path +from typing import Any, Dict, List from pymilvus import MilvusClient from sentence_transformers import SentenceTransformer -MILVUS_URI = "http://milvus..svc.cluster.local:19530" -MILVUS_USER = "root" -MILVUS_PASSWORD = "Milvus" -COLLECTION_NAME = "kubeflow_docs_docs_rag" -EMBEDDING_MODEL = "sentence-transformers/all-mpnet-base-v2" -PORT = 8000 +REPO_ROOT = Path(__file__).resolve().parents[2] +if str(REPO_ROOT) not in sys.path: + sys.path.append(str(REPO_ROOT)) + +from shared.reranking import candidate_pool_limit, load_rerank_config_from_env, rerank_documents + +MILVUS_URI = os.getenv("MILVUS_URI", "http://milvus..svc.cluster.local:19530") +MILVUS_USER = os.getenv("MILVUS_USER", "root") +MILVUS_PASSWORD = os.getenv("MILVUS_PASSWORD", "Milvus") +COLLECTION_NAME = os.getenv("COLLECTION_NAME", "kubeflow_docs_docs_rag") +EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "sentence-transformers/all-mpnet-base-v2") +PORT = int(os.getenv("PORT", "8000")) +RERANK_CONFIG = load_rerank_config_from_env() mcp = FastMCP("Kubeflow Docs MCP Server") +logger = logging.getLogger(__name__) +logging.basicConfig(level=os.getenv("LOG_LEVEL", "INFO").upper()) model: SentenceTransformer = None client: MilvusClient = None +_initialized = False +_init_lock = threading.Lock() def _init(): - global model, client - if model is None: - model = SentenceTransformer(EMBEDDING_MODEL) - if client is None: - client = MilvusClient(uri=MILVUS_URI, user=MILVUS_USER, password=MILVUS_PASSWORD) + """Initialize shared model/client exactly once. + + Synchronization strategy: + - Fast path: return immediately after successful initialization. + - Slow path: take a process-local lock and re-check state (double-checked locking). + + This guarantees that concurrent callers block until initialization completes, + and all callers observe the same initialized instances. + """ + global model, client, _initialized + + if _initialized: + return + + with _init_lock: + if _initialized: + return + + logger.info("Initializing shared MCP resources") + + # Build local instances first, then publish atomically under the lock. + local_model = SentenceTransformer(EMBEDDING_MODEL) + local_client = MilvusClient(uri=MILVUS_URI, user=MILVUS_USER, password=MILVUS_PASSWORD) + + model = local_model + client = local_client + _initialized = True + + logger.info("Shared MCP resources initialized") @mcp.tool() @@ -37,24 +78,54 @@ def search_kubeflow_docs(query: str, top_k: int = 5) -> str: _init() embedding = model.encode(query).tolist() + requested_top_k = max(1, int(top_k)) + candidate_limit = candidate_pool_limit(requested_top_k, RERANK_CONFIG) hits = client.search( collection_name=COLLECTION_NAME, data=[embedding], - limit=top_k, + limit=candidate_limit, output_fields=["content_text", "citation_url", "file_path"], )[0] if not hits: return "No results found for your query." + docs: List[Dict[str, Any]] = [] + for hit in hits: + entity = hit.get("entity", {}) + content_text = entity.get("content_text") or "" + if isinstance(content_text, str) and len(content_text) > 400: + content_text = content_text[:400] + "..." + + docs.append( + { + "distance": float(hit.get("distance", 0.0)), + "similarity": 1.0 - float(hit.get("distance", 0.0)), + "file_path": entity.get("file_path"), + "citation_url": entity.get("citation_url"), + "content_text": content_text, + } + ) + + selected_docs = rerank_documents( + query=query, + docs=docs, + config=RERANK_CONFIG, + top_k=requested_top_k, + logger=logger, + log_prefix="mcp_search", + ) + results = [] - for i, hit in enumerate(hits, 1): - entity = hit["entity"] - entry = f"### Result {i} (score: {hit['distance']:.4f})" - entry += f"\n**Source:** {entity.get('citation_url', '')}" - entry += f"\n**File:** {entity.get('file_path', '')}" - entry += f"\n\n{entity.get('content_text', '')}\n" + for i, doc in enumerate(selected_docs, 1): + entry = f"### Result {i} (rerank_score: {doc.get('rerank_score', 0.0):.4f})" + entry += f"\n**Similarity:** {doc.get('similarity', 0.0):.4f}" + entry += f"\n**Keyword Score:** {doc.get('keyword_score', 0.0):.4f}" + entry += f"\n**Metadata Score:** {doc.get('metadata_score', 0.0):.4f}" + entry += f"\n**Source:** {doc.get('citation_url', '')}" + entry += f"\n**File:** {doc.get('file_path', '')}" + entry += f"\n\n{doc.get('content_text', '')}\n" results.append(entry) return "\n---\n".join(results) diff --git a/kagent-feast-mcp/mcp-server/tests/test_init_concurrency.py b/kagent-feast-mcp/mcp-server/tests/test_init_concurrency.py new file mode 100644 index 0000000..e1f2aef --- /dev/null +++ b/kagent-feast-mcp/mcp-server/tests/test_init_concurrency.py @@ -0,0 +1,129 @@ +import importlib.util +import logging +import sys +import threading +import types +from pathlib import Path + + +class _FastMCPStub: + def __init__(self, _name: str): + self._name = _name + + def tool(self): + def _decorator(func): + return func + + return _decorator + + def run(self, **_kwargs): + return None + + +def _load_server_module(): + module_name = "mcp_server_under_test" + module_path = Path(__file__).resolve().parents[1] / "server.py" + + fastmcp_stub = types.ModuleType("fastmcp") + fastmcp_stub.FastMCP = _FastMCPStub + + pymilvus_stub = types.ModuleType("pymilvus") + + class _MilvusClientStub: + def __init__(self, **_kwargs): + pass + + pymilvus_stub.MilvusClient = _MilvusClientStub + + sentence_stub = types.ModuleType("sentence_transformers") + + class _SentenceTransformerStub: + def __init__(self, *_args, **_kwargs): + pass + + sentence_stub.SentenceTransformer = _SentenceTransformerStub + + sys.modules["fastmcp"] = fastmcp_stub + sys.modules["pymilvus"] = pymilvus_stub + sys.modules["sentence_transformers"] = sentence_stub + + spec = importlib.util.spec_from_file_location(module_name, module_path) + module = importlib.util.module_from_spec(spec) + assert spec is not None and spec.loader is not None + spec.loader.exec_module(module) + return module + + +def test_init_is_thread_safe_and_idempotent(monkeypatch, caplog): + server = _load_server_module() + + model_init_count = 0 + client_init_count = 0 + init_count_lock = threading.Lock() + + class FakeModel: + pass + + class FakeClient: + pass + + def fake_sentence_transformer(_model_name): + nonlocal model_init_count + with init_count_lock: + model_init_count += 1 + return FakeModel() + + def fake_milvus_client(**_kwargs): + nonlocal client_init_count + with init_count_lock: + client_init_count += 1 + return FakeClient() + + monkeypatch.setattr(server, "SentenceTransformer", fake_sentence_transformer) + monkeypatch.setattr(server, "MilvusClient", fake_milvus_client) + + server.model = None + server.client = None + server._initialized = False + + workers = 32 + barrier = threading.Barrier(workers) + errors = [] + seen_models = [] + seen_clients = [] + seen_lock = threading.Lock() + + def worker(): + try: + barrier.wait() + server._init() + with seen_lock: + seen_models.append(server.model) + seen_clients.append(server.client) + except Exception as exc: # pragma: no cover + with seen_lock: + errors.append(exc) + + with caplog.at_level(logging.INFO, logger=server.__name__): + threads = [threading.Thread(target=worker) for _ in range(workers)] + for thread in threads: + thread.start() + for thread in threads: + thread.join(timeout=5) + + assert not errors + assert model_init_count == 1 + assert client_init_count == 1 + assert server._initialized is True + + first_model = seen_models[0] + first_client = seen_clients[0] + + assert first_model is not None + assert first_client is not None + assert all(model is first_model for model in seen_models) + assert all(client is first_client for client in seen_clients) + + messages = [record.getMessage() for record in caplog.records] + assert messages.count("Initializing shared MCP resources") == 1 + assert messages.count("Shared MCP resources initialized") == 1 diff --git a/server-https/app.py b/server-https/app.py index 694af8a..0c6b118 100644 --- a/server-https/app.py +++ b/server-https/app.py @@ -1,14 +1,46 @@ import os import json import httpx +import logging from fastapi import FastAPI, HTTPException from fastapi.responses import StreamingResponse from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel import uvicorn +import sys +from pathlib import Path from typing import Dict, Any, List, Optional, AsyncGenerator from sentence_transformers import SentenceTransformer from pymilvus import connections, Collection +import threading +import time + +REPO_ROOT = Path(__file__).resolve().parents[1] +if str(REPO_ROOT) not in sys.path: + sys.path.append(str(REPO_ROOT)) + +from shared.reranking import candidate_pool_limit, load_rerank_config_from_env, rerank_documents + +_model_lock = threading.Lock() +_embedding_model = None + + +def get_embedding_model(): + """Thread-safe lazy initialization of SentenceTransformer. + + The model is cached once per process because loading it on every request + adds avoidable latency and memory churn. + """ + global _embedding_model + if _embedding_model is None: + with _model_lock: + # Double-checked locking + if _embedding_model is None: + start_t = time.perf_counter() + logger.info("Lazy loading SentenceTransformer model '%s'", EMBEDDING_MODEL) + _embedding_model = SentenceTransformer(EMBEDDING_MODEL) + logger.info("Model loaded in %.3f seconds", time.perf_counter() - start_t) + return _embedding_model # Config KSERVE_URL = os.getenv("KSERVE_URL", "http://llama.docs-agent.svc.cluster.local/openai/v1/chat/completions") @@ -22,6 +54,9 @@ MILVUS_VECTOR_FIELD = os.getenv("MILVUS_VECTOR_FIELD", "vector") EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "sentence-transformers/all-mpnet-base-v2") +logger = logging.getLogger(__name__) +RERANK_CONFIG = load_rerank_config_from_env() + # System prompt (same as WebSocket version) SYSTEM_PROMPT = """ You are the Kubeflow Docs Assistant. @@ -113,6 +148,7 @@ class ChatRequest(BaseModel): def milvus_search(query: str, top_k: int = 5) -> Dict[str, Any]: """Execute a semantic search in Milvus and return structured JSON serializable results.""" + try: # Connect to Milvus connections.connect(alias="default", host=MILVUS_HOST, port=MILVUS_PORT) @@ -120,15 +156,18 @@ def milvus_search(query: str, top_k: int = 5) -> Dict[str, Any]: collection.load() # Encoder (same model as pipeline) - encoder = SentenceTransformer(EMBEDDING_MODEL) - query_vec = encoder.encode(query).tolist() + model = get_embedding_model() + query_vec = model.encode(query).tolist() + + requested_top_k = max(1, int(top_k)) + candidate_limit = candidate_pool_limit(requested_top_k, RERANK_CONFIG) search_params = {"metric_type": "COSINE", "params": {"nprobe": 32}} results = collection.search( data=[query_vec], anns_field=MILVUS_VECTOR_FIELD, param=search_params, - limit=int(top_k), + limit=candidate_limit, output_fields=["file_path", "content_text", "citation_url"], ) @@ -146,15 +185,25 @@ def milvus_search(query: str, top_k: int = 5) -> Dict[str, Any]: "citation_url": entity.get("citation_url"), "content_text": content_text, }) + + hits = rerank_documents( + query=query, + docs=hits, + config=RERANK_CONFIG, + top_k=requested_top_k, + logger=logger, + log_prefix="https_search", + ) + return {"results": hits} - except Exception as e: - print(f"[ERROR] Milvus search failed: {e}") - return {"results": []} + except Exception: + logger.exception("Milvus search failed for query='%s' top_k=%s", query, top_k) + raise finally: try: connections.disconnect(alias="default") - except Exception: - pass + except Exception as disconnect_error: + logger.warning("Milvus disconnect failed: %s", disconnect_error) async def execute_tool(tool_call: Dict[str, Any]) -> tuple[str, List[str]]: """Execute a tool call and return the result and citations""" @@ -166,7 +215,7 @@ async def execute_tool(tool_call: Dict[str, Any]) -> tuple[str, List[str]]: query = arguments.get("query", "") top_k = arguments.get("top_k", 5) - print(f"[TOOL] Executing Milvus search for: '{query}' (top_k={top_k})") + logger.info("Executing Milvus search for query='%s' top_k=%s", query, top_k) result = milvus_search(query, top_k) # Collect citations @@ -191,19 +240,21 @@ async def execute_tool(tool_call: Dict[str, Any]) -> tuple[str, List[str]]: return f"Unknown tool: {function_name}", [] except Exception as e: - print(f"[ERROR] Tool execution failed: {e}") + logger.exception("Tool execution failed") return f"Tool execution failed: {e}", [] -async def stream_llm_response(payload: Dict[str, Any]) -> AsyncGenerator[str, None]: +async def stream_llm_response(payload: Dict[str, Any], citations_collector: Optional[List[str]] = None) -> AsyncGenerator[str, None]: """Stream response from LLM and handle tool calls, yielding SSE events""" - citations_collector = [] + is_outermost = citations_collector is None + if citations_collector is None: + citations_collector = [] try: async with httpx.AsyncClient(timeout=120) as client: async with client.stream("POST", KSERVE_URL, json=payload) as response: if response.status_code != 200: error_msg = f"LLM service error: HTTP {response.status_code}" - print(f"[ERROR] {error_msg}") + logger.error(error_msg) yield f"data: {json.dumps({'type': 'error', 'content': error_msg})}\n\n" return @@ -262,14 +313,14 @@ async def stream_llm_response(payload: Dict[str, Any]) -> AsyncGenerator[str, No # Handle finish reason - execute tools if needed if finish_reason == "tool_calls": - print(f"[TOOL] Finish reason: tool_calls, executing {len(tool_calls_buffer)} tools") + logger.info("Finish reason=tool_calls; executing %s tools", len(tool_calls_buffer)) # Execute all accumulated tool calls for tool_call in tool_calls_buffer.values(): if tool_call["function"]["name"] and tool_call["function"]["arguments"]: try: - print(f"[TOOL] Executing: {tool_call['function']['name']}") - print(f"[TOOL] Arguments: {tool_call['function']['arguments']}") + logger.info("Executing tool=%s", tool_call["function"]["name"]) + logger.debug("Tool arguments=%s", tool_call["function"]["arguments"]) result, tool_citations = await execute_tool(tool_call) @@ -284,18 +335,18 @@ async def stream_llm_response(payload: Dict[str, Any]) -> AsyncGenerator[str, No yield follow_up_chunk except Exception as e: - print(f"[ERROR] Tool execution error: {e}") + logger.exception("Tool execution error") yield f"data: {json.dumps({'type': 'error', 'content': f'Tool execution failed: {e}'})}\n\n" tool_calls_buffer.clear() break # Tool execution complete, exit streaming loop except json.JSONDecodeError as e: - print(f"[ERROR] JSON decode error: {e}, line: {line}") + logger.warning("JSON decode error: %s line=%s", e, line) continue # Send citations if any were collected - if citations_collector: + if is_outermost and citations_collector: # Remove duplicates while preserving order unique_citations = [] for citation in citations_collector: @@ -305,16 +356,17 @@ async def stream_llm_response(payload: Dict[str, Any]) -> AsyncGenerator[str, No yield f"data: {json.dumps({'type': 'citations', 'citations': unique_citations})}\n\n" # Send completion signal - yield f"data: {json.dumps({'type': 'done'})}\n\n" + if is_outermost: + yield f"data: {json.dumps({'type': 'done'})}\n\n" except Exception as e: - print(f"[ERROR] Streaming failed: {e}") + logger.exception("Streaming failed") yield f"data: {json.dumps({'type': 'error', 'content': f'Streaming failed: {e}'})}\n\n" async def handle_tool_follow_up(original_payload: Dict[str, Any], tool_call: Dict[str, Any], tool_result: str, citations_collector: List[str]) -> AsyncGenerator[str, None]: """Handle follow-up request after tool execution""" try: - print("[TOOL] Handling follow-up request with tool results") + logger.info("Handling follow-up request with tool results") # Create messages with tool call and result messages = original_payload["messages"].copy() @@ -341,11 +393,11 @@ async def handle_tool_follow_up(original_payload: Dict[str, Any], tool_call: Dic } # Stream the follow-up response - async for chunk in stream_llm_response(follow_up_payload): + async for chunk in stream_llm_response(follow_up_payload, citations_collector=citations_collector): yield chunk except Exception as e: - print(f"[ERROR] Tool follow-up failed: {e}") + logger.exception("Tool follow-up failed") yield f"data: {json.dumps({'type': 'error', 'content': f'Tool follow-up failed: {e}'})}\n\n" async def get_non_streaming_response(payload: Dict[str, Any]) -> tuple[str, List[str]]: @@ -397,7 +449,7 @@ async def options_health(): async def chat(request: ChatRequest): """Chat endpoint with RAG capabilities - supports both streaming and non-streaming""" try: - print(f"[CHAT] Processing message: {request.message[:100]}...") + logger.info("Processing chat message preview=%s", request.message[:100]) # Create initial payload payload = { @@ -440,15 +492,19 @@ async def chat(request: ChatRequest): } except Exception as e: - print(f"[ERROR] Chat handling failed: {e}") + logger.exception("Chat handling failed") raise HTTPException(status_code=500, detail=f"Request failed: {e}") if __name__ == "__main__": - print("🚀 Starting Kubeflow Docs HTTP API Server") - print(f" Port: {PORT}") - print(f" LLM Service: {KSERVE_URL}") - print(f" Milvus: {MILVUS_HOST}:{MILVUS_PORT}") - print(f" Collection: {MILVUS_COLLECTION}") + logging.basicConfig( + level=os.getenv("LOG_LEVEL", "INFO").upper(), + format="%(asctime)s %(levelname)s %(name)s - %(message)s", + ) + logger.info("Starting Kubeflow Docs HTTP API Server") + logger.info("Port: %s", PORT) + logger.info("LLM Service: %s", KSERVE_URL) + logger.info("Milvus: %s:%s", MILVUS_HOST, MILVUS_PORT) + logger.info("Collection: %s", MILVUS_COLLECTION) uvicorn.run( app, diff --git a/server/app.py b/server/app.py index 96b277c..d63b697 100644 --- a/server/app.py +++ b/server/app.py @@ -6,9 +6,19 @@ from websockets.server import serve from websockets.exceptions import ConnectionClosedError import logging +import sys +from pathlib import Path from typing import Dict, Any, List from sentence_transformers import SentenceTransformer from pymilvus import connections, Collection +import threading +import time + +REPO_ROOT = Path(__file__).resolve().parents[1] +if str(REPO_ROOT) not in sys.path: + sys.path.append(str(REPO_ROOT)) + +from shared.reranking import candidate_pool_limit, load_rerank_config_from_env, rerank_documents # Config KSERVE_URL = os.getenv("KSERVE_URL", "http://llama.docs-agent.svc.cluster.local/openai/v1/chat/completions") @@ -22,6 +32,9 @@ MILVUS_VECTOR_FIELD = os.getenv("MILVUS_VECTOR_FIELD", "vector") EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "sentence-transformers/all-mpnet-base-v2") +logger = logging.getLogger(__name__) +RERANK_CONFIG = load_rerank_config_from_env() + # System prompt SYSTEM_PROMPT = """ You are the Kubeflow Docs Assistant. @@ -65,6 +78,27 @@ + +_model_lock = threading.Lock() +_embedding_model = None + +def get_embedding_model(): + """Thread-safe lazy initialization of SentenceTransformer. + + The model is cached once per process because loading it on every request + adds avoidable latency and memory churn. + """ + global _embedding_model + if _embedding_model is None: + with _model_lock: + # Double-checked locking + if _embedding_model is None: + start_t = time.perf_counter() + logger.info("Lazy loading SentenceTransformer model '%s'", EMBEDDING_MODEL) + _embedding_model = SentenceTransformer(EMBEDDING_MODEL) + logger.info("Model loaded in %.3f seconds", time.perf_counter() - start_t) + return _embedding_model + def milvus_search(query: str, top_k: int = 5) -> Dict[str, Any]: """Execute a semantic search in Milvus and return structured JSON serializable results.""" try: @@ -73,22 +107,24 @@ def milvus_search(query: str, top_k: int = 5) -> Dict[str, Any]: collection = Collection(MILVUS_COLLECTION) collection.load() - # Encoder (same model as pipeline) - encoder = SentenceTransformer(EMBEDDING_MODEL) - query_vec = encoder.encode(query).tolist() + # Thread-safe cached encoder + model = get_embedding_model() + query_vec = model.encode(query).tolist() + + requested_top_k = max(1, int(top_k)) + candidate_limit = candidate_pool_limit(requested_top_k, RERANK_CONFIG) search_params = {"metric_type": "COSINE", "params": {"nprobe": 32}} results = collection.search( data=[query_vec], anns_field=MILVUS_VECTOR_FIELD, param=search_params, - limit=int(top_k), + limit=candidate_limit, output_fields=["file_path", "content_text", "citation_url"], ) hits = [] for hit in results[0]: - # similarity = 1 - distance for COSINE in Milvus similarity = 1.0 - float(hit.distance) entity = hit.entity content_text = entity.get("content_text") or "" @@ -100,15 +136,26 @@ def milvus_search(query: str, top_k: int = 5) -> Dict[str, Any]: "citation_url": entity.get("citation_url"), "content_text": content_text, }) + + hits = rerank_documents( + query=query, + docs=hits, + config=RERANK_CONFIG, + top_k=requested_top_k, + logger=logger, + log_prefix="websocket_search", + ) + return {"results": hits} - except Exception as e: - print(f"[ERROR] Milvus search failed: {e}") - return {"results": []} + except Exception: + logger.exception("Milvus search failed for query='%s' top_k=%s", query, top_k) + raise finally: try: connections.disconnect(alias="default") - except Exception: - pass + except Exception as disconnect_error: + logger.warning("Milvus disconnect failed: %s", disconnect_error) + TOOLS = [ { "type": "function", @@ -153,7 +200,7 @@ async def execute_tool(tool_call: Dict[str, Any]) -> tuple[str, List[str]]: query = arguments.get("query", "") top_k = arguments.get("top_k", 5) - print(f"[TOOL] Executing Milvus search for: '{query}' (top_k={top_k})") + logger.info("Executing Milvus search for query='%s' top_k=%s", query, top_k) result = milvus_search(query, top_k) # Collect citations @@ -178,7 +225,7 @@ async def execute_tool(tool_call: Dict[str, Any]) -> tuple[str, List[str]]: return f"Unknown tool: {function_name}", [] except Exception as e: - print(f"[ERROR] Tool execution failed: {e}") + logger.exception("Tool execution failed") return f"Tool execution failed: {e}", [] async def stream_llm_response(payload: Dict[str, Any], websocket, citations_collector: List[str] = None) -> None: @@ -190,7 +237,7 @@ async def stream_llm_response(payload: Dict[str, Any], websocket, citations_coll async with client.stream("POST", KSERVE_URL, json=payload) as response: if response.status_code != 200: error_msg = f"LLM service error: HTTP {response.status_code}" - print(f"[ERROR] {error_msg}") + logger.error(error_msg) await websocket.send(json.dumps({"type": "error", "content": error_msg})) return @@ -252,14 +299,14 @@ async def stream_llm_response(payload: Dict[str, Any], websocket, citations_coll # Handle finish reason - execute tools if needed if finish_reason == "tool_calls": - print(f"[TOOL] Finish reason: tool_calls, executing {len(tool_calls_buffer)} tools") + logger.info("Finish reason=tool_calls; executing %s tools", len(tool_calls_buffer)) # Execute all accumulated tool calls for tool_call in tool_calls_buffer.values(): if tool_call["function"]["name"] and tool_call["function"]["arguments"]: try: - print(f"[TOOL] Executing: {tool_call['function']['name']}") - print(f"[TOOL] Arguments: {tool_call['function']['arguments']}") + logger.info("Executing tool=%s", tool_call["function"]["name"]) + logger.debug("Tool arguments=%s", tool_call["function"]["arguments"]) result, tool_citations = await execute_tool(tool_call) @@ -277,7 +324,7 @@ async def stream_llm_response(payload: Dict[str, Any], websocket, citations_coll await handle_tool_follow_up(payload, tool_call, result, websocket, citations_collector) except Exception as e: - print(f"[ERROR] Tool execution error: {e}") + logger.exception("Tool execution error") await websocket.send(json.dumps({ "type": "error", "content": f"Tool execution failed: {e}" @@ -287,11 +334,11 @@ async def stream_llm_response(payload: Dict[str, Any], websocket, citations_coll break # Tool execution complete, exit streaming loop except json.JSONDecodeError as e: - print(f"[ERROR] JSON decode error: {e}, line: {line}") + logger.warning("JSON decode error: %s line=%s", e, line) continue except Exception as e: - print(f"[ERROR] Streaming failed: {e}") + logger.exception("Streaming failed") await websocket.send(json.dumps({"type": "error", "content": f"Streaming failed: {e}"})) async def handle_tool_follow_up(original_payload: Dict[str, Any], tool_call: Dict[str, Any], tool_result: str, websocket, citations_collector: List[str] = None) -> None: @@ -299,7 +346,7 @@ async def handle_tool_follow_up(original_payload: Dict[str, Any], tool_call: Dic if citations_collector is None: citations_collector = [] try: - print("[TOOL] Handling follow-up request with tool results") + logger.info("Handling follow-up request with tool results") # Create messages with tool call and result messages = original_payload["messages"].copy() @@ -329,13 +376,13 @@ async def handle_tool_follow_up(original_payload: Dict[str, Any], tool_call: Dic await stream_llm_response(follow_up_payload, websocket, citations_collector) except Exception as e: - print(f"[ERROR] Tool follow-up failed: {e}") + logger.exception("Tool follow-up failed") await websocket.send(json.dumps({"type": "error", "content": f"Tool follow-up failed: {e}"})) async def handle_chat(message: str, websocket) -> None: """Handle chat with tool calling support""" try: - print(f"[CHAT] Processing message: {message[:100]}...") + logger.info("Processing chat message preview=%s", message[:100]) # Create initial payload payload = { @@ -373,12 +420,12 @@ async def handle_chat(message: str, websocket) -> None: await websocket.send(json.dumps({"type": "done"})) except Exception as e: - print(f"[ERROR] Chat handling failed: {e}") + logger.exception("Chat handling failed") await websocket.send(json.dumps({"type": "error", "content": f"Request failed: {e}"})) async def handle_websocket(websocket, path): """Handle WebSocket connections""" - print(f"[WS] New connection from {websocket.remote_address}") + logger.info("New websocket connection from %s", websocket.remote_address) try: # Send welcome message @@ -402,20 +449,20 @@ async def handle_websocket(websocket, path): # Treat as plain text message pass - print(f"[WS] Received: {message[:100]}...") + logger.info("Received websocket message preview=%s", message[:100]) await handle_chat(message, websocket) except Exception as e: - print(f"[ERROR] Message processing error: {e}") + logger.exception("Message processing error") await websocket.send(json.dumps({ "type": "error", "content": f"Message processing failed: {e}" })) except ConnectionClosedError: - print("[WS] Connection closed") + logger.info("Websocket connection closed") except Exception as e: - print(f"[ERROR] WebSocket error: {e}") + logger.exception("Websocket error") async def health_check(path, request_headers): """Handle HTTP health checks""" @@ -425,11 +472,11 @@ async def health_check(path, request_headers): async def main(): """Start the WebSocket server""" - print("🚀 Starting Kubeflow Docs WebSocket Server") - print(f" Port: {PORT}") - print(f" LLM Service: {KSERVE_URL}") - print(f" Milvus: {MILVUS_HOST}:{MILVUS_PORT}") - print(f" Collection: {MILVUS_COLLECTION}") + logger.info("Starting Kubeflow Docs WebSocket Server") + logger.info("Port: %s", PORT) + logger.info("LLM Service: %s", KSERVE_URL) + logger.info("Milvus: %s:%s", MILVUS_HOST, MILVUS_PORT) + logger.info("Collection: %s", MILVUS_COLLECTION) # Configure logging logging.getLogger("websockets").setLevel(logging.WARNING) @@ -443,12 +490,16 @@ async def main(): ping_interval=30, ping_timeout=10 ): - print("✅ WebSocket server is running...") - print(f" WebSocket: ws://localhost:{PORT}") - print(f" Health: http://localhost:{PORT}/health") + logger.info("WebSocket server is running") + logger.info("WebSocket: ws://localhost:%s", PORT) + logger.info("Health: http://localhost:%s/health", PORT) # Keep server running await asyncio.Future() if __name__ == "__main__": + logging.basicConfig( + level=os.getenv("LOG_LEVEL", "INFO").upper(), + format="%(asctime)s %(levelname)s %(name)s - %(message)s", + ) asyncio.run(main()) diff --git a/shared/__init__.py b/shared/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/shared/reranking.py b/shared/reranking.py new file mode 100644 index 0000000..e051816 --- /dev/null +++ b/shared/reranking.py @@ -0,0 +1,238 @@ +import logging +import os +import re +from dataclasses import dataclass +from typing import Any, Dict, List, Optional + + +LOGGER = logging.getLogger(__name__) + + +@dataclass(frozen=True) +class RerankConfig: + """Reranking defaults favor a lightweight, recall-first blend. + + The similarity weight remains the primary signal, while keyword and metadata + weights stay small so reranking improves relevance without overpowering the + original vector score. + """ + + enabled: bool = True + candidate_multiplier: int = 3 + similarity_weight: float = 0.7 + keyword_weight: float = 0.2 + metadata_weight: float = 0.1 + max_candidates: int = 50 + min_token_len: int = 3 + debug_logging: bool = False + log_top_n: int = 5 + + +def _parse_bool(value: str, default: bool) -> bool: + if value is None: + return default + + normalized = value.strip().lower() + if normalized in {"1", "true", "yes", "on"}: + return True + if normalized in {"0", "false", "no", "off"}: + return False + + LOGGER.warning("Invalid boolean value '%s'; using default=%s", value, default) + return default + + +def _parse_int_env(name: str, default: int, minimum: int = 1) -> int: + raw = os.getenv(name) + if raw is None: + return default + + try: + return max(minimum, int(raw)) + except (TypeError, ValueError): + LOGGER.warning("Invalid integer for %s=%s; using default=%s", name, raw, default) + return default + + +def _parse_float_env(name: str, default: float) -> float: + raw = os.getenv(name) + if raw is None: + return default + + try: + return float(raw) + except (TypeError, ValueError): + LOGGER.warning("Invalid float for %s=%s; using default=%s", name, raw, default) + return default + + +def load_rerank_config_from_env() -> RerankConfig: + return RerankConfig( + enabled=_parse_bool(os.getenv("RERANK_ENABLED", "true"), True), + candidate_multiplier=_parse_int_env("RERANK_CANDIDATE_MULTIPLIER", 3), + similarity_weight=_parse_float_env("RERANK_SIMILARITY_WEIGHT", 0.7), + keyword_weight=_parse_float_env("RERANK_KEYWORD_WEIGHT", 0.2), + metadata_weight=_parse_float_env("RERANK_METADATA_WEIGHT", 0.1), + max_candidates=_parse_int_env("RERANK_MAX_CANDIDATES", 50), + min_token_len=_parse_int_env("RERANK_MIN_TOKEN_LEN", 3), + debug_logging=_parse_bool(os.getenv("RERANK_DEBUG_LOG", "false"), False), + log_top_n=_parse_int_env("RERANK_LOG_TOP_N", 5), + ) + + +def candidate_pool_limit(top_k: int, config: RerankConfig) -> int: + try: + requested_top_k = max(1, int(top_k)) + except (TypeError, ValueError): + LOGGER.warning("Invalid top_k=%s; using fallback top_k=1", top_k) + requested_top_k = 1 + if not config.enabled: + return requested_top_k + expanded = requested_top_k * max(1, config.candidate_multiplier) + return min(max(requested_top_k, expanded), max(1, config.max_candidates)) + + +def _tokenize_text(text: str) -> List[str]: + return re.findall(r"[a-zA-Z0-9_]+", (text or "").lower()) + + +def _query_terms(query: str, min_token_len: int) -> set: + return {token for token in _tokenize_text(query) if len(token) >= min_token_len} + + +def _keyword_overlap_score(query_terms: set, content: str) -> float: + if not query_terms: + return 0.0 + content_terms = set(_tokenize_text(content)) + overlap = query_terms.intersection(content_terms) + return len(overlap) / len(query_terms) + + +def _metadata_score(query_terms: set, file_path: str, citation_url: str) -> float: + if not query_terms: + return 0.0 + metadata_terms = set(_tokenize_text(file_path)) | set(_tokenize_text(citation_url)) + overlap = query_terms.intersection(metadata_terms) + return len(overlap) / len(query_terms) + + +def _extract_similarity(doc: Dict[str, Any]) -> float: + if doc.get("similarity") is not None: + return float(doc.get("similarity", 0.0)) + if doc.get("distance") is not None: + return 1.0 - float(doc.get("distance", 0.0)) + return 0.0 + + +def _log_docs( + logger: Optional[logging.Logger], + enabled: bool, + stage: str, + docs: List[Dict[str, Any]], + top_n: int, + log_prefix: str, +) -> None: + if not logger or not enabled: + return + + logger.info("[%s] %s top %s documents", log_prefix, stage, min(top_n, len(docs))) + for idx, doc in enumerate(docs[:top_n], start=1): + logger.info( + "[%s] %s #%s path=%s similarity=%.4f rerank=%.4f keyword=%.4f metadata=%.4f", + log_prefix, + stage, + idx, + doc.get("file_path", ""), + float(doc.get("similarity", 0.0)), + float(doc.get("rerank_score", 0.0)), + float(doc.get("keyword_score", 0.0)), + float(doc.get("metadata_score", 0.0)), + ) + + +def rerank_documents( + query: str, + docs: List[Dict[str, Any]], + config: RerankConfig, + top_k: int, + logger: Optional[logging.Logger] = None, + log_prefix: str = "retrieval", +) -> List[Dict[str, Any]]: + try: + requested_top_k = max(1, int(top_k)) + except (TypeError, ValueError): + LOGGER.warning("Invalid top_k=%s; using fallback top_k=1", top_k) + requested_top_k = 1 + if not docs: + return [] + + query_terms = _query_terms(query, config.min_token_len) + + normalized_docs: List[Dict[str, Any]] = [] + for doc in docs: + normalized = dict(doc) + normalized["similarity"] = _extract_similarity(normalized) + normalized_docs.append(normalized) + + _log_docs( + logger, + config.debug_logging, + "before_rerank", + normalized_docs, + config.log_top_n, + log_prefix, + ) + + if not config.enabled: + selected = normalized_docs[:requested_top_k] + for doc in selected: + doc["keyword_score"] = 0.0 + doc["metadata_score"] = 0.0 + doc["rerank_score"] = round(float(doc.get("similarity", 0.0)), 4) + _log_docs( + logger, + config.debug_logging, + "after_rerank_disabled", + selected, + config.log_top_n, + log_prefix, + ) + return selected + + reranked: List[Dict[str, Any]] = [] + for doc in normalized_docs: + keyword_score = _keyword_overlap_score(query_terms, doc.get("content_text", "")) + metadata_score = _metadata_score(query_terms, doc.get("file_path", ""), doc.get("citation_url", "")) + final_score = ( + config.similarity_weight * float(doc.get("similarity", 0.0)) + + config.keyword_weight * keyword_score + + config.metadata_weight * metadata_score + ) + + doc["keyword_score"] = round(keyword_score, 4) + doc["metadata_score"] = round(metadata_score, 4) + doc["rerank_score"] = round(final_score, 4) + reranked.append(doc) + + # Sort by score and stable tie-breakers for deterministic ordering in tests and production. + reranked.sort( + key=lambda item: ( + -float(item.get("rerank_score", 0.0)), + -float(item.get("similarity", 0.0)), + str(item.get("file_path", "")), + str(item.get("citation_url", "")), + str(item.get("content_text", "")), + ) + ) + selected = reranked[:requested_top_k] + + _log_docs( + logger, + config.debug_logging, + "after_rerank", + selected, + config.log_top_n, + log_prefix, + ) + + return selected diff --git a/shared/tests/test_reranking.py b/shared/tests/test_reranking.py new file mode 100644 index 0000000..19dbe7d --- /dev/null +++ b/shared/tests/test_reranking.py @@ -0,0 +1,119 @@ +import logging +import sys +from pathlib import Path + +REPO_ROOT = Path(__file__).resolve().parents[2] +if str(REPO_ROOT) not in sys.path: + sys.path.append(str(REPO_ROOT)) + +from shared.reranking import RerankConfig, candidate_pool_limit, load_rerank_config_from_env, rerank_documents + + +def test_candidate_pool_limit_expands_and_caps(): + config = RerankConfig(enabled=True, candidate_multiplier=3, max_candidates=10) + assert candidate_pool_limit(2, config) == 6 + assert candidate_pool_limit(4, config) == 10 + + +def test_candidate_pool_limit_handles_invalid_top_k(caplog): + config = RerankConfig(enabled=True, candidate_multiplier=3, max_candidates=10) + + with caplog.at_level(logging.WARNING): + limit = candidate_pool_limit("invalid", config) + + assert limit == 3 + assert "Invalid top_k=invalid" in caplog.text + + +def test_rerank_documents_returns_empty_for_empty_input(): + config = RerankConfig() + assert rerank_documents("kubeflow", [], config, top_k=5) == [] + + +def test_rerank_documents_scoring_and_ordering(): + config = RerankConfig( + enabled=True, + similarity_weight=0.7, + keyword_weight=0.2, + metadata_weight=0.1, + min_token_len=3, + ) + + docs = [ + { + "similarity": 0.9, + "file_path": "docs/kserve.md", + "citation_url": "https://kubeflow.org/docs/components/kserve", + "content_text": "kserve inference service gpu deployment", + }, + { + "similarity": 0.95, + "file_path": "docs/pipelines.md", + "citation_url": "https://kubeflow.org/docs/components/pipelines", + "content_text": "pipeline runs and scheduling", + }, + ] + + ranked = rerank_documents("kserve gpu", docs, config, top_k=2) + + assert len(ranked) == 2 + assert ranked[0]["file_path"] == "docs/kserve.md" + assert ranked[0]["rerank_score"] >= ranked[1]["rerank_score"] + + +def test_rerank_documents_deterministic_tie_breaking(): + config = RerankConfig(enabled=True, similarity_weight=1.0, keyword_weight=0.0, metadata_weight=0.0) + + docs = [ + { + "similarity": 0.5, + "file_path": "z.md", + "citation_url": "https://example.com/z", + "content_text": "same", + }, + { + "similarity": 0.5, + "file_path": "a.md", + "citation_url": "https://example.com/a", + "content_text": "same", + }, + ] + + ranked = rerank_documents("anything", docs, config, top_k=2) + + assert [doc["file_path"] for doc in ranked] == ["a.md", "z.md"] + + +def test_load_rerank_config_from_env_fallbacks(monkeypatch, caplog): + monkeypatch.setenv("RERANK_ENABLED", "not-a-bool") + monkeypatch.setenv("RERANK_CANDIDATE_MULTIPLIER", "nan") + monkeypatch.setenv("RERANK_SIMILARITY_WEIGHT", "bad") + monkeypatch.setenv("RERANK_KEYWORD_WEIGHT", "bad") + monkeypatch.setenv("RERANK_METADATA_WEIGHT", "bad") + monkeypatch.setenv("RERANK_MAX_CANDIDATES", "bad") + monkeypatch.setenv("RERANK_MIN_TOKEN_LEN", "bad") + monkeypatch.setenv("RERANK_DEBUG_LOG", "bad") + monkeypatch.setenv("RERANK_LOG_TOP_N", "bad") + + with caplog.at_level(logging.WARNING): + config = load_rerank_config_from_env() + + assert config == RerankConfig() + assert "Invalid boolean value" in caplog.text + assert "Invalid integer" in caplog.text + assert "Invalid float" in caplog.text + + +def test_rerank_disabled_preserves_similarity_order(): + config = RerankConfig(enabled=False) + + docs = [ + {"similarity": 0.8, "file_path": "one.md", "citation_url": "", "content_text": ""}, + {"similarity": 0.7, "file_path": "two.md", "citation_url": "", "content_text": ""}, + ] + + ranked = rerank_documents("kubeflow", docs, config, top_k=2) + + assert [doc["file_path"] for doc in ranked] == ["one.md", "two.md"] + assert ranked[0]["keyword_score"] == 0.0 + assert ranked[0]["metadata_score"] == 0.0