diff --git a/src/exo/worker/runner/runner_supervisor.py b/src/exo/worker/runner/runner_supervisor.py index 8a48c6bcdc..13cb194109 100644 --- a/src/exo/worker/runner/runner_supervisor.py +++ b/src/exo/worker/runner/runner_supervisor.py @@ -1,5 +1,6 @@ import contextlib import multiprocessing as mp +import os import signal from dataclasses import dataclass, field from typing import Self @@ -49,6 +50,35 @@ DECODE_TIMEOUT_SECONDS = 5 +def _sigterm_handler(signum, frame): + """ + SIGTERM handler: forcibly SIGKILL all direct child processes so that + orphaned python3 MLX-runner processes do not survive a kickstart. + Re-raises default SIGTERM so the supervisor itself still exits cleanly. + """ + try: + import subprocess + result = subprocess.run( + ["pgrep", "-P", str(os.getpid())], + capture_output=True, text=True, timeout=2 + ) + for pid_str in result.stdout.strip().splitlines(): + try: + os.kill(int(pid_str), signal.SIGKILL) + except (ProcessLookupError, ValueError): + pass + except Exception: + pass + # Restore default and re-raise so the process exits with SIGTERM. + signal.signal(signum, signal.SIG_DFL) + os.kill(os.getpid(), signum) + + +# Install at import time so the handler is active for the entire supervisor +# lifetime, including the finally block inside run(). +signal.signal(signal.SIGTERM, _sigterm_handler) + + @dataclass(eq=False) class RunnerSupervisor: shard_metadata: ShardMetadata @@ -68,6 +98,21 @@ class RunnerSupervisor: _cancel_watch_runner: anyio.CancelScope = field( default_factory=anyio.CancelScope, init=False ) + _shutdown_requested: bool = field(default=False, init=False) + + + def _runner_is_alive(self) -> bool: + try: + return self.runner_process.is_alive() + except ValueError: + return False + + def _runner_exitcode(self) -> int | None: + try: + return self.runner_process.exitcode + except ValueError: + return -1 + @classmethod def create( @@ -108,64 +153,144 @@ def create( return self + # ------------------------------------------------------------------ + # Non-blocking helpers — each offloads a blocking call to a thread so + # the asyncio event loop (and therefore the API server) stays alive. + # ------------------------------------------------------------------ + + async def _join_runner(self, timeout: float) -> None: + """Join the runner process without blocking the event loop.""" + await to_thread.run_sync( + lambda: self.runner_process.join(timeout), abandon_on_cancel=True + ) + + async def _terminate_runner(self) -> None: + """Send SIGTERM to the runner without blocking the event loop.""" + await to_thread.run_sync( + self.runner_process.terminate, abandon_on_cancel=True + ) + + async def _kill_runner(self) -> None: + """Send SIGKILL to the runner without blocking the event loop.""" + await to_thread.run_sync( + self.runner_process.kill, abandon_on_cancel=True + ) + async def run(self): - self.runner_process.start() - try: - async with self._tg as tg: - tg.start_soon(self._watch_runner) - tg.start_soon(self._forward_events) - finally: - logger.info("Runner supervisor shutting down") - if not self._cancel_watch_runner.cancel_called: - self._cancel_watch_runner.cancel() - with contextlib.suppress(ClosedResourceError): - self._ev_recv.close() - with contextlib.suppress(ClosedResourceError): - self._task_sender.close() - with contextlib.suppress(ClosedResourceError): - self._event_sender.close() - with contextlib.suppress(ClosedResourceError): - self._cancel_sender.send(CANCEL_ALL_TASKS) - with contextlib.suppress(ClosedResourceError): - self._cancel_sender.close() - - await to_thread.run_sync(self.runner_process.join, 5) - - if self.runner_process.is_alive(): - logger.warning( - "Runner process didn't shutdown succesfully, terminating" - ) - self.runner_process.terminate() - self.runner_process.join(timeout=10) + MAX_RESTARTS = 5 + restart_count = 0 - if not self.runner_process.is_alive(): - logger.warning("Terminated nicely in the first attempt!") + while True: + self._shutdown_requested = False + self.runner_process.start() + try: + async with self._tg as tg: + tg.start_soon(self._watch_runner) + tg.start_soon(self._forward_events) + finally: + logger.info("Runner supervisor shutting down" if self._shutdown_requested else "Runner process exited unexpectedly, cleaning up") + if not self._cancel_watch_runner.cancel_called: + self._cancel_watch_runner.cancel() + with contextlib.suppress(ClosedResourceError): + self._ev_recv.close() + with contextlib.suppress(ClosedResourceError): + self._task_sender.close() + # Only close the event sender on intentional shutdown — on crash-restart, + # keep it open so the rest of exo keeps receiving events after reload + if self._shutdown_requested: + with contextlib.suppress(ClosedResourceError): + self._event_sender.close() + with contextlib.suppress(ClosedResourceError, TimeoutError, Exception): + with anyio.move_on_after(2.0): + await self._cancel_sender.send_async(CANCEL_ALL_TASKS) + with contextlib.suppress(ClosedResourceError): + self._cancel_sender.close() - else: - # Try really hard to terminate - for i in range(2, 11): - self.runner_process.terminate() - self.runner_process.join(timeout=2) - if not self.runner_process.is_alive(): - logger.warning(f"That took {i} attempts :)") - break - # Try even harder to kill + await self._join_runner(5) + + if self._runner_is_alive(): + logger.warning( + "Runner process didn't shutdown succesfully, terminating" + ) + await self._terminate_runner() + await self._join_runner(10) + + if not self._runner_is_alive(): + logger.warning("Terminated nicely in the first attempt!") else: - logger.critical( - "Runner process didn't respond to SIGTERM, killing" - ) - j = 0 - while self.runner_process.is_alive(): - j += 1 - self.runner_process.kill() - self.runner_process.join(timeout=5) - logger.warning(f"That took {j} attempts :(") - else: - logger.info("Runner process succesfully terminated") + for i in range(2, 11): + await self._terminate_runner() + await self._join_runner(2) + if not self._runner_is_alive(): + logger.warning(f"That took {i} attempts :)") + break + else: + logger.critical( + "Runner process didn't respond to SIGTERM, killing" + ) + j = 0 + while self._runner_is_alive(): + j += 1 + await self._kill_runner() + await self._join_runner(5) + logger.warning(f"That took {j} attempts :(") + else: + logger.info("Runner process succesfully terminated") + + self.runner_process.close() + + if self._shutdown_requested: + logger.info("Runner supervisor: intentional shutdown, not restarting") + break - self.runner_process.close() + restart_count += 1 + if restart_count > MAX_RESTARTS: + logger.critical( + f"Runner crashed {MAX_RESTARTS} times without recovery, giving up" + ) + break + + delay = min(2.0 * (2 ** (restart_count - 1)), 10.0) + logger.warning( + f"Runner crashed (attempt {restart_count}/{MAX_RESTARTS}), " + f"restarting in {delay:.0f}s" + ) + await anyio.sleep(delay) + self._reset_for_restart() def shutdown(self): + self._shutdown_requested = True + self._tg.cancel_tasks() + + def _reset_for_restart(self) -> None: + """Recreate channels and process for a runner restart after an unexpected crash.""" + ev_send, ev_recv = mp_channel[Event]() + task_sender, task_recv = mp_channel[Task]() + cancel_sender, cancel_recv = mp_channel[TaskId]() + + self.runner_process = mp.Process( + target=entrypoint, + args=(self.bound_instance, ev_send, task_recv, cancel_recv, logger), + daemon=True, + ) + self._ev_recv = ev_recv + self._task_sender = task_sender + self._cancel_sender = cancel_sender + self._tg = TaskGroup() + self._cancel_watch_runner = anyio.CancelScope() + self.status = RunnerIdle() + self.pending = {} + self.in_progress = {} + self.completed = set() + self.cancelled = set() + # _event_sender is intentionally NOT reset — it connects to the rest of exo + # and must remain open across restarts + + def _cancel_tg(self) -> None: + """Cancel the running task group without marking this as an intentional shutdown. + Used by _check_runner to tear down the current run() iteration so the restart + loop can start a fresh runner subprocess. + """ self._tg.cancel_tasks() async def start_task(self, task: Task): @@ -246,17 +371,17 @@ async def _watch_runner(self) -> None: with self._cancel_watch_runner: while True: await anyio.sleep(5) - if not self.runner_process.is_alive(): + if not self._runner_is_alive(): await self._check_runner(RuntimeError("Runner found to be dead")) async def _check_runner(self, e: Exception) -> None: if not self._cancel_watch_runner.cancel_called: self._cancel_watch_runner.cancel() logger.info("Checking runner's status") - if self.runner_process.is_alive(): + if self._runner_is_alive(): logger.info("Runner was found to be alive, attempting to join process") - await to_thread.run_sync(self.runner_process.join, 5) - rc = self.runner_process.exitcode + await self._join_runner(5) + rc = self._runner_exitcode() logger.info(f"Runner exited with exit code {rc}") if rc == 0: return @@ -301,4 +426,4 @@ async def _check_runner(self, e: Exception) -> None: logger.warning( "Event sender already closed, unable to report runner failure" ) - self.shutdown() + self._cancel_tg()