diff --git a/src/litserve/server.py b/src/litserve/server.py index ccefb5c2..256bf2d4 100644 --- a/src/litserve/server.py +++ b/src/litserve/server.py @@ -510,6 +510,8 @@ def predict(self, x): - Returns 200 when all workers are ready - Critical for Kubernetes/Docker deployments + - Each LitAPI also gets its own health check endpoint at ``{api_path}{healthcheck_path}`` + (e.g. "/predict/health") to monitor a single API independently info_path: Server information endpoint. Defaults to "/info". @@ -1021,6 +1023,22 @@ def active_requests(self): return sum(counter.value for counter in self.active_counters) return None + async def _check_lit_api_health(self, lit_api: LitAPI) -> bool: + """Check worker readiness and the user-defined ``LitAPI.health`` hook for a single LitAPI.""" + endpoint = lit_api.api_path.split("/")[-1] + worker_statuses = [v for k, v in self.workers_setup_status.items() if k.rsplit("_", 1)[0] == endpoint] + workers_ready = bool(worker_statuses) and all(v == WorkerSetupStatus.READY for v in worker_statuses) + + try: + health_status = lit_api.health() + if inspect.isawaitable(health_status): + health_status = await health_status + except Exception: + logger.exception(f"Health check failed for {lit_api.__class__.__name__}") + health_status = False + + return workers_ready and bool(health_status) + def _register_internal_endpoints(self): @self.app.get("/", dependencies=[Depends(self.setup_auth())]) async def index(request: Request) -> Response: @@ -1028,27 +1046,10 @@ async def index(request: Request) -> Response: @self.app.get(self.healthcheck_path, dependencies=[Depends(self.setup_auth())]) async def health(request: Request) -> Response: - workers_ready = bool(self.workers_setup_status) and all( - v == WorkerSetupStatus.READY for v in self.workers_setup_status.values() - ) - - lit_api_health_status = True for lit_api in self.litapi_connector: - try: - result = lit_api.health() - if inspect.isawaitable(result): - result = await result - if not result: - lit_api_health_status = False - break - except Exception: - logger.exception(f"Health check failed for {lit_api.__class__.__name__}") - lit_api_health_status = False - break - if workers_ready and lit_api_health_status: - return Response(content="ok", status_code=200) - - return Response(content="not ready", status_code=503) + if not await self._check_lit_api_health(lit_api): + return Response(content="not ready", status_code=503) + return Response(content="ok", status_code=200) @self.app.get(self.info_path, dependencies=[Depends(self.setup_auth())]) async def info(request: Request) -> Response: @@ -1117,6 +1118,19 @@ async def endpoint_handler(request: request_type) -> response_type: dependencies=[Depends(self.setup_auth(lit_api))], ) + # Register a per-LitAPI health check endpoint, e.g. /predict/health + async def health_endpoint(request: Request) -> Response: + if await self._check_lit_api_health(lit_api): + return Response(content="ok", status_code=200) + return Response(content="not ready", status_code=503) + + self.app.add_api_route( + f"{lit_api.api_path}{self.healthcheck_path}", + health_endpoint, + methods=["GET"], + dependencies=[Depends(self.setup_auth(lit_api))], + ) + # Handle specs self._register_spec_endpoints(lit_api) diff --git a/tests/unit/test_lit_server.py b/tests/unit/test_lit_server.py index d319ae19..be908124 100644 --- a/tests/unit/test_lit_server.py +++ b/tests/unit/test_lit_server.py @@ -950,7 +950,7 @@ async def mock_lifespan(app): def test_health_check_returns_503_on_health_exception(lifespan_mock, simple_litapi): """Health check should return 503 (not 500) when a custom health() method raises an exception.""" server = LitServer(simple_litapi, accelerator="cpu", devices=1, timeout=10) - server.workers_setup_status = {"worker-0": WorkerSetupStatus.READY} + server.workers_setup_status = {"predict_0": WorkerSetupStatus.READY} @contextlib.asynccontextmanager async def mock_lifespan(app): @@ -982,11 +982,11 @@ async def mock_lifespan(app): lit_api.health = MagicMock(return_value=True) with TestClient(server.app) as client: - server.workers_setup_status = {"worker-0": WorkerSetupStatus.READY} + server.workers_setup_status = {"predict_0": WorkerSetupStatus.READY} response = client.get("/health") assert response.status_code == 200 - server.workers_setup_status = {"worker-0": WorkerSetupStatus.ERROR} + server.workers_setup_status = {"predict_0": WorkerSetupStatus.ERROR} response = client.get("/health") assert response.status_code == 503, "Health check should return 503 after worker enters error state" assert response.text == "not ready" @@ -1007,17 +1007,17 @@ async def mock_lifespan(app): lit_api.health = MagicMock(return_value=True) with TestClient(server.app) as client: - server.workers_setup_status = {"worker-0": WorkerSetupStatus.READY} + server.workers_setup_status = {"predict_0": WorkerSetupStatus.READY} response = client.get("/health") assert response.status_code == 200 response = client.get("/health") assert response.status_code == 200 - server.workers_setup_status = {"worker-0": WorkerSetupStatus.ERROR} + server.workers_setup_status = {"predict_0": WorkerSetupStatus.ERROR} response = client.get("/health") assert response.status_code == 503, "Should not serve stale cached status" - server.workers_setup_status = {"worker-0": WorkerSetupStatus.READY} + server.workers_setup_status = {"predict_0": WorkerSetupStatus.READY} response = client.get("/health") assert response.status_code == 200, "Should recover when workers become ready again" diff --git a/tests/unit/test_simple.py b/tests/unit/test_simple.py index bcc078f3..bf8936aa 100644 --- a/tests/unit/test_simple.py +++ b/tests/unit/test_simple.py @@ -177,6 +177,68 @@ def test_workers_health_with_async_health_method(use_zmq): assert response.text == "ok" +class UnhealthyLitAPI(SimpleLitAPI): + def health(self) -> bool: + return False + + +@pytest.mark.parametrize("use_zmq", [True, False]) +def test_per_api_health(use_zmq): + server = LitServer( + [SimpleLitAPI(api_path="/healthy"), UnhealthyLitAPI(api_path="/unhealthy")], + accelerator="cpu", + devices=1, + timeout=5, + fast_queue=use_zmq, + ) + + with wrap_litserve_start(server) as server, TestClient(server.app) as client: + # wait for workers to be ready + for _ in range(10): + response = client.get("/healthy/health") + if response.status_code == 200: + break + time.sleep(0.5) + assert response.status_code == 200 + assert response.text == "ok" + + response = client.get("/unhealthy/health") + assert response.status_code == 503 + assert response.text == "not ready" + + # global health check aggregates all APIs + response = client.get("/health") + assert response.status_code == 503 + assert response.text == "not ready" + + +@pytest.mark.parametrize("use_zmq", [True, False]) +def test_per_api_health_custom_path(use_zmq): + server = LitServer( + SlowSetupLitAPI(api_path="/api1"), + accelerator="cpu", + healthcheck_path="/my_server/health", + devices=1, + timeout=5, + workers_per_device=2, + fast_queue=use_zmq, + ) + + with wrap_litserve_start(server) as server, TestClient(server.app) as client: + response = client.get("/api1/my_server/health") + assert response.status_code == 503 + assert response.text == "not ready" + + # wait for workers to be ready + for _ in range(10): + response = client.get("/api1/my_server/health") + if response.status_code == 200: + break + time.sleep(0.5) + assert response.status_code == 200 + assert response.text == "ok" + + def make_load_request(server, outputs): with TestClient(server.app) as client: for i in range(100):