diff --git a/src/exo/main.py b/src/exo/main.py index 9edee42096..e308afc879 100644 --- a/src/exo/main.py +++ b/src/exo/main.py @@ -1,9 +1,11 @@ import argparse +import ipaddress import multiprocessing as mp import os import resource import signal import subprocess +import sys from dataclasses import dataclass, field from datetime import datetime, timezone from typing import Self @@ -44,6 +46,7 @@ class Node: node_id: NodeId offline: bool _api_port: int + _libp2p_port: int _tg: TaskGroup = field(init=False, default_factory=TaskGroup) @classmethod @@ -144,6 +147,7 @@ async def create(cls, args: "Args") -> Self: node_id, args.offline, args.api_port, + args.libp2p_port, ) logger_set_context( node_id=node_id, role="master" if args.force_master else "node" @@ -169,6 +173,12 @@ async def run(self): tg.start_soon(self.master.run) if self.api: tg.start_soon(self.api.run) + if sys.platform == "darwin" and self._libp2p_port != 0: + tg.start_soon( + _darwin_mdns_broadcast_announcer, + self.node_id, + self._libp2p_port, + ) tg.start_soon(self._elect_loop) tg.start_soon(self._diagnostic_snapshot_loop) @@ -361,6 +371,77 @@ def _last_seen_ages(self, state: State) -> dict[str, float]: return ages +def _darwin_en0_ip_address() -> str | None: + try: + return subprocess.check_output( + ["ipconfig", "getifaddr", "en0"], + text=True, + stderr=subprocess.DEVNULL, + ).strip() + except (OSError, subprocess.CalledProcessError): + return None + + +def _darwin_en0_broadcast_address(ip_address: str) -> str | None: + try: + subnet_mask = subprocess.check_output( + ["ipconfig", "getoption", "en0", "subnet_mask"], + text=True, + stderr=subprocess.DEVNULL, + ).strip() + interface = ipaddress.IPv4Interface(f"{ip_address}/{subnet_mask}") + return str(interface.network.broadcast_address) + except (OSError, ValueError, subprocess.CalledProcessError): + return None + + +async def _darwin_mdns_broadcast_announcer( + node_id: NodeId, libp2p_port: int +) -> None: + ip_address = _darwin_en0_ip_address() + if not ip_address: + logger.debug("Darwin mDNS broadcast announcer disabled: no en0 IPv4 address") + return + + broadcast_address = _darwin_en0_broadcast_address(ip_address) + logger.debug( + f"Darwin mDNS announcer advertising {node_id} at {ip_address}:{libp2p_port}" + ) + command = [ + sys.executable, + "-m", + "exo.routing.mdns_announcer", + "--node-id", + str(node_id), + "--ip-address", + ip_address, + "--libp2p-port", + str(libp2p_port), + ] + if broadcast_address is not None: + command.extend(["--broadcast-address", broadcast_address]) + process = subprocess.Popen( + command, + start_new_session=True, + stdout=subprocess.DEVNULL, + ) + try: + while process.poll() is None: + await anyio.sleep(60) + logger.debug( + f"Darwin mDNS announcer subprocess exited with {process.returncode}" + ) + finally: + if process.poll() is None: + process.terminate() + with anyio.move_on_after(2): + while process.poll() is None: + await anyio.sleep(0.1) + if process.poll() is None: + process.kill() + await anyio.sleep(0) + + def main(): args = Args.parse() soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE) diff --git a/src/exo/master/main.py b/src/exo/master/main.py index 888c39e4c8..50c23fa5c9 100644 --- a/src/exo/master/main.py +++ b/src/exo/master/main.py @@ -481,6 +481,9 @@ async def _command_processor(self) -> None: # These plan loops are the cracks showing in our event sourcing architecture - more things could be commands async def _plan(self) -> None: + node_inactivity_timeout = timedelta(seconds=5) + tick_interval_seconds = 1.0 + while True: # kill broken instances connected_node_ids = set(self.state.topology.list_nodes()) @@ -503,7 +506,7 @@ async def _plan(self) -> None: # time out dead nodes for node_id, time in self.state.last_seen.items(): now = datetime.now(tz=timezone.utc) - if now - time > timedelta(seconds=30): + if now - time > node_inactivity_timeout: impacted_instances = [ str(instance_id) for instance_id, instance in self.state.instances.items() @@ -520,7 +523,7 @@ async def _plan(self) -> None: ) await self.event_sender.send(NodeTimedOut(node_id=node_id)) - await anyio.sleep(10) + await anyio.sleep(tick_interval_seconds) async def _event_processor(self) -> None: with self.local_event_receiver as local_events: diff --git a/src/exo/routing/mdns_announcer.py b/src/exo/routing/mdns_announcer.py new file mode 100644 index 0000000000..cafd1d3acc --- /dev/null +++ b/src/exo/routing/mdns_announcer.py @@ -0,0 +1,97 @@ +import argparse +import contextlib +import random +import socket +import string +import struct +import sys +import time +from typing import final + + +def _dns_qname(name: bytes) -> bytes: + return b"".join(bytes([len(part)]) + part for part in name.split(b".")) + b"\0" + + +def _build_response_packet(node_id: str, ip_address: str, libp2p_port: int) -> bytes: + service_name = b"_p2p._udp.local" + peer_name = ( + "".join(random.choice(string.ascii_letters + string.digits) for _ in range(32)) + + "._p2p._udp.local" + ).encode() + txt_record = f"dnsaddr=/ip4/{ip_address}/tcp/{libp2p_port}/p2p/{node_id}".encode() + + peer_qname = _dns_qname(peer_name) + packet = bytearray() + packet += struct.pack("!HHHHHH", 0, 0x8400, 0, 1, 0, 1) + packet += _dns_qname(service_name) + packet += struct.pack("!HHI", 12, 1, 120) + packet += struct.pack("!H", len(peer_qname)) + packet += peer_qname + packet += peer_qname + packet += struct.pack("!HHI", 16, 1, 120) + packet += struct.pack("!H", len(txt_record) + 1) + packet += bytes([len(txt_record)]) + packet += txt_record + return bytes(packet) + + +@final +class Args(argparse.Namespace): + node_id: str + ip_address: str + libp2p_port: int + broadcast_address: str | None + count: int + + @staticmethod + def parse() -> "Args": + parser = argparse.ArgumentParser() + parser.add_argument("--node-id", required=True) + parser.add_argument("--ip-address", required=True) + parser.add_argument("--libp2p-port", required=True, type=int) + parser.add_argument("--broadcast-address") + parser.add_argument("--count", default=0, type=int) + return parser.parse_args(namespace=Args()) + + +def main() -> None: + args = Args.parse() + sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + with contextlib.suppress(OSError): + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) + sock.bind((args.ip_address, 0)) + + sent_count = 0 + while True: + packet = _build_response_packet( + args.node_id, args.ip_address, args.libp2p_port + ) + errors: list[str] = [] + destinations: list[tuple[str, int]] = [] + if args.broadcast_address is not None: + destinations.append((args.broadcast_address, 5353)) + destinations.extend([("255.255.255.255", 5353), ("224.0.0.251", 5353)]) + sent = False + for destination in destinations: + try: + sock.sendto(packet, destination) + sent = True + except OSError as err: + errors.append(f"{destination}: {err}") + if not sent: + print( + f"mDNS announcer send failed: {'; '.join(errors)}", + file=sys.stderr, + flush=True, + ) + sent_count += 1 + if args.count > 0 and sent_count >= args.count: + return + time.sleep(1.0 if sent_count < 60 else 10.0) + + +if __name__ == "__main__": + main() diff --git a/src/exo/routing/router.py b/src/exo/routing/router.py index ebe0ea8d90..5b679fe192 100644 --- a/src/exo/routing/router.py +++ b/src/exo/routing/router.py @@ -298,8 +298,6 @@ def get_node_id_keypair( Obtains the :class:`Keypair` associated with this node-ID. Obtain the :class:`PeerId` by from it. """ - # TODO(evan): bring back node id persistence once we figure out how to deal with duplicates - return Keypair.generate() def lock_path(path: str | bytes | PathLike[str] | PathLike[bytes]) -> Path: return Path(str(path) + ".lock") diff --git a/src/exo/shared/constants.py b/src/exo/shared/constants.py index d79354184b..ad76a88ffb 100644 --- a/src/exo/shared/constants.py +++ b/src/exo/shared/constants.py @@ -8,12 +8,12 @@ def _get_xdg_dir(env_var: str, fallback: str) -> Path: - """Get XDG directory, prioritising EXO_HOME environment variable if its set. On non-Linux platforms, default to ~/.exo.""" + """Get XDG directory, prioritising EXO_HOME environment variable if its set. On non-Linux platforms, default to ~/.exo. Cache home always prefers .cache/exo""" if _EXO_HOME_ENV is not None: return Path.home() / _EXO_HOME_ENV - if sys.platform != "linux": + if sys.platform != "linux" and env_var != "XDG_CACHE_HOME": return Path.home() / ".exo" xdg_value = os.environ.get(env_var, None) @@ -68,10 +68,9 @@ def _parse_colon_dirs(env_var: str) -> tuple[Path, ...]: # Log files (data/logs or cache) EXO_LOG_DIR = EXO_CACHE_HOME / "exo_log" EXO_LOG = EXO_LOG_DIR / "exo.log" -EXO_TEST_LOG = EXO_CACHE_HOME / "exo_test.log" # Identity (config) -EXO_NODE_ID_KEYPAIR = EXO_CONFIG_HOME / "node_id.keypair" +EXO_NODE_ID_KEYPAIR = EXO_CACHE_HOME / "node_id.keypair" EXO_CONFIG_FILE = EXO_CONFIG_HOME / "config.toml" # libp2p topics for event forwarding diff --git a/src/exo/shared/tests/test_xdg_paths.py b/src/exo/shared/tests/test_xdg_paths.py index f3b82ebffd..dce2c7d7c1 100644 --- a/src/exo/shared/tests/test_xdg_paths.py +++ b/src/exo/shared/tests/test_xdg_paths.py @@ -94,7 +94,27 @@ def test_macos_uses_traditional_paths(): home = Path.home() assert home / ".exo" == constants.EXO_CONFIG_HOME assert home / ".exo" == constants.EXO_DATA_HOME - assert home / ".exo" == constants.EXO_CACHE_HOME + assert home / ".cache" / "exo" == constants.EXO_CACHE_HOME + + +def test_exo_home_env(): + """Test that macOS uses traditional ~/.exo directory.""" + # Remove EXO_HOME to ensure we test the default behavior + env = {k: v for k, v in os.environ.items() if k != "EXO_HOME"} + env["EXO_HOME"] = "/exo" + with ( + mock.patch.dict(os.environ, env, clear=True), + mock.patch.object(sys, "platform", "darwin"), + ): + import importlib + + import exo.shared.constants as constants + + importlib.reload(constants) + + assert Path("/exo") == constants.EXO_CONFIG_HOME + assert Path("/exo") == constants.EXO_DATA_HOME + assert Path("/exo") == constants.EXO_CACHE_HOME def test_node_id_in_config_dir(): diff --git a/src/exo/utils/keyed_backoff.py b/src/exo/utils/keyed_backoff.py index 4d7c9a66ed..a95fe5c5f7 100644 --- a/src/exo/utils/keyed_backoff.py +++ b/src/exo/utils/keyed_backoff.py @@ -29,6 +29,10 @@ def attempts(self, key: K) -> int: """Return the number of recorded attempts for a key.""" return self._attempts.get(key, 0) + def tracked_keys(self) -> set[K]: + """Return keys that currently have recorded backoff state.""" + return set(self._attempts) | set(self._last_time) + def reset(self, key: K) -> None: """Reset backoff state for a key (e.g., on success).""" self._attempts.pop(key, None) diff --git a/src/exo/utils/tests/test_keyed_backoff.py b/src/exo/utils/tests/test_keyed_backoff.py new file mode 100644 index 0000000000..b592a4fabd --- /dev/null +++ b/src/exo/utils/tests/test_keyed_backoff.py @@ -0,0 +1,13 @@ +from exo.utils.keyed_backoff import KeyedBackoff + + +def test_tracked_keys_reports_and_resets_backoff_state() -> None: + backoff = KeyedBackoff[str]() + + backoff.record_attempt("instance-a") + + assert backoff.tracked_keys() == {"instance-a"} + + backoff.reset("instance-a") + + assert backoff.tracked_keys() == set() diff --git a/src/exo/worker/main.py b/src/exo/worker/main.py index b35f946aac..68f329cc7c 100644 --- a/src/exo/worker/main.py +++ b/src/exo/worker/main.py @@ -3,7 +3,7 @@ from datetime import datetime, timezone import anyio -from anyio import fail_after, to_thread +from anyio import fail_after, move_on_after, to_thread from loguru import logger from exo.api.types import ImageEditsTaskParams @@ -109,6 +109,7 @@ async def run(self): tg.start_soon(self._forward_info, info_recv) tg.start_soon(self.plan_step) tg.start_soon(self._event_applier) + tg.start_soon(self._reconcile_instance_backoff) tg.start_soon(self._poll_connection_updates) finally: # Actual shutdown code - waits for all tasks to complete before executing. @@ -179,6 +180,17 @@ async def _event_applier(self): if isinstance(event, CustomModelCardDeleted): await delete_custom_card(event.model_id) + async def _reconcile_instance_backoff(self) -> None: + while True: + await anyio.sleep(1) + self._reconcile_instance_backoff_once() + + def _reconcile_instance_backoff_once(self) -> None: + live_instances = set(self.state.instances) + for instance_id in self._instance_backoff.tracked_keys(): + if instance_id not in live_instances: + self._instance_backoff.reset(instance_id) + async def plan_step(self): while True: await anyio.sleep(0.1) @@ -356,8 +368,16 @@ async def plan_step(self): await self._start_runner_task(task) async def shutdown(self): + self.event_sender.close() + self.command_sender.close() + self.download_command_sender.close() + for runner in self.runners.values(): + runner.shutdown() self._tg.cancel_tasks() - await self._stopped.wait() + with move_on_after(5) as scope: + await self._stopped.wait() + if scope.cancel_called: + logger.warning("Timed out waiting for Worker shutdown") async def _start_runner_task(self, task: Task): if (instance := self.state.instances.get(task.instance_id)) is not None: diff --git a/src/exo/worker/tests/unittests/test_worker_instance_backoff.py b/src/exo/worker/tests/unittests/test_worker_instance_backoff.py new file mode 100644 index 0000000000..b0052c1eb7 --- /dev/null +++ b/src/exo/worker/tests/unittests/test_worker_instance_backoff.py @@ -0,0 +1,36 @@ +# pyright: reportPrivateUsage=false + +from exo.shared.types.common import ModelId, NodeId +from exo.shared.types.state import State +from exo.shared.types.worker.instances import InstanceId, MlxRingInstance +from exo.shared.types.worker.runners import ShardAssignments +from exo.utils.keyed_backoff import KeyedBackoff +from exo.worker.main import Worker + + +def _make_instance(instance_id: InstanceId) -> MlxRingInstance: + return MlxRingInstance( + instance_id=instance_id, + shard_assignments=ShardAssignments( + model_id=ModelId("test-model"), + node_to_runner={}, + runner_to_shard={}, + ), + hosts_by_node={NodeId("node-1"): []}, + ephemeral_port=1, + ) + + +def test_worker_reconciles_instance_backoff_from_state() -> None: + live_instance_id = InstanceId("inst-live") + deleted_instance_id = InstanceId("inst-deleted") + worker = object.__new__(Worker) + worker.state = State(instances={live_instance_id: _make_instance(live_instance_id)}) + worker._instance_backoff = KeyedBackoff[InstanceId]() + worker._instance_backoff.record_attempt(live_instance_id) + worker._instance_backoff.record_attempt(deleted_instance_id) + + worker._reconcile_instance_backoff_once() + + assert worker._instance_backoff.attempts(live_instance_id) == 1 + assert worker._instance_backoff.attempts(deleted_instance_id) == 0