diff --git a/src/exo/api/main.py b/src/exo/api/main.py index e7b61bf0fb..4ac9f605e2 100644 --- a/src/exo/api/main.py +++ b/src/exo/api/main.py @@ -730,6 +730,7 @@ async def get_placement( topology=self.state.topology, current_instances=self.state.instances, download_status=self.state.downloads, + node_rdma_ctl=self.state.node_rdma_ctl, ) except ValueError as exc: raise HTTPException(status_code=400, detail=str(exc)) from exc @@ -794,6 +795,7 @@ async def get_placement_previews( allowed_nodes=allowed_nodes, allow_single_node_total_memory=allowed_nodes is not None, download_status=self.state.downloads, + node_rdma_ctl=self.state.node_rdma_ctl, ) except ValueError as exc: if (model_card.model_id, sharding, instance_meta, 0) not in seen: diff --git a/src/exo/download/coordinator.py b/src/exo/download/coordinator.py index 1f5a99f8ef..e566df1640 100644 --- a/src/exo/download/coordinator.py +++ b/src/exo/download/coordinator.py @@ -15,6 +15,7 @@ map_repo_download_progress_to_download_progress_data, resolve_existing_model, ) +from exo.download.peer_shard_downloader import PeerAwareShardDownloader from exo.download.shard_downloader import ShardDownloader from exo.shared.constants import EXO_DEFAULT_MODELS_DIR, EXO_MODELS_READ_ONLY_DIRS from exo.shared.models.model_cards import ModelCard, ModelId, get_model_cards @@ -225,7 +226,15 @@ async def _command_processor(self) -> None: continue match cmd.command: - case StartDownload(shard_metadata=shard): + case StartDownload(shard_metadata=shard, available_peers=peers): + # Pass peer endpoints to the shard downloader if it supports it + if isinstance(self.shard_downloader, PeerAwareShardDownloader): + self.shard_downloader.set_available_peers(shard, peers) + elif hasattr(self.shard_downloader, "shard_downloader") and isinstance( + self.shard_downloader.shard_downloader, PeerAwareShardDownloader # type: ignore[union-attr] + ): + # Unwrap SingletonShardDownloader + self.shard_downloader.shard_downloader.set_available_peers(shard, peers) # type: ignore[union-attr] await self._start_download(shard) case DeleteDownload(model_id=model_id): await self._delete_download(model_id) diff --git a/src/exo/download/download_utils.py b/src/exo/download/download_utils.py index 277dd3a6d3..0da4542a36 100644 --- a/src/exo/download/download_utils.py +++ b/src/exo/download/download_utils.py @@ -1,5 +1,6 @@ import asyncio import hashlib +import json import os import random import shutil @@ -777,6 +778,9 @@ async def _download_file( ) as f: while chunk := await r.content.read(8 * 1024 * 1024): n_read = n_read + (await f.write(chunk)) + await f.flush() + # Write companion metadata for peer download streaming + await _write_partial_meta(partial_path, n_read, length, remote_hash) on_progress(n_read, length, False) final_hash = await calc_hash( @@ -792,10 +796,31 @@ async def _download_file( f"Downloaded file {target_dir / path} has hash {final_hash} but remote hash is {remote_hash}" ) await aios.rename(partial_path, target_dir / path) + # Clean up companion metadata file + meta_path = Path(f"{partial_path}.meta") + if await aios.path.exists(meta_path): + await aios.remove(meta_path) on_progress(length, length, True) return target_dir / path +async def _write_partial_meta( + partial_path: Path, safe_bytes: int, total: int, etag: str +) -> None: + """Write companion .partial.meta file for peer download streaming. + + This small JSON file tells the peer file server how many bytes of the + .partial file have been safely flushed to disk and are safe to serve. + """ + meta_path = Path(f"{partial_path}.meta") + meta = json.dumps({"safe_bytes": safe_bytes, "total": total, "etag": etag}) + # Write to temp then rename for atomicity + tmp_path = Path(f"{partial_path}.meta.tmp") + async with aiofiles.open(tmp_path, "w") as f: + await f.write(meta) + await aios.rename(tmp_path, meta_path) + + def calculate_repo_progress( shard: ShardMetadata, model_id: ModelId, diff --git a/src/exo/download/impl_shard_downloader.py b/src/exo/download/impl_shard_downloader.py index 14ec83b689..4312da368d 100644 --- a/src/exo/download/impl_shard_downloader.py +++ b/src/exo/download/impl_shard_downloader.py @@ -10,6 +10,7 @@ RepoDownloadProgress, download_shard, ) +from exo.download.peer_shard_downloader import PeerAwareShardDownloader from exo.download.shard_downloader import ShardDownloader from exo.shared.models.model_cards import ( ModelCard, @@ -25,11 +26,16 @@ def exo_shard_downloader( - max_parallel_downloads: int = 8, offline: bool = False + max_parallel_downloads: int = 8, + offline: bool = False, + peer_download_enabled: bool = False, ) -> ShardDownloader: - return SingletonShardDownloader( - ResumableShardDownloader(max_parallel_downloads, offline=offline) + inner: ShardDownloader = ResumableShardDownloader( + max_parallel_downloads, offline=offline ) + if peer_download_enabled: + inner = PeerAwareShardDownloader(inner, offline=offline) + return SingletonShardDownloader(inner) async def build_base_shard(model_id: ModelId) -> ShardMetadata: diff --git a/src/exo/download/peer_download.py b/src/exo/download/peer_download.py new file mode 100644 index 0000000000..793a572ae4 --- /dev/null +++ b/src/exo/download/peer_download.py @@ -0,0 +1,271 @@ +"""HTTP client for downloading model files from peer nodes. + +Instead of downloading from HuggingFace, nodes can fetch model files from +peers on the same LAN that already have them (or are still downloading them). +Falls back gracefully if the peer is unreachable or the transfer fails. +""" + +import asyncio +from dataclasses import dataclass +from pathlib import Path +from typing import Callable, cast + +import aiofiles +import aiofiles.os as aios +import aiohttp +from loguru import logger + + +@dataclass(frozen=True) +class PeerFileInfo: + """Status of a single file on a peer node.""" + + path: str + size: int + complete: bool + safe_bytes: int + + +def _as_int(value: object) -> int: + return value if isinstance(value, int) else 0 + + +async def get_peer_file_status( + peer_host: str, + peer_port: int, + model_id_normalized: str, + timeout: float = 5.0, +) -> list[PeerFileInfo] | None: + """Query a peer's file server for available files for a model. + + Returns None if the peer is unreachable. + """ + url = f"http://{peer_host}:{peer_port}/status/{model_id_normalized}" + try: + async with ( + aiohttp.ClientSession( + timeout=aiohttp.ClientTimeout(total=timeout) + ) as session, + session.get(url) as r, + ): + if r.status != 200: + return None + data = cast(dict[str, object], await r.json()) + files = data.get("files", []) + if not isinstance(files, list): + return [] + raw_files = cast(list[object], files) + out: list[PeerFileInfo] = [] + required = {"path", "size", "complete", "safe_bytes"} + for raw_file in raw_files: + if not isinstance(raw_file, dict): + continue + file_info = cast(dict[str, object], raw_file) + if not required.issubset(file_info): + continue + out.append( + PeerFileInfo( + path=str(file_info["path"]), + size=_as_int(file_info["size"]), + complete=bool(file_info["complete"]), + safe_bytes=_as_int(file_info["safe_bytes"]), + ) + ) + return out + except Exception as e: + logger.debug(f"Could not reach peer {peer_host}:{peer_port}: {e}") + return None + + +async def download_file_from_peer( + peer_host: str, + peer_port: int, + model_id_normalized: str, + file_path: str, + target_dir: Path, + expected_size: int, + on_progress: Callable[[int, int, bool], None] = lambda _a, _b, _c: None, + max_poll_attempts: int = 60, + poll_interval: float = 3.0, +) -> Path | None: + """Download a single file from a peer's file server. + + Supports streaming relay: if the peer is still downloading the file, + we fetch available bytes, wait, and poll for more until the file is + complete. + + Returns the final file path on success, or None on failure (caller + should fall back to HuggingFace). + """ + target_path = target_dir / file_path + partial_path = target_dir / f"{file_path}.partial" + + # Check if already complete locally + if await aios.path.exists(target_path): + local_size = (await aios.stat(target_path)).st_size + if local_size == expected_size: + on_progress(expected_size, expected_size, True) + return target_path + + await aios.makedirs((target_dir / file_path).parent, exist_ok=True) + + url = f"http://{peer_host}:{peer_port}/files/{model_id_normalized}/{file_path}" + n_read = 0 + + # Resume from existing partial. + # + # Codex P1 (PR #16 round 5): a stale ``.partial`` left over from a + # previous run can be larger than ``expected_size`` (e.g. the peer + # was serving the wrong revision, the on-disk file was truncated + # to a different blob, or the user manually replaced it). In that + # case ``n_read >= expected_size`` skips the resume loop entirely + # and we'd then ``rename`` a too-large file as the "successful" + # result. With offline mode we explicitly skip hash verification, + # so the bad bytes would never get caught downstream and would + # poison the model cache. Fail fast: drop the stale partial and + # restart from zero on this peer. + if await aios.path.exists(partial_path): + existing_size = (await aios.stat(partial_path)).st_size + if existing_size > expected_size: + logger.warning( + f"Discarding stale oversized peer partial for {file_path} " + f"({existing_size} > expected {expected_size}); " + "restarting download from zero" + ) + await aios.remove(partial_path) + n_read = 0 + else: + n_read = existing_size + + poll_count = 0 + chunk_size = 8 * 1024 * 1024 # 8MB, matching HF download + + try: + while n_read < expected_size and poll_count < max_poll_attempts: + headers: dict[str, str] = {} + if n_read > 0: + headers["Range"] = f"bytes={n_read}-" + + got_bytes = False + range_was_requested = n_read > 0 + async with ( + aiohttp.ClientSession( + timeout=aiohttp.ClientTimeout(total=300, sock_read=60) + ) as session, + session.get(url, headers=headers) as r, + ): + if r.status == 416: + # Range not satisfiable - peer doesn't have more yet + pass + elif range_was_requested and r.status == 200: + # Codex P1 (PR #16 round-(N+3), peer_download.py:162): + # we sent a ``Range`` header (we have a partial), but + # the peer ignored it and returned full content with + # 200. Appending the body would duplicate the + # already-downloaded prefix, push ``n_read`` past + # ``expected_size``, and -- because offline mode + # skips hash verification -- silently poison the + # model file. Drop the partial and restart from + # zero on the next loop iteration so the next + # request gets fresh, intact bytes. + logger.warning( + f"Peer {peer_host} ignored Range header for " + f"{file_path} (returned 200 instead of 206); " + "discarding partial and restarting from zero" + ) + await aios.remove(partial_path) + n_read = 0 + elif r.status in (200, 206): + # Codex P1 (PR #16 round-(N+8), peer_download.py:187): + # bound the inner read by ``expected_size - n_read`` + # and treat any extra bytes as a peer protocol + # violation. Pre-fix the loop kept appending until + # EOF and only checked ``n_read < expected_size`` + # afterward, so an oversized response (peer + # serving a stale/wrong blob) was accepted as + # success and renamed into the model cache. In + # offline mode hash verification is skipped, so + # this silently poisoned local weights. Now we + # cap each chunk at the remaining budget and bail + # out the moment a peer tries to send extra data. + oversized_response = False + async with aiofiles.open( + partial_path, "ab" if n_read > 0 else "wb" + ) as f: + while True: + remaining = expected_size - n_read + if remaining <= 0: + # We have everything we need. Read one + # more byte to detect peer + # over-supplying; if the stream isn't + # EOF, the peer is sending more bytes + # than ``expected_size`` claims. + tail = await r.content.read(1) + if tail: + oversized_response = True + break + chunk = await r.content.read(min(chunk_size, remaining)) + if not chunk: + break + written = await f.write(chunk) + n_read += written + got_bytes = True + on_progress(n_read, expected_size, False) + if oversized_response: + # Discard the partial: we cannot trust any + # bytes from a peer that violates the + # advertised file size, especially in + # offline mode where hash verification is + # skipped. Restart from zero on the next + # iteration so a fresh request gets a + # well-bounded response. + logger.warning( + f"Peer {peer_host} returned oversized response for " + f"{file_path} (advertised {expected_size} bytes, " + "stream still had data when budget was exhausted); " + "discarding partial and restarting from zero" + ) + await aios.remove(partial_path) + n_read = 0 + elif r.status == 404: + logger.debug(f"File {file_path} not found on peer {peer_host}") + return None + else: + logger.warning( + f"Unexpected status {r.status} from peer {peer_host}" + ) + return None + + # Check if we're done + if n_read >= expected_size: + break + + # If we got no new bytes, the peer might still be downloading + if not got_bytes: + poll_count += 1 + logger.debug( + f"Waiting for peer {peer_host} to download more of {file_path} " + f"({n_read}/{expected_size}, poll {poll_count}/{max_poll_attempts})" + ) + await asyncio.sleep(poll_interval) + else: + # Got data, reset poll counter + poll_count = 0 + + if n_read < expected_size: + logger.warning( + f"Peer download incomplete for {file_path}: {n_read}/{expected_size}" + ) + return None + + # Rename partial to final + await aios.rename(partial_path, target_path) + on_progress(expected_size, expected_size, True) + logger.info( + f"Downloaded {file_path} from peer {peer_host} ({expected_size} bytes)" + ) + return target_path + + except Exception as e: + logger.warning(f"Peer download failed for {file_path} from {peer_host}: {e}") + return None diff --git a/src/exo/download/peer_file_server.py b/src/exo/download/peer_file_server.py new file mode 100644 index 0000000000..7591ae1699 --- /dev/null +++ b/src/exo/download/peer_file_server.py @@ -0,0 +1,376 @@ +"""Lightweight HTTP file server for peer-to-peer model downloads. + +Each exo node runs a PeerFileServer that serves model files from its local +caches. When one node finishes downloading a model from HuggingFace, other +nodes on the same LAN can fetch it directly over HTTP instead of +re-downloading from the internet. + +Supports serving in-progress downloads via .partial.meta files that track +how many bytes have been safely flushed to disk. + +The server is given the *full* set of directories the local node may store +models in (the writable ``EXO_MODELS_DIRS`` plus any read-only mounts under +``EXO_MODELS_READ_ONLY_DIRS``) so that peers can fetch any locally-resident +model regardless of which directory the downloader picked. Restricting the +server to a single hard-coded directory would silently disable the peer +download path whenever ``select_download_dir_for_shard`` placed the model +in a non-default directory (custom path, low-disk fallback, or a read-only +mount). +""" + +import json +from collections.abc import Sequence +from pathlib import Path +from typing import TypeAlias, cast + +import aiofiles +import aiofiles.os as aios +import anyio +from aiohttp import web +from loguru import logger + +PartialMeta: TypeAlias = dict[str, int | str] + + +class PeerFileServer: + """HTTP server that exposes local model files for peer download.""" + + def __init__(self, host: str, port: int, models_dirs: Sequence[Path]) -> None: + if not models_dirs: + raise ValueError("PeerFileServer requires at least one models directory") + self.host = host + self.port = port + # Preserve caller order so callers can prefer writable dirs over + # read-only dirs without us re-sorting them. + self.models_dirs: tuple[Path, ...] = tuple(models_dirs) + self._app = web.Application() + self._app.router.add_get("/status/{model_id}", self._handle_status) + self._app.router.add_get("/files/{model_id}/{file_path:.+}", self._handle_file) + self._app.router.add_get("/health", self._handle_health) + self._runner: web.AppRunner | None = None + + async def run(self) -> None: + """Start the peer file server and keep the task alive until cancelled. + + Codex P2 (PR #16 round-(N+10), peer_file_server.py:56): pre-fix + ``run()`` returned immediately after ``site.start()``, so the + task spawned by ``Node.run()`` (``tg.start_soon(self.peer_file_server.run)``) + completed on the first event-loop tick and the parent task + group considered the server "done". When the node was + cancelled, there was no live coroutine for the task group to + cancel, so the aiohttp listener kept its TCP socket open + until process exit. That manifested as + ``OSError: [Errno 48] address already in use`` whenever a + node was stopped/restarted in the same process (commonly in + tests, embedded runs, or systemd-style restart loops). + + The fix keeps the coroutine alive via ``anyio.sleep_forever`` + and runs ``self._runner.cleanup()`` in a shielded ``finally`` + block on cancellation, so the listener is reliably released + before the task group considers the server torn down. + """ + runner = web.AppRunner(self._app) + self._runner = runner + await runner.setup() + site = web.TCPSite(runner, self.host, self.port) + await site.start() + logger.info(f"PeerFileServer listening on {self.host}:{self.port}") + try: + await anyio.sleep_forever() + finally: + # Shield cleanup from the cancellation that woke us so + # ``aiohttp`` can drain in-flight responses and release + # the listening socket before this task is considered + # complete. Without the shield the cleanup itself is + # cancelled immediately, which leaves the socket bound + # and reproduces the original ``EADDRINUSE`` symptom. + with anyio.CancelScope(shield=True): + # Re-read self._runner so an external ``shutdown()`` + # call (e.g. from a separate code path) doesn't drive + # cleanup twice. ``cast`` because the type-checker has + # narrowed ``self._runner`` to ``AppRunner`` from the + # assignment above; an external mutation could still + # have set it to ``None``. + live_runner = cast(web.AppRunner | None, self._runner) + if live_runner is not None: + self._runner = None + await live_runner.cleanup() + logger.info(f"PeerFileServer on {self.host}:{self.port} stopped") + + async def shutdown(self) -> None: + if self._runner: + await self._runner.cleanup() + self._runner = None + + async def _handle_health(self, request: web.Request) -> web.Response: + return web.json_response({"status": "ok"}) + + async def _handle_status(self, request: web.Request) -> web.Response: + """Return status of all files for a model (complete + in-progress). + + Codex P2 (PR #16 round-(N+9), peer_file_server.py:201): when + a model's contents are split across multiple configured + roots (e.g. an earlier writable cache holds a partial copy + and a later read-only mount holds the full canonical copy), + report the union across every root that contains the model. + For files that appear in more than one root we keep the + most-complete entry (complete > larger partial) so peers see + the true 'most progressed' version of the file. The earlier + single-root behaviour caused the peer downloader to + miss-report missing files and silently fall back to + HuggingFace even when this node had a complete copy + elsewhere on disk. + """ + model_id = request.match_info["model_id"] + model_dirs = await self._locate_all_model_dirs(model_id) + if not model_dirs: + return web.json_response({"files": []}) + + # path -> entry; complete files dominate partials; larger + # partials dominate smaller ones when no complete is found. + merged: dict[str, dict[str, object]] = {} + + def merge(entry: dict[str, object]) -> None: + path = cast(str, entry["path"]) + existing = merged.get(path) + if existing is None: + merged[path] = entry + return + existing_complete = bool(existing["complete"]) + new_complete = bool(entry["complete"]) + new_partial_is_more_complete = ( + not new_complete + and not existing_complete + and cast(int, entry["safe_bytes"]) > cast(int, existing["safe_bytes"]) + ) + if (new_complete and not existing_complete) or ( + new_partial_is_more_complete + ): + merged[path] = entry + # complete-vs-complete: keep the first (sizes equal by + # construction, callers only need one entry). + + for model_dir in model_dirs: + for item in model_dir.rglob("*"): + relative_path = item.relative_to(model_dir).as_posix() + if item.is_dir() or relative_path.endswith(".partial.meta"): + continue + if _resolve_child(model_dir, relative_path) is None: + continue + + if relative_path.endswith(".partial"): + meta = await _read_partial_meta(item) + if meta: + total = _meta_int(meta, "total") + safe_bytes = _meta_int(meta, "safe_bytes") + merge( + { + "path": relative_path.removesuffix(".partial"), + "size": total, + "complete": False, + "safe_bytes": safe_bytes, + } + ) + else: + stat = await aios.stat(item) + merge( + { + "path": relative_path, + "size": stat.st_size, + "complete": True, + "safe_bytes": stat.st_size, + } + ) + + return web.json_response({"files": list(merged.values())}) + + async def _handle_file(self, request: web.Request) -> web.StreamResponse: + """Serve a model file with Range request support. + + For complete files: standard HTTP file serving. + For .partial files: serves only the safe byte range (flushed to disk). + + Codex P2 (PR #16 round-(N+9), peer_file_server.py:201): when + a model's contents are split across multiple roots, prefer + the root holding a *complete* copy of the requested file + over the first root that merely contains the model + directory. Fall back to a partial copy only if no root has + the file complete. Pre-fix the server returned 404 for + files that lived in a later root, forcing peers to fall + back to HuggingFace despite a complete local copy. + """ + model_id = request.match_info["model_id"] + file_path = request.match_info["file_path"] + + model_dirs = await self._locate_all_model_dirs(model_id) + if not model_dirs: + return web.Response(status=404, text="Model not found") + + complete_hit: Path | None = None + best_partial: tuple[Path, PartialMeta] | None = None + + for model_dir in model_dirs: + complete_candidate = _resolve_child(model_dir, file_path) + partial_candidate = _resolve_child(model_dir, f"{file_path}.partial") + if complete_candidate is None or partial_candidate is None: + continue + if complete_hit is None and await aios.path.exists(complete_candidate): + complete_hit = complete_candidate + # Complete copy in the first matching root wins; we + # don't need to scan the rest for this file. + break + if await aios.path.exists(partial_candidate): + meta = await _read_partial_meta(partial_candidate) + if ( + meta + and _meta_int(meta, "safe_bytes") > 0 + and ( + best_partial is None + or _meta_int(meta, "safe_bytes") + > _meta_int(best_partial[1], "safe_bytes") + ) + ): + best_partial = (partial_candidate, meta) + + if complete_hit is not None: + serve_path = complete_hit + file_size = (await aios.stat(complete_hit)).st_size + safe_bytes = file_size + is_complete = True + elif best_partial is not None: + partial_path, meta = best_partial + serve_path = partial_path + file_size = _meta_int(meta, "total") + safe_bytes = _meta_int(meta, "safe_bytes") + is_complete = False + else: + return web.Response(status=404, text="File not found") + + # Parse Range header + range_header = request.headers.get("Range") + start = 0 + if range_header: + try: + range_spec = range_header.replace("bytes=", "") + start = int(range_spec.split("-")[0]) + except (ValueError, IndexError): + return web.Response(status=416, text="Invalid range") + + if start >= safe_bytes: + return web.Response(status=416, text="Range not satisfiable") + + end = safe_bytes # Serve up to safe boundary only + content_length = end - start + + response = web.StreamResponse( + status=206 if start > 0 else 200, + headers={ + "Content-Type": "application/octet-stream", + "Content-Length": str(content_length), + "Accept-Ranges": "bytes", + "Content-Range": f"bytes {start}-{end - 1}/{file_size}", + "X-Exo-Safe-Bytes": str(safe_bytes), + "X-Exo-Total-Size": str(file_size), + "X-Exo-Complete": "true" if is_complete else "false", + }, + ) + await response.prepare(request) + + chunk_size = 8 * 1024 * 1024 # 8MB chunks matching HF download + async with aiofiles.open(serve_path, "rb") as f: + await f.seek(start) + remaining = content_length + while remaining > 0: + to_read = min(chunk_size, remaining) + chunk = await f.read(to_read) + if not chunk: + break + await response.write(chunk) + remaining -= len(chunk) + + await response.write_eof() + return response + + async def _locate_model_dir(self, model_id: str) -> Path | None: + """Return the first configured directory that contains ``model_id``. + + Each candidate root is path-traversal-checked independently before we + probe the filesystem. We prefer the first directory in ``models_dirs`` + that has a matching subdirectory; this preserves caller-specified + priority (e.g. writable before read-only) without re-sorting. + + Note: callers that need to merge contents across multiple + roots should use :meth:`_locate_all_model_dirs` instead. That + helper exists to address Codex P2 (PR #16 round-(N+9), + peer_file_server.py:201) where an earlier incomplete root + masked a later complete copy. + """ + for root in self.models_dirs: + candidate = _resolve_child(root, model_id) + if candidate is None: + continue + if await aios.path.exists(candidate): + return candidate + return None + + async def _locate_all_model_dirs(self, model_id: str) -> list[Path]: + """Return every configured directory that contains ``model_id``. + + Roots are returned in the same priority order as + ``self.models_dirs`` (writable before read-only) so callers + can short-circuit to the first complete copy. Each candidate + root is path-traversal-checked independently before we probe + the filesystem. + + Codex P2 (PR #16 round-(N+9), peer_file_server.py:201): + ``_locate_model_dir`` returned the first root that *contained* + the model directory regardless of completeness. When an + earlier writable root held a partial download and a later + read-only mount held a complete copy, ``/status`` and + ``/files`` only saw the partial tree -- peers thought the + node had no canonical copy and fell back to HuggingFace. + Callers that merge across roots use this helper to scan + every match. + """ + matches: list[Path] = [] + for root in self.models_dirs: + candidate = _resolve_child(root, model_id) + if candidate is None: + continue + if await aios.path.exists(candidate): + matches.append(candidate) + return matches + + +def _resolve_child(root: Path, relative_path: str) -> Path | None: + """Resolve relative_path under root, rejecting path traversal.""" + resolved_root = root.resolve(strict=False) + resolved_path = (resolved_root / relative_path).resolve(strict=False) + if resolved_root in resolved_path.parents: + return resolved_path + return None + + +def _meta_int(meta: PartialMeta, key: str) -> int: + value = meta.get(key, 0) + return value if isinstance(value, int) else 0 + + +async def _read_partial_meta(partial_path: Path) -> PartialMeta | None: + """Read the .partial.meta companion file for a .partial download.""" + meta_path = Path(f"{partial_path}.meta") + if not await aios.path.exists(meta_path): + return None + try: + async with aiofiles.open(meta_path, "r") as f: + data = cast(object, json.loads(await f.read())) + if not isinstance(data, dict): + return None + raw_meta = cast(dict[object, object], data) + return { + str(key): value + for key, value in raw_meta.items() + if isinstance(value, (int, str)) + } + except (json.JSONDecodeError, OSError): + return None diff --git a/src/exo/download/peer_shard_downloader.py b/src/exo/download/peer_shard_downloader.py new file mode 100644 index 0000000000..5d4e127d4c --- /dev/null +++ b/src/exo/download/peer_shard_downloader.py @@ -0,0 +1,510 @@ +"""Peer-aware shard downloader that tries LAN peers before HuggingFace. + +Wraps an existing ShardDownloader and adds a peer-download step: before +hitting HuggingFace, try peers provided in the available_peers list. +Falls back to the inner downloader (HF) if peer download fails. + +The peer list is computed by the Worker at command-emit time and passed +through the StartDownload command, keeping the download coordinator +decoupled from Worker state. +""" + +import asyncio +import time +from collections import defaultdict, deque +from collections.abc import Awaitable, Coroutine +from datetime import timedelta +from pathlib import Path +from typing import Any, AsyncIterator, Callable, Literal + +import aiofiles +import aiofiles.os as aios +from loguru import logger + +from exo.download.download_utils import ( + RepoDownloadProgress, + calc_hash, + calculate_repo_progress, + fetch_file_list_with_cache, + file_meta, + is_image_model, + resolve_allow_patterns, + resolve_model_dir, +) +from exo.download.huggingface_utils import filter_repo_objects +from exo.download.peer_download import ( + download_file_from_peer, + get_peer_file_status, +) +from exo.download.shard_downloader import ShardDownloader +from exo.shared.types.commands import PeerEndpoint +from exo.shared.types.memory import Memory +from exo.shared.types.worker.downloads import FileListEntry, RepoFileDownloadProgress +from exo.shared.types.worker.shards import ShardMetadata + +ShardPeerKey = str + + +async def _run_progress_callback( + callback: Callable[[ShardMetadata, RepoDownloadProgress], Awaitable[None]], + shard: ShardMetadata, + progress: RepoDownloadProgress, +) -> None: + await callback(shard, progress) + + +class PeerAwareShardDownloader(ShardDownloader): + """ShardDownloader that tries peer download before HuggingFace. + + Decorates an inner ShardDownloader (typically ResumableShardDownloader). + On ensure_shard(), if available_peers were provided, tries downloading + from them over the LAN first. Falls back to the inner downloader if + no peer has it or the transfer fails. + """ + + def __init__(self, inner: ShardDownloader, offline: bool = False) -> None: + self._inner = inner + # ``offline`` mirrors ``ResumableShardDownloader.offline`` and is + # forwarded to ``fetch_file_list_with_cache`` so that a node + # configured for offline operation never reaches out to + # HuggingFace before attempting a peer download. Pre-fix the + # peer path hard-coded ``skip_internet=False`` and would raise + # on cold/offline nodes that lacked a cached file list, ending + # the peer attempt before it could even start. Codex flagged + # this as a P1 (PR #16 round 2). + self._offline = offline + self._progress_callbacks: list[ + Callable[[ShardMetadata, RepoDownloadProgress], Awaitable[None]] + ] = [] + # Peers are set per-download by the coordinator before calling ensure_shard. + self._peers_by_shard: defaultdict[ShardPeerKey, deque[list[PeerEndpoint]]] = ( + defaultdict(deque) + ) + + def set_available_peers( + self, shard: ShardMetadata, peers: list[PeerEndpoint] + ) -> None: + """Set the peers to try for a specific ensure_shard call. + + Called by DownloadCoordinator before triggering a download, based + on the peers embedded in the StartDownload command. + """ + self._peers_by_shard[_peer_key(shard)].append(list(peers)) + + def on_progress( + self, + callback: Callable[[ShardMetadata, RepoDownloadProgress], Awaitable[None]], + ) -> None: + self._inner.on_progress(callback) + self._progress_callbacks.append(callback) + + async def ensure_shard( + self, shard: ShardMetadata, config_only: bool = False + ) -> Path: + if config_only: + return await self._inner.ensure_shard(shard, config_only=True) + + model_id = shard.model_card.model_id + normalized = model_id.normalize() + peers = self._pop_available_peers(shard) + + if not peers: + logger.debug( + f"No peers available for {model_id}, downloading from HuggingFace" + ) + return await self._inner.ensure_shard(shard, config_only=False) + + # Try each peer (already sorted by priority: RDMA first, completed first) + for peer in peers: + logger.info( + f"Attempting peer download of {model_id} from " + f"{peer.ip}:{peer.port} (status: {peer.status}, link: {peer.connection_type})" + ) + result = await self._try_peer_download( + shard, peer.ip, peer.port, normalized + ) + if result is not None: + logger.info(f"Successfully downloaded {model_id} from peer {peer.ip}") + return result + logger.info( + f"Peer download from {peer.ip} failed, trying next peer or HuggingFace" + ) + + # All peers failed, fall back to HuggingFace + logger.info( + f"All peer downloads failed for {model_id}, falling back to HuggingFace" + ) + return await self._inner.ensure_shard(shard, config_only=False) + + async def _try_peer_download( + self, + shard: ShardMetadata, + peer_ip: str, + peer_port: int, + model_id_normalized: str, + ) -> Path | None: + """Attempt to download all model files from a single peer. + + Returns the model directory path on success, None on failure. + """ + # First, check what the peer has + peer_files = await get_peer_file_status(peer_ip, peer_port, model_id_normalized) + if not peer_files: + return None + + peer_file_map = {f.path: f for f in peer_files} + + # Get the file list we need (same logic as download_shard) + revision = "main" + target_dir = await resolve_model_dir(shard.model_card.model_id) + + try: + file_list = await fetch_file_list_with_cache( + shard.model_card.model_id, + revision, + recursive=True, + # Honor the coordinator's offline setting so a cold + # offline node can still satisfy a peer download from + # the LAN without reaching out to HuggingFace for the + # initial file-list fetch (Codex P1, PR #16 round 2). + skip_internet=self._offline, + ) + except Exception: + return None + + allow_patterns = await resolve_allow_patterns(shard) + # Mirror ``download_shard``'s selection logic exactly: it filters + # by ``allow_patterns`` AND ``ignore_patterns`` before deciding + # which files to fetch. Pre-fix the peer path applied + # ``allow_patterns`` only and missed the ignore set, so for any + # repo containing ``original/*`` or ``metal/*`` (e.g. Llama 3.x + # repos) the peer would not have those files locally, and the + # later strict ``peer_info`` missing => fail check would abort + # the whole peer transfer and force a HuggingFace fallback for + # every download (Codex P1, PR #16 round 2). Keep this list in + # sync with ``download_shard`` (download_utils.py:983). + ignore_patterns = ["original/*", "metal/*"] + filtered_file_list: list[FileListEntry] = list( + filter_repo_objects( + file_list, + allow_patterns=allow_patterns, + ignore_patterns=ignore_patterns, + key=lambda x: x.path, + ) + ) + + if is_image_model(shard): + filtered_file_list = [ + f + for f in filtered_file_list + if "/" in f.path or not f.path.endswith(".safetensors") + ] + + # Check the peer has all (or most) files we need + files_on_peer = sum(1 for f in filtered_file_list if f.path in peer_file_map) + if files_on_peer == 0: + logger.debug(f"Peer has no files we need for {model_id_normalized}") + return None + + # Download from peer with progress tracking + all_start_time = time.time() + file_progress: dict[str, RepoFileDownloadProgress] = {} + semaphore = asyncio.Semaphore(8) + failed = False + + async def download_one(file_path: str, expected_size: int) -> bool: + def on_file_progress( + curr_bytes: int, total_bytes: int, is_renamed: bool + ) -> None: + file_progress[file_path] = RepoFileDownloadProgress( + repo_id=str(shard.model_card.model_id), + repo_revision=revision, + file_path=file_path, + downloaded=Memory.from_bytes(curr_bytes), + downloaded_this_session=Memory.from_bytes(curr_bytes), + total=Memory.from_bytes(total_bytes), + speed=curr_bytes / max(time.time() - all_start_time, 0.1), + eta=timedelta( + seconds=(total_bytes - curr_bytes) + / max( + curr_bytes / max(time.time() - all_start_time, 0.1), + 0.1, + ) + ), + status="complete" if is_renamed else "in_progress", + start_time=all_start_time, + ) + progress = calculate_repo_progress( + shard, + shard.model_card.model_id, + revision, + file_progress, + all_start_time, + ) + for cb in self._progress_callbacks: + asyncio.create_task(_run_progress_callback(cb, shard, progress)) + + async with semaphore: + result = await download_file_from_peer( + peer_ip, + peer_port, + model_id_normalized, + file_path, + target_dir, + expected_size, + on_progress=on_file_progress, + ) + if result is None: + return False + # Offline / air-gapped deployments have explicitly opted + # out of contacting HuggingFace. Codex flagged (P1, PR + # #16 round 3) that calling ``file_meta`` here silently + # broke peer transfers in offline mode: any exception + # (e.g. DNS failure, blocked egress) was treated as + # integrity-check failure and the peer copy was + # deleted, leaving the cold node with no path to + # complete model sync. When the operator runs with + # ``--offline``/``EXO_OFFLINE=true`` we trust the LAN + # peer's bytes (size already enforced by + # ``download_file_from_peer``) and skip the HF + # canonical-hash check entirely. + if self._offline: + return True + + # Codex flagged (P2, PR #16 round 2) that peer downloads + # were marked successful as soon as ``n_read == + # expected_size``, with no content-integrity check. A + # peer serving wrong bytes with the right length + # (stale/corrupt/malicious) would otherwise be + # silently accepted as model data, causing + # hard-to-diagnose inference failures. + # + # Validate against HuggingFace's authoritative hash: + # we already need internet for ``fetch_file_list_with_cache`` + # in online mode, so the extra ``file_meta()`` HEAD is + # cheap. Trusting a hash advertised by the peer would + # leave us vulnerable to a malicious peer that lies + # about both bytes and hash; HF is the canonical + # source. + # + # On mismatch the partial-or-renamed file is removed + # so the caller's HF fallback (``self._inner.ensure_shard``) + # starts from a clean slate. + try: + _expected_size, expected_etag = await file_meta( + shard.model_card.model_id, revision, file_path + ) + except Exception as exc: + # If we can't reach HF for metadata, the file + # might still be valid -- but we can't prove it. + # Fall back to HF download where the same call + # would have happened anyway. + logger.warning( + f"Peer download integrity-check failed: could not " + f"fetch HF metadata for {file_path}: {exc}; " + f"discarding peer-downloaded copy" + ) + try: + await aios.remove(result) + except Exception as cleanup_exc: + logger.debug( + f"Could not remove unverified peer download " + f"{result}: {cleanup_exc}" + ) + return False + + hash_type: Literal["sha1", "sha256"] = ( + "sha256" if len(expected_etag) == 64 else "sha1" + ) + final_hash = await calc_hash(result, hash_type=hash_type) + if final_hash != expected_etag: + logger.warning( + f"Peer-downloaded {file_path} from {peer_ip} has " + f"hash {final_hash} but HF authoritative hash is " + f"{expected_etag} ({hash_type}); discarding and " + f"falling back to HF" + ) + try: + await aios.remove(result) + except Exception as exc: + logger.error( + f"Failed to remove corrupt peer download {result}: {exc}" + ) + return False + return True + + # Initialize progress for all files + for f in filtered_file_list: + file_progress[f.path] = RepoFileDownloadProgress( + repo_id=str(shard.model_card.model_id), + repo_revision=revision, + file_path=f.path, + downloaded=Memory.from_bytes(0), + downloaded_this_session=Memory.from_bytes(0), + total=Memory.from_bytes(f.size or 0), + speed=0, + eta=timedelta(0), + status="not_started", + start_time=all_start_time, + ) + + # Codex P2 (PR #16 round-(N+10), peer_shard_downloader.py:354): + # zero-byte files (e.g. ``.gitattributes`` markers, empty + # ``__init__.py`` shims) MUST still be materialized so the + # local snapshot mirrors the filtered file list HF would + # have produced. Pre-fix the peer path silently skipped any + # file with ``size in (None, 0)`` and reported success, so + # ``DownloadCompleted`` was published with an incomplete + # local model directory -- subsequent loads that touched + # those marker files (model loaders, processors that probe + # for ``chat_template.json``, etc.) would then fail in ways + # that don't point back at the peer step. + # + # Codex P1 (PR #16 round-(N+11), peer_shard_downloader.py:354): + # ``size is None`` is *not* the same as ``size == 0``. + # ``fetch_file_list_with_cache`` returns ``FileListEntry(size=None)`` + # for files discovered via the safetensors index (e.g. weight + # shards whose size is not in the HF API response). Pre-fix + # the previous round lumped ``None`` together with literal + # zero and materialized those weight files as empty, + # producing a "DownloadCompleted" snapshot with corrupted / + # incomplete weights that failed only at load/inference + # time. Split the cases: ``== 0`` is materialized as an + # empty marker; ``is None`` aborts the peer transfer and + # forces the HF fallback so the file gets a real download + # path. + # + # Pre-pass: detect bail-out conditions before constructing any + # ``download_one`` coroutines so we don't leak un-awaited + # coroutines on the unknown-size or missing-peer-info paths. + for f in filtered_file_list: + if f.size is None: + logger.info( + f"Peer transfer for {model_id_normalized} aborted: " + f"unknown-size entry {f.path!r} (size=None) cannot " + f"be safely transferred over peer; falling back to HF" + ) + return None + if f.size == 0: + continue + peer_info = peer_file_map.get(f.path) + if not peer_info or peer_info.safe_bytes <= 0: + # Real-size file the peer doesn't have => abort transfer. + return None + + zero_byte_files: list[str] = [] + tasks: list[Coroutine[Any, Any, bool]] = [] + for f in filtered_file_list: + if f.size is None: + # Pre-pass already bailed; safety net for type-narrowing. + return None + if f.size == 0: + # Defer the local touch until after we know the rest + # of the peer transfer succeeded; a partial peer + # transfer should not leave behind orphan empty + # marker files that masquerade as a complete download. + zero_byte_files.append(f.path) + continue + peer_info = peer_file_map.get(f.path) + if peer_info and peer_info.safe_bytes > 0: + tasks.append(download_one(f.path, f.size)) + else: + failed = True + break + + if failed: + return None + + results = await asyncio.gather(*tasks, return_exceptions=True) + if any(isinstance(r, Exception) or r is False for r in results): + return None + + for marker_path in zero_byte_files: + full_path = target_dir / marker_path + try: + await aios.makedirs(full_path.parent, exist_ok=True) + # ``aios.path.exists`` first to avoid an unnecessary + # touch (and the corresponding mtime bump) when + # resume-from-partial finds the marker already on + # disk. ``aios.open`` in append mode is the safest + # way to materialize the empty file without + # truncating an already-present marker. + if not await aios.path.exists(full_path): + async with aiofiles.open(full_path, mode="a"): + pass + except Exception as exc: + logger.warning( + f"Could not materialize zero-byte marker file " + f"{full_path} after peer transfer: {exc}; " + f"falling back to HF for full snapshot integrity" + ) + return None + # Codex P2 (PR #16 round-(N+13), peer_shard_downloader.py:407): + # ``download_one`` -> ``on_file_progress`` is the only + # writer of the per-file ``status="complete"`` marker; + # the zero-byte branch never invokes it (there are no + # bytes to stream), so the file_progress entry seeded + # at line 338 stays at ``status="not_started"``. The + # final ``calculate_repo_progress`` call below then + # rolls those up into a non-``complete`` overall status, + # which means ``_download_progress_callback`` does NOT + # publish ``DownloadCompleted`` -- the model gets stuck + # in ``DownloadOngoing`` until the periodic + # reconciliation loop in ``DownloadCoordinator`` notices + # the on-disk snapshot and force-updates the status. + # Mirror the regular file completion path by overwriting + # the seeded entry with a fully-complete one once the + # marker is on disk. ``RepoFileDownloadProgress`` is + # frozen, so we replace the dict slot rather than + # mutating the existing instance. + file_progress[marker_path] = RepoFileDownloadProgress( + repo_id=str(shard.model_card.model_id), + repo_revision=revision, + file_path=marker_path, + downloaded=Memory.from_bytes(0), + downloaded_this_session=Memory.from_bytes(0), + total=Memory.from_bytes(0), + speed=0, + eta=timedelta(0), + status="complete", + start_time=all_start_time, + ) + + # Emit final progress + final_progress = calculate_repo_progress( + shard, + shard.model_card.model_id, + revision, + file_progress, + all_start_time, + ) + for cb in self._progress_callbacks: + await cb(shard, final_progress) + + gguf = next((f for f in filtered_file_list if f.path.endswith(".gguf")), None) + return (target_dir / gguf.path) if gguf else target_dir + + async def get_shard_download_status( + self, + ) -> AsyncIterator[tuple[Path, RepoDownloadProgress]]: + async for path, status in self._inner.get_shard_download_status(): + yield path, status + + async def get_shard_download_status_for_shard( + self, shard: ShardMetadata + ) -> RepoDownloadProgress: + return await self._inner.get_shard_download_status_for_shard(shard) + + def _pop_available_peers(self, shard: ShardMetadata) -> list[PeerEndpoint]: + key = _peer_key(shard) + queue = self._peers_by_shard.get(key) + if not queue: + return [] + peers = queue.popleft() + if not queue: + del self._peers_by_shard[key] + return peers + + +def _peer_key(shard: ShardMetadata) -> ShardPeerKey: + return shard.model_dump_json() diff --git a/src/exo/download/peer_state.py b/src/exo/download/peer_state.py new file mode 100644 index 0000000000..8696a50b9f --- /dev/null +++ b/src/exo/download/peer_state.py @@ -0,0 +1,129 @@ +"""Pure functions for discovering which peers have which models. + +These functions are called by the Worker (which owns the State) to compute +peer availability at command-emit time. The results are embedded in the +StartDownload command so the download coordinator stays decoupled from +Worker state. +""" + +from loguru import logger + +from exo.shared.types.commands import PeerEndpoint +from exo.shared.types.common import NodeId +from exo.shared.types.state import State +from exo.shared.types.topology import RDMAConnection, SocketConnection +from exo.shared.types.worker.downloads import ( + DownloadCompleted, + DownloadOngoing, +) + + +def discover_peers_for_model( + node_id: NodeId, + state: State, + model_id_normalized: str, + peer_download_port: int, +) -> list[PeerEndpoint]: + """Find peers that have a specific model (complete or in-progress). + + Called by the Worker when emitting a StartDownload command. Returns + peers sorted by priority: RDMA/Thunderbolt connections first, then + completed downloads before ongoing ones. + + Args: + node_id: This node's ID (excluded from results). + state: The global State object (owned by Worker). + model_id_normalized: Normalized model ID (e.g. "org--model"). + peer_download_port: Port where peers run their PeerFileServer. + + Returns: + List of PeerEndpoint sorted by connection quality and completeness. + """ + peers: list[PeerEndpoint] = [] + + for peer_node_id, download_list in state.downloads.items(): + if peer_node_id == node_id: + continue + + for dl in download_list: + dl_model_id = dl.shard_metadata.model_card.model_id + if dl_model_id.normalize() != model_id_normalized: + continue + + if isinstance(dl, DownloadCompleted): + status = "complete" + elif isinstance(dl, DownloadOngoing): + status = "ongoing" + else: + continue + + # Resolve IP and connection type from topology + endpoint = _resolve_peer_endpoint( + node_id, peer_node_id, state, peer_download_port, status + ) + if endpoint: + peers.append(endpoint) + + # Sort by priority: + # 1. RDMA/Thunderbolt connections first (lower latency, higher bandwidth) + # 2. Completed downloads before ongoing ones + peers.sort( + key=lambda p: ( + 0 if p.connection_type == "rdma" else 1, + 0 if p.status == "complete" else 1, + ) + ) + return peers + + +def _resolve_peer_endpoint( + node_id: NodeId, + peer_node_id: NodeId, + state: State, + peer_download_port: int, + status: str, +) -> PeerEndpoint | None: + """Resolve a peer's IP address and connection type from the topology. + + Iteration order over ``out_edges`` is not guaranteed to surface RDMA + edges before socket edges, so we scan the full edge set once: any + RDMA edge wins (we use the peer's socket address for the actual TCP + connect since RDMA edges don't carry routable IPs), and only when no + RDMA edge exists do we fall back to the socket endpoint. Returning + on the first non-RDMA hit would otherwise mislabel peers as + ``socket`` whenever the socket edge happens to be visited first. + """ + try: + edges = [ + conn + for conn in state.topology.out_edges(node_id) + if conn.sink == peer_node_id + ] + has_rdma = any(isinstance(conn.edge, RDMAConnection) for conn in edges) + socket_ip = next( + ( + conn.edge.sink_multiaddr.ip_address + for conn in edges + if isinstance(conn.edge, SocketConnection) + ), + None, + ) + if has_rdma and socket_ip: + return PeerEndpoint( + node_id=peer_node_id, + ip=socket_ip, + port=peer_download_port, + status=status, + connection_type="rdma", + ) + if socket_ip: + return PeerEndpoint( + node_id=peer_node_id, + ip=socket_ip, + port=peer_download_port, + status=status, + connection_type="socket", + ) + except Exception as e: + logger.debug(f"Could not resolve endpoint for peer {peer_node_id}: {e}") + return None diff --git a/src/exo/download/tests/test_peer_download.py b/src/exo/download/tests/test_peer_download.py new file mode 100644 index 0000000000..3a753eea48 --- /dev/null +++ b/src/exo/download/tests/test_peer_download.py @@ -0,0 +1,1759 @@ +"""Tests for peer-to-peer model downloading.""" +# pyright: reportPrivateUsage=false + +import json +import socket +from collections.abc import AsyncIterator, Generator, Iterable +from pathlib import Path +from typing import Callable, cast + +import aiofiles +import aiofiles.os as aios +import aiohttp +import anyio +import pytest + +from exo.download.peer_download import download_file_from_peer, get_peer_file_status +from exo.download.peer_file_server import PeerFileServer +from exo.download.peer_shard_downloader import PeerAwareShardDownloader +from exo.download.shard_downloader import NoopShardDownloader +from exo.shared.models.model_cards import ModelCard, ModelId, ModelTask +from exo.shared.types.commands import PeerEndpoint +from exo.shared.types.common import NodeId +from exo.shared.types.memory import Memory +from exo.shared.types.worker.shards import PipelineShardMetadata, ShardMetadata + + +@pytest.fixture +async def temp_models_dir(tmp_path: Path) -> AsyncIterator[Path]: + """Set up a temporary models directory for testing.""" + models_dir = tmp_path / "models" + await aios.makedirs(models_dir, exist_ok=True) + yield models_dir + + +@pytest.fixture +async def peer_server(temp_models_dir: Path) -> AsyncIterator[PeerFileServer]: + """Start a PeerFileServer on a random port for testing.""" + server = PeerFileServer(host="127.0.0.1", port=0, models_dirs=[temp_models_dir]) + # Use port 0 to let OS assign a free port + from aiohttp import web + + server._runner = web.AppRunner(server._app) + await server._runner.setup() + site = web.TCPSite(server._runner, "127.0.0.1", 0) + await site.start() + # Get the actual port assigned + server.port = site._server.sockets[0].getsockname()[1] # type: ignore[union-attr] + yield server + await server.shutdown() + + +def _make_shard(model_id: ModelId) -> ShardMetadata: + return PipelineShardMetadata( + model_card=ModelCard( + model_id=model_id, + storage_size=Memory.from_mb(100), + n_layers=28, + hidden_size=1024, + supports_tensor=False, + tasks=[ModelTask.TextGeneration], + ), + device_rank=0, + world_size=1, + start_layer=0, + end_layer=28, + n_layers=28, + ) + + +class TestPeerFileServer: + """Tests for the HTTP file server that serves model files to peers.""" + + async def test_health_check(self, peer_server: PeerFileServer) -> None: + """Health endpoint should return ok.""" + import aiohttp + + async with ( + aiohttp.ClientSession() as session, + session.get(f"http://127.0.0.1:{peer_server.port}/health") as r, + ): + assert r.status == 200 + data = cast(dict[str, object], await r.json()) + assert data["status"] == "ok" + + async def test_status_empty_model(self, peer_server: PeerFileServer) -> None: + """Status for non-existent model should return empty file list.""" + files = await get_peer_file_status( + "127.0.0.1", peer_server.port, "nonexistent--model" + ) + assert files is not None + assert len(files) == 0 + + async def test_status_with_complete_file( + self, peer_server: PeerFileServer, temp_models_dir: Path + ) -> None: + """Status should report complete files correctly.""" + model_dir = temp_models_dir / "test--model" + await aios.makedirs(model_dir, exist_ok=True) + + # Create a complete test file + async with aiofiles.open(model_dir / "config.json", "wb") as f: + await f.write(b'{"test": true}') + + files = await get_peer_file_status("127.0.0.1", peer_server.port, "test--model") + assert files is not None + assert len(files) == 1 + assert files[0].path == "config.json" + assert files[0].complete is True + assert files[0].safe_bytes == 14 + + async def test_status_with_partial_file( + self, peer_server: PeerFileServer, temp_models_dir: Path + ) -> None: + """Status should report partial files with safe byte count.""" + model_dir = temp_models_dir / "test--model" + await aios.makedirs(model_dir, exist_ok=True) + + # Create a partial file with metadata + partial_data = b"x" * 1024 + async with aiofiles.open(model_dir / "weights.safetensors.partial", "wb") as f: + await f.write(partial_data) + + meta = {"safe_bytes": 1024, "total": 4096, "etag": "abc123"} + async with aiofiles.open( + model_dir / "weights.safetensors.partial.meta", "w" + ) as f: + await f.write(json.dumps(meta)) + + files = await get_peer_file_status("127.0.0.1", peer_server.port, "test--model") + assert files is not None + assert len(files) == 1 + assert files[0].path == "weights.safetensors" + assert files[0].complete is False + assert files[0].safe_bytes == 1024 + assert files[0].size == 4096 + + async def test_status_includes_nested_files( + self, peer_server: PeerFileServer, temp_models_dir: Path + ) -> None: + """Status should report nested complete and partial files.""" + model_dir = temp_models_dir / "test--model" + nested_dir = model_dir / "snapshots" / "abc123" + await aios.makedirs(nested_dir, exist_ok=True) + + async with aiofiles.open(nested_dir / "config.json", "wb") as f: + await f.write(b"{}") + async with aiofiles.open(nested_dir / "model.safetensors.partial", "wb") as f: + await f.write(b"x" * 512) + async with aiofiles.open( + nested_dir / "model.safetensors.partial.meta", "w" + ) as f: + await f.write(json.dumps({"safe_bytes": 512, "total": 2048})) + + files = await get_peer_file_status("127.0.0.1", peer_server.port, "test--model") + assert files is not None + by_path = {file.path: file for file in files} + assert by_path["snapshots/abc123/config.json"].complete is True + assert by_path["snapshots/abc123/model.safetensors"].complete is False + assert by_path["snapshots/abc123/model.safetensors"].safe_bytes == 512 + + async def test_serve_complete_file( + self, peer_server: PeerFileServer, temp_models_dir: Path + ) -> None: + """Should serve a complete file with correct headers.""" + model_dir = temp_models_dir / "test--model" + await aios.makedirs(model_dir, exist_ok=True) + + content = b"hello world test content" + async with aiofiles.open(model_dir / "config.json", "wb") as f: + await f.write(content) + + import aiohttp + + async with ( + aiohttp.ClientSession() as session, + session.get( + f"http://127.0.0.1:{peer_server.port}/files/test--model/config.json" + ) as r, + ): + assert r.status == 200 + assert r.headers["X-Exo-Complete"] == "true" + body = await r.read() + assert body == content + + async def test_serve_nested_file( + self, peer_server: PeerFileServer, temp_models_dir: Path + ) -> None: + """Should serve a complete nested file with correct headers.""" + model_dir = temp_models_dir / "test--model" + nested_dir = model_dir / "snapshots" / "abc123" + await aios.makedirs(nested_dir, exist_ok=True) + + content = b"nested content" + async with aiofiles.open(nested_dir / "config.json", "wb") as f: + await f.write(content) + + import aiohttp + + async with ( + aiohttp.ClientSession() as session, + session.get( + f"http://127.0.0.1:{peer_server.port}/files/test--model/" + "snapshots/abc123/config.json" + ) as r, + ): + assert r.status == 200 + body = await r.read() + assert body == content + + async def test_rejects_path_traversal( + self, peer_server: PeerFileServer, temp_models_dir: Path + ) -> None: + """Should not serve files outside the requested model directory.""" + model_dir = temp_models_dir / "test--model" + await aios.makedirs(model_dir, exist_ok=True) + + outside_file = temp_models_dir / "outside.txt" + async with aiofiles.open(outside_file, "wb") as f: + await f.write(b"outside") + + import aiohttp + + async with ( + aiohttp.ClientSession() as session, + session.get( + f"http://127.0.0.1:{peer_server.port}/files/test--model/" + "%2E%2E/outside.txt" + ) as r, + ): + assert r.status == 404 + assert await r.text() != "outside" + + async def test_serve_with_range_request( + self, peer_server: PeerFileServer, temp_models_dir: Path + ) -> None: + """Should support Range requests for resume.""" + model_dir = temp_models_dir / "test--model" + await aios.makedirs(model_dir, exist_ok=True) + + content = b"0123456789abcdef" + async with aiofiles.open(model_dir / "weights.bin", "wb") as f: + await f.write(content) + + import aiohttp + + async with ( + aiohttp.ClientSession() as session, + session.get( + f"http://127.0.0.1:{peer_server.port}/files/test--model/weights.bin", + headers={"Range": "bytes=8-"}, + ) as r, + ): + assert r.status == 206 + body = await r.read() + assert body == b"89abcdef" + + async def test_file_not_found(self, peer_server: PeerFileServer) -> None: + """Should return 404 for missing files.""" + import aiohttp + + async with ( + aiohttp.ClientSession() as session, + session.get( + f"http://127.0.0.1:{peer_server.port}/files/test--model/missing.bin" + ) as r, + ): + assert r.status == 404 + + +class TestPeerFileServerMultipleDirectories: + """The peer file server must look for the model in *every* configured + models directory. Otherwise a node that lands a model in a non-default + writable directory (custom path, low-disk fallback, or read-only mount) + would silently fail to advertise it to peers and force them back onto + HuggingFace -- defeating the whole peer download path. + """ + + async def test_serves_model_from_secondary_writable_dir( + self, tmp_path: Path + ) -> None: + primary = tmp_path / "primary" + secondary = tmp_path / "secondary" + await aios.makedirs(primary, exist_ok=True) + await aios.makedirs(secondary, exist_ok=True) + + model_dir = secondary / "test--model" + await aios.makedirs(model_dir, exist_ok=True) + async with aiofiles.open(model_dir / "config.json", "wb") as f: + await f.write(b'{"hello":"world"}') + + server = PeerFileServer( + host="127.0.0.1", port=0, models_dirs=[primary, secondary] + ) + + from aiohttp import web + + server._runner = web.AppRunner(server._app) + await server._runner.setup() + site = web.TCPSite(server._runner, "127.0.0.1", 0) + await site.start() + port_int: int = cast(int, site._server.sockets[0].getsockname()[1]) # type: ignore[union-attr] + server.port = port_int + try: + files = await get_peer_file_status("127.0.0.1", port_int, "test--model") + assert files is not None + assert {f.path for f in files} == {"config.json"} + finally: + await server.shutdown() + + async def test_serves_model_from_read_only_mount(self, tmp_path: Path) -> None: + writable = tmp_path / "writable" + read_only = tmp_path / "ro_mount" + await aios.makedirs(writable, exist_ok=True) + await aios.makedirs(read_only / "ro--model", exist_ok=True) + async with aiofiles.open(read_only / "ro--model" / "config.json", "wb") as f: + await f.write(b"{}") + + server = PeerFileServer( + host="127.0.0.1", port=0, models_dirs=[writable, read_only] + ) + + from aiohttp import web + + server._runner = web.AppRunner(server._app) + await server._runner.setup() + site = web.TCPSite(server._runner, "127.0.0.1", 0) + await site.start() + port_int: int = cast(int, site._server.sockets[0].getsockname()[1]) # type: ignore[union-attr] + server.port = port_int + try: + files = await get_peer_file_status("127.0.0.1", port_int, "ro--model") + assert files is not None + assert {f.path for f in files} == {"config.json"} + finally: + await server.shutdown() + + async def test_constructor_rejects_empty_directory_list(self) -> None: + with pytest.raises(ValueError, match="at least one models directory"): + PeerFileServer(host="127.0.0.1", port=0, models_dirs=[]) + + async def test_status_unions_partial_in_first_root_with_complete_in_second( + self, tmp_path: Path + ) -> None: + """Codex P2 (PR #16 round-(N+9), peer_file_server.py:201): if + an earlier root has a stale/incomplete model directory and a + later root has a complete copy, ``/status`` must surface the + complete file -- otherwise peers see the file as missing and + fall back to HuggingFace despite the local node having a + canonical copy on a different mount. + """ + from aiohttp import web + + first = tmp_path / "first" + second = tmp_path / "second" + await aios.makedirs(first / "test--model", exist_ok=True) + await aios.makedirs(second / "test--model", exist_ok=True) + + # First root has only a partial of weights.bin (incomplete). + partial_path = first / "test--model" / "weights.bin.partial" + canonical = b"the canonical model weights" + async with aiofiles.open(partial_path, "wb") as f: + await f.write(canonical[: len(canonical) // 2]) + # Companion meta marking 50% safe. + meta_path = first / "test--model" / "weights.bin.partial.meta" + async with aiofiles.open(meta_path, "w") as f: + await f.write( + json.dumps( + { + "total": len(canonical), + "safe_bytes": len(canonical) // 2, + } + ) + ) + + # Second root has the full canonical file (complete). + async with aiofiles.open(second / "test--model" / "weights.bin", "wb") as f: + await f.write(canonical) + + server = PeerFileServer(host="127.0.0.1", port=0, models_dirs=[first, second]) + server._runner = web.AppRunner(server._app) + await server._runner.setup() + site = web.TCPSite(server._runner, "127.0.0.1", 0) + await site.start() + port_int: int = cast(int, site._server.sockets[0].getsockname()[1]) # type: ignore[union-attr] + server.port = port_int + try: + files = await get_peer_file_status("127.0.0.1", port_int, "test--model") + assert files is not None + file_map = {f.path: f for f in files} + assert "weights.bin" in file_map, ( + "complete copy in the second root must surface in /status; " + "got files={file_map.keys()}" + ) + assert file_map["weights.bin"].complete is True, ( + "complete copy in the second root must dominate the " + "partial in the first root; otherwise peers will fall " + "back to HuggingFace" + ) + assert file_map["weights.bin"].size == len(canonical) + finally: + await server.shutdown() + + async def test_files_serves_complete_copy_when_first_root_has_only_partial( + self, tmp_path: Path + ) -> None: + """End-to-end: ``/files/`` must select the root holding + the complete file even when an earlier root has only a + partial. Pre-fix the server returned 404 (or served the + smaller partial via the partial-bytes path) when a complete + file lived in a later root, forcing peers to fall back to + HuggingFace. + """ + from aiohttp import web + + first = tmp_path / "first" + second = tmp_path / "second" + await aios.makedirs(first / "test--model", exist_ok=True) + await aios.makedirs(second / "test--model", exist_ok=True) + + canonical = b"complete-canonical-bytes" + # First root has partial (with valid meta). + partial_path = first / "test--model" / "weights.bin.partial" + async with aiofiles.open(partial_path, "wb") as f: + await f.write(canonical[: len(canonical) // 2]) + meta_path = first / "test--model" / "weights.bin.partial.meta" + async with aiofiles.open(meta_path, "w") as f: + await f.write( + json.dumps( + { + "total": len(canonical), + "safe_bytes": len(canonical) // 2, + } + ) + ) + # Second root has the complete file. + async with aiofiles.open(second / "test--model" / "weights.bin", "wb") as f: + await f.write(canonical) + + server = PeerFileServer(host="127.0.0.1", port=0, models_dirs=[first, second]) + server._runner = web.AppRunner(server._app) + await server._runner.setup() + site = web.TCPSite(server._runner, "127.0.0.1", 0) + await site.start() + port_int: int = cast(int, site._server.sockets[0].getsockname()[1]) # type: ignore[union-attr] + server.port = port_int + try: + url = f"http://127.0.0.1:{port_int}/files/test--model/weights.bin" + async with ( + aiohttp.ClientSession() as session, + session.get(url) as r, + ): + assert r.status == 200, ( + f"expected 200 from /files when complete copy exists in " + f"a later root; got {r.status}" + ) + body = await r.read() + assert body == canonical, ( + f"expected canonical bytes from later root; got " + f"{len(body)} bytes (expected {len(canonical)})" + ) + # Sanity: X-Exo-Complete header should mark this as a + # complete serving (not a partial-bytes fragment). + assert r.headers.get("X-Exo-Complete") == "true" + finally: + await server.shutdown() + + +class TestPeerDownloadClient: + """Tests for downloading files from a peer server.""" + + async def test_download_complete_file( + self, peer_server: PeerFileServer, temp_models_dir: Path, tmp_path: Path + ) -> None: + """Should download a complete file from peer.""" + # Set up source file on the peer server + model_dir = temp_models_dir / "test--model" + await aios.makedirs(model_dir, exist_ok=True) + + content = b"model weights data " * 100 + async with aiofiles.open(model_dir / "weights.bin", "wb") as f: + await f.write(content) + + # Download to a different directory + download_dir = tmp_path / "downloads" / "test--model" + await aios.makedirs(download_dir, exist_ok=True) + + progress_calls: list[tuple[int, int, bool]] = [] + + result = await download_file_from_peer( + "127.0.0.1", + peer_server.port, + "test--model", + "weights.bin", + download_dir, + len(content), + on_progress=lambda c, t, r: progress_calls.append((c, t, r)), + ) + + assert result is not None + assert result == download_dir / "weights.bin" + async with aiofiles.open(result, "rb") as f: + downloaded = await f.read() + assert downloaded == content + # Should have progress calls including final + assert len(progress_calls) > 0 + assert progress_calls[-1][2] is True # is_renamed + + async def test_download_returns_none_on_missing( + self, peer_server: PeerFileServer, tmp_path: Path + ) -> None: + """Should return None when file doesn't exist on peer.""" + download_dir = tmp_path / "downloads" / "test--model" + await aios.makedirs(download_dir, exist_ok=True) + + result = await download_file_from_peer( + "127.0.0.1", + peer_server.port, + "test--model", + "nonexistent.bin", + download_dir, + 1000, + ) + assert result is None + + async def test_download_returns_none_on_unreachable_peer( + self, tmp_path: Path + ) -> None: + """Should return None when peer is unreachable.""" + download_dir = tmp_path / "downloads" / "test--model" + await aios.makedirs(download_dir, exist_ok=True) + + result = await download_file_from_peer( + "127.0.0.1", + 19999, # Nobody listening + "test--model", + "weights.bin", + download_dir, + 1000, + ) + assert result is None + + async def test_oversized_stale_partial_is_discarded_and_retransferred( + self, peer_server: PeerFileServer, temp_models_dir: Path, tmp_path: Path + ) -> None: + """Codex P1 (PR #16 round 5): a stale ``.partial`` larger than + ``expected_size`` left over from a previous run must be + rejected, NOT silently renamed as the successful download. + + Pre-fix the resume loop ran ``while n_read < expected_size``, + so an oversized partial skipped the loop entirely and the + final ``rename`` accepted bad bytes. In offline mode (where + hash verification is intentionally skipped) this would + permanently poison the model cache without any warning. + Post-fix the oversized partial is discarded and the file is + re-fetched from the peer. + """ + model_dir = temp_models_dir / "test--model" + await aios.makedirs(model_dir, exist_ok=True) + canonical = b"the canonical model weights" + async with aiofiles.open(model_dir / "weights.bin", "wb") as f: + await f.write(canonical) + + download_dir = tmp_path / "downloads" / "test--model" + await aios.makedirs(download_dir, exist_ok=True) + # Stale partial from a "previous run" -- bigger than the + # canonical file and full of junk bytes. Pre-fix, this would + # be the file that ended up renamed as ``weights.bin``. + stale_partial = download_dir / "weights.bin.partial" + stale_bytes = b"\xde\xad\xbe\xef" * (len(canonical) * 2) + async with aiofiles.open(stale_partial, "wb") as f: + await f.write(stale_bytes) + assert (await aios.stat(stale_partial)).st_size > len(canonical) + + result = await download_file_from_peer( + "127.0.0.1", + peer_server.port, + "test--model", + "weights.bin", + download_dir, + len(canonical), + ) + + assert result is not None + assert result == download_dir / "weights.bin" + async with aiofiles.open(result, "rb") as f: + downloaded = await f.read() + assert downloaded == canonical, ( + "stale oversized partial must NOT be accepted as the " + "downloaded file; the fix must redownload from the peer" + ) + assert not stale_partial.exists() + + async def test_resume_with_200_response_discards_partial_and_restarts( + self, tmp_path: Path + ) -> None: + """Codex P1 (PR #16 round-(N+3), peer_download.py:162): when + the client resumes a download (``n_read > 0``) it sends a + ``Range`` header, but a non-compliant server is permitted to + ignore it and return full content with HTTP 200 instead of + 206. Pre-fix the client appended the full body to the + partial, pushing ``n_read`` past ``expected_size`` and + renaming the oversized file as the "successful" download. + In offline mode hash verification is intentionally skipped, + so the bad bytes silently poisoned the model cache. + + We stand up a tiny aiohttp server that returns full content + with 200 even when ``Range`` is set, prime a partial file, + and assert the client discards the partial, restarts from + zero, and lands the canonical bytes (matching ``expected_size``). + """ + from aiohttp import web + + canonical = b"the canonical model weights" + + async def handler(request: web.Request) -> web.Response: + # Always return full content with HTTP 200, ignoring any + # ``Range`` header. This simulates the non-compliant + # peer server the codex finding flagged. + del request + return web.Response(body=canonical, status=200) + + app = web.Application() + # Path must match the client's URL template: + # ``http://host:port/files//`` + _ = app.router.add_get("/files/test/weights.bin", handler) + runner = web.AppRunner(app) + await runner.setup() + site = web.TCPSite(runner, "127.0.0.1", 0) + await site.start() + try: + # Mirror the ``peer_server`` fixture: ``aiohttp.web.TCPSite`` + # surfaces the kernel-assigned port through its private + # ``_server.sockets`` attribute. The module-level + # ``reportPrivateUsage=false`` and ``type: ignore`` here + # match the existing fixture's access pattern. + port: int = cast( + int, + site._server.sockets[0].getsockname()[1], # type: ignore[union-attr] + ) + + download_dir = tmp_path / "downloads" / "test" + await aios.makedirs(download_dir, exist_ok=True) + # Prime a stale partial with bogus content to force the + # resume codepath (Range header) on the first attempt. + partial_path = download_dir / "weights.bin.partial" + stale_prefix = b"\xff" * (len(canonical) // 2) + async with aiofiles.open(partial_path, "wb") as f: + await f.write(stale_prefix) + assert (await aios.stat(partial_path)).st_size > 0 + + result = await download_file_from_peer( + "127.0.0.1", + port, + "test", + "weights.bin", + download_dir, + len(canonical), + ) + + assert result is not None, ( + "the client should ultimately succeed by discarding the " + "stale partial and restarting from zero on the second " + "request" + ) + assert result == download_dir / "weights.bin" + async with aiofiles.open(result, "rb") as f: + downloaded = await f.read() + assert downloaded == canonical, ( + "200-on-resume must trigger a partial restart; the final " + "file must be the canonical bytes, not a duplicate-prefix " + "concatenation" + ) + assert not partial_path.exists(), ( + "successful download must remove the partial path" + ) + finally: + await runner.cleanup() + + async def test_oversized_peer_response_is_rejected_and_restarted( + self, tmp_path: Path + ) -> None: + """Codex P1 (PR #16 round-(N+8), peer_download.py:187): the + download loop used to keep appending bytes until EOF and only + check ``n_read < expected_size`` afterwards. A non-compliant + peer that serves *more* bytes than the advertised + ``expected_size`` would push ``n_read`` past it, the file + would be renamed as a successful download, and -- because + offline mode skips hash verification -- silently poison the + model cache. + + We stand up a tiny aiohttp server that always returns + ``len(canonical) + 8`` bytes regardless of how much was + requested. Pre-fix this would land a corrupt file in the + cache. Post-fix the client must discard each oversized + response and never end up with a final file containing extra + bytes.""" + from aiohttp import web + + canonical = b"the canonical model weights" + # The payload the bad peer always serves: the canonical + # bytes plus extra trailing bytes the peer claimed wouldn't + # exist. This is the attack/bug the fix guards against. + oversized_payload = canonical + b"POISONED" + request_count = 0 + max_requests = 4 # keep test fast: client retries a few times + + async def handler(request: web.Request) -> web.Response: + nonlocal request_count + request_count += 1 + del request + if request_count > max_requests: + # Surface a definitive failure if the client keeps + # hammering the bad peer; that means the fix + # regressed and we'd otherwise hang. + return web.Response(body=b"", status=500) + return web.Response(body=oversized_payload, status=200) + + app = web.Application() + _ = app.router.add_get("/files/test/weights.bin", handler) + runner = web.AppRunner(app) + await runner.setup() + site = web.TCPSite(runner, "127.0.0.1", 0) + await site.start() + try: + port: int = cast( + int, + site._server.sockets[0].getsockname()[1], # type: ignore[union-attr] + ) + + download_dir = tmp_path / "downloads" / "test" + await aios.makedirs(download_dir, exist_ok=True) + + result = await download_file_from_peer( + "127.0.0.1", + port, + "test", + "weights.bin", + download_dir, + len(canonical), + ) + + # The bad peer never serves a well-bounded response, so + # the client cannot complete. The contract is "no + # corrupt data lands in the cache". We tolerate either + # outcome: + # 1. ``result is None`` (client gave up after retries); or + # 2. ``result == canonical`` (a future improvement + # where we keep the canonical-prefix bytes after + # stripping the over-supply). + # The forbidden outcome is the final file containing + # the trailing "POISONED" bytes. + partial_path = download_dir / "weights.bin.partial" + target_path = download_dir / "weights.bin" + + if result is not None: + async with aiofiles.open(result, "rb") as f: + downloaded = await f.read() + assert downloaded == canonical, ( + "if the client claims success, the final file MUST " + "be exactly the canonical bytes; oversized peer " + "responses must never land trailing junk in the " + f"cache. got len={len(downloaded)} bytes: {downloaded!r}" + ) + # In the giving-up branch, neither file should remain + # poisoned. The partial is removed every time we detect + # over-supply, and we never rename to ``target_path`` + # without a clean-budgeted final write. + if target_path.exists(): + async with aiofiles.open(target_path, "rb") as f: + final = await f.read() + assert final == canonical, ( + f"target path was renamed but contains " + f"{len(final)} bytes (expected {len(canonical)}); " + "oversized response made it into the cache" + ) + if partial_path.exists(): + size = (await aios.stat(partial_path)).st_size + assert size <= len(canonical), ( + f"partial path retains {size} bytes after " + f"oversized response (expected <= {len(canonical)}); " + "over-supply must be discarded, not preserved" + ) + finally: + await runner.cleanup() + + async def test_skip_already_complete( + self, peer_server: PeerFileServer, temp_models_dir: Path, tmp_path: Path + ) -> None: + """Should skip download if file already exists locally with correct size.""" + model_dir = temp_models_dir / "test--model" + await aios.makedirs(model_dir, exist_ok=True) + + content = b"existing content" + # File already exists in target + download_dir = tmp_path / "downloads" / "test--model" + await aios.makedirs(download_dir, exist_ok=True) + async with aiofiles.open(download_dir / "config.json", "wb") as f: + await f.write(content) + + result = await download_file_from_peer( + "127.0.0.1", + peer_server.port, + "test--model", + "config.json", + download_dir, + len(content), + ) + + assert result is not None + assert result == download_dir / "config.json" + + +class TestPeerAwareShardDownloader: + """Tests for peer selection handoff into peer-aware downloads.""" + + def test_peers_are_queued_per_shard(self) -> None: + """Concurrent downloads should not overwrite each other's peer list.""" + downloader = PeerAwareShardDownloader(NoopShardDownloader()) + shard_a = _make_shard(ModelId("test-org/model-a")) + shard_b = _make_shard(ModelId("test-org/model-b")) + peer_a = PeerEndpoint( + node_id=NodeId("aaaaaaaa-aaaa-4aaa-8aaa-aaaaaaaaaaaa"), + ip="10.0.0.1", + port=52415, + ) + peer_b = PeerEndpoint( + node_id=NodeId("bbbbbbbb-bbbb-4bbb-8bbb-bbbbbbbbbbbb"), + ip="10.0.0.2", + port=52415, + ) + + downloader.set_available_peers(shard_a, [peer_a]) + downloader.set_available_peers(shard_b, [peer_b]) + + assert downloader._pop_available_peers(shard_b) == [peer_b] + assert downloader._pop_available_peers(shard_a) == [peer_a] + assert downloader._pop_available_peers(shard_a) == [] + + def test_peers_for_same_shard_are_not_overwritten(self) -> None: + """Repeated commands for one shard should be consumed FIFO.""" + downloader = PeerAwareShardDownloader(NoopShardDownloader()) + shard = _make_shard(ModelId("test-org/model-a")) + peer_a = PeerEndpoint( + node_id=NodeId("aaaaaaaa-aaaa-4aaa-8aaa-aaaaaaaaaaaa"), + ip="10.0.0.1", + port=52415, + ) + peer_b = PeerEndpoint( + node_id=NodeId("bbbbbbbb-bbbb-4bbb-8bbb-bbbbbbbbbbbb"), + ip="10.0.0.2", + port=52415, + ) + + downloader.set_available_peers(shard, [peer_a]) + downloader.set_available_peers(shard, [peer_b]) + + assert downloader._pop_available_peers(shard) == [peer_a] + assert downloader._pop_available_peers(shard) == [peer_b] + assert downloader._pop_available_peers(shard) == [] + + +class TestPeerSelectionRespectsOfflineAndIgnorePatterns: + """Codex P1s on PR #16 round 2: peer selection must mirror + ``download_shard``'s logic exactly (``ignore_patterns`` for + ``original/*`` / ``metal/*``) and must propagate the coordinator's + offline mode into ``fetch_file_list_with_cache`` so a cold offline + node can still complete a peer download without reaching out to + HuggingFace for the initial file list. + """ + + def test_offline_flag_defaults_to_false(self) -> None: + downloader = PeerAwareShardDownloader(NoopShardDownloader()) + assert downloader._offline is False + + def test_offline_flag_propagates(self) -> None: + downloader = PeerAwareShardDownloader(NoopShardDownloader(), offline=True) + assert downloader._offline is True + + async def test_try_peer_download_passes_offline_to_fetch_file_list( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + """``_try_peer_download`` must thread ``self._offline`` into + ``fetch_file_list_with_cache`` instead of always passing + ``skip_internet=False``. We capture the kwargs by patching + the import binding inside ``peer_shard_downloader``. + """ + from exo.download import peer_shard_downloader as psd + from exo.download.peer_download import PeerFileInfo + from exo.shared.types.worker.downloads import FileListEntry + + captured: dict[str, object] = {} + + async def fake_fetch(*args: object, **kwargs: object) -> list[FileListEntry]: + captured["args"] = args + captured["kwargs"] = kwargs + # Empty list -> no required files -> ``failed`` short- + # circuit -> we get out cleanly with the call kwargs + # captured. + return [] + + async def fake_peer_status( + peer_host: str, + peer_port: int, + model_id_normalized: str, + timeout: float = 5.0, + ) -> list[PeerFileInfo] | None: + return [ + PeerFileInfo( + path="model-00001-of-00002.safetensors", + size=10, + complete=True, + safe_bytes=10, + ) + ] + + async def fake_resolve_dir(model_id: ModelId) -> Path: + return Path("/tmp/fake-model") + + async def fake_resolve_allow(shard: ShardMetadata) -> list[str]: + return ["*.safetensors"] + + monkeypatch.setattr(psd, "fetch_file_list_with_cache", fake_fetch) + monkeypatch.setattr(psd, "get_peer_file_status", fake_peer_status) + monkeypatch.setattr(psd, "resolve_model_dir", fake_resolve_dir) + monkeypatch.setattr(psd, "resolve_allow_patterns", fake_resolve_allow) + + downloader = PeerAwareShardDownloader(NoopShardDownloader(), offline=True) + shard = _make_shard(ModelId("test-org/model-a")) + + result = await downloader._try_peer_download( + shard, + peer_ip="10.0.0.1", + peer_port=52415, + model_id_normalized="test-org/model-a", + ) + # Empty file list short-circuits to ``failed`` path and returns + # None, but that's beside the point -- we just need the kwargs. + assert result is None + assert captured["kwargs"] == { + "recursive": True, + "skip_internet": True, + }, f"skip_internet must reflect downloader.offline (got {captured['kwargs']!r})" + + async def test_try_peer_download_filters_ignore_patterns( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + """Files under ``original/*`` and ``metal/*`` are excluded by + ``download_shard``; the peer path must skip them too. Pre-fix + the peer path filtered only ``allow_patterns``, leaving these + in the required-files list. The peer doesn't have them + locally (HF never downloads them), the strict + ``peer_info missing => fail`` check fired, and every download + fell back to HuggingFace. + """ + from exo.download import peer_shard_downloader as psd + from exo.download.peer_download import PeerFileInfo + from exo.shared.types.worker.downloads import FileListEntry + + served = [ + FileListEntry( + type="file", + path="model-00001-of-00002.safetensors", + size=100, + ), + FileListEntry(type="file", path="config.json", size=10), + # These two should NOT show up on the peer's required-files + # list once the fix lands. Pre-fix they did, the peer didn't + # have them, and the whole transfer fell back to HF. + FileListEntry(type="file", path="original/consolidated.00.pth", size=999), + FileListEntry(type="file", path="metal/dist.bin", size=999), + ] + + async def fake_fetch(*_args: object, **_kwargs: object) -> list[FileListEntry]: + return served + + # The peer reports ONLY the canonical files, exactly the shape + # production peers are in (HF never downloaded ``original/*`` or + # ``metal/*`` for them either). + peer_paths = ("model-00001-of-00002.safetensors", "config.json") + + async def fake_peer_status( + peer_host: str, + peer_port: int, + model_id_normalized: str, + timeout: float = 5.0, + ) -> list[PeerFileInfo] | None: + return [ + PeerFileInfo(path=p, size=100, complete=True, safe_bytes=100) + for p in peer_paths + ] + + async def fake_resolve_dir(model_id: ModelId) -> Path: + return Path("/tmp/fake-model") + + async def fake_resolve_allow(shard: ShardMetadata) -> list[str]: + # Match the production allow set permissively; the legacy + # bug was that ``allow_patterns`` admitted ``original/*`` / + # ``metal/*`` whenever the repo allow-list was loose. + return ["*"] + + async def fake_download( + peer_ip: str, + peer_port: int, + model_id_normalized: str, + file_path: str, + target_dir: Path, + expected_size: int, + on_progress: object = None, + ) -> Path | None: + return None + + captured_kwargs: list[object] = [] + real_filter = psd.filter_repo_objects + + def recording_filter( + items: Iterable[FileListEntry], + *, + allow_patterns: list[str] | str | None = None, + ignore_patterns: list[str] | str | None = None, + key: Callable[[FileListEntry], str] | None = None, + ) -> Generator[FileListEntry, None, None]: + captured_kwargs.append(ignore_patterns) + yield from real_filter( + items, + allow_patterns=allow_patterns, + ignore_patterns=ignore_patterns, + key=key, + ) + + monkeypatch.setattr(psd, "fetch_file_list_with_cache", fake_fetch) + monkeypatch.setattr(psd, "get_peer_file_status", fake_peer_status) + monkeypatch.setattr(psd, "resolve_model_dir", fake_resolve_dir) + monkeypatch.setattr(psd, "resolve_allow_patterns", fake_resolve_allow) + monkeypatch.setattr(psd, "download_file_from_peer", fake_download) + monkeypatch.setattr(psd, "filter_repo_objects", recording_filter) + + downloader = PeerAwareShardDownloader(NoopShardDownloader()) + shard = _make_shard(ModelId("test-org/model-a")) + + await downloader._try_peer_download( + shard, + peer_ip="10.0.0.1", + peer_port=52415, + model_id_normalized="test-org/model-a", + ) + + assert captured_kwargs == [["original/*", "metal/*"]], ( + "peer download must apply the same ``ignore_patterns`` set " + "as ``download_shard`` (download_utils.py:983) so peers " + "that don't have ``original/*`` / ``metal/*`` aren't " + "incorrectly judged incomplete; got " + f"{captured_kwargs!r}" + ) + + +class TestPeerDownloadIntegrityCheckRespectsOfflineMode: + """Codex P1 on PR #16 round 3: ``_try_peer_download`` was calling + ``file_meta(...)`` against HuggingFace for every file, even when the + coordinator was started with ``--offline`` / ``EXO_OFFLINE=true``. + Any failure to reach HF (the entire point of offline mode) was + treated as an integrity-check failure, the peer-fetched bytes were + deleted, and the cold node was left with no path to complete model + sync. The fix: when the downloader is in offline mode, trust the + LAN peer's bytes and skip the HF metadata call entirely. + """ + + async def test_offline_mode_skips_file_meta_and_keeps_peer_bytes( + self, monkeypatch: pytest.MonkeyPatch, tmp_path: Path + ) -> None: + from exo.download import peer_shard_downloader as psd + from exo.download.peer_download import PeerFileInfo + from exo.shared.types.worker.downloads import FileListEntry + + async def fake_fetch(*_args: object, **_kwargs: object) -> list[FileListEntry]: + return [ + FileListEntry( + type="file", + path="model.safetensors", + size=10, + ), + ] + + async def fake_peer_status( + peer_host: str, + peer_port: int, + model_id_normalized: str, + timeout: float = 5.0, + ) -> list[PeerFileInfo] | None: + return [ + PeerFileInfo( + path="model.safetensors", + size=10, + complete=True, + safe_bytes=10, + ) + ] + + async def fake_resolve_dir(model_id: ModelId) -> Path: + return tmp_path + + async def fake_resolve_allow(shard: ShardMetadata) -> list[str]: + return ["*"] + + target_path = tmp_path / "model.safetensors" + + async def fake_download( + peer_ip: str, + peer_port: int, + model_id_normalized: str, + file_path: str, + target_dir: Path, + expected_size: int, + on_progress: object = None, + ) -> Path | None: + async with aiofiles.open(target_path, "wb") as f: + await f.write(b"0123456789") + return target_path + + async def file_meta_should_not_be_called( + *_args: object, **_kwargs: object + ) -> tuple[int, str]: + raise AssertionError( + "file_meta must not be called in offline mode -- the " + "operator opted into trusting LAN peers" + ) + + monkeypatch.setattr(psd, "fetch_file_list_with_cache", fake_fetch) + monkeypatch.setattr(psd, "get_peer_file_status", fake_peer_status) + monkeypatch.setattr(psd, "resolve_model_dir", fake_resolve_dir) + monkeypatch.setattr(psd, "resolve_allow_patterns", fake_resolve_allow) + monkeypatch.setattr(psd, "download_file_from_peer", fake_download) + monkeypatch.setattr(psd, "file_meta", file_meta_should_not_be_called) + + downloader = PeerAwareShardDownloader(NoopShardDownloader(), offline=True) + shard = _make_shard(ModelId("test-org/model-a")) + + result = await downloader._try_peer_download( + shard, + peer_ip="10.0.0.1", + peer_port=52415, + model_id_normalized="test-org/model-a", + ) + assert result is not None, ( + "offline peer download must succeed without consulting HF; " + "got None which means the integrity check fired and the " + "peer bytes were discarded" + ) + assert await aios.path.exists(target_path), ( + "peer-downloaded file must be retained when offline mode " + "skips the HF integrity check" + ) + + async def test_online_mode_still_calls_file_meta( + self, monkeypatch: pytest.MonkeyPatch, tmp_path: Path + ) -> None: + from exo.download import peer_shard_downloader as psd + from exo.download.peer_download import PeerFileInfo + from exo.shared.types.worker.downloads import FileListEntry + + async def fake_fetch(*_args: object, **_kwargs: object) -> list[FileListEntry]: + return [ + FileListEntry( + type="file", + path="model.safetensors", + size=10, + ), + ] + + async def fake_peer_status( + peer_host: str, + peer_port: int, + model_id_normalized: str, + timeout: float = 5.0, + ) -> list[PeerFileInfo] | None: + return [ + PeerFileInfo( + path="model.safetensors", + size=10, + complete=True, + safe_bytes=10, + ) + ] + + async def fake_resolve_dir(model_id: ModelId) -> Path: + return tmp_path + + async def fake_resolve_allow(shard: ShardMetadata) -> list[str]: + return ["*"] + + target_path = tmp_path / "model.safetensors" + + async def fake_download( + peer_ip: str, + peer_port: int, + model_id_normalized: str, + file_path: str, + target_dir: Path, + expected_size: int, + on_progress: object = None, + ) -> Path | None: + async with aiofiles.open(target_path, "wb") as f: + await f.write(b"0123456789") + return target_path + + meta_calls: list[tuple[object, ...]] = [] + + async def recording_meta(*args: object, **_kwargs: object) -> tuple[int, str]: + meta_calls.append(args) + # Return mismatched etag -> downloader will discard. + return (10, "deadbeef" * 5) + + monkeypatch.setattr(psd, "fetch_file_list_with_cache", fake_fetch) + monkeypatch.setattr(psd, "get_peer_file_status", fake_peer_status) + monkeypatch.setattr(psd, "resolve_model_dir", fake_resolve_dir) + monkeypatch.setattr(psd, "resolve_allow_patterns", fake_resolve_allow) + monkeypatch.setattr(psd, "download_file_from_peer", fake_download) + monkeypatch.setattr(psd, "file_meta", recording_meta) + + downloader = PeerAwareShardDownloader(NoopShardDownloader(), offline=False) + shard = _make_shard(ModelId("test-org/model-a")) + + await downloader._try_peer_download( + shard, + peer_ip="10.0.0.1", + peer_port=52415, + model_id_normalized="test-org/model-a", + ) + + assert len(meta_calls) == 1, ( + "online mode must continue calling file_meta to validate " + "peer-downloaded bytes against HF's authoritative hash; " + f"got meta_calls={meta_calls!r}" + ) + + +class TestPeerDownloadZeroByteFiles: + """Codex P2 (PR #16 round-(N+10), peer_shard_downloader.py:354): + The peer transfer path skipped every file whose declared size was + 0 (e.g. ``.gitattributes`` markers, empty ``__init__.py`` shims), + so DownloadCompleted was published with an incomplete local + snapshot. Loaders that probe for those marker files at runtime + (chat-template adapters, processor configs that expect an empty + sentinel) then failed in ways that didn't point back at the peer + step. The fix materializes the zero-byte files locally after the + rest of the peer transfer succeeds. + """ + + async def test_zero_byte_marker_files_materialized_after_peer_transfer( + self, monkeypatch: pytest.MonkeyPatch, tmp_path: Path + ) -> None: + """A repo containing canonical bytes plus an empty marker file + must end the peer transfer with BOTH on disk -- the marker is + a zero-byte file that pre-fix was silently dropped. + """ + from exo.download import peer_shard_downloader as psd + from exo.download.peer_download import PeerFileInfo + from exo.shared.types.worker.downloads import FileListEntry + + served = [ + FileListEntry(type="file", path="model.safetensors", size=10), + # Zero-byte sentinel; pre-fix the peer path silently + # skipped this and the local snapshot was incomplete. + FileListEntry(type="file", path=".gitattributes", size=0), + # Empty shim that loaders sometimes probe for. + FileListEntry(type="file", path="empty/__init__.py", size=0), + ] + + async def fake_fetch(*_args: object, **_kwargs: object) -> list[FileListEntry]: + return served + + async def fake_peer_status( + peer_host: str, # noqa: ARG001 + peer_port: int, # noqa: ARG001 + model_id_normalized: str, # noqa: ARG001 + timeout: float = 5.0, # noqa: ARG001 + ) -> list[PeerFileInfo] | None: + # The peer reports only the canonical bytes (mirrors + # production peers; HF-shard listings do not include + # zero-byte markers either). + return [ + PeerFileInfo( + path="model.safetensors", size=10, complete=True, safe_bytes=10 + ) + ] + + async def fake_resolve_dir(_model_id: ModelId) -> Path: + return tmp_path + + async def fake_resolve_allow(_shard: ShardMetadata) -> list[str]: + return ["*"] + + target_path = tmp_path / "model.safetensors" + + async def fake_download( + peer_ip: str, # noqa: ARG001 + peer_port: int, # noqa: ARG001 + model_id_normalized: str, # noqa: ARG001 + file_path: str, # noqa: ARG001 + target_dir: Path, # noqa: ARG001 + expected_size: int, # noqa: ARG001 + on_progress: object = None, # noqa: ARG001 + ) -> Path | None: + async with aiofiles.open(target_path, "wb") as f: + await f.write(b"0123456789") + return target_path + + monkeypatch.setattr(psd, "fetch_file_list_with_cache", fake_fetch) + monkeypatch.setattr(psd, "get_peer_file_status", fake_peer_status) + monkeypatch.setattr(psd, "resolve_model_dir", fake_resolve_dir) + monkeypatch.setattr(psd, "resolve_allow_patterns", fake_resolve_allow) + monkeypatch.setattr(psd, "download_file_from_peer", fake_download) + + downloader = PeerAwareShardDownloader(NoopShardDownloader(), offline=True) + shard = _make_shard(ModelId("test-org/model-a")) + + result = await downloader._try_peer_download( + shard, + peer_ip="10.0.0.1", + peer_port=52415, + model_id_normalized="test-org/model-a", + ) + assert result is not None, ( + "peer transfer must succeed when the only missing 'files' " + "are zero-byte markers; pre-fix the path returned success " + "without materializing them, so subsequent loads broke" + ) + assert await aios.path.exists(target_path), ( + "the canonical safetensor must still be present" + ) + # The crux of the regression test: zero-byte markers MUST be on disk. + gitattributes = tmp_path / ".gitattributes" + empty_shim = tmp_path / "empty" / "__init__.py" + assert await aios.path.exists(gitattributes), ( + "zero-byte ``.gitattributes`` marker must be materialized on " + "disk after peer transfer; pre-fix it was silently skipped " + "and DownloadCompleted reported success on an incomplete dir" + ) + assert await aios.path.exists(empty_shim), ( + "zero-byte ``empty/__init__.py`` shim must exist after peer " + "transfer (parent dir must also be created)" + ) + # Both must literally be empty. + assert (await aios.stat(gitattributes)).st_size == 0 + assert (await aios.stat(empty_shim)).st_size == 0 + + async def test_zero_byte_files_marked_complete_in_progress_map( + self, monkeypatch: pytest.MonkeyPatch, tmp_path: Path + ) -> None: + """Codex P2 (PR #16 round-(N+13), peer_shard_downloader.py:407): + zero-byte files must be marked ``status="complete"`` in the + progress map AFTER materialization, otherwise the final + ``calculate_repo_progress`` call rolls them up as + ``status="not_started"`` and the overall repo status stays + non-complete -- so ``_download_progress_callback`` does not + publish ``DownloadCompleted`` immediately and the model is + stuck in ``DownloadOngoing`` until reconciliation runs. + + We exercise the same fixture as the materialization test, + but capture the *final* progress callback emission (the one + the coordinator turns into ``DownloadCompleted``) and + assert its ``status`` is ``"complete"`` and that every + per-file entry is also ``"complete"``. + """ + from exo.download import peer_shard_downloader as psd + from exo.download.download_utils import RepoDownloadProgress + from exo.download.peer_download import PeerFileInfo + from exo.shared.types.worker.downloads import FileListEntry + + served = [ + FileListEntry(type="file", path="model.safetensors", size=10), + FileListEntry(type="file", path=".gitattributes", size=0), + FileListEntry(type="file", path="empty/__init__.py", size=0), + ] + + async def fake_fetch(*_args: object, **_kwargs: object) -> list[FileListEntry]: + return served + + async def fake_peer_status( + peer_host: str, # noqa: ARG001 + peer_port: int, # noqa: ARG001 + model_id_normalized: str, # noqa: ARG001 + timeout: float = 5.0, # noqa: ARG001 + ) -> list[PeerFileInfo] | None: + return [ + PeerFileInfo( + path="model.safetensors", size=10, complete=True, safe_bytes=10 + ) + ] + + async def fake_resolve_dir(_model_id: ModelId) -> Path: + return tmp_path + + async def fake_resolve_allow(_shard: ShardMetadata) -> list[str]: + return ["*"] + + target_path = tmp_path / "model.safetensors" + + async def fake_download( + peer_ip: str, # noqa: ARG001 + peer_port: int, # noqa: ARG001 + model_id_normalized: str, # noqa: ARG001 + file_path: str, # noqa: ARG001 + target_dir: Path, # noqa: ARG001 + expected_size: int, + on_progress: Callable[[int, int, bool], None] = lambda _a, _b, _c: None, + ) -> Path | None: + async with aiofiles.open(target_path, "wb") as f: + await f.write(b"0123456789") + # Match the production peer_download contract: emit the + # final rename-completed progress callback so the + # canonical-file's per-file progress entry transitions + # to ``status="complete"`` like it would in production. + on_progress(expected_size, expected_size, True) + return target_path + + monkeypatch.setattr(psd, "fetch_file_list_with_cache", fake_fetch) + monkeypatch.setattr(psd, "get_peer_file_status", fake_peer_status) + monkeypatch.setattr(psd, "resolve_model_dir", fake_resolve_dir) + monkeypatch.setattr(psd, "resolve_allow_patterns", fake_resolve_allow) + monkeypatch.setattr(psd, "download_file_from_peer", fake_download) + + downloader = PeerAwareShardDownloader(NoopShardDownloader(), offline=True) + shard = _make_shard(ModelId("test-org/model-a")) + + # Capture the final progress emitted by the peer downloader + # so we can assert its rolled-up status. + captured: list[RepoDownloadProgress] = [] + + async def capture_progress( + _shard: ShardMetadata, progress: RepoDownloadProgress + ) -> None: + captured.append(progress) + + downloader._progress_callbacks.append(capture_progress) + + result = await downloader._try_peer_download( + shard, + peer_ip="10.0.0.1", + peer_port=52415, + model_id_normalized="test-org/model-a", + ) + assert result is not None + assert captured, ( + "peer downloader must emit at least one progress event " + "(the rolled-up final status); pre-fix the test never " + "got past this because the canonical file's per-byte " + "callback also triggers an emit" + ) + final = captured[-1] + assert final.status == "complete", ( + "rolled-up final repo progress must be ``complete`` once " + "every file (including zero-byte markers) is on disk; " + "pre-(N+13)-fix the zero-byte entries stayed at " + "``not_started`` so the rollup was non-complete and " + "DownloadCompleted was never published. " + f"final.status={final.status!r} " + f"per_file={[(p, e.status) for p, e in final.file_progress.items()]}" + ) + for marker in (".gitattributes", "empty/__init__.py"): + entry = final.file_progress.get(marker) + assert entry is not None, ( + f"file_progress must contain entry for {marker!r}; " + "pre-fix the seeded ``not_started`` entry was never " + "updated, so this assert succeeded but on the wrong " + "status -- this version of the assert covers both " + "regressions (entry presence and final status)" + ) + assert entry.status == "complete", ( + f"zero-byte marker {marker!r} must be marked complete " + f"in the progress map after materialization; " + f"pre-fix status was {entry.status!r} which causes " + f"calculate_repo_progress to roll up to non-complete" + ) + + async def test_unknown_size_file_aborts_peer_transfer_for_hf_fallback( + self, monkeypatch: pytest.MonkeyPatch, tmp_path: Path + ) -> None: + """Codex P1 (PR #16 round-(N+11), peer_shard_downloader.py:354): + ``FileListEntry(size=None)`` is NOT a zero-byte marker -- the + upstream ``fetch_file_list_with_cache`` returns ``size=None`` + for files discovered via the safetensors index whose size + wasn't in the HF API response (real weight shards). Pre-fix + the round-(N+10) materialize-as-empty path treated those as + empty markers and reported peer transfer success on a + corrupted snapshot. + + Post-fix, ``size is None`` aborts the peer transfer (returns + None) so the HF fallback gets a real download path. We + construct a file list with a real safetensor (size=10) and + an unknown-size weight shard (size=None) and assert the + peer transfer returns None *without* materializing the + unknown-size entry as an empty file. + """ + from exo.download import peer_shard_downloader as psd + from exo.download.peer_download import PeerFileInfo + from exo.shared.types.worker.downloads import FileListEntry + + served = [ + FileListEntry(type="file", path="model.safetensors", size=10), + # Unknown size: real weight shard from safetensors index. + FileListEntry( + type="file", path="model-00002-of-00003.safetensors", size=None + ), + ] + + async def fake_fetch(*_args: object, **_kwargs: object) -> list[FileListEntry]: + return served + + async def fake_peer_status( + peer_host: str, # noqa: ARG001 + peer_port: int, # noqa: ARG001 + model_id_normalized: str, # noqa: ARG001 + timeout: float = 5.0, # noqa: ARG001 + ) -> list[PeerFileInfo] | None: + return [ + PeerFileInfo( + path="model.safetensors", size=10, complete=True, safe_bytes=10 + ), + PeerFileInfo( + path="model-00002-of-00003.safetensors", + size=999, + complete=True, + safe_bytes=999, + ), + ] + + async def fake_resolve_dir(_model_id: ModelId) -> Path: + return tmp_path + + async def fake_resolve_allow(_shard: ShardMetadata) -> list[str]: + return ["*"] + + download_called = anyio.Event() + + async def fake_download( + peer_ip: str, # noqa: ARG001 + peer_port: int, # noqa: ARG001 + model_id_normalized: str, # noqa: ARG001 + file_path: str, # noqa: ARG001 + target_dir: Path, # noqa: ARG001 + expected_size: int, # noqa: ARG001 + on_progress: object = None, # noqa: ARG001 + ) -> Path | None: + download_called.set() + return None + + monkeypatch.setattr(psd, "fetch_file_list_with_cache", fake_fetch) + monkeypatch.setattr(psd, "get_peer_file_status", fake_peer_status) + monkeypatch.setattr(psd, "resolve_model_dir", fake_resolve_dir) + monkeypatch.setattr(psd, "resolve_allow_patterns", fake_resolve_allow) + monkeypatch.setattr(psd, "download_file_from_peer", fake_download) + + downloader = PeerAwareShardDownloader(NoopShardDownloader(), offline=True) + shard = _make_shard(ModelId("test-org/model-a")) + + result = await downloader._try_peer_download( + shard, + peer_ip="10.0.0.1", + peer_port=52415, + model_id_normalized="test-org/model-a", + ) + assert result is None, ( + "peer transfer must abort (return None) when the file list " + "contains a size=None entry; HF fallback then takes over to " + "ensure the unknown-size weight is properly downloaded. " + "Pre-fix the size=None entry was lumped with size=0 markers " + "and materialized as empty, producing corrupted snapshots." + ) + # The unknown-size file must NOT have been created as empty + # by the marker-materialization path. + unknown_path = tmp_path / "model-00002-of-00003.safetensors" + assert not await aios.path.exists(unknown_path), ( + "size=None entries must NOT be materialized as empty marker " + "files -- they're real weights of unknown size, not markers" + ) + assert not download_called.is_set(), ( + "peer transfer should abort BEFORE issuing any download " + "call when a size=None entry is encountered" + ) + + async def test_zero_byte_files_not_created_when_canonical_transfer_fails( + self, monkeypatch: pytest.MonkeyPatch, tmp_path: Path + ) -> None: + """If the non-empty file transfer fails, the zero-byte markers + must NOT be created. Otherwise the local model dir would + contain orphan empty files masquerading as a partial download + and the HF fallback might skip them. + """ + from exo.download import peer_shard_downloader as psd + from exo.download.peer_download import PeerFileInfo + from exo.shared.types.worker.downloads import FileListEntry + + served = [ + FileListEntry(type="file", path="model.safetensors", size=10), + FileListEntry(type="file", path=".gitattributes", size=0), + ] + + async def fake_fetch(*_args: object, **_kwargs: object) -> list[FileListEntry]: + return served + + async def fake_peer_status( + peer_host: str, # noqa: ARG001 + peer_port: int, # noqa: ARG001 + model_id_normalized: str, # noqa: ARG001 + timeout: float = 5.0, # noqa: ARG001 + ) -> list[PeerFileInfo] | None: + return [ + PeerFileInfo( + path="model.safetensors", size=10, complete=True, safe_bytes=10 + ) + ] + + async def fake_resolve_dir(_model_id: ModelId) -> Path: + return tmp_path + + async def fake_resolve_allow(_shard: ShardMetadata) -> list[str]: + return ["*"] + + async def failing_download( + peer_ip: str, # noqa: ARG001 + peer_port: int, # noqa: ARG001 + model_id_normalized: str, # noqa: ARG001 + file_path: str, # noqa: ARG001 + target_dir: Path, # noqa: ARG001 + expected_size: int, # noqa: ARG001 + on_progress: object = None, # noqa: ARG001 + ) -> Path | None: + return None + + monkeypatch.setattr(psd, "fetch_file_list_with_cache", fake_fetch) + monkeypatch.setattr(psd, "get_peer_file_status", fake_peer_status) + monkeypatch.setattr(psd, "resolve_model_dir", fake_resolve_dir) + monkeypatch.setattr(psd, "resolve_allow_patterns", fake_resolve_allow) + monkeypatch.setattr(psd, "download_file_from_peer", failing_download) + + downloader = PeerAwareShardDownloader(NoopShardDownloader(), offline=True) + shard = _make_shard(ModelId("test-org/model-a")) + + result = await downloader._try_peer_download( + shard, + peer_ip="10.0.0.1", + peer_port=52415, + model_id_normalized="test-org/model-a", + ) + assert result is None, ( + "peer transfer must report failure when the non-empty " + "canonical bytes never landed; the HF fallback then runs" + ) + gitattributes = tmp_path / ".gitattributes" + assert not await aios.path.exists(gitattributes), ( + "zero-byte markers must NOT be created if the canonical " + "transfer failed -- otherwise the partial dir confuses the " + "HF fallback's already-downloaded probe" + ) + + +def _allocate_free_tcp_port() -> int: + """Bind ephemeral port 0 to grab a free TCP port; close before reuse. + + Used by lifecycle tests that want to verify a specific port is + released after server teardown -- we cannot bind 0 in the server + itself because the test needs a stable port to assert on. + """ + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as probe: + probe.bind(("127.0.0.1", 0)) + return cast(int, probe.getsockname()[1]) + + +class TestPeerFileServerLifecycle: + """Codex P2 (PR #16 round-(N+10), peer_file_server.py:56): the + coroutine returned by ``PeerFileServer.run()`` must stay alive + until cancelled, otherwise the parent task group considers the + server "done" the moment ``site.start()`` returns and never drives + cleanup -- the listening socket leaks until process exit, causing + ``EADDRINUSE`` on stop/restart in the same process (tests, + embedded runs, systemd-style restart loops). + """ + + async def test_run_blocks_until_cancelled(self, tmp_path: Path) -> None: + models_dir = tmp_path / "models" + await aios.makedirs(models_dir, exist_ok=True) + server = PeerFileServer(host="127.0.0.1", port=0, models_dirs=[models_dir]) + + run_completed = anyio.Event() + + async def _run_and_signal() -> None: + try: + await server.run() + finally: + run_completed.set() + + async with anyio.create_task_group() as tg: + tg.start_soon(_run_and_signal) + # Yield a few times so the server can boot. + for _ in range(5): + await anyio.sleep(0.01) + assert not run_completed.is_set(), ( + "PeerFileServer.run must keep the coroutine alive after " + "site.start() so task-group cancellation can drive " + "teardown; pre-fix it returned immediately and the " + "listening socket leaked until process exit" + ) + tg.cancel_scope.cancel() + assert run_completed.is_set() + + async def test_listening_port_is_released_after_run_cancellation( + self, tmp_path: Path + ) -> None: + """End-to-end EADDRINUSE regression: pre-fix a stop/restart + in the same process raised ``OSError: [Errno 48] address + already in use`` because cleanup never ran. After the fix the + same port must be re-bindable immediately after cancellation. + """ + models_dir = tmp_path / "models" + await aios.makedirs(models_dir, exist_ok=True) + port = _allocate_free_tcp_port() + + server = PeerFileServer(host="127.0.0.1", port=port, models_dirs=[models_dir]) + + async with anyio.create_task_group() as tg: + tg.start_soon(server.run) + for _ in range(10): + await anyio.sleep(0.02) + async with aiohttp.ClientSession() as s: + try: + async with s.get( + f"http://127.0.0.1:{port}/health", + timeout=aiohttp.ClientTimeout(total=0.5), + ) as r: + if r.status == 200: + break + except (aiohttp.ClientError, TimeoutError): + continue + else: + raise AssertionError( + "PeerFileServer never started listening on the " + f"allocated port {port}" + ) + tg.cancel_scope.cancel() + + # Restart on the same port immediately. Pre-fix this raised + # EADDRINUSE because the prior listener was never closed. + server2 = PeerFileServer(host="127.0.0.1", port=port, models_dirs=[models_dir]) + async with anyio.create_task_group() as tg2: + tg2.start_soon(server2.run) + await anyio.sleep(0.05) + async with ( + aiohttp.ClientSession() as s, + s.get( + f"http://127.0.0.1:{port}/health", + timeout=aiohttp.ClientTimeout(total=2.0), + ) as r, + ): + assert r.status == 200, ( + "server2 must come up cleanly on the recycled " + "port; pre-fix the prior server's socket " + "leaked and this raised EADDRINUSE" + ) + tg2.cancel_scope.cancel() diff --git a/src/exo/download/tests/test_peer_state.py b/src/exo/download/tests/test_peer_state.py new file mode 100644 index 0000000000..570692373d --- /dev/null +++ b/src/exo/download/tests/test_peer_state.py @@ -0,0 +1,142 @@ +"""Regression tests for ``exo.download.peer_state``. + +These exercise the topology-iteration ordering that decides whether a peer +is reachable over RDMA or merely via socket. The original implementation +returned on the first edge whose type happened to be visited first, which +mislabelled peers when ``out_edges`` yielded the socket edge before the +RDMA edge. We now scan all edges and prefer RDMA whenever any RDMA edge +exists for that peer. +""" + +from collections.abc import Iterable +from pathlib import Path +from typing import cast + +from exo.download.peer_state import discover_peers_for_model +from exo.shared.models.model_cards import ModelCard, ModelId, ModelTask +from exo.shared.topology import Topology +from exo.shared.types.common import NodeId +from exo.shared.types.memory import Memory +from exo.shared.types.multiaddr import Multiaddr +from exo.shared.types.state import State +from exo.shared.types.topology import ( + Connection, + RDMAConnection, + SocketConnection, +) +from exo.shared.types.worker.downloads import DownloadCompleted, DownloadProgress +from exo.shared.types.worker.shards import PipelineShardMetadata, ShardMetadata + +LOCAL = NodeId("aaaaaaaa-aaaa-4aaa-8aaa-aaaaaaaaaaaa") +PEER = NodeId("bbbbbbbb-bbbb-4bbb-8bbb-bbbbbbbbbbbb") +MODEL_ID = ModelId("test-org/test-model") +NORMALIZED = MODEL_ID.normalize() + + +def _make_shard() -> ShardMetadata: + return PipelineShardMetadata( + model_card=ModelCard( + model_id=MODEL_ID, + storage_size=Memory.from_mb(100), + n_layers=4, + hidden_size=64, + supports_tensor=False, + tasks=[ModelTask.TextGeneration], + ), + device_rank=0, + world_size=1, + start_layer=0, + end_layer=4, + n_layers=4, + ) + + +def _build_topology(edges: Iterable[Connection]) -> Topology: + topology = Topology() + topology.add_node(LOCAL) + topology.add_node(PEER) + for conn in edges: + topology.add_connection(conn) + return topology + + +def _state_with_completed_peer(topology: Topology) -> State: + completed = DownloadCompleted( + node_id=PEER, + shard_metadata=_make_shard(), + total=Memory.from_mb(100), + model_directory=str(Path("/fake/models/test-org--test-model")), + ) + return State( + downloads={PEER: [cast(DownloadProgress, completed)]}, + topology=topology, + ) + + +def _socket_edge() -> Connection: + return Connection( + source=LOCAL, + sink=PEER, + edge=SocketConnection( + sink_multiaddr=Multiaddr(address="/ip4/10.0.0.2/tcp/4001") + ), + ) + + +def _rdma_edge() -> Connection: + return Connection( + source=LOCAL, + sink=PEER, + edge=RDMAConnection(source_rdma_iface="bridge0", sink_rdma_iface="bridge0"), + ) + + +def test_peer_marked_rdma_when_socket_edge_inserted_first() -> None: + """If both an RDMA edge and a socket edge exist for the same peer, the + peer must be reported as RDMA *regardless of insertion order*. The + original implementation returned on the first edge it saw, so a socket + edge inserted before the RDMA edge silently downgraded a real RDMA peer + to ``socket`` and broke the "RDMA first" ordering used by the peer + downloader. + """ + topology = _build_topology([_socket_edge(), _rdma_edge()]) + state = _state_with_completed_peer(topology) + + peers = discover_peers_for_model(LOCAL, state, NORMALIZED, peer_download_port=52416) + + assert len(peers) == 1 + assert peers[0].connection_type == "rdma" + assert peers[0].ip == "10.0.0.2" + + +def test_peer_marked_rdma_when_rdma_edge_inserted_first() -> None: + topology = _build_topology([_rdma_edge(), _socket_edge()]) + state = _state_with_completed_peer(topology) + + peers = discover_peers_for_model(LOCAL, state, NORMALIZED, peer_download_port=52416) + + assert len(peers) == 1 + assert peers[0].connection_type == "rdma" + + +def test_peer_marked_socket_when_no_rdma_edge_exists() -> None: + topology = _build_topology([_socket_edge()]) + state = _state_with_completed_peer(topology) + + peers = discover_peers_for_model(LOCAL, state, NORMALIZED, peer_download_port=52416) + + assert len(peers) == 1 + assert peers[0].connection_type == "socket" + assert peers[0].ip == "10.0.0.2" + + +def test_peer_skipped_when_only_rdma_edge_has_no_socket_companion() -> None: + """An RDMA-only peer cannot be contacted over the peer-download HTTP + server, so we must omit it rather than fabricate a missing IP. + """ + topology = _build_topology([_rdma_edge()]) + state = _state_with_completed_peer(topology) + + peers = discover_peers_for_model(LOCAL, state, NORMALIZED, peer_download_port=52416) + + assert peers == [] diff --git a/src/exo/main.py b/src/exo/main.py index 9edee42096..54749c0112 100644 --- a/src/exo/main.py +++ b/src/exo/main.py @@ -1,9 +1,11 @@ import argparse +import ipaddress import multiprocessing as mp import os import resource import signal import subprocess +import sys from dataclasses import dataclass, field from datetime import datetime, timezone from typing import Self @@ -16,10 +18,16 @@ from exo.api.main import API from exo.download.coordinator import DownloadCoordinator from exo.download.impl_shard_downloader import exo_shard_downloader +from exo.download.peer_file_server import PeerFileServer from exo.master.main import Master from exo.routing.event_router import EventRouter from exo.routing.router import Router, get_node_id_keypair -from exo.shared.constants import EXO_LOG +from exo.shared.constants import ( + EXO_LOG, + EXO_MODELS_DIRS, + EXO_MODELS_READ_ONLY_DIRS, + EXO_PEER_DOWNLOAD_PORT, +) from exo.shared.election import Election, ElectionResult from exo.shared.logging import logger_cleanup, logger_set_context, logger_setup from exo.shared.types.common import NodeId, SessionId @@ -44,11 +52,34 @@ class Node: node_id: NodeId offline: bool _api_port: int + _libp2p_port: int + _peer_download_port: int + peer_file_server: PeerFileServer | None = None _tg: TaskGroup = field(init=False, default_factory=TaskGroup) @classmethod async def create(cls, args: "Args") -> Self: - keypair = get_node_id_keypair() + # Codex P1 (PR #16 round-(N+3), main.py:74): scope the on-disk + # node-ID keypair by the *combination* of ports the operator + # has chosen, not just ``--peer-download-port``. The earlier + # peer-download-only scope leaked identity collisions when + # ``--no-downloads`` / ``--no-peer-download`` is set: that + # mode doesn't bind the peer file server, so two same-host + # processes can legitimately keep the default + # ``peer_download_port`` and would then load the same scoped + # keypair file -- producing identical ``NodeId``s and + # breaking election/routing's unique-NodeId invariants. + # + # Combined-port scoping is robust against every same-host + # multi-process configuration: at least one of the listening + # ports MUST differ between processes (libp2p, peer-download, + # api -- each is a distinct local socket bind), so the scope + # tuple differs whenever the actual configuration differs. + # Single-process deployments on default ports keep a stable + # filename (e.g. ``node_id.libp2p-0.api-52415.peer-52416.keypair``) + # so identity persists across restarts. + process_scope = _node_id_keypair_scope(args) + keypair = get_node_id_keypair(process_scope=process_scope) node_id = NodeId(keypair.to_node_id()) session_id = SessionId(master_node_id=node_id, election_clock=0) router = Router.create( @@ -71,17 +102,8 @@ async def create(cls, args: "Args") -> Self: logger.info(f"Starting node {node_id}") - # Create DownloadCoordinator (unless --no-downloads) - if not args.no_downloads: - download_coordinator = DownloadCoordinator( - node_id, - exo_shard_downloader(offline=args.offline), - event_sender=event_router.sender(), - download_command_receiver=router.receiver(topics.DOWNLOAD_COMMANDS), - offline=args.offline, - ) - else: - download_coordinator = None + peer_file_server: PeerFileServer | None = None + peer_download_enabled = not args.no_peer_download and not args.no_downloads if args.spawn_api: api = API( @@ -103,10 +125,47 @@ async def create(cls, args: "Args") -> Self: command_sender=router.sender(topics.COMMANDS), download_command_sender=router.sender(topics.DOWNLOAD_COMMANDS), api_port=args.api_port, + # Each node now binds its own peer-download listener on + # ``--peer-download-port`` (default ``EXO_PEER_DOWNLOAD_PORT``). + # The Worker uses this same value when discovering peers, + # so all nodes in a cluster MUST agree on it (typically + # via the shared ``EXO_PEER_DOWNLOAD_PORT`` env var). + # Pre-fix this was a single import-time module constant, + # making same-host multi-node setups impossible (Codex + # P2, PR #16 round 3). + peer_download_port=args.peer_download_port, ) else: worker = None + if peer_download_enabled: + # Serve from every configured model directory so peers can fetch + # any locally-resident shard regardless of which directory the + # downloader landed it in. ``EXO_MODELS_DIRS`` already includes + # ``EXO_DEFAULT_MODELS_DIR`` as its first entry; ``EXO_MODELS_READ_ONLY_DIRS`` + # captures pre-populated mounts (e.g. shared NFS caches) that + # ``select_download_dir_for_shard`` excludes from new writes but + # which other peers still benefit from being able to read. + peer_file_server = PeerFileServer( + host="0.0.0.0", + port=args.peer_download_port, + models_dirs=(*EXO_MODELS_DIRS, *EXO_MODELS_READ_ONLY_DIRS), + ) + + if not args.no_downloads: + download_coordinator: DownloadCoordinator | None = DownloadCoordinator( + node_id, + exo_shard_downloader( + offline=args.offline, + peer_download_enabled=peer_download_enabled, + ), + event_sender=event_router.sender(), + download_command_receiver=router.receiver(topics.DOWNLOAD_COMMANDS), + offline=args.offline, + ) + else: + download_coordinator = None + # We start every node with a master master = Master( node_id, @@ -144,6 +203,9 @@ async def create(cls, args: "Args") -> Self: node_id, args.offline, args.api_port, + args.libp2p_port, + args.peer_download_port, + peer_file_server, ) logger_set_context( node_id=node_id, role="master" if args.force_master else "node" @@ -161,6 +223,8 @@ async def run(self): tg.start_soon(self.router.run) tg.start_soon(self.event_router.run) tg.start_soon(self.election.run) + if self.peer_file_server: + tg.start_soon(self.peer_file_server.run) if self.download_coordinator: tg.start_soon(self.download_coordinator.run) if self.worker: @@ -169,6 +233,12 @@ async def run(self): tg.start_soon(self.master.run) if self.api: tg.start_soon(self.api.run) + if sys.platform == "darwin" and self._libp2p_port != 0: + tg.start_soon( + _darwin_mdns_broadcast_announcer, + self.node_id, + self._libp2p_port, + ) tg.start_soon(self._elect_loop) tg.start_soon(self._diagnostic_snapshot_loop) @@ -252,7 +322,10 @@ async def _elect_loop(self): await self.download_coordinator.shutdown() self.download_coordinator = DownloadCoordinator( self.node_id, - exo_shard_downloader(offline=self.offline), + exo_shard_downloader( + offline=self.offline, + peer_download_enabled=self.peer_file_server is not None, + ), event_sender=self.event_router.sender(), download_command_receiver=self.router.receiver( topics.DOWNLOAD_COMMANDS @@ -272,6 +345,7 @@ async def _elect_loop(self): topics.DOWNLOAD_COMMANDS ), api_port=self._api_port, + peer_download_port=self._peer_download_port, ) self._tg.start_soon(self.worker.run) if self.api: @@ -361,6 +435,119 @@ def _last_seen_ages(self, state: State) -> dict[str, float]: return ages +def _node_id_keypair_scope(args: "Args") -> str: + """Produce a stable per-process scope for the node-ID keypair file. + + Combines every listening port the operator could plausibly + distinguish between same-host processes: ``--libp2p-port``, + ``--api-port``, and ``--peer-download-port``. At least one of + these MUST differ between two processes that share a host (each + is a distinct local socket bind), so the resulting scope is + always unique per process while remaining stable across + restarts of the same configuration. + + Used by :func:`get_node_id_keypair` to avoid two same-host + processes loading the same scoped keypair file when peer + download is disabled (which would otherwise let them collide + on the default ``peer_download_port`` since no socket is + actually being bound). See Codex P1 (PR #16 round-(N+3), + main.py:74). + + Codex P1 (PR #16 round-(N+8), main.py:457): when + ``--libp2p-port 0`` is set, the configured value is the literal + ``0`` even though each process actually binds a different + ephemeral port at runtime. Two same-host worker-only processes + (no API, no peer download) sharing the default + ``peer_download_port`` and ``api_port`` -- but each binding + ``libp2p_port=0`` -- would otherwise produce identical scope + strings ``"libp2p-0.api-...peer-..."`` and load the same + keypair file, breaking the unique-NodeId invariant. + Stability across restarts is impossible in this configuration + anyway (the OS hands out a different ephemeral port on every + bind), so fold in ``os.getpid()`` as a per-process + discriminator. The trade-off (ephemeral identity for + ephemeral ports) is the right semantic: the operator opted + into ephemeral binding by setting ``libp2p_port=0``. + """ + if args.libp2p_port == 0: + return ( + f"libp2p-pid-{os.getpid()}." + f"api-{args.api_port}.peer-{args.peer_download_port}" + ) + return ( + f"libp2p-{args.libp2p_port}.api-{args.api_port}.peer-{args.peer_download_port}" + ) + + +def _darwin_en0_ip_address() -> str | None: + try: + return subprocess.check_output( + ["ipconfig", "getifaddr", "en0"], + text=True, + stderr=subprocess.DEVNULL, + ).strip() + except (OSError, subprocess.CalledProcessError): + return None + + +def _darwin_en0_broadcast_address(ip_address: str) -> str | None: + try: + subnet_mask = subprocess.check_output( + ["ipconfig", "getoption", "en0", "subnet_mask"], + text=True, + stderr=subprocess.DEVNULL, + ).strip() + interface = ipaddress.IPv4Interface(f"{ip_address}/{subnet_mask}") + return str(interface.network.broadcast_address) + except (OSError, ValueError, subprocess.CalledProcessError): + return None + + +async def _darwin_mdns_broadcast_announcer(node_id: NodeId, libp2p_port: int) -> None: + ip_address = _darwin_en0_ip_address() + if not ip_address: + logger.debug("Darwin mDNS broadcast announcer disabled: no en0 IPv4 address") + return + + broadcast_address = _darwin_en0_broadcast_address(ip_address) + logger.debug( + f"Darwin mDNS announcer advertising {node_id} at {ip_address}:{libp2p_port}" + ) + command = [ + sys.executable, + "-m", + "exo.routing.mdns_announcer", + "--node-id", + str(node_id), + "--ip-address", + ip_address, + "--libp2p-port", + str(libp2p_port), + ] + if broadcast_address is not None: + command.extend(["--broadcast-address", broadcast_address]) + process = subprocess.Popen( + command, + start_new_session=True, + stdout=subprocess.DEVNULL, + ) + try: + while process.poll() is None: + await anyio.sleep(60) + logger.debug( + f"Darwin mDNS announcer subprocess exited with {process.returncode}" + ) + finally: + if process.poll() is None: + process.terminate() + with anyio.move_on_after(2): + while process.poll() is None: + await anyio.sleep(0.1) + if process.poll() is None: + process.kill() + await anyio.sleep(0) + + def main(): args = Args.parse() soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE) @@ -386,6 +573,13 @@ def main(): os.environ["EXO_NO_BATCH"] = "1" logger.info("Continuous batching disabled (--no-batch)") + # Set trust_remote_code override env var for runner subprocesses + if args.trust_remote_code: + os.environ["EXO_TRUST_REMOTE_CODE"] = "1" + logger.warning( + "--trust-remote-code enabled: models may execute arbitrary code during loading" + ) + # Set FAST_SYNCH override env var for runner subprocesses if args.fast_synch is True: os.environ["EXO_FAST_SYNCH"] = "true" @@ -430,11 +624,24 @@ class Args(FrozenModel): tb_only: bool = False no_worker: bool = False no_downloads: bool = False + no_peer_download: bool = False offline: bool = os.getenv("EXO_OFFLINE", "false").lower() == "true" no_batch: bool = False fast_synch: bool | None = None # None = auto, True = force on, False = force off bootstrap_peers: list[str] = [] libp2p_port: int + # Per-process listener port for peer-to-peer model file serving. + # Defaults to ``EXO_PEER_DOWNLOAD_PORT`` so existing single-node-per- + # host deployments keep working unchanged. Operators running + # multiple nodes on the same host MUST set this to a distinct value + # for each process; the cluster-wide convention is that every node + # exposes the same port, since peer discovery currently uses each + # node's local value as the assumed remote endpoint (see + # ``Worker._peer_download_port``). A future state-sync change can + # advertise per-node ports across the cluster -- tracked as a + # follow-up to Codex P2 (PR #16 round 3). + peer_download_port: PositiveInt = EXO_PEER_DOWNLOAD_PORT + trust_remote_code: bool = False @classmethod def parse(cls) -> Self: @@ -481,6 +688,11 @@ def parse(cls) -> Self: action="store_true", help="Disable the download coordinator (node won't download models)", ) + parser.add_argument( + "--no-peer-download", + action="store_true", + help="Disable peer-to-peer model downloads (each node downloads from HuggingFace independently)", + ) parser.add_argument( "--offline", action="store_true", @@ -492,6 +704,11 @@ def parse(cls) -> Self: action="store_true", help="Disable continuous batching, use sequential generation", ) + parser.add_argument( + "--trust-remote-code", + action="store_true", + help="Allow models to execute custom code during tokenizer loading (security-sensitive, CLI-only)", + ) parser.add_argument( "--bootstrap-peers", type=lambda s: [p for p in s.split(",") if p], @@ -508,6 +725,22 @@ def parse(cls) -> Self: dest="libp2p_port", help="Fixed TCP port for libp2p to listen on (0 = OS-assigned).", ) + parser.add_argument( + "--peer-download-port", + type=int, + default=EXO_PEER_DOWNLOAD_PORT, + dest="peer_download_port", + help=( + "TCP port for peer-to-peer model file serving (default: " + "EXO_PEER_DOWNLOAD_PORT, currently 52416). Required to " + "differ between processes when running multiple nodes " + "on the same host; otherwise the second node's " + "PeerFileServer hits 'address already in use'. All " + "nodes in a cluster must use the same value (peer " + "discovery uses the local port as the assumed remote " + "port)." + ), + ) fast_synch_group = parser.add_mutually_exclusive_group() fast_synch_group.add_argument( "--fast-synch", diff --git a/src/exo/master/main.py b/src/exo/master/main.py index 888c39e4c8..2e3a104fcf 100644 --- a/src/exo/master/main.py +++ b/src/exo/master/main.py @@ -392,6 +392,7 @@ async def _command_processor(self) -> None: self.state.node_memory, self.state.node_network, download_status=self.state.downloads, + node_rdma_ctl=self.state.node_rdma_ctl, ) transition_events = get_transition_events( self.state.instances, placement, self.state.tasks @@ -481,6 +482,20 @@ async def _command_processor(self) -> None: # These plan loops are the cracks showing in our event sourcing architecture - more things could be commands async def _plan(self) -> None: + # Codex P1 (PR #16 round-(N+9), master/main.py:486): the + # inactivity timeout MUST stay safely above ``NodeGatheredInfo`` + # cadence jitter -- 5s was too tight (any node that didn't + # publish telemetry within 5s, e.g. when fast probes are + # unavailable or delayed, would be marked timed out and have + # its instances deleted in the same _plan loop). Because + # this loop now ticks every second, normal jitter caused + # repeated false-positive ``NodeTimedOut`` events and + # unnecessary instance churn. Restore the upstream-safe + # 30s budget while keeping the 1s tick so the master still + # reacts quickly when a node *does* genuinely time out. + node_inactivity_timeout = timedelta(seconds=30) + tick_interval_seconds = 1.0 + while True: # kill broken instances connected_node_ids = set(self.state.topology.list_nodes()) @@ -503,7 +518,7 @@ async def _plan(self) -> None: # time out dead nodes for node_id, time in self.state.last_seen.items(): now = datetime.now(tz=timezone.utc) - if now - time > timedelta(seconds=30): + if now - time > node_inactivity_timeout: impacted_instances = [ str(instance_id) for instance_id, instance in self.state.instances.items() @@ -520,7 +535,7 @@ async def _plan(self) -> None: ) await self.event_sender.send(NodeTimedOut(node_id=node_id)) - await anyio.sleep(10) + await anyio.sleep(tick_interval_seconds) async def _event_processor(self) -> None: with self.local_event_receiver as local_events: diff --git a/src/exo/master/placement.py b/src/exo/master/placement.py index f2d4066a9c..4cb0b7d646 100644 --- a/src/exo/master/placement.py +++ b/src/exo/master/placement.py @@ -36,6 +36,7 @@ MemoryUsage, NetworkInterfaceInfo, NodeNetworkInfo, + NodeRdmaCtlStatus, ) from exo.shared.types.tasks import Task, TaskId, TaskStatus from exo.shared.types.topology import SocketConnection @@ -136,6 +137,7 @@ def place_instance( allowed_nodes: set[NodeId] | None = None, allow_single_node_total_memory: bool = False, download_status: Mapping[NodeId, Sequence[DownloadProgress]] | None = None, + node_rdma_ctl: Mapping[NodeId, NodeRdmaCtlStatus] | None = None, ) -> dict[InstanceId, Instance]: sharding = command.sharding instance_meta = command.instance_meta @@ -263,9 +265,18 @@ def place_instance( ) smallest_cycles = get_smallest_cycles(cycles_with_sufficient_memory) + rdma_ctl_status = node_rdma_ctl or {} + + def _all_rdma_ctl_enabled(cycle: Cycle) -> bool: + return all( + ((status := rdma_ctl_status.get(node_id)) is not None and status.enabled) + for node_id in cycle + ) smallest_rdma_cycles = [ - cycle for cycle in smallest_cycles if topology.is_rdma_cycle(cycle) + cycle + for cycle in smallest_cycles + if topology.is_rdma_cycle(cycle) and _all_rdma_ctl_enabled(cycle) ] if instance_meta == InstanceMeta.MlxJaccl: diff --git a/src/exo/master/tests/test_placement.py b/src/exo/master/tests/test_placement.py index f102b6f2b7..fb2df5a1ae 100644 --- a/src/exo/master/tests/test_placement.py +++ b/src/exo/master/tests/test_placement.py @@ -26,6 +26,7 @@ MemoryUsage, NetworkInterfaceInfo, NodeNetworkInfo, + NodeRdmaCtlStatus, ) from exo.shared.types.tasks import TaskId, TaskStatus, TextGeneration from exo.shared.types.text_generation import ( @@ -569,8 +570,21 @@ def test_tensor_rdma_backend_connectivity_matrix( min_nodes=1, ) + node_rdma_ctl = { + node_a: NodeRdmaCtlStatus(enabled=True), + node_b: NodeRdmaCtlStatus(enabled=True), + node_c: NodeRdmaCtlStatus(enabled=True), + } + # act - placements = place_instance(cic, topology, {}, node_memory, node_network) + placements = place_instance( + cic, + topology, + {}, + node_memory, + node_network, + node_rdma_ctl=node_rdma_ctl, + ) # assert assert len(placements) == 1 @@ -611,7 +625,6 @@ def test_tensor_rdma_backend_connectivity_matrix( ip_part = coordinator.split(":")[0] assert len(ip_part.split(".")) == 4 - def test_qwen3_5_tensor_auto_upgrade_requires_opt_in( monkeypatch: pytest.MonkeyPatch, ) -> None: @@ -651,6 +664,10 @@ def test_qwen3_5_tensor_auto_upgrade_requires_opt_in( model_card=model_card, min_nodes=2, ) + node_rdma_ctl = { + large_node: NodeRdmaCtlStatus(enabled=True), + small_node: NodeRdmaCtlStatus(enabled=True), + } placements_without_opt_in = place_instance( command, @@ -664,6 +681,7 @@ def test_qwen3_5_tensor_auto_upgrade_requires_opt_in( large_node: create_jaccl_node_network("192.168.0.1"), small_node: create_jaccl_node_network("192.168.0.2"), }, + node_rdma_ctl=node_rdma_ctl, ) instance_without_opt_in = next(iter(placements_without_opt_in.values())) large_runner_without_opt_in = ( @@ -690,6 +708,7 @@ def test_qwen3_5_tensor_auto_upgrade_requires_opt_in( large_node: create_jaccl_node_network("192.168.0.1"), small_node: create_jaccl_node_network("192.168.0.2"), }, + node_rdma_ctl=node_rdma_ctl, ) instance = next(iter(placements.values())) @@ -1004,8 +1023,19 @@ def test_jaccl_placement_uses_advertised_lan_ip_for_rdma_coordinator( model_card=model_card, min_nodes=2, ) + node_rdma_ctl = { + node_a: NodeRdmaCtlStatus(enabled=True), + node_b: NodeRdmaCtlStatus(enabled=True), + } - placements = place_instance(command, topology, {}, node_memory, node_network) + placements = place_instance( + command, + topology, + {}, + node_memory, + node_network, + node_rdma_ctl=node_rdma_ctl, + ) instance = list(placements.values())[0] assert isinstance(instance, MlxJacclInstance) @@ -1493,6 +1523,129 @@ def test_placement_prefers_socket_reachable_rank_zero( assert shard.device_rank == 0 +def _build_three_node_rdma_topology() -> tuple[ + Topology, NodeId, NodeId, NodeId, dict[NodeId, NodeNetworkInfo] +]: + topology = Topology() + node_a = NodeId() + node_b = NodeId() + node_c = NodeId() + + ethernet_interface = NetworkInterfaceInfo(name="en0", ip_address="10.0.0.1") + ethernet_conn = SocketConnection( + sink_multiaddr=Multiaddr(address="/ip4/10.0.0.1/tcp/8000") + ) + node_network = { + node_a: NodeNetworkInfo(interfaces=[ethernet_interface]), + node_b: NodeNetworkInfo(interfaces=[ethernet_interface]), + node_c: NodeNetworkInfo(interfaces=[ethernet_interface]), + } + + for node_id in (node_a, node_b, node_c): + topology.add_node(node_id) + + for source, sink, iface in ( + (node_a, node_b, 3), + (node_b, node_a, 3), + (node_b, node_c, 4), + (node_c, node_b, 4), + (node_a, node_c, 5), + (node_c, node_a, 5), + ): + topology.add_connection( + Connection(source=source, sink=sink, edge=create_rdma_connection(iface)) + ) + + for source, sink in ( + (node_a, node_b), + (node_b, node_c), + (node_c, node_a), + (node_a, node_c), + (node_b, node_a), + (node_c, node_b), + ): + topology.add_connection( + Connection(source=source, sink=sink, edge=ethernet_conn) + ) + + return topology, node_a, node_b, node_c, node_network + + +def test_place_mlx_jaccl_rejects_when_a_node_has_rdma_ctl_disabled( + model_card: ModelCard, +) -> None: + model_card = model_card.model_copy( + update={"n_layers": 12, "storage_size": Memory.from_bytes(1500)} + ) + topology, node_a, node_b, node_c, node_network = _build_three_node_rdma_topology() + node_memory = { + node_a: create_node_memory(500), + node_b: create_node_memory(500), + node_c: create_node_memory(500), + } + node_rdma_ctl = { + node_a: NodeRdmaCtlStatus(enabled=True), + node_b: NodeRdmaCtlStatus(enabled=True), + node_c: NodeRdmaCtlStatus(enabled=False), + } + command = PlaceInstance( + sharding=Sharding.Tensor, + instance_meta=InstanceMeta.MlxJaccl, + command_id=CommandId(), + model_card=model_card, + min_nodes=3, + ) + + with pytest.raises( + ValueError, match="Requested RDMA \\(MlxJaccl\\) but no RDMA-connected cycles" + ): + place_instance( + command, + topology, + {}, + node_memory, + node_network, + node_rdma_ctl=node_rdma_ctl, + ) + + +def test_place_mlx_jaccl_rejects_when_node_rdma_ctl_missing( + model_card: ModelCard, +) -> None: + model_card = model_card.model_copy( + update={"n_layers": 12, "storage_size": Memory.from_bytes(1500)} + ) + topology, node_a, node_b, node_c, node_network = _build_three_node_rdma_topology() + node_memory = { + node_a: create_node_memory(500), + node_b: create_node_memory(500), + node_c: create_node_memory(500), + } + node_rdma_ctl = { + node_a: NodeRdmaCtlStatus(enabled=True), + node_b: NodeRdmaCtlStatus(enabled=True), + } + command = PlaceInstance( + sharding=Sharding.Tensor, + instance_meta=InstanceMeta.MlxJaccl, + command_id=CommandId(), + model_card=model_card, + min_nodes=3, + ) + + with pytest.raises( + ValueError, match="Requested RDMA \\(MlxJaccl\\) but no RDMA-connected cycles" + ): + place_instance( + command, + topology, + {}, + node_memory, + node_network, + node_rdma_ctl=node_rdma_ctl, + ) + + def _make_task( instance_id: InstanceId, status: TaskStatus = TaskStatus.Running, diff --git a/src/exo/routing/mdns_announcer.py b/src/exo/routing/mdns_announcer.py new file mode 100644 index 0000000000..cafd1d3acc --- /dev/null +++ b/src/exo/routing/mdns_announcer.py @@ -0,0 +1,97 @@ +import argparse +import contextlib +import random +import socket +import string +import struct +import sys +import time +from typing import final + + +def _dns_qname(name: bytes) -> bytes: + return b"".join(bytes([len(part)]) + part for part in name.split(b".")) + b"\0" + + +def _build_response_packet(node_id: str, ip_address: str, libp2p_port: int) -> bytes: + service_name = b"_p2p._udp.local" + peer_name = ( + "".join(random.choice(string.ascii_letters + string.digits) for _ in range(32)) + + "._p2p._udp.local" + ).encode() + txt_record = f"dnsaddr=/ip4/{ip_address}/tcp/{libp2p_port}/p2p/{node_id}".encode() + + peer_qname = _dns_qname(peer_name) + packet = bytearray() + packet += struct.pack("!HHHHHH", 0, 0x8400, 0, 1, 0, 1) + packet += _dns_qname(service_name) + packet += struct.pack("!HHI", 12, 1, 120) + packet += struct.pack("!H", len(peer_qname)) + packet += peer_qname + packet += peer_qname + packet += struct.pack("!HHI", 16, 1, 120) + packet += struct.pack("!H", len(txt_record) + 1) + packet += bytes([len(txt_record)]) + packet += txt_record + return bytes(packet) + + +@final +class Args(argparse.Namespace): + node_id: str + ip_address: str + libp2p_port: int + broadcast_address: str | None + count: int + + @staticmethod + def parse() -> "Args": + parser = argparse.ArgumentParser() + parser.add_argument("--node-id", required=True) + parser.add_argument("--ip-address", required=True) + parser.add_argument("--libp2p-port", required=True, type=int) + parser.add_argument("--broadcast-address") + parser.add_argument("--count", default=0, type=int) + return parser.parse_args(namespace=Args()) + + +def main() -> None: + args = Args.parse() + sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + with contextlib.suppress(OSError): + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) + sock.bind((args.ip_address, 0)) + + sent_count = 0 + while True: + packet = _build_response_packet( + args.node_id, args.ip_address, args.libp2p_port + ) + errors: list[str] = [] + destinations: list[tuple[str, int]] = [] + if args.broadcast_address is not None: + destinations.append((args.broadcast_address, 5353)) + destinations.extend([("255.255.255.255", 5353), ("224.0.0.251", 5353)]) + sent = False + for destination in destinations: + try: + sock.sendto(packet, destination) + sent = True + except OSError as err: + errors.append(f"{destination}: {err}") + if not sent: + print( + f"mDNS announcer send failed: {'; '.join(errors)}", + file=sys.stderr, + flush=True, + ) + sent_count += 1 + if args.count > 0 and sent_count >= args.count: + return + time.sleep(1.0 if sent_count < 60 else 10.0) + + +if __name__ == "__main__": + main() diff --git a/src/exo/routing/router.py b/src/exo/routing/router.py index ebe0ea8d90..5e42639475 100644 --- a/src/exo/routing/router.py +++ b/src/exo/routing/router.py @@ -24,7 +24,7 @@ from filelock import FileLock from loguru import logger -from exo.shared.constants import EXO_NODE_ID_KEYPAIR +from exo.shared.constants import EXO_LEGACY_NODE_ID_KEYPAIR, EXO_NODE_ID_KEYPAIR from exo.utils.channels import Receiver, Sender, channel from exo.utils.pydantic_ext import FrozenModel from exo.utils.task_group import TaskGroup @@ -293,20 +293,102 @@ def _clear_publish_failures(self, topic: str) -> None: def get_node_id_keypair( path: str | bytes | PathLike[str] | PathLike[bytes] = EXO_NODE_ID_KEYPAIR, + legacy_path: str | bytes | PathLike[str] | PathLike[bytes] | None = ( + EXO_LEGACY_NODE_ID_KEYPAIR + ), + process_scope: int | str | None = None, ) -> Keypair: """ Obtains the :class:`Keypair` associated with this node-ID. Obtain the :class:`PeerId` by from it. - """ - # TODO(evan): bring back node id persistence once we figure out how to deal with duplicates - return Keypair.generate() - - def lock_path(path: str | bytes | PathLike[str] | PathLike[bytes]) -> Path: - return Path(str(path) + ".lock") - # operate with cross-process lock to avoid race conditions - with FileLock(lock_path(path)): - with open(path, "a+b") as f: # opens in append-mode => starts at EOF + Codex P1 (PR #16 round-(N+2), router.py:297): when ``process_scope`` + is provided, the on-disk keypair filename is suffixed with the + scope (typically the libp2p / peer-download port the caller has + chosen). This preserves *per-process* node identity isolation + when multiple exo processes run on the same host -- the new + same-host multi-node workflow added in this PR (distinct + peer-download ports per process) needs each process to have a + distinct ``NodeId`` so peer discovery's ``peer_node_id == + node_id`` self-skip and routing's unique-node-id assumptions + hold. Single-process deployments leave ``process_scope=None`` + and continue using the shared persistent keypair file. + + On first call after the upgrade, if the new ``path`` (config dir) + has no keypair yet but the legacy cache-dir ``legacy_path`` does, + the legacy file is moved to ``path`` so the node retains its + identity across the relocation. Migration is best-effort: if + moving fails (e.g. cross-device link errors on Linux when + ``XDG_*`` dirs span filesystems), the legacy bytes are copied + instead. Either way, the legacy file is removed once the new + location holds a valid keypair so subsequent calls do not need + to re-check. Codex P2 (PR #16 round-(N+2), router.py:322): the + migration is performed INSIDE the file lock so two concurrent + processes can't both pass the existence check and then race + each other into divergent in-memory vs. on-disk identities. + Codex P1 (PR #16 round-(N+13), router.py:359): when callers + pass distinct ``process_scope`` values, the per-scope lock + above does NOT serialize legacy adoption across scopes, so a + second lock keyed on the (unscoped) legacy path is acquired + before invoking the migrator -- otherwise the cross-device + byte-copy fallback can produce duplicate ``NodeId``s. + """ + base_path = Path(str(path)) + resolved_path = ( + _scoped_keypair_path(base_path, process_scope) + if process_scope is not None + else base_path + ) + + # The legacy cache file pre-dates the per-process scoping change + # so it is intentionally NOT scope-suffixed. We migrate it as a + # one-shot identity adoption for whichever process happens to + # boot first; subsequent processes (with different scopes) will + # observe the legacy file already gone and start with fresh + # keypairs, which is exactly what per-process isolation requires. + resolved_legacy: Path | None = ( + Path(str(legacy_path)) if legacy_path is not None else None + ) + + def lock_path(p: str | bytes | PathLike[str] | PathLike[bytes]) -> Path: + return Path(str(p) + ".lock") + + resolved_path.parent.mkdir(parents=True, exist_ok=True) + + # operate with cross-process lock to avoid race conditions. + # The migration MUST run inside this lock so two processes that + # boot simultaneously can't both pass the migrator's existence + # check, race the keypair generation, and end up with the same + # on-disk file but divergent in-memory identities. + with FileLock(lock_path(resolved_path)): + if resolved_legacy is not None: + # Codex P1 (PR #16 round-(N+13), router.py:359): + # serialize legacy adoption across ALL ``process_scope`` + # values. The outer ``resolved_path`` lock is per-scope, + # so two same-host processes with different scopes + # acquire DIFFERENT lock files and can each enter + # ``_migrate_legacy_node_id_keypair`` concurrently. In + # the cross-device fallback path -- where ``replace()`` + # raises ``OSError`` and the migrator falls back to a + # ``read_bytes`` + ``write_bytes`` + ``unlink`` + # sequence -- both processes can read the same legacy + # keypair before either unlinks it, then each writes + # those bytes into its own scoped file. Result: two + # nodes claiming the same ``NodeId`` despite distinct + # scopes, breaking routing's unique-identity and + # election's tiebreaker invariants. A lock keyed on the + # legacy path (which is intentionally NOT scope-suffixed + # because it pre-dates scoping) serializes migration so + # exactly one scope wins legacy adoption and any + # concurrent peers observe the file already gone and + # generate fresh keypairs -- the documented "first + # process boots wins" semantic. Released immediately + # after migration so unrelated keypair I/O on other + # scopes isn't blocked on identity housekeeping. + with FileLock(lock_path(resolved_legacy)): + _migrate_legacy_node_id_keypair(resolved_path, resolved_legacy) + + with open(resolved_path, "a+b") as f: # opens in append-mode => starts at EOF # if non-zero EOF, then file exists => use to get node-ID if f.tell() != 0: f.seek(0) # go to start & read protobuf-encoded bytes @@ -318,7 +400,69 @@ def lock_path(path: str | bytes | PathLike[str] | PathLike[bytes]) -> Path: logger.warning(f"Encountered error when trying to get keypair: {e}") # if no valid credentials, create new ones and persist - with open(path, "w+b") as f: + with open(resolved_path, "w+b") as f: keypair = Keypair.generate() f.write(keypair.to_bytes()) return keypair + + +def _scoped_keypair_path(base: Path, scope: int | str) -> Path: + """Return ``base`` with the process scope inserted before the + suffix (e.g. ``node_id.keypair`` + scope ``52415`` -> + ``node_id.52415.keypair``). + + We insert the scope as a stem-suffix rather than as a directory + so concurrent processes on the same host share the parent dir + (and the file lock's inode-level coordination still works for + legacy-migration safety) while their identity files remain + distinct. Scope is rendered with ``str()`` so callers can pass + a port number, a UUID, a hostname, etc. + """ + suffix = base.suffix or ".keypair" + stem = base.stem if base.suffix else base.name + return base.parent / f"{stem}.{scope}{suffix}" + + +def _migrate_legacy_node_id_keypair( + new_path: Path, + legacy_path: Path, +) -> None: + """One-shot migrator for the cache→config relocation of the + node-ID keypair (Codex P1 PR #16 round 5). + + Idempotent and best-effort: only acts when ``new_path`` is + absent and ``legacy_path`` exists. Falls back to byte copy if + ``rename`` fails (cross-device, permissions, etc.). On any + exception we log and bail -- the caller will then generate a + fresh keypair, which is suboptimal but better than crashing + startup over identity-file housekeeping. + """ + try: + if new_path.exists() or not legacy_path.exists(): + return + # Ensure the destination directory exists for either the + # ``replace`` (which silently no-ops on missing parent on some + # platforms but raises ``ENOENT`` on others) or the byte-copy + # fallback. ``get_node_id_keypair`` already creates this dir + # for the same reason; doing it again here keeps the migrator + # safely callable from tests in isolation. + new_path.parent.mkdir(parents=True, exist_ok=True) + try: + legacy_path.replace(new_path) + except OSError as rename_err: + logger.debug( + f"Cross-device rename of legacy keypair failed ({rename_err}); " + "falling back to byte copy." + ) + new_path.write_bytes(legacy_path.read_bytes()) + legacy_path.unlink(missing_ok=True) + logger.info( + f"Migrated node-ID keypair from legacy cache path {legacy_path} " + f"to persistent config path {new_path}." + ) + except Exception as e: + logger.warning( + f"Failed to migrate legacy node-ID keypair from {legacy_path} " + f"to {new_path}: {e}. The node will generate a new identity; " + "manually copy the file if cluster membership matters." + ) diff --git a/src/exo/routing/tests/test_node_id_migration.py b/src/exo/routing/tests/test_node_id_migration.py new file mode 100644 index 0000000000..bfa0a8f8a9 --- /dev/null +++ b/src/exo/routing/tests/test_node_id_migration.py @@ -0,0 +1,535 @@ +"""Regression tests for the cache→config migration of the node-ID +keypair (Codex P1, PR #16 round 5). + +The keypair used to live under ``EXO_CACHE_HOME``, which is subject +to normal cache cleanup (e.g. ``trash ~/.cache/exo``) and would +silently regenerate a new node-ID. The fix relocates the keypair to +``EXO_CONFIG_HOME`` and migrates legacy files transparently. +""" + +from __future__ import annotations + +from pathlib import Path + +import pytest +from exo_pyo3_bindings import Keypair + +from exo.routing.router import ( + _migrate_legacy_node_id_keypair, # pyright: ignore[reportPrivateUsage] + get_node_id_keypair, +) + + +def test_legacy_keypair_is_migrated_to_new_location(tmp_path: Path) -> None: + """Legacy cache-dir keypair must be moved to the new config-dir + location and the legacy file removed -- so the node retains its + identity across the upgrade and a future cache wipe doesn't + resurrect a stale copy.""" + legacy_path = tmp_path / "cache" / "node_id.keypair" + new_path = tmp_path / "config" / "node_id.keypair" + legacy_path.parent.mkdir(parents=True) + + keypair = Keypair.generate() + legacy_bytes = keypair.to_bytes() + legacy_path.write_bytes(legacy_bytes) + + _migrate_legacy_node_id_keypair(new_path, legacy_path) + + assert new_path.exists(), "migration must place keypair at new location" + assert new_path.read_bytes() == legacy_bytes, ( + "migration must preserve the byte-for-byte keypair contents " + "so the node retains its peer ID" + ) + assert not legacy_path.exists(), ( + "migration must remove the legacy file once the new location " + "holds the keypair, otherwise a later cache wipe could " + "resurrect a now-stale copy" + ) + + +def test_migration_is_idempotent_when_new_location_already_present( + tmp_path: Path, +) -> None: + """If the new location already has a keypair, migration must be + a no-op even when a legacy file exists -- otherwise we'd + overwrite the (canonical) new keypair with a stale legacy one.""" + legacy_path = tmp_path / "cache" / "node_id.keypair" + new_path = tmp_path / "config" / "node_id.keypair" + legacy_path.parent.mkdir(parents=True) + new_path.parent.mkdir(parents=True) + + canonical = Keypair.generate().to_bytes() + legacy = Keypair.generate().to_bytes() + new_path.write_bytes(canonical) + legacy_path.write_bytes(legacy) + + _migrate_legacy_node_id_keypair(new_path, legacy_path) + + assert new_path.read_bytes() == canonical, ( + "migration must NOT overwrite an existing new-location keypair" + ) + # We deliberately leave the legacy file alone in this branch: + # touching it would surprise an operator who is intentionally + # keeping both copies during an upgrade window. + assert legacy_path.exists() + + +def test_migration_skipped_when_no_legacy_file(tmp_path: Path) -> None: + """Fresh installs must not error when the legacy path is absent.""" + new_path = tmp_path / "config" / "node_id.keypair" + new_path.parent.mkdir(parents=True) + + _migrate_legacy_node_id_keypair(new_path, tmp_path / "missing.keypair") + + assert not new_path.exists() + + +def test_get_node_id_keypair_uses_migrated_legacy_keypair(tmp_path: Path) -> None: + """End-to-end: ``get_node_id_keypair`` must surface the legacy + keypair bytes when only the legacy path holds a valid file at + call time, completing the cache→config migration on first use.""" + legacy_path = tmp_path / "cache" / "node_id.keypair" + new_path = tmp_path / "config" / "node_id.keypair" + legacy_path.parent.mkdir(parents=True) + + keypair = Keypair.generate() + expected_bytes = keypair.to_bytes() + legacy_path.write_bytes(expected_bytes) + + loaded = get_node_id_keypair(path=new_path, legacy_path=legacy_path) + + assert loaded.to_bytes() == expected_bytes + assert new_path.exists() + assert not legacy_path.exists() + + +# --------------------------------------------------------------------------- +# Codex P1 (PR #16 round-(N+2), router.py:297): per-process scoping +# --------------------------------------------------------------------------- +# +# The new same-host multi-node workflow (per-process +# ``--peer-download-port``) requires distinct ``NodeId``s per +# process so peer-discovery's self-skip and routing's unique-NodeId +# invariants hold. ``get_node_id_keypair`` therefore accepts a +# ``process_scope`` argument that is folded into the on-disk +# filename. + + +def test_distinct_process_scopes_produce_distinct_keypairs(tmp_path: Path) -> None: + """Two processes that pass different scopes (e.g. distinct + peer-download ports) MUST end up with different keypair files + and different on-disk identities; otherwise two same-host + nodes would race on the same NodeId.""" + base_path = tmp_path / "config" / "node_id.keypair" + + keypair_a = get_node_id_keypair( + path=base_path, legacy_path=None, process_scope=52416 + ) + keypair_b = get_node_id_keypair( + path=base_path, legacy_path=None, process_scope=52417 + ) + + assert keypair_a.to_bytes() != keypair_b.to_bytes(), ( + "distinct process scopes must yield distinct keypairs so " + "same-host multi-node deployments don't share a NodeId" + ) + + scoped_a = base_path.parent / "node_id.52416.keypair" + scoped_b = base_path.parent / "node_id.52417.keypair" + assert scoped_a.exists() + assert scoped_b.exists() + assert scoped_a.read_bytes() != scoped_b.read_bytes() + + +def test_same_process_scope_is_stable_across_calls(tmp_path: Path) -> None: + """Per-process scoping must remain *persistent*: the same + process (same scope) must load the same keypair on subsequent + calls -- otherwise restart would silently churn NodeIds.""" + base_path = tmp_path / "config" / "node_id.keypair" + + first = get_node_id_keypair(path=base_path, legacy_path=None, process_scope=52416) + second = get_node_id_keypair(path=base_path, legacy_path=None, process_scope=52416) + + assert first.to_bytes() == second.to_bytes() + + +def test_migration_runs_inside_file_lock(tmp_path: Path) -> None: + """Codex P2 (PR #16 round-(N+2), router.py:322): the legacy + migration must execute *inside* ``FileLock`` so two processes + booting concurrently can't both pass the existence check, race + each other into divergent in-memory keypairs, and end up with + mismatched identities for the same on-disk file. + + We assert this structurally by hooking ``_migrate_legacy_node_id_keypair`` + and ``filelock.FileLock`` and verifying the lock is acquired + *before* the migrator is called. A pre-lock migration would + show ``migrate_called=True`` while the lock is still + ``unacquired``.""" + import exo.routing.router as router_mod + + legacy_path = tmp_path / "cache" / "node_id.keypair" + base_path = tmp_path / "config" / "node_id.keypair" + legacy_path.parent.mkdir(parents=True) + legacy_path.write_bytes(Keypair.generate().to_bytes()) + + lock_state: dict[str, bool] = {"acquired": False, "acquired_before_migrate": False} + + # We hook ``router_mod.FileLock`` (the symbol the production + # code dereferences) with a thin wrapper class. The wrapper + # delegates to the real ``FileLock`` instance but flips the + # ``acquired`` flag on entry, which the migrator hook below + # then snapshots. This keeps the type of ``FileLock`` intact + # while letting us observe acquire-vs-migrate ordering. + real_filelock = router_mod.FileLock + + class _ObservingFileLock: + def __init__(self, *args: object, **kwargs: object) -> None: + self._inner = real_filelock(*args, **kwargs) # pyright: ignore[reportArgumentType] + + def __enter__(self) -> object: + lock_state["acquired"] = True + return self._inner.__enter__() + + def __exit__(self, *exc: object) -> object: + return self._inner.__exit__(*exc) # pyright: ignore[reportArgumentType] + + original_migrate = router_mod._migrate_legacy_node_id_keypair # pyright: ignore[reportPrivateUsage] + + def _track_migrate(new_path: Path, legacy: Path) -> None: + lock_state["acquired_before_migrate"] = lock_state["acquired"] + original_migrate(new_path, legacy) + + router_mod.FileLock = _ObservingFileLock + router_mod._migrate_legacy_node_id_keypair = _track_migrate # pyright: ignore[reportPrivateUsage] + try: + _ = get_node_id_keypair(path=base_path, legacy_path=legacy_path) + finally: + router_mod.FileLock = real_filelock + router_mod._migrate_legacy_node_id_keypair = original_migrate # pyright: ignore[reportPrivateUsage] + + assert lock_state["acquired_before_migrate"] is True, ( + "legacy migration must run INSIDE the FileLock to prevent a " + "concurrent-startup race on the on-disk keypair" + ) + + +class TestNodeIdKeypairScope: + """Codex P1 (PR #16 round-(N+3), main.py:74): the node-ID keypair + scope MUST account for every distinguishable per-process port, + not just ``--peer-download-port``. With peer-download disabled + the operator can legitimately keep the default + ``peer_download_port`` (no socket bind), so the previous + peer-only scope let two same-host processes share an identity. + """ + + def _build_args( + self, + *, + libp2p_port: int = 0, + api_port: int = 52415, + peer_download_port: int = 52416, + no_downloads: bool = False, + no_peer_download: bool = False, + spawn_api: bool = False, + ): # noqa: ANN202 + from exo.main import Args + + return Args( + libp2p_port=libp2p_port, + api_port=api_port, + peer_download_port=peer_download_port, + no_downloads=no_downloads, + no_peer_download=no_peer_download, + spawn_api=spawn_api, + ) + + def test_distinct_libp2p_ports_yield_distinct_scopes(self) -> None: + from exo.main import ( + _node_id_keypair_scope, # pyright: ignore[reportPrivateUsage] + ) + + scope_a = _node_id_keypair_scope(self._build_args(libp2p_port=4001)) + scope_b = _node_id_keypair_scope(self._build_args(libp2p_port=4002)) + assert scope_a != scope_b + + def test_distinct_api_ports_yield_distinct_scopes(self) -> None: + from exo.main import ( + _node_id_keypair_scope, # pyright: ignore[reportPrivateUsage] + ) + + scope_a = _node_id_keypair_scope(self._build_args(api_port=52415)) + scope_b = _node_id_keypair_scope(self._build_args(api_port=52416)) + assert scope_a != scope_b + + def test_distinct_peer_download_ports_yield_distinct_scopes(self) -> None: + from exo.main import ( + _node_id_keypair_scope, # pyright: ignore[reportPrivateUsage] + ) + + scope_a = _node_id_keypair_scope(self._build_args(peer_download_port=52416)) + scope_b = _node_id_keypair_scope(self._build_args(peer_download_port=52417)) + assert scope_a != scope_b + + def test_disabled_peer_download_with_same_default_port_still_isolates( + self, + ) -> None: + """The original Codex P1 (round-(N+3)) regression: with + ``--no-peer-download`` two processes can both keep + ``peer_download_port=52416``. They MUST still get distinct + scopes when *some* other port differs (here, libp2p). + Pre-fix the scope was just ``peer_download_port`` and these + two configs collided on the same keypair.""" + from exo.main import ( + _node_id_keypair_scope, # pyright: ignore[reportPrivateUsage] + ) + + process_one = self._build_args( + libp2p_port=4001, + no_peer_download=True, + peer_download_port=52416, + ) + process_two = self._build_args( + libp2p_port=4002, + no_peer_download=True, + peer_download_port=52416, + ) + assert _node_id_keypair_scope(process_one) != _node_id_keypair_scope( + process_two + ) + + def test_identical_args_yield_identical_scope(self) -> None: + """Stability invariant: the same configuration on a single + process across restarts must hash to the same scope so the + node retains its identity across restarts.""" + from exo.main import ( + _node_id_keypair_scope, # pyright: ignore[reportPrivateUsage] + ) + + args = self._build_args( + libp2p_port=4001, api_port=52415, peer_download_port=52416 + ) + assert _node_id_keypair_scope(args) == _node_id_keypair_scope(args) + + def test_libp2p_port_zero_uses_pid_for_per_process_isolation(self) -> None: + """Codex P1 (PR #16 round-(N+8), main.py:457): with + ``--libp2p-port 0`` the configured port is the literal ``0`` + even though each process binds a different ephemeral port at + runtime. Without per-process discrimination two same-host + worker-only processes (no API, no peer download) sharing the + default ``peer_download_port`` and ``api_port`` would collide + on the same scoped keypair. The scope must therefore fold in + ``os.getpid()`` (or another guaranteed per-process + discriminator) when ``libp2p_port == 0``.""" + import os + + from exo.main import ( + _node_id_keypair_scope, # pyright: ignore[reportPrivateUsage] + ) + + scope = _node_id_keypair_scope( + self._build_args( + libp2p_port=0, + api_port=52415, + peer_download_port=52416, + no_peer_download=True, + spawn_api=False, + ) + ) + + assert f"pid-{os.getpid()}" in scope, ( + f"libp2p_port=0 must mix in os.getpid() to discriminate " + f"same-host processes binding ephemeral libp2p ports; " + f"got scope={scope!r}" + ) + + def test_libp2p_port_zero_in_two_processes_yield_distinct_scopes( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + """End-to-end: simulate two same-host processes both binding + ``libp2p_port=0`` and otherwise default ports. Pre-fix they + collided on a single keypair file; post-fix the scopes + differ because each carries its own PID.""" + import os + + from exo.main import ( + _node_id_keypair_scope, # pyright: ignore[reportPrivateUsage] + ) + + # Process A: real PID + scope_a = _node_id_keypair_scope( + self._build_args( + libp2p_port=0, + api_port=52415, + peer_download_port=52416, + no_peer_download=True, + spawn_api=False, + ) + ) + + # Process B: simulate a different PID via monkeypatch + real_pid = os.getpid() + monkeypatch.setattr(os, "getpid", lambda: real_pid + 1) + scope_b = _node_id_keypair_scope( + self._build_args( + libp2p_port=0, + api_port=52415, + peer_download_port=52416, + no_peer_download=True, + spawn_api=False, + ) + ) + + assert scope_a != scope_b, ( + "two same-host processes both binding libp2p_port=0 with " + "identical api/peer ports must produce distinct keypair " + "scopes; otherwise they load the same on-disk keypair " + "and collide on NodeId, breaking routing/election " + f"invariants. scope_a={scope_a!r} scope_b={scope_b!r}" + ) + + +def test_legacy_migration_serialized_across_process_scopes( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + """Codex P1 (PR #16 round-(N+13), router.py:359): legacy + adoption MUST be serialized across all ``process_scope`` values, + even when the per-scope ``resolved_path`` lock differs and the + cross-device byte-copy fallback path is taken inside + ``_migrate_legacy_node_id_keypair``. + + Pre-fix this test produces two identical scoped keypairs (both + matching the legacy bytes), simulating two same-host processes + racing legacy adoption: each acquires its own per-scope lock, + both fall through to the byte-copy branch, both read the same + legacy bytes, and both end up writing those bytes to their own + scoped file -- duplicate ``NodeId`` despite distinct scopes. + + Post-fix the migrator is wrapped in a second FileLock keyed on + the legacy path. The first scope wins adoption and unlinks the + legacy file; the second scope's migrator no-ops on the absent + legacy and generates a fresh keypair, so the two scopes diverge + as required by the per-process isolation invariant. + + We simulate the cross-device fallback by monkey-patching + ``Path.replace`` to raise ``OSError`` (the same trigger that + fires on Linux when ``XDG_*`` dirs span filesystems). The + serialization invariant is asserted by also blocking the byte + copy with a ``threading.Event`` so two threads must contend on + the legacy lock; only one thread should observe the legacy + file present at copy time. + """ + import threading + + import exo.routing.router as router_mod + + legacy_path = tmp_path / "cache" / "node_id.keypair" + base_path = tmp_path / "config" / "node_id.keypair" + legacy_path.parent.mkdir(parents=True) + base_path.parent.mkdir(parents=True) + + legacy_bytes = Keypair.generate().to_bytes() + legacy_path.write_bytes(legacy_bytes) + + # Force the cross-device fallback so the migrator goes through + # the read_bytes/write_bytes/unlink sequence (the path Codex + # flagged as racy). + real_replace = Path.replace + + def _force_cross_device(self: Path, target: object) -> object: # noqa: ANN001 + if Path(self) == legacy_path: + raise OSError("simulated cross-device link error") + return real_replace(self, target) # pyright: ignore[reportArgumentType] + + monkeypatch.setattr(Path, "replace", _force_cross_device) + + # Pause inside the byte-copy branch so two threads pile up on + # the legacy lock while one thread holds it. Without the legacy + # lock both threads would observe the legacy file present at + # this point and both would proceed to write_bytes/unlink. + in_copy = threading.Event() + release_copy = threading.Event() + real_write_bytes = Path.write_bytes + + def _slow_write_bytes(self: Path, data: bytes) -> int: + if self.parent == base_path.parent: + in_copy.set() + release_copy.wait(timeout=5.0) + return real_write_bytes(self, data) + + monkeypatch.setattr(Path, "write_bytes", _slow_write_bytes) + + keypairs: dict[int, Keypair] = {} + + def _run(scope: int) -> None: + keypairs[scope] = router_mod.get_node_id_keypair( + path=base_path, legacy_path=legacy_path, process_scope=scope + ) + + thread_a = threading.Thread(target=_run, args=(52416,), daemon=True) + thread_b = threading.Thread(target=_run, args=(52417,), daemon=True) + thread_a.start() + in_copy.wait(timeout=5.0) + # While thread_a is paused inside the byte copy holding the + # legacy lock, thread_b should be blocked on the legacy lock -- + # NOT racing through its own byte copy of the same legacy file. + thread_b.start() + # Give thread_b a moment to attempt acquiring the legacy lock + # so we can assert it did not slip through. + thread_b.join(timeout=0.2) + assert thread_b.is_alive(), ( + "second scope must be blocked on the legacy lock while the " + "first scope is mid-copy; if this fails, both scopes will " + "duplicate the legacy NodeId via the byte-copy race" + ) + release_copy.set() + thread_a.join(timeout=5.0) + thread_b.join(timeout=5.0) + assert not thread_a.is_alive() and not thread_b.is_alive() + + scope_a_bytes = keypairs[52416].to_bytes() + scope_b_bytes = keypairs[52417].to_bytes() + assert scope_a_bytes != scope_b_bytes, ( + "concurrent legacy adoption across distinct process_scope " + "values must NOT produce duplicate keypairs; the legacy " + "lock should let exactly one scope adopt the legacy bytes " + "while the other generates a fresh identity" + ) + # Exactly one scoped file should match the legacy bytes (the + # winner of adoption); the other was generated fresh. + scoped_a = base_path.parent / "node_id.52416.keypair" + scoped_b = base_path.parent / "node_id.52417.keypair" + matches = sum( + 1 + for p in (scoped_a, scoped_b) + if p.exists() and p.read_bytes() == legacy_bytes + ) + assert matches == 1, ( + f"exactly one scope must have adopted the legacy bytes; " + f"matches={matches} indicates the cross-device race fired" + ) + assert not legacy_path.exists(), "legacy file must be unlinked after adoption" + + +def test_legacy_migration_adopts_into_scoped_path(tmp_path: Path) -> None: + """When a process passes a scope and a legacy unscoped keypair + exists, the legacy bytes must be adopted into the scoped path. + This is the upgrade-time behaviour: the first process to boot + after the upgrade keeps the operator's existing identity; later + processes (different scopes) start with fresh identities, which + is exactly what per-process isolation requires.""" + legacy_path = tmp_path / "cache" / "node_id.keypair" + base_path = tmp_path / "config" / "node_id.keypair" + legacy_path.parent.mkdir(parents=True) + + expected_bytes = Keypair.generate().to_bytes() + legacy_path.write_bytes(expected_bytes) + + loaded = get_node_id_keypair( + path=base_path, legacy_path=legacy_path, process_scope=52416 + ) + + scoped = base_path.parent / "node_id.52416.keypair" + assert loaded.to_bytes() == expected_bytes + assert scoped.exists(), "legacy bytes must land at the scoped path" + assert scoped.read_bytes() == expected_bytes + assert not legacy_path.exists() diff --git a/src/exo/shared/apply.py b/src/exo/shared/apply.py index ce3f503537..278dd660ac 100644 --- a/src/exo/shared/apply.py +++ b/src/exo/shared/apply.py @@ -65,6 +65,13 @@ ) +def _is_rdma_ctl_enabled( + node_id: NodeId, node_rdma_ctl: Mapping[NodeId, NodeRdmaCtlStatus] +) -> bool: + status = node_rdma_ctl.get(node_id) + return status is not None and status.enabled + + def event_apply(event: Event, state: State) -> State: """Apply an event to state.""" match event: @@ -415,6 +422,9 @@ def apply_node_gathered_info(event: NodeGatheredInfo, state: State) -> State: for nid in state.node_thunderbolt for tb_ident in state.node_thunderbolt[nid].interfaces } + source_is_rdma_enabled = _is_rdma_ctl_enabled( + event.node_id, state.node_rdma_ctl + ) as_rdma_conns = [ Connection( source=event.node_id, @@ -427,6 +437,10 @@ def apply_node_gathered_info(event: NodeGatheredInfo, state: State) -> State: for tb_conn in info.conns if tb_conn.source_uuid in conn_map if tb_conn.sink_uuid in conn_map + if source_is_rdma_enabled + and _is_rdma_ctl_enabled( + conn_map[tb_conn.sink_uuid][0], state.node_rdma_ctl + ) ] topology.replace_all_out_rdma_connections(event.node_id, as_rdma_conns) case ThunderboltBridgeInfo(): @@ -450,6 +464,8 @@ def apply_node_gathered_info(event: NodeGatheredInfo, state: State) -> State: **state.node_rdma_ctl, event.node_id: NodeRdmaCtlStatus(enabled=info.enabled), } + if not info.enabled: + topology.remove_all_rdma_connections_touching(event.node_id) return state.model_copy(update=update) diff --git a/src/exo/shared/constants.py b/src/exo/shared/constants.py index d79354184b..89bc512cab 100644 --- a/src/exo/shared/constants.py +++ b/src/exo/shared/constants.py @@ -8,12 +8,12 @@ def _get_xdg_dir(env_var: str, fallback: str) -> Path: - """Get XDG directory, prioritising EXO_HOME environment variable if its set. On non-Linux platforms, default to ~/.exo.""" + """Get XDG directory, prioritising EXO_HOME environment variable if its set. On non-Linux platforms, default to ~/.exo. Cache home always prefers .cache/exo""" if _EXO_HOME_ENV is not None: return Path.home() / _EXO_HOME_ENV - if sys.platform != "linux": + if sys.platform != "linux" and env_var != "XDG_CACHE_HOME": return Path.home() / ".exo" xdg_value = os.environ.get(env_var, None) @@ -68,10 +68,19 @@ def _parse_colon_dirs(env_var: str) -> tuple[Path, ...]: # Log files (data/logs or cache) EXO_LOG_DIR = EXO_CACHE_HOME / "exo_log" EXO_LOG = EXO_LOG_DIR / "exo.log" -EXO_TEST_LOG = EXO_CACHE_HOME / "exo_test.log" -# Identity (config) +# Identity (config -- persistent across cache eviction). +# +# Codex P1 (PR #16 round 5): keeping the node-ID keypair under +# ``EXO_CACHE_HOME`` makes cluster identity vulnerable to normal +# cache cleanup, which causes nodes to come up with a new peer ID +# after a cache wipe and breaks the intended persistence of cluster +# membership / mDNS routes. Identity material lives under +# ``EXO_CONFIG_HOME`` instead. The legacy cache path is migrated +# on first use by ``get_node_id_keypair`` to preserve existing +# identity across the upgrade. EXO_NODE_ID_KEYPAIR = EXO_CONFIG_HOME / "node_id.keypair" +EXO_LEGACY_NODE_ID_KEYPAIR = EXO_CACHE_HOME / "node_id.keypair" EXO_CONFIG_FILE = EXO_CONFIG_HOME / "config.toml" # libp2p topics for event forwarding @@ -101,3 +110,6 @@ def _parse_colon_dirs(env_var: str) -> tuple[Path, ...]: EXO_MAX_CONCURRENT_REQUESTS = int(os.getenv("EXO_MAX_CONCURRENT_REQUESTS", "8")) EXO_MAX_INSTANCE_RETRIES = 5 + +# Peer-to-peer model download server port (one above default API port) +EXO_PEER_DOWNLOAD_PORT = int(os.getenv("EXO_PEER_DOWNLOAD_PORT", "52416")) diff --git a/src/exo/shared/tests/test_apply/test_apply_rdma_gating.py b/src/exo/shared/tests/test_apply/test_apply_rdma_gating.py new file mode 100644 index 0000000000..492e3fc5ea --- /dev/null +++ b/src/exo/shared/tests/test_apply/test_apply_rdma_gating.py @@ -0,0 +1,231 @@ +from datetime import datetime, timezone + +from exo.shared.apply import apply_node_gathered_info +from exo.shared.topology import Topology +from exo.shared.types.common import NodeId +from exo.shared.types.events import NodeGatheredInfo +from exo.shared.types.profiling import ( + NodeRdmaCtlStatus, + NodeThunderboltInfo, +) +from exo.shared.types.state import State +from exo.shared.types.thunderbolt import ThunderboltConnection, ThunderboltIdentifier +from exo.shared.types.topology import RDMAConnection +from exo.utils.info_gatherer.info_gatherer import ( + MacThunderboltConnections, + RdmaCtlStatus, +) + + +def _now() -> str: + return datetime.now(timezone.utc).isoformat() + + +def _make_state_with_thunderbolt_idents( + *node_ids_and_uuids: tuple[NodeId, str, str], + rdma_ctl: dict[NodeId, NodeRdmaCtlStatus] | None = None, +) -> State: + """Build a State with Thunderbolt identifiers per node so the apply MacThunderboltConnections + case can resolve uuid -> (node, iface).""" + node_thunderbolt = { + nid: NodeThunderboltInfo( + interfaces=[ThunderboltIdentifier(rdma_interface=iface, domain_uuid=uuid)] + ) + for nid, uuid, iface in node_ids_and_uuids + } + return State( + node_thunderbolt=node_thunderbolt, + node_rdma_ctl=rdma_ctl or {}, + ) + + +def _has_rdma_edge(topology: Topology, source: NodeId, sink: NodeId) -> bool: + return any( + isinstance(edge, RDMAConnection) + for edge in topology.get_all_connections_between(source, sink) + ) + + +def test_mac_thunderbolt_connections_emits_rdma_when_both_endpoints_enabled(): + node_a = NodeId() + node_b = NodeId() + state = _make_state_with_thunderbolt_idents( + (node_a, "uuid-a", "rdma_en1"), + (node_b, "uuid-b", "rdma_en1"), + rdma_ctl={ + node_a: NodeRdmaCtlStatus(enabled=True), + node_b: NodeRdmaCtlStatus(enabled=True), + }, + ) + + event = NodeGatheredInfo( + node_id=node_a, + when=_now(), + info=MacThunderboltConnections( + conns=[ThunderboltConnection(source_uuid="uuid-a", sink_uuid="uuid-b")] + ), + ) + + new_state = apply_node_gathered_info(event, state) + + assert _has_rdma_edge(new_state.topology, node_a, node_b) + + +def test_mac_thunderbolt_connections_skips_rdma_when_source_rdma_ctl_disabled(): + node_a = NodeId() + node_b = NodeId() + state = _make_state_with_thunderbolt_idents( + (node_a, "uuid-a", "rdma_en1"), + (node_b, "uuid-b", "rdma_en1"), + rdma_ctl={ + node_a: NodeRdmaCtlStatus(enabled=False), + node_b: NodeRdmaCtlStatus(enabled=True), + }, + ) + + event = NodeGatheredInfo( + node_id=node_a, + when=_now(), + info=MacThunderboltConnections( + conns=[ThunderboltConnection(source_uuid="uuid-a", sink_uuid="uuid-b")] + ), + ) + + new_state = apply_node_gathered_info(event, state) + + assert not _has_rdma_edge(new_state.topology, node_a, node_b) + + +def test_mac_thunderbolt_connections_skips_rdma_when_sink_rdma_ctl_disabled(): + node_a = NodeId() + node_b = NodeId() + state = _make_state_with_thunderbolt_idents( + (node_a, "uuid-a", "rdma_en1"), + (node_b, "uuid-b", "rdma_en1"), + rdma_ctl={ + node_a: NodeRdmaCtlStatus(enabled=True), + node_b: NodeRdmaCtlStatus(enabled=False), + }, + ) + + event = NodeGatheredInfo( + node_id=node_a, + when=_now(), + info=MacThunderboltConnections( + conns=[ThunderboltConnection(source_uuid="uuid-a", sink_uuid="uuid-b")] + ), + ) + + new_state = apply_node_gathered_info(event, state) + + assert not _has_rdma_edge(new_state.topology, node_a, node_b) + + +def test_mac_thunderbolt_connections_skips_rdma_when_rdma_ctl_status_missing(): + """Missing rdma_ctl status defaults to not-enabled — node is RDMA-incapable.""" + node_a = NodeId() + node_b = NodeId() + state = _make_state_with_thunderbolt_idents( + (node_a, "uuid-a", "rdma_en1"), + (node_b, "uuid-b", "rdma_en1"), + rdma_ctl={ + node_a: NodeRdmaCtlStatus(enabled=True), + # node_b intentionally absent + }, + ) + + event = NodeGatheredInfo( + node_id=node_a, + when=_now(), + info=MacThunderboltConnections( + conns=[ThunderboltConnection(source_uuid="uuid-a", sink_uuid="uuid-b")] + ), + ) + + new_state = apply_node_gathered_info(event, state) + + assert not _has_rdma_edge(new_state.topology, node_a, node_b) + + +def test_rdma_ctl_status_disabled_purges_existing_rdma_edges(): + """When a node reports rdma_ctl disabled, all RDMA edges touching it must be removed.""" + node_a = NodeId() + node_b = NodeId() + + # Start with both nodes RDMA-enabled and existing RDMA edges in the topology. + state = _make_state_with_thunderbolt_idents( + (node_a, "uuid-a", "rdma_en1"), + (node_b, "uuid-b", "rdma_en1"), + rdma_ctl={ + node_a: NodeRdmaCtlStatus(enabled=True), + node_b: NodeRdmaCtlStatus(enabled=True), + }, + ) + state = apply_node_gathered_info( + NodeGatheredInfo( + node_id=node_a, + when=_now(), + info=MacThunderboltConnections( + conns=[ThunderboltConnection(source_uuid="uuid-a", sink_uuid="uuid-b")] + ), + ), + state, + ) + state = apply_node_gathered_info( + NodeGatheredInfo( + node_id=node_b, + when=_now(), + info=MacThunderboltConnections( + conns=[ThunderboltConnection(source_uuid="uuid-b", sink_uuid="uuid-a")] + ), + ), + state, + ) + assert _has_rdma_edge(state.topology, node_a, node_b) + assert _has_rdma_edge(state.topology, node_b, node_a) + + # Now node_a flips to rdma_ctl disabled — both directions of RDMA edge must drop. + state = apply_node_gathered_info( + NodeGatheredInfo( + node_id=node_a, when=_now(), info=RdmaCtlStatus(enabled=False) + ), + state, + ) + + assert not _has_rdma_edge(state.topology, node_a, node_b) + assert not _has_rdma_edge(state.topology, node_b, node_a) + assert state.node_rdma_ctl[node_a].enabled is False + + +def test_topology_remove_all_rdma_connections_touching_keeps_socket_edges(): + """Purging RDMA edges for a disabled node must not affect non-RDMA edges.""" + from exo.shared.types.multiaddr import Multiaddr + from exo.shared.types.topology import Connection, SocketConnection + + topology = Topology() + node_a = NodeId() + node_b = NodeId() + topology.add_node(node_a) + topology.add_node(node_b) + topology.add_connection( + Connection( + source=node_a, + sink=node_b, + edge=RDMAConnection( + source_rdma_iface="rdma_en1", sink_rdma_iface="rdma_en1" + ), + ) + ) + socket_edge = SocketConnection( + sink_multiaddr=Multiaddr(address="/ip4/10.0.0.1/tcp/8000") + ) + topology.add_connection(Connection(source=node_a, sink=node_b, edge=socket_edge)) + + topology.remove_all_rdma_connections_touching(node_a) + + assert not _has_rdma_edge(topology, node_a, node_b) + # Socket edge survives. + assert any( + isinstance(edge, SocketConnection) + for edge in topology.get_all_connections_between(node_a, node_b) + ) diff --git a/src/exo/shared/tests/test_xdg_paths.py b/src/exo/shared/tests/test_xdg_paths.py index f3b82ebffd..dce2c7d7c1 100644 --- a/src/exo/shared/tests/test_xdg_paths.py +++ b/src/exo/shared/tests/test_xdg_paths.py @@ -94,7 +94,27 @@ def test_macos_uses_traditional_paths(): home = Path.home() assert home / ".exo" == constants.EXO_CONFIG_HOME assert home / ".exo" == constants.EXO_DATA_HOME - assert home / ".exo" == constants.EXO_CACHE_HOME + assert home / ".cache" / "exo" == constants.EXO_CACHE_HOME + + +def test_exo_home_env(): + """Test that macOS uses traditional ~/.exo directory.""" + # Remove EXO_HOME to ensure we test the default behavior + env = {k: v for k, v in os.environ.items() if k != "EXO_HOME"} + env["EXO_HOME"] = "/exo" + with ( + mock.patch.dict(os.environ, env, clear=True), + mock.patch.object(sys, "platform", "darwin"), + ): + import importlib + + import exo.shared.constants as constants + + importlib.reload(constants) + + assert Path("/exo") == constants.EXO_CONFIG_HOME + assert Path("/exo") == constants.EXO_DATA_HOME + assert Path("/exo") == constants.EXO_CACHE_HOME def test_node_id_in_config_dir(): diff --git a/src/exo/shared/topology.py b/src/exo/shared/topology.py index 9d649a6f4e..5c12bc5077 100644 --- a/src/exo/shared/topology.py +++ b/src/exo/shared/topology.py @@ -169,6 +169,22 @@ def replace_all_out_rdma_connections( for conn in new_connections: self.add_connection(conn) + def remove_all_rdma_connections_touching(self, node_id: NodeId) -> None: + """Remove every incoming or outgoing RDMA edge touching node_id.""" + if node_id not in self._vertex_indices: + return + rx_idx = self._vertex_indices[node_id] + rdma_edge_idxs = [ + edge_idx + for edge_idx in ( + *self._graph.out_edge_indices(rx_idx), + *self._graph.in_edge_indices(rx_idx), + ) + if isinstance(self._graph.get_edge_data_by_index(edge_idx), RDMAConnection) + ] + for edge_idx in rdma_edge_idxs: + self._graph.remove_edge_from_index(edge_idx) + def remove_connection(self, conn: Connection) -> None: if ( conn.source not in self._vertex_indices diff --git a/src/exo/shared/types/commands.py b/src/exo/shared/types/commands.py index 62f73ac399..2088402127 100644 --- a/src/exo/shared/types/commands.py +++ b/src/exo/shared/types/commands.py @@ -68,9 +68,20 @@ class RequestEventLog(BaseCommand): since_idx: int +class PeerEndpoint(FrozenModel): + """A peer node that has (or is downloading) a model, with its network address.""" + + node_id: NodeId + ip: str + port: int + status: str = "complete" # "complete" or "ongoing" + connection_type: str = "socket" # "rdma" or "socket" + + class StartDownload(BaseCommand): target_node_id: NodeId shard_metadata: ShardMetadata + available_peers: list[PeerEndpoint] = Field(default_factory=list) class DeleteDownload(BaseCommand): diff --git a/src/exo/utils/keyed_backoff.py b/src/exo/utils/keyed_backoff.py index 4d7c9a66ed..a95fe5c5f7 100644 --- a/src/exo/utils/keyed_backoff.py +++ b/src/exo/utils/keyed_backoff.py @@ -29,6 +29,10 @@ def attempts(self, key: K) -> int: """Return the number of recorded attempts for a key.""" return self._attempts.get(key, 0) + def tracked_keys(self) -> set[K]: + """Return keys that currently have recorded backoff state.""" + return set(self._attempts) | set(self._last_time) + def reset(self, key: K) -> None: """Reset backoff state for a key (e.g., on success).""" self._attempts.pop(key, None) diff --git a/src/exo/utils/tests/test_keyed_backoff.py b/src/exo/utils/tests/test_keyed_backoff.py new file mode 100644 index 0000000000..b592a4fabd --- /dev/null +++ b/src/exo/utils/tests/test_keyed_backoff.py @@ -0,0 +1,13 @@ +from exo.utils.keyed_backoff import KeyedBackoff + + +def test_tracked_keys_reports_and_resets_backoff_state() -> None: + backoff = KeyedBackoff[str]() + + backoff.record_attempt("instance-a") + + assert backoff.tracked_keys() == {"instance-a"} + + backoff.reset("instance-a") + + assert backoff.tracked_keys() == set() diff --git a/src/exo/worker/engines/mlx/cache.py b/src/exo/worker/engines/mlx/cache.py index 7cdcc77fbe..84f852933f 100644 --- a/src/exo/worker/engines/mlx/cache.py +++ b/src/exo/worker/engines/mlx/cache.py @@ -557,18 +557,108 @@ def get_memory_used_percentage() -> float: return float(mem.percent / 100) +def _model_is_pipeline_parallel(model: Model) -> bool: + """True iff the model has pipeline-parallel layer wrappers installed. + + Only the PP path is safe to combine with QuantizedKVCache right now: + the single-node BatchGenerator code path in mlx-lm calls + ``_merge_caches`` on every step (even for a single in-flight request), + and QuantizedKVCache does not implement ``merge``. Attempting to use + a quantized cache in that path crashes with:: + + does not yet + support batching with history + + Detecting PP mode by layer type is cheap and avoids threading the + distributed group through every cache call site. + """ + try: + from exo.worker.engines.mlx.auto_parallel import ( + PipelineFirstLayer, + PipelineLastLayer, + ) + except Exception: + return False + layers = getattr(model, "layers", None) + if layers is None: + return False + for layer in layers: # type: ignore[reportUnknownVariableType] + if isinstance(layer, (PipelineFirstLayer, PipelineLastLayer)): + return True + return False + + def make_kv_cache( model: Model, max_kv_size: int | None = None, keep: int = 0 ) -> KVCacheType: assert hasattr(model, "layers") if hasattr(model, "make_cache"): - logger.info("Using MLX LM's make cache") - return model.make_cache() # type: ignore + caches: list[ + KVCache | RotatingKVCache | QuantizedKVCache | ArraysCache | CacheList + ] = list(model.make_cache()) # pyright: ignore[reportUnknownArgumentType, reportUnknownMemberType] + # Apply the same single-node safeguard used in the + # ``make_cache``-less branch below: ``QuantizedKVCache`` + # cannot be combined with the single-node ``BatchGenerator`` + # path because mlx-lm calls ``_merge_caches`` on every step + # and ``QuantizedKVCache`` doesn't implement ``merge``. Models + # with ``make_cache()`` (e.g. Gemma3 with mixed attention + # layers) used to skip this guard and would crash at runtime + # with:: + # + # does not + # yet support batching with history + # + # Pipeline-parallel deployments use a different generation + # path that does support quantized caches, so we honor + # ``EXO_KV_CACHE_BITS`` only when the model has PP layer + # wrappers installed. + if KV_CACHE_BITS is not None and _model_is_pipeline_parallel(model): + # Honor KV_CACHE_BITS even when the model provides its own + # make_cache(). Replace plain KVCache entries with + # QuantizedKVCache; leave ArraysCache (DeltaNet/SSM) and other + # cache types alone since they don't support quantization. + quantized = 0 + for i, c in enumerate(caches): + if isinstance(c, KVCache): + qc = QuantizedKVCache( + group_size=CACHE_GROUP_SIZE, bits=KV_CACHE_BITS + ) + qc.step = 16384 + caches[i] = qc + quantized += 1 + logger.info( + f"Using quantized KV cache " + f"(bits={KV_CACHE_BITS}, group_size={CACHE_GROUP_SIZE}) " + f"for {quantized}/{len(caches)} layers" + ) + else: + if KV_CACHE_BITS is not None: + logger.info( + f"EXO_KV_CACHE_BITS={KV_CACHE_BITS} ignored in single-node mode " + f"(QuantizedKVCache has no merge() support, " + f"required by BatchGenerator); using model.make_cache() unmodified" + ) + else: + logger.info("Using MLX LM's make cache") + # Increase KVCache step size to reduce Metal allocator + # fragmentation. Default step=256 causes a mx.concatenate + # expansion every prefill chunk; a larger step lets the cache + # pre-allocate and write in-place for most of the prefill. + for c in caches: + if isinstance(c, KVCache): + c.step = 16384 + return caches if max_kv_size is None: - if KV_CACHE_BITS is None: - logger.info("Using default KV cache") + if KV_CACHE_BITS is None or not _model_is_pipeline_parallel(model): + if KV_CACHE_BITS is not None: + logger.info( + f"EXO_KV_CACHE_BITS={KV_CACHE_BITS} ignored in single-node mode " + f"(QuantizedKVCache has no merge() support, required by BatchGenerator)" + ) + else: + logger.info("Using default KV cache") return [KVCache() for _ in model.layers] else: logger.info("Using quantized KV cache") diff --git a/src/exo/worker/engines/mlx/constants.py b/src/exo/worker/engines/mlx/constants.py index 86a663e424..c44e93e750 100644 --- a/src/exo/worker/engines/mlx/constants.py +++ b/src/exo/worker/engines/mlx/constants.py @@ -1,3 +1,5 @@ +import os + # TODO: Do we want so many constants? # I think we want a lot of these as parameters? @@ -9,9 +11,14 @@ KEEP_KV_SIZE: int | None = 1600 QUANTIZE_MODEL_MODE: str | None = "affine" CACHE_GROUP_SIZE: int = 64 -KV_CACHE_BITS: int | None = None +KV_CACHE_BITS: int | None = ( + int(os.environ["EXO_KV_CACHE_BITS"]) + if os.environ.get("EXO_KV_CACHE_BITS") + else None +) DEFAULT_TOP_LOGPROBS: int = 5 -# TODO: We should really make this opt-in, but Kimi requires trust_remote_code=True +# True for built-in models with known model cards; custom models added via API default to False +# and can be overridden with the --trust-remote-code CLI flag. TRUST_REMOTE_CODE: bool = True diff --git a/src/exo/worker/engines/mlx/utils_mlx.py b/src/exo/worker/engines/mlx/utils_mlx.py index 463ad1e91d..1eeb99d404 100644 --- a/src/exo/worker/engines/mlx/utils_mlx.py +++ b/src/exo/worker/engines/mlx/utils_mlx.py @@ -480,10 +480,14 @@ def shard_and_load( def get_tokenizer(model_path: Path, shard_metadata: ShardMetadata) -> TokenizerWrapper: """Load tokenizer for a model shard. Delegates to load_tokenizer_for_model_id.""" + trust_remote_code = ( + shard_metadata.model_card.trust_remote_code + or os.environ.get("EXO_TRUST_REMOTE_CODE") == "1" + ) return load_tokenizer_for_model_id( shard_metadata.model_card.model_id, model_path, - trust_remote_code=shard_metadata.model_card.trust_remote_code, + trust_remote_code=trust_remote_code, ) diff --git a/src/exo/worker/main.py b/src/exo/worker/main.py index b35f946aac..7b6ea75ce9 100644 --- a/src/exo/worker/main.py +++ b/src/exo/worker/main.py @@ -3,11 +3,12 @@ from datetime import datetime, timezone import anyio -from anyio import fail_after, to_thread +from anyio import fail_after, move_on_after, to_thread from loguru import logger from exo.api.types import ImageEditsTaskParams from exo.download.download_utils import is_read_only_model_dir, resolve_existing_model +from exo.download.peer_state import discover_peers_for_model from exo.shared.apply import apply from exo.shared.constants import EXO_MAX_INSTANCE_RETRIES from exo.shared.models.model_cards import ModelId, add_to_card_cache, delete_custom_card @@ -72,6 +73,7 @@ def __init__( command_sender: Sender[ForwarderCommand], download_command_sender: Sender[ForwarderDownloadCommand], api_port: int, + peer_download_port: int, ): self.node_id: NodeId = node_id self.event_receiver = event_receiver @@ -79,6 +81,14 @@ def __init__( self.command_sender = command_sender self.download_command_sender = download_command_sender self.api_port = api_port + # Codex P2 (PR #16 round 3): the peer-download listener port is + # now per-process configurable instead of a module-level + # constant. Use the local value when computing + # ``discover_peers_for_model`` results because peers in the + # current architecture all bind the same port (cluster-wide + # convention enforced via ``EXO_PEER_DOWNLOAD_PORT`` / + # ``--peer-download-port``). + self._peer_download_port = peer_download_port self.state: State = State() self.runners: dict[RunnerId, RunnerSupervisor] = {} @@ -109,6 +119,7 @@ async def run(self): tg.start_soon(self._forward_info, info_recv) tg.start_soon(self.plan_step) tg.start_soon(self._event_applier) + tg.start_soon(self._reconcile_instance_backoff) tg.start_soon(self._poll_connection_updates) finally: # Actual shutdown code - waits for all tasks to complete before executing. @@ -179,6 +190,17 @@ async def _event_applier(self): if isinstance(event, CustomModelCardDeleted): await delete_custom_card(event.model_id) + async def _reconcile_instance_backoff(self) -> None: + while True: + await anyio.sleep(1) + self._reconcile_instance_backoff_once() + + def _reconcile_instance_backoff_once(self) -> None: + live_instances = set(self.state.instances) + for instance_id in self._instance_backoff.tracked_keys(): + if instance_id not in live_instances: + self._instance_backoff.reset(instance_id) + async def plan_step(self): while True: await anyio.sleep(0.1) @@ -252,12 +274,19 @@ async def plan_step(self): ) ) else: + peers = discover_peers_for_model( + self.node_id, + self.state, + shard.model_card.model_id.normalize(), + self._peer_download_port, + ) await self.download_command_sender.send( ForwarderDownloadCommand( origin=self._system_id, command=StartDownload( target_node_id=self.node_id, shard_metadata=shard, + available_peers=peers, ), ) ) @@ -356,8 +385,16 @@ async def plan_step(self): await self._start_runner_task(task) async def shutdown(self): + self.event_sender.close() + self.command_sender.close() + self.download_command_sender.close() + for runner in self.runners.values(): + runner.shutdown() self._tg.cancel_tasks() - await self._stopped.wait() + with move_on_after(5) as scope: + await self._stopped.wait() + if scope.cancel_called: + logger.warning("Timed out waiting for Worker shutdown") async def _start_runner_task(self, task: Task): if (instance := self.state.instances.get(task.instance_id)) is not None: diff --git a/src/exo/worker/tests/unittests/test_worker_instance_backoff.py b/src/exo/worker/tests/unittests/test_worker_instance_backoff.py new file mode 100644 index 0000000000..b0052c1eb7 --- /dev/null +++ b/src/exo/worker/tests/unittests/test_worker_instance_backoff.py @@ -0,0 +1,36 @@ +# pyright: reportPrivateUsage=false + +from exo.shared.types.common import ModelId, NodeId +from exo.shared.types.state import State +from exo.shared.types.worker.instances import InstanceId, MlxRingInstance +from exo.shared.types.worker.runners import ShardAssignments +from exo.utils.keyed_backoff import KeyedBackoff +from exo.worker.main import Worker + + +def _make_instance(instance_id: InstanceId) -> MlxRingInstance: + return MlxRingInstance( + instance_id=instance_id, + shard_assignments=ShardAssignments( + model_id=ModelId("test-model"), + node_to_runner={}, + runner_to_shard={}, + ), + hosts_by_node={NodeId("node-1"): []}, + ephemeral_port=1, + ) + + +def test_worker_reconciles_instance_backoff_from_state() -> None: + live_instance_id = InstanceId("inst-live") + deleted_instance_id = InstanceId("inst-deleted") + worker = object.__new__(Worker) + worker.state = State(instances={live_instance_id: _make_instance(live_instance_id)}) + worker._instance_backoff = KeyedBackoff[InstanceId]() + worker._instance_backoff.record_attempt(live_instance_id) + worker._instance_backoff.record_attempt(deleted_instance_id) + + worker._reconcile_instance_backoff_once() + + assert worker._instance_backoff.attempts(live_instance_id) == 1 + assert worker._instance_backoff.attempts(deleted_instance_id) == 0