Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
237 changes: 181 additions & 56 deletions src/exo/worker/runner/runner_supervisor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import contextlib
import multiprocessing as mp
import os
import signal
from dataclasses import dataclass, field
from typing import Self
Expand Down Expand Up @@ -49,6 +50,35 @@
DECODE_TIMEOUT_SECONDS = 5


def _sigterm_handler(signum, frame):
"""
SIGTERM handler: forcibly SIGKILL all direct child processes so that
orphaned python3 MLX-runner processes do not survive a kickstart.
Re-raises default SIGTERM so the supervisor itself still exits cleanly.
"""
try:
import subprocess
result = subprocess.run(
["pgrep", "-P", str(os.getpid())],
capture_output=True, text=True, timeout=2
)
for pid_str in result.stdout.strip().splitlines():
try:
os.kill(int(pid_str), signal.SIGKILL)
except (ProcessLookupError, ValueError):
pass
except Exception:
pass
# Restore default and re-raise so the process exits with SIGTERM.
signal.signal(signum, signal.SIG_DFL)
os.kill(os.getpid(), signum)


# Install at import time so the handler is active for the entire supervisor
# lifetime, including the finally block inside run().
signal.signal(signal.SIGTERM, _sigterm_handler)


@dataclass(eq=False)
class RunnerSupervisor:
shard_metadata: ShardMetadata
Expand All @@ -68,6 +98,21 @@ class RunnerSupervisor:
_cancel_watch_runner: anyio.CancelScope = field(
default_factory=anyio.CancelScope, init=False
)
_shutdown_requested: bool = field(default=False, init=False)


def _runner_is_alive(self) -> bool:
try:
return self.runner_process.is_alive()
except ValueError:
return False

def _runner_exitcode(self) -> int | None:
try:
return self.runner_process.exitcode
except ValueError:
return -1


