diff --git a/README.md b/README.md
index b387d19..6713c75 100644
--- a/README.md
+++ b/README.md
@@ -584,6 +584,51 @@ if data.get('citations'):
sentence-transformers/all-mpnet-base-v2 |
Embedding model |
+
+RERANK_ENABLED |
+true |
+Enable lightweight post-retrieval reranking |
+
+
+RERANK_CANDIDATE_MULTIPLIER |
+3 |
+Multiplier for initial candidate pool before reranking |
+
+
+RERANK_SIMILARITY_WEIGHT |
+0.7 |
+Weight for vector similarity score in final rerank score |
+
+
+RERANK_KEYWORD_WEIGHT |
+0.2 |
+Weight for query term overlap with chunk content |
+
+
+RERANK_METADATA_WEIGHT |
+0.1 |
+Weight for query term overlap with file path and citation URL |
+
+
+RERANK_MAX_CANDIDATES |
+50 |
+Upper bound for candidate pool size fetched before reranking |
+
+
+RERANK_MIN_TOKEN_LEN |
+3 |
+Minimum token length used for query/content term overlap scoring |
+
+
+RERANK_DEBUG_LOG |
+false |
+Enable before/after reranking logs with component scores |
+
+
+RERANK_LOG_TOP_N |
+5 |
+Number of documents to show in reranking debug logs |
+
diff --git a/eval_retrieval.py b/eval_retrieval.py
new file mode 100644
index 0000000..fc2d85e
--- /dev/null
+++ b/eval_retrieval.py
@@ -0,0 +1,140 @@
+import argparse
+import os
+from typing import Any, Dict, List
+
+from pymilvus import Collection, connections
+from sentence_transformers import SentenceTransformer
+
+from shared.reranking import candidate_pool_limit, load_rerank_config_from_env, rerank_documents
+
+
+DEFAULT_QUERIES = [
+ "How do I create a Kubeflow Pipeline?",
+ "How to deploy an InferenceService in KServe?",
+ "Kubeflow Notebook setup requirements",
+]
+
+
+def build_args() -> argparse.Namespace:
+ parser = argparse.ArgumentParser(description="Evaluate retrieval results before/after reranking.")
+ parser.add_argument("--queries", nargs="*", default=DEFAULT_QUERIES, help="Queries to evaluate")
+ parser.add_argument("--top-k", type=int, default=5, help="Final top-k results to keep")
+ parser.add_argument(
+ "--show-content-chars",
+ type=int,
+ default=180,
+ help="Number of content characters to print per result",
+ )
+ return parser.parse_args()
+
+
+def _print_docs(title: str, docs: List[Dict[str, Any]], show_content_chars: int) -> None:
+ print(f"\n{title}")
+ if not docs:
+ print(" (no results)")
+ return
+
+ for idx, doc in enumerate(docs, start=1):
+ content = (doc.get("content_text") or "").replace("\n", " ").strip()
+ if len(content) > show_content_chars:
+ content = content[:show_content_chars] + "..."
+
+ print(
+ f" {idx}. score={doc.get('rerank_score', doc.get('similarity', 0.0)):.4f} "
+ f"sim={doc.get('similarity', 0.0):.4f} "
+ f"keyword={doc.get('keyword_score', 0.0):.4f} "
+ f"metadata={doc.get('metadata_score', 0.0):.4f}"
+ )
+ print(f" file={doc.get('file_path', '')}")
+ print(f" url={doc.get('citation_url', '')}")
+ print(f" text={content}")
+
+
+def retrieve_candidates(
+ query: str,
+ model: SentenceTransformer,
+ collection: Collection,
+ top_k: int,
+ candidate_limit: int,
+ vector_field: str,
+) -> List[Dict[str, Any]]:
+ query_vec = model.encode(query).tolist()
+ search_params = {"metric_type": "COSINE", "params": {"nprobe": 32}}
+
+ results = collection.search(
+ data=[query_vec],
+ anns_field=vector_field,
+ param=search_params,
+ limit=candidate_limit,
+ output_fields=["file_path", "content_text", "citation_url"],
+ )
+
+ docs: List[Dict[str, Any]] = []
+ for hit in results[0]:
+ entity = hit.entity
+ docs.append(
+ {
+ "similarity": 1.0 - float(hit.distance),
+ "file_path": entity.get("file_path"),
+ "citation_url": entity.get("citation_url"),
+ "content_text": entity.get("content_text") or "",
+ }
+ )
+
+ return docs
+
+
+def main() -> None:
+ args = build_args()
+
+ milvus_host = os.getenv("MILVUS_HOST", "my-release-milvus.docs-agent.svc.cluster.local")
+ milvus_port = os.getenv("MILVUS_PORT", "19530")
+ milvus_collection = os.getenv("MILVUS_COLLECTION", "docs_rag")
+ milvus_vector_field = os.getenv("MILVUS_VECTOR_FIELD", "vector")
+ embedding_model_name = os.getenv("EMBEDDING_MODEL", "sentence-transformers/all-mpnet-base-v2")
+
+ rerank_config = load_rerank_config_from_env()
+ requested_top_k = max(1, int(args.top_k))
+ candidate_limit = candidate_pool_limit(requested_top_k, rerank_config)
+
+ print("Retrieval evaluation configuration")
+ print(f"- collection: {milvus_collection}")
+ print(f"- top_k: {requested_top_k}")
+ print(f"- candidate_limit: {candidate_limit}")
+ print(f"- rerank_enabled: {rerank_config.enabled}")
+
+ connections.connect(alias="default", host=milvus_host, port=milvus_port)
+ try:
+ collection = Collection(milvus_collection)
+ collection.load()
+ model = SentenceTransformer(embedding_model_name)
+
+ for query in args.queries:
+ print("\n" + "=" * 100)
+ print(f"Query: {query}")
+
+ candidates = retrieve_candidates(
+ query=query,
+ model=model,
+ collection=collection,
+ top_k=requested_top_k,
+ candidate_limit=candidate_limit,
+ vector_field=milvus_vector_field,
+ )
+
+ before_docs = candidates[:requested_top_k]
+ after_docs = rerank_documents(
+ query=query,
+ docs=candidates,
+ config=rerank_config,
+ top_k=requested_top_k,
+ )
+
+ _print_docs("Before reranking", before_docs, args.show_content_chars)
+ _print_docs("After reranking", after_docs, args.show_content_chars)
+ finally:
+ connections.disconnect(alias="default")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/git_logs_temp.txt b/git_logs_temp.txt
new file mode 100644
index 0000000..b2fd484
Binary files /dev/null and b/git_logs_temp.txt differ
diff --git a/git_msg.txt b/git_msg.txt
new file mode 100644
index 0000000..c6bf362
--- /dev/null
+++ b/git_msg.txt
@@ -0,0 +1,12 @@
+bd1725fdd3d161db68f327df783abe7f8d464cf7
+Ayush-kathil
+Add thread-safe model init and Milvus search
+Introduce thread-safe lazy initialization for SentenceTransformer in server-https/app.py and server/app.py using a lock and double-checked locking, with timing/info logs to avoid repeated model loading. Add threading and time imports. In server/app.py add milvus_search(query, top_k) to perform Milvus connection, load collection, encode query with the cached model, execute search, format results (similarity, file_path, citation_url, truncated content_text), handle errors, and ensure disconnect. These changes reduce initialization overhead and centralize Milvus search logic.
+
+---
+f05614a123776a4009d7104125931289d23bcdcc
+Ayush Gupta
+fix(server-https): preserve multi-hop citations in stream_llm_response
+Signed-off-by: Ayush-kathil
+
+---
\ No newline at end of file
diff --git a/kagent-feast-mcp/mcp-server/server.py b/kagent-feast-mcp/mcp-server/server.py
index 44a9fcb..f286c7c 100644
--- a/kagent-feast-mcp/mcp-server/server.py
+++ b/kagent-feast-mcp/mcp-server/server.py
@@ -1,26 +1,67 @@
from fastmcp import FastMCP
+import logging
+import os
+import sys
+import threading
+from pathlib import Path
+from typing import Any, Dict, List
from pymilvus import MilvusClient
from sentence_transformers import SentenceTransformer
-MILVUS_URI = "http://milvus..svc.cluster.local:19530"
-MILVUS_USER = "root"
-MILVUS_PASSWORD = "Milvus"
-COLLECTION_NAME = "kubeflow_docs_docs_rag"
-EMBEDDING_MODEL = "sentence-transformers/all-mpnet-base-v2"
-PORT = 8000
+REPO_ROOT = Path(__file__).resolve().parents[2]
+if str(REPO_ROOT) not in sys.path:
+ sys.path.append(str(REPO_ROOT))
+
+from shared.reranking import candidate_pool_limit, load_rerank_config_from_env, rerank_documents
+
+MILVUS_URI = os.getenv("MILVUS_URI", "http://milvus..svc.cluster.local:19530")
+MILVUS_USER = os.getenv("MILVUS_USER", "root")
+MILVUS_PASSWORD = os.getenv("MILVUS_PASSWORD", "Milvus")
+COLLECTION_NAME = os.getenv("COLLECTION_NAME", "kubeflow_docs_docs_rag")
+EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "sentence-transformers/all-mpnet-base-v2")
+PORT = int(os.getenv("PORT", "8000"))
+RERANK_CONFIG = load_rerank_config_from_env()
mcp = FastMCP("Kubeflow Docs MCP Server")
+logger = logging.getLogger(__name__)
+logging.basicConfig(level=os.getenv("LOG_LEVEL", "INFO").upper())
model: SentenceTransformer = None
client: MilvusClient = None
+_initialized = False
+_init_lock = threading.Lock()
def _init():
- global model, client
- if model is None:
- model = SentenceTransformer(EMBEDDING_MODEL)
- if client is None:
- client = MilvusClient(uri=MILVUS_URI, user=MILVUS_USER, password=MILVUS_PASSWORD)
+ """Initialize shared model/client exactly once.
+
+ Synchronization strategy:
+ - Fast path: return immediately after successful initialization.
+ - Slow path: take a process-local lock and re-check state (double-checked locking).
+
+ This guarantees that concurrent callers block until initialization completes,
+ and all callers observe the same initialized instances.
+ """
+ global model, client, _initialized
+
+ if _initialized:
+ return
+
+ with _init_lock:
+ if _initialized:
+ return
+
+ logger.info("Initializing shared MCP resources")
+
+ # Build local instances first, then publish atomically under the lock.
+ local_model = SentenceTransformer(EMBEDDING_MODEL)
+ local_client = MilvusClient(uri=MILVUS_URI, user=MILVUS_USER, password=MILVUS_PASSWORD)
+
+ model = local_model
+ client = local_client
+ _initialized = True
+
+ logger.info("Shared MCP resources initialized")
@mcp.tool()
@@ -37,24 +78,54 @@ def search_kubeflow_docs(query: str, top_k: int = 5) -> str:
_init()
embedding = model.encode(query).tolist()
+ requested_top_k = max(1, int(top_k))
+ candidate_limit = candidate_pool_limit(requested_top_k, RERANK_CONFIG)
hits = client.search(
collection_name=COLLECTION_NAME,
data=[embedding],
- limit=top_k,
+ limit=candidate_limit,
output_fields=["content_text", "citation_url", "file_path"],
)[0]
if not hits:
return "No results found for your query."
+ docs: List[Dict[str, Any]] = []
+ for hit in hits:
+ entity = hit.get("entity", {})
+ content_text = entity.get("content_text") or ""
+ if isinstance(content_text, str) and len(content_text) > 400:
+ content_text = content_text[:400] + "..."
+
+ docs.append(
+ {
+ "distance": float(hit.get("distance", 0.0)),
+ "similarity": 1.0 - float(hit.get("distance", 0.0)),
+ "file_path": entity.get("file_path"),
+ "citation_url": entity.get("citation_url"),
+ "content_text": content_text,
+ }
+ )
+
+ selected_docs = rerank_documents(
+ query=query,
+ docs=docs,
+ config=RERANK_CONFIG,
+ top_k=requested_top_k,
+ logger=logger,
+ log_prefix="mcp_search",
+ )
+
results = []
- for i, hit in enumerate(hits, 1):
- entity = hit["entity"]
- entry = f"### Result {i} (score: {hit['distance']:.4f})"
- entry += f"\n**Source:** {entity.get('citation_url', '')}"
- entry += f"\n**File:** {entity.get('file_path', '')}"
- entry += f"\n\n{entity.get('content_text', '')}\n"
+ for i, doc in enumerate(selected_docs, 1):
+ entry = f"### Result {i} (rerank_score: {doc.get('rerank_score', 0.0):.4f})"
+ entry += f"\n**Similarity:** {doc.get('similarity', 0.0):.4f}"
+ entry += f"\n**Keyword Score:** {doc.get('keyword_score', 0.0):.4f}"
+ entry += f"\n**Metadata Score:** {doc.get('metadata_score', 0.0):.4f}"
+ entry += f"\n**Source:** {doc.get('citation_url', '')}"
+ entry += f"\n**File:** {doc.get('file_path', '')}"
+ entry += f"\n\n{doc.get('content_text', '')}\n"
results.append(entry)
return "\n---\n".join(results)
diff --git a/kagent-feast-mcp/mcp-server/tests/test_init_concurrency.py b/kagent-feast-mcp/mcp-server/tests/test_init_concurrency.py
new file mode 100644
index 0000000..e1f2aef
--- /dev/null
+++ b/kagent-feast-mcp/mcp-server/tests/test_init_concurrency.py
@@ -0,0 +1,129 @@
+import importlib.util
+import logging
+import sys
+import threading
+import types
+from pathlib import Path
+
+
+class _FastMCPStub:
+ def __init__(self, _name: str):
+ self._name = _name
+
+ def tool(self):
+ def _decorator(func):
+ return func
+
+ return _decorator
+
+ def run(self, **_kwargs):
+ return None
+
+
+def _load_server_module():
+ module_name = "mcp_server_under_test"
+ module_path = Path(__file__).resolve().parents[1] / "server.py"
+
+ fastmcp_stub = types.ModuleType("fastmcp")
+ fastmcp_stub.FastMCP = _FastMCPStub
+
+ pymilvus_stub = types.ModuleType("pymilvus")
+
+ class _MilvusClientStub:
+ def __init__(self, **_kwargs):
+ pass
+
+ pymilvus_stub.MilvusClient = _MilvusClientStub
+
+ sentence_stub = types.ModuleType("sentence_transformers")
+
+ class _SentenceTransformerStub:
+ def __init__(self, *_args, **_kwargs):
+ pass
+
+ sentence_stub.SentenceTransformer = _SentenceTransformerStub
+
+ sys.modules["fastmcp"] = fastmcp_stub
+ sys.modules["pymilvus"] = pymilvus_stub
+ sys.modules["sentence_transformers"] = sentence_stub
+
+ spec = importlib.util.spec_from_file_location(module_name, module_path)
+ module = importlib.util.module_from_spec(spec)
+ assert spec is not None and spec.loader is not None
+ spec.loader.exec_module(module)
+ return module
+
+
+def test_init_is_thread_safe_and_idempotent(monkeypatch, caplog):
+ server = _load_server_module()
+
+ model_init_count = 0
+ client_init_count = 0
+ init_count_lock = threading.Lock()
+
+ class FakeModel:
+ pass
+
+ class FakeClient:
+ pass
+
+ def fake_sentence_transformer(_model_name):
+ nonlocal model_init_count
+ with init_count_lock:
+ model_init_count += 1
+ return FakeModel()
+
+ def fake_milvus_client(**_kwargs):
+ nonlocal client_init_count
+ with init_count_lock:
+ client_init_count += 1
+ return FakeClient()
+
+ monkeypatch.setattr(server, "SentenceTransformer", fake_sentence_transformer)
+ monkeypatch.setattr(server, "MilvusClient", fake_milvus_client)
+
+ server.model = None
+ server.client = None
+ server._initialized = False
+
+ workers = 32
+ barrier = threading.Barrier(workers)
+ errors = []
+ seen_models = []
+ seen_clients = []
+ seen_lock = threading.Lock()
+
+ def worker():
+ try:
+ barrier.wait()
+ server._init()
+ with seen_lock:
+ seen_models.append(server.model)
+ seen_clients.append(server.client)
+ except Exception as exc: # pragma: no cover
+ with seen_lock:
+ errors.append(exc)
+
+ with caplog.at_level(logging.INFO, logger=server.__name__):
+ threads = [threading.Thread(target=worker) for _ in range(workers)]
+ for thread in threads:
+ thread.start()
+ for thread in threads:
+ thread.join(timeout=5)
+
+ assert not errors
+ assert model_init_count == 1
+ assert client_init_count == 1
+ assert server._initialized is True
+
+ first_model = seen_models[0]
+ first_client = seen_clients[0]
+
+ assert first_model is not None
+ assert first_client is not None
+ assert all(model is first_model for model in seen_models)
+ assert all(client is first_client for client in seen_clients)
+
+ messages = [record.getMessage() for record in caplog.records]
+ assert messages.count("Initializing shared MCP resources") == 1
+ assert messages.count("Shared MCP resources initialized") == 1
diff --git a/server-https/app.py b/server-https/app.py
index 694af8a..0c6b118 100644
--- a/server-https/app.py
+++ b/server-https/app.py
@@ -1,14 +1,46 @@
import os
import json
import httpx
+import logging
from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import uvicorn
+import sys
+from pathlib import Path
from typing import Dict, Any, List, Optional, AsyncGenerator
from sentence_transformers import SentenceTransformer
from pymilvus import connections, Collection
+import threading
+import time
+
+REPO_ROOT = Path(__file__).resolve().parents[1]
+if str(REPO_ROOT) not in sys.path:
+ sys.path.append(str(REPO_ROOT))
+
+from shared.reranking import candidate_pool_limit, load_rerank_config_from_env, rerank_documents
+
+_model_lock = threading.Lock()
+_embedding_model = None
+
+
+def get_embedding_model():
+ """Thread-safe lazy initialization of SentenceTransformer.
+
+ The model is cached once per process because loading it on every request
+ adds avoidable latency and memory churn.
+ """
+ global _embedding_model
+ if _embedding_model is None:
+ with _model_lock:
+ # Double-checked locking
+ if _embedding_model is None:
+ start_t = time.perf_counter()
+ logger.info("Lazy loading SentenceTransformer model '%s'", EMBEDDING_MODEL)
+ _embedding_model = SentenceTransformer(EMBEDDING_MODEL)
+ logger.info("Model loaded in %.3f seconds", time.perf_counter() - start_t)
+ return _embedding_model
# Config
KSERVE_URL = os.getenv("KSERVE_URL", "http://llama.docs-agent.svc.cluster.local/openai/v1/chat/completions")
@@ -22,6 +54,9 @@
MILVUS_VECTOR_FIELD = os.getenv("MILVUS_VECTOR_FIELD", "vector")
EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "sentence-transformers/all-mpnet-base-v2")
+logger = logging.getLogger(__name__)
+RERANK_CONFIG = load_rerank_config_from_env()
+
# System prompt (same as WebSocket version)
SYSTEM_PROMPT = """
You are the Kubeflow Docs Assistant.
@@ -113,6 +148,7 @@ class ChatRequest(BaseModel):
def milvus_search(query: str, top_k: int = 5) -> Dict[str, Any]:
"""Execute a semantic search in Milvus and return structured JSON serializable results."""
+
try:
# Connect to Milvus
connections.connect(alias="default", host=MILVUS_HOST, port=MILVUS_PORT)
@@ -120,15 +156,18 @@ def milvus_search(query: str, top_k: int = 5) -> Dict[str, Any]:
collection.load()
# Encoder (same model as pipeline)
- encoder = SentenceTransformer(EMBEDDING_MODEL)
- query_vec = encoder.encode(query).tolist()
+ model = get_embedding_model()
+ query_vec = model.encode(query).tolist()
+
+ requested_top_k = max(1, int(top_k))
+ candidate_limit = candidate_pool_limit(requested_top_k, RERANK_CONFIG)
search_params = {"metric_type": "COSINE", "params": {"nprobe": 32}}
results = collection.search(
data=[query_vec],
anns_field=MILVUS_VECTOR_FIELD,
param=search_params,
- limit=int(top_k),
+ limit=candidate_limit,
output_fields=["file_path", "content_text", "citation_url"],
)
@@ -146,15 +185,25 @@ def milvus_search(query: str, top_k: int = 5) -> Dict[str, Any]:
"citation_url": entity.get("citation_url"),
"content_text": content_text,
})
+
+ hits = rerank_documents(
+ query=query,
+ docs=hits,
+ config=RERANK_CONFIG,
+ top_k=requested_top_k,
+ logger=logger,
+ log_prefix="https_search",
+ )
+
return {"results": hits}
- except Exception as e:
- print(f"[ERROR] Milvus search failed: {e}")
- return {"results": []}
+ except Exception:
+ logger.exception("Milvus search failed for query='%s' top_k=%s", query, top_k)
+ raise
finally:
try:
connections.disconnect(alias="default")
- except Exception:
- pass
+ except Exception as disconnect_error:
+ logger.warning("Milvus disconnect failed: %s", disconnect_error)
async def execute_tool(tool_call: Dict[str, Any]) -> tuple[str, List[str]]:
"""Execute a tool call and return the result and citations"""
@@ -166,7 +215,7 @@ async def execute_tool(tool_call: Dict[str, Any]) -> tuple[str, List[str]]:
query = arguments.get("query", "")
top_k = arguments.get("top_k", 5)
- print(f"[TOOL] Executing Milvus search for: '{query}' (top_k={top_k})")
+ logger.info("Executing Milvus search for query='%s' top_k=%s", query, top_k)
result = milvus_search(query, top_k)
# Collect citations
@@ -191,19 +240,21 @@ async def execute_tool(tool_call: Dict[str, Any]) -> tuple[str, List[str]]:
return f"Unknown tool: {function_name}", []
except Exception as e:
- print(f"[ERROR] Tool execution failed: {e}")
+ logger.exception("Tool execution failed")
return f"Tool execution failed: {e}", []
-async def stream_llm_response(payload: Dict[str, Any]) -> AsyncGenerator[str, None]:
+async def stream_llm_response(payload: Dict[str, Any], citations_collector: Optional[List[str]] = None) -> AsyncGenerator[str, None]:
"""Stream response from LLM and handle tool calls, yielding SSE events"""
- citations_collector = []
+ is_outermost = citations_collector is None
+ if citations_collector is None:
+ citations_collector = []
try:
async with httpx.AsyncClient(timeout=120) as client:
async with client.stream("POST", KSERVE_URL, json=payload) as response:
if response.status_code != 200:
error_msg = f"LLM service error: HTTP {response.status_code}"
- print(f"[ERROR] {error_msg}")
+ logger.error(error_msg)
yield f"data: {json.dumps({'type': 'error', 'content': error_msg})}\n\n"
return
@@ -262,14 +313,14 @@ async def stream_llm_response(payload: Dict[str, Any]) -> AsyncGenerator[str, No
# Handle finish reason - execute tools if needed
if finish_reason == "tool_calls":
- print(f"[TOOL] Finish reason: tool_calls, executing {len(tool_calls_buffer)} tools")
+ logger.info("Finish reason=tool_calls; executing %s tools", len(tool_calls_buffer))
# Execute all accumulated tool calls
for tool_call in tool_calls_buffer.values():
if tool_call["function"]["name"] and tool_call["function"]["arguments"]:
try:
- print(f"[TOOL] Executing: {tool_call['function']['name']}")
- print(f"[TOOL] Arguments: {tool_call['function']['arguments']}")
+ logger.info("Executing tool=%s", tool_call["function"]["name"])
+ logger.debug("Tool arguments=%s", tool_call["function"]["arguments"])
result, tool_citations = await execute_tool(tool_call)
@@ -284,18 +335,18 @@ async def stream_llm_response(payload: Dict[str, Any]) -> AsyncGenerator[str, No
yield follow_up_chunk
except Exception as e:
- print(f"[ERROR] Tool execution error: {e}")
+ logger.exception("Tool execution error")
yield f"data: {json.dumps({'type': 'error', 'content': f'Tool execution failed: {e}'})}\n\n"
tool_calls_buffer.clear()
break # Tool execution complete, exit streaming loop
except json.JSONDecodeError as e:
- print(f"[ERROR] JSON decode error: {e}, line: {line}")
+ logger.warning("JSON decode error: %s line=%s", e, line)
continue
# Send citations if any were collected
- if citations_collector:
+ if is_outermost and citations_collector:
# Remove duplicates while preserving order
unique_citations = []
for citation in citations_collector:
@@ -305,16 +356,17 @@ async def stream_llm_response(payload: Dict[str, Any]) -> AsyncGenerator[str, No
yield f"data: {json.dumps({'type': 'citations', 'citations': unique_citations})}\n\n"
# Send completion signal
- yield f"data: {json.dumps({'type': 'done'})}\n\n"
+ if is_outermost:
+ yield f"data: {json.dumps({'type': 'done'})}\n\n"
except Exception as e:
- print(f"[ERROR] Streaming failed: {e}")
+ logger.exception("Streaming failed")
yield f"data: {json.dumps({'type': 'error', 'content': f'Streaming failed: {e}'})}\n\n"
async def handle_tool_follow_up(original_payload: Dict[str, Any], tool_call: Dict[str, Any], tool_result: str, citations_collector: List[str]) -> AsyncGenerator[str, None]:
"""Handle follow-up request after tool execution"""
try:
- print("[TOOL] Handling follow-up request with tool results")
+ logger.info("Handling follow-up request with tool results")
# Create messages with tool call and result
messages = original_payload["messages"].copy()
@@ -341,11 +393,11 @@ async def handle_tool_follow_up(original_payload: Dict[str, Any], tool_call: Dic
}
# Stream the follow-up response
- async for chunk in stream_llm_response(follow_up_payload):
+ async for chunk in stream_llm_response(follow_up_payload, citations_collector=citations_collector):
yield chunk
except Exception as e:
- print(f"[ERROR] Tool follow-up failed: {e}")
+ logger.exception("Tool follow-up failed")
yield f"data: {json.dumps({'type': 'error', 'content': f'Tool follow-up failed: {e}'})}\n\n"
async def get_non_streaming_response(payload: Dict[str, Any]) -> tuple[str, List[str]]:
@@ -397,7 +449,7 @@ async def options_health():
async def chat(request: ChatRequest):
"""Chat endpoint with RAG capabilities - supports both streaming and non-streaming"""
try:
- print(f"[CHAT] Processing message: {request.message[:100]}...")
+ logger.info("Processing chat message preview=%s", request.message[:100])
# Create initial payload
payload = {
@@ -440,15 +492,19 @@ async def chat(request: ChatRequest):
}
except Exception as e:
- print(f"[ERROR] Chat handling failed: {e}")
+ logger.exception("Chat handling failed")
raise HTTPException(status_code=500, detail=f"Request failed: {e}")
if __name__ == "__main__":
- print("🚀 Starting Kubeflow Docs HTTP API Server")
- print(f" Port: {PORT}")
- print(f" LLM Service: {KSERVE_URL}")
- print(f" Milvus: {MILVUS_HOST}:{MILVUS_PORT}")
- print(f" Collection: {MILVUS_COLLECTION}")
+ logging.basicConfig(
+ level=os.getenv("LOG_LEVEL", "INFO").upper(),
+ format="%(asctime)s %(levelname)s %(name)s - %(message)s",
+ )
+ logger.info("Starting Kubeflow Docs HTTP API Server")
+ logger.info("Port: %s", PORT)
+ logger.info("LLM Service: %s", KSERVE_URL)
+ logger.info("Milvus: %s:%s", MILVUS_HOST, MILVUS_PORT)
+ logger.info("Collection: %s", MILVUS_COLLECTION)
uvicorn.run(
app,
diff --git a/server/app.py b/server/app.py
index 96b277c..d63b697 100644
--- a/server/app.py
+++ b/server/app.py
@@ -6,9 +6,19 @@
from websockets.server import serve
from websockets.exceptions import ConnectionClosedError
import logging
+import sys
+from pathlib import Path
from typing import Dict, Any, List
from sentence_transformers import SentenceTransformer
from pymilvus import connections, Collection
+import threading
+import time
+
+REPO_ROOT = Path(__file__).resolve().parents[1]
+if str(REPO_ROOT) not in sys.path:
+ sys.path.append(str(REPO_ROOT))
+
+from shared.reranking import candidate_pool_limit, load_rerank_config_from_env, rerank_documents
# Config
KSERVE_URL = os.getenv("KSERVE_URL", "http://llama.docs-agent.svc.cluster.local/openai/v1/chat/completions")
@@ -22,6 +32,9 @@
MILVUS_VECTOR_FIELD = os.getenv("MILVUS_VECTOR_FIELD", "vector")
EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "sentence-transformers/all-mpnet-base-v2")
+logger = logging.getLogger(__name__)
+RERANK_CONFIG = load_rerank_config_from_env()
+
# System prompt
SYSTEM_PROMPT = """
You are the Kubeflow Docs Assistant.
@@ -65,6 +78,27 @@
+
+_model_lock = threading.Lock()
+_embedding_model = None
+
+def get_embedding_model():
+ """Thread-safe lazy initialization of SentenceTransformer.
+
+ The model is cached once per process because loading it on every request
+ adds avoidable latency and memory churn.
+ """
+ global _embedding_model
+ if _embedding_model is None:
+ with _model_lock:
+ # Double-checked locking
+ if _embedding_model is None:
+ start_t = time.perf_counter()
+ logger.info("Lazy loading SentenceTransformer model '%s'", EMBEDDING_MODEL)
+ _embedding_model = SentenceTransformer(EMBEDDING_MODEL)
+ logger.info("Model loaded in %.3f seconds", time.perf_counter() - start_t)
+ return _embedding_model
+
def milvus_search(query: str, top_k: int = 5) -> Dict[str, Any]:
"""Execute a semantic search in Milvus and return structured JSON serializable results."""
try:
@@ -73,22 +107,24 @@ def milvus_search(query: str, top_k: int = 5) -> Dict[str, Any]:
collection = Collection(MILVUS_COLLECTION)
collection.load()
- # Encoder (same model as pipeline)
- encoder = SentenceTransformer(EMBEDDING_MODEL)
- query_vec = encoder.encode(query).tolist()
+ # Thread-safe cached encoder
+ model = get_embedding_model()
+ query_vec = model.encode(query).tolist()
+
+ requested_top_k = max(1, int(top_k))
+ candidate_limit = candidate_pool_limit(requested_top_k, RERANK_CONFIG)
search_params = {"metric_type": "COSINE", "params": {"nprobe": 32}}
results = collection.search(
data=[query_vec],
anns_field=MILVUS_VECTOR_FIELD,
param=search_params,
- limit=int(top_k),
+ limit=candidate_limit,
output_fields=["file_path", "content_text", "citation_url"],
)
hits = []
for hit in results[0]:
- # similarity = 1 - distance for COSINE in Milvus
similarity = 1.0 - float(hit.distance)
entity = hit.entity
content_text = entity.get("content_text") or ""
@@ -100,15 +136,26 @@ def milvus_search(query: str, top_k: int = 5) -> Dict[str, Any]:
"citation_url": entity.get("citation_url"),
"content_text": content_text,
})
+
+ hits = rerank_documents(
+ query=query,
+ docs=hits,
+ config=RERANK_CONFIG,
+ top_k=requested_top_k,
+ logger=logger,
+ log_prefix="websocket_search",
+ )
+
return {"results": hits}
- except Exception as e:
- print(f"[ERROR] Milvus search failed: {e}")
- return {"results": []}
+ except Exception:
+ logger.exception("Milvus search failed for query='%s' top_k=%s", query, top_k)
+ raise
finally:
try:
connections.disconnect(alias="default")
- except Exception:
- pass
+ except Exception as disconnect_error:
+ logger.warning("Milvus disconnect failed: %s", disconnect_error)
+
TOOLS = [
{
"type": "function",
@@ -153,7 +200,7 @@ async def execute_tool(tool_call: Dict[str, Any]) -> tuple[str, List[str]]:
query = arguments.get("query", "")
top_k = arguments.get("top_k", 5)
- print(f"[TOOL] Executing Milvus search for: '{query}' (top_k={top_k})")
+ logger.info("Executing Milvus search for query='%s' top_k=%s", query, top_k)
result = milvus_search(query, top_k)
# Collect citations
@@ -178,7 +225,7 @@ async def execute_tool(tool_call: Dict[str, Any]) -> tuple[str, List[str]]:
return f"Unknown tool: {function_name}", []
except Exception as e:
- print(f"[ERROR] Tool execution failed: {e}")
+ logger.exception("Tool execution failed")
return f"Tool execution failed: {e}", []
async def stream_llm_response(payload: Dict[str, Any], websocket, citations_collector: List[str] = None) -> None:
@@ -190,7 +237,7 @@ async def stream_llm_response(payload: Dict[str, Any], websocket, citations_coll
async with client.stream("POST", KSERVE_URL, json=payload) as response:
if response.status_code != 200:
error_msg = f"LLM service error: HTTP {response.status_code}"
- print(f"[ERROR] {error_msg}")
+ logger.error(error_msg)
await websocket.send(json.dumps({"type": "error", "content": error_msg}))
return
@@ -252,14 +299,14 @@ async def stream_llm_response(payload: Dict[str, Any], websocket, citations_coll
# Handle finish reason - execute tools if needed
if finish_reason == "tool_calls":
- print(f"[TOOL] Finish reason: tool_calls, executing {len(tool_calls_buffer)} tools")
+ logger.info("Finish reason=tool_calls; executing %s tools", len(tool_calls_buffer))
# Execute all accumulated tool calls
for tool_call in tool_calls_buffer.values():
if tool_call["function"]["name"] and tool_call["function"]["arguments"]:
try:
- print(f"[TOOL] Executing: {tool_call['function']['name']}")
- print(f"[TOOL] Arguments: {tool_call['function']['arguments']}")
+ logger.info("Executing tool=%s", tool_call["function"]["name"])
+ logger.debug("Tool arguments=%s", tool_call["function"]["arguments"])
result, tool_citations = await execute_tool(tool_call)
@@ -277,7 +324,7 @@ async def stream_llm_response(payload: Dict[str, Any], websocket, citations_coll
await handle_tool_follow_up(payload, tool_call, result, websocket, citations_collector)
except Exception as e:
- print(f"[ERROR] Tool execution error: {e}")
+ logger.exception("Tool execution error")
await websocket.send(json.dumps({
"type": "error",
"content": f"Tool execution failed: {e}"
@@ -287,11 +334,11 @@ async def stream_llm_response(payload: Dict[str, Any], websocket, citations_coll
break # Tool execution complete, exit streaming loop
except json.JSONDecodeError as e:
- print(f"[ERROR] JSON decode error: {e}, line: {line}")
+ logger.warning("JSON decode error: %s line=%s", e, line)
continue
except Exception as e:
- print(f"[ERROR] Streaming failed: {e}")
+ logger.exception("Streaming failed")
await websocket.send(json.dumps({"type": "error", "content": f"Streaming failed: {e}"}))
async def handle_tool_follow_up(original_payload: Dict[str, Any], tool_call: Dict[str, Any], tool_result: str, websocket, citations_collector: List[str] = None) -> None:
@@ -299,7 +346,7 @@ async def handle_tool_follow_up(original_payload: Dict[str, Any], tool_call: Dic
if citations_collector is None:
citations_collector = []
try:
- print("[TOOL] Handling follow-up request with tool results")
+ logger.info("Handling follow-up request with tool results")
# Create messages with tool call and result
messages = original_payload["messages"].copy()
@@ -329,13 +376,13 @@ async def handle_tool_follow_up(original_payload: Dict[str, Any], tool_call: Dic
await stream_llm_response(follow_up_payload, websocket, citations_collector)
except Exception as e:
- print(f"[ERROR] Tool follow-up failed: {e}")
+ logger.exception("Tool follow-up failed")
await websocket.send(json.dumps({"type": "error", "content": f"Tool follow-up failed: {e}"}))
async def handle_chat(message: str, websocket) -> None:
"""Handle chat with tool calling support"""
try:
- print(f"[CHAT] Processing message: {message[:100]}...")
+ logger.info("Processing chat message preview=%s", message[:100])
# Create initial payload
payload = {
@@ -373,12 +420,12 @@ async def handle_chat(message: str, websocket) -> None:
await websocket.send(json.dumps({"type": "done"}))
except Exception as e:
- print(f"[ERROR] Chat handling failed: {e}")
+ logger.exception("Chat handling failed")
await websocket.send(json.dumps({"type": "error", "content": f"Request failed: {e}"}))
async def handle_websocket(websocket, path):
"""Handle WebSocket connections"""
- print(f"[WS] New connection from {websocket.remote_address}")
+ logger.info("New websocket connection from %s", websocket.remote_address)
try:
# Send welcome message
@@ -402,20 +449,20 @@ async def handle_websocket(websocket, path):
# Treat as plain text message
pass
- print(f"[WS] Received: {message[:100]}...")
+ logger.info("Received websocket message preview=%s", message[:100])
await handle_chat(message, websocket)
except Exception as e:
- print(f"[ERROR] Message processing error: {e}")
+ logger.exception("Message processing error")
await websocket.send(json.dumps({
"type": "error",
"content": f"Message processing failed: {e}"
}))
except ConnectionClosedError:
- print("[WS] Connection closed")
+ logger.info("Websocket connection closed")
except Exception as e:
- print(f"[ERROR] WebSocket error: {e}")
+ logger.exception("Websocket error")
async def health_check(path, request_headers):
"""Handle HTTP health checks"""
@@ -425,11 +472,11 @@ async def health_check(path, request_headers):
async def main():
"""Start the WebSocket server"""
- print("🚀 Starting Kubeflow Docs WebSocket Server")
- print(f" Port: {PORT}")
- print(f" LLM Service: {KSERVE_URL}")
- print(f" Milvus: {MILVUS_HOST}:{MILVUS_PORT}")
- print(f" Collection: {MILVUS_COLLECTION}")
+ logger.info("Starting Kubeflow Docs WebSocket Server")
+ logger.info("Port: %s", PORT)
+ logger.info("LLM Service: %s", KSERVE_URL)
+ logger.info("Milvus: %s:%s", MILVUS_HOST, MILVUS_PORT)
+ logger.info("Collection: %s", MILVUS_COLLECTION)
# Configure logging
logging.getLogger("websockets").setLevel(logging.WARNING)
@@ -443,12 +490,16 @@ async def main():
ping_interval=30,
ping_timeout=10
):
- print("✅ WebSocket server is running...")
- print(f" WebSocket: ws://localhost:{PORT}")
- print(f" Health: http://localhost:{PORT}/health")
+ logger.info("WebSocket server is running")
+ logger.info("WebSocket: ws://localhost:%s", PORT)
+ logger.info("Health: http://localhost:%s/health", PORT)
# Keep server running
await asyncio.Future()
if __name__ == "__main__":
+ logging.basicConfig(
+ level=os.getenv("LOG_LEVEL", "INFO").upper(),
+ format="%(asctime)s %(levelname)s %(name)s - %(message)s",
+ )
asyncio.run(main())
diff --git a/shared/__init__.py b/shared/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/shared/reranking.py b/shared/reranking.py
new file mode 100644
index 0000000..e051816
--- /dev/null
+++ b/shared/reranking.py
@@ -0,0 +1,238 @@
+import logging
+import os
+import re
+from dataclasses import dataclass
+from typing import Any, Dict, List, Optional
+
+
+LOGGER = logging.getLogger(__name__)
+
+
+@dataclass(frozen=True)
+class RerankConfig:
+ """Reranking defaults favor a lightweight, recall-first blend.
+
+ The similarity weight remains the primary signal, while keyword and metadata
+ weights stay small so reranking improves relevance without overpowering the
+ original vector score.
+ """
+
+ enabled: bool = True
+ candidate_multiplier: int = 3
+ similarity_weight: float = 0.7
+ keyword_weight: float = 0.2
+ metadata_weight: float = 0.1
+ max_candidates: int = 50
+ min_token_len: int = 3
+ debug_logging: bool = False
+ log_top_n: int = 5
+
+
+def _parse_bool(value: str, default: bool) -> bool:
+ if value is None:
+ return default
+
+ normalized = value.strip().lower()
+ if normalized in {"1", "true", "yes", "on"}:
+ return True
+ if normalized in {"0", "false", "no", "off"}:
+ return False
+
+ LOGGER.warning("Invalid boolean value '%s'; using default=%s", value, default)
+ return default
+
+
+def _parse_int_env(name: str, default: int, minimum: int = 1) -> int:
+ raw = os.getenv(name)
+ if raw is None:
+ return default
+
+ try:
+ return max(minimum, int(raw))
+ except (TypeError, ValueError):
+ LOGGER.warning("Invalid integer for %s=%s; using default=%s", name, raw, default)
+ return default
+
+
+def _parse_float_env(name: str, default: float) -> float:
+ raw = os.getenv(name)
+ if raw is None:
+ return default
+
+ try:
+ return float(raw)
+ except (TypeError, ValueError):
+ LOGGER.warning("Invalid float for %s=%s; using default=%s", name, raw, default)
+ return default
+
+
+def load_rerank_config_from_env() -> RerankConfig:
+ return RerankConfig(
+ enabled=_parse_bool(os.getenv("RERANK_ENABLED", "true"), True),
+ candidate_multiplier=_parse_int_env("RERANK_CANDIDATE_MULTIPLIER", 3),
+ similarity_weight=_parse_float_env("RERANK_SIMILARITY_WEIGHT", 0.7),
+ keyword_weight=_parse_float_env("RERANK_KEYWORD_WEIGHT", 0.2),
+ metadata_weight=_parse_float_env("RERANK_METADATA_WEIGHT", 0.1),
+ max_candidates=_parse_int_env("RERANK_MAX_CANDIDATES", 50),
+ min_token_len=_parse_int_env("RERANK_MIN_TOKEN_LEN", 3),
+ debug_logging=_parse_bool(os.getenv("RERANK_DEBUG_LOG", "false"), False),
+ log_top_n=_parse_int_env("RERANK_LOG_TOP_N", 5),
+ )
+
+
+def candidate_pool_limit(top_k: int, config: RerankConfig) -> int:
+ try:
+ requested_top_k = max(1, int(top_k))
+ except (TypeError, ValueError):
+ LOGGER.warning("Invalid top_k=%s; using fallback top_k=1", top_k)
+ requested_top_k = 1
+ if not config.enabled:
+ return requested_top_k
+ expanded = requested_top_k * max(1, config.candidate_multiplier)
+ return min(max(requested_top_k, expanded), max(1, config.max_candidates))
+
+
+def _tokenize_text(text: str) -> List[str]:
+ return re.findall(r"[a-zA-Z0-9_]+", (text or "").lower())
+
+
+def _query_terms(query: str, min_token_len: int) -> set:
+ return {token for token in _tokenize_text(query) if len(token) >= min_token_len}
+
+
+def _keyword_overlap_score(query_terms: set, content: str) -> float:
+ if not query_terms:
+ return 0.0
+ content_terms = set(_tokenize_text(content))
+ overlap = query_terms.intersection(content_terms)
+ return len(overlap) / len(query_terms)
+
+
+def _metadata_score(query_terms: set, file_path: str, citation_url: str) -> float:
+ if not query_terms:
+ return 0.0
+ metadata_terms = set(_tokenize_text(file_path)) | set(_tokenize_text(citation_url))
+ overlap = query_terms.intersection(metadata_terms)
+ return len(overlap) / len(query_terms)
+
+
+def _extract_similarity(doc: Dict[str, Any]) -> float:
+ if doc.get("similarity") is not None:
+ return float(doc.get("similarity", 0.0))
+ if doc.get("distance") is not None:
+ return 1.0 - float(doc.get("distance", 0.0))
+ return 0.0
+
+
+def _log_docs(
+ logger: Optional[logging.Logger],
+ enabled: bool,
+ stage: str,
+ docs: List[Dict[str, Any]],
+ top_n: int,
+ log_prefix: str,
+) -> None:
+ if not logger or not enabled:
+ return
+
+ logger.info("[%s] %s top %s documents", log_prefix, stage, min(top_n, len(docs)))
+ for idx, doc in enumerate(docs[:top_n], start=1):
+ logger.info(
+ "[%s] %s #%s path=%s similarity=%.4f rerank=%.4f keyword=%.4f metadata=%.4f",
+ log_prefix,
+ stage,
+ idx,
+ doc.get("file_path", ""),
+ float(doc.get("similarity", 0.0)),
+ float(doc.get("rerank_score", 0.0)),
+ float(doc.get("keyword_score", 0.0)),
+ float(doc.get("metadata_score", 0.0)),
+ )
+
+
+def rerank_documents(
+ query: str,
+ docs: List[Dict[str, Any]],
+ config: RerankConfig,
+ top_k: int,
+ logger: Optional[logging.Logger] = None,
+ log_prefix: str = "retrieval",
+) -> List[Dict[str, Any]]:
+ try:
+ requested_top_k = max(1, int(top_k))
+ except (TypeError, ValueError):
+ LOGGER.warning("Invalid top_k=%s; using fallback top_k=1", top_k)
+ requested_top_k = 1
+ if not docs:
+ return []
+
+ query_terms = _query_terms(query, config.min_token_len)
+
+ normalized_docs: List[Dict[str, Any]] = []
+ for doc in docs:
+ normalized = dict(doc)
+ normalized["similarity"] = _extract_similarity(normalized)
+ normalized_docs.append(normalized)
+
+ _log_docs(
+ logger,
+ config.debug_logging,
+ "before_rerank",
+ normalized_docs,
+ config.log_top_n,
+ log_prefix,
+ )
+
+ if not config.enabled:
+ selected = normalized_docs[:requested_top_k]
+ for doc in selected:
+ doc["keyword_score"] = 0.0
+ doc["metadata_score"] = 0.0
+ doc["rerank_score"] = round(float(doc.get("similarity", 0.0)), 4)
+ _log_docs(
+ logger,
+ config.debug_logging,
+ "after_rerank_disabled",
+ selected,
+ config.log_top_n,
+ log_prefix,
+ )
+ return selected
+
+ reranked: List[Dict[str, Any]] = []
+ for doc in normalized_docs:
+ keyword_score = _keyword_overlap_score(query_terms, doc.get("content_text", ""))
+ metadata_score = _metadata_score(query_terms, doc.get("file_path", ""), doc.get("citation_url", ""))
+ final_score = (
+ config.similarity_weight * float(doc.get("similarity", 0.0))
+ + config.keyword_weight * keyword_score
+ + config.metadata_weight * metadata_score
+ )
+
+ doc["keyword_score"] = round(keyword_score, 4)
+ doc["metadata_score"] = round(metadata_score, 4)
+ doc["rerank_score"] = round(final_score, 4)
+ reranked.append(doc)
+
+ # Sort by score and stable tie-breakers for deterministic ordering in tests and production.
+ reranked.sort(
+ key=lambda item: (
+ -float(item.get("rerank_score", 0.0)),
+ -float(item.get("similarity", 0.0)),
+ str(item.get("file_path", "")),
+ str(item.get("citation_url", "")),
+ str(item.get("content_text", "")),
+ )
+ )
+ selected = reranked[:requested_top_k]
+
+ _log_docs(
+ logger,
+ config.debug_logging,
+ "after_rerank",
+ selected,
+ config.log_top_n,
+ log_prefix,
+ )
+
+ return selected
diff --git a/shared/tests/test_reranking.py b/shared/tests/test_reranking.py
new file mode 100644
index 0000000..19dbe7d
--- /dev/null
+++ b/shared/tests/test_reranking.py
@@ -0,0 +1,119 @@
+import logging
+import sys
+from pathlib import Path
+
+REPO_ROOT = Path(__file__).resolve().parents[2]
+if str(REPO_ROOT) not in sys.path:
+ sys.path.append(str(REPO_ROOT))
+
+from shared.reranking import RerankConfig, candidate_pool_limit, load_rerank_config_from_env, rerank_documents
+
+
+def test_candidate_pool_limit_expands_and_caps():
+ config = RerankConfig(enabled=True, candidate_multiplier=3, max_candidates=10)
+ assert candidate_pool_limit(2, config) == 6
+ assert candidate_pool_limit(4, config) == 10
+
+
+def test_candidate_pool_limit_handles_invalid_top_k(caplog):
+ config = RerankConfig(enabled=True, candidate_multiplier=3, max_candidates=10)
+
+ with caplog.at_level(logging.WARNING):
+ limit = candidate_pool_limit("invalid", config)
+
+ assert limit == 3
+ assert "Invalid top_k=invalid" in caplog.text
+
+
+def test_rerank_documents_returns_empty_for_empty_input():
+ config = RerankConfig()
+ assert rerank_documents("kubeflow", [], config, top_k=5) == []
+
+
+def test_rerank_documents_scoring_and_ordering():
+ config = RerankConfig(
+ enabled=True,
+ similarity_weight=0.7,
+ keyword_weight=0.2,
+ metadata_weight=0.1,
+ min_token_len=3,
+ )
+
+ docs = [
+ {
+ "similarity": 0.9,
+ "file_path": "docs/kserve.md",
+ "citation_url": "https://kubeflow.org/docs/components/kserve",
+ "content_text": "kserve inference service gpu deployment",
+ },
+ {
+ "similarity": 0.95,
+ "file_path": "docs/pipelines.md",
+ "citation_url": "https://kubeflow.org/docs/components/pipelines",
+ "content_text": "pipeline runs and scheduling",
+ },
+ ]
+
+ ranked = rerank_documents("kserve gpu", docs, config, top_k=2)
+
+ assert len(ranked) == 2
+ assert ranked[0]["file_path"] == "docs/kserve.md"
+ assert ranked[0]["rerank_score"] >= ranked[1]["rerank_score"]
+
+
+def test_rerank_documents_deterministic_tie_breaking():
+ config = RerankConfig(enabled=True, similarity_weight=1.0, keyword_weight=0.0, metadata_weight=0.0)
+
+ docs = [
+ {
+ "similarity": 0.5,
+ "file_path": "z.md",
+ "citation_url": "https://example.com/z",
+ "content_text": "same",
+ },
+ {
+ "similarity": 0.5,
+ "file_path": "a.md",
+ "citation_url": "https://example.com/a",
+ "content_text": "same",
+ },
+ ]
+
+ ranked = rerank_documents("anything", docs, config, top_k=2)
+
+ assert [doc["file_path"] for doc in ranked] == ["a.md", "z.md"]
+
+
+def test_load_rerank_config_from_env_fallbacks(monkeypatch, caplog):
+ monkeypatch.setenv("RERANK_ENABLED", "not-a-bool")
+ monkeypatch.setenv("RERANK_CANDIDATE_MULTIPLIER", "nan")
+ monkeypatch.setenv("RERANK_SIMILARITY_WEIGHT", "bad")
+ monkeypatch.setenv("RERANK_KEYWORD_WEIGHT", "bad")
+ monkeypatch.setenv("RERANK_METADATA_WEIGHT", "bad")
+ monkeypatch.setenv("RERANK_MAX_CANDIDATES", "bad")
+ monkeypatch.setenv("RERANK_MIN_TOKEN_LEN", "bad")
+ monkeypatch.setenv("RERANK_DEBUG_LOG", "bad")
+ monkeypatch.setenv("RERANK_LOG_TOP_N", "bad")
+
+ with caplog.at_level(logging.WARNING):
+ config = load_rerank_config_from_env()
+
+ assert config == RerankConfig()
+ assert "Invalid boolean value" in caplog.text
+ assert "Invalid integer" in caplog.text
+ assert "Invalid float" in caplog.text
+
+
+def test_rerank_disabled_preserves_similarity_order():
+ config = RerankConfig(enabled=False)
+
+ docs = [
+ {"similarity": 0.8, "file_path": "one.md", "citation_url": "", "content_text": ""},
+ {"similarity": 0.7, "file_path": "two.md", "citation_url": "", "content_text": ""},
+ ]
+
+ ranked = rerank_documents("kubeflow", docs, config, top_k=2)
+
+ assert [doc["file_path"] for doc in ranked] == ["one.md", "two.md"]
+ assert ranked[0]["keyword_score"] == 0.0
+ assert ranked[0]["metadata_score"] == 0.0