Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 34 additions & 20 deletions src/litserve/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,8 @@ def predict(self, x):

- Returns 200 when all workers are ready
- Critical for Kubernetes/Docker deployments
- Each LitAPI also gets its own health check endpoint at ``{api_path}{healthcheck_path}``
(e.g. "/predict/health") to monitor a single API independently

info_path:
Server information endpoint. Defaults to "/info".
Expand Down Expand Up @@ -1021,34 +1023,33 @@ def active_requests(self):
return sum(counter.value for counter in self.active_counters)
return None

async def _check_lit_api_health(self, lit_api: LitAPI) -> bool:
"""Check worker readiness and the user-defined ``LitAPI.health`` hook for a single LitAPI."""
endpoint = lit_api.api_path.split("/")[-1]
worker_statuses = [v for k, v in self.workers_setup_status.items() if k.rsplit("_", 1)[0] == endpoint]
workers_ready = bool(worker_statuses) and all(v == WorkerSetupStatus.READY for v in worker_statuses)

try:
health_status = lit_api.health()
if inspect.isawaitable(health_status):
health_status = await health_status
except Exception:
logger.exception(f"Health check failed for {lit_api.__class__.__name__}")
health_status = False

return workers_ready and bool(health_status)

def _register_internal_endpoints(self):
@self.app.get("/", dependencies=[Depends(self.setup_auth())])
async def index(request: Request) -> Response:
return Response(content="litserve running")

@self.app.get(self.healthcheck_path, dependencies=[Depends(self.setup_auth())])
async def health(request: Request) -> Response:
workers_ready = bool(self.workers_setup_status) and all(
v == WorkerSetupStatus.READY for v in self.workers_setup_status.values()
)

lit_api_health_status = True
for lit_api in self.litapi_connector:
try:
result = lit_api.health()
if inspect.isawaitable(result):
result = await result
if not result:
lit_api_health_status = False
break
except Exception:
logger.exception(f"Health check failed for {lit_api.__class__.__name__}")
lit_api_health_status = False
break
if workers_ready and lit_api_health_status:
return Response(content="ok", status_code=200)

return Response(content="not ready", status_code=503)
if not await self._check_lit_api_health(lit_api):
return Response(content="not ready", status_code=503)
return Response(content="ok", status_code=200)

@self.app.get(self.info_path, dependencies=[Depends(self.setup_auth())])
async def info(request: Request) -> Response:
Expand Down Expand Up @@ -1117,6 +1118,19 @@ async def endpoint_handler(request: request_type) -> response_type:
dependencies=[Depends(self.setup_auth(lit_api))],
)

# Register a per-LitAPI health check endpoint, e.g. /predict/health
async def health_endpoint(request: Request) -> Response:
if await self._check_lit_api_health(lit_api):
return Response(content="ok", status_code=200)
return Response(content="not ready", status_code=503)

self.app.add_api_route(
f"{lit_api.api_path}{self.healthcheck_path}",
health_endpoint,
methods=["GET"],
dependencies=[Depends(self.setup_auth(lit_api))],
)

# Handle specs
self._register_spec_endpoints(lit_api)

Expand Down
12 changes: 6 additions & 6 deletions tests/unit/test_lit_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -950,7 +950,7 @@ async def mock_lifespan(app):
def test_health_check_returns_503_on_health_exception(lifespan_mock, simple_litapi):
"""Health check should return 503 (not 500) when a custom health() method raises an exception."""
server = LitServer(simple_litapi, accelerator="cpu", devices=1, timeout=10)
server.workers_setup_status = {"worker-0": WorkerSetupStatus.READY}
server.workers_setup_status = {"predict_0": WorkerSetupStatus.READY}

@contextlib.asynccontextmanager
async def mock_lifespan(app):
Expand Down Expand Up @@ -982,11 +982,11 @@ async def mock_lifespan(app):
lit_api.health = MagicMock(return_value=True)

with TestClient(server.app) as client:
server.workers_setup_status = {"worker-0": WorkerSetupStatus.READY}
server.workers_setup_status = {"predict_0": WorkerSetupStatus.READY}
response = client.get("/health")
assert response.status_code == 200

server.workers_setup_status = {"worker-0": WorkerSetupStatus.ERROR}
server.workers_setup_status = {"predict_0": WorkerSetupStatus.ERROR}
response = client.get("/health")
assert response.status_code == 503, "Health check should return 503 after worker enters error state"
assert response.text == "not ready"
Expand All @@ -1007,17 +1007,17 @@ async def mock_lifespan(app):
lit_api.health = MagicMock(return_value=True)

with TestClient(server.app) as client:
server.workers_setup_status = {"worker-0": WorkerSetupStatus.READY}
server.workers_setup_status = {"predict_0": WorkerSetupStatus.READY}
response = client.get("/health")
assert response.status_code == 200

response = client.get("/health")
assert response.status_code == 200

server.workers_setup_status = {"worker-0": WorkerSetupStatus.ERROR}
server.workers_setup_status = {"predict_0": WorkerSetupStatus.ERROR}
response = client.get("/health")
assert response.status_code == 503, "Should not serve stale cached status"

server.workers_setup_status = {"worker-0": WorkerSetupStatus.READY}
server.workers_setup_status = {"predict_0": WorkerSetupStatus.READY}
response = client.get("/health")
assert response.status_code == 200, "Should recover when workers become ready again"
62 changes: 62 additions & 0 deletions tests/unit/test_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,68 @@ def test_workers_health_with_async_health_method(use_zmq):
assert response.text == "ok"


class UnhealthyLitAPI(SimpleLitAPI):
def health(self) -> bool:
return False


@pytest.mark.parametrize("use_zmq", [True, False])
def test_per_api_health(use_zmq):
server = LitServer(
[SimpleLitAPI(api_path="/healthy"), UnhealthyLitAPI(api_path="/unhealthy")],
accelerator="cpu",
devices=1,
timeout=5,
fast_queue=use_zmq,
)

with wrap_litserve_start(server) as server, TestClient(server.app) as client:
# wait for workers to be ready
for _ in range(10):
response = client.get("/healthy/health")
if response.status_code == 200:
break
time.sleep(0.5)
assert response.status_code == 200
assert response.text == "ok"

response = client.get("/unhealthy/health")
assert response.status_code == 503
assert response.text == "not ready"

# global health check aggregates all APIs
response = client.get("/health")
assert response.status_code == 503
assert response.text == "not ready"


@pytest.mark.parametrize("use_zmq", [True, False])
def test_per_api_health_custom_path(use_zmq):
server = LitServer(
SlowSetupLitAPI(api_path="/api1"),
accelerator="cpu",
healthcheck_path="/my_server/health",
devices=1,
timeout=5,
workers_per_device=2,
fast_queue=use_zmq,
)

with wrap_litserve_start(server) as server, TestClient(server.app) as client:
response = client.get("/api1/my_server/health")
assert response.status_code == 503
assert response.text == "not ready"

# wait for workers to be ready
for _ in range(10):
response = client.get("/api1/my_server/health")
if response.status_code == 200:
break
time.sleep(0.5)
assert response.status_code == 200
assert response.text == "ok"


def make_load_request(server, outputs):
with TestClient(server.app) as client:
for i in range(100):
Expand Down