-
Notifications
You must be signed in to change notification settings - Fork 0
Cluster discovery + lifecycle stability fixes #17
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,97 @@ | ||
| import argparse | ||
| import contextlib | ||
| import random | ||
| import socket | ||
| import string | ||
| import struct | ||
| import sys | ||
| import time | ||
| from typing import final | ||
|
|
||
|
|
||
| def _dns_qname(name: bytes) -> bytes: | ||
| return b"".join(bytes([len(part)]) + part for part in name.split(b".")) + b"\0" | ||
|
|
||
|
|
||
| def _build_response_packet(node_id: str, ip_address: str, libp2p_port: int) -> bytes: | ||
| service_name = b"_p2p._udp.local" | ||
| peer_name = ( | ||
| "".join(random.choice(string.ascii_letters + string.digits) for _ in range(32)) | ||
| + "._p2p._udp.local" | ||
| ).encode() | ||
| txt_record = f"dnsaddr=/ip4/{ip_address}/tcp/{libp2p_port}/p2p/{node_id}".encode() | ||
|
|
||
| peer_qname = _dns_qname(peer_name) | ||
| packet = bytearray() | ||
| packet += struct.pack("!HHHHHH", 0, 0x8400, 0, 1, 0, 1) | ||
| packet += _dns_qname(service_name) | ||
| packet += struct.pack("!HHI", 12, 1, 120) | ||
| packet += struct.pack("!H", len(peer_qname)) | ||
| packet += peer_qname | ||
| packet += peer_qname | ||
| packet += struct.pack("!HHI", 16, 1, 120) | ||
| packet += struct.pack("!H", len(txt_record) + 1) | ||
| packet += bytes([len(txt_record)]) | ||
| packet += txt_record | ||
| return bytes(packet) | ||
|
|
||
|
|
||
| @final | ||
| class Args(argparse.Namespace): | ||
| node_id: str | ||
| ip_address: str | ||
| libp2p_port: int | ||
| broadcast_address: str | None | ||
| count: int | ||
|
|
||
| @staticmethod | ||
| def parse() -> "Args": | ||
| parser = argparse.ArgumentParser() | ||
| parser.add_argument("--node-id", required=True) | ||
| parser.add_argument("--ip-address", required=True) | ||
| parser.add_argument("--libp2p-port", required=True, type=int) | ||
| parser.add_argument("--broadcast-address") | ||
| parser.add_argument("--count", default=0, type=int) | ||
| return parser.parse_args(namespace=Args()) | ||
|
|
||
|
|
||
| def main() -> None: | ||
| args = Args.parse() | ||
| sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) | ||
| sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1) | ||
| sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) | ||
| with contextlib.suppress(OSError): | ||
| sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) | ||
| sock.bind((args.ip_address, 0)) | ||
|
|
||
| sent_count = 0 | ||
| while True: | ||
| packet = _build_response_packet( | ||
| args.node_id, args.ip_address, args.libp2p_port | ||
| ) | ||
| errors: list[str] = [] | ||
| destinations: list[tuple[str, int]] = [] | ||
| if args.broadcast_address is not None: | ||
| destinations.append((args.broadcast_address, 5353)) | ||
| destinations.extend([("255.255.255.255", 5353), ("224.0.0.251", 5353)]) | ||
| sent = False | ||
| for destination in destinations: | ||
| try: | ||
| sock.sendto(packet, destination) | ||
| sent = True | ||
| except OSError as err: | ||
| errors.append(f"{destination}: {err}") | ||
| if not sent: | ||
| print( | ||
| f"mDNS announcer send failed: {'; '.join(errors)}", | ||
| file=sys.stderr, | ||
| flush=True, | ||
| ) | ||
| sent_count += 1 | ||
| if args.count > 0 and sent_count >= args.count: | ||
| return | ||
| time.sleep(1.0 if sent_count < 60 else 10.0) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,13 @@ | ||
| from exo.utils.keyed_backoff import KeyedBackoff | ||
|
|
||
|
|
||
| def test_tracked_keys_reports_and_resets_backoff_state() -> None: | ||
| backoff = KeyedBackoff[str]() | ||
|
|
||
| backoff.record_attempt("instance-a") | ||
|
|
||
| assert backoff.tracked_keys() == {"instance-a"} | ||
|
|
||
| backoff.reset("instance-a") | ||
|
|
||
| assert backoff.tracked_keys() == set() |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3,7 +3,7 @@ | |
| from datetime import datetime, timezone | ||
|
|
||
| import anyio | ||
| from anyio import fail_after, to_thread | ||
| from anyio import fail_after, move_on_after, to_thread | ||
| from loguru import logger | ||
|
|
||
| from exo.api.types import ImageEditsTaskParams | ||
|
|
@@ -109,6 +109,7 @@ async def run(self): | |
| tg.start_soon(self._forward_info, info_recv) | ||
| tg.start_soon(self.plan_step) | ||
| tg.start_soon(self._event_applier) | ||
| tg.start_soon(self._reconcile_instance_backoff) | ||
| tg.start_soon(self._poll_connection_updates) | ||
| finally: | ||
| # Actual shutdown code - waits for all tasks to complete before executing. | ||
|
|
@@ -179,6 +180,17 @@ async def _event_applier(self): | |
| if isinstance(event, CustomModelCardDeleted): | ||
| await delete_custom_card(event.model_id) | ||
|
|
||
| async def _reconcile_instance_backoff(self) -> None: | ||
| while True: | ||
| await anyio.sleep(1) | ||
| self._reconcile_instance_backoff_once() | ||
|
|
||
| def _reconcile_instance_backoff_once(self) -> None: | ||
| live_instances = set(self.state.instances) | ||
| for instance_id in self._instance_backoff.tracked_keys(): | ||
| if instance_id not in live_instances: | ||
| self._instance_backoff.reset(instance_id) | ||
|
|
||
| async def plan_step(self): | ||
| while True: | ||
| await anyio.sleep(0.1) | ||
|
|
@@ -356,8 +368,16 @@ async def plan_step(self): | |
| await self._start_runner_task(task) | ||
|
|
||
| async def shutdown(self): | ||
| self.event_sender.close() | ||
| self.command_sender.close() | ||
| self.download_command_sender.close() | ||
| for runner in self.runners.values(): | ||
| runner.shutdown() | ||
|
Comment on lines
+371
to
+375
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Closing Useful? React with 👍 / 👎. |
||
| self._tg.cancel_tasks() | ||
| await self._stopped.wait() | ||
| with move_on_after(5) as scope: | ||
| await self._stopped.wait() | ||
| if scope.cancel_called: | ||
| logger.warning("Timed out waiting for Worker shutdown") | ||
|
|
||
| async def _start_runner_task(self, task: Task): | ||
| if (instance := self.state.instances.get(task.instance_id)) is not None: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changing
EXO_NODE_ID_KEYPAIRfrom config to cache without a fallback migration means any installation that already hasnode_id.keypairin the previous config location will generate a new peer identity after upgrade. That breaks node identity continuity across restarts/upgrades and can invalidate cluster membership/reconnect behavior for existing deployments. Load the legacy config-path key when the new cache-path key is absent, then migrate it.Useful? React with 👍 / 👎.