Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 81 additions & 0 deletions src/exo/main.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import argparse
import ipaddress
import multiprocessing as mp
import os
import resource
import signal
import subprocess
import sys
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import Self
Expand Down Expand Up @@ -44,6 +46,7 @@ class Node:
node_id: NodeId
offline: bool
_api_port: int
_libp2p_port: int
_tg: TaskGroup = field(init=False, default_factory=TaskGroup)

@classmethod
Expand Down Expand Up @@ -144,6 +147,7 @@ async def create(cls, args: "Args") -> Self:
node_id,
args.offline,
args.api_port,
args.libp2p_port,
)
logger_set_context(
node_id=node_id, role="master" if args.force_master else "node"
Expand All @@ -169,6 +173,12 @@ async def run(self):
tg.start_soon(self.master.run)
if self.api:
tg.start_soon(self.api.run)
if sys.platform == "darwin" and self._libp2p_port != 0:
tg.start_soon(
_darwin_mdns_broadcast_announcer,
self.node_id,
self._libp2p_port,
)
tg.start_soon(self._elect_loop)
tg.start_soon(self._diagnostic_snapshot_loop)

Expand Down Expand Up @@ -361,6 +371,77 @@ def _last_seen_ages(self, state: State) -> dict[str, float]:
return ages


def _darwin_en0_ip_address() -> str | None:
try:
return subprocess.check_output(
["ipconfig", "getifaddr", "en0"],
text=True,
stderr=subprocess.DEVNULL,
).strip()
except (OSError, subprocess.CalledProcessError):
return None


def _darwin_en0_broadcast_address(ip_address: str) -> str | None:
try:
subnet_mask = subprocess.check_output(
["ipconfig", "getoption", "en0", "subnet_mask"],
text=True,
stderr=subprocess.DEVNULL,
).strip()
interface = ipaddress.IPv4Interface(f"{ip_address}/{subnet_mask}")
return str(interface.network.broadcast_address)
except (OSError, ValueError, subprocess.CalledProcessError):
return None


async def _darwin_mdns_broadcast_announcer(
node_id: NodeId, libp2p_port: int
) -> None:
ip_address = _darwin_en0_ip_address()
if not ip_address:
logger.debug("Darwin mDNS broadcast announcer disabled: no en0 IPv4 address")
return

broadcast_address = _darwin_en0_broadcast_address(ip_address)
logger.debug(
f"Darwin mDNS announcer advertising {node_id} at {ip_address}:{libp2p_port}"
)
command = [
sys.executable,
"-m",
"exo.routing.mdns_announcer",
"--node-id",
str(node_id),
"--ip-address",
ip_address,
"--libp2p-port",
str(libp2p_port),
]
if broadcast_address is not None:
command.extend(["--broadcast-address", broadcast_address])
process = subprocess.Popen(
command,
start_new_session=True,
stdout=subprocess.DEVNULL,
)
try:
while process.poll() is None:
await anyio.sleep(60)
logger.debug(
f"Darwin mDNS announcer subprocess exited with {process.returncode}"
)
finally:
if process.poll() is None:
process.terminate()
with anyio.move_on_after(2):
while process.poll() is None:
await anyio.sleep(0.1)
if process.poll() is None:
process.kill()
await anyio.sleep(0)


def main():
args = Args.parse()
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
Expand Down
7 changes: 5 additions & 2 deletions src/exo/master/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,9 @@ async def _command_processor(self) -> None:

# These plan loops are the cracks showing in our event sourcing architecture - more things could be commands
async def _plan(self) -> None:
node_inactivity_timeout = timedelta(seconds=5)
tick_interval_seconds = 1.0

while True:
# kill broken instances
connected_node_ids = set(self.state.topology.list_nodes())
Expand All @@ -503,7 +506,7 @@ async def _plan(self) -> None:
# time out dead nodes
for node_id, time in self.state.last_seen.items():
now = datetime.now(tz=timezone.utc)
if now - time > timedelta(seconds=30):
if now - time > node_inactivity_timeout:
impacted_instances = [
str(instance_id)
for instance_id, instance in self.state.instances.items()
Expand All @@ -520,7 +523,7 @@ async def _plan(self) -> None:
)
await self.event_sender.send(NodeTimedOut(node_id=node_id))

await anyio.sleep(10)
await anyio.sleep(tick_interval_seconds)

async def _event_processor(self) -> None:
with self.local_event_receiver as local_events:
Expand Down
97 changes: 97 additions & 0 deletions src/exo/routing/mdns_announcer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import argparse
import contextlib
import random
import socket
import string
import struct
import sys
import time
from typing import final


def _dns_qname(name: bytes) -> bytes:
return b"".join(bytes([len(part)]) + part for part in name.split(b".")) + b"\0"


def _build_response_packet(node_id: str, ip_address: str, libp2p_port: int) -> bytes:
service_name = b"_p2p._udp.local"
peer_name = (
"".join(random.choice(string.ascii_letters + string.digits) for _ in range(32))
+ "._p2p._udp.local"
).encode()
txt_record = f"dnsaddr=/ip4/{ip_address}/tcp/{libp2p_port}/p2p/{node_id}".encode()

peer_qname = _dns_qname(peer_name)
packet = bytearray()
packet += struct.pack("!HHHHHH", 0, 0x8400, 0, 1, 0, 1)
packet += _dns_qname(service_name)
packet += struct.pack("!HHI", 12, 1, 120)
packet += struct.pack("!H", len(peer_qname))
packet += peer_qname
packet += peer_qname
packet += struct.pack("!HHI", 16, 1, 120)
packet += struct.pack("!H", len(txt_record) + 1)
packet += bytes([len(txt_record)])
packet += txt_record
return bytes(packet)


@final
class Args(argparse.Namespace):
node_id: str
ip_address: str
libp2p_port: int
broadcast_address: str | None
count: int

@staticmethod
def parse() -> "Args":
parser = argparse.ArgumentParser()
parser.add_argument("--node-id", required=True)
parser.add_argument("--ip-address", required=True)
parser.add_argument("--libp2p-port", required=True, type=int)
parser.add_argument("--broadcast-address")
parser.add_argument("--count", default=0, type=int)
return parser.parse_args(namespace=Args())


def main() -> None:
args = Args.parse()
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
with contextlib.suppress(OSError):
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
sock.bind((args.ip_address, 0))

sent_count = 0
while True:
packet = _build_response_packet(
args.node_id, args.ip_address, args.libp2p_port
)
errors: list[str] = []
destinations: list[tuple[str, int]] = []
if args.broadcast_address is not None:
destinations.append((args.broadcast_address, 5353))
destinations.extend([("255.255.255.255", 5353), ("224.0.0.251", 5353)])
sent = False
for destination in destinations:
try:
sock.sendto(packet, destination)
sent = True
except OSError as err:
errors.append(f"{destination}: {err}")
if not sent:
print(
f"mDNS announcer send failed: {'; '.join(errors)}",
file=sys.stderr,
flush=True,
)
sent_count += 1
if args.count > 0 and sent_count >= args.count:
return
time.sleep(1.0 if sent_count < 60 else 10.0)


if __name__ == "__main__":
main()
2 changes: 0 additions & 2 deletions src/exo/routing/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,8 +298,6 @@ def get_node_id_keypair(
Obtains the :class:`Keypair` associated with this node-ID.
Obtain the :class:`PeerId` by from it.
"""
# TODO(evan): bring back node id persistence once we figure out how to deal with duplicates
return Keypair.generate()

def lock_path(path: str | bytes | PathLike[str] | PathLike[bytes]) -> Path:
return Path(str(path) + ".lock")
Expand Down
7 changes: 3 additions & 4 deletions src/exo/shared/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@


def _get_xdg_dir(env_var: str, fallback: str) -> Path:
"""Get XDG directory, prioritising EXO_HOME environment variable if its set. On non-Linux platforms, default to ~/.exo."""
"""Get XDG directory, prioritising EXO_HOME environment variable if its set. On non-Linux platforms, default to ~/.exo. Cache home always prefers .cache/exo"""

if _EXO_HOME_ENV is not None:
return Path.home() / _EXO_HOME_ENV

if sys.platform != "linux":
if sys.platform != "linux" and env_var != "XDG_CACHE_HOME":
return Path.home() / ".exo"

xdg_value = os.environ.get(env_var, None)
Expand Down Expand Up @@ -68,10 +68,9 @@ def _parse_colon_dirs(env_var: str) -> tuple[Path, ...]:
# Log files (data/logs or cache)
EXO_LOG_DIR = EXO_CACHE_HOME / "exo_log"
EXO_LOG = EXO_LOG_DIR / "exo.log"
EXO_TEST_LOG = EXO_CACHE_HOME / "exo_test.log"

# Identity (config)
EXO_NODE_ID_KEYPAIR = EXO_CONFIG_HOME / "node_id.keypair"
EXO_NODE_ID_KEYPAIR = EXO_CACHE_HOME / "node_id.keypair"
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Preserve legacy node-id key path when enabling persistence

Changing EXO_NODE_ID_KEYPAIR from config to cache without a fallback migration means any installation that already has node_id.keypair in the previous config location will generate a new peer identity after upgrade. That breaks node identity continuity across restarts/upgrades and can invalidate cluster membership/reconnect behavior for existing deployments. Load the legacy config-path key when the new cache-path key is absent, then migrate it.

Useful? React with 👍 / 👎.

EXO_CONFIG_FILE = EXO_CONFIG_HOME / "config.toml"

# libp2p topics for event forwarding
Expand Down
22 changes: 21 additions & 1 deletion src/exo/shared/tests/test_xdg_paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,27 @@ def test_macos_uses_traditional_paths():
home = Path.home()
assert home / ".exo" == constants.EXO_CONFIG_HOME
assert home / ".exo" == constants.EXO_DATA_HOME
assert home / ".exo" == constants.EXO_CACHE_HOME
assert home / ".cache" / "exo" == constants.EXO_CACHE_HOME


def test_exo_home_env():
"""Test that macOS uses traditional ~/.exo directory."""
# Remove EXO_HOME to ensure we test the default behavior
env = {k: v for k, v in os.environ.items() if k != "EXO_HOME"}
env["EXO_HOME"] = "/exo"
with (
mock.patch.dict(os.environ, env, clear=True),
mock.patch.object(sys, "platform", "darwin"),
):
import importlib

import exo.shared.constants as constants

importlib.reload(constants)

assert Path("/exo") == constants.EXO_CONFIG_HOME
assert Path("/exo") == constants.EXO_DATA_HOME
assert Path("/exo") == constants.EXO_CACHE_HOME


def test_node_id_in_config_dir():
Expand Down
4 changes: 4 additions & 0 deletions src/exo/utils/keyed_backoff.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ def attempts(self, key: K) -> int:
"""Return the number of recorded attempts for a key."""
return self._attempts.get(key, 0)

def tracked_keys(self) -> set[K]:
"""Return keys that currently have recorded backoff state."""
return set(self._attempts) | set(self._last_time)

def reset(self, key: K) -> None:
"""Reset backoff state for a key (e.g., on success)."""
self._attempts.pop(key, None)
Expand Down
13 changes: 13 additions & 0 deletions src/exo/utils/tests/test_keyed_backoff.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from exo.utils.keyed_backoff import KeyedBackoff


def test_tracked_keys_reports_and_resets_backoff_state() -> None:
backoff = KeyedBackoff[str]()

backoff.record_attempt("instance-a")

assert backoff.tracked_keys() == {"instance-a"}

backoff.reset("instance-a")

assert backoff.tracked_keys() == set()
24 changes: 22 additions & 2 deletions src/exo/worker/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from datetime import datetime, timezone

import anyio
from anyio import fail_after, to_thread
from anyio import fail_after, move_on_after, to_thread
from loguru import logger

from exo.api.types import ImageEditsTaskParams
Expand Down Expand Up @@ -109,6 +109,7 @@ async def run(self):
tg.start_soon(self._forward_info, info_recv)
tg.start_soon(self.plan_step)
tg.start_soon(self._event_applier)
tg.start_soon(self._reconcile_instance_backoff)
tg.start_soon(self._poll_connection_updates)
finally:
# Actual shutdown code - waits for all tasks to complete before executing.
Expand Down Expand Up @@ -179,6 +180,17 @@ async def _event_applier(self):
if isinstance(event, CustomModelCardDeleted):
await delete_custom_card(event.model_id)

async def _reconcile_instance_backoff(self) -> None:
while True:
await anyio.sleep(1)
self._reconcile_instance_backoff_once()

def _reconcile_instance_backoff_once(self) -> None:
live_instances = set(self.state.instances)
for instance_id in self._instance_backoff.tracked_keys():
if instance_id not in live_instances:
self._instance_backoff.reset(instance_id)

async def plan_step(self):
while True:
await anyio.sleep(0.1)
Expand Down Expand Up @@ -356,8 +368,16 @@ async def plan_step(self):
await self._start_runner_task(task)

async def shutdown(self):
self.event_sender.close()
self.command_sender.close()
self.download_command_sender.close()
for runner in self.runners.values():
runner.shutdown()
Comment on lines +371 to +375
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Delay sender closure until worker tasks finish shutting down

Closing event_sender/command_sender before canceling worker tasks and runner supervisors can drop terminal status events during shutdown. RunnerSupervisor attempts to emit failure/completion updates while winding down, but those sends will raise ClosedResourceError once this early close runs, so the master can keep stale running state until timeout-based cleanup. Cancel and drain runner/task shutdown first, then close senders after _stopped is set.

Useful? React with 👍 / 👎.

self._tg.cancel_tasks()
await self._stopped.wait()
with move_on_after(5) as scope:
await self._stopped.wait()
if scope.cancel_called:
logger.warning("Timed out waiting for Worker shutdown")

async def _start_runner_task(self, task: Task):
if (instance := self.state.instances.get(task.instance_id)) is not None:
Expand Down
Loading