Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/litserve/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -1208,6 +1208,12 @@ def _perform_graceful_shutdown(
logger.warning(f"{log_prefix}: Already not alive.")
continue
try:
if isinstance(uw, threading.Thread):
getattr(uw, "_litserve_server").should_exit = True
uw.join(timeout=self.uvicorn_graceful_timeout)
if uw.is_alive():
logger.warning(f"{log_prefix}: Did not terminate gracefully.")
continue
uw.terminate()
uw.join(timeout=self.uvicorn_graceful_timeout)
if uw.is_alive():
Expand Down Expand Up @@ -1577,6 +1583,7 @@ def _start_server(self, port, num_uvicorn_servers, log_level, sockets, uvicorn_w
w = threading.Thread(
target=server.run, args=(response_queue_id, sockets), name=f"LitServer-{response_queue_id}"
)
setattr(w, "_litserve_server", server)
Comment on lines 1583 to +1586

@bhimrazy bhimrazy Jun 3, 2026

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The setattr/getattr approach works but I think a small subclass makes the intent clearer and the handle typed:

class _UvicornServerThread(threading.Thread):
    def __init__(self, server: _Server, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.server = server
w = _UvicornServerThread(
    server, target=server.run, args=(response_queue_id, sockets), name=f"LitServer-{response_queue_id}"
)

wdyt ?

cc: @andyland

else:
raise ValueError("Invalid value for api_server_worker_type. Must be 'process' or 'thread'")
w.start()
Expand Down
23 changes: 23 additions & 0 deletions tests/unit/test_lit_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import json
import os
import sys
import threading
import time
from time import sleep
from unittest.mock import MagicMock, patch
Expand Down Expand Up @@ -340,6 +341,28 @@ def test_server_terminate():
server._transport.close.assert_called()


def test_graceful_shutdown_stops_uvicorn_thread(simple_litapi):
server = LitServer(simple_litapi)
server._transport = MagicMock()
server.inference_workers = []
uvicorn_server = MagicMock(should_exit=False)

def wait_for_shutdown():
while not uvicorn_server.should_exit:
sleep(0.01)

worker = threading.Thread(target=wait_for_shutdown, name="LitServer-0")
setattr(worker, "_litserve_server", uvicorn_server)
worker.start()

manager = MagicMock()
server._perform_graceful_shutdown(manager, {0: worker})

assert uvicorn_server.should_exit is True
assert not worker.is_alive()
manager.shutdown.assert_called_once()


@pytest.mark.parametrize(("disable_openapi_url", "should_print"), [(False, True), (True, False)])
@patch("builtins.print")
@patch("litserve.server.uvicorn")
Expand Down
Loading