@classmethod
def create(
Expand Down Expand Up @@ -108,64 +153,144 @@ def create(

return self

# ------------------------------------------------------------------
# Non-blocking helpers — each offloads a blocking call to a thread so
# the asyncio event loop (and therefore the API server) stays alive.
# ------------------------------------------------------------------

async def _join_runner(self, timeout: float) -> None:
"""Join the runner process without blocking the event loop."""
await to_thread.run_sync(
lambda: self.runner_process.join(timeout), abandon_on_cancel=True
)

async def _terminate_runner(self) -> None:
"""Send SIGTERM to the runner without blocking the event loop."""
await to_thread.run_sync(
self.runner_process.terminate, abandon_on_cancel=True
)

async def _kill_runner(self) -> None:
"""Send SIGKILL to the runner without blocking the event loop."""
await to_thread.run_sync(
self.runner_process.kill, abandon_on_cancel=True
)

async def run(self):
self.runner_process.start()
try:
async with self._tg as tg:
tg.start_soon(self._watch_runner)
tg.start_soon(self._forward_events)
finally:
logger.info("Runner supervisor shutting down")
if not self._cancel_watch_runner.cancel_called:
self._cancel_watch_runner.cancel()
with contextlib.suppress(ClosedResourceError):
self._ev_recv.close()
with contextlib.suppress(ClosedResourceError):
self._task_sender.close()
with contextlib.suppress(ClosedResourceError):
self._event_sender.close()
with contextlib.suppress(ClosedResourceError):
self._cancel_sender.send(CANCEL_ALL_TASKS)
with contextlib.suppress(ClosedResourceError):
self._cancel_sender.close()

await to_thread.run_sync(self.runner_process.join, 5)

if self.runner_process.is_alive():
logger.warning(
"Runner process didn't shutdown succesfully, terminating"
)
self.runner_process.terminate()
self.runner_process.join(timeout=10)
MAX_RESTARTS = 5
restart_count = 0

if not self.runner_process.is_alive():
logger.warning("Terminated nicely in the first attempt!")
while True:
self._shutdown_requested = False
self.runner_process.start()
try:
async with self._tg as tg:
tg.start_soon(self._watch_runner)
tg.start_soon(self._forward_events)
finally:
logger.info("Runner supervisor shutting down" if self._shutdown_requested else "Runner process exited unexpectedly, cleaning up")
if not self._cancel_watch_runner.cancel_called:
self._cancel_watch_runner.cancel()
with contextlib.suppress(ClosedResourceError):
self._ev_recv.close()
with contextlib.suppress(ClosedResourceError):
self._task_sender.close()
# Only close the event sender on intentional shutdown — on crash-restart,
# keep it open so the rest of exo keeps receiving events after reload
if self._shutdown_requested:
with contextlib.suppress(ClosedResourceError):
self._event_sender.close()
with contextlib.suppress(ClosedResourceError, TimeoutError, Exception):
with anyio.move_on_after(2.0):
await self._cancel_sender.send_async(CANCEL_ALL_TASKS)
with contextlib.suppress(ClosedResourceError):
self._cancel_sender.close()

else:
# Try really hard to terminate
for i in range(2, 11):
self.runner_process.terminate()
self.runner_process.join(timeout=2)
if not self.runner_process.is_alive():
logger.warning(f"That took {i} attempts :)")
break
# Try even harder to kill
await self._join_runner(5)

if self._runner_is_alive():
logger.warning(
"Runner process didn't shutdown succesfully, terminating"
)
await self._terminate_runner()
await self._join_runner(10)

if not self._runner_is_alive():
logger.warning("Terminated nicely in the first attempt!")
else:
logger.critical(
"Runner process didn't respond to SIGTERM, killing"
)
j = 0
while self.runner_process.is_alive():
j += 1
self.runner_process.kill()
self.runner_process.join(timeout=5)
logger.warning(f"That took {j} attempts :(")
else:
logger.info("Runner process succesfully terminated")
for i in range(2, 11):
await self._terminate_runner()
await self._join_runner(2)
if not self._runner_is_alive():
logger.warning(f"That took {i} attempts :)")
break
else:
logger.critical(
"Runner process didn't respond to SIGTERM, killing"
)
j = 0
while self._runner_is_alive():
j += 1
await self._kill_runner()
await self._join_runner(5)
logger.warning(f"That took {j} attempts :(")
else:
logger.info("Runner process succesfully terminated")

self.runner_process.close()

if self._shutdown_requested:
logger.info("Runner supervisor: intentional shutdown, not restarting")
break

self.runner_process.close()
restart_count += 1
if restart_count > MAX_RESTARTS:
logger.critical(
f"Runner crashed {MAX_RESTARTS} times without recovery, giving up"
)
break

delay = min(2.0 * (2 ** (restart_count - 1)), 10.0)
logger.warning(
f"Runner crashed (attempt {restart_count}/{MAX_RESTARTS}), "
f"restarting in {delay:.0f}s"
)
await anyio.sleep(delay)
self._reset_for_restart()

def shutdown(self):
self._shutdown_requested = True
self._tg.cancel_tasks()

def _reset_for_restart(self) -> None:
"""Recreate channels and process for a runner restart after an unexpected crash."""
ev_send, ev_recv = mp_channel[Event]()
task_sender, task_recv = mp_channel[Task]()
cancel_sender, cancel_recv = mp_channel[TaskId]()

self.runner_process = mp.Process(
target=entrypoint,
args=(self.bound_instance, ev_send, task_recv, cancel_recv, logger),
daemon=True,
)
self._ev_recv = ev_recv
self._task_sender = task_sender
self._cancel_sender = cancel_sender
self._tg = TaskGroup()
self._cancel_watch_runner = anyio.CancelScope()
self.status = RunnerIdle()
self.pending = {}
self.in_progress = {}
self.completed = set()
self.cancelled = set()
# _event_sender is intentionally NOT reset — it connects to the rest of exo
# and must remain open across restarts

def _cancel_tg(self) -> None:
"""Cancel the running task group without marking this as an intentional shutdown.
Used by _check_runner to tear down the current run() iteration so the restart
loop can start a fresh runner subprocess.
"""
self._tg.cancel_tasks()

async def start_task(self, task: Task):
Expand Down Expand Up @@ -246,17 +371,17 @@ async def _watch_runner(self) -> None:
with self._cancel_watch_runner:
while True:
await anyio.sleep(5)
if not self.runner_process.is_alive():
if not self._runner_is_alive():
await self._check_runner(RuntimeError("Runner found to be dead"))

async def _check_runner(self, e: Exception) -> None:
if not self._cancel_watch_runner.cancel_called:
self._cancel_watch_runner.cancel()
logger.info("Checking runner's status")
if self.runner_process.is_alive():
if self._runner_is_alive():
logger.info("Runner was found to be alive, attempting to join process")
await to_thread.run_sync(self.runner_process.join, 5)
rc = self.runner_process.exitcode
await self._join_runner(5)
rc = self._runner_exitcode()
logger.info(f"Runner exited with exit code {rc}")
if rc == 0:
return
Expand Down Expand Up @@ -301,4 +426,4 @@ async def _check_runner(self, e: Exception) -> None:
logger.warning(
"Event sender already closed, unable to report runner failure"
)
self.shutdown()
self._cancel_tg()