From 4d70a624b9aea89097ab2b5caa5fa0a4876cf6fc Mon Sep 17 00:00:00 2001 From: Evan Date: Tue, 24 Feb 2026 11:49:42 +0000 Subject: [PATCH 01/29] persist node ids in .cache brings back EXO_CACHE_HOME as always ~/.cache/exo/, and store the node id in there. no random copies now! --- src/exo/routing/router.py | 2 -- src/exo/shared/constants.py | 7 +++---- src/exo/shared/tests/test_xdg_paths.py | 22 +++++++++++++++++++++- 3 files changed, 24 insertions(+), 7 deletions(-) diff --git a/src/exo/routing/router.py b/src/exo/routing/router.py index ebe0ea8d90..5b679fe192 100644 --- a/src/exo/routing/router.py +++ b/src/exo/routing/router.py @@ -298,8 +298,6 @@ def get_node_id_keypair( Obtains the :class:`Keypair` associated with this node-ID. Obtain the :class:`PeerId` by from it. """ - # TODO(evan): bring back node id persistence once we figure out how to deal with duplicates - return Keypair.generate() def lock_path(path: str | bytes | PathLike[str] | PathLike[bytes]) -> Path: return Path(str(path) + ".lock") diff --git a/src/exo/shared/constants.py b/src/exo/shared/constants.py index d79354184b..ad76a88ffb 100644 --- a/src/exo/shared/constants.py +++ b/src/exo/shared/constants.py @@ -8,12 +8,12 @@ def _get_xdg_dir(env_var: str, fallback: str) -> Path: - """Get XDG directory, prioritising EXO_HOME environment variable if its set. On non-Linux platforms, default to ~/.exo.""" + """Get XDG directory, prioritising EXO_HOME environment variable if its set. On non-Linux platforms, default to ~/.exo. Cache home always prefers .cache/exo""" if _EXO_HOME_ENV is not None: return Path.home() / _EXO_HOME_ENV - if sys.platform != "linux": + if sys.platform != "linux" and env_var != "XDG_CACHE_HOME": return Path.home() / ".exo" xdg_value = os.environ.get(env_var, None) @@ -68,10 +68,9 @@ def _parse_colon_dirs(env_var: str) -> tuple[Path, ...]: # Log files (data/logs or cache) EXO_LOG_DIR = EXO_CACHE_HOME / "exo_log" EXO_LOG = EXO_LOG_DIR / "exo.log" -EXO_TEST_LOG = EXO_CACHE_HOME / "exo_test.log" # Identity (config) -EXO_NODE_ID_KEYPAIR = EXO_CONFIG_HOME / "node_id.keypair" +EXO_NODE_ID_KEYPAIR = EXO_CACHE_HOME / "node_id.keypair" EXO_CONFIG_FILE = EXO_CONFIG_HOME / "config.toml" # libp2p topics for event forwarding diff --git a/src/exo/shared/tests/test_xdg_paths.py b/src/exo/shared/tests/test_xdg_paths.py index f3b82ebffd..dce2c7d7c1 100644 --- a/src/exo/shared/tests/test_xdg_paths.py +++ b/src/exo/shared/tests/test_xdg_paths.py @@ -94,7 +94,27 @@ def test_macos_uses_traditional_paths(): home = Path.home() assert home / ".exo" == constants.EXO_CONFIG_HOME assert home / ".exo" == constants.EXO_DATA_HOME - assert home / ".exo" == constants.EXO_CACHE_HOME + assert home / ".cache" / "exo" == constants.EXO_CACHE_HOME + + +def test_exo_home_env(): + """Test that macOS uses traditional ~/.exo directory.""" + # Remove EXO_HOME to ensure we test the default behavior + env = {k: v for k, v in os.environ.items() if k != "EXO_HOME"} + env["EXO_HOME"] = "/exo" + with ( + mock.patch.dict(os.environ, env, clear=True), + mock.patch.object(sys, "platform", "darwin"), + ): + import importlib + + import exo.shared.constants as constants + + importlib.reload(constants) + + assert Path("/exo") == constants.EXO_CONFIG_HOME + assert Path("/exo") == constants.EXO_DATA_HOME + assert Path("/exo") == constants.EXO_CACHE_HOME def test_node_id_in_config_dir(): From 0c813b4db0d060259a7f16459f1ab70b5fe0f287 Mon Sep 17 00:00:00 2001 From: Jordan Miller Date: Tue, 10 Mar 2026 21:33:04 -0500 Subject: [PATCH 02/29] feat: peer-to-peer model downloads over LAN MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When multiple nodes need the same model, only one downloads from HuggingFace while others fetch it over the LAN — eliminating redundant internet downloads and cutting cluster startup time roughly in half. Architecture: - PeerFileServer: lightweight aiohttp server on each node (port 52416) that serves model files from local cache with Range request support - PeerAwareShardDownloader: wraps ResumableShardDownloader, checks if any peer already has the model before hitting HuggingFace - Streaming relay: followers can download from a peer while it's still downloading from HF, via .partial.meta companion files that track flushed byte boundaries - Graceful fallback: if peer transfer fails, falls back to HuggingFace with .partial resume support Key design decisions: - No new gossipsub messages — reuses existing NodeDownloadProgress events and topology for peer discovery and IP resolution - No leader election — first node to start becomes de facto seed - Backend-agnostic — works with MLX, tinygrad, PyTorch (any engine) - Network-agnostic — works over any LAN (Ethernet, WiFi, Thunderbolt) - Zero config — enabled by default, disable with --no-peer-download - Complementary to PR #1463 (MLX memory-to-memory transfer) Addresses: #1257, #721, #1606 Co-Authored-By: Claude Opus 4.6 --- src/exo/download/download_utils.py | 25 ++ src/exo/download/impl_shard_downloader.py | 13 +- src/exo/download/peer_download.py | 169 +++++++++++ src/exo/download/peer_file_server.py | 174 ++++++++++++ src/exo/download/peer_shard_downloader.py | 277 +++++++++++++++++++ src/exo/download/peer_state.py | 93 +++++++ src/exo/download/tests/test_peer_download.py | 265 ++++++++++++++++++ src/exo/main.py | 69 ++++- src/exo/shared/constants.py | 3 + 9 files changed, 1072 insertions(+), 16 deletions(-) create mode 100644 src/exo/download/peer_download.py create mode 100644 src/exo/download/peer_file_server.py create mode 100644 src/exo/download/peer_shard_downloader.py create mode 100644 src/exo/download/peer_state.py create mode 100644 src/exo/download/tests/test_peer_download.py diff --git a/src/exo/download/download_utils.py b/src/exo/download/download_utils.py index 277dd3a6d3..0da4542a36 100644 --- a/src/exo/download/download_utils.py +++ b/src/exo/download/download_utils.py @@ -1,5 +1,6 @@ import asyncio import hashlib +import json import os import random import shutil @@ -777,6 +778,9 @@ async def _download_file( ) as f: while chunk := await r.content.read(8 * 1024 * 1024): n_read = n_read + (await f.write(chunk)) + await f.flush() + # Write companion metadata for peer download streaming + await _write_partial_meta(partial_path, n_read, length, remote_hash) on_progress(n_read, length, False) final_hash = await calc_hash( @@ -792,10 +796,31 @@ async def _download_file( f"Downloaded file {target_dir / path} has hash {final_hash} but remote hash is {remote_hash}" ) await aios.rename(partial_path, target_dir / path) + # Clean up companion metadata file + meta_path = Path(f"{partial_path}.meta") + if await aios.path.exists(meta_path): + await aios.remove(meta_path) on_progress(length, length, True) return target_dir / path +async def _write_partial_meta( + partial_path: Path, safe_bytes: int, total: int, etag: str +) -> None: + """Write companion .partial.meta file for peer download streaming. + + This small JSON file tells the peer file server how many bytes of the + .partial file have been safely flushed to disk and are safe to serve. + """ + meta_path = Path(f"{partial_path}.meta") + meta = json.dumps({"safe_bytes": safe_bytes, "total": total, "etag": etag}) + # Write to temp then rename for atomicity + tmp_path = Path(f"{partial_path}.meta.tmp") + async with aiofiles.open(tmp_path, "w") as f: + await f.write(meta) + await aios.rename(tmp_path, meta_path) + + def calculate_repo_progress( shard: ShardMetadata, model_id: ModelId, diff --git a/src/exo/download/impl_shard_downloader.py b/src/exo/download/impl_shard_downloader.py index 14ec83b689..1cb304c66d 100644 --- a/src/exo/download/impl_shard_downloader.py +++ b/src/exo/download/impl_shard_downloader.py @@ -10,6 +10,8 @@ RepoDownloadProgress, download_shard, ) +from exo.download.peer_shard_downloader import PeerAwareShardDownloader +from exo.download.peer_state import PeerStateProvider from exo.download.shard_downloader import ShardDownloader from exo.shared.models.model_cards import ( ModelCard, @@ -25,11 +27,16 @@ def exo_shard_downloader( - max_parallel_downloads: int = 8, offline: bool = False + max_parallel_downloads: int = 8, + offline: bool = False, + peer_state_provider: PeerStateProvider | None = None, ) -> ShardDownloader: - return SingletonShardDownloader( - ResumableShardDownloader(max_parallel_downloads, offline=offline) + inner: ShardDownloader = ResumableShardDownloader( + max_parallel_downloads, offline=offline ) + if peer_state_provider is not None: + inner = PeerAwareShardDownloader(inner, peer_state_provider) + return SingletonShardDownloader(inner) async def build_base_shard(model_id: ModelId) -> ShardMetadata: diff --git a/src/exo/download/peer_download.py b/src/exo/download/peer_download.py new file mode 100644 index 0000000000..1fab3657d6 --- /dev/null +++ b/src/exo/download/peer_download.py @@ -0,0 +1,169 @@ +"""HTTP client for downloading model files from peer nodes. + +Instead of downloading from HuggingFace, nodes can fetch model files from +peers on the same LAN that already have them (or are still downloading them). +Falls back gracefully if the peer is unreachable or the transfer fails. +""" + +import asyncio +import json +from dataclasses import dataclass +from pathlib import Path +from typing import Callable + +import aiofiles +import aiofiles.os as aios +import aiohttp +from loguru import logger + + +@dataclass(frozen=True) +class PeerFileInfo: + """Status of a single file on a peer node.""" + + path: str + size: int + complete: bool + safe_bytes: int + + +async def get_peer_file_status( + peer_host: str, + peer_port: int, + model_id_normalized: str, + timeout: float = 5.0, +) -> list[PeerFileInfo] | None: + """Query a peer's file server for available files for a model. + + Returns None if the peer is unreachable. + """ + url = f"http://{peer_host}:{peer_port}/status/{model_id_normalized}" + try: + async with aiohttp.ClientSession( + timeout=aiohttp.ClientTimeout(total=timeout) + ) as session: + async with session.get(url) as r: + if r.status != 200: + return None + data = await r.json() + return [PeerFileInfo(**f) for f in data.get("files", [])] + except Exception as e: + logger.debug(f"Could not reach peer {peer_host}:{peer_port}: {e}") + return None + + +async def download_file_from_peer( + peer_host: str, + peer_port: int, + model_id_normalized: str, + file_path: str, + target_dir: Path, + expected_size: int, + on_progress: Callable[[int, int, bool], None] = lambda _a, _b, _c: None, + max_poll_attempts: int = 60, + poll_interval: float = 3.0, +) -> Path | None: + """Download a single file from a peer's file server. + + Supports streaming relay: if the peer is still downloading the file, + we fetch available bytes, wait, and poll for more until the file is + complete. + + Returns the final file path on success, or None on failure (caller + should fall back to HuggingFace). + """ + target_path = target_dir / file_path + partial_path = target_dir / f"{file_path}.partial" + + # Check if already complete locally + if await aios.path.exists(target_path): + local_size = (await aios.stat(target_path)).st_size + if local_size == expected_size: + on_progress(expected_size, expected_size, True) + return target_path + + await aios.makedirs((target_dir / file_path).parent, exist_ok=True) + + url = f"http://{peer_host}:{peer_port}/files/{model_id_normalized}/{file_path}" + n_read = 0 + + # Resume from existing partial + if await aios.path.exists(partial_path): + n_read = (await aios.stat(partial_path)).st_size + + poll_count = 0 + chunk_size = 8 * 1024 * 1024 # 8MB, matching HF download + + try: + while n_read < expected_size and poll_count < max_poll_attempts: + headers: dict[str, str] = {} + if n_read > 0: + headers["Range"] = f"bytes={n_read}-" + + got_bytes = False + async with aiohttp.ClientSession( + timeout=aiohttp.ClientTimeout(total=300, sock_read=60) + ) as session: + async with session.get(url, headers=headers) as r: + if r.status == 416: + # Range not satisfiable - peer doesn't have more yet + pass + elif r.status in (200, 206): + peer_complete = r.headers.get("X-Exo-Complete") == "true" + safe_bytes = int(r.headers.get("X-Exo-Safe-Bytes", "0")) + + async with aiofiles.open( + partial_path, "ab" if n_read > 0 else "wb" + ) as f: + while True: + chunk = await r.content.read(chunk_size) + if not chunk: + break + written = await f.write(chunk) + n_read += written + got_bytes = True + on_progress(n_read, expected_size, False) + elif r.status == 404: + logger.debug( + f"File {file_path} not found on peer {peer_host}" + ) + return None + else: + logger.warning( + f"Unexpected status {r.status} from peer {peer_host}" + ) + return None + + # Check if we're done + if n_read >= expected_size: + break + + # If we got no new bytes, the peer might still be downloading + if not got_bytes: + poll_count += 1 + logger.debug( + f"Waiting for peer {peer_host} to download more of {file_path} " + f"({n_read}/{expected_size}, poll {poll_count}/{max_poll_attempts})" + ) + await asyncio.sleep(poll_interval) + else: + # Got data, reset poll counter + poll_count = 0 + + if n_read < expected_size: + logger.warning( + f"Peer download incomplete for {file_path}: {n_read}/{expected_size}" + ) + return None + + # Rename partial to final + await aios.rename(partial_path, target_path) + on_progress(expected_size, expected_size, True) + logger.info( + f"Downloaded {file_path} from peer {peer_host} ({expected_size} bytes)" + ) + return target_path + + except Exception as e: + logger.warning(f"Peer download failed for {file_path} from {peer_host}: {e}") + return None diff --git a/src/exo/download/peer_file_server.py b/src/exo/download/peer_file_server.py new file mode 100644 index 0000000000..f36823ac27 --- /dev/null +++ b/src/exo/download/peer_file_server.py @@ -0,0 +1,174 @@ +"""Lightweight HTTP file server for peer-to-peer model downloads. + +Each exo node runs a PeerFileServer that serves model files from the local +cache directory. When one node finishes downloading a model from HuggingFace, +other nodes on the same LAN can fetch it directly over HTTP instead of +re-downloading from the internet. + +Supports serving in-progress downloads via .partial.meta files that track +how many bytes have been safely flushed to disk. +""" + +import json +from pathlib import Path + +import aiofiles +import aiofiles.os as aios +from aiohttp import web +from loguru import logger + + +class PeerFileServer: + """HTTP server that exposes local model files for peer download.""" + + def __init__(self, host: str, port: int, models_dir: Path) -> None: + self.host = host + self.port = port + self.models_dir = models_dir + self._app = web.Application() + self._app.router.add_get("/status/{model_id}", self._handle_status) + self._app.router.add_get( + "/files/{model_id}/{file_path:.+}", self._handle_file + ) + self._app.router.add_get("/health", self._handle_health) + self._runner: web.AppRunner | None = None + + async def run(self) -> None: + self._runner = web.AppRunner(self._app) + await self._runner.setup() + site = web.TCPSite(self._runner, self.host, self.port) + await site.start() + logger.info(f"PeerFileServer listening on {self.host}:{self.port}") + + async def shutdown(self) -> None: + if self._runner: + await self._runner.cleanup() + + async def _handle_health(self, request: web.Request) -> web.Response: + return web.json_response({"status": "ok"}) + + async def _handle_status(self, request: web.Request) -> web.Response: + """Return status of all files for a model (complete + in-progress).""" + model_id = request.match_info["model_id"] + model_dir = self.models_dir / model_id + + if not await aios.path.exists(model_dir): + return web.json_response({"files": []}) + + files = [] + for item in model_dir.iterdir(): + if item.is_dir() or item.name.endswith(".partial.meta"): + continue + + if item.name.endswith(".partial"): + # In-progress file - read meta for safe bytes + meta = await _read_partial_meta(item) + if meta: + files.append( + { + "path": item.name.removesuffix(".partial"), + "size": meta.get("total", 0), + "complete": False, + "safe_bytes": meta.get("safe_bytes", 0), + } + ) + else: + # Complete file + stat = await aios.stat(item) + files.append( + { + "path": item.name, + "size": stat.st_size, + "complete": True, + "safe_bytes": stat.st_size, + } + ) + + return web.json_response({"files": files}) + + async def _handle_file(self, request: web.Request) -> web.StreamResponse: + """Serve a model file with Range request support. + + For complete files: standard HTTP file serving. + For .partial files: serves only the safe byte range (flushed to disk). + """ + model_id = request.match_info["model_id"] + file_path = request.match_info["file_path"] + + model_dir = self.models_dir / model_id + complete_path = model_dir / file_path + partial_path = model_dir / f"{file_path}.partial" + + # Determine which file to serve and its safe size + if await aios.path.exists(complete_path): + serve_path = complete_path + file_size = (await aios.stat(complete_path)).st_size + safe_bytes = file_size + is_complete = True + elif await aios.path.exists(partial_path): + meta = await _read_partial_meta(partial_path) + if not meta or meta.get("safe_bytes", 0) == 0: + return web.Response(status=404, text="File not available yet") + serve_path = partial_path + file_size = meta.get("total", 0) + safe_bytes = meta["safe_bytes"] + is_complete = False + else: + return web.Response(status=404, text="File not found") + + # Parse Range header + range_header = request.headers.get("Range") + start = 0 + if range_header: + try: + range_spec = range_header.replace("bytes=", "") + start = int(range_spec.split("-")[0]) + except (ValueError, IndexError): + return web.Response(status=416, text="Invalid range") + + if start >= safe_bytes: + return web.Response(status=416, text="Range not satisfiable") + + end = safe_bytes # Serve up to safe boundary only + content_length = end - start + + response = web.StreamResponse( + status=206 if start > 0 else 200, + headers={ + "Content-Type": "application/octet-stream", + "Content-Length": str(content_length), + "Accept-Ranges": "bytes", + "Content-Range": f"bytes {start}-{end - 1}/{file_size}", + "X-Exo-Safe-Bytes": str(safe_bytes), + "X-Exo-Total-Size": str(file_size), + "X-Exo-Complete": "true" if is_complete else "false", + }, + ) + await response.prepare(request) + + chunk_size = 8 * 1024 * 1024 # 8MB chunks matching HF download + async with aiofiles.open(serve_path, "rb") as f: + await f.seek(start) + remaining = content_length + while remaining > 0: + to_read = min(chunk_size, remaining) + chunk = await f.read(to_read) + if not chunk: + break + await response.write(chunk) + remaining -= len(chunk) + + await response.write_eof() + return response + + +async def _read_partial_meta(partial_path: Path) -> dict | None: + """Read the .partial.meta companion file for a .partial download.""" + meta_path = Path(f"{partial_path}.meta") + if not await aios.path.exists(meta_path): + return None + try: + async with aiofiles.open(meta_path, "r") as f: + return json.loads(await f.read()) + except (json.JSONDecodeError, OSError): + return None diff --git a/src/exo/download/peer_shard_downloader.py b/src/exo/download/peer_shard_downloader.py new file mode 100644 index 0000000000..ee0c3404e2 --- /dev/null +++ b/src/exo/download/peer_shard_downloader.py @@ -0,0 +1,277 @@ +"""Peer-aware shard downloader that tries LAN peers before HuggingFace. + +Wraps an existing ShardDownloader and adds a peer-download step: before +hitting HuggingFace, check if any peer on the LAN already has the model +(or is downloading it) and fetch from them instead. Falls back to the +inner downloader (HF) if peer download fails. +""" + +import asyncio +import time +from collections.abc import Awaitable +from datetime import timedelta +from pathlib import Path +from typing import AsyncIterator, Callable + +from loguru import logger + +from exo.download.download_utils import ( + RepoDownloadProgress, + calculate_repo_progress, + ensure_models_dir, + fetch_file_list_with_cache, + is_image_model, + resolve_allow_patterns, +) +from exo.download.huggingface_utils import filter_repo_objects +from exo.download.peer_download import ( + download_file_from_peer, + get_peer_file_status, +) +from exo.download.peer_state import PeerStateProvider +from exo.download.shard_downloader import ShardDownloader +from exo.shared.types.memory import Memory +from exo.shared.types.worker.downloads import RepoFileDownloadProgress +from exo.shared.types.worker.shards import ShardMetadata + + +class PeerAwareShardDownloader(ShardDownloader): + """ShardDownloader that tries peer download before HuggingFace. + + Decorates an inner ShardDownloader (typically ResumableShardDownloader). + On ensure_shard(), checks if any cluster peer already has the model + and downloads from them over the LAN. Falls back to the inner + downloader if no peer has it or the peer transfer fails. + """ + + def __init__( + self, + inner: ShardDownloader, + peer_state_provider: PeerStateProvider, + ) -> None: + self._inner = inner + self._peer_state = peer_state_provider + self._progress_callbacks: list[ + Callable[[ShardMetadata, RepoDownloadProgress], Awaitable[None]] + ] = [] + + def on_progress( + self, + callback: Callable[[ShardMetadata, RepoDownloadProgress], Awaitable[None]], + ) -> None: + self._inner.on_progress(callback) + self._progress_callbacks.append(callback) + + async def ensure_shard( + self, shard: ShardMetadata, config_only: bool = False + ) -> Path: + if config_only: + return await self._inner.ensure_shard(shard, config_only=True) + + model_id = shard.model_card.model_id + normalized = model_id.normalize() + + # Check if any peer has this model + peers = self._peer_state.get_peers_for_model(normalized) + if not peers: + logger.debug(f"No peers have {model_id}, downloading from HuggingFace") + return await self._inner.ensure_shard(shard, config_only=False) + + # Try each peer (completed peers first) + for peer in peers: + logger.info( + f"Attempting peer download of {model_id} from " + f"{peer.ip} (status: {peer.status})" + ) + result = await self._try_peer_download( + shard, peer.ip, self._peer_state.peer_download_port, normalized + ) + if result is not None: + logger.info( + f"Successfully downloaded {model_id} from peer {peer.ip}" + ) + return result + logger.info( + f"Peer download from {peer.ip} failed, trying next peer or HuggingFace" + ) + + # All peers failed, fall back to HuggingFace + logger.info(f"All peer downloads failed for {model_id}, falling back to HuggingFace") + return await self._inner.ensure_shard(shard, config_only=False) + + async def _try_peer_download( + self, + shard: ShardMetadata, + peer_ip: str, + peer_port: int, + model_id_normalized: str, + ) -> Path | None: + """Attempt to download all model files from a single peer. + + Returns the model directory path on success, None on failure. + """ + # First, check what the peer has + peer_files = await get_peer_file_status( + peer_ip, peer_port, model_id_normalized + ) + if not peer_files: + return None + + peer_file_map = {f.path: f for f in peer_files} + + # Get the file list we need (same logic as download_shard) + revision = "main" + target_dir = await ensure_models_dir() / model_id_normalized + + try: + file_list = await fetch_file_list_with_cache( + shard.model_card.model_id, + revision, + recursive=True, + skip_internet=False, + ) + except Exception: + # Can't get file list - fall back + return None + + allow_patterns = await resolve_allow_patterns(shard) + filtered_file_list = list( + filter_repo_objects( + file_list, allow_patterns=allow_patterns, key=lambda x: x.path + ) + ) + + if is_image_model(shard): + filtered_file_list = [ + f + for f in filtered_file_list + if "/" in f.path or not f.path.endswith(".safetensors") + ] + + # Check the peer has all (or most) files we need + files_on_peer = 0 + for f in filtered_file_list: + if f.path in peer_file_map: + files_on_peer += 1 + + if files_on_peer == 0: + logger.debug(f"Peer has no files we need for {model_id_normalized}") + return None + + # Download from peer with progress tracking + all_start_time = time.time() + file_progress: dict[str, RepoFileDownloadProgress] = {} + semaphore = asyncio.Semaphore(8) + failed = False + + async def download_one(file_path: str, expected_size: int) -> bool: + """Download a single file from peer. Returns True on success.""" + + def on_file_progress( + curr_bytes: int, total_bytes: int, is_renamed: bool + ) -> None: + file_progress[file_path] = RepoFileDownloadProgress( + repo_id=str(shard.model_card.model_id), + repo_revision=revision, + file_path=file_path, + downloaded=Memory.from_bytes(curr_bytes), + downloaded_this_session=Memory.from_bytes(curr_bytes), + total=Memory.from_bytes(total_bytes), + speed=curr_bytes / max(time.time() - all_start_time, 0.1), + eta=timedelta( + seconds=(total_bytes - curr_bytes) + / max( + curr_bytes / max(time.time() - all_start_time, 0.1), + 0.1, + ) + ), + status="complete" if is_renamed else "in_progress", + start_time=all_start_time, + ) + # Fire progress callbacks + progress = calculate_repo_progress( + shard, + shard.model_card.model_id, + revision, + file_progress, + all_start_time, + ) + for cb in self._progress_callbacks: + asyncio.create_task(cb(shard, progress)) + + async with semaphore: + result = await download_file_from_peer( + peer_ip, + peer_port, + model_id_normalized, + file_path, + target_dir, + expected_size, + on_progress=on_file_progress, + ) + return result is not None + + # Initialize progress for all files + for f in filtered_file_list: + file_progress[f.path] = RepoFileDownloadProgress( + repo_id=str(shard.model_card.model_id), + repo_revision=revision, + file_path=f.path, + downloaded=Memory.from_bytes(0), + downloaded_this_session=Memory.from_bytes(0), + total=Memory.from_bytes(f.size or 0), + speed=0, + eta=timedelta(0), + status="not_started", + start_time=all_start_time, + ) + + # Download all files in parallel + tasks = [] + for f in filtered_file_list: + if f.size is None or f.size == 0: + continue + peer_info = peer_file_map.get(f.path) + if peer_info and peer_info.safe_bytes > 0: + tasks.append(download_one(f.path, f.size)) + else: + # Peer doesn't have this file yet - this means incomplete peer + # We could still try for the files it has, but for simplicity + # fail the whole peer download if any file is missing + failed = True + break + + if failed: + return None + + results = await asyncio.gather(*tasks, return_exceptions=True) + if any(isinstance(r, Exception) or r is False for r in results): + return None + + # Emit final progress + final_progress = calculate_repo_progress( + shard, + shard.model_card.model_id, + revision, + file_progress, + all_start_time, + ) + for cb in self._progress_callbacks: + await cb(shard, final_progress) + + # Return path (same as download_shard does) + gguf = next( + (f for f in filtered_file_list if f.path.endswith(".gguf")), None + ) + return (target_dir / gguf.path) if gguf else target_dir + + async def get_shard_download_status( + self, + ) -> AsyncIterator[tuple[Path, RepoDownloadProgress]]: + async for path, status in self._inner.get_shard_download_status(): + yield path, status + + async def get_shard_download_status_for_shard( + self, shard: ShardMetadata + ) -> RepoDownloadProgress: + return await self._inner.get_shard_download_status_for_shard(shard) diff --git a/src/exo/download/peer_state.py b/src/exo/download/peer_state.py new file mode 100644 index 0000000000..bc87c0ce47 --- /dev/null +++ b/src/exo/download/peer_state.py @@ -0,0 +1,93 @@ +"""Peer state provider for discovering which peers have which models. + +Reads from the shared State object (populated via gossipsub events) to +determine which peer nodes have completed or are in the process of +downloading a given model. Resolves peer IP addresses from the topology. +""" + +from dataclasses import dataclass +from typing import Callable, Literal + +from loguru import logger + +from exo.shared.types.common import NodeId +from exo.shared.types.state import State +from exo.shared.types.topology import SocketConnection +from exo.shared.types.worker.downloads import ( + DownloadCompleted, + DownloadOngoing, +) + + +@dataclass(frozen=True) +class PeerInfo: + """A peer that has (or is downloading) a model.""" + + node_id: NodeId + ip: str + status: Literal["complete", "ongoing"] + + +class PeerStateProvider: + """Provides information about which peers have which models. + + Reads from the Worker's shared State to find peers and resolve their + network addresses from the topology graph. + """ + + def __init__( + self, + node_id: NodeId, + state_accessor: Callable[[], State], + peer_download_port: int, + ) -> None: + self.node_id = node_id + self._state_accessor = state_accessor + self.peer_download_port = peer_download_port + + def get_peers_for_model(self, model_id: str) -> list[PeerInfo]: + """Find peers that have a specific model (complete or in-progress). + + Returns peers sorted by completeness (completed first, then ongoing). + Excludes self. + """ + state = self._state_accessor() + peers: list[PeerInfo] = [] + + # Check download status across all nodes + for peer_node_id, download_list in state.downloads.items(): + if peer_node_id == self.node_id: + continue + + for dl in download_list: + dl_model_id = dl.shard_metadata.model_card.model_id + if dl_model_id.normalize() != model_id: + continue + + if isinstance(dl, DownloadCompleted): + status: Literal["complete", "ongoing"] = "complete" + elif isinstance(dl, DownloadOngoing): + status = "ongoing" + else: + continue + + # Resolve IP from topology + ip = self._resolve_peer_ip(peer_node_id, state) + if ip: + peers.append(PeerInfo(node_id=peer_node_id, ip=ip, status=status)) + + # Sort: completed peers first + peers.sort(key=lambda p: 0 if p.status == "complete" else 1) + return peers + + def _resolve_peer_ip(self, peer_node_id: NodeId, state: State) -> str | None: + """Resolve a peer's IP address from the topology graph.""" + try: + for conn in state.topology.out_edges(self.node_id): + if conn.sink == peer_node_id and isinstance( + conn.edge, SocketConnection + ): + return conn.edge.sink_multiaddr.ip_address + except Exception as e: + logger.debug(f"Could not resolve IP for peer {peer_node_id}: {e}") + return None diff --git a/src/exo/download/tests/test_peer_download.py b/src/exo/download/tests/test_peer_download.py new file mode 100644 index 0000000000..04c49de786 --- /dev/null +++ b/src/exo/download/tests/test_peer_download.py @@ -0,0 +1,265 @@ +"""Tests for peer-to-peer model downloading.""" + +import asyncio +import json +from collections.abc import AsyncIterator +from pathlib import Path + +import aiofiles +import aiofiles.os as aios +import pytest + +from exo.download.peer_download import download_file_from_peer, get_peer_file_status +from exo.download.peer_file_server import PeerFileServer + + +@pytest.fixture +async def temp_models_dir(tmp_path: Path) -> AsyncIterator[Path]: + """Set up a temporary models directory for testing.""" + models_dir = tmp_path / "models" + await aios.makedirs(models_dir, exist_ok=True) + yield models_dir + + +@pytest.fixture +async def peer_server(temp_models_dir: Path) -> AsyncIterator[PeerFileServer]: + """Start a PeerFileServer on a random port for testing.""" + server = PeerFileServer(host="127.0.0.1", port=0, models_dir=temp_models_dir) + # Use port 0 to let OS assign a free port + from aiohttp import web + + server._runner = web.AppRunner(server._app) + await server._runner.setup() + site = web.TCPSite(server._runner, "127.0.0.1", 0) + await site.start() + # Get the actual port assigned + server.port = site._server.sockets[0].getsockname()[1] # type: ignore[union-attr] + yield server + await server.shutdown() + + +class TestPeerFileServer: + """Tests for the HTTP file server that serves model files to peers.""" + + async def test_health_check(self, peer_server: PeerFileServer) -> None: + """Health endpoint should return ok.""" + import aiohttp + + async with aiohttp.ClientSession() as session: + async with session.get( + f"http://127.0.0.1:{peer_server.port}/health" + ) as r: + assert r.status == 200 + data = await r.json() + assert data["status"] == "ok" + + async def test_status_empty_model(self, peer_server: PeerFileServer) -> None: + """Status for non-existent model should return empty file list.""" + files = await get_peer_file_status( + "127.0.0.1", peer_server.port, "nonexistent--model" + ) + assert files is not None + assert len(files) == 0 + + async def test_status_with_complete_file( + self, peer_server: PeerFileServer, temp_models_dir: Path + ) -> None: + """Status should report complete files correctly.""" + model_dir = temp_models_dir / "test--model" + await aios.makedirs(model_dir, exist_ok=True) + + # Create a complete test file + async with aiofiles.open(model_dir / "config.json", "wb") as f: + await f.write(b'{"test": true}') + + files = await get_peer_file_status( + "127.0.0.1", peer_server.port, "test--model" + ) + assert files is not None + assert len(files) == 1 + assert files[0].path == "config.json" + assert files[0].complete is True + assert files[0].safe_bytes == 14 + + async def test_status_with_partial_file( + self, peer_server: PeerFileServer, temp_models_dir: Path + ) -> None: + """Status should report partial files with safe byte count.""" + model_dir = temp_models_dir / "test--model" + await aios.makedirs(model_dir, exist_ok=True) + + # Create a partial file with metadata + partial_data = b"x" * 1024 + async with aiofiles.open(model_dir / "weights.safetensors.partial", "wb") as f: + await f.write(partial_data) + + meta = {"safe_bytes": 1024, "total": 4096, "etag": "abc123"} + async with aiofiles.open( + model_dir / "weights.safetensors.partial.meta", "w" + ) as f: + await f.write(json.dumps(meta)) + + files = await get_peer_file_status( + "127.0.0.1", peer_server.port, "test--model" + ) + assert files is not None + assert len(files) == 1 + assert files[0].path == "weights.safetensors" + assert files[0].complete is False + assert files[0].safe_bytes == 1024 + assert files[0].size == 4096 + + async def test_serve_complete_file( + self, peer_server: PeerFileServer, temp_models_dir: Path + ) -> None: + """Should serve a complete file with correct headers.""" + model_dir = temp_models_dir / "test--model" + await aios.makedirs(model_dir, exist_ok=True) + + content = b"hello world test content" + async with aiofiles.open(model_dir / "config.json", "wb") as f: + await f.write(content) + + import aiohttp + + async with aiohttp.ClientSession() as session: + async with session.get( + f"http://127.0.0.1:{peer_server.port}/files/test--model/config.json" + ) as r: + assert r.status == 200 + assert r.headers["X-Exo-Complete"] == "true" + body = await r.read() + assert body == content + + async def test_serve_with_range_request( + self, peer_server: PeerFileServer, temp_models_dir: Path + ) -> None: + """Should support Range requests for resume.""" + model_dir = temp_models_dir / "test--model" + await aios.makedirs(model_dir, exist_ok=True) + + content = b"0123456789abcdef" + async with aiofiles.open(model_dir / "weights.bin", "wb") as f: + await f.write(content) + + import aiohttp + + async with aiohttp.ClientSession() as session: + async with session.get( + f"http://127.0.0.1:{peer_server.port}/files/test--model/weights.bin", + headers={"Range": "bytes=8-"}, + ) as r: + assert r.status == 206 + body = await r.read() + assert body == b"89abcdef" + + async def test_file_not_found(self, peer_server: PeerFileServer) -> None: + """Should return 404 for missing files.""" + import aiohttp + + async with aiohttp.ClientSession() as session: + async with session.get( + f"http://127.0.0.1:{peer_server.port}/files/test--model/missing.bin" + ) as r: + assert r.status == 404 + + +class TestPeerDownloadClient: + """Tests for downloading files from a peer server.""" + + async def test_download_complete_file( + self, peer_server: PeerFileServer, temp_models_dir: Path, tmp_path: Path + ) -> None: + """Should download a complete file from peer.""" + # Set up source file on the peer server + model_dir = temp_models_dir / "test--model" + await aios.makedirs(model_dir, exist_ok=True) + + content = b"model weights data " * 100 + async with aiofiles.open(model_dir / "weights.bin", "wb") as f: + await f.write(content) + + # Download to a different directory + download_dir = tmp_path / "downloads" / "test--model" + await aios.makedirs(download_dir, exist_ok=True) + + progress_calls: list[tuple[int, int, bool]] = [] + + result = await download_file_from_peer( + "127.0.0.1", + peer_server.port, + "test--model", + "weights.bin", + download_dir, + len(content), + on_progress=lambda c, t, r: progress_calls.append((c, t, r)), + ) + + assert result is not None + assert result == download_dir / "weights.bin" + async with aiofiles.open(result, "rb") as f: + downloaded = await f.read() + assert downloaded == content + # Should have progress calls including final + assert len(progress_calls) > 0 + assert progress_calls[-1][2] is True # is_renamed + + async def test_download_returns_none_on_missing( + self, peer_server: PeerFileServer, tmp_path: Path + ) -> None: + """Should return None when file doesn't exist on peer.""" + download_dir = tmp_path / "downloads" / "test--model" + await aios.makedirs(download_dir, exist_ok=True) + + result = await download_file_from_peer( + "127.0.0.1", + peer_server.port, + "test--model", + "nonexistent.bin", + download_dir, + 1000, + ) + assert result is None + + async def test_download_returns_none_on_unreachable_peer( + self, tmp_path: Path + ) -> None: + """Should return None when peer is unreachable.""" + download_dir = tmp_path / "downloads" / "test--model" + await aios.makedirs(download_dir, exist_ok=True) + + result = await download_file_from_peer( + "127.0.0.1", + 19999, # Nobody listening + "test--model", + "weights.bin", + download_dir, + 1000, + ) + assert result is None + + async def test_skip_already_complete( + self, peer_server: PeerFileServer, temp_models_dir: Path, tmp_path: Path + ) -> None: + """Should skip download if file already exists locally with correct size.""" + model_dir = temp_models_dir / "test--model" + await aios.makedirs(model_dir, exist_ok=True) + + content = b"existing content" + # File already exists in target + download_dir = tmp_path / "downloads" / "test--model" + await aios.makedirs(download_dir, exist_ok=True) + async with aiofiles.open(download_dir / "config.json", "wb") as f: + await f.write(content) + + result = await download_file_from_peer( + "127.0.0.1", + peer_server.port, + "test--model", + "config.json", + download_dir, + len(content), + ) + + assert result is not None + assert result == download_dir / "config.json" diff --git a/src/exo/main.py b/src/exo/main.py index 9edee42096..b43b7cd0aa 100644 --- a/src/exo/main.py +++ b/src/exo/main.py @@ -16,10 +16,12 @@ from exo.api.main import API from exo.download.coordinator import DownloadCoordinator from exo.download.impl_shard_downloader import exo_shard_downloader +from exo.download.peer_file_server import PeerFileServer +from exo.download.peer_state import PeerStateProvider from exo.master.main import Master from exo.routing.event_router import EventRouter from exo.routing.router import Router, get_node_id_keypair -from exo.shared.constants import EXO_LOG +from exo.shared.constants import EXO_LOG, EXO_MODELS_DIR, EXO_PEER_DOWNLOAD_PORT from exo.shared.election import Election, ElectionResult from exo.shared.logging import logger_cleanup, logger_set_context, logger_setup from exo.shared.types.common import NodeId, SessionId @@ -44,6 +46,8 @@ class Node: node_id: NodeId offline: bool _api_port: int + peer_file_server: PeerFileServer | None = None + peer_state_provider: PeerStateProvider | None = None _tg: TaskGroup = field(init=False, default_factory=TaskGroup) @classmethod @@ -71,17 +75,12 @@ async def create(cls, args: "Args") -> Self: logger.info(f"Starting node {node_id}") - # Create DownloadCoordinator (unless --no-downloads) - if not args.no_downloads: - download_coordinator = DownloadCoordinator( - node_id, - exo_shard_downloader(offline=args.offline), - event_sender=event_router.sender(), - download_command_receiver=router.receiver(topics.DOWNLOAD_COMMANDS), - offline=args.offline, - ) - else: - download_coordinator = None + # Peer download components are created below, after Worker (needs state access) + peer_file_server: PeerFileServer | None = None + peer_state_provider: PeerStateProvider | None = None + + # DownloadCoordinator is also created below, after Worker + download_coordinator: DownloadCoordinator | None = None if args.spawn_api: api = API( @@ -107,6 +106,31 @@ async def create(cls, args: "Args") -> Self: else: worker = None + # Create peer download components and DownloadCoordinator + # (after Worker, since PeerStateProvider needs access to worker state) + if not args.no_downloads: + if not args.no_peer_download and worker is not None: + peer_file_server = PeerFileServer( + host="0.0.0.0", + port=EXO_PEER_DOWNLOAD_PORT, + models_dir=EXO_MODELS_DIR, + ) + peer_state_provider = PeerStateProvider( + node_id=node_id, + state_accessor=lambda: worker.state, + peer_download_port=EXO_PEER_DOWNLOAD_PORT, + ) + download_coordinator = DownloadCoordinator( + node_id, + exo_shard_downloader( + offline=args.offline, + peer_state_provider=peer_state_provider, + ), + event_sender=event_router.sender(), + download_command_receiver=router.receiver(topics.DOWNLOAD_COMMANDS), + offline=args.offline, + ) + # We start every node with a master master = Master( node_id, @@ -144,6 +168,8 @@ async def create(cls, args: "Args") -> Self: node_id, args.offline, args.api_port, + peer_file_server, + peer_state_provider, ) logger_set_context( node_id=node_id, role="master" if args.force_master else "node" @@ -161,6 +187,8 @@ async def run(self): tg.start_soon(self.router.run) tg.start_soon(self.event_router.run) tg.start_soon(self.election.run) + if self.peer_file_server: + tg.start_soon(self.peer_file_server.run) if self.download_coordinator: tg.start_soon(self.download_coordinator.run) if self.worker: @@ -252,7 +280,10 @@ async def _elect_loop(self): await self.download_coordinator.shutdown() self.download_coordinator = DownloadCoordinator( self.node_id, - exo_shard_downloader(offline=self.offline), + exo_shard_downloader( + offline=self.offline, + peer_state_provider=self.peer_state_provider, + ), event_sender=self.event_router.sender(), download_command_receiver=self.router.receiver( topics.DOWNLOAD_COMMANDS @@ -273,6 +304,12 @@ async def _elect_loop(self): ), api_port=self._api_port, ) + # Update peer state provider to reference the new worker + if self.peer_state_provider is not None: + new_worker = self.worker + self.peer_state_provider._state_accessor = ( + lambda: new_worker.state + ) self._tg.start_soon(self.worker.run) if self.api: self.api.reset(result.won_clock, self.event_router.receiver()) @@ -430,6 +467,7 @@ class Args(FrozenModel): tb_only: bool = False no_worker: bool = False no_downloads: bool = False + no_peer_download: bool = False offline: bool = os.getenv("EXO_OFFLINE", "false").lower() == "true" no_batch: bool = False fast_synch: bool | None = None # None = auto, True = force on, False = force off @@ -481,6 +519,11 @@ def parse(cls) -> Self: action="store_true", help="Disable the download coordinator (node won't download models)", ) + parser.add_argument( + "--no-peer-download", + action="store_true", + help="Disable peer-to-peer model downloads (each node downloads from HuggingFace independently)", + ) parser.add_argument( "--offline", action="store_true", diff --git a/src/exo/shared/constants.py b/src/exo/shared/constants.py index ad76a88ffb..f823195798 100644 --- a/src/exo/shared/constants.py +++ b/src/exo/shared/constants.py @@ -100,3 +100,6 @@ def _parse_colon_dirs(env_var: str) -> tuple[Path, ...]: EXO_MAX_CONCURRENT_REQUESTS = int(os.getenv("EXO_MAX_CONCURRENT_REQUESTS", "8")) EXO_MAX_INSTANCE_RETRIES = 5 + +# Peer-to-peer model download server port (one above default API port) +EXO_PEER_DOWNLOAD_PORT = int(os.getenv("EXO_PEER_DOWNLOAD_PORT", "52416")) From 4740f9d4fc948dffabc3ca00b0c386228098bf20 Mon Sep 17 00:00:00 2001 From: Jordan Miller Date: Thu, 12 Mar 2026 08:20:16 -0500 Subject: [PATCH 03/29] refactor: decouple peer discovery from worker state, add link prioritization MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Addresses review feedback from @rltakashige: 1. **Decouple from worker state**: Peer discovery is now a pure function called by the Worker (which owns the state) when emitting StartDownload. Peer endpoints are embedded in the StartDownload command as `available_peers`, so the DownloadCoordinator stays self-contained and has no dependency on Worker state. 2. **Link prioritization**: Peers are sorted by connection quality — RDMA/Thunderbolt connections rank higher than socket connections, and completed downloads rank higher than ongoing ones. Architecture change: - PeerStateProvider class → discover_peers_for_model() pure function - StartDownload command gains `available_peers: list[PeerEndpoint]` field - PeerEndpoint includes `connection_type` for prioritization - Worker computes peers at emit time → DownloadCoordinator receives them - main.py no longer has worker-state coupling (PeerStateProvider removed) Co-Authored-By: Claude Opus 4.6 --- src/exo/download/coordinator.py | 11 +- src/exo/download/impl_shard_downloader.py | 7 +- src/exo/download/peer_shard_downloader.py | 60 ++++---- src/exo/download/peer_state.py | 175 +++++++++++++--------- src/exo/main.py | 44 ++---- src/exo/shared/types/commands.py | 11 ++ src/exo/worker/main.py | 11 +- 7 files changed, 180 insertions(+), 139 deletions(-) diff --git a/src/exo/download/coordinator.py b/src/exo/download/coordinator.py index 1f5a99f8ef..b5539a59c3 100644 --- a/src/exo/download/coordinator.py +++ b/src/exo/download/coordinator.py @@ -15,6 +15,7 @@ map_repo_download_progress_to_download_progress_data, resolve_existing_model, ) +from exo.download.peer_shard_downloader import PeerAwareShardDownloader from exo.download.shard_downloader import ShardDownloader from exo.shared.constants import EXO_DEFAULT_MODELS_DIR, EXO_MODELS_READ_ONLY_DIRS from exo.shared.models.model_cards import ModelCard, ModelId, get_model_cards @@ -225,7 +226,15 @@ async def _command_processor(self) -> None: continue match cmd.command: - case StartDownload(shard_metadata=shard): + case StartDownload(shard_metadata=shard, available_peers=peers): + # Pass peer endpoints to the shard downloader if it supports it + if isinstance(self.shard_downloader, PeerAwareShardDownloader): + self.shard_downloader.set_available_peers(peers) + elif hasattr(self.shard_downloader, "shard_downloader") and isinstance( + self.shard_downloader.shard_downloader, PeerAwareShardDownloader # type: ignore[union-attr] + ): + # Unwrap SingletonShardDownloader + self.shard_downloader.shard_downloader.set_available_peers(peers) # type: ignore[union-attr] await self._start_download(shard) case DeleteDownload(model_id=model_id): await self._delete_download(model_id) diff --git a/src/exo/download/impl_shard_downloader.py b/src/exo/download/impl_shard_downloader.py index 1cb304c66d..4db0f1f36c 100644 --- a/src/exo/download/impl_shard_downloader.py +++ b/src/exo/download/impl_shard_downloader.py @@ -11,7 +11,6 @@ download_shard, ) from exo.download.peer_shard_downloader import PeerAwareShardDownloader -from exo.download.peer_state import PeerStateProvider from exo.download.shard_downloader import ShardDownloader from exo.shared.models.model_cards import ( ModelCard, @@ -29,13 +28,13 @@ def exo_shard_downloader( max_parallel_downloads: int = 8, offline: bool = False, - peer_state_provider: PeerStateProvider | None = None, + peer_download_enabled: bool = False, ) -> ShardDownloader: inner: ShardDownloader = ResumableShardDownloader( max_parallel_downloads, offline=offline ) - if peer_state_provider is not None: - inner = PeerAwareShardDownloader(inner, peer_state_provider) + if peer_download_enabled: + inner = PeerAwareShardDownloader(inner) return SingletonShardDownloader(inner) diff --git a/src/exo/download/peer_shard_downloader.py b/src/exo/download/peer_shard_downloader.py index ee0c3404e2..4b5a71db34 100644 --- a/src/exo/download/peer_shard_downloader.py +++ b/src/exo/download/peer_shard_downloader.py @@ -1,9 +1,12 @@ """Peer-aware shard downloader that tries LAN peers before HuggingFace. Wraps an existing ShardDownloader and adds a peer-download step: before -hitting HuggingFace, check if any peer on the LAN already has the model -(or is downloading it) and fetch from them instead. Falls back to the -inner downloader (HF) if peer download fails. +hitting HuggingFace, try peers provided in the available_peers list. +Falls back to the inner downloader (HF) if peer download fails. + +The peer list is computed by the Worker at command-emit time and passed +through the StartDownload command, keeping the download coordinator +decoupled from Worker state. """ import asyncio @@ -28,8 +31,8 @@ download_file_from_peer, get_peer_file_status, ) -from exo.download.peer_state import PeerStateProvider from exo.download.shard_downloader import ShardDownloader +from exo.shared.types.commands import PeerEndpoint from exo.shared.types.memory import Memory from exo.shared.types.worker.downloads import RepoFileDownloadProgress from exo.shared.types.worker.shards import ShardMetadata @@ -39,21 +42,26 @@ class PeerAwareShardDownloader(ShardDownloader): """ShardDownloader that tries peer download before HuggingFace. Decorates an inner ShardDownloader (typically ResumableShardDownloader). - On ensure_shard(), checks if any cluster peer already has the model - and downloads from them over the LAN. Falls back to the inner - downloader if no peer has it or the peer transfer fails. + On ensure_shard(), if available_peers were provided, tries downloading + from them over the LAN first. Falls back to the inner downloader if + no peer has it or the transfer fails. """ - def __init__( - self, - inner: ShardDownloader, - peer_state_provider: PeerStateProvider, - ) -> None: + def __init__(self, inner: ShardDownloader) -> None: self._inner = inner - self._peer_state = peer_state_provider self._progress_callbacks: list[ Callable[[ShardMetadata, RepoDownloadProgress], Awaitable[None]] ] = [] + # Peers are set per-download by the coordinator before calling ensure_shard + self._current_peers: list[PeerEndpoint] = [] + + def set_available_peers(self, peers: list[PeerEndpoint]) -> None: + """Set the peers to try for the next ensure_shard call. + + Called by DownloadCoordinator before triggering a download, based + on the peers embedded in the StartDownload command. + """ + self._current_peers = peers def on_progress( self, @@ -70,21 +78,21 @@ async def ensure_shard( model_id = shard.model_card.model_id normalized = model_id.normalize() + peers = self._current_peers + self._current_peers = [] # Reset after consumption - # Check if any peer has this model - peers = self._peer_state.get_peers_for_model(normalized) if not peers: - logger.debug(f"No peers have {model_id}, downloading from HuggingFace") + logger.debug(f"No peers available for {model_id}, downloading from HuggingFace") return await self._inner.ensure_shard(shard, config_only=False) - # Try each peer (completed peers first) + # Try each peer (already sorted by priority: RDMA first, completed first) for peer in peers: logger.info( f"Attempting peer download of {model_id} from " - f"{peer.ip} (status: {peer.status})" + f"{peer.ip}:{peer.port} (status: {peer.status}, link: {peer.connection_type})" ) result = await self._try_peer_download( - shard, peer.ip, self._peer_state.peer_download_port, normalized + shard, peer.ip, peer.port, normalized ) if result is not None: logger.info( @@ -131,7 +139,6 @@ async def _try_peer_download( skip_internet=False, ) except Exception: - # Can't get file list - fall back return None allow_patterns = await resolve_allow_patterns(shard) @@ -149,11 +156,7 @@ async def _try_peer_download( ] # Check the peer has all (or most) files we need - files_on_peer = 0 - for f in filtered_file_list: - if f.path in peer_file_map: - files_on_peer += 1 - + files_on_peer = sum(1 for f in filtered_file_list if f.path in peer_file_map) if files_on_peer == 0: logger.debug(f"Peer has no files we need for {model_id_normalized}") return None @@ -165,8 +168,6 @@ async def _try_peer_download( failed = False async def download_one(file_path: str, expected_size: int) -> bool: - """Download a single file from peer. Returns True on success.""" - def on_file_progress( curr_bytes: int, total_bytes: int, is_renamed: bool ) -> None: @@ -188,7 +189,6 @@ def on_file_progress( status="complete" if is_renamed else "in_progress", start_time=all_start_time, ) - # Fire progress callbacks progress = calculate_repo_progress( shard, shard.model_card.model_id, @@ -235,9 +235,6 @@ def on_file_progress( if peer_info and peer_info.safe_bytes > 0: tasks.append(download_one(f.path, f.size)) else: - # Peer doesn't have this file yet - this means incomplete peer - # We could still try for the files it has, but for simplicity - # fail the whole peer download if any file is missing failed = True break @@ -259,7 +256,6 @@ def on_file_progress( for cb in self._progress_callbacks: await cb(shard, final_progress) - # Return path (same as download_shard does) gguf = next( (f for f in filtered_file_list if f.path.endswith(".gguf")), None ) diff --git a/src/exo/download/peer_state.py b/src/exo/download/peer_state.py index bc87c0ce47..6f400b92a5 100644 --- a/src/exo/download/peer_state.py +++ b/src/exo/download/peer_state.py @@ -1,93 +1,126 @@ -"""Peer state provider for discovering which peers have which models. +"""Pure functions for discovering which peers have which models. -Reads from the shared State object (populated via gossipsub events) to -determine which peer nodes have completed or are in the process of -downloading a given model. Resolves peer IP addresses from the topology. +These functions are called by the Worker (which owns the State) to compute +peer availability at command-emit time. The results are embedded in the +StartDownload command so the download coordinator stays decoupled from +Worker state. """ -from dataclasses import dataclass -from typing import Callable, Literal - from loguru import logger +from exo.shared.types.commands import PeerEndpoint from exo.shared.types.common import NodeId from exo.shared.types.state import State -from exo.shared.types.topology import SocketConnection +from exo.shared.types.topology import RDMAConnection, SocketConnection from exo.shared.types.worker.downloads import ( DownloadCompleted, DownloadOngoing, ) -@dataclass(frozen=True) -class PeerInfo: - """A peer that has (or is downloading) a model.""" - - node_id: NodeId - ip: str - status: Literal["complete", "ongoing"] +def discover_peers_for_model( + node_id: NodeId, + state: State, + model_id_normalized: str, + peer_download_port: int, +) -> list[PeerEndpoint]: + """Find peers that have a specific model (complete or in-progress). + Called by the Worker when emitting a StartDownload command. Returns + peers sorted by priority: RDMA/Thunderbolt connections first, then + completed downloads before ongoing ones. -class PeerStateProvider: - """Provides information about which peers have which models. + Args: + node_id: This node's ID (excluded from results). + state: The global State object (owned by Worker). + model_id_normalized: Normalized model ID (e.g. "org--model"). + peer_download_port: Port where peers run their PeerFileServer. - Reads from the Worker's shared State to find peers and resolve their - network addresses from the topology graph. + Returns: + List of PeerEndpoint sorted by connection quality and completeness. """ + peers: list[PeerEndpoint] = [] - def __init__( - self, - node_id: NodeId, - state_accessor: Callable[[], State], - peer_download_port: int, - ) -> None: - self.node_id = node_id - self._state_accessor = state_accessor - self.peer_download_port = peer_download_port - - def get_peers_for_model(self, model_id: str) -> list[PeerInfo]: - """Find peers that have a specific model (complete or in-progress). - - Returns peers sorted by completeness (completed first, then ongoing). - Excludes self. - """ - state = self._state_accessor() - peers: list[PeerInfo] = [] - - # Check download status across all nodes - for peer_node_id, download_list in state.downloads.items(): - if peer_node_id == self.node_id: - continue + for peer_node_id, download_list in state.downloads.items(): + if peer_node_id == node_id: + continue - for dl in download_list: - dl_model_id = dl.shard_metadata.model_card.model_id - if dl_model_id.normalize() != model_id: - continue + for dl in download_list: + dl_model_id = dl.shard_metadata.model_card.model_id + if dl_model_id.normalize() != model_id_normalized: + continue - if isinstance(dl, DownloadCompleted): - status: Literal["complete", "ongoing"] = "complete" - elif isinstance(dl, DownloadOngoing): - status = "ongoing" - else: - continue + if isinstance(dl, DownloadCompleted): + status = "complete" + elif isinstance(dl, DownloadOngoing): + status = "ongoing" + else: + continue - # Resolve IP from topology - ip = self._resolve_peer_ip(peer_node_id, state) + # Resolve IP and connection type from topology + endpoint = _resolve_peer_endpoint( + node_id, peer_node_id, state, peer_download_port, status + ) + if endpoint: + peers.append(endpoint) + + # Sort by priority: + # 1. RDMA/Thunderbolt connections first (lower latency, higher bandwidth) + # 2. Completed downloads before ongoing ones + peers.sort( + key=lambda p: ( + 0 if p.connection_type == "rdma" else 1, + 0 if p.status == "complete" else 1, + ) + ) + return peers + + +def _resolve_peer_endpoint( + node_id: NodeId, + peer_node_id: NodeId, + state: State, + peer_download_port: int, + status: str, +) -> PeerEndpoint | None: + """Resolve a peer's IP address and connection type from the topology.""" + try: + # Check for RDMA connections first (highest priority) + for conn in state.topology.out_edges(node_id): + if conn.sink != peer_node_id: + continue + if isinstance(conn.edge, RDMAConnection): + # RDMA peer — still need IP from a socket connection + ip = _find_socket_ip(node_id, peer_node_id, state) if ip: - peers.append(PeerInfo(node_id=peer_node_id, ip=ip, status=status)) - - # Sort: completed peers first - peers.sort(key=lambda p: 0 if p.status == "complete" else 1) - return peers - - def _resolve_peer_ip(self, peer_node_id: NodeId, state: State) -> str | None: - """Resolve a peer's IP address from the topology graph.""" - try: - for conn in state.topology.out_edges(self.node_id): - if conn.sink == peer_node_id and isinstance( - conn.edge, SocketConnection - ): - return conn.edge.sink_multiaddr.ip_address - except Exception as e: - logger.debug(f"Could not resolve IP for peer {peer_node_id}: {e}") - return None + return PeerEndpoint( + node_id=peer_node_id, + ip=ip, + port=peer_download_port, + status=status, + connection_type="rdma", + ) + elif isinstance(conn.edge, SocketConnection): + return PeerEndpoint( + node_id=peer_node_id, + ip=conn.edge.sink_multiaddr.ip_address, + port=peer_download_port, + status=status, + connection_type="socket", + ) + except Exception as e: + logger.debug(f"Could not resolve endpoint for peer {peer_node_id}: {e}") + return None + + +def _find_socket_ip( + node_id: NodeId, peer_node_id: NodeId, state: State +) -> str | None: + """Find a socket connection IP for a peer (used as fallback for RDMA peers).""" + try: + for conn in state.topology.out_edges(node_id): + if conn.sink == peer_node_id and isinstance(conn.edge, SocketConnection): + return conn.edge.sink_multiaddr.ip_address + except Exception: + pass + return None diff --git a/src/exo/main.py b/src/exo/main.py index b43b7cd0aa..861afac01b 100644 --- a/src/exo/main.py +++ b/src/exo/main.py @@ -17,7 +17,6 @@ from exo.download.coordinator import DownloadCoordinator from exo.download.impl_shard_downloader import exo_shard_downloader from exo.download.peer_file_server import PeerFileServer -from exo.download.peer_state import PeerStateProvider from exo.master.main import Master from exo.routing.event_router import EventRouter from exo.routing.router import Router, get_node_id_keypair @@ -47,7 +46,6 @@ class Node: offline: bool _api_port: int peer_file_server: PeerFileServer | None = None - peer_state_provider: PeerStateProvider | None = None _tg: TaskGroup = field(init=False, default_factory=TaskGroup) @classmethod @@ -75,12 +73,8 @@ async def create(cls, args: "Args") -> Self: logger.info(f"Starting node {node_id}") - # Peer download components are created below, after Worker (needs state access) peer_file_server: PeerFileServer | None = None - peer_state_provider: PeerStateProvider | None = None - - # DownloadCoordinator is also created below, after Worker - download_coordinator: DownloadCoordinator | None = None + peer_download_enabled = not args.no_peer_download and not args.no_downloads if args.spawn_api: api = API( @@ -106,30 +100,27 @@ async def create(cls, args: "Args") -> Self: else: worker = None - # Create peer download components and DownloadCoordinator - # (after Worker, since PeerStateProvider needs access to worker state) + # Create peer file server and download coordinator + if peer_download_enabled: + peer_file_server = PeerFileServer( + host="0.0.0.0", + port=EXO_PEER_DOWNLOAD_PORT, + models_dir=EXO_MODELS_DIR, + ) + if not args.no_downloads: - if not args.no_peer_download and worker is not None: - peer_file_server = PeerFileServer( - host="0.0.0.0", - port=EXO_PEER_DOWNLOAD_PORT, - models_dir=EXO_MODELS_DIR, - ) - peer_state_provider = PeerStateProvider( - node_id=node_id, - state_accessor=lambda: worker.state, - peer_download_port=EXO_PEER_DOWNLOAD_PORT, - ) - download_coordinator = DownloadCoordinator( + download_coordinator: DownloadCoordinator | None = DownloadCoordinator( node_id, exo_shard_downloader( offline=args.offline, - peer_state_provider=peer_state_provider, + peer_download_enabled=peer_download_enabled, ), event_sender=event_router.sender(), download_command_receiver=router.receiver(topics.DOWNLOAD_COMMANDS), offline=args.offline, ) + else: + download_coordinator = None # We start every node with a master master = Master( @@ -169,7 +160,6 @@ async def create(cls, args: "Args") -> Self: args.offline, args.api_port, peer_file_server, - peer_state_provider, ) logger_set_context( node_id=node_id, role="master" if args.force_master else "node" @@ -282,7 +272,7 @@ async def _elect_loop(self): self.node_id, exo_shard_downloader( offline=self.offline, - peer_state_provider=self.peer_state_provider, + peer_download_enabled=self.peer_file_server is not None, ), event_sender=self.event_router.sender(), download_command_receiver=self.router.receiver( @@ -304,12 +294,6 @@ async def _elect_loop(self): ), api_port=self._api_port, ) - # Update peer state provider to reference the new worker - if self.peer_state_provider is not None: - new_worker = self.worker - self.peer_state_provider._state_accessor = ( - lambda: new_worker.state - ) self._tg.start_soon(self.worker.run) if self.api: self.api.reset(result.won_clock, self.event_router.receiver()) diff --git a/src/exo/shared/types/commands.py b/src/exo/shared/types/commands.py index 62f73ac399..c05002f231 100644 --- a/src/exo/shared/types/commands.py +++ b/src/exo/shared/types/commands.py @@ -68,9 +68,20 @@ class RequestEventLog(BaseCommand): since_idx: int +class PeerEndpoint(CamelCaseModel): + """A peer node that has (or is downloading) a model, with its network address.""" + + node_id: NodeId + ip: str + port: int + status: str = "complete" # "complete" or "ongoing" + connection_type: str = "socket" # "rdma" or "socket" + + class StartDownload(BaseCommand): target_node_id: NodeId shard_metadata: ShardMetadata + available_peers: list[PeerEndpoint] = Field(default_factory=list) class DeleteDownload(BaseCommand): diff --git a/src/exo/worker/main.py b/src/exo/worker/main.py index b35f946aac..12054fc303 100644 --- a/src/exo/worker/main.py +++ b/src/exo/worker/main.py @@ -8,8 +8,9 @@ from exo.api.types import ImageEditsTaskParams from exo.download.download_utils import is_read_only_model_dir, resolve_existing_model +from exo.download.peer_state import discover_peers_for_model from exo.shared.apply import apply -from exo.shared.constants import EXO_MAX_INSTANCE_RETRIES +from exo.shared.constants import EXO_MAX_INSTANCE_RETRIES, EXO_PEER_DOWNLOAD_PORT from exo.shared.models.model_cards import ModelId, add_to_card_cache, delete_custom_card from exo.shared.types.chunks import InputImageChunk from exo.shared.types.commands import ( @@ -252,12 +253,20 @@ async def plan_step(self): ) ) else: + # Discover peers that already have this model + peers = discover_peers_for_model( + self.node_id, + self.state, + shard.model_card.model_id.normalize(), + EXO_PEER_DOWNLOAD_PORT, + ) await self.download_command_sender.send( ForwarderDownloadCommand( origin=self._system_id, command=StartDownload( target_node_id=self.node_id, shard_metadata=shard, + available_peers=peers, ), ) ) From 66e0c35a7347f1ce9f8994aae3e9498fb0d35291 Mon Sep 17 00:00:00 2001 From: Adam Durham Date: Sun, 26 Apr 2026 12:30:20 -0500 Subject: [PATCH 04/29] feat: EXO_KV_CACHE_BITS env var + step=16384 to keep QuantizedKVCache usable at long context MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Two small changes that pair in practice: 1. **Env-var override for `KV_CACHE_BITS`.** The constant in `worker/engines/mlx/constants.py` already supports None / int but is compile-time only (the maintainer comment above the constants block asks "Do we want so many constants? I think we want a lot of these as parameters?"). Reading `EXO_KV_CACHE_BITS` at module import lets operators flip 4-bit KV on a deployment without editing source. 2. **`step = 16384` on QuantizedKVCache when `KV_CACHE_BITS` is set.** `QuantizedKVCache`'s default step (256) forces a `mx.concatenate` expansion every 256 tokens during prefill — at 50K context that is ~195 reallocations, which fragments Metal memory and OOMs well before the nominal cache size limit. Matching the existing `KVCache` step=16384 lets the cache pre-allocate and write in-place. Also extends to the `model.make_cache()` branch so models that ship their own cache layout (e.g. DeepSeek-V3 with mixed KV + DeltaNet caches) honor `KV_CACHE_BITS` by replacing only their plain `KVCache` entries with `QuantizedKVCache`, leaving `ArraysCache` and other types untouched. Together these unblock 4-bit KV at long context: ~100 KB/tok → ~48 KB/tok, roughly doubling the context ceiling on a fixed memory budget. Default behavior is unchanged when the env var is unset. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/exo/worker/engines/mlx/cache.py | 34 +++++++++++++++++++++++-- src/exo/worker/engines/mlx/constants.py | 8 +++++- 2 files changed, 39 insertions(+), 3 deletions(-) diff --git a/src/exo/worker/engines/mlx/cache.py b/src/exo/worker/engines/mlx/cache.py index 7cdcc77fbe..0cf4fbf92e 100644 --- a/src/exo/worker/engines/mlx/cache.py +++ b/src/exo/worker/engines/mlx/cache.py @@ -563,8 +563,38 @@ def make_kv_cache( assert hasattr(model, "layers") if hasattr(model, "make_cache"): - logger.info("Using MLX LM's make cache") - return model.make_cache() # type: ignore + caches: list[ + KVCache | RotatingKVCache | QuantizedKVCache | ArraysCache | CacheList + ] = list(model.make_cache()) # pyright: ignore[reportUnknownArgumentType, reportUnknownMemberType] + if KV_CACHE_BITS is not None: + # Honor KV_CACHE_BITS even when the model provides its own + # make_cache(). Replace plain KVCache entries with + # QuantizedKVCache; leave ArraysCache (DeltaNet/SSM) and other + # cache types alone since they don't support quantization. + quantized = 0 + for i, c in enumerate(caches): + if isinstance(c, KVCache): + qc = QuantizedKVCache( + group_size=CACHE_GROUP_SIZE, bits=KV_CACHE_BITS + ) + qc.step = 16384 + caches[i] = qc + quantized += 1 + logger.info( + f"Using quantized KV cache " + f"(bits={KV_CACHE_BITS}, group_size={CACHE_GROUP_SIZE}) " + f"for {quantized}/{len(caches)} layers" + ) + else: + logger.info("Using MLX LM's make cache") + # Increase KVCache step size to reduce Metal allocator + # fragmentation. Default step=256 causes a mx.concatenate + # expansion every prefill chunk; a larger step lets the cache + # pre-allocate and write in-place for most of the prefill. + for c in caches: + if isinstance(c, KVCache): + c.step = 16384 + return caches if max_kv_size is None: if KV_CACHE_BITS is None: diff --git a/src/exo/worker/engines/mlx/constants.py b/src/exo/worker/engines/mlx/constants.py index 86a663e424..2d2805fe8c 100644 --- a/src/exo/worker/engines/mlx/constants.py +++ b/src/exo/worker/engines/mlx/constants.py @@ -1,3 +1,5 @@ +import os + # TODO: Do we want so many constants? # I think we want a lot of these as parameters? @@ -9,7 +11,11 @@ KEEP_KV_SIZE: int | None = 1600 QUANTIZE_MODEL_MODE: str | None = "affine" CACHE_GROUP_SIZE: int = 64 -KV_CACHE_BITS: int | None = None +KV_CACHE_BITS: int | None = ( + int(os.environ["EXO_KV_CACHE_BITS"]) + if os.environ.get("EXO_KV_CACHE_BITS") + else None +) DEFAULT_TOP_LOGPROBS: int = 5 From 8996e7d0757912cf3dc2ef1bccc563863b9f28f8 Mon Sep 17 00:00:00 2001 From: Adam Durham Date: Sun, 26 Apr 2026 13:18:58 -0500 Subject: [PATCH 05/29] fix: skip KV cache quantization in single-node BatchGenerator mode MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Single-node inference crashed with: ValueError: does not yet support batching with history mlx-lm's BatchGenerator calls _merge_caches on every step — even when there's only one prompt in flight — and that helper requires every layer's cache to implement .merge(). QuantizedKVCache has no merge implementation, so any single-node inference with EXO_KV_CACHE_BITS set crashes on the first real request. The PP mode stayed working because it runs in pipeline-parallel mode which uses a different inference path that doesn't go through _merge_caches. Fix: only build QuantizedKVCache when the model is actually running in PP mode (detected by PipelineFirstLayer/PipelineLastLayer wrappers). Single-node falls back to vanilla KVCache, and logs that EXO_KV_CACHE_BITS is being ignored so the operator can see what's happening. Co-Authored-By: Claude Opus 4.6 (1M context) --- src/exo/worker/engines/mlx/cache.py | 41 +++++++++++++++++++++++++++-- 1 file changed, 39 insertions(+), 2 deletions(-) diff --git a/src/exo/worker/engines/mlx/cache.py b/src/exo/worker/engines/mlx/cache.py index 0cf4fbf92e..53cd274874 100644 --- a/src/exo/worker/engines/mlx/cache.py +++ b/src/exo/worker/engines/mlx/cache.py @@ -557,6 +557,37 @@ def get_memory_used_percentage() -> float: return float(mem.percent / 100) +def _model_is_pipeline_parallel(model: Model) -> bool: + """True iff the model has pipeline-parallel layer wrappers installed. + + Only the PP path is safe to combine with QuantizedKVCache right now: + the single-node BatchGenerator code path in mlx-lm calls + ``_merge_caches`` on every step (even for a single in-flight request), + and QuantizedKVCache does not implement ``merge``. Attempting to use + a quantized cache in that path crashes with:: + + does not yet + support batching with history + + Detecting PP mode by layer type is cheap and avoids threading the + distributed group through every cache call site. + """ + try: + from exo.worker.engines.mlx.auto_parallel import ( + PipelineFirstLayer, + PipelineLastLayer, + ) + except Exception: + return False + layers = getattr(model, "layers", None) + if layers is None: + return False + for layer in layers: # type: ignore[reportUnknownVariableType] + if isinstance(layer, (PipelineFirstLayer, PipelineLastLayer)): + return True + return False + + def make_kv_cache( model: Model, max_kv_size: int | None = None, keep: int = 0 ) -> KVCacheType: @@ -597,8 +628,14 @@ def make_kv_cache( return caches if max_kv_size is None: - if KV_CACHE_BITS is None: - logger.info("Using default KV cache") + if KV_CACHE_BITS is None or not _model_is_pipeline_parallel(model): + if KV_CACHE_BITS is not None: + logger.info( + f"EXO_KV_CACHE_BITS={KV_CACHE_BITS} ignored in single-node mode " + f"(QuantizedKVCache has no merge() support, required by BatchGenerator)" + ) + else: + logger.info("Using default KV cache") return [KVCache() for _ in model.layers] else: logger.info("Using quantized KV cache") From 1cc0be2e97cc12cb6a4e9d96877bcb2f2e6e1559 Mon Sep 17 00:00:00 2001 From: Alex Cheema Date: Mon, 23 Feb 2026 13:17:52 -0800 Subject: [PATCH 06/29] feat: add --trust-remote-code CLI flag for custom model tokenizers Some custom models (e.g. Kimi) require trust_remote_code=True to load their tokenizers. This adds an opt-in CLI flag that sets an env var read by runner subprocesses, following the same pattern as --fast-synch. The flag is intentionally CLI-only (not API-accessible) to prevent remote code execution attacks via the API. Also changes the default TRUST_REMOTE_CODE constant from True to False, making remote code execution fully opt-in. Co-Authored-By: Claude Opus 4.6 --- src/exo/main.py | 13 +++++++++++++ src/exo/worker/engines/mlx/constants.py | 4 ++-- src/exo/worker/engines/mlx/utils_mlx.py | 6 +++++- 3 files changed, 20 insertions(+), 3 deletions(-) diff --git a/src/exo/main.py b/src/exo/main.py index 861afac01b..fcea355d2c 100644 --- a/src/exo/main.py +++ b/src/exo/main.py @@ -407,6 +407,13 @@ def main(): os.environ["EXO_NO_BATCH"] = "1" logger.info("Continuous batching disabled (--no-batch)") + # Set trust_remote_code override env var for runner subprocesses + if args.trust_remote_code: + os.environ["EXO_TRUST_REMOTE_CODE"] = "1" + logger.warning( + "--trust-remote-code enabled: models may execute arbitrary code during loading" + ) + # Set FAST_SYNCH override env var for runner subprocesses if args.fast_synch is True: os.environ["EXO_FAST_SYNCH"] = "true" @@ -457,6 +464,7 @@ class Args(FrozenModel): fast_synch: bool | None = None # None = auto, True = force on, False = force off bootstrap_peers: list[str] = [] libp2p_port: int + trust_remote_code: bool = False @classmethod def parse(cls) -> Self: @@ -519,6 +527,11 @@ def parse(cls) -> Self: action="store_true", help="Disable continuous batching, use sequential generation", ) + parser.add_argument( + "--trust-remote-code", + action="store_true", + help="Allow models to execute custom code during tokenizer loading (security-sensitive, CLI-only)", + ) parser.add_argument( "--bootstrap-peers", type=lambda s: [p for p in s.split(",") if p], diff --git a/src/exo/worker/engines/mlx/constants.py b/src/exo/worker/engines/mlx/constants.py index 2d2805fe8c..16504884fc 100644 --- a/src/exo/worker/engines/mlx/constants.py +++ b/src/exo/worker/engines/mlx/constants.py @@ -19,5 +19,5 @@ DEFAULT_TOP_LOGPROBS: int = 5 -# TODO: We should really make this opt-in, but Kimi requires trust_remote_code=True -TRUST_REMOTE_CODE: bool = True +# Opt-in via --trust-remote-code CLI flag; default is False for security. +TRUST_REMOTE_CODE: bool = False diff --git a/src/exo/worker/engines/mlx/utils_mlx.py b/src/exo/worker/engines/mlx/utils_mlx.py index 463ad1e91d..1eeb99d404 100644 --- a/src/exo/worker/engines/mlx/utils_mlx.py +++ b/src/exo/worker/engines/mlx/utils_mlx.py @@ -480,10 +480,14 @@ def shard_and_load( def get_tokenizer(model_path: Path, shard_metadata: ShardMetadata) -> TokenizerWrapper: """Load tokenizer for a model shard. Delegates to load_tokenizer_for_model_id.""" + trust_remote_code = ( + shard_metadata.model_card.trust_remote_code + or os.environ.get("EXO_TRUST_REMOTE_CODE") == "1" + ) return load_tokenizer_for_model_id( shard_metadata.model_card.model_id, model_path, - trust_remote_code=shard_metadata.model_card.trust_remote_code, + trust_remote_code=trust_remote_code, ) From 392bc6cf3ccfa60909b4b9ed3c998be0ccdcb474 Mon Sep 17 00:00:00 2001 From: Alex Cheema Date: Mon, 23 Feb 2026 13:21:56 -0800 Subject: [PATCH 07/29] fix: keep TRUST_REMOTE_CODE=True for built-in models The constant is the default for built-in models with known model cards, which are trusted. Custom models added via API already default to trust_remote_code=False in ModelCard.fetch_from_hf(). The CLI flag overrides custom models only. Co-Authored-By: Claude Opus 4.6 --- src/exo/worker/engines/mlx/constants.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/exo/worker/engines/mlx/constants.py b/src/exo/worker/engines/mlx/constants.py index 16504884fc..c44e93e750 100644 --- a/src/exo/worker/engines/mlx/constants.py +++ b/src/exo/worker/engines/mlx/constants.py @@ -19,5 +19,6 @@ DEFAULT_TOP_LOGPROBS: int = 5 -# Opt-in via --trust-remote-code CLI flag; default is False for security. -TRUST_REMOTE_CODE: bool = False +# True for built-in models with known model cards; custom models added via API default to False +# and can be overridden with the --trust-remote-code CLI flag. +TRUST_REMOTE_CODE: bool = True From a8ea158d011b749c5692b03dd1f2c501492c14df Mon Sep 17 00:00:00 2001 From: jw-wcv <101585096+jw-wcv@users.noreply.github.com> Date: Wed, 6 May 2026 23:39:22 -0700 Subject: [PATCH 08/29] Reconcile worker instance backoff from state --- src/exo/utils/keyed_backoff.py | 4 +++ src/exo/utils/tests/test_keyed_backoff.py | 13 +++++++ src/exo/worker/main.py | 12 +++++++ .../unittests/test_worker_instance_backoff.py | 36 +++++++++++++++++++ 4 files changed, 65 insertions(+) create mode 100644 src/exo/utils/tests/test_keyed_backoff.py create mode 100644 src/exo/worker/tests/unittests/test_worker_instance_backoff.py diff --git a/src/exo/utils/keyed_backoff.py b/src/exo/utils/keyed_backoff.py index 4d7c9a66ed..a95fe5c5f7 100644 --- a/src/exo/utils/keyed_backoff.py +++ b/src/exo/utils/keyed_backoff.py @@ -29,6 +29,10 @@ def attempts(self, key: K) -> int: """Return the number of recorded attempts for a key.""" return self._attempts.get(key, 0) + def tracked_keys(self) -> set[K]: + """Return keys that currently have recorded backoff state.""" + return set(self._attempts) | set(self._last_time) + def reset(self, key: K) -> None: """Reset backoff state for a key (e.g., on success).""" self._attempts.pop(key, None) diff --git a/src/exo/utils/tests/test_keyed_backoff.py b/src/exo/utils/tests/test_keyed_backoff.py new file mode 100644 index 0000000000..b592a4fabd --- /dev/null +++ b/src/exo/utils/tests/test_keyed_backoff.py @@ -0,0 +1,13 @@ +from exo.utils.keyed_backoff import KeyedBackoff + + +def test_tracked_keys_reports_and_resets_backoff_state() -> None: + backoff = KeyedBackoff[str]() + + backoff.record_attempt("instance-a") + + assert backoff.tracked_keys() == {"instance-a"} + + backoff.reset("instance-a") + + assert backoff.tracked_keys() == set() diff --git a/src/exo/worker/main.py b/src/exo/worker/main.py index 12054fc303..20c861e007 100644 --- a/src/exo/worker/main.py +++ b/src/exo/worker/main.py @@ -110,6 +110,7 @@ async def run(self): tg.start_soon(self._forward_info, info_recv) tg.start_soon(self.plan_step) tg.start_soon(self._event_applier) + tg.start_soon(self._reconcile_instance_backoff) tg.start_soon(self._poll_connection_updates) finally: # Actual shutdown code - waits for all tasks to complete before executing. @@ -180,6 +181,17 @@ async def _event_applier(self): if isinstance(event, CustomModelCardDeleted): await delete_custom_card(event.model_id) + async def _reconcile_instance_backoff(self) -> None: + while True: + await anyio.sleep(1) + self._reconcile_instance_backoff_once() + + def _reconcile_instance_backoff_once(self) -> None: + live_instances = set(self.state.instances) + for instance_id in self._instance_backoff.tracked_keys(): + if instance_id not in live_instances: + self._instance_backoff.reset(instance_id) + async def plan_step(self): while True: await anyio.sleep(0.1) diff --git a/src/exo/worker/tests/unittests/test_worker_instance_backoff.py b/src/exo/worker/tests/unittests/test_worker_instance_backoff.py new file mode 100644 index 0000000000..b0052c1eb7 --- /dev/null +++ b/src/exo/worker/tests/unittests/test_worker_instance_backoff.py @@ -0,0 +1,36 @@ +# pyright: reportPrivateUsage=false + +from exo.shared.types.common import ModelId, NodeId +from exo.shared.types.state import State +from exo.shared.types.worker.instances import InstanceId, MlxRingInstance +from exo.shared.types.worker.runners import ShardAssignments +from exo.utils.keyed_backoff import KeyedBackoff +from exo.worker.main import Worker + + +def _make_instance(instance_id: InstanceId) -> MlxRingInstance: + return MlxRingInstance( + instance_id=instance_id, + shard_assignments=ShardAssignments( + model_id=ModelId("test-model"), + node_to_runner={}, + runner_to_shard={}, + ), + hosts_by_node={NodeId("node-1"): []}, + ephemeral_port=1, + ) + + +def test_worker_reconciles_instance_backoff_from_state() -> None: + live_instance_id = InstanceId("inst-live") + deleted_instance_id = InstanceId("inst-deleted") + worker = object.__new__(Worker) + worker.state = State(instances={live_instance_id: _make_instance(live_instance_id)}) + worker._instance_backoff = KeyedBackoff[InstanceId]() + worker._instance_backoff.record_attempt(live_instance_id) + worker._instance_backoff.record_attempt(deleted_instance_id) + + worker._reconcile_instance_backoff_once() + + assert worker._instance_backoff.attempts(live_instance_id) == 1 + assert worker._instance_backoff.attempts(deleted_instance_id) == 0 From 0955b1cd18289130f96255403d60a9adb8d62577 Mon Sep 17 00:00:00 2001 From: jw-wcv <101585096+jw-wcv@users.noreply.github.com> Date: Wed, 6 May 2026 23:39:42 -0700 Subject: [PATCH 09/29] Tune cluster liveness polling cadence --- src/exo/master/main.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/exo/master/main.py b/src/exo/master/main.py index 888c39e4c8..50c23fa5c9 100644 --- a/src/exo/master/main.py +++ b/src/exo/master/main.py @@ -481,6 +481,9 @@ async def _command_processor(self) -> None: # These plan loops are the cracks showing in our event sourcing architecture - more things could be commands async def _plan(self) -> None: + node_inactivity_timeout = timedelta(seconds=5) + tick_interval_seconds = 1.0 + while True: # kill broken instances connected_node_ids = set(self.state.topology.list_nodes()) @@ -503,7 +506,7 @@ async def _plan(self) -> None: # time out dead nodes for node_id, time in self.state.last_seen.items(): now = datetime.now(tz=timezone.utc) - if now - time > timedelta(seconds=30): + if now - time > node_inactivity_timeout: impacted_instances = [ str(instance_id) for instance_id, instance in self.state.instances.items() @@ -520,7 +523,7 @@ async def _plan(self) -> None: ) await self.event_sender.send(NodeTimedOut(node_id=node_id)) - await anyio.sleep(10) + await anyio.sleep(tick_interval_seconds) async def _event_processor(self) -> None: with self.local_event_receiver as local_events: From d18c00ba06f19dadc5163b415d249e61462117fd Mon Sep 17 00:00:00 2001 From: jw-wcv <101585096+jw-wcv@users.noreply.github.com> Date: Wed, 6 May 2026 23:42:03 -0700 Subject: [PATCH 10/29] Gate RDMA placement on rdma_ctl state --- src/exo/api/main.py | 2 + src/exo/master/main.py | 1 + src/exo/master/placement.py | 13 +- src/exo/master/tests/test_placement.py | 140 ++++++++++- src/exo/shared/apply.py | 16 ++ .../test_apply/test_apply_rdma_gating.py | 231 ++++++++++++++++++ src/exo/shared/topology.py | 16 ++ 7 files changed, 416 insertions(+), 3 deletions(-) create mode 100644 src/exo/shared/tests/test_apply/test_apply_rdma_gating.py diff --git a/src/exo/api/main.py b/src/exo/api/main.py index e7b61bf0fb..4ac9f605e2 100644 --- a/src/exo/api/main.py +++ b/src/exo/api/main.py @@ -730,6 +730,7 @@ async def get_placement( topology=self.state.topology, current_instances=self.state.instances, download_status=self.state.downloads, + node_rdma_ctl=self.state.node_rdma_ctl, ) except ValueError as exc: raise HTTPException(status_code=400, detail=str(exc)) from exc @@ -794,6 +795,7 @@ async def get_placement_previews( allowed_nodes=allowed_nodes, allow_single_node_total_memory=allowed_nodes is not None, download_status=self.state.downloads, + node_rdma_ctl=self.state.node_rdma_ctl, ) except ValueError as exc: if (model_card.model_id, sharding, instance_meta, 0) not in seen: diff --git a/src/exo/master/main.py b/src/exo/master/main.py index 50c23fa5c9..83d2871fb2 100644 --- a/src/exo/master/main.py +++ b/src/exo/master/main.py @@ -392,6 +392,7 @@ async def _command_processor(self) -> None: self.state.node_memory, self.state.node_network, download_status=self.state.downloads, + node_rdma_ctl=self.state.node_rdma_ctl, ) transition_events = get_transition_events( self.state.instances, placement, self.state.tasks diff --git a/src/exo/master/placement.py b/src/exo/master/placement.py index f2d4066a9c..4cb0b7d646 100644 --- a/src/exo/master/placement.py +++ b/src/exo/master/placement.py @@ -36,6 +36,7 @@ MemoryUsage, NetworkInterfaceInfo, NodeNetworkInfo, + NodeRdmaCtlStatus, ) from exo.shared.types.tasks import Task, TaskId, TaskStatus from exo.shared.types.topology import SocketConnection @@ -136,6 +137,7 @@ def place_instance( allowed_nodes: set[NodeId] | None = None, allow_single_node_total_memory: bool = False, download_status: Mapping[NodeId, Sequence[DownloadProgress]] | None = None, + node_rdma_ctl: Mapping[NodeId, NodeRdmaCtlStatus] | None = None, ) -> dict[InstanceId, Instance]: sharding = command.sharding instance_meta = command.instance_meta @@ -263,9 +265,18 @@ def place_instance( ) smallest_cycles = get_smallest_cycles(cycles_with_sufficient_memory) + rdma_ctl_status = node_rdma_ctl or {} + + def _all_rdma_ctl_enabled(cycle: Cycle) -> bool: + return all( + ((status := rdma_ctl_status.get(node_id)) is not None and status.enabled) + for node_id in cycle + ) smallest_rdma_cycles = [ - cycle for cycle in smallest_cycles if topology.is_rdma_cycle(cycle) + cycle + for cycle in smallest_cycles + if topology.is_rdma_cycle(cycle) and _all_rdma_ctl_enabled(cycle) ] if instance_meta == InstanceMeta.MlxJaccl: diff --git a/src/exo/master/tests/test_placement.py b/src/exo/master/tests/test_placement.py index f102b6f2b7..cb3aaba72c 100644 --- a/src/exo/master/tests/test_placement.py +++ b/src/exo/master/tests/test_placement.py @@ -26,6 +26,7 @@ MemoryUsage, NetworkInterfaceInfo, NodeNetworkInfo, + NodeRdmaCtlStatus, ) from exo.shared.types.tasks import TaskId, TaskStatus, TextGeneration from exo.shared.types.text_generation import ( @@ -569,8 +570,21 @@ def test_tensor_rdma_backend_connectivity_matrix( min_nodes=1, ) + node_rdma_ctl = { + node_a: NodeRdmaCtlStatus(enabled=True), + node_b: NodeRdmaCtlStatus(enabled=True), + node_c: NodeRdmaCtlStatus(enabled=True), + } + # act - placements = place_instance(cic, topology, {}, node_memory, node_network) + placements = place_instance( + cic, + topology, + {}, + node_memory, + node_network, + node_rdma_ctl=node_rdma_ctl, + ) # assert assert len(placements) == 1 @@ -611,7 +625,6 @@ def test_tensor_rdma_backend_connectivity_matrix( ip_part = coordinator.split(":")[0] assert len(ip_part.split(".")) == 4 - def test_qwen3_5_tensor_auto_upgrade_requires_opt_in( monkeypatch: pytest.MonkeyPatch, ) -> None: @@ -1493,6 +1506,129 @@ def test_placement_prefers_socket_reachable_rank_zero( assert shard.device_rank == 0 +def _build_three_node_rdma_topology() -> tuple[ + Topology, NodeId, NodeId, NodeId, dict[NodeId, NodeNetworkInfo] +]: + topology = Topology() + node_a = NodeId() + node_b = NodeId() + node_c = NodeId() + + ethernet_interface = NetworkInterfaceInfo(name="en0", ip_address="10.0.0.1") + ethernet_conn = SocketConnection( + sink_multiaddr=Multiaddr(address="/ip4/10.0.0.1/tcp/8000") + ) + node_network = { + node_a: NodeNetworkInfo(interfaces=[ethernet_interface]), + node_b: NodeNetworkInfo(interfaces=[ethernet_interface]), + node_c: NodeNetworkInfo(interfaces=[ethernet_interface]), + } + + for node_id in (node_a, node_b, node_c): + topology.add_node(node_id) + + for source, sink, iface in ( + (node_a, node_b, 3), + (node_b, node_a, 3), + (node_b, node_c, 4), + (node_c, node_b, 4), + (node_a, node_c, 5), + (node_c, node_a, 5), + ): + topology.add_connection( + Connection(source=source, sink=sink, edge=create_rdma_connection(iface)) + ) + + for source, sink in ( + (node_a, node_b), + (node_b, node_c), + (node_c, node_a), + (node_a, node_c), + (node_b, node_a), + (node_c, node_b), + ): + topology.add_connection( + Connection(source=source, sink=sink, edge=ethernet_conn) + ) + + return topology, node_a, node_b, node_c, node_network + + +def test_place_mlx_jaccl_rejects_when_a_node_has_rdma_ctl_disabled( + model_card: ModelCard, +) -> None: + model_card = model_card.model_copy( + update={"n_layers": 12, "storage_size": Memory.from_bytes(1500)} + ) + topology, node_a, node_b, node_c, node_network = _build_three_node_rdma_topology() + node_memory = { + node_a: create_node_memory(500), + node_b: create_node_memory(500), + node_c: create_node_memory(500), + } + node_rdma_ctl = { + node_a: NodeRdmaCtlStatus(enabled=True), + node_b: NodeRdmaCtlStatus(enabled=True), + node_c: NodeRdmaCtlStatus(enabled=False), + } + command = PlaceInstance( + sharding=Sharding.Tensor, + instance_meta=InstanceMeta.MlxJaccl, + command_id=CommandId(), + model_card=model_card, + min_nodes=3, + ) + + with pytest.raises( + ValueError, match="Requested RDMA \\(MlxJaccl\\) but no RDMA-connected cycles" + ): + place_instance( + command, + topology, + {}, + node_memory, + node_network, + node_rdma_ctl=node_rdma_ctl, + ) + + +def test_place_mlx_jaccl_rejects_when_node_rdma_ctl_missing( + model_card: ModelCard, +) -> None: + model_card = model_card.model_copy( + update={"n_layers": 12, "storage_size": Memory.from_bytes(1500)} + ) + topology, node_a, node_b, node_c, node_network = _build_three_node_rdma_topology() + node_memory = { + node_a: create_node_memory(500), + node_b: create_node_memory(500), + node_c: create_node_memory(500), + } + node_rdma_ctl = { + node_a: NodeRdmaCtlStatus(enabled=True), + node_b: NodeRdmaCtlStatus(enabled=True), + } + command = PlaceInstance( + sharding=Sharding.Tensor, + instance_meta=InstanceMeta.MlxJaccl, + command_id=CommandId(), + model_card=model_card, + min_nodes=3, + ) + + with pytest.raises( + ValueError, match="Requested RDMA \\(MlxJaccl\\) but no RDMA-connected cycles" + ): + place_instance( + command, + topology, + {}, + node_memory, + node_network, + node_rdma_ctl=node_rdma_ctl, + ) + + def _make_task( instance_id: InstanceId, status: TaskStatus = TaskStatus.Running, diff --git a/src/exo/shared/apply.py b/src/exo/shared/apply.py index ce3f503537..278dd660ac 100644 --- a/src/exo/shared/apply.py +++ b/src/exo/shared/apply.py @@ -65,6 +65,13 @@ ) +def _is_rdma_ctl_enabled( + node_id: NodeId, node_rdma_ctl: Mapping[NodeId, NodeRdmaCtlStatus] +) -> bool: + status = node_rdma_ctl.get(node_id) + return status is not None and status.enabled + + def event_apply(event: Event, state: State) -> State: """Apply an event to state.""" match event: @@ -415,6 +422,9 @@ def apply_node_gathered_info(event: NodeGatheredInfo, state: State) -> State: for nid in state.node_thunderbolt for tb_ident in state.node_thunderbolt[nid].interfaces } + source_is_rdma_enabled = _is_rdma_ctl_enabled( + event.node_id, state.node_rdma_ctl + ) as_rdma_conns = [ Connection( source=event.node_id, @@ -427,6 +437,10 @@ def apply_node_gathered_info(event: NodeGatheredInfo, state: State) -> State: for tb_conn in info.conns if tb_conn.source_uuid in conn_map if tb_conn.sink_uuid in conn_map + if source_is_rdma_enabled + and _is_rdma_ctl_enabled( + conn_map[tb_conn.sink_uuid][0], state.node_rdma_ctl + ) ] topology.replace_all_out_rdma_connections(event.node_id, as_rdma_conns) case ThunderboltBridgeInfo(): @@ -450,6 +464,8 @@ def apply_node_gathered_info(event: NodeGatheredInfo, state: State) -> State: **state.node_rdma_ctl, event.node_id: NodeRdmaCtlStatus(enabled=info.enabled), } + if not info.enabled: + topology.remove_all_rdma_connections_touching(event.node_id) return state.model_copy(update=update) diff --git a/src/exo/shared/tests/test_apply/test_apply_rdma_gating.py b/src/exo/shared/tests/test_apply/test_apply_rdma_gating.py new file mode 100644 index 0000000000..492e3fc5ea --- /dev/null +++ b/src/exo/shared/tests/test_apply/test_apply_rdma_gating.py @@ -0,0 +1,231 @@ +from datetime import datetime, timezone + +from exo.shared.apply import apply_node_gathered_info +from exo.shared.topology import Topology +from exo.shared.types.common import NodeId +from exo.shared.types.events import NodeGatheredInfo +from exo.shared.types.profiling import ( + NodeRdmaCtlStatus, + NodeThunderboltInfo, +) +from exo.shared.types.state import State +from exo.shared.types.thunderbolt import ThunderboltConnection, ThunderboltIdentifier +from exo.shared.types.topology import RDMAConnection +from exo.utils.info_gatherer.info_gatherer import ( + MacThunderboltConnections, + RdmaCtlStatus, +) + + +def _now() -> str: + return datetime.now(timezone.utc).isoformat() + + +def _make_state_with_thunderbolt_idents( + *node_ids_and_uuids: tuple[NodeId, str, str], + rdma_ctl: dict[NodeId, NodeRdmaCtlStatus] | None = None, +) -> State: + """Build a State with Thunderbolt identifiers per node so the apply MacThunderboltConnections + case can resolve uuid -> (node, iface).""" + node_thunderbolt = { + nid: NodeThunderboltInfo( + interfaces=[ThunderboltIdentifier(rdma_interface=iface, domain_uuid=uuid)] + ) + for nid, uuid, iface in node_ids_and_uuids + } + return State( + node_thunderbolt=node_thunderbolt, + node_rdma_ctl=rdma_ctl or {}, + ) + + +def _has_rdma_edge(topology: Topology, source: NodeId, sink: NodeId) -> bool: + return any( + isinstance(edge, RDMAConnection) + for edge in topology.get_all_connections_between(source, sink) + ) + + +def test_mac_thunderbolt_connections_emits_rdma_when_both_endpoints_enabled(): + node_a = NodeId() + node_b = NodeId() + state = _make_state_with_thunderbolt_idents( + (node_a, "uuid-a", "rdma_en1"), + (node_b, "uuid-b", "rdma_en1"), + rdma_ctl={ + node_a: NodeRdmaCtlStatus(enabled=True), + node_b: NodeRdmaCtlStatus(enabled=True), + }, + ) + + event = NodeGatheredInfo( + node_id=node_a, + when=_now(), + info=MacThunderboltConnections( + conns=[ThunderboltConnection(source_uuid="uuid-a", sink_uuid="uuid-b")] + ), + ) + + new_state = apply_node_gathered_info(event, state) + + assert _has_rdma_edge(new_state.topology, node_a, node_b) + + +def test_mac_thunderbolt_connections_skips_rdma_when_source_rdma_ctl_disabled(): + node_a = NodeId() + node_b = NodeId() + state = _make_state_with_thunderbolt_idents( + (node_a, "uuid-a", "rdma_en1"), + (node_b, "uuid-b", "rdma_en1"), + rdma_ctl={ + node_a: NodeRdmaCtlStatus(enabled=False), + node_b: NodeRdmaCtlStatus(enabled=True), + }, + ) + + event = NodeGatheredInfo( + node_id=node_a, + when=_now(), + info=MacThunderboltConnections( + conns=[ThunderboltConnection(source_uuid="uuid-a", sink_uuid="uuid-b")] + ), + ) + + new_state = apply_node_gathered_info(event, state) + + assert not _has_rdma_edge(new_state.topology, node_a, node_b) + + +def test_mac_thunderbolt_connections_skips_rdma_when_sink_rdma_ctl_disabled(): + node_a = NodeId() + node_b = NodeId() + state = _make_state_with_thunderbolt_idents( + (node_a, "uuid-a", "rdma_en1"), + (node_b, "uuid-b", "rdma_en1"), + rdma_ctl={ + node_a: NodeRdmaCtlStatus(enabled=True), + node_b: NodeRdmaCtlStatus(enabled=False), + }, + ) + + event = NodeGatheredInfo( + node_id=node_a, + when=_now(), + info=MacThunderboltConnections( + conns=[ThunderboltConnection(source_uuid="uuid-a", sink_uuid="uuid-b")] + ), + ) + + new_state = apply_node_gathered_info(event, state) + + assert not _has_rdma_edge(new_state.topology, node_a, node_b) + + +def test_mac_thunderbolt_connections_skips_rdma_when_rdma_ctl_status_missing(): + """Missing rdma_ctl status defaults to not-enabled — node is RDMA-incapable.""" + node_a = NodeId() + node_b = NodeId() + state = _make_state_with_thunderbolt_idents( + (node_a, "uuid-a", "rdma_en1"), + (node_b, "uuid-b", "rdma_en1"), + rdma_ctl={ + node_a: NodeRdmaCtlStatus(enabled=True), + # node_b intentionally absent + }, + ) + + event = NodeGatheredInfo( + node_id=node_a, + when=_now(), + info=MacThunderboltConnections( + conns=[ThunderboltConnection(source_uuid="uuid-a", sink_uuid="uuid-b")] + ), + ) + + new_state = apply_node_gathered_info(event, state) + + assert not _has_rdma_edge(new_state.topology, node_a, node_b) + + +def test_rdma_ctl_status_disabled_purges_existing_rdma_edges(): + """When a node reports rdma_ctl disabled, all RDMA edges touching it must be removed.""" + node_a = NodeId() + node_b = NodeId() + + # Start with both nodes RDMA-enabled and existing RDMA edges in the topology. + state = _make_state_with_thunderbolt_idents( + (node_a, "uuid-a", "rdma_en1"), + (node_b, "uuid-b", "rdma_en1"), + rdma_ctl={ + node_a: NodeRdmaCtlStatus(enabled=True), + node_b: NodeRdmaCtlStatus(enabled=True), + }, + ) + state = apply_node_gathered_info( + NodeGatheredInfo( + node_id=node_a, + when=_now(), + info=MacThunderboltConnections( + conns=[ThunderboltConnection(source_uuid="uuid-a", sink_uuid="uuid-b")] + ), + ), + state, + ) + state = apply_node_gathered_info( + NodeGatheredInfo( + node_id=node_b, + when=_now(), + info=MacThunderboltConnections( + conns=[ThunderboltConnection(source_uuid="uuid-b", sink_uuid="uuid-a")] + ), + ), + state, + ) + assert _has_rdma_edge(state.topology, node_a, node_b) + assert _has_rdma_edge(state.topology, node_b, node_a) + + # Now node_a flips to rdma_ctl disabled — both directions of RDMA edge must drop. + state = apply_node_gathered_info( + NodeGatheredInfo( + node_id=node_a, when=_now(), info=RdmaCtlStatus(enabled=False) + ), + state, + ) + + assert not _has_rdma_edge(state.topology, node_a, node_b) + assert not _has_rdma_edge(state.topology, node_b, node_a) + assert state.node_rdma_ctl[node_a].enabled is False + + +def test_topology_remove_all_rdma_connections_touching_keeps_socket_edges(): + """Purging RDMA edges for a disabled node must not affect non-RDMA edges.""" + from exo.shared.types.multiaddr import Multiaddr + from exo.shared.types.topology import Connection, SocketConnection + + topology = Topology() + node_a = NodeId() + node_b = NodeId() + topology.add_node(node_a) + topology.add_node(node_b) + topology.add_connection( + Connection( + source=node_a, + sink=node_b, + edge=RDMAConnection( + source_rdma_iface="rdma_en1", sink_rdma_iface="rdma_en1" + ), + ) + ) + socket_edge = SocketConnection( + sink_multiaddr=Multiaddr(address="/ip4/10.0.0.1/tcp/8000") + ) + topology.add_connection(Connection(source=node_a, sink=node_b, edge=socket_edge)) + + topology.remove_all_rdma_connections_touching(node_a) + + assert not _has_rdma_edge(topology, node_a, node_b) + # Socket edge survives. + assert any( + isinstance(edge, SocketConnection) + for edge in topology.get_all_connections_between(node_a, node_b) + ) diff --git a/src/exo/shared/topology.py b/src/exo/shared/topology.py index 9d649a6f4e..5c12bc5077 100644 --- a/src/exo/shared/topology.py +++ b/src/exo/shared/topology.py @@ -169,6 +169,22 @@ def replace_all_out_rdma_connections( for conn in new_connections: self.add_connection(conn) + def remove_all_rdma_connections_touching(self, node_id: NodeId) -> None: + """Remove every incoming or outgoing RDMA edge touching node_id.""" + if node_id not in self._vertex_indices: + return + rx_idx = self._vertex_indices[node_id] + rdma_edge_idxs = [ + edge_idx + for edge_idx in ( + *self._graph.out_edge_indices(rx_idx), + *self._graph.in_edge_indices(rx_idx), + ) + if isinstance(self._graph.get_edge_data_by_index(edge_idx), RDMAConnection) + ] + for edge_idx in rdma_edge_idxs: + self._graph.remove_edge_from_index(edge_idx) + def remove_connection(self, conn: Connection) -> None: if ( conn.source not in self._vertex_indices From 47b0ccb7a9220931272beb13e82919c5afe1aa1c Mon Sep 17 00:00:00 2001 From: jw-wcv <101585096+jw-wcv@users.noreply.github.com> Date: Wed, 6 May 2026 23:44:19 -0700 Subject: [PATCH 11/29] Fix upstream port compatibility issues --- src/exo/download/peer_download.py | 68 +++++++++----------- src/exo/download/tests/test_peer_download.py | 53 +++++++-------- src/exo/master/tests/test_placement.py | 19 +++++- src/exo/shared/types/commands.py | 2 +- 4 files changed, 74 insertions(+), 68 deletions(-) diff --git a/src/exo/download/peer_download.py b/src/exo/download/peer_download.py index 1fab3657d6..a2c3326510 100644 --- a/src/exo/download/peer_download.py +++ b/src/exo/download/peer_download.py @@ -6,7 +6,6 @@ """ import asyncio -import json from dataclasses import dataclass from pathlib import Path from typing import Callable @@ -41,12 +40,11 @@ async def get_peer_file_status( try: async with aiohttp.ClientSession( timeout=aiohttp.ClientTimeout(total=timeout) - ) as session: - async with session.get(url) as r: - if r.status != 200: - return None - data = await r.json() - return [PeerFileInfo(**f) for f in data.get("files", [])] + ) as session, session.get(url) as r: + if r.status != 200: + return None + data = await r.json() + return [PeerFileInfo(**f) for f in data.get("files", [])] except Exception as e: logger.debug(f"Could not reach peer {peer_host}:{peer_port}: {e}") return None @@ -103,36 +101,32 @@ async def download_file_from_peer( got_bytes = False async with aiohttp.ClientSession( timeout=aiohttp.ClientTimeout(total=300, sock_read=60) - ) as session: - async with session.get(url, headers=headers) as r: - if r.status == 416: - # Range not satisfiable - peer doesn't have more yet - pass - elif r.status in (200, 206): - peer_complete = r.headers.get("X-Exo-Complete") == "true" - safe_bytes = int(r.headers.get("X-Exo-Safe-Bytes", "0")) - - async with aiofiles.open( - partial_path, "ab" if n_read > 0 else "wb" - ) as f: - while True: - chunk = await r.content.read(chunk_size) - if not chunk: - break - written = await f.write(chunk) - n_read += written - got_bytes = True - on_progress(n_read, expected_size, False) - elif r.status == 404: - logger.debug( - f"File {file_path} not found on peer {peer_host}" - ) - return None - else: - logger.warning( - f"Unexpected status {r.status} from peer {peer_host}" - ) - return None + ) as session, session.get(url, headers=headers) as r: + if r.status == 416: + # Range not satisfiable - peer doesn't have more yet + pass + elif r.status in (200, 206): + async with aiofiles.open( + partial_path, "ab" if n_read > 0 else "wb" + ) as f: + while True: + chunk = await r.content.read(chunk_size) + if not chunk: + break + written = await f.write(chunk) + n_read += written + got_bytes = True + on_progress(n_read, expected_size, False) + elif r.status == 404: + logger.debug( + f"File {file_path} not found on peer {peer_host}" + ) + return None + else: + logger.warning( + f"Unexpected status {r.status} from peer {peer_host}" + ) + return None # Check if we're done if n_read >= expected_size: diff --git a/src/exo/download/tests/test_peer_download.py b/src/exo/download/tests/test_peer_download.py index 04c49de786..396fcd5095 100644 --- a/src/exo/download/tests/test_peer_download.py +++ b/src/exo/download/tests/test_peer_download.py @@ -1,6 +1,5 @@ """Tests for peer-to-peer model downloading.""" -import asyncio import json from collections.abc import AsyncIterator from pathlib import Path @@ -45,13 +44,12 @@ async def test_health_check(self, peer_server: PeerFileServer) -> None: """Health endpoint should return ok.""" import aiohttp - async with aiohttp.ClientSession() as session: - async with session.get( - f"http://127.0.0.1:{peer_server.port}/health" - ) as r: - assert r.status == 200 - data = await r.json() - assert data["status"] == "ok" + async with aiohttp.ClientSession() as session, session.get( + f"http://127.0.0.1:{peer_server.port}/health" + ) as r: + assert r.status == 200 + data = await r.json() + assert data["status"] == "ok" async def test_status_empty_model(self, peer_server: PeerFileServer) -> None: """Status for non-existent model should return empty file list.""" @@ -122,14 +120,13 @@ async def test_serve_complete_file( import aiohttp - async with aiohttp.ClientSession() as session: - async with session.get( - f"http://127.0.0.1:{peer_server.port}/files/test--model/config.json" - ) as r: - assert r.status == 200 - assert r.headers["X-Exo-Complete"] == "true" - body = await r.read() - assert body == content + async with aiohttp.ClientSession() as session, session.get( + f"http://127.0.0.1:{peer_server.port}/files/test--model/config.json" + ) as r: + assert r.status == 200 + assert r.headers["X-Exo-Complete"] == "true" + body = await r.read() + assert body == content async def test_serve_with_range_request( self, peer_server: PeerFileServer, temp_models_dir: Path @@ -144,24 +141,22 @@ async def test_serve_with_range_request( import aiohttp - async with aiohttp.ClientSession() as session: - async with session.get( - f"http://127.0.0.1:{peer_server.port}/files/test--model/weights.bin", - headers={"Range": "bytes=8-"}, - ) as r: - assert r.status == 206 - body = await r.read() - assert body == b"89abcdef" + async with aiohttp.ClientSession() as session, session.get( + f"http://127.0.0.1:{peer_server.port}/files/test--model/weights.bin", + headers={"Range": "bytes=8-"}, + ) as r: + assert r.status == 206 + body = await r.read() + assert body == b"89abcdef" async def test_file_not_found(self, peer_server: PeerFileServer) -> None: """Should return 404 for missing files.""" import aiohttp - async with aiohttp.ClientSession() as session: - async with session.get( - f"http://127.0.0.1:{peer_server.port}/files/test--model/missing.bin" - ) as r: - assert r.status == 404 + async with aiohttp.ClientSession() as session, session.get( + f"http://127.0.0.1:{peer_server.port}/files/test--model/missing.bin" + ) as r: + assert r.status == 404 class TestPeerDownloadClient: diff --git a/src/exo/master/tests/test_placement.py b/src/exo/master/tests/test_placement.py index cb3aaba72c..fb2df5a1ae 100644 --- a/src/exo/master/tests/test_placement.py +++ b/src/exo/master/tests/test_placement.py @@ -664,6 +664,10 @@ def test_qwen3_5_tensor_auto_upgrade_requires_opt_in( model_card=model_card, min_nodes=2, ) + node_rdma_ctl = { + large_node: NodeRdmaCtlStatus(enabled=True), + small_node: NodeRdmaCtlStatus(enabled=True), + } placements_without_opt_in = place_instance( command, @@ -677,6 +681,7 @@ def test_qwen3_5_tensor_auto_upgrade_requires_opt_in( large_node: create_jaccl_node_network("192.168.0.1"), small_node: create_jaccl_node_network("192.168.0.2"), }, + node_rdma_ctl=node_rdma_ctl, ) instance_without_opt_in = next(iter(placements_without_opt_in.values())) large_runner_without_opt_in = ( @@ -703,6 +708,7 @@ def test_qwen3_5_tensor_auto_upgrade_requires_opt_in( large_node: create_jaccl_node_network("192.168.0.1"), small_node: create_jaccl_node_network("192.168.0.2"), }, + node_rdma_ctl=node_rdma_ctl, ) instance = next(iter(placements.values())) @@ -1017,8 +1023,19 @@ def test_jaccl_placement_uses_advertised_lan_ip_for_rdma_coordinator( model_card=model_card, min_nodes=2, ) + node_rdma_ctl = { + node_a: NodeRdmaCtlStatus(enabled=True), + node_b: NodeRdmaCtlStatus(enabled=True), + } - placements = place_instance(command, topology, {}, node_memory, node_network) + placements = place_instance( + command, + topology, + {}, + node_memory, + node_network, + node_rdma_ctl=node_rdma_ctl, + ) instance = list(placements.values())[0] assert isinstance(instance, MlxJacclInstance) diff --git a/src/exo/shared/types/commands.py b/src/exo/shared/types/commands.py index c05002f231..2088402127 100644 --- a/src/exo/shared/types/commands.py +++ b/src/exo/shared/types/commands.py @@ -68,7 +68,7 @@ class RequestEventLog(BaseCommand): since_idx: int -class PeerEndpoint(CamelCaseModel): +class PeerEndpoint(FrozenModel): """A peer node that has (or is downloading) a model, with its network address.""" node_id: NodeId From 34455fd6adc682d7d4bf749c806316197bb47953 Mon Sep 17 00:00:00 2001 From: jw-wcv <101585096+jw-wcv@users.noreply.github.com> Date: Wed, 6 May 2026 23:46:19 -0700 Subject: [PATCH 12/29] Harden peer download port for current typing --- src/exo/download/peer_download.py | 30 +++++++++++++++-- src/exo/download/peer_file_server.py | 34 +++++++++++++++----- src/exo/download/peer_shard_downloader.py | 24 +++++++++----- src/exo/download/peer_state.py | 2 +- src/exo/download/tests/test_peer_download.py | 4 ++- 5 files changed, 73 insertions(+), 21 deletions(-) diff --git a/src/exo/download/peer_download.py b/src/exo/download/peer_download.py index a2c3326510..0800a6fd18 100644 --- a/src/exo/download/peer_download.py +++ b/src/exo/download/peer_download.py @@ -8,7 +8,7 @@ import asyncio from dataclasses import dataclass from pathlib import Path -from typing import Callable +from typing import Callable, cast import aiofiles import aiofiles.os as aios @@ -26,6 +26,10 @@ class PeerFileInfo: safe_bytes: int +def _as_int(value: object) -> int: + return value if isinstance(value, int) else 0 + + async def get_peer_file_status( peer_host: str, peer_port: int, @@ -43,8 +47,28 @@ async def get_peer_file_status( ) as session, session.get(url) as r: if r.status != 200: return None - data = await r.json() - return [PeerFileInfo(**f) for f in data.get("files", [])] + data = cast(dict[str, object], await r.json()) + files = data.get("files", []) + if not isinstance(files, list): + return [] + raw_files = cast(list[object], files) + out: list[PeerFileInfo] = [] + required = {"path", "size", "complete", "safe_bytes"} + for raw_file in raw_files: + if not isinstance(raw_file, dict): + continue + file_info = cast(dict[str, object], raw_file) + if not required.issubset(file_info): + continue + out.append( + PeerFileInfo( + path=str(file_info["path"]), + size=_as_int(file_info["size"]), + complete=bool(file_info["complete"]), + safe_bytes=_as_int(file_info["safe_bytes"]), + ) + ) + return out except Exception as e: logger.debug(f"Could not reach peer {peer_host}:{peer_port}: {e}") return None diff --git a/src/exo/download/peer_file_server.py b/src/exo/download/peer_file_server.py index f36823ac27..25d602ddfd 100644 --- a/src/exo/download/peer_file_server.py +++ b/src/exo/download/peer_file_server.py @@ -11,12 +11,15 @@ import json from pathlib import Path +from typing import TypeAlias, cast import aiofiles import aiofiles.os as aios from aiohttp import web from loguru import logger +PartialMeta: TypeAlias = dict[str, int | str] + class PeerFileServer: """HTTP server that exposes local model files for peer download.""" @@ -55,7 +58,7 @@ async def _handle_status(self, request: web.Request) -> web.Response: if not await aios.path.exists(model_dir): return web.json_response({"files": []}) - files = [] + files: list[dict[str, object]] = [] for item in model_dir.iterdir(): if item.is_dir() or item.name.endswith(".partial.meta"): continue @@ -64,12 +67,14 @@ async def _handle_status(self, request: web.Request) -> web.Response: # In-progress file - read meta for safe bytes meta = await _read_partial_meta(item) if meta: + total = _meta_int(meta, "total") + safe_bytes = _meta_int(meta, "safe_bytes") files.append( { "path": item.name.removesuffix(".partial"), - "size": meta.get("total", 0), + "size": total, "complete": False, - "safe_bytes": meta.get("safe_bytes", 0), + "safe_bytes": safe_bytes, } ) else: @@ -107,11 +112,11 @@ async def _handle_file(self, request: web.Request) -> web.StreamResponse: is_complete = True elif await aios.path.exists(partial_path): meta = await _read_partial_meta(partial_path) - if not meta or meta.get("safe_bytes", 0) == 0: + if not meta or _meta_int(meta, "safe_bytes") == 0: return web.Response(status=404, text="File not available yet") serve_path = partial_path - file_size = meta.get("total", 0) - safe_bytes = meta["safe_bytes"] + file_size = _meta_int(meta, "total") + safe_bytes = _meta_int(meta, "safe_bytes") is_complete = False else: return web.Response(status=404, text="File not found") @@ -162,13 +167,26 @@ async def _handle_file(self, request: web.Request) -> web.StreamResponse: return response -async def _read_partial_meta(partial_path: Path) -> dict | None: +def _meta_int(meta: PartialMeta, key: str) -> int: + value = meta.get(key, 0) + return value if isinstance(value, int) else 0 + + +async def _read_partial_meta(partial_path: Path) -> PartialMeta | None: """Read the .partial.meta companion file for a .partial download.""" meta_path = Path(f"{partial_path}.meta") if not await aios.path.exists(meta_path): return None try: async with aiofiles.open(meta_path, "r") as f: - return json.loads(await f.read()) + data = cast(object, json.loads(await f.read())) + if not isinstance(data, dict): + return None + raw_meta = cast(dict[object, object], data) + return { + str(key): value + for key, value in raw_meta.items() + if isinstance(value, (int, str)) + } except (json.JSONDecodeError, OSError): return None diff --git a/src/exo/download/peer_shard_downloader.py b/src/exo/download/peer_shard_downloader.py index 4b5a71db34..88fcbb5333 100644 --- a/src/exo/download/peer_shard_downloader.py +++ b/src/exo/download/peer_shard_downloader.py @@ -11,20 +11,20 @@ import asyncio import time -from collections.abc import Awaitable +from collections.abc import Awaitable, Coroutine from datetime import timedelta from pathlib import Path -from typing import AsyncIterator, Callable +from typing import Any, AsyncIterator, Callable from loguru import logger from exo.download.download_utils import ( RepoDownloadProgress, calculate_repo_progress, - ensure_models_dir, fetch_file_list_with_cache, is_image_model, resolve_allow_patterns, + resolve_model_dir, ) from exo.download.huggingface_utils import filter_repo_objects from exo.download.peer_download import ( @@ -34,10 +34,18 @@ from exo.download.shard_downloader import ShardDownloader from exo.shared.types.commands import PeerEndpoint from exo.shared.types.memory import Memory -from exo.shared.types.worker.downloads import RepoFileDownloadProgress +from exo.shared.types.worker.downloads import FileListEntry, RepoFileDownloadProgress from exo.shared.types.worker.shards import ShardMetadata +async def _run_progress_callback( + callback: Callable[[ShardMetadata, RepoDownloadProgress], Awaitable[None]], + shard: ShardMetadata, + progress: RepoDownloadProgress, +) -> None: + await callback(shard, progress) + + class PeerAwareShardDownloader(ShardDownloader): """ShardDownloader that tries peer download before HuggingFace. @@ -129,7 +137,7 @@ async def _try_peer_download( # Get the file list we need (same logic as download_shard) revision = "main" - target_dir = await ensure_models_dir() / model_id_normalized + target_dir = await resolve_model_dir(shard.model_card.model_id) try: file_list = await fetch_file_list_with_cache( @@ -142,7 +150,7 @@ async def _try_peer_download( return None allow_patterns = await resolve_allow_patterns(shard) - filtered_file_list = list( + filtered_file_list: list[FileListEntry] = list( filter_repo_objects( file_list, allow_patterns=allow_patterns, key=lambda x: x.path ) @@ -197,7 +205,7 @@ def on_file_progress( all_start_time, ) for cb in self._progress_callbacks: - asyncio.create_task(cb(shard, progress)) + asyncio.create_task(_run_progress_callback(cb, shard, progress)) async with semaphore: result = await download_file_from_peer( @@ -227,7 +235,7 @@ def on_file_progress( ) # Download all files in parallel - tasks = [] + tasks: list[Coroutine[Any, Any, bool]] = [] for f in filtered_file_list: if f.size is None or f.size == 0: continue diff --git a/src/exo/download/peer_state.py b/src/exo/download/peer_state.py index 6f400b92a5..bfcabc58ed 100644 --- a/src/exo/download/peer_state.py +++ b/src/exo/download/peer_state.py @@ -100,7 +100,7 @@ def _resolve_peer_endpoint( status=status, connection_type="rdma", ) - elif isinstance(conn.edge, SocketConnection): + else: return PeerEndpoint( node_id=peer_node_id, ip=conn.edge.sink_multiaddr.ip_address, diff --git a/src/exo/download/tests/test_peer_download.py b/src/exo/download/tests/test_peer_download.py index 396fcd5095..4f8e140379 100644 --- a/src/exo/download/tests/test_peer_download.py +++ b/src/exo/download/tests/test_peer_download.py @@ -1,8 +1,10 @@ """Tests for peer-to-peer model downloading.""" +# pyright: reportPrivateUsage=false import json from collections.abc import AsyncIterator from pathlib import Path +from typing import cast import aiofiles import aiofiles.os as aios @@ -48,7 +50,7 @@ async def test_health_check(self, peer_server: PeerFileServer) -> None: f"http://127.0.0.1:{peer_server.port}/health" ) as r: assert r.status == 200 - data = await r.json() + data = cast(dict[str, object], await r.json()) assert data["status"] == "ok" async def test_status_empty_model(self, peer_server: PeerFileServer) -> None: From 1fe31b56db27024e63cc37c608a054442be347d7 Mon Sep 17 00:00:00 2001 From: jw-wcv <101585096+jw-wcv@users.noreply.github.com> Date: Wed, 6 May 2026 23:48:34 -0700 Subject: [PATCH 13/29] Use current model directory for peer file server --- src/exo/main.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/exo/main.py b/src/exo/main.py index fcea355d2c..51e66339ec 100644 --- a/src/exo/main.py +++ b/src/exo/main.py @@ -20,7 +20,7 @@ from exo.master.main import Master from exo.routing.event_router import EventRouter from exo.routing.router import Router, get_node_id_keypair -from exo.shared.constants import EXO_LOG, EXO_MODELS_DIR, EXO_PEER_DOWNLOAD_PORT +from exo.shared.constants import EXO_DEFAULT_MODELS_DIR, EXO_LOG, EXO_PEER_DOWNLOAD_PORT from exo.shared.election import Election, ElectionResult from exo.shared.logging import logger_cleanup, logger_set_context, logger_setup from exo.shared.types.common import NodeId, SessionId @@ -105,7 +105,7 @@ async def create(cls, args: "Args") -> Self: peer_file_server = PeerFileServer( host="0.0.0.0", port=EXO_PEER_DOWNLOAD_PORT, - models_dir=EXO_MODELS_DIR, + models_dir=EXO_DEFAULT_MODELS_DIR, ) if not args.no_downloads: From a8dcc3239734b8b199bce40540326a29696261bc Mon Sep 17 00:00:00 2001 From: jw-wcv <101585096+jw-wcv@users.noreply.github.com> Date: Thu, 7 May 2026 00:32:53 -0700 Subject: [PATCH 14/29] fix(download): harden peer file serving X-Orchestraitor-Task: exo-upstream-pr-import X-Orchestraitor-Plan: import-useful-upstream-prs X-Agent-Platform: codex --- src/exo/download/coordinator.py | 4 +- src/exo/download/peer_file_server.py | 37 +++-- src/exo/download/peer_shard_downloader.py | 34 ++++- src/exo/download/tests/test_peer_download.py | 142 +++++++++++++++++++ 4 files changed, 199 insertions(+), 18 deletions(-) diff --git a/src/exo/download/coordinator.py b/src/exo/download/coordinator.py index b5539a59c3..e566df1640 100644 --- a/src/exo/download/coordinator.py +++ b/src/exo/download/coordinator.py @@ -229,12 +229,12 @@ async def _command_processor(self) -> None: case StartDownload(shard_metadata=shard, available_peers=peers): # Pass peer endpoints to the shard downloader if it supports it if isinstance(self.shard_downloader, PeerAwareShardDownloader): - self.shard_downloader.set_available_peers(peers) + self.shard_downloader.set_available_peers(shard, peers) elif hasattr(self.shard_downloader, "shard_downloader") and isinstance( self.shard_downloader.shard_downloader, PeerAwareShardDownloader # type: ignore[union-attr] ): # Unwrap SingletonShardDownloader - self.shard_downloader.shard_downloader.set_available_peers(peers) # type: ignore[union-attr] + self.shard_downloader.shard_downloader.set_available_peers(shard, peers) # type: ignore[union-attr] await self._start_download(shard) case DeleteDownload(model_id=model_id): await self._delete_download(model_id) diff --git a/src/exo/download/peer_file_server.py b/src/exo/download/peer_file_server.py index 25d602ddfd..2cc64be549 100644 --- a/src/exo/download/peer_file_server.py +++ b/src/exo/download/peer_file_server.py @@ -53,17 +53,22 @@ async def _handle_health(self, request: web.Request) -> web.Response: async def _handle_status(self, request: web.Request) -> web.Response: """Return status of all files for a model (complete + in-progress).""" model_id = request.match_info["model_id"] - model_dir = self.models_dir / model_id + model_dir = _resolve_child(self.models_dir, model_id) + if model_dir is None: + return web.Response(status=404, text="Model not found") if not await aios.path.exists(model_dir): return web.json_response({"files": []}) files: list[dict[str, object]] = [] - for item in model_dir.iterdir(): - if item.is_dir() or item.name.endswith(".partial.meta"): + for item in model_dir.rglob("*"): + relative_path = item.relative_to(model_dir).as_posix() + if item.is_dir() or relative_path.endswith(".partial.meta"): + continue + if _resolve_child(model_dir, relative_path) is None: continue - if item.name.endswith(".partial"): + if relative_path.endswith(".partial"): # In-progress file - read meta for safe bytes meta = await _read_partial_meta(item) if meta: @@ -71,7 +76,7 @@ async def _handle_status(self, request: web.Request) -> web.Response: safe_bytes = _meta_int(meta, "safe_bytes") files.append( { - "path": item.name.removesuffix(".partial"), + "path": relative_path.removesuffix(".partial"), "size": total, "complete": False, "safe_bytes": safe_bytes, @@ -82,7 +87,7 @@ async def _handle_status(self, request: web.Request) -> web.Response: stat = await aios.stat(item) files.append( { - "path": item.name, + "path": relative_path, "size": stat.st_size, "complete": True, "safe_bytes": stat.st_size, @@ -100,9 +105,14 @@ async def _handle_file(self, request: web.Request) -> web.StreamResponse: model_id = request.match_info["model_id"] file_path = request.match_info["file_path"] - model_dir = self.models_dir / model_id - complete_path = model_dir / file_path - partial_path = model_dir / f"{file_path}.partial" + model_dir = _resolve_child(self.models_dir, model_id) + if model_dir is None: + return web.Response(status=404, text="Model not found") + + complete_path = _resolve_child(model_dir, file_path) + partial_path = _resolve_child(model_dir, f"{file_path}.partial") + if complete_path is None or partial_path is None: + return web.Response(status=404, text="File not found") # Determine which file to serve and its safe size if await aios.path.exists(complete_path): @@ -167,6 +177,15 @@ async def _handle_file(self, request: web.Request) -> web.StreamResponse: return response +def _resolve_child(root: Path, relative_path: str) -> Path | None: + """Resolve relative_path under root, rejecting path traversal.""" + resolved_root = root.resolve(strict=False) + resolved_path = (resolved_root / relative_path).resolve(strict=False) + if resolved_root in resolved_path.parents: + return resolved_path + return None + + def _meta_int(meta: PartialMeta, key: str) -> int: value = meta.get(key, 0) return value if isinstance(value, int) else 0 diff --git a/src/exo/download/peer_shard_downloader.py b/src/exo/download/peer_shard_downloader.py index 88fcbb5333..f816d40619 100644 --- a/src/exo/download/peer_shard_downloader.py +++ b/src/exo/download/peer_shard_downloader.py @@ -11,6 +11,7 @@ import asyncio import time +from collections import defaultdict, deque from collections.abc import Awaitable, Coroutine from datetime import timedelta from pathlib import Path @@ -37,6 +38,8 @@ from exo.shared.types.worker.downloads import FileListEntry, RepoFileDownloadProgress from exo.shared.types.worker.shards import ShardMetadata +ShardPeerKey = str + async def _run_progress_callback( callback: Callable[[ShardMetadata, RepoDownloadProgress], Awaitable[None]], @@ -60,16 +63,20 @@ def __init__(self, inner: ShardDownloader) -> None: self._progress_callbacks: list[ Callable[[ShardMetadata, RepoDownloadProgress], Awaitable[None]] ] = [] - # Peers are set per-download by the coordinator before calling ensure_shard - self._current_peers: list[PeerEndpoint] = [] + # Peers are set per-download by the coordinator before calling ensure_shard. + self._peers_by_shard: defaultdict[ + ShardPeerKey, deque[list[PeerEndpoint]] + ] = defaultdict(deque) - def set_available_peers(self, peers: list[PeerEndpoint]) -> None: - """Set the peers to try for the next ensure_shard call. + def set_available_peers( + self, shard: ShardMetadata, peers: list[PeerEndpoint] + ) -> None: + """Set the peers to try for a specific ensure_shard call. Called by DownloadCoordinator before triggering a download, based on the peers embedded in the StartDownload command. """ - self._current_peers = peers + self._peers_by_shard[_peer_key(shard)].append(list(peers)) def on_progress( self, @@ -86,8 +93,7 @@ async def ensure_shard( model_id = shard.model_card.model_id normalized = model_id.normalize() - peers = self._current_peers - self._current_peers = [] # Reset after consumption + peers = self._pop_available_peers(shard) if not peers: logger.debug(f"No peers available for {model_id}, downloading from HuggingFace") @@ -279,3 +285,17 @@ async def get_shard_download_status_for_shard( self, shard: ShardMetadata ) -> RepoDownloadProgress: return await self._inner.get_shard_download_status_for_shard(shard) + + def _pop_available_peers(self, shard: ShardMetadata) -> list[PeerEndpoint]: + key = _peer_key(shard) + queue = self._peers_by_shard.get(key) + if not queue: + return [] + peers = queue.popleft() + if not queue: + del self._peers_by_shard[key] + return peers + + +def _peer_key(shard: ShardMetadata) -> ShardPeerKey: + return shard.model_dump_json() diff --git a/src/exo/download/tests/test_peer_download.py b/src/exo/download/tests/test_peer_download.py index 4f8e140379..a3084b90e4 100644 --- a/src/exo/download/tests/test_peer_download.py +++ b/src/exo/download/tests/test_peer_download.py @@ -12,6 +12,13 @@ from exo.download.peer_download import download_file_from_peer, get_peer_file_status from exo.download.peer_file_server import PeerFileServer +from exo.download.peer_shard_downloader import PeerAwareShardDownloader +from exo.download.shard_downloader import NoopShardDownloader +from exo.shared.models.model_cards import ModelCard, ModelId, ModelTask +from exo.shared.types.commands import PeerEndpoint +from exo.shared.types.common import NodeId +from exo.shared.types.memory import Memory +from exo.shared.types.worker.shards import PipelineShardMetadata, ShardMetadata @pytest.fixture @@ -39,6 +46,24 @@ async def peer_server(temp_models_dir: Path) -> AsyncIterator[PeerFileServer]: await server.shutdown() +def _make_shard(model_id: ModelId) -> ShardMetadata: + return PipelineShardMetadata( + model_card=ModelCard( + model_id=model_id, + storage_size=Memory.from_mb(100), + n_layers=28, + hidden_size=1024, + supports_tensor=False, + tasks=[ModelTask.TextGeneration], + ), + device_rank=0, + world_size=1, + start_layer=0, + end_layer=28, + n_layers=28, + ) + + class TestPeerFileServer: """Tests for the HTTP file server that serves model files to peers.""" @@ -109,6 +134,32 @@ async def test_status_with_partial_file( assert files[0].safe_bytes == 1024 assert files[0].size == 4096 + async def test_status_includes_nested_files( + self, peer_server: PeerFileServer, temp_models_dir: Path + ) -> None: + """Status should report nested complete and partial files.""" + model_dir = temp_models_dir / "test--model" + nested_dir = model_dir / "snapshots" / "abc123" + await aios.makedirs(nested_dir, exist_ok=True) + + async with aiofiles.open(nested_dir / "config.json", "wb") as f: + await f.write(b"{}") + async with aiofiles.open(nested_dir / "model.safetensors.partial", "wb") as f: + await f.write(b"x" * 512) + async with aiofiles.open( + nested_dir / "model.safetensors.partial.meta", "w" + ) as f: + await f.write(json.dumps({"safe_bytes": 512, "total": 2048})) + + files = await get_peer_file_status( + "127.0.0.1", peer_server.port, "test--model" + ) + assert files is not None + by_path = {file.path: file for file in files} + assert by_path["snapshots/abc123/config.json"].complete is True + assert by_path["snapshots/abc123/model.safetensors"].complete is False + assert by_path["snapshots/abc123/model.safetensors"].safe_bytes == 512 + async def test_serve_complete_file( self, peer_server: PeerFileServer, temp_models_dir: Path ) -> None: @@ -130,6 +181,48 @@ async def test_serve_complete_file( body = await r.read() assert body == content + async def test_serve_nested_file( + self, peer_server: PeerFileServer, temp_models_dir: Path + ) -> None: + """Should serve a complete nested file with correct headers.""" + model_dir = temp_models_dir / "test--model" + nested_dir = model_dir / "snapshots" / "abc123" + await aios.makedirs(nested_dir, exist_ok=True) + + content = b"nested content" + async with aiofiles.open(nested_dir / "config.json", "wb") as f: + await f.write(content) + + import aiohttp + + async with aiohttp.ClientSession() as session, session.get( + f"http://127.0.0.1:{peer_server.port}/files/test--model/" + "snapshots/abc123/config.json" + ) as r: + assert r.status == 200 + body = await r.read() + assert body == content + + async def test_rejects_path_traversal( + self, peer_server: PeerFileServer, temp_models_dir: Path + ) -> None: + """Should not serve files outside the requested model directory.""" + model_dir = temp_models_dir / "test--model" + await aios.makedirs(model_dir, exist_ok=True) + + outside_file = temp_models_dir / "outside.txt" + async with aiofiles.open(outside_file, "wb") as f: + await f.write(b"outside") + + import aiohttp + + async with aiohttp.ClientSession() as session, session.get( + f"http://127.0.0.1:{peer_server.port}/files/test--model/" + "%2E%2E/outside.txt" + ) as r: + assert r.status == 404 + assert await r.text() != "outside" + async def test_serve_with_range_request( self, peer_server: PeerFileServer, temp_models_dir: Path ) -> None: @@ -260,3 +353,52 @@ async def test_skip_already_complete( assert result is not None assert result == download_dir / "config.json" + + +class TestPeerAwareShardDownloader: + """Tests for peer selection handoff into peer-aware downloads.""" + + def test_peers_are_queued_per_shard(self) -> None: + """Concurrent downloads should not overwrite each other's peer list.""" + downloader = PeerAwareShardDownloader(NoopShardDownloader()) + shard_a = _make_shard(ModelId("test-org/model-a")) + shard_b = _make_shard(ModelId("test-org/model-b")) + peer_a = PeerEndpoint( + node_id=NodeId("aaaaaaaa-aaaa-4aaa-8aaa-aaaaaaaaaaaa"), + ip="10.0.0.1", + port=52415, + ) + peer_b = PeerEndpoint( + node_id=NodeId("bbbbbbbb-bbbb-4bbb-8bbb-bbbbbbbbbbbb"), + ip="10.0.0.2", + port=52415, + ) + + downloader.set_available_peers(shard_a, [peer_a]) + downloader.set_available_peers(shard_b, [peer_b]) + + assert downloader._pop_available_peers(shard_b) == [peer_b] + assert downloader._pop_available_peers(shard_a) == [peer_a] + assert downloader._pop_available_peers(shard_a) == [] + + def test_peers_for_same_shard_are_not_overwritten(self) -> None: + """Repeated commands for one shard should be consumed FIFO.""" + downloader = PeerAwareShardDownloader(NoopShardDownloader()) + shard = _make_shard(ModelId("test-org/model-a")) + peer_a = PeerEndpoint( + node_id=NodeId("aaaaaaaa-aaaa-4aaa-8aaa-aaaaaaaaaaaa"), + ip="10.0.0.1", + port=52415, + ) + peer_b = PeerEndpoint( + node_id=NodeId("bbbbbbbb-bbbb-4bbb-8bbb-bbbbbbbbbbbb"), + ip="10.0.0.2", + port=52415, + ) + + downloader.set_available_peers(shard, [peer_a]) + downloader.set_available_peers(shard, [peer_b]) + + assert downloader._pop_available_peers(shard) == [peer_a] + assert downloader._pop_available_peers(shard) == [peer_b] + assert downloader._pop_available_peers(shard) == [] From 70e502dd21b3b520a803c4cb33689838563c8a43 Mon Sep 17 00:00:00 2001 From: Alex Cheema Date: Fri, 8 May 2026 21:15:04 -0700 Subject: [PATCH 15/29] fix: make darwin mdns discovery reliable Cherry-picked from upstream 701838aa (Alex Cheema, exo-explore/exo). Resolves trivial dataclass conflict in src/exo/main.py: combined `peer_file_server: PeerFileServer | None = None` (this branch's peer-to-peer download field) with `_libp2p_port: int` (mDNS commit's new field). Both fields kept; constructor positional args reordered to match. Type-checked clean (basedpyright 0 errors). Ruff clean. --- src/exo/main.py | 81 ++++++++++++++++++++++++++ src/exo/routing/mdns_announcer.py | 97 +++++++++++++++++++++++++++++++ src/exo/worker/main.py | 12 +++- 3 files changed, 188 insertions(+), 2 deletions(-) create mode 100644 src/exo/routing/mdns_announcer.py diff --git a/src/exo/main.py b/src/exo/main.py index 51e66339ec..9e0969ab82 100644 --- a/src/exo/main.py +++ b/src/exo/main.py @@ -1,9 +1,11 @@ import argparse +import ipaddress import multiprocessing as mp import os import resource import signal import subprocess +import sys from dataclasses import dataclass, field from datetime import datetime, timezone from typing import Self @@ -45,6 +47,7 @@ class Node: node_id: NodeId offline: bool _api_port: int + _libp2p_port: int peer_file_server: PeerFileServer | None = None _tg: TaskGroup = field(init=False, default_factory=TaskGroup) @@ -159,6 +162,7 @@ async def create(cls, args: "Args") -> Self: node_id, args.offline, args.api_port, + args.libp2p_port, peer_file_server, ) logger_set_context( @@ -187,6 +191,12 @@ async def run(self): tg.start_soon(self.master.run) if self.api: tg.start_soon(self.api.run) + if sys.platform == "darwin" and self._libp2p_port != 0: + tg.start_soon( + _darwin_mdns_broadcast_announcer, + self.node_id, + self._libp2p_port, + ) tg.start_soon(self._elect_loop) tg.start_soon(self._diagnostic_snapshot_loop) @@ -382,6 +392,77 @@ def _last_seen_ages(self, state: State) -> dict[str, float]: return ages +def _darwin_en0_ip_address() -> str | None: + try: + return subprocess.check_output( + ["ipconfig", "getifaddr", "en0"], + text=True, + stderr=subprocess.DEVNULL, + ).strip() + except (OSError, subprocess.CalledProcessError): + return None + + +def _darwin_en0_broadcast_address(ip_address: str) -> str | None: + try: + subnet_mask = subprocess.check_output( + ["ipconfig", "getoption", "en0", "subnet_mask"], + text=True, + stderr=subprocess.DEVNULL, + ).strip() + interface = ipaddress.IPv4Interface(f"{ip_address}/{subnet_mask}") + return str(interface.network.broadcast_address) + except (OSError, ValueError, subprocess.CalledProcessError): + return None + + +async def _darwin_mdns_broadcast_announcer( + node_id: NodeId, libp2p_port: int +) -> None: + ip_address = _darwin_en0_ip_address() + if not ip_address: + logger.debug("Darwin mDNS broadcast announcer disabled: no en0 IPv4 address") + return + + broadcast_address = _darwin_en0_broadcast_address(ip_address) + logger.debug( + f"Darwin mDNS announcer advertising {node_id} at {ip_address}:{libp2p_port}" + ) + command = [ + sys.executable, + "-m", + "exo.routing.mdns_announcer", + "--node-id", + str(node_id), + "--ip-address", + ip_address, + "--libp2p-port", + str(libp2p_port), + ] + if broadcast_address is not None: + command.extend(["--broadcast-address", broadcast_address]) + process = subprocess.Popen( + command, + start_new_session=True, + stdout=subprocess.DEVNULL, + ) + try: + while process.poll() is None: + await anyio.sleep(60) + logger.debug( + f"Darwin mDNS announcer subprocess exited with {process.returncode}" + ) + finally: + if process.poll() is None: + process.terminate() + with anyio.move_on_after(2): + while process.poll() is None: + await anyio.sleep(0.1) + if process.poll() is None: + process.kill() + await anyio.sleep(0) + + def main(): args = Args.parse() soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE) diff --git a/src/exo/routing/mdns_announcer.py b/src/exo/routing/mdns_announcer.py new file mode 100644 index 0000000000..cafd1d3acc --- /dev/null +++ b/src/exo/routing/mdns_announcer.py @@ -0,0 +1,97 @@ +import argparse +import contextlib +import random +import socket +import string +import struct +import sys +import time +from typing import final + + +def _dns_qname(name: bytes) -> bytes: + return b"".join(bytes([len(part)]) + part for part in name.split(b".")) + b"\0" + + +def _build_response_packet(node_id: str, ip_address: str, libp2p_port: int) -> bytes: + service_name = b"_p2p._udp.local" + peer_name = ( + "".join(random.choice(string.ascii_letters + string.digits) for _ in range(32)) + + "._p2p._udp.local" + ).encode() + txt_record = f"dnsaddr=/ip4/{ip_address}/tcp/{libp2p_port}/p2p/{node_id}".encode() + + peer_qname = _dns_qname(peer_name) + packet = bytearray() + packet += struct.pack("!HHHHHH", 0, 0x8400, 0, 1, 0, 1) + packet += _dns_qname(service_name) + packet += struct.pack("!HHI", 12, 1, 120) + packet += struct.pack("!H", len(peer_qname)) + packet += peer_qname + packet += peer_qname + packet += struct.pack("!HHI", 16, 1, 120) + packet += struct.pack("!H", len(txt_record) + 1) + packet += bytes([len(txt_record)]) + packet += txt_record + return bytes(packet) + + +@final +class Args(argparse.Namespace): + node_id: str + ip_address: str + libp2p_port: int + broadcast_address: str | None + count: int + + @staticmethod + def parse() -> "Args": + parser = argparse.ArgumentParser() + parser.add_argument("--node-id", required=True) + parser.add_argument("--ip-address", required=True) + parser.add_argument("--libp2p-port", required=True, type=int) + parser.add_argument("--broadcast-address") + parser.add_argument("--count", default=0, type=int) + return parser.parse_args(namespace=Args()) + + +def main() -> None: + args = Args.parse() + sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + with contextlib.suppress(OSError): + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) + sock.bind((args.ip_address, 0)) + + sent_count = 0 + while True: + packet = _build_response_packet( + args.node_id, args.ip_address, args.libp2p_port + ) + errors: list[str] = [] + destinations: list[tuple[str, int]] = [] + if args.broadcast_address is not None: + destinations.append((args.broadcast_address, 5353)) + destinations.extend([("255.255.255.255", 5353), ("224.0.0.251", 5353)]) + sent = False + for destination in destinations: + try: + sock.sendto(packet, destination) + sent = True + except OSError as err: + errors.append(f"{destination}: {err}") + if not sent: + print( + f"mDNS announcer send failed: {'; '.join(errors)}", + file=sys.stderr, + flush=True, + ) + sent_count += 1 + if args.count > 0 and sent_count >= args.count: + return + time.sleep(1.0 if sent_count < 60 else 10.0) + + +if __name__ == "__main__": + main() diff --git a/src/exo/worker/main.py b/src/exo/worker/main.py index 20c861e007..9e2a07010a 100644 --- a/src/exo/worker/main.py +++ b/src/exo/worker/main.py @@ -3,7 +3,7 @@ from datetime import datetime, timezone import anyio -from anyio import fail_after, to_thread +from anyio import fail_after, move_on_after, to_thread from loguru import logger from exo.api.types import ImageEditsTaskParams @@ -377,8 +377,16 @@ async def plan_step(self): await self._start_runner_task(task) async def shutdown(self): + self.event_sender.close() + self.command_sender.close() + self.download_command_sender.close() + for runner in self.runners.values(): + runner.shutdown() self._tg.cancel_tasks() - await self._stopped.wait() + with move_on_after(5) as scope: + await self._stopped.wait() + if scope.cancel_called: + logger.warning("Timed out waiting for Worker shutdown") async def _start_runner_task(self, task: Task): if (instance := self.state.instances.get(task.instance_id)) is not None: From 3f217861710765bb91bd0070a226ca17481dd134 Mon Sep 17 00:00:00 2001 From: jw-wcv <101585096+jw-wcv@users.noreply.github.com> Date: Fri, 8 May 2026 22:21:49 -0700 Subject: [PATCH 16/29] Guard quantized cache + integrity-check peer downloads Two fixes for PR #16 round 2 (Codex): 1. P1: Skip quantized caches for non-PP make_cache models. ``make_kv_cache`` had a single-node safeguard (``_model_is_pipeline_parallel``) on the make_cache-LESS branch, but the make_cache-aware branch above it unconditionally replaced ``KVCache`` entries with ``QuantizedKVCache`` whenever ``EXO_KV_CACHE_BITS`` was set. Models that expose ``make_cache()`` (e.g. Gemma3 with mixed attention layers) and run single-node would therefore crash at runtime with:: does not yet support batching with history because mlx-lm's single-node ``BatchGenerator`` calls ``_merge_caches`` on every step and ``QuantizedKVCache`` doesn't implement ``merge``. Fix: apply the same ``_model_is_pipeline_parallel`` guard to the make_cache branch and emit the same warning when the env var is ignored. 2. P2: Verify peer file integrity before marking download complete. ``download_one`` in ``peer_shard_downloader.py`` marked peer downloads successful as soon as ``n_read == expected_size``, with no content-integrity check. A peer serving wrong bytes with the right length (stale/corrupt/malicious) was therefore silently accepted as model data, causing hard-to-diagnose inference failures or bad outputs. Fix: after the peer download completes, fetch the authoritative etag/hash from HuggingFace via ``file_meta()`` and validate the downloaded file via ``calc_hash``. On mismatch the file is removed and the caller falls back to direct HF download. Trusts HF as canonical source rather than peer-advertised hash to defend against malicious peers that lie about both content and hash. ``file_meta`` adds one HEAD round-trip per file; ``fetch_file_list_with_cache`` already requires HF connectivity at this code path so the network requirement is not new. The semantics now match the direct HuggingFace download path which has done identical validation since import. --- src/exo/download/peer_shard_downloader.py | 70 ++++++++++++++++++++++- src/exo/worker/engines/mlx/cache.py | 27 ++++++++- 2 files changed, 93 insertions(+), 4 deletions(-) diff --git a/src/exo/download/peer_shard_downloader.py b/src/exo/download/peer_shard_downloader.py index f816d40619..8a75e9b420 100644 --- a/src/exo/download/peer_shard_downloader.py +++ b/src/exo/download/peer_shard_downloader.py @@ -15,14 +15,17 @@ from collections.abc import Awaitable, Coroutine from datetime import timedelta from pathlib import Path -from typing import Any, AsyncIterator, Callable +from typing import Any, AsyncIterator, Callable, Literal +import aiofiles.os as aios from loguru import logger from exo.download.download_utils import ( RepoDownloadProgress, + calc_hash, calculate_repo_progress, fetch_file_list_with_cache, + file_meta, is_image_model, resolve_allow_patterns, resolve_model_dir, @@ -223,7 +226,70 @@ def on_file_progress( expected_size, on_progress=on_file_progress, ) - return result is not None + if result is None: + return False + # Codex flagged (P2, PR #16 round 2) that peer downloads + # were marked successful as soon as ``n_read == + # expected_size``, with no content-integrity check. A + # peer serving wrong bytes with the right length + # (stale/corrupt/malicious) would otherwise be + # silently accepted as model data, causing + # hard-to-diagnose inference failures. + # + # Validate against HuggingFace's authoritative hash: + # we already need internet for ``fetch_file_list_with_cache`` + # (line 149), so the extra ``file_meta()`` HEAD is + # cheap. Trusting a hash advertised by the peer would + # leave us vulnerable to a malicious peer that lies + # about both bytes and hash; HF is the canonical + # source. + # + # On mismatch the partial-or-renamed file is removed + # so the caller's HF fallback (``self._inner.ensure_shard``) + # starts from a clean slate. + try: + _expected_size, expected_etag = await file_meta( + shard.model_card.model_id, revision, file_path + ) + except Exception as exc: + # If we can't reach HF for metadata, the file + # might still be valid -- but we can't prove it. + # Fall back to HF download where the same call + # would have happened anyway. + logger.warning( + f"Peer download integrity-check failed: could not " + f"fetch HF metadata for {file_path}: {exc}; " + f"discarding peer-downloaded copy" + ) + try: + await aios.remove(result) + except Exception as cleanup_exc: + logger.debug( + f"Could not remove unverified peer download " + f"{result}: {cleanup_exc}" + ) + return False + + hash_type: Literal["sha1", "sha256"] = ( + "sha256" if len(expected_etag) == 64 else "sha1" + ) + final_hash = await calc_hash(result, hash_type=hash_type) + if final_hash != expected_etag: + logger.warning( + f"Peer-downloaded {file_path} from {peer_ip} has " + f"hash {final_hash} but HF authoritative hash is " + f"{expected_etag} ({hash_type}); discarding and " + f"falling back to HF" + ) + try: + await aios.remove(result) + except Exception as exc: + logger.error( + f"Failed to remove corrupt peer download " + f"{result}: {exc}" + ) + return False + return True # Initialize progress for all files for f in filtered_file_list: diff --git a/src/exo/worker/engines/mlx/cache.py b/src/exo/worker/engines/mlx/cache.py index 53cd274874..84f852933f 100644 --- a/src/exo/worker/engines/mlx/cache.py +++ b/src/exo/worker/engines/mlx/cache.py @@ -597,7 +597,23 @@ def make_kv_cache( caches: list[ KVCache | RotatingKVCache | QuantizedKVCache | ArraysCache | CacheList ] = list(model.make_cache()) # pyright: ignore[reportUnknownArgumentType, reportUnknownMemberType] - if KV_CACHE_BITS is not None: + # Apply the same single-node safeguard used in the + # ``make_cache``-less branch below: ``QuantizedKVCache`` + # cannot be combined with the single-node ``BatchGenerator`` + # path because mlx-lm calls ``_merge_caches`` on every step + # and ``QuantizedKVCache`` doesn't implement ``merge``. Models + # with ``make_cache()`` (e.g. Gemma3 with mixed attention + # layers) used to skip this guard and would crash at runtime + # with:: + # + # does not + # yet support batching with history + # + # Pipeline-parallel deployments use a different generation + # path that does support quantized caches, so we honor + # ``EXO_KV_CACHE_BITS`` only when the model has PP layer + # wrappers installed. + if KV_CACHE_BITS is not None and _model_is_pipeline_parallel(model): # Honor KV_CACHE_BITS even when the model provides its own # make_cache(). Replace plain KVCache entries with # QuantizedKVCache; leave ArraysCache (DeltaNet/SSM) and other @@ -617,7 +633,14 @@ def make_kv_cache( f"for {quantized}/{len(caches)} layers" ) else: - logger.info("Using MLX LM's make cache") + if KV_CACHE_BITS is not None: + logger.info( + f"EXO_KV_CACHE_BITS={KV_CACHE_BITS} ignored in single-node mode " + f"(QuantizedKVCache has no merge() support, " + f"required by BatchGenerator); using model.make_cache() unmodified" + ) + else: + logger.info("Using MLX LM's make cache") # Increase KVCache step size to reduce Metal allocator # fragmentation. Default step=256 causes a mx.concatenate # expansion every prefill chunk; a larger step lets the cache From 3cf6bbf110cccdbd7e8adfe02b798cdcc1c1664e Mon Sep 17 00:00:00 2001 From: jw-wcv <101585096+jw-wcv@users.noreply.github.com> Date: Fri, 8 May 2026 22:42:18 -0700 Subject: [PATCH 17/29] Mirror download_shard ignore_patterns and offline flag in peer path Address Codex round-2 P1s on PR #16. P1: Reuse ``ignore_patterns`` when selecting peer download files - ``download_shard`` (download_utils.py:983) excludes ``original/*`` and ``metal/*`` because HuggingFace never downloads them. The peer path applied ``allow_patterns`` only, so any repo containing those paths (e.g. Llama 3.x) had a required-files list that included files the peer never had locally; the strict ``peer_info missing => fail`` check then aborted the entire transfer and forced a HF fallback for every download. Pass the same ``ignore_patterns=["original/*", "metal/*"]`` into ``filter_repo_objects`` to match selection. P1: Honor offline mode in peer file-list fetch - ``_try_peer_download`` hard-coded ``skip_internet=False`` when calling ``fetch_file_list_with_cache``, so offline-configured nodes still reached out to HuggingFace before downloading from a LAN peer. On cold/offline nodes without a cached file list this raised, the ``except: return None`` exited early, and the peer download could not even start. Add an ``offline`` parameter to ``PeerAwareShardDownloader.__init__`` (defaulting to ``False``) and thread the existing ``DownloadCoordinator.offline`` value through ``exo_shard_downloader`` so the peer file-list fetch honors the same offline contract as ``ResumableShardDownloader``. Tests - ``test_offline_flag_defaults_to_false`` / ``..._propagates`` cover the constructor wiring directly. - ``test_try_peer_download_passes_offline_to_fetch_file_list`` patches the import binding and asserts ``skip_internet=True`` is forwarded when the downloader is constructed with ``offline=True``. - ``test_try_peer_download_filters_ignore_patterns`` records every ``filter_repo_objects`` call and asserts that the peer path requested ``ignore_patterns=["original/*", "metal/*"]`` -- the exact set ``download_shard`` uses. --- src/exo/download/impl_shard_downloader.py | 2 +- src/exo/download/peer_shard_downloader.py | 33 ++- src/exo/download/tests/test_peer_download.py | 213 ++++++++++++++++++- 3 files changed, 242 insertions(+), 6 deletions(-) diff --git a/src/exo/download/impl_shard_downloader.py b/src/exo/download/impl_shard_downloader.py index 4db0f1f36c..4312da368d 100644 --- a/src/exo/download/impl_shard_downloader.py +++ b/src/exo/download/impl_shard_downloader.py @@ -34,7 +34,7 @@ def exo_shard_downloader( max_parallel_downloads, offline=offline ) if peer_download_enabled: - inner = PeerAwareShardDownloader(inner) + inner = PeerAwareShardDownloader(inner, offline=offline) return SingletonShardDownloader(inner) diff --git a/src/exo/download/peer_shard_downloader.py b/src/exo/download/peer_shard_downloader.py index 8a75e9b420..7b6b4db24f 100644 --- a/src/exo/download/peer_shard_downloader.py +++ b/src/exo/download/peer_shard_downloader.py @@ -61,8 +61,17 @@ class PeerAwareShardDownloader(ShardDownloader): no peer has it or the transfer fails. """ - def __init__(self, inner: ShardDownloader) -> None: + def __init__(self, inner: ShardDownloader, offline: bool = False) -> None: self._inner = inner + # ``offline`` mirrors ``ResumableShardDownloader.offline`` and is + # forwarded to ``fetch_file_list_with_cache`` so that a node + # configured for offline operation never reaches out to + # HuggingFace before attempting a peer download. Pre-fix the + # peer path hard-coded ``skip_internet=False`` and would raise + # on cold/offline nodes that lacked a cached file list, ending + # the peer attempt before it could even start. Codex flagged + # this as a P1 (PR #16 round 2). + self._offline = offline self._progress_callbacks: list[ Callable[[ShardMetadata, RepoDownloadProgress], Awaitable[None]] ] = [] @@ -153,15 +162,33 @@ async def _try_peer_download( shard.model_card.model_id, revision, recursive=True, - skip_internet=False, + # Honor the coordinator's offline setting so a cold + # offline node can still satisfy a peer download from + # the LAN without reaching out to HuggingFace for the + # initial file-list fetch (Codex P1, PR #16 round 2). + skip_internet=self._offline, ) except Exception: return None allow_patterns = await resolve_allow_patterns(shard) + # Mirror ``download_shard``'s selection logic exactly: it filters + # by ``allow_patterns`` AND ``ignore_patterns`` before deciding + # which files to fetch. Pre-fix the peer path applied + # ``allow_patterns`` only and missed the ignore set, so for any + # repo containing ``original/*`` or ``metal/*`` (e.g. Llama 3.x + # repos) the peer would not have those files locally, and the + # later strict ``peer_info`` missing => fail check would abort + # the whole peer transfer and force a HuggingFace fallback for + # every download (Codex P1, PR #16 round 2). Keep this list in + # sync with ``download_shard`` (download_utils.py:983). + ignore_patterns = ["original/*", "metal/*"] filtered_file_list: list[FileListEntry] = list( filter_repo_objects( - file_list, allow_patterns=allow_patterns, key=lambda x: x.path + file_list, + allow_patterns=allow_patterns, + ignore_patterns=ignore_patterns, + key=lambda x: x.path, ) ) diff --git a/src/exo/download/tests/test_peer_download.py b/src/exo/download/tests/test_peer_download.py index a3084b90e4..7a252a463e 100644 --- a/src/exo/download/tests/test_peer_download.py +++ b/src/exo/download/tests/test_peer_download.py @@ -2,9 +2,9 @@ # pyright: reportPrivateUsage=false import json -from collections.abc import AsyncIterator +from collections.abc import AsyncIterator, Generator, Iterable from pathlib import Path -from typing import cast +from typing import Callable, cast import aiofiles import aiofiles.os as aios @@ -402,3 +402,212 @@ def test_peers_for_same_shard_are_not_overwritten(self) -> None: assert downloader._pop_available_peers(shard) == [peer_a] assert downloader._pop_available_peers(shard) == [peer_b] assert downloader._pop_available_peers(shard) == [] + + +class TestPeerSelectionRespectsOfflineAndIgnorePatterns: + """Codex P1s on PR #16 round 2: peer selection must mirror + ``download_shard``'s logic exactly (``ignore_patterns`` for + ``original/*`` / ``metal/*``) and must propagate the coordinator's + offline mode into ``fetch_file_list_with_cache`` so a cold offline + node can still complete a peer download without reaching out to + HuggingFace for the initial file list. + """ + + def test_offline_flag_defaults_to_false(self) -> None: + downloader = PeerAwareShardDownloader(NoopShardDownloader()) + assert downloader._offline is False + + def test_offline_flag_propagates(self) -> None: + downloader = PeerAwareShardDownloader( + NoopShardDownloader(), offline=True + ) + assert downloader._offline is True + + async def test_try_peer_download_passes_offline_to_fetch_file_list( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + """``_try_peer_download`` must thread ``self._offline`` into + ``fetch_file_list_with_cache`` instead of always passing + ``skip_internet=False``. We capture the kwargs by patching + the import binding inside ``peer_shard_downloader``. + """ + from exo.download import peer_shard_downloader as psd + from exo.download.peer_download import PeerFileInfo + from exo.shared.types.worker.downloads import FileListEntry + + captured: dict[str, object] = {} + + async def fake_fetch( + *args: object, **kwargs: object + ) -> list[FileListEntry]: + captured["args"] = args + captured["kwargs"] = kwargs + # Empty list -> no required files -> ``failed`` short- + # circuit -> we get out cleanly with the call kwargs + # captured. + return [] + + async def fake_peer_status( + peer_host: str, + peer_port: int, + model_id_normalized: str, + timeout: float = 5.0, + ) -> list[PeerFileInfo] | None: + return [ + PeerFileInfo( + path="model-00001-of-00002.safetensors", + size=10, + complete=True, + safe_bytes=10, + ) + ] + + async def fake_resolve_dir(model_id: ModelId) -> Path: + return Path("/tmp/fake-model") + + async def fake_resolve_allow(shard: ShardMetadata) -> list[str]: + return ["*.safetensors"] + + monkeypatch.setattr(psd, "fetch_file_list_with_cache", fake_fetch) + monkeypatch.setattr(psd, "get_peer_file_status", fake_peer_status) + monkeypatch.setattr(psd, "resolve_model_dir", fake_resolve_dir) + monkeypatch.setattr(psd, "resolve_allow_patterns", fake_resolve_allow) + + downloader = PeerAwareShardDownloader( + NoopShardDownloader(), offline=True + ) + shard = _make_shard(ModelId("test-org/model-a")) + + result = await downloader._try_peer_download( + shard, + peer_ip="10.0.0.1", + peer_port=52415, + model_id_normalized="test-org/model-a", + ) + # Empty file list short-circuits to ``failed`` path and returns + # None, but that's beside the point -- we just need the kwargs. + assert result is None + assert captured["kwargs"] == { + "recursive": True, + "skip_internet": True, + }, ( + "skip_internet must reflect downloader.offline (got " + f"{captured['kwargs']!r})" + ) + + async def test_try_peer_download_filters_ignore_patterns( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + """Files under ``original/*`` and ``metal/*`` are excluded by + ``download_shard``; the peer path must skip them too. Pre-fix + the peer path filtered only ``allow_patterns``, leaving these + in the required-files list. The peer doesn't have them + locally (HF never downloads them), the strict + ``peer_info missing => fail`` check fired, and every download + fell back to HuggingFace. + """ + from exo.download import peer_shard_downloader as psd + from exo.download.peer_download import PeerFileInfo + from exo.shared.types.worker.downloads import FileListEntry + + served = [ + FileListEntry( + type="file", + path="model-00001-of-00002.safetensors", + size=100, + ), + FileListEntry(type="file", path="config.json", size=10), + # These two should NOT show up on the peer's required-files + # list once the fix lands. Pre-fix they did, the peer didn't + # have them, and the whole transfer fell back to HF. + FileListEntry( + type="file", path="original/consolidated.00.pth", size=999 + ), + FileListEntry(type="file", path="metal/dist.bin", size=999), + ] + + async def fake_fetch( + *_args: object, **_kwargs: object + ) -> list[FileListEntry]: + return served + + # The peer reports ONLY the canonical files, exactly the shape + # production peers are in (HF never downloaded ``original/*`` or + # ``metal/*`` for them either). + peer_paths = ("model-00001-of-00002.safetensors", "config.json") + + async def fake_peer_status( + peer_host: str, + peer_port: int, + model_id_normalized: str, + timeout: float = 5.0, + ) -> list[PeerFileInfo] | None: + return [ + PeerFileInfo( + path=p, size=100, complete=True, safe_bytes=100 + ) + for p in peer_paths + ] + + async def fake_resolve_dir(model_id: ModelId) -> Path: + return Path("/tmp/fake-model") + + async def fake_resolve_allow(shard: ShardMetadata) -> list[str]: + # Match the production allow set permissively; the legacy + # bug was that ``allow_patterns`` admitted ``original/*`` / + # ``metal/*`` whenever the repo allow-list was loose. + return ["*"] + + async def fake_download( + peer_ip: str, + peer_port: int, + model_id_normalized: str, + file_path: str, + target_dir: Path, + expected_size: int, + on_progress: object = None, + ) -> Path | None: + return None + + captured_kwargs: list[object] = [] + real_filter = psd.filter_repo_objects + + def recording_filter( + items: Iterable[FileListEntry], + *, + allow_patterns: list[str] | str | None = None, + ignore_patterns: list[str] | str | None = None, + key: Callable[[FileListEntry], str] | None = None, + ) -> Generator[FileListEntry, None, None]: + captured_kwargs.append(ignore_patterns) + yield from real_filter( + items, + allow_patterns=allow_patterns, + ignore_patterns=ignore_patterns, + key=key, + ) + + monkeypatch.setattr(psd, "fetch_file_list_with_cache", fake_fetch) + monkeypatch.setattr(psd, "get_peer_file_status", fake_peer_status) + monkeypatch.setattr(psd, "resolve_model_dir", fake_resolve_dir) + monkeypatch.setattr(psd, "resolve_allow_patterns", fake_resolve_allow) + monkeypatch.setattr(psd, "download_file_from_peer", fake_download) + monkeypatch.setattr(psd, "filter_repo_objects", recording_filter) + + downloader = PeerAwareShardDownloader(NoopShardDownloader()) + shard = _make_shard(ModelId("test-org/model-a")) + + await downloader._try_peer_download( + shard, + peer_ip="10.0.0.1", + peer_port=52415, + model_id_normalized="test-org/model-a", + ) + + assert captured_kwargs == [["original/*", "metal/*"]], ( + "peer download must apply the same ``ignore_patterns`` set " + "as ``download_shard`` (download_utils.py:983) so peers " + "that don't have ``original/*`` / ``metal/*`` aren't " + "incorrectly judged incomplete; got " + f"{captured_kwargs!r}" + ) From 7ef79bc38e4c5fe9711fcb2b23007c26896ec33f Mon Sep 17 00:00:00 2001 From: jw-wcv <101585096+jw-wcv@users.noreply.github.com> Date: Fri, 8 May 2026 23:13:56 -0700 Subject: [PATCH 18/29] PR #16 R3: skip HF integrity check in offline mode + per-process peer port Codex P1 (round 3): peer downloads were calling ``file_meta()`` for every file even when the coordinator ran with ``--offline`` / ``EXO_OFFLINE=true``. Any failure to reach HF (the entire point of offline mode) was treated as an integrity-check failure, the peer-fetched bytes were deleted, and the cold offline node was left with no path to complete model sync. When the downloader is in offline mode we now trust the LAN peer's bytes (size already enforced by ``download_file_from_peer``) and skip the HF canonical-hash call entirely. Online mode still validates against HF. Codex P2 (round 3): the peer-download listener was hard-coded to a single module-level constant, so a same-host multi-node deployment crashed on the second process with ``address already in use``. Add a ``--peer-download-port`` CLI flag (default ``EXO_PEER_DOWNLOAD_PORT``) and thread it through ``Node`` -> ``Worker`` and ``PeerFileServer``, replacing the import-time constant. Cluster-wide convention: every node uses the same value (peer discovery still uses the local value as the assumed remote port). Cross-node port advertisement via state is a documented follow-up. Tests: - ``TestPeerDownloadIntegrityCheckRespectsOfflineMode`` covers both modes: offline must not call ``file_meta`` and must keep the bytes, online must still call ``file_meta``. --- src/exo/download/peer_shard_downloader.py | 46 +-- src/exo/download/tests/test_peer_download.py | 279 +++++++++++++++---- src/exo/main.py | 46 ++- src/exo/worker/main.py | 14 +- 4 files changed, 308 insertions(+), 77 deletions(-) diff --git a/src/exo/download/peer_shard_downloader.py b/src/exo/download/peer_shard_downloader.py index 7b6b4db24f..85b282f252 100644 --- a/src/exo/download/peer_shard_downloader.py +++ b/src/exo/download/peer_shard_downloader.py @@ -76,9 +76,9 @@ def __init__(self, inner: ShardDownloader, offline: bool = False) -> None: Callable[[ShardMetadata, RepoDownloadProgress], Awaitable[None]] ] = [] # Peers are set per-download by the coordinator before calling ensure_shard. - self._peers_by_shard: defaultdict[ - ShardPeerKey, deque[list[PeerEndpoint]] - ] = defaultdict(deque) + self._peers_by_shard: defaultdict[ShardPeerKey, deque[list[PeerEndpoint]]] = ( + defaultdict(deque) + ) def set_available_peers( self, shard: ShardMetadata, peers: list[PeerEndpoint] @@ -108,7 +108,9 @@ async def ensure_shard( peers = self._pop_available_peers(shard) if not peers: - logger.debug(f"No peers available for {model_id}, downloading from HuggingFace") + logger.debug( + f"No peers available for {model_id}, downloading from HuggingFace" + ) return await self._inner.ensure_shard(shard, config_only=False) # Try each peer (already sorted by priority: RDMA first, completed first) @@ -121,16 +123,16 @@ async def ensure_shard( shard, peer.ip, peer.port, normalized ) if result is not None: - logger.info( - f"Successfully downloaded {model_id} from peer {peer.ip}" - ) + logger.info(f"Successfully downloaded {model_id} from peer {peer.ip}") return result logger.info( f"Peer download from {peer.ip} failed, trying next peer or HuggingFace" ) # All peers failed, fall back to HuggingFace - logger.info(f"All peer downloads failed for {model_id}, falling back to HuggingFace") + logger.info( + f"All peer downloads failed for {model_id}, falling back to HuggingFace" + ) return await self._inner.ensure_shard(shard, config_only=False) async def _try_peer_download( @@ -145,9 +147,7 @@ async def _try_peer_download( Returns the model directory path on success, None on failure. """ # First, check what the peer has - peer_files = await get_peer_file_status( - peer_ip, peer_port, model_id_normalized - ) + peer_files = await get_peer_file_status(peer_ip, peer_port, model_id_normalized) if not peer_files: return None @@ -255,6 +255,21 @@ def on_file_progress( ) if result is None: return False + # Offline / air-gapped deployments have explicitly opted + # out of contacting HuggingFace. Codex flagged (P1, PR + # #16 round 3) that calling ``file_meta`` here silently + # broke peer transfers in offline mode: any exception + # (e.g. DNS failure, blocked egress) was treated as + # integrity-check failure and the peer copy was + # deleted, leaving the cold node with no path to + # complete model sync. When the operator runs with + # ``--offline``/``EXO_OFFLINE=true`` we trust the LAN + # peer's bytes (size already enforced by + # ``download_file_from_peer``) and skip the HF + # canonical-hash check entirely. + if self._offline: + return True + # Codex flagged (P2, PR #16 round 2) that peer downloads # were marked successful as soon as ``n_read == # expected_size``, with no content-integrity check. A @@ -265,7 +280,7 @@ def on_file_progress( # # Validate against HuggingFace's authoritative hash: # we already need internet for ``fetch_file_list_with_cache`` - # (line 149), so the extra ``file_meta()`` HEAD is + # in online mode, so the extra ``file_meta()`` HEAD is # cheap. Trusting a hash advertised by the peer would # leave us vulnerable to a malicious peer that lies # about both bytes and hash; HF is the canonical @@ -312,8 +327,7 @@ def on_file_progress( await aios.remove(result) except Exception as exc: logger.error( - f"Failed to remove corrupt peer download " - f"{result}: {exc}" + f"Failed to remove corrupt peer download {result}: {exc}" ) return False return True @@ -363,9 +377,7 @@ def on_file_progress( for cb in self._progress_callbacks: await cb(shard, final_progress) - gguf = next( - (f for f in filtered_file_list if f.path.endswith(".gguf")), None - ) + gguf = next((f for f in filtered_file_list if f.path.endswith(".gguf")), None) return (target_dir / gguf.path) if gguf else target_dir async def get_shard_download_status( diff --git a/src/exo/download/tests/test_peer_download.py b/src/exo/download/tests/test_peer_download.py index 7a252a463e..6f0d2dc9b7 100644 --- a/src/exo/download/tests/test_peer_download.py +++ b/src/exo/download/tests/test_peer_download.py @@ -71,9 +71,10 @@ async def test_health_check(self, peer_server: PeerFileServer) -> None: """Health endpoint should return ok.""" import aiohttp - async with aiohttp.ClientSession() as session, session.get( - f"http://127.0.0.1:{peer_server.port}/health" - ) as r: + async with ( + aiohttp.ClientSession() as session, + session.get(f"http://127.0.0.1:{peer_server.port}/health") as r, + ): assert r.status == 200 data = cast(dict[str, object], await r.json()) assert data["status"] == "ok" @@ -97,9 +98,7 @@ async def test_status_with_complete_file( async with aiofiles.open(model_dir / "config.json", "wb") as f: await f.write(b'{"test": true}') - files = await get_peer_file_status( - "127.0.0.1", peer_server.port, "test--model" - ) + files = await get_peer_file_status("127.0.0.1", peer_server.port, "test--model") assert files is not None assert len(files) == 1 assert files[0].path == "config.json" @@ -124,9 +123,7 @@ async def test_status_with_partial_file( ) as f: await f.write(json.dumps(meta)) - files = await get_peer_file_status( - "127.0.0.1", peer_server.port, "test--model" - ) + files = await get_peer_file_status("127.0.0.1", peer_server.port, "test--model") assert files is not None assert len(files) == 1 assert files[0].path == "weights.safetensors" @@ -151,9 +148,7 @@ async def test_status_includes_nested_files( ) as f: await f.write(json.dumps({"safe_bytes": 512, "total": 2048})) - files = await get_peer_file_status( - "127.0.0.1", peer_server.port, "test--model" - ) + files = await get_peer_file_status("127.0.0.1", peer_server.port, "test--model") assert files is not None by_path = {file.path: file for file in files} assert by_path["snapshots/abc123/config.json"].complete is True @@ -173,9 +168,12 @@ async def test_serve_complete_file( import aiohttp - async with aiohttp.ClientSession() as session, session.get( - f"http://127.0.0.1:{peer_server.port}/files/test--model/config.json" - ) as r: + async with ( + aiohttp.ClientSession() as session, + session.get( + f"http://127.0.0.1:{peer_server.port}/files/test--model/config.json" + ) as r, + ): assert r.status == 200 assert r.headers["X-Exo-Complete"] == "true" body = await r.read() @@ -195,10 +193,13 @@ async def test_serve_nested_file( import aiohttp - async with aiohttp.ClientSession() as session, session.get( - f"http://127.0.0.1:{peer_server.port}/files/test--model/" - "snapshots/abc123/config.json" - ) as r: + async with ( + aiohttp.ClientSession() as session, + session.get( + f"http://127.0.0.1:{peer_server.port}/files/test--model/" + "snapshots/abc123/config.json" + ) as r, + ): assert r.status == 200 body = await r.read() assert body == content @@ -216,10 +217,13 @@ async def test_rejects_path_traversal( import aiohttp - async with aiohttp.ClientSession() as session, session.get( - f"http://127.0.0.1:{peer_server.port}/files/test--model/" - "%2E%2E/outside.txt" - ) as r: + async with ( + aiohttp.ClientSession() as session, + session.get( + f"http://127.0.0.1:{peer_server.port}/files/test--model/" + "%2E%2E/outside.txt" + ) as r, + ): assert r.status == 404 assert await r.text() != "outside" @@ -236,10 +240,13 @@ async def test_serve_with_range_request( import aiohttp - async with aiohttp.ClientSession() as session, session.get( - f"http://127.0.0.1:{peer_server.port}/files/test--model/weights.bin", - headers={"Range": "bytes=8-"}, - ) as r: + async with ( + aiohttp.ClientSession() as session, + session.get( + f"http://127.0.0.1:{peer_server.port}/files/test--model/weights.bin", + headers={"Range": "bytes=8-"}, + ) as r, + ): assert r.status == 206 body = await r.read() assert body == b"89abcdef" @@ -248,9 +255,12 @@ async def test_file_not_found(self, peer_server: PeerFileServer) -> None: """Should return 404 for missing files.""" import aiohttp - async with aiohttp.ClientSession() as session, session.get( - f"http://127.0.0.1:{peer_server.port}/files/test--model/missing.bin" - ) as r: + async with ( + aiohttp.ClientSession() as session, + session.get( + f"http://127.0.0.1:{peer_server.port}/files/test--model/missing.bin" + ) as r, + ): assert r.status == 404 @@ -418,9 +428,7 @@ def test_offline_flag_defaults_to_false(self) -> None: assert downloader._offline is False def test_offline_flag_propagates(self) -> None: - downloader = PeerAwareShardDownloader( - NoopShardDownloader(), offline=True - ) + downloader = PeerAwareShardDownloader(NoopShardDownloader(), offline=True) assert downloader._offline is True async def test_try_peer_download_passes_offline_to_fetch_file_list( @@ -437,9 +445,7 @@ async def test_try_peer_download_passes_offline_to_fetch_file_list( captured: dict[str, object] = {} - async def fake_fetch( - *args: object, **kwargs: object - ) -> list[FileListEntry]: + async def fake_fetch(*args: object, **kwargs: object) -> list[FileListEntry]: captured["args"] = args captured["kwargs"] = kwargs # Empty list -> no required files -> ``failed`` short- @@ -473,9 +479,7 @@ async def fake_resolve_allow(shard: ShardMetadata) -> list[str]: monkeypatch.setattr(psd, "resolve_model_dir", fake_resolve_dir) monkeypatch.setattr(psd, "resolve_allow_patterns", fake_resolve_allow) - downloader = PeerAwareShardDownloader( - NoopShardDownloader(), offline=True - ) + downloader = PeerAwareShardDownloader(NoopShardDownloader(), offline=True) shard = _make_shard(ModelId("test-org/model-a")) result = await downloader._try_peer_download( @@ -490,10 +494,7 @@ async def fake_resolve_allow(shard: ShardMetadata) -> list[str]: assert captured["kwargs"] == { "recursive": True, "skip_internet": True, - }, ( - "skip_internet must reflect downloader.offline (got " - f"{captured['kwargs']!r})" - ) + }, f"skip_internet must reflect downloader.offline (got {captured['kwargs']!r})" async def test_try_peer_download_filters_ignore_patterns( self, monkeypatch: pytest.MonkeyPatch @@ -520,15 +521,11 @@ async def test_try_peer_download_filters_ignore_patterns( # These two should NOT show up on the peer's required-files # list once the fix lands. Pre-fix they did, the peer didn't # have them, and the whole transfer fell back to HF. - FileListEntry( - type="file", path="original/consolidated.00.pth", size=999 - ), + FileListEntry(type="file", path="original/consolidated.00.pth", size=999), FileListEntry(type="file", path="metal/dist.bin", size=999), ] - async def fake_fetch( - *_args: object, **_kwargs: object - ) -> list[FileListEntry]: + async def fake_fetch(*_args: object, **_kwargs: object) -> list[FileListEntry]: return served # The peer reports ONLY the canonical files, exactly the shape @@ -543,9 +540,7 @@ async def fake_peer_status( timeout: float = 5.0, ) -> list[PeerFileInfo] | None: return [ - PeerFileInfo( - path=p, size=100, complete=True, safe_bytes=100 - ) + PeerFileInfo(path=p, size=100, complete=True, safe_bytes=100) for p in peer_paths ] @@ -611,3 +606,183 @@ def recording_filter( "incorrectly judged incomplete; got " f"{captured_kwargs!r}" ) + + +class TestPeerDownloadIntegrityCheckRespectsOfflineMode: + """Codex P1 on PR #16 round 3: ``_try_peer_download`` was calling + ``file_meta(...)`` against HuggingFace for every file, even when the + coordinator was started with ``--offline`` / ``EXO_OFFLINE=true``. + Any failure to reach HF (the entire point of offline mode) was + treated as an integrity-check failure, the peer-fetched bytes were + deleted, and the cold node was left with no path to complete model + sync. The fix: when the downloader is in offline mode, trust the + LAN peer's bytes and skip the HF metadata call entirely. + """ + + async def test_offline_mode_skips_file_meta_and_keeps_peer_bytes( + self, monkeypatch: pytest.MonkeyPatch, tmp_path: Path + ) -> None: + from exo.download import peer_shard_downloader as psd + from exo.download.peer_download import PeerFileInfo + from exo.shared.types.worker.downloads import FileListEntry + + async def fake_fetch(*_args: object, **_kwargs: object) -> list[FileListEntry]: + return [ + FileListEntry( + type="file", + path="model.safetensors", + size=10, + ), + ] + + async def fake_peer_status( + peer_host: str, + peer_port: int, + model_id_normalized: str, + timeout: float = 5.0, + ) -> list[PeerFileInfo] | None: + return [ + PeerFileInfo( + path="model.safetensors", + size=10, + complete=True, + safe_bytes=10, + ) + ] + + async def fake_resolve_dir(model_id: ModelId) -> Path: + return tmp_path + + async def fake_resolve_allow(shard: ShardMetadata) -> list[str]: + return ["*"] + + target_path = tmp_path / "model.safetensors" + + async def fake_download( + peer_ip: str, + peer_port: int, + model_id_normalized: str, + file_path: str, + target_dir: Path, + expected_size: int, + on_progress: object = None, + ) -> Path | None: + async with aiofiles.open(target_path, "wb") as f: + await f.write(b"0123456789") + return target_path + + async def file_meta_should_not_be_called( + *_args: object, **_kwargs: object + ) -> tuple[int, str]: + raise AssertionError( + "file_meta must not be called in offline mode -- the " + "operator opted into trusting LAN peers" + ) + + monkeypatch.setattr(psd, "fetch_file_list_with_cache", fake_fetch) + monkeypatch.setattr(psd, "get_peer_file_status", fake_peer_status) + monkeypatch.setattr(psd, "resolve_model_dir", fake_resolve_dir) + monkeypatch.setattr(psd, "resolve_allow_patterns", fake_resolve_allow) + monkeypatch.setattr(psd, "download_file_from_peer", fake_download) + monkeypatch.setattr(psd, "file_meta", file_meta_should_not_be_called) + + downloader = PeerAwareShardDownloader(NoopShardDownloader(), offline=True) + shard = _make_shard(ModelId("test-org/model-a")) + + result = await downloader._try_peer_download( + shard, + peer_ip="10.0.0.1", + peer_port=52415, + model_id_normalized="test-org/model-a", + ) + assert result is not None, ( + "offline peer download must succeed without consulting HF; " + "got None which means the integrity check fired and the " + "peer bytes were discarded" + ) + assert await aios.path.exists(target_path), ( + "peer-downloaded file must be retained when offline mode " + "skips the HF integrity check" + ) + + async def test_online_mode_still_calls_file_meta( + self, monkeypatch: pytest.MonkeyPatch, tmp_path: Path + ) -> None: + from exo.download import peer_shard_downloader as psd + from exo.download.peer_download import PeerFileInfo + from exo.shared.types.worker.downloads import FileListEntry + + async def fake_fetch(*_args: object, **_kwargs: object) -> list[FileListEntry]: + return [ + FileListEntry( + type="file", + path="model.safetensors", + size=10, + ), + ] + + async def fake_peer_status( + peer_host: str, + peer_port: int, + model_id_normalized: str, + timeout: float = 5.0, + ) -> list[PeerFileInfo] | None: + return [ + PeerFileInfo( + path="model.safetensors", + size=10, + complete=True, + safe_bytes=10, + ) + ] + + async def fake_resolve_dir(model_id: ModelId) -> Path: + return tmp_path + + async def fake_resolve_allow(shard: ShardMetadata) -> list[str]: + return ["*"] + + target_path = tmp_path / "model.safetensors" + + async def fake_download( + peer_ip: str, + peer_port: int, + model_id_normalized: str, + file_path: str, + target_dir: Path, + expected_size: int, + on_progress: object = None, + ) -> Path | None: + async with aiofiles.open(target_path, "wb") as f: + await f.write(b"0123456789") + return target_path + + meta_calls: list[tuple[object, ...]] = [] + + async def recording_meta(*args: object, **_kwargs: object) -> tuple[int, str]: + meta_calls.append(args) + # Return mismatched etag -> downloader will discard. + return (10, "deadbeef" * 5) + + monkeypatch.setattr(psd, "fetch_file_list_with_cache", fake_fetch) + monkeypatch.setattr(psd, "get_peer_file_status", fake_peer_status) + monkeypatch.setattr(psd, "resolve_model_dir", fake_resolve_dir) + monkeypatch.setattr(psd, "resolve_allow_patterns", fake_resolve_allow) + monkeypatch.setattr(psd, "download_file_from_peer", fake_download) + monkeypatch.setattr(psd, "file_meta", recording_meta) + + downloader = PeerAwareShardDownloader(NoopShardDownloader(), offline=False) + shard = _make_shard(ModelId("test-org/model-a")) + + await downloader._try_peer_download( + shard, + peer_ip="10.0.0.1", + peer_port=52415, + model_id_normalized="test-org/model-a", + ) + + assert len(meta_calls) == 1, ( + "online mode must continue calling file_meta to validate " + "peer-downloaded bytes against HF's authoritative hash; " + f"got meta_calls={meta_calls!r}" + ) diff --git a/src/exo/main.py b/src/exo/main.py index 9e0969ab82..24e176e4a0 100644 --- a/src/exo/main.py +++ b/src/exo/main.py @@ -48,6 +48,7 @@ class Node: offline: bool _api_port: int _libp2p_port: int + _peer_download_port: int peer_file_server: PeerFileServer | None = None _tg: TaskGroup = field(init=False, default_factory=TaskGroup) @@ -99,15 +100,23 @@ async def create(cls, args: "Args") -> Self: command_sender=router.sender(topics.COMMANDS), download_command_sender=router.sender(topics.DOWNLOAD_COMMANDS), api_port=args.api_port, + # Each node now binds its own peer-download listener on + # ``--peer-download-port`` (default ``EXO_PEER_DOWNLOAD_PORT``). + # The Worker uses this same value when discovering peers, + # so all nodes in a cluster MUST agree on it (typically + # via the shared ``EXO_PEER_DOWNLOAD_PORT`` env var). + # Pre-fix this was a single import-time module constant, + # making same-host multi-node setups impossible (Codex + # P2, PR #16 round 3). + peer_download_port=args.peer_download_port, ) else: worker = None - # Create peer file server and download coordinator if peer_download_enabled: peer_file_server = PeerFileServer( host="0.0.0.0", - port=EXO_PEER_DOWNLOAD_PORT, + port=args.peer_download_port, models_dir=EXO_DEFAULT_MODELS_DIR, ) @@ -163,6 +172,7 @@ async def create(cls, args: "Args") -> Self: args.offline, args.api_port, args.libp2p_port, + args.peer_download_port, peer_file_server, ) logger_set_context( @@ -303,6 +313,7 @@ async def _elect_loop(self): topics.DOWNLOAD_COMMANDS ), api_port=self._api_port, + peer_download_port=self._peer_download_port, ) self._tg.start_soon(self.worker.run) if self.api: @@ -416,9 +427,7 @@ def _darwin_en0_broadcast_address(ip_address: str) -> str | None: return None -async def _darwin_mdns_broadcast_announcer( - node_id: NodeId, libp2p_port: int -) -> None: +async def _darwin_mdns_broadcast_announcer(node_id: NodeId, libp2p_port: int) -> None: ip_address = _darwin_en0_ip_address() if not ip_address: logger.debug("Darwin mDNS broadcast announcer disabled: no en0 IPv4 address") @@ -545,6 +554,17 @@ class Args(FrozenModel): fast_synch: bool | None = None # None = auto, True = force on, False = force off bootstrap_peers: list[str] = [] libp2p_port: int + # Per-process listener port for peer-to-peer model file serving. + # Defaults to ``EXO_PEER_DOWNLOAD_PORT`` so existing single-node-per- + # host deployments keep working unchanged. Operators running + # multiple nodes on the same host MUST set this to a distinct value + # for each process; the cluster-wide convention is that every node + # exposes the same port, since peer discovery currently uses each + # node's local value as the assumed remote endpoint (see + # ``Worker._peer_download_port``). A future state-sync change can + # advertise per-node ports across the cluster -- tracked as a + # follow-up to Codex P2 (PR #16 round 3). + peer_download_port: PositiveInt = EXO_PEER_DOWNLOAD_PORT trust_remote_code: bool = False @classmethod @@ -629,6 +649,22 @@ def parse(cls) -> Self: dest="libp2p_port", help="Fixed TCP port for libp2p to listen on (0 = OS-assigned).", ) + parser.add_argument( + "--peer-download-port", + type=int, + default=EXO_PEER_DOWNLOAD_PORT, + dest="peer_download_port", + help=( + "TCP port for peer-to-peer model file serving (default: " + "EXO_PEER_DOWNLOAD_PORT, currently 52416). Required to " + "differ between processes when running multiple nodes " + "on the same host; otherwise the second node's " + "PeerFileServer hits 'address already in use'. All " + "nodes in a cluster must use the same value (peer " + "discovery uses the local port as the assumed remote " + "port)." + ), + ) fast_synch_group = parser.add_mutually_exclusive_group() fast_synch_group.add_argument( "--fast-synch", diff --git a/src/exo/worker/main.py b/src/exo/worker/main.py index 9e2a07010a..7b6ea75ce9 100644 --- a/src/exo/worker/main.py +++ b/src/exo/worker/main.py @@ -10,7 +10,7 @@ from exo.download.download_utils import is_read_only_model_dir, resolve_existing_model from exo.download.peer_state import discover_peers_for_model from exo.shared.apply import apply -from exo.shared.constants import EXO_MAX_INSTANCE_RETRIES, EXO_PEER_DOWNLOAD_PORT +from exo.shared.constants import EXO_MAX_INSTANCE_RETRIES from exo.shared.models.model_cards import ModelId, add_to_card_cache, delete_custom_card from exo.shared.types.chunks import InputImageChunk from exo.shared.types.commands import ( @@ -73,6 +73,7 @@ def __init__( command_sender: Sender[ForwarderCommand], download_command_sender: Sender[ForwarderDownloadCommand], api_port: int, + peer_download_port: int, ): self.node_id: NodeId = node_id self.event_receiver = event_receiver @@ -80,6 +81,14 @@ def __init__( self.command_sender = command_sender self.download_command_sender = download_command_sender self.api_port = api_port + # Codex P2 (PR #16 round 3): the peer-download listener port is + # now per-process configurable instead of a module-level + # constant. Use the local value when computing + # ``discover_peers_for_model`` results because peers in the + # current architecture all bind the same port (cluster-wide + # convention enforced via ``EXO_PEER_DOWNLOAD_PORT`` / + # ``--peer-download-port``). + self._peer_download_port = peer_download_port self.state: State = State() self.runners: dict[RunnerId, RunnerSupervisor] = {} @@ -265,12 +274,11 @@ async def plan_step(self): ) ) else: - # Discover peers that already have this model peers = discover_peers_for_model( self.node_id, self.state, shard.model_card.model_id.normalize(), - EXO_PEER_DOWNLOAD_PORT, + self._peer_download_port, ) await self.download_command_sender.send( ForwarderDownloadCommand( From 714c6103fddf7b38b426725a2aa2f1847fa1d0eb Mon Sep 17 00:00:00 2001 From: jw-wcv <101585096+jw-wcv@users.noreply.github.com> Date: Fri, 8 May 2026 23:54:12 -0700 Subject: [PATCH 19/29] Pick RDMA edges and serve from every model dir for peer downloads This addresses two Codex P2 findings on PR #16. `_resolve_peer_endpoint` returned on the *first* topology edge it visited for a peer, so when ``out_edges`` happened to yield the SocketConnection before the RDMAConnection (insertion order is not a stable contract on the topology graph), the peer was silently labelled ``socket`` and lost its RDMA priority in the peer ordering. The lookup now scans all edges for a peer, prefers RDMA whenever any RDMA edge exists (using the companion socket address for the actual TCP connect), and only falls back to ``socket`` when no RDMA edge is present. The unused `_find_socket_ip` helper that was meant as a fallback for the RDMA branch is removed -- its behaviour is folded into the new scan. `PeerFileServer` was hard-wired to ``EXO_DEFAULT_MODELS_DIR``, but ``select_download_dir_for_shard`` legitimately writes shards across ``EXO_MODELS_DIRS`` (custom paths, low-disk fallback) and we also have shards available in ``EXO_MODELS_READ_ONLY_DIRS`` mounts. Pre-fix, any model that landed outside the default directory was invisible to /status and /files, so peers always fell back to HuggingFace and the new peer path was a no-op for valid multi-directory deployments. The server now takes ``models_dirs`` (a sequence) and probes every configured root in caller-specified priority, checking each candidate against ``_resolve_child`` for path-traversal safety. Adds regression coverage: * `test_peer_state.py` covers RDMA-first ordering for both edge insertion orders and the socket-only / RDMA-only edge cases. * `TestPeerFileServerMultipleDirectories` covers serving from a secondary writable directory and a read-only mount, plus the constructor's empty-list rejection. --- src/exo/download/peer_file_server.py | 52 +++++-- src/exo/download/peer_state.py | 77 +++++----- src/exo/download/tests/test_peer_download.py | 74 +++++++++- src/exo/download/tests/test_peer_state.py | 142 +++++++++++++++++++ src/exo/main.py | 16 ++- 5 files changed, 308 insertions(+), 53 deletions(-) create mode 100644 src/exo/download/tests/test_peer_state.py diff --git a/src/exo/download/peer_file_server.py b/src/exo/download/peer_file_server.py index 2cc64be549..a674c85390 100644 --- a/src/exo/download/peer_file_server.py +++ b/src/exo/download/peer_file_server.py @@ -1,15 +1,25 @@ """Lightweight HTTP file server for peer-to-peer model downloads. -Each exo node runs a PeerFileServer that serves model files from the local -cache directory. When one node finishes downloading a model from HuggingFace, -other nodes on the same LAN can fetch it directly over HTTP instead of +Each exo node runs a PeerFileServer that serves model files from its local +caches. When one node finishes downloading a model from HuggingFace, other +nodes on the same LAN can fetch it directly over HTTP instead of re-downloading from the internet. Supports serving in-progress downloads via .partial.meta files that track how many bytes have been safely flushed to disk. + +The server is given the *full* set of directories the local node may store +models in (the writable ``EXO_MODELS_DIRS`` plus any read-only mounts under +``EXO_MODELS_READ_ONLY_DIRS``) so that peers can fetch any locally-resident +model regardless of which directory the downloader picked. Restricting the +server to a single hard-coded directory would silently disable the peer +download path whenever ``select_download_dir_for_shard`` placed the model +in a non-default directory (custom path, low-disk fallback, or a read-only +mount). """ import json +from collections.abc import Sequence from pathlib import Path from typing import TypeAlias, cast @@ -24,15 +34,17 @@ class PeerFileServer: """HTTP server that exposes local model files for peer download.""" - def __init__(self, host: str, port: int, models_dir: Path) -> None: + def __init__(self, host: str, port: int, models_dirs: Sequence[Path]) -> None: + if not models_dirs: + raise ValueError("PeerFileServer requires at least one models directory") self.host = host self.port = port - self.models_dir = models_dir + # Preserve caller order so callers can prefer writable dirs over + # read-only dirs without us re-sorting them. + self.models_dirs: tuple[Path, ...] = tuple(models_dirs) self._app = web.Application() self._app.router.add_get("/status/{model_id}", self._handle_status) - self._app.router.add_get( - "/files/{model_id}/{file_path:.+}", self._handle_file - ) + self._app.router.add_get("/files/{model_id}/{file_path:.+}", self._handle_file) self._app.router.add_get("/health", self._handle_health) self._runner: web.AppRunner | None = None @@ -53,11 +65,9 @@ async def _handle_health(self, request: web.Request) -> web.Response: async def _handle_status(self, request: web.Request) -> web.Response: """Return status of all files for a model (complete + in-progress).""" model_id = request.match_info["model_id"] - model_dir = _resolve_child(self.models_dir, model_id) + model_dir = await self._locate_model_dir(model_id) if model_dir is None: - return web.Response(status=404, text="Model not found") - - if not await aios.path.exists(model_dir): + # No matching directory containing this model. return web.json_response({"files": []}) files: list[dict[str, object]] = [] @@ -105,7 +115,7 @@ async def _handle_file(self, request: web.Request) -> web.StreamResponse: model_id = request.match_info["model_id"] file_path = request.match_info["file_path"] - model_dir = _resolve_child(self.models_dir, model_id) + model_dir = await self._locate_model_dir(model_id) if model_dir is None: return web.Response(status=404, text="Model not found") @@ -176,6 +186,22 @@ async def _handle_file(self, request: web.Request) -> web.StreamResponse: await response.write_eof() return response + async def _locate_model_dir(self, model_id: str) -> Path | None: + """Return the first configured directory that contains ``model_id``. + + Each candidate root is path-traversal-checked independently before we + probe the filesystem. We prefer the first directory in ``models_dirs`` + that has a matching subdirectory; this preserves caller-specified + priority (e.g. writable before read-only) without re-sorting. + """ + for root in self.models_dirs: + candidate = _resolve_child(root, model_id) + if candidate is None: + continue + if await aios.path.exists(candidate): + return candidate + return None + def _resolve_child(root: Path, relative_path: str) -> Path | None: """Resolve relative_path under root, rejecting path traversal.""" diff --git a/src/exo/download/peer_state.py b/src/exo/download/peer_state.py index bfcabc58ed..8696a50b9f 100644 --- a/src/exo/download/peer_state.py +++ b/src/exo/download/peer_state.py @@ -83,44 +83,47 @@ def _resolve_peer_endpoint( peer_download_port: int, status: str, ) -> PeerEndpoint | None: - """Resolve a peer's IP address and connection type from the topology.""" + """Resolve a peer's IP address and connection type from the topology. + + Iteration order over ``out_edges`` is not guaranteed to surface RDMA + edges before socket edges, so we scan the full edge set once: any + RDMA edge wins (we use the peer's socket address for the actual TCP + connect since RDMA edges don't carry routable IPs), and only when no + RDMA edge exists do we fall back to the socket endpoint. Returning + on the first non-RDMA hit would otherwise mislabel peers as + ``socket`` whenever the socket edge happens to be visited first. + """ try: - # Check for RDMA connections first (highest priority) - for conn in state.topology.out_edges(node_id): - if conn.sink != peer_node_id: - continue - if isinstance(conn.edge, RDMAConnection): - # RDMA peer — still need IP from a socket connection - ip = _find_socket_ip(node_id, peer_node_id, state) - if ip: - return PeerEndpoint( - node_id=peer_node_id, - ip=ip, - port=peer_download_port, - status=status, - connection_type="rdma", - ) - else: - return PeerEndpoint( - node_id=peer_node_id, - ip=conn.edge.sink_multiaddr.ip_address, - port=peer_download_port, - status=status, - connection_type="socket", - ) + edges = [ + conn + for conn in state.topology.out_edges(node_id) + if conn.sink == peer_node_id + ] + has_rdma = any(isinstance(conn.edge, RDMAConnection) for conn in edges) + socket_ip = next( + ( + conn.edge.sink_multiaddr.ip_address + for conn in edges + if isinstance(conn.edge, SocketConnection) + ), + None, + ) + if has_rdma and socket_ip: + return PeerEndpoint( + node_id=peer_node_id, + ip=socket_ip, + port=peer_download_port, + status=status, + connection_type="rdma", + ) + if socket_ip: + return PeerEndpoint( + node_id=peer_node_id, + ip=socket_ip, + port=peer_download_port, + status=status, + connection_type="socket", + ) except Exception as e: logger.debug(f"Could not resolve endpoint for peer {peer_node_id}: {e}") return None - - -def _find_socket_ip( - node_id: NodeId, peer_node_id: NodeId, state: State -) -> str | None: - """Find a socket connection IP for a peer (used as fallback for RDMA peers).""" - try: - for conn in state.topology.out_edges(node_id): - if conn.sink == peer_node_id and isinstance(conn.edge, SocketConnection): - return conn.edge.sink_multiaddr.ip_address - except Exception: - pass - return None diff --git a/src/exo/download/tests/test_peer_download.py b/src/exo/download/tests/test_peer_download.py index 6f0d2dc9b7..0a774ec365 100644 --- a/src/exo/download/tests/test_peer_download.py +++ b/src/exo/download/tests/test_peer_download.py @@ -32,7 +32,7 @@ async def temp_models_dir(tmp_path: Path) -> AsyncIterator[Path]: @pytest.fixture async def peer_server(temp_models_dir: Path) -> AsyncIterator[PeerFileServer]: """Start a PeerFileServer on a random port for testing.""" - server = PeerFileServer(host="127.0.0.1", port=0, models_dir=temp_models_dir) + server = PeerFileServer(host="127.0.0.1", port=0, models_dirs=[temp_models_dir]) # Use port 0 to let OS assign a free port from aiohttp import web @@ -264,6 +264,78 @@ async def test_file_not_found(self, peer_server: PeerFileServer) -> None: assert r.status == 404 +class TestPeerFileServerMultipleDirectories: + """The peer file server must look for the model in *every* configured + models directory. Otherwise a node that lands a model in a non-default + writable directory (custom path, low-disk fallback, or read-only mount) + would silently fail to advertise it to peers and force them back onto + HuggingFace -- defeating the whole peer download path. + """ + + async def test_serves_model_from_secondary_writable_dir( + self, tmp_path: Path + ) -> None: + primary = tmp_path / "primary" + secondary = tmp_path / "secondary" + await aios.makedirs(primary, exist_ok=True) + await aios.makedirs(secondary, exist_ok=True) + + model_dir = secondary / "test--model" + await aios.makedirs(model_dir, exist_ok=True) + async with aiofiles.open(model_dir / "config.json", "wb") as f: + await f.write(b'{"hello":"world"}') + + server = PeerFileServer( + host="127.0.0.1", port=0, models_dirs=[primary, secondary] + ) + + from aiohttp import web + + server._runner = web.AppRunner(server._app) + await server._runner.setup() + site = web.TCPSite(server._runner, "127.0.0.1", 0) + await site.start() + port_int: int = cast(int, site._server.sockets[0].getsockname()[1]) # type: ignore[union-attr] + server.port = port_int + try: + files = await get_peer_file_status("127.0.0.1", port_int, "test--model") + assert files is not None + assert {f.path for f in files} == {"config.json"} + finally: + await server.shutdown() + + async def test_serves_model_from_read_only_mount(self, tmp_path: Path) -> None: + writable = tmp_path / "writable" + read_only = tmp_path / "ro_mount" + await aios.makedirs(writable, exist_ok=True) + await aios.makedirs(read_only / "ro--model", exist_ok=True) + async with aiofiles.open(read_only / "ro--model" / "config.json", "wb") as f: + await f.write(b"{}") + + server = PeerFileServer( + host="127.0.0.1", port=0, models_dirs=[writable, read_only] + ) + + from aiohttp import web + + server._runner = web.AppRunner(server._app) + await server._runner.setup() + site = web.TCPSite(server._runner, "127.0.0.1", 0) + await site.start() + port_int: int = cast(int, site._server.sockets[0].getsockname()[1]) # type: ignore[union-attr] + server.port = port_int + try: + files = await get_peer_file_status("127.0.0.1", port_int, "ro--model") + assert files is not None + assert {f.path for f in files} == {"config.json"} + finally: + await server.shutdown() + + async def test_constructor_rejects_empty_directory_list(self) -> None: + with pytest.raises(ValueError, match="at least one models directory"): + PeerFileServer(host="127.0.0.1", port=0, models_dirs=[]) + + class TestPeerDownloadClient: """Tests for downloading files from a peer server.""" diff --git a/src/exo/download/tests/test_peer_state.py b/src/exo/download/tests/test_peer_state.py new file mode 100644 index 0000000000..570692373d --- /dev/null +++ b/src/exo/download/tests/test_peer_state.py @@ -0,0 +1,142 @@ +"""Regression tests for ``exo.download.peer_state``. + +These exercise the topology-iteration ordering that decides whether a peer +is reachable over RDMA or merely via socket. The original implementation +returned on the first edge whose type happened to be visited first, which +mislabelled peers when ``out_edges`` yielded the socket edge before the +RDMA edge. We now scan all edges and prefer RDMA whenever any RDMA edge +exists for that peer. +""" + +from collections.abc import Iterable +from pathlib import Path +from typing import cast + +from exo.download.peer_state import discover_peers_for_model +from exo.shared.models.model_cards import ModelCard, ModelId, ModelTask +from exo.shared.topology import Topology +from exo.shared.types.common import NodeId +from exo.shared.types.memory import Memory +from exo.shared.types.multiaddr import Multiaddr +from exo.shared.types.state import State +from exo.shared.types.topology import ( + Connection, + RDMAConnection, + SocketConnection, +) +from exo.shared.types.worker.downloads import DownloadCompleted, DownloadProgress +from exo.shared.types.worker.shards import PipelineShardMetadata, ShardMetadata + +LOCAL = NodeId("aaaaaaaa-aaaa-4aaa-8aaa-aaaaaaaaaaaa") +PEER = NodeId("bbbbbbbb-bbbb-4bbb-8bbb-bbbbbbbbbbbb") +MODEL_ID = ModelId("test-org/test-model") +NORMALIZED = MODEL_ID.normalize() + + +def _make_shard() -> ShardMetadata: + return PipelineShardMetadata( + model_card=ModelCard( + model_id=MODEL_ID, + storage_size=Memory.from_mb(100), + n_layers=4, + hidden_size=64, + supports_tensor=False, + tasks=[ModelTask.TextGeneration], + ), + device_rank=0, + world_size=1, + start_layer=0, + end_layer=4, + n_layers=4, + ) + + +def _build_topology(edges: Iterable[Connection]) -> Topology: + topology = Topology() + topology.add_node(LOCAL) + topology.add_node(PEER) + for conn in edges: + topology.add_connection(conn) + return topology + + +def _state_with_completed_peer(topology: Topology) -> State: + completed = DownloadCompleted( + node_id=PEER, + shard_metadata=_make_shard(), + total=Memory.from_mb(100), + model_directory=str(Path("/fake/models/test-org--test-model")), + ) + return State( + downloads={PEER: [cast(DownloadProgress, completed)]}, + topology=topology, + ) + + +def _socket_edge() -> Connection: + return Connection( + source=LOCAL, + sink=PEER, + edge=SocketConnection( + sink_multiaddr=Multiaddr(address="/ip4/10.0.0.2/tcp/4001") + ), + ) + + +def _rdma_edge() -> Connection: + return Connection( + source=LOCAL, + sink=PEER, + edge=RDMAConnection(source_rdma_iface="bridge0", sink_rdma_iface="bridge0"), + ) + + +def test_peer_marked_rdma_when_socket_edge_inserted_first() -> None: + """If both an RDMA edge and a socket edge exist for the same peer, the + peer must be reported as RDMA *regardless of insertion order*. The + original implementation returned on the first edge it saw, so a socket + edge inserted before the RDMA edge silently downgraded a real RDMA peer + to ``socket`` and broke the "RDMA first" ordering used by the peer + downloader. + """ + topology = _build_topology([_socket_edge(), _rdma_edge()]) + state = _state_with_completed_peer(topology) + + peers = discover_peers_for_model(LOCAL, state, NORMALIZED, peer_download_port=52416) + + assert len(peers) == 1 + assert peers[0].connection_type == "rdma" + assert peers[0].ip == "10.0.0.2" + + +def test_peer_marked_rdma_when_rdma_edge_inserted_first() -> None: + topology = _build_topology([_rdma_edge(), _socket_edge()]) + state = _state_with_completed_peer(topology) + + peers = discover_peers_for_model(LOCAL, state, NORMALIZED, peer_download_port=52416) + + assert len(peers) == 1 + assert peers[0].connection_type == "rdma" + + +def test_peer_marked_socket_when_no_rdma_edge_exists() -> None: + topology = _build_topology([_socket_edge()]) + state = _state_with_completed_peer(topology) + + peers = discover_peers_for_model(LOCAL, state, NORMALIZED, peer_download_port=52416) + + assert len(peers) == 1 + assert peers[0].connection_type == "socket" + assert peers[0].ip == "10.0.0.2" + + +def test_peer_skipped_when_only_rdma_edge_has_no_socket_companion() -> None: + """An RDMA-only peer cannot be contacted over the peer-download HTTP + server, so we must omit it rather than fabricate a missing IP. + """ + topology = _build_topology([_rdma_edge()]) + state = _state_with_completed_peer(topology) + + peers = discover_peers_for_model(LOCAL, state, NORMALIZED, peer_download_port=52416) + + assert peers == [] diff --git a/src/exo/main.py b/src/exo/main.py index 24e176e4a0..70ad935642 100644 --- a/src/exo/main.py +++ b/src/exo/main.py @@ -22,7 +22,12 @@ from exo.master.main import Master from exo.routing.event_router import EventRouter from exo.routing.router import Router, get_node_id_keypair -from exo.shared.constants import EXO_DEFAULT_MODELS_DIR, EXO_LOG, EXO_PEER_DOWNLOAD_PORT +from exo.shared.constants import ( + EXO_LOG, + EXO_MODELS_DIRS, + EXO_MODELS_READ_ONLY_DIRS, + EXO_PEER_DOWNLOAD_PORT, +) from exo.shared.election import Election, ElectionResult from exo.shared.logging import logger_cleanup, logger_set_context, logger_setup from exo.shared.types.common import NodeId, SessionId @@ -114,10 +119,17 @@ async def create(cls, args: "Args") -> Self: worker = None if peer_download_enabled: + # Serve from every configured model directory so peers can fetch + # any locally-resident shard regardless of which directory the + # downloader landed it in. ``EXO_MODELS_DIRS`` already includes + # ``EXO_DEFAULT_MODELS_DIR`` as its first entry; ``EXO_MODELS_READ_ONLY_DIRS`` + # captures pre-populated mounts (e.g. shared NFS caches) that + # ``select_download_dir_for_shard`` excludes from new writes but + # which other peers still benefit from being able to read. peer_file_server = PeerFileServer( host="0.0.0.0", port=args.peer_download_port, - models_dir=EXO_DEFAULT_MODELS_DIR, + models_dirs=(*EXO_MODELS_DIRS, *EXO_MODELS_READ_ONLY_DIRS), ) if not args.no_downloads: From 09420bd27af143f653de1e76d9126db5412b73fb Mon Sep 17 00:00:00 2001 From: jw-wcv <101585096+jw-wcv@users.noreply.github.com> Date: Sat, 9 May 2026 01:11:17 -0700 Subject: [PATCH 20/29] Reject oversized peer partials; relocate node-ID keypair to config dir Two Codex round-(N+1) P1 fixes for PR #16: 1. peer_download.py:120 -- "Reject oversized stale partials before peer resume". The resume loop ran ``while n_read < expected_size``, so any pre-existing ``.partial`` larger than ``expected_size`` skipped the loop entirely and the final ``rename`` accepted the bad bytes as the canonical download. In offline mode (where peer download integrity check is intentionally skipped) this would permanently poison the model cache. Discard the stale oversized partial up front and restart from zero on this peer. 2. constants.py:73 -- "Keep node identity keypair out of cache storage". ``EXO_NODE_ID_KEYPAIR`` was rooted at ``EXO_CACHE_HOME``, which is subject to normal cache cleanup (e.g. ``trash ~/.cache/exo``); a wipe silently regenerated a fresh peer ID and broke cluster membership / mDNS routes. Move it to ``EXO_CONFIG_HOME`` (matching the existing ``test_node_id_in_config_dir`` invariant) and add a one-shot transparent migrator in ``router._migrate_legacy_node_id_keypair`` so existing nodes retain identity after the upgrade. Regression tests: * ``test_oversized_stale_partial_is_discarded_and_retransferred`` -- pre-fix would rename the junk ``.partial`` as the downloaded weights file; post-fix the file is re-fetched and matches the canonical bytes. * ``test_legacy_keypair_is_migrated_to_new_location``, ``test_migration_is_idempotent_when_new_location_already_present``, ``test_migration_skipped_when_no_legacy_file``, ``test_get_node_id_keypair_uses_migrated_legacy_keypair``. X-Orchestraitor-Plan: ecosystem_repo_standardization_aeee88ba X-Agent-Platform: cursor-claude-opus-4.7 --- src/exo/download/peer_download.py | 47 ++++++-- src/exo/download/tests/test_peer_download.py | 51 +++++++++ src/exo/routing/router.py | 76 +++++++++++-- .../routing/tests/test_node_id_migration.py | 102 ++++++++++++++++++ src/exo/shared/constants.py | 14 ++- 5 files changed, 271 insertions(+), 19 deletions(-) create mode 100644 src/exo/routing/tests/test_node_id_migration.py diff --git a/src/exo/download/peer_download.py b/src/exo/download/peer_download.py index 0800a6fd18..67ecbb7810 100644 --- a/src/exo/download/peer_download.py +++ b/src/exo/download/peer_download.py @@ -42,9 +42,12 @@ async def get_peer_file_status( """ url = f"http://{peer_host}:{peer_port}/status/{model_id_normalized}" try: - async with aiohttp.ClientSession( - timeout=aiohttp.ClientTimeout(total=timeout) - ) as session, session.get(url) as r: + async with ( + aiohttp.ClientSession( + timeout=aiohttp.ClientTimeout(total=timeout) + ) as session, + session.get(url) as r, + ): if r.status != 200: return None data = cast(dict[str, object], await r.json()) @@ -109,9 +112,30 @@ async def download_file_from_peer( url = f"http://{peer_host}:{peer_port}/files/{model_id_normalized}/{file_path}" n_read = 0 - # Resume from existing partial + # Resume from existing partial. + # + # Codex P1 (PR #16 round 5): a stale ``.partial`` left over from a + # previous run can be larger than ``expected_size`` (e.g. the peer + # was serving the wrong revision, the on-disk file was truncated + # to a different blob, or the user manually replaced it). In that + # case ``n_read >= expected_size`` skips the resume loop entirely + # and we'd then ``rename`` a too-large file as the "successful" + # result. With offline mode we explicitly skip hash verification, + # so the bad bytes would never get caught downstream and would + # poison the model cache. Fail fast: drop the stale partial and + # restart from zero on this peer. if await aios.path.exists(partial_path): - n_read = (await aios.stat(partial_path)).st_size + existing_size = (await aios.stat(partial_path)).st_size + if existing_size > expected_size: + logger.warning( + f"Discarding stale oversized peer partial for {file_path} " + f"({existing_size} > expected {expected_size}); " + "restarting download from zero" + ) + await aios.remove(partial_path) + n_read = 0 + else: + n_read = existing_size poll_count = 0 chunk_size = 8 * 1024 * 1024 # 8MB, matching HF download @@ -123,9 +147,12 @@ async def download_file_from_peer( headers["Range"] = f"bytes={n_read}-" got_bytes = False - async with aiohttp.ClientSession( - timeout=aiohttp.ClientTimeout(total=300, sock_read=60) - ) as session, session.get(url, headers=headers) as r: + async with ( + aiohttp.ClientSession( + timeout=aiohttp.ClientTimeout(total=300, sock_read=60) + ) as session, + session.get(url, headers=headers) as r, + ): if r.status == 416: # Range not satisfiable - peer doesn't have more yet pass @@ -142,9 +169,7 @@ async def download_file_from_peer( got_bytes = True on_progress(n_read, expected_size, False) elif r.status == 404: - logger.debug( - f"File {file_path} not found on peer {peer_host}" - ) + logger.debug(f"File {file_path} not found on peer {peer_host}") return None else: logger.warning( diff --git a/src/exo/download/tests/test_peer_download.py b/src/exo/download/tests/test_peer_download.py index 0a774ec365..f7046c87d2 100644 --- a/src/exo/download/tests/test_peer_download.py +++ b/src/exo/download/tests/test_peer_download.py @@ -410,6 +410,57 @@ async def test_download_returns_none_on_unreachable_peer( ) assert result is None + async def test_oversized_stale_partial_is_discarded_and_retransferred( + self, peer_server: PeerFileServer, temp_models_dir: Path, tmp_path: Path + ) -> None: + """Codex P1 (PR #16 round 5): a stale ``.partial`` larger than + ``expected_size`` left over from a previous run must be + rejected, NOT silently renamed as the successful download. + + Pre-fix the resume loop ran ``while n_read < expected_size``, + so an oversized partial skipped the loop entirely and the + final ``rename`` accepted bad bytes. In offline mode (where + hash verification is intentionally skipped) this would + permanently poison the model cache without any warning. + Post-fix the oversized partial is discarded and the file is + re-fetched from the peer. + """ + model_dir = temp_models_dir / "test--model" + await aios.makedirs(model_dir, exist_ok=True) + canonical = b"the canonical model weights" + async with aiofiles.open(model_dir / "weights.bin", "wb") as f: + await f.write(canonical) + + download_dir = tmp_path / "downloads" / "test--model" + await aios.makedirs(download_dir, exist_ok=True) + # Stale partial from a "previous run" -- bigger than the + # canonical file and full of junk bytes. Pre-fix, this would + # be the file that ended up renamed as ``weights.bin``. + stale_partial = download_dir / "weights.bin.partial" + stale_bytes = b"\xde\xad\xbe\xef" * (len(canonical) * 2) + async with aiofiles.open(stale_partial, "wb") as f: + await f.write(stale_bytes) + assert (await aios.stat(stale_partial)).st_size > len(canonical) + + result = await download_file_from_peer( + "127.0.0.1", + peer_server.port, + "test--model", + "weights.bin", + download_dir, + len(canonical), + ) + + assert result is not None + assert result == download_dir / "weights.bin" + async with aiofiles.open(result, "rb") as f: + downloaded = await f.read() + assert downloaded == canonical, ( + "stale oversized partial must NOT be accepted as the " + "downloaded file; the fix must redownload from the peer" + ) + assert not stale_partial.exists() + async def test_skip_already_complete( self, peer_server: PeerFileServer, temp_models_dir: Path, tmp_path: Path ) -> None: diff --git a/src/exo/routing/router.py b/src/exo/routing/router.py index 5b679fe192..70423a88e6 100644 --- a/src/exo/routing/router.py +++ b/src/exo/routing/router.py @@ -24,7 +24,7 @@ from filelock import FileLock from loguru import logger -from exo.shared.constants import EXO_NODE_ID_KEYPAIR +from exo.shared.constants import EXO_LEGACY_NODE_ID_KEYPAIR, EXO_NODE_ID_KEYPAIR from exo.utils.channels import Receiver, Sender, channel from exo.utils.pydantic_ext import FrozenModel from exo.utils.task_group import TaskGroup @@ -293,18 +293,37 @@ def _clear_publish_failures(self, topic: str) -> None: def get_node_id_keypair( path: str | bytes | PathLike[str] | PathLike[bytes] = EXO_NODE_ID_KEYPAIR, + legacy_path: str | bytes | PathLike[str] | PathLike[bytes] | None = ( + EXO_LEGACY_NODE_ID_KEYPAIR + ), ) -> Keypair: """ Obtains the :class:`Keypair` associated with this node-ID. Obtain the :class:`PeerId` by from it. + + On first call after the upgrade, if the new ``path`` (config dir) + has no keypair yet but the legacy cache-dir ``legacy_path`` does, + the legacy file is moved to ``path`` so the node retains its + identity across the relocation. Migration is best-effort: if + moving fails (e.g. cross-device link errors on Linux when + ``XDG_*`` dirs span filesystems), the legacy bytes are copied + instead. Either way, the legacy file is removed once the new + location holds a valid keypair so subsequent calls do not need + to re-check. """ + resolved_path = Path(str(path)) + + def lock_path(p: str | bytes | PathLike[str] | PathLike[bytes]) -> Path: + return Path(str(p) + ".lock") + + resolved_path.parent.mkdir(parents=True, exist_ok=True) - def lock_path(path: str | bytes | PathLike[str] | PathLike[bytes]) -> Path: - return Path(str(path) + ".lock") + if legacy_path is not None: + _migrate_legacy_node_id_keypair(resolved_path, Path(str(legacy_path))) # operate with cross-process lock to avoid race conditions - with FileLock(lock_path(path)): - with open(path, "a+b") as f: # opens in append-mode => starts at EOF + with FileLock(lock_path(resolved_path)): + with open(resolved_path, "a+b") as f: # opens in append-mode => starts at EOF # if non-zero EOF, then file exists => use to get node-ID if f.tell() != 0: f.seek(0) # go to start & read protobuf-encoded bytes @@ -316,7 +335,52 @@ def lock_path(path: str | bytes | PathLike[str] | PathLike[bytes]) -> Path: logger.warning(f"Encountered error when trying to get keypair: {e}") # if no valid credentials, create new ones and persist - with open(path, "w+b") as f: + with open(resolved_path, "w+b") as f: keypair = Keypair.generate() f.write(keypair.to_bytes()) return keypair + + +def _migrate_legacy_node_id_keypair( + new_path: Path, + legacy_path: Path, +) -> None: + """One-shot migrator for the cache→config relocation of the + node-ID keypair (Codex P1 PR #16 round 5). + + Idempotent and best-effort: only acts when ``new_path`` is + absent and ``legacy_path`` exists. Falls back to byte copy if + ``rename`` fails (cross-device, permissions, etc.). On any + exception we log and bail -- the caller will then generate a + fresh keypair, which is suboptimal but better than crashing + startup over identity-file housekeeping. + """ + try: + if new_path.exists() or not legacy_path.exists(): + return + # Ensure the destination directory exists for either the + # ``replace`` (which silently no-ops on missing parent on some + # platforms but raises ``ENOENT`` on others) or the byte-copy + # fallback. ``get_node_id_keypair`` already creates this dir + # for the same reason; doing it again here keeps the migrator + # safely callable from tests in isolation. + new_path.parent.mkdir(parents=True, exist_ok=True) + try: + legacy_path.replace(new_path) + except OSError as rename_err: + logger.debug( + f"Cross-device rename of legacy keypair failed ({rename_err}); " + "falling back to byte copy." + ) + new_path.write_bytes(legacy_path.read_bytes()) + legacy_path.unlink(missing_ok=True) + logger.info( + f"Migrated node-ID keypair from legacy cache path {legacy_path} " + f"to persistent config path {new_path}." + ) + except Exception as e: + logger.warning( + f"Failed to migrate legacy node-ID keypair from {legacy_path} " + f"to {new_path}: {e}. The node will generate a new identity; " + "manually copy the file if cluster membership matters." + ) diff --git a/src/exo/routing/tests/test_node_id_migration.py b/src/exo/routing/tests/test_node_id_migration.py new file mode 100644 index 0000000000..2eff238b4d --- /dev/null +++ b/src/exo/routing/tests/test_node_id_migration.py @@ -0,0 +1,102 @@ +"""Regression tests for the cache→config migration of the node-ID +keypair (Codex P1, PR #16 round 5). + +The keypair used to live under ``EXO_CACHE_HOME``, which is subject +to normal cache cleanup (e.g. ``trash ~/.cache/exo``) and would +silently regenerate a new node-ID. The fix relocates the keypair to +``EXO_CONFIG_HOME`` and migrates legacy files transparently. +""" + +from __future__ import annotations + +from pathlib import Path + +from exo_pyo3_bindings import Keypair + +from exo.routing.router import ( + _migrate_legacy_node_id_keypair, # pyright: ignore[reportPrivateUsage] + get_node_id_keypair, +) + + +def test_legacy_keypair_is_migrated_to_new_location(tmp_path: Path) -> None: + """Legacy cache-dir keypair must be moved to the new config-dir + location and the legacy file removed -- so the node retains its + identity across the upgrade and a future cache wipe doesn't + resurrect a stale copy.""" + legacy_path = tmp_path / "cache" / "node_id.keypair" + new_path = tmp_path / "config" / "node_id.keypair" + legacy_path.parent.mkdir(parents=True) + + keypair = Keypair.generate() + legacy_bytes = keypair.to_bytes() + legacy_path.write_bytes(legacy_bytes) + + _migrate_legacy_node_id_keypair(new_path, legacy_path) + + assert new_path.exists(), "migration must place keypair at new location" + assert new_path.read_bytes() == legacy_bytes, ( + "migration must preserve the byte-for-byte keypair contents " + "so the node retains its peer ID" + ) + assert not legacy_path.exists(), ( + "migration must remove the legacy file once the new location " + "holds the keypair, otherwise a later cache wipe could " + "resurrect a now-stale copy" + ) + + +def test_migration_is_idempotent_when_new_location_already_present( + tmp_path: Path, +) -> None: + """If the new location already has a keypair, migration must be + a no-op even when a legacy file exists -- otherwise we'd + overwrite the (canonical) new keypair with a stale legacy one.""" + legacy_path = tmp_path / "cache" / "node_id.keypair" + new_path = tmp_path / "config" / "node_id.keypair" + legacy_path.parent.mkdir(parents=True) + new_path.parent.mkdir(parents=True) + + canonical = Keypair.generate().to_bytes() + legacy = Keypair.generate().to_bytes() + new_path.write_bytes(canonical) + legacy_path.write_bytes(legacy) + + _migrate_legacy_node_id_keypair(new_path, legacy_path) + + assert new_path.read_bytes() == canonical, ( + "migration must NOT overwrite an existing new-location keypair" + ) + # We deliberately leave the legacy file alone in this branch: + # touching it would surprise an operator who is intentionally + # keeping both copies during an upgrade window. + assert legacy_path.exists() + + +def test_migration_skipped_when_no_legacy_file(tmp_path: Path) -> None: + """Fresh installs must not error when the legacy path is absent.""" + new_path = tmp_path / "config" / "node_id.keypair" + new_path.parent.mkdir(parents=True) + + _migrate_legacy_node_id_keypair(new_path, tmp_path / "missing.keypair") + + assert not new_path.exists() + + +def test_get_node_id_keypair_uses_migrated_legacy_keypair(tmp_path: Path) -> None: + """End-to-end: ``get_node_id_keypair`` must surface the legacy + keypair bytes when only the legacy path holds a valid file at + call time, completing the cache→config migration on first use.""" + legacy_path = tmp_path / "cache" / "node_id.keypair" + new_path = tmp_path / "config" / "node_id.keypair" + legacy_path.parent.mkdir(parents=True) + + keypair = Keypair.generate() + expected_bytes = keypair.to_bytes() + legacy_path.write_bytes(expected_bytes) + + loaded = get_node_id_keypair(path=new_path, legacy_path=legacy_path) + + assert loaded.to_bytes() == expected_bytes + assert new_path.exists() + assert not legacy_path.exists() diff --git a/src/exo/shared/constants.py b/src/exo/shared/constants.py index f823195798..89bc512cab 100644 --- a/src/exo/shared/constants.py +++ b/src/exo/shared/constants.py @@ -69,8 +69,18 @@ def _parse_colon_dirs(env_var: str) -> tuple[Path, ...]: EXO_LOG_DIR = EXO_CACHE_HOME / "exo_log" EXO_LOG = EXO_LOG_DIR / "exo.log" -# Identity (config) -EXO_NODE_ID_KEYPAIR = EXO_CACHE_HOME / "node_id.keypair" +# Identity (config -- persistent across cache eviction). +# +# Codex P1 (PR #16 round 5): keeping the node-ID keypair under +# ``EXO_CACHE_HOME`` makes cluster identity vulnerable to normal +# cache cleanup, which causes nodes to come up with a new peer ID +# after a cache wipe and breaks the intended persistence of cluster +# membership / mDNS routes. Identity material lives under +# ``EXO_CONFIG_HOME`` instead. The legacy cache path is migrated +# on first use by ``get_node_id_keypair`` to preserve existing +# identity across the upgrade. +EXO_NODE_ID_KEYPAIR = EXO_CONFIG_HOME / "node_id.keypair" +EXO_LEGACY_NODE_ID_KEYPAIR = EXO_CACHE_HOME / "node_id.keypair" EXO_CONFIG_FILE = EXO_CONFIG_HOME / "config.toml" # libp2p topics for event forwarding From 755cce7541eacf747e3745cf7ffa8337fa8055cd Mon Sep 17 00:00:00 2001 From: jw-wcv <101585096+jw-wcv@users.noreply.github.com> Date: Sat, 9 May 2026 01:55:21 -0700 Subject: [PATCH 21/29] Scope node-ID keypair per process and migrate inside the file lock Codex P1/P2 (PR #16 round-(N+2), router.py:297, router.py:322): the same-host multi-node workflow this PR introduces (distinct peer-download ports per process) requires distinct NodeIds per process so peer-discovery's self-skip and routing's unique-NodeId invariants hold. Concurrent startups also raced on the legacy cache->config keypair migration because it ran before FileLock. - Add ``process_scope`` parameter to ``get_node_id_keypair`` that is folded into the on-disk filename (``node_id..keypair``). Single-process deployments default to None and keep the existing shared file; ``main.py`` passes ``args.peer_download_port`` so multi-process same-host runs land on distinct identities. - Move the legacy migration call inside the ``FileLock`` so two processes can't both pass the existence check and race into divergent in-memory vs. on-disk identities. - Legacy file remains unscoped: the first process to migrate adopts the operator's existing identity; later processes (other scopes) start with fresh keypairs, which is exactly what per-process isolation requires. - Add regression tests covering distinct scopes producing distinct keypairs, scope stability across calls, scoped legacy adoption, and a structural check that the migration runs inside the FileLock. --- src/exo/main.py | 14 +- src/exo/routing/router.py | 64 ++++++++- .../routing/tests/test_node_id_migration.py | 135 ++++++++++++++++++ 3 files changed, 206 insertions(+), 7 deletions(-) diff --git a/src/exo/main.py b/src/exo/main.py index 70ad935642..7922b5664e 100644 --- a/src/exo/main.py +++ b/src/exo/main.py @@ -59,7 +59,19 @@ class Node: @classmethod async def create(cls, args: "Args") -> Self: - keypair = get_node_id_keypair() + # Codex P1 (PR #16 round-(N+2), router.py:297): scope the + # on-disk node-ID keypair by ``--peer-download-port``. That + # port already MUST differ between processes on the same + # host (see the ``--peer-download-port`` help text), making + # it the natural per-process disambiguator. Single-process + # deployments use the default port and therefore land on a + # stable scoped filename, preserving identity across + # restarts; multi-process same-host deployments get + # distinct keypair files (and therefore distinct + # ``NodeId``s) so peer-discovery's ``peer_node_id == + # node_id`` self-skip and routing's unique-NodeId + # invariants continue to hold. + keypair = get_node_id_keypair(process_scope=args.peer_download_port) node_id = NodeId(keypair.to_node_id()) session_id = SessionId(master_node_id=node_id, election_clock=0) router = Router.create( diff --git a/src/exo/routing/router.py b/src/exo/routing/router.py index 70423a88e6..1319290f4d 100644 --- a/src/exo/routing/router.py +++ b/src/exo/routing/router.py @@ -296,11 +296,24 @@ def get_node_id_keypair( legacy_path: str | bytes | PathLike[str] | PathLike[bytes] | None = ( EXO_LEGACY_NODE_ID_KEYPAIR ), + process_scope: int | str | None = None, ) -> Keypair: """ Obtains the :class:`Keypair` associated with this node-ID. Obtain the :class:`PeerId` by from it. + Codex P1 (PR #16 round-(N+2), router.py:297): when ``process_scope`` + is provided, the on-disk keypair filename is suffixed with the + scope (typically the libp2p / peer-download port the caller has + chosen). This preserves *per-process* node identity isolation + when multiple exo processes run on the same host -- the new + same-host multi-node workflow added in this PR (distinct + peer-download ports per process) needs each process to have a + distinct ``NodeId`` so peer discovery's ``peer_node_id == + node_id`` self-skip and routing's unique-node-id assumptions + hold. Single-process deployments leave ``process_scope=None`` + and continue using the shared persistent keypair file. + On first call after the upgrade, if the new ``path`` (config dir) has no keypair yet but the legacy cache-dir ``legacy_path`` does, the legacy file is moved to ``path`` so the node retains its @@ -309,20 +322,42 @@ def get_node_id_keypair( ``XDG_*`` dirs span filesystems), the legacy bytes are copied instead. Either way, the legacy file is removed once the new location holds a valid keypair so subsequent calls do not need - to re-check. + to re-check. Codex P2 (PR #16 round-(N+2), router.py:322): the + migration is performed INSIDE the file lock so two concurrent + processes can't both pass the existence check and then race + each other into divergent in-memory vs. on-disk identities. """ - resolved_path = Path(str(path)) + base_path = Path(str(path)) + resolved_path = ( + _scoped_keypair_path(base_path, process_scope) + if process_scope is not None + else base_path + ) + + # The legacy cache file pre-dates the per-process scoping change + # so it is intentionally NOT scope-suffixed. We migrate it as a + # one-shot identity adoption for whichever process happens to + # boot first; subsequent processes (with different scopes) will + # observe the legacy file already gone and start with fresh + # keypairs, which is exactly what per-process isolation requires. + resolved_legacy: Path | None = ( + Path(str(legacy_path)) if legacy_path is not None else None + ) def lock_path(p: str | bytes | PathLike[str] | PathLike[bytes]) -> Path: return Path(str(p) + ".lock") resolved_path.parent.mkdir(parents=True, exist_ok=True) - if legacy_path is not None: - _migrate_legacy_node_id_keypair(resolved_path, Path(str(legacy_path))) - - # operate with cross-process lock to avoid race conditions + # operate with cross-process lock to avoid race conditions. + # The migration MUST run inside this lock so two processes that + # boot simultaneously can't both pass the migrator's existence + # check, race the keypair generation, and end up with the same + # on-disk file but divergent in-memory identities. with FileLock(lock_path(resolved_path)): + if resolved_legacy is not None: + _migrate_legacy_node_id_keypair(resolved_path, resolved_legacy) + with open(resolved_path, "a+b") as f: # opens in append-mode => starts at EOF # if non-zero EOF, then file exists => use to get node-ID if f.tell() != 0: @@ -341,6 +376,23 @@ def lock_path(p: str | bytes | PathLike[str] | PathLike[bytes]) -> Path: return keypair +def _scoped_keypair_path(base: Path, scope: int | str) -> Path: + """Return ``base`` with the process scope inserted before the + suffix (e.g. ``node_id.keypair`` + scope ``52415`` -> + ``node_id.52415.keypair``). + + We insert the scope as a stem-suffix rather than as a directory + so concurrent processes on the same host share the parent dir + (and the file lock's inode-level coordination still works for + legacy-migration safety) while their identity files remain + distinct. Scope is rendered with ``str()`` so callers can pass + a port number, a UUID, a hostname, etc. + """ + suffix = base.suffix or ".keypair" + stem = base.stem if base.suffix else base.name + return base.parent / f"{stem}.{scope}{suffix}" + + def _migrate_legacy_node_id_keypair( new_path: Path, legacy_path: Path, diff --git a/src/exo/routing/tests/test_node_id_migration.py b/src/exo/routing/tests/test_node_id_migration.py index 2eff238b4d..b1d88736e0 100644 --- a/src/exo/routing/tests/test_node_id_migration.py +++ b/src/exo/routing/tests/test_node_id_migration.py @@ -100,3 +100,138 @@ def test_get_node_id_keypair_uses_migrated_legacy_keypair(tmp_path: Path) -> Non assert loaded.to_bytes() == expected_bytes assert new_path.exists() assert not legacy_path.exists() + + +# --------------------------------------------------------------------------- +# Codex P1 (PR #16 round-(N+2), router.py:297): per-process scoping +# --------------------------------------------------------------------------- +# +# The new same-host multi-node workflow (per-process +# ``--peer-download-port``) requires distinct ``NodeId``s per +# process so peer-discovery's self-skip and routing's unique-NodeId +# invariants hold. ``get_node_id_keypair`` therefore accepts a +# ``process_scope`` argument that is folded into the on-disk +# filename. + + +def test_distinct_process_scopes_produce_distinct_keypairs(tmp_path: Path) -> None: + """Two processes that pass different scopes (e.g. distinct + peer-download ports) MUST end up with different keypair files + and different on-disk identities; otherwise two same-host + nodes would race on the same NodeId.""" + base_path = tmp_path / "config" / "node_id.keypair" + + keypair_a = get_node_id_keypair( + path=base_path, legacy_path=None, process_scope=52416 + ) + keypair_b = get_node_id_keypair( + path=base_path, legacy_path=None, process_scope=52417 + ) + + assert keypair_a.to_bytes() != keypair_b.to_bytes(), ( + "distinct process scopes must yield distinct keypairs so " + "same-host multi-node deployments don't share a NodeId" + ) + + scoped_a = base_path.parent / "node_id.52416.keypair" + scoped_b = base_path.parent / "node_id.52417.keypair" + assert scoped_a.exists() + assert scoped_b.exists() + assert scoped_a.read_bytes() != scoped_b.read_bytes() + + +def test_same_process_scope_is_stable_across_calls(tmp_path: Path) -> None: + """Per-process scoping must remain *persistent*: the same + process (same scope) must load the same keypair on subsequent + calls -- otherwise restart would silently churn NodeIds.""" + base_path = tmp_path / "config" / "node_id.keypair" + + first = get_node_id_keypair(path=base_path, legacy_path=None, process_scope=52416) + second = get_node_id_keypair(path=base_path, legacy_path=None, process_scope=52416) + + assert first.to_bytes() == second.to_bytes() + + +def test_migration_runs_inside_file_lock(tmp_path: Path) -> None: + """Codex P2 (PR #16 round-(N+2), router.py:322): the legacy + migration must execute *inside* ``FileLock`` so two processes + booting concurrently can't both pass the existence check, race + each other into divergent in-memory keypairs, and end up with + mismatched identities for the same on-disk file. + + We assert this structurally by hooking ``_migrate_legacy_node_id_keypair`` + and ``filelock.FileLock`` and verifying the lock is acquired + *before* the migrator is called. A pre-lock migration would + show ``migrate_called=True`` while the lock is still + ``unacquired``.""" + import exo.routing.router as router_mod + + legacy_path = tmp_path / "cache" / "node_id.keypair" + base_path = tmp_path / "config" / "node_id.keypair" + legacy_path.parent.mkdir(parents=True) + legacy_path.write_bytes(Keypair.generate().to_bytes()) + + lock_state: dict[str, bool] = {"acquired": False, "acquired_before_migrate": False} + + # We hook ``router_mod.FileLock`` (the symbol the production + # code dereferences) with a thin wrapper class. The wrapper + # delegates to the real ``FileLock`` instance but flips the + # ``acquired`` flag on entry, which the migrator hook below + # then snapshots. This keeps the type of ``FileLock`` intact + # while letting us observe acquire-vs-migrate ordering. + real_filelock = router_mod.FileLock + + class _ObservingFileLock: + def __init__(self, *args: object, **kwargs: object) -> None: + self._inner = real_filelock(*args, **kwargs) # pyright: ignore[reportArgumentType] + + def __enter__(self) -> object: + lock_state["acquired"] = True + return self._inner.__enter__() + + def __exit__(self, *exc: object) -> object: + return self._inner.__exit__(*exc) # pyright: ignore[reportArgumentType] + + original_migrate = router_mod._migrate_legacy_node_id_keypair # pyright: ignore[reportPrivateUsage] + + def _track_migrate(new_path: Path, legacy: Path) -> None: + lock_state["acquired_before_migrate"] = lock_state["acquired"] + original_migrate(new_path, legacy) + + router_mod.FileLock = _ObservingFileLock + router_mod._migrate_legacy_node_id_keypair = _track_migrate # pyright: ignore[reportPrivateUsage] + try: + _ = get_node_id_keypair(path=base_path, legacy_path=legacy_path) + finally: + router_mod.FileLock = real_filelock + router_mod._migrate_legacy_node_id_keypair = original_migrate # pyright: ignore[reportPrivateUsage] + + assert lock_state["acquired_before_migrate"] is True, ( + "legacy migration must run INSIDE the FileLock to prevent a " + "concurrent-startup race on the on-disk keypair" + ) + + +def test_legacy_migration_adopts_into_scoped_path(tmp_path: Path) -> None: + """When a process passes a scope and a legacy unscoped keypair + exists, the legacy bytes must be adopted into the scoped path. + This is the upgrade-time behaviour: the first process to boot + after the upgrade keeps the operator's existing identity; later + processes (different scopes) start with fresh identities, which + is exactly what per-process isolation requires.""" + legacy_path = tmp_path / "cache" / "node_id.keypair" + base_path = tmp_path / "config" / "node_id.keypair" + legacy_path.parent.mkdir(parents=True) + + expected_bytes = Keypair.generate().to_bytes() + legacy_path.write_bytes(expected_bytes) + + loaded = get_node_id_keypair( + path=base_path, legacy_path=legacy_path, process_scope=52416 + ) + + scoped = base_path.parent / "node_id.52416.keypair" + assert loaded.to_bytes() == expected_bytes + assert scoped.exists(), "legacy bytes must land at the scoped path" + assert scoped.read_bytes() == expected_bytes + assert not legacy_path.exists() From 5cc987e1c44213a24fac596c907ce4778361e339 Mon Sep 17 00:00:00 2001 From: jw-wcv <101585096+jw-wcv@users.noreply.github.com> Date: Sat, 9 May 2026 02:31:56 -0700 Subject: [PATCH 22/29] Combine listening ports for keypair scope; restart on 200-on-resume Codex P1 (PR #16 round-(N+3), main.py:74): the previous scope used ``args.peer_download_port`` only. With ``--no-downloads`` / ``--no-peer-download`` the peer file server doesn't bind, so two same-host processes can both keep the default ``peer_download_port`` and would then load the same scoped keypair file -- producing identical ``NodeId``s and breaking election/routing's unique-NodeId invariants. The new ``_node_id_keypair_scope`` helper combines libp2p, api, and peer-download ports: at least one of those MUST differ between two same-host processes (each is a distinct local socket bind), so the resulting scope is always per-process unique while remaining stable across restarts of the same configuration. Codex P1 (PR #16 round-(N+3), peer_download.py:162): on resume ``download_file_from_peer`` sends a ``Range`` header but accepted HTTP 200 and appended to the existing partial. A non-compliant peer server is allowed to ignore Range and return full content with 200, which would duplicate bytes, push ``n_read`` past ``expected_size``, and -- because offline mode skips hash verification -- silently poison the model cache by renaming the oversized file as success. Now we treat 200-on-resume as a restart: discard the partial, reset ``n_read = 0``, and the next loop iteration re-fetches from zero. Add regression tests: - ``TestNodeIdKeypairScope`` covers the per-process scope helper: distinct libp2p / api / peer-download ports each yield distinct scopes; the same args yield the same scope; and the original bug (same default peer_download_port with peer-download disabled) is now isolated by libp2p_port differences. - ``test_resume_with_200_response_discards_partial_and_restarts`` stands up a tiny aiohttp server that always returns 200 (even for ranged requests), primes a partial file, and asserts the client discards the partial, restarts from zero, and lands the canonical bytes matching ``expected_size``. --- src/exo/download/peer_download.py | 19 ++++ src/exo/download/tests/test_peer_download.py | 86 ++++++++++++++++ src/exo/main.py | 57 ++++++++--- .../routing/tests/test_node_id_migration.py | 98 +++++++++++++++++++ 4 files changed, 247 insertions(+), 13 deletions(-) diff --git a/src/exo/download/peer_download.py b/src/exo/download/peer_download.py index 67ecbb7810..95949797ca 100644 --- a/src/exo/download/peer_download.py +++ b/src/exo/download/peer_download.py @@ -147,6 +147,7 @@ async def download_file_from_peer( headers["Range"] = f"bytes={n_read}-" got_bytes = False + range_was_requested = n_read > 0 async with ( aiohttp.ClientSession( timeout=aiohttp.ClientTimeout(total=300, sock_read=60) @@ -156,6 +157,24 @@ async def download_file_from_peer( if r.status == 416: # Range not satisfiable - peer doesn't have more yet pass + elif range_was_requested and r.status == 200: + # Codex P1 (PR #16 round-(N+3), peer_download.py:162): + # we sent a ``Range`` header (we have a partial), but + # the peer ignored it and returned full content with + # 200. Appending the body would duplicate the + # already-downloaded prefix, push ``n_read`` past + # ``expected_size``, and -- because offline mode + # skips hash verification -- silently poison the + # model file. Drop the partial and restart from + # zero on the next loop iteration so the next + # request gets fresh, intact bytes. + logger.warning( + f"Peer {peer_host} ignored Range header for " + f"{file_path} (returned 200 instead of 206); " + "discarding partial and restarting from zero" + ) + await aios.remove(partial_path) + n_read = 0 elif r.status in (200, 206): async with aiofiles.open( partial_path, "ab" if n_read > 0 else "wb" diff --git a/src/exo/download/tests/test_peer_download.py b/src/exo/download/tests/test_peer_download.py index f7046c87d2..ba34d046ed 100644 --- a/src/exo/download/tests/test_peer_download.py +++ b/src/exo/download/tests/test_peer_download.py @@ -461,6 +461,92 @@ async def test_oversized_stale_partial_is_discarded_and_retransferred( ) assert not stale_partial.exists() + async def test_resume_with_200_response_discards_partial_and_restarts( + self, tmp_path: Path + ) -> None: + """Codex P1 (PR #16 round-(N+3), peer_download.py:162): when + the client resumes a download (``n_read > 0``) it sends a + ``Range`` header, but a non-compliant server is permitted to + ignore it and return full content with HTTP 200 instead of + 206. Pre-fix the client appended the full body to the + partial, pushing ``n_read`` past ``expected_size`` and + renaming the oversized file as the "successful" download. + In offline mode hash verification is intentionally skipped, + so the bad bytes silently poisoned the model cache. + + We stand up a tiny aiohttp server that returns full content + with 200 even when ``Range`` is set, prime a partial file, + and assert the client discards the partial, restarts from + zero, and lands the canonical bytes (matching ``expected_size``). + """ + from aiohttp import web + + canonical = b"the canonical model weights" + + async def handler(request: web.Request) -> web.Response: + # Always return full content with HTTP 200, ignoring any + # ``Range`` header. This simulates the non-compliant + # peer server the codex finding flagged. + del request + return web.Response(body=canonical, status=200) + + app = web.Application() + # Path must match the client's URL template: + # ``http://host:port/files//`` + _ = app.router.add_get("/files/test/weights.bin", handler) + runner = web.AppRunner(app) + await runner.setup() + site = web.TCPSite(runner, "127.0.0.1", 0) + await site.start() + try: + # Mirror the ``peer_server`` fixture: ``aiohttp.web.TCPSite`` + # surfaces the kernel-assigned port through its private + # ``_server.sockets`` attribute. The module-level + # ``reportPrivateUsage=false`` and ``type: ignore`` here + # match the existing fixture's access pattern. + port: int = cast( + int, + site._server.sockets[0].getsockname()[1], # type: ignore[union-attr] + ) + + download_dir = tmp_path / "downloads" / "test" + await aios.makedirs(download_dir, exist_ok=True) + # Prime a stale partial with bogus content to force the + # resume codepath (Range header) on the first attempt. + partial_path = download_dir / "weights.bin.partial" + stale_prefix = b"\xff" * (len(canonical) // 2) + async with aiofiles.open(partial_path, "wb") as f: + await f.write(stale_prefix) + assert (await aios.stat(partial_path)).st_size > 0 + + result = await download_file_from_peer( + "127.0.0.1", + port, + "test", + "weights.bin", + download_dir, + len(canonical), + ) + + assert result is not None, ( + "the client should ultimately succeed by discarding the " + "stale partial and restarting from zero on the second " + "request" + ) + assert result == download_dir / "weights.bin" + async with aiofiles.open(result, "rb") as f: + downloaded = await f.read() + assert downloaded == canonical, ( + "200-on-resume must trigger a partial restart; the final " + "file must be the canonical bytes, not a duplicate-prefix " + "concatenation" + ) + assert not partial_path.exists(), ( + "successful download must remove the partial path" + ) + finally: + await runner.cleanup() + async def test_skip_already_complete( self, peer_server: PeerFileServer, temp_models_dir: Path, tmp_path: Path ) -> None: diff --git a/src/exo/main.py b/src/exo/main.py index 7922b5664e..d598f4a24a 100644 --- a/src/exo/main.py +++ b/src/exo/main.py @@ -59,19 +59,27 @@ class Node: @classmethod async def create(cls, args: "Args") -> Self: - # Codex P1 (PR #16 round-(N+2), router.py:297): scope the - # on-disk node-ID keypair by ``--peer-download-port``. That - # port already MUST differ between processes on the same - # host (see the ``--peer-download-port`` help text), making - # it the natural per-process disambiguator. Single-process - # deployments use the default port and therefore land on a - # stable scoped filename, preserving identity across - # restarts; multi-process same-host deployments get - # distinct keypair files (and therefore distinct - # ``NodeId``s) so peer-discovery's ``peer_node_id == - # node_id`` self-skip and routing's unique-NodeId - # invariants continue to hold. - keypair = get_node_id_keypair(process_scope=args.peer_download_port) + # Codex P1 (PR #16 round-(N+3), main.py:74): scope the on-disk + # node-ID keypair by the *combination* of ports the operator + # has chosen, not just ``--peer-download-port``. The earlier + # peer-download-only scope leaked identity collisions when + # ``--no-downloads`` / ``--no-peer-download`` is set: that + # mode doesn't bind the peer file server, so two same-host + # processes can legitimately keep the default + # ``peer_download_port`` and would then load the same scoped + # keypair file -- producing identical ``NodeId``s and + # breaking election/routing's unique-NodeId invariants. + # + # Combined-port scoping is robust against every same-host + # multi-process configuration: at least one of the listening + # ports MUST differ between processes (libp2p, peer-download, + # api -- each is a distinct local socket bind), so the scope + # tuple differs whenever the actual configuration differs. + # Single-process deployments on default ports keep a stable + # filename (e.g. ``node_id.libp2p-0.api-52415.peer-52416.keypair``) + # so identity persists across restarts. + process_scope = _node_id_keypair_scope(args) + keypair = get_node_id_keypair(process_scope=process_scope) node_id = NodeId(keypair.to_node_id()) session_id = SessionId(master_node_id=node_id, election_clock=0) router = Router.create( @@ -427,6 +435,29 @@ def _last_seen_ages(self, state: State) -> dict[str, float]: return ages +def _node_id_keypair_scope(args: "Args") -> str: + """Produce a stable per-process scope for the node-ID keypair file. + + Combines every listening port the operator could plausibly + distinguish between same-host processes: ``--libp2p-port``, + ``--api-port``, and ``--peer-download-port``. At least one of + these MUST differ between two processes that share a host (each + is a distinct local socket bind), so the resulting scope is + always unique per process while remaining stable across + restarts of the same configuration. + + Used by :func:`get_node_id_keypair` to avoid two same-host + processes loading the same scoped keypair file when peer + download is disabled (which would otherwise let them collide + on the default ``peer_download_port`` since no socket is + actually being bound). See Codex P1 (PR #16 round-(N+3), + main.py:74). + """ + return ( + f"libp2p-{args.libp2p_port}.api-{args.api_port}.peer-{args.peer_download_port}" + ) + + def _darwin_en0_ip_address() -> str | None: try: return subprocess.check_output( diff --git a/src/exo/routing/tests/test_node_id_migration.py b/src/exo/routing/tests/test_node_id_migration.py index b1d88736e0..3994b5616c 100644 --- a/src/exo/routing/tests/test_node_id_migration.py +++ b/src/exo/routing/tests/test_node_id_migration.py @@ -212,6 +212,104 @@ def _track_migrate(new_path: Path, legacy: Path) -> None: ) +class TestNodeIdKeypairScope: + """Codex P1 (PR #16 round-(N+3), main.py:74): the node-ID keypair + scope MUST account for every distinguishable per-process port, + not just ``--peer-download-port``. With peer-download disabled + the operator can legitimately keep the default + ``peer_download_port`` (no socket bind), so the previous + peer-only scope let two same-host processes share an identity. + """ + + def _build_args( + self, + *, + libp2p_port: int = 0, + api_port: int = 52415, + peer_download_port: int = 52416, + no_downloads: bool = False, + no_peer_download: bool = False, + spawn_api: bool = False, + ): # noqa: ANN202 + from exo.main import Args + + return Args( + libp2p_port=libp2p_port, + api_port=api_port, + peer_download_port=peer_download_port, + no_downloads=no_downloads, + no_peer_download=no_peer_download, + spawn_api=spawn_api, + ) + + def test_distinct_libp2p_ports_yield_distinct_scopes(self) -> None: + from exo.main import ( + _node_id_keypair_scope, # pyright: ignore[reportPrivateUsage] + ) + + scope_a = _node_id_keypair_scope(self._build_args(libp2p_port=4001)) + scope_b = _node_id_keypair_scope(self._build_args(libp2p_port=4002)) + assert scope_a != scope_b + + def test_distinct_api_ports_yield_distinct_scopes(self) -> None: + from exo.main import ( + _node_id_keypair_scope, # pyright: ignore[reportPrivateUsage] + ) + + scope_a = _node_id_keypair_scope(self._build_args(api_port=52415)) + scope_b = _node_id_keypair_scope(self._build_args(api_port=52416)) + assert scope_a != scope_b + + def test_distinct_peer_download_ports_yield_distinct_scopes(self) -> None: + from exo.main import ( + _node_id_keypair_scope, # pyright: ignore[reportPrivateUsage] + ) + + scope_a = _node_id_keypair_scope(self._build_args(peer_download_port=52416)) + scope_b = _node_id_keypair_scope(self._build_args(peer_download_port=52417)) + assert scope_a != scope_b + + def test_disabled_peer_download_with_same_default_port_still_isolates( + self, + ) -> None: + """The original Codex P1 (round-(N+3)) regression: with + ``--no-peer-download`` two processes can both keep + ``peer_download_port=52416``. They MUST still get distinct + scopes when *some* other port differs (here, libp2p). + Pre-fix the scope was just ``peer_download_port`` and these + two configs collided on the same keypair.""" + from exo.main import ( + _node_id_keypair_scope, # pyright: ignore[reportPrivateUsage] + ) + + process_one = self._build_args( + libp2p_port=4001, + no_peer_download=True, + peer_download_port=52416, + ) + process_two = self._build_args( + libp2p_port=4002, + no_peer_download=True, + peer_download_port=52416, + ) + assert _node_id_keypair_scope(process_one) != _node_id_keypair_scope( + process_two + ) + + def test_identical_args_yield_identical_scope(self) -> None: + """Stability invariant: the same configuration on a single + process across restarts must hash to the same scope so the + node retains its identity across restarts.""" + from exo.main import ( + _node_id_keypair_scope, # pyright: ignore[reportPrivateUsage] + ) + + args = self._build_args( + libp2p_port=4001, api_port=52415, peer_download_port=52416 + ) + assert _node_id_keypair_scope(args) == _node_id_keypair_scope(args) + + def test_legacy_migration_adopts_into_scoped_path(tmp_path: Path) -> None: """When a process passes a scope and a legacy unscoped keypair exists, the legacy bytes must be adopted into the scoped path. From fe348e98b68e1afc2f9a712d74888554b3c60665 Mon Sep 17 00:00:00 2001 From: jw-wcv <101585096+jw-wcv@users.noreply.github.com> Date: Sat, 9 May 2026 05:03:32 -0700 Subject: [PATCH 23/29] PR #16 R(N+8) P1: address libp2p-port=0 scope collision and oversized peer responses main.py: when --libp2p-port 0 is set, the configured value is the literal 0 even though each process binds a different ephemeral port at runtime. Two same-host worker-only processes (no API, no peer download) sharing the default api/peer ports would otherwise produce identical scope strings and load the same on-disk keypair file, breaking the unique-NodeId invariant. Fold os.getpid() into the scope when libp2p_port == 0; the trade-off (ephemeral identity for ephemeral ports) is the right semantic since the operator opted into ephemeral binding by setting libp2p_port=0. peer_download.py: bound the inner read by 'expected_size - n_read' and treat any extra bytes as a peer protocol violation. Pre-fix the loop kept appending until EOF and only checked n_read < expected_size afterwards, so an oversized response (peer serving a stale/wrong blob) was accepted as success and renamed into the cache. In offline mode hash verification is skipped, so this silently poisoned local weights. New tests: - test_libp2p_port_zero_uses_pid_for_per_process_isolation: verifies the scope contains 'pid-' when libp2p_port=0. - test_libp2p_port_zero_in_two_processes_yield_distinct_scopes: monkeypatches os.getpid to simulate two same-host processes both binding libp2p_port=0 with identical api/peer ports and asserts the scopes diverge. - test_oversized_peer_response_is_rejected_and_restarted: stands up a bad aiohttp peer that always serves canonical+'POISONED' bytes and asserts the client never lands the trailing junk in the cache. --- src/exo/download/peer_download.py | 42 ++++++- src/exo/download/tests/test_peer_download.py | 107 ++++++++++++++++++ src/exo/main.py | 21 ++++ .../routing/tests/test_node_id_migration.py | 78 +++++++++++++ 4 files changed, 247 insertions(+), 1 deletion(-) diff --git a/src/exo/download/peer_download.py b/src/exo/download/peer_download.py index 95949797ca..793a572ae4 100644 --- a/src/exo/download/peer_download.py +++ b/src/exo/download/peer_download.py @@ -176,17 +176,57 @@ async def download_file_from_peer( await aios.remove(partial_path) n_read = 0 elif r.status in (200, 206): + # Codex P1 (PR #16 round-(N+8), peer_download.py:187): + # bound the inner read by ``expected_size - n_read`` + # and treat any extra bytes as a peer protocol + # violation. Pre-fix the loop kept appending until + # EOF and only checked ``n_read < expected_size`` + # afterward, so an oversized response (peer + # serving a stale/wrong blob) was accepted as + # success and renamed into the model cache. In + # offline mode hash verification is skipped, so + # this silently poisoned local weights. Now we + # cap each chunk at the remaining budget and bail + # out the moment a peer tries to send extra data. + oversized_response = False async with aiofiles.open( partial_path, "ab" if n_read > 0 else "wb" ) as f: while True: - chunk = await r.content.read(chunk_size) + remaining = expected_size - n_read + if remaining <= 0: + # We have everything we need. Read one + # more byte to detect peer + # over-supplying; if the stream isn't + # EOF, the peer is sending more bytes + # than ``expected_size`` claims. + tail = await r.content.read(1) + if tail: + oversized_response = True + break + chunk = await r.content.read(min(chunk_size, remaining)) if not chunk: break written = await f.write(chunk) n_read += written got_bytes = True on_progress(n_read, expected_size, False) + if oversized_response: + # Discard the partial: we cannot trust any + # bytes from a peer that violates the + # advertised file size, especially in + # offline mode where hash verification is + # skipped. Restart from zero on the next + # iteration so a fresh request gets a + # well-bounded response. + logger.warning( + f"Peer {peer_host} returned oversized response for " + f"{file_path} (advertised {expected_size} bytes, " + "stream still had data when budget was exhausted); " + "discarding partial and restarting from zero" + ) + await aios.remove(partial_path) + n_read = 0 elif r.status == 404: logger.debug(f"File {file_path} not found on peer {peer_host}") return None diff --git a/src/exo/download/tests/test_peer_download.py b/src/exo/download/tests/test_peer_download.py index ba34d046ed..81f6b4909c 100644 --- a/src/exo/download/tests/test_peer_download.py +++ b/src/exo/download/tests/test_peer_download.py @@ -547,6 +547,113 @@ async def handler(request: web.Request) -> web.Response: finally: await runner.cleanup() + async def test_oversized_peer_response_is_rejected_and_restarted( + self, tmp_path: Path + ) -> None: + """Codex P1 (PR #16 round-(N+8), peer_download.py:187): the + download loop used to keep appending bytes until EOF and only + check ``n_read < expected_size`` afterwards. A non-compliant + peer that serves *more* bytes than the advertised + ``expected_size`` would push ``n_read`` past it, the file + would be renamed as a successful download, and -- because + offline mode skips hash verification -- silently poison the + model cache. + + We stand up a tiny aiohttp server that always returns + ``len(canonical) + 8`` bytes regardless of how much was + requested. Pre-fix this would land a corrupt file in the + cache. Post-fix the client must discard each oversized + response and never end up with a final file containing extra + bytes.""" + from aiohttp import web + + canonical = b"the canonical model weights" + # The payload the bad peer always serves: the canonical + # bytes plus extra trailing bytes the peer claimed wouldn't + # exist. This is the attack/bug the fix guards against. + oversized_payload = canonical + b"POISONED" + request_count = 0 + max_requests = 4 # keep test fast: client retries a few times + + async def handler(request: web.Request) -> web.Response: + nonlocal request_count + request_count += 1 + del request + if request_count > max_requests: + # Surface a definitive failure if the client keeps + # hammering the bad peer; that means the fix + # regressed and we'd otherwise hang. + return web.Response(body=b"", status=500) + return web.Response(body=oversized_payload, status=200) + + app = web.Application() + _ = app.router.add_get("/files/test/weights.bin", handler) + runner = web.AppRunner(app) + await runner.setup() + site = web.TCPSite(runner, "127.0.0.1", 0) + await site.start() + try: + port: int = cast( + int, + site._server.sockets[0].getsockname()[1], # type: ignore[union-attr] + ) + + download_dir = tmp_path / "downloads" / "test" + await aios.makedirs(download_dir, exist_ok=True) + + result = await download_file_from_peer( + "127.0.0.1", + port, + "test", + "weights.bin", + download_dir, + len(canonical), + ) + + # The bad peer never serves a well-bounded response, so + # the client cannot complete. The contract is "no + # corrupt data lands in the cache". We tolerate either + # outcome: + # 1. ``result is None`` (client gave up after retries); or + # 2. ``result == canonical`` (a future improvement + # where we keep the canonical-prefix bytes after + # stripping the over-supply). + # The forbidden outcome is the final file containing + # the trailing "POISONED" bytes. + partial_path = download_dir / "weights.bin.partial" + target_path = download_dir / "weights.bin" + + if result is not None: + async with aiofiles.open(result, "rb") as f: + downloaded = await f.read() + assert downloaded == canonical, ( + "if the client claims success, the final file MUST " + "be exactly the canonical bytes; oversized peer " + "responses must never land trailing junk in the " + f"cache. got len={len(downloaded)} bytes: {downloaded!r}" + ) + # In the giving-up branch, neither file should remain + # poisoned. The partial is removed every time we detect + # over-supply, and we never rename to ``target_path`` + # without a clean-budgeted final write. + if target_path.exists(): + async with aiofiles.open(target_path, "rb") as f: + final = await f.read() + assert final == canonical, ( + f"target path was renamed but contains " + f"{len(final)} bytes (expected {len(canonical)}); " + "oversized response made it into the cache" + ) + if partial_path.exists(): + size = (await aios.stat(partial_path)).st_size + assert size <= len(canonical), ( + f"partial path retains {size} bytes after " + f"oversized response (expected <= {len(canonical)}); " + "over-supply must be discarded, not preserved" + ) + finally: + await runner.cleanup() + async def test_skip_already_complete( self, peer_server: PeerFileServer, temp_models_dir: Path, tmp_path: Path ) -> None: diff --git a/src/exo/main.py b/src/exo/main.py index d598f4a24a..54749c0112 100644 --- a/src/exo/main.py +++ b/src/exo/main.py @@ -452,7 +452,28 @@ def _node_id_keypair_scope(args: "Args") -> str: on the default ``peer_download_port`` since no socket is actually being bound). See Codex P1 (PR #16 round-(N+3), main.py:74). + + Codex P1 (PR #16 round-(N+8), main.py:457): when + ``--libp2p-port 0`` is set, the configured value is the literal + ``0`` even though each process actually binds a different + ephemeral port at runtime. Two same-host worker-only processes + (no API, no peer download) sharing the default + ``peer_download_port`` and ``api_port`` -- but each binding + ``libp2p_port=0`` -- would otherwise produce identical scope + strings ``"libp2p-0.api-...peer-..."`` and load the same + keypair file, breaking the unique-NodeId invariant. + Stability across restarts is impossible in this configuration + anyway (the OS hands out a different ephemeral port on every + bind), so fold in ``os.getpid()`` as a per-process + discriminator. The trade-off (ephemeral identity for + ephemeral ports) is the right semantic: the operator opted + into ephemeral binding by setting ``libp2p_port=0``. """ + if args.libp2p_port == 0: + return ( + f"libp2p-pid-{os.getpid()}." + f"api-{args.api_port}.peer-{args.peer_download_port}" + ) return ( f"libp2p-{args.libp2p_port}.api-{args.api_port}.peer-{args.peer_download_port}" ) diff --git a/src/exo/routing/tests/test_node_id_migration.py b/src/exo/routing/tests/test_node_id_migration.py index 3994b5616c..bf3860719b 100644 --- a/src/exo/routing/tests/test_node_id_migration.py +++ b/src/exo/routing/tests/test_node_id_migration.py @@ -11,6 +11,7 @@ from pathlib import Path +import pytest from exo_pyo3_bindings import Keypair from exo.routing.router import ( @@ -309,6 +310,83 @@ def test_identical_args_yield_identical_scope(self) -> None: ) assert _node_id_keypair_scope(args) == _node_id_keypair_scope(args) + def test_libp2p_port_zero_uses_pid_for_per_process_isolation(self) -> None: + """Codex P1 (PR #16 round-(N+8), main.py:457): with + ``--libp2p-port 0`` the configured port is the literal ``0`` + even though each process binds a different ephemeral port at + runtime. Without per-process discrimination two same-host + worker-only processes (no API, no peer download) sharing the + default ``peer_download_port`` and ``api_port`` would collide + on the same scoped keypair. The scope must therefore fold in + ``os.getpid()`` (or another guaranteed per-process + discriminator) when ``libp2p_port == 0``.""" + import os + + from exo.main import ( + _node_id_keypair_scope, # pyright: ignore[reportPrivateUsage] + ) + + scope = _node_id_keypair_scope( + self._build_args( + libp2p_port=0, + api_port=52415, + peer_download_port=52416, + no_peer_download=True, + spawn_api=False, + ) + ) + + assert f"pid-{os.getpid()}" in scope, ( + f"libp2p_port=0 must mix in os.getpid() to discriminate " + f"same-host processes binding ephemeral libp2p ports; " + f"got scope={scope!r}" + ) + + def test_libp2p_port_zero_in_two_processes_yield_distinct_scopes( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + """End-to-end: simulate two same-host processes both binding + ``libp2p_port=0`` and otherwise default ports. Pre-fix they + collided on a single keypair file; post-fix the scopes + differ because each carries its own PID.""" + import os + + from exo.main import ( + _node_id_keypair_scope, # pyright: ignore[reportPrivateUsage] + ) + + # Process A: real PID + scope_a = _node_id_keypair_scope( + self._build_args( + libp2p_port=0, + api_port=52415, + peer_download_port=52416, + no_peer_download=True, + spawn_api=False, + ) + ) + + # Process B: simulate a different PID via monkeypatch + real_pid = os.getpid() + monkeypatch.setattr(os, "getpid", lambda: real_pid + 1) + scope_b = _node_id_keypair_scope( + self._build_args( + libp2p_port=0, + api_port=52415, + peer_download_port=52416, + no_peer_download=True, + spawn_api=False, + ) + ) + + assert scope_a != scope_b, ( + "two same-host processes both binding libp2p_port=0 with " + "identical api/peer ports must produce distinct keypair " + "scopes; otherwise they load the same on-disk keypair " + "and collide on NodeId, breaking routing/election " + f"invariants. scope_a={scope_a!r} scope_b={scope_b!r}" + ) + def test_legacy_migration_adopts_into_scoped_path(tmp_path: Path) -> None: """When a process passes a scope and a legacy unscoped keypair From 78031c4d1aabd3b7e90a543fe682ff2e451b4212 Mon Sep 17 00:00:00 2001 From: jw-wcv <101585096+jw-wcv@users.noreply.github.com> Date: Sat, 9 May 2026 05:39:36 -0700 Subject: [PATCH 24/29] PR #16 R(N+9) P1+P2: restore safer node timeout and search all model roots master/main.py P1: revert the 5s node_inactivity_timeout to 30s. Pre-fix the 5s window was too tight: any node that didn't publish NodeGatheredInfo within 5s (e.g. when fast probes are unavailable or delayed) was marked timed out and had its instances deleted in the same _plan loop. Because this loop now ticks every second, normal telemetry jitter caused repeated false-positive NodeTimedOut events and unnecessary instance churn. The 1s tick stays so the master reacts quickly when a node *does* genuinely time out. peer_file_server.py P2: search every configured root before selecting the model dir to serve. Pre-fix _locate_model_dir returned the first root that *contained* the model directory regardless of completeness. When an earlier writable root held a partial download and a later read-only mount held a complete copy, /status and /files only saw the partial tree -- peers thought the node had no canonical copy and fell back to HuggingFace despite a complete local copy on a different mount. New behavior: - /status unions across all matching roots; for duplicate filenames, complete files dominate partials, larger partials dominate smaller. - /files prefers the root holding a complete copy of the requested file; falls back to the largest partial when no root has it complete; returns 404 only when every root truly lacks the file. Added _locate_all_model_dirs helper that returns every root holding the model in priority order (writable before read-only). New tests: - test_status_unions_partial_in_first_root_with_complete_in_second: verifies /status surfaces the complete file from a later root and marks it as complete (not partial). - test_files_serves_complete_copy_when_first_root_has_only_partial: end-to-end via aiohttp client, verifies /files returns 200 with canonical bytes and X-Exo-Complete=true. --- src/exo/download/peer_file_server.py | 197 ++++++++++++++----- src/exo/download/tests/test_peer_download.py | 131 ++++++++++++ src/exo/master/main.py | 13 +- 3 files changed, 290 insertions(+), 51 deletions(-) diff --git a/src/exo/download/peer_file_server.py b/src/exo/download/peer_file_server.py index a674c85390..ab3816802a 100644 --- a/src/exo/download/peer_file_server.py +++ b/src/exo/download/peer_file_server.py @@ -63,77 +63,140 @@ async def _handle_health(self, request: web.Request) -> web.Response: return web.json_response({"status": "ok"}) async def _handle_status(self, request: web.Request) -> web.Response: - """Return status of all files for a model (complete + in-progress).""" + """Return status of all files for a model (complete + in-progress). + + Codex P2 (PR #16 round-(N+9), peer_file_server.py:201): when + a model's contents are split across multiple configured + roots (e.g. an earlier writable cache holds a partial copy + and a later read-only mount holds the full canonical copy), + report the union across every root that contains the model. + For files that appear in more than one root we keep the + most-complete entry (complete > larger partial) so peers see + the true 'most progressed' version of the file. The earlier + single-root behaviour caused the peer downloader to + miss-report missing files and silently fall back to + HuggingFace even when this node had a complete copy + elsewhere on disk. + """ model_id = request.match_info["model_id"] - model_dir = await self._locate_model_dir(model_id) - if model_dir is None: - # No matching directory containing this model. + model_dirs = await self._locate_all_model_dirs(model_id) + if not model_dirs: return web.json_response({"files": []}) - files: list[dict[str, object]] = [] - for item in model_dir.rglob("*"): - relative_path = item.relative_to(model_dir).as_posix() - if item.is_dir() or relative_path.endswith(".partial.meta"): - continue - if _resolve_child(model_dir, relative_path) is None: - continue - - if relative_path.endswith(".partial"): - # In-progress file - read meta for safe bytes - meta = await _read_partial_meta(item) - if meta: - total = _meta_int(meta, "total") - safe_bytes = _meta_int(meta, "safe_bytes") - files.append( + # path -> entry; complete files dominate partials; larger + # partials dominate smaller ones when no complete is found. + merged: dict[str, dict[str, object]] = {} + + def merge(entry: dict[str, object]) -> None: + path = cast(str, entry["path"]) + existing = merged.get(path) + if existing is None: + merged[path] = entry + return + existing_complete = bool(existing["complete"]) + new_complete = bool(entry["complete"]) + new_partial_is_more_complete = ( + not new_complete + and not existing_complete + and cast(int, entry["safe_bytes"]) + > cast(int, existing["safe_bytes"]) + ) + if (new_complete and not existing_complete) or ( + new_partial_is_more_complete + ): + merged[path] = entry + # complete-vs-complete: keep the first (sizes equal by + # construction, callers only need one entry). + + for model_dir in model_dirs: + for item in model_dir.rglob("*"): + relative_path = item.relative_to(model_dir).as_posix() + if item.is_dir() or relative_path.endswith(".partial.meta"): + continue + if _resolve_child(model_dir, relative_path) is None: + continue + + if relative_path.endswith(".partial"): + meta = await _read_partial_meta(item) + if meta: + total = _meta_int(meta, "total") + safe_bytes = _meta_int(meta, "safe_bytes") + merge( + { + "path": relative_path.removesuffix(".partial"), + "size": total, + "complete": False, + "safe_bytes": safe_bytes, + } + ) + else: + stat = await aios.stat(item) + merge( { - "path": relative_path.removesuffix(".partial"), - "size": total, - "complete": False, - "safe_bytes": safe_bytes, + "path": relative_path, + "size": stat.st_size, + "complete": True, + "safe_bytes": stat.st_size, } ) - else: - # Complete file - stat = await aios.stat(item) - files.append( - { - "path": relative_path, - "size": stat.st_size, - "complete": True, - "safe_bytes": stat.st_size, - } - ) - - return web.json_response({"files": files}) + + return web.json_response({"files": list(merged.values())}) async def _handle_file(self, request: web.Request) -> web.StreamResponse: """Serve a model file with Range request support. For complete files: standard HTTP file serving. For .partial files: serves only the safe byte range (flushed to disk). + + Codex P2 (PR #16 round-(N+9), peer_file_server.py:201): when + a model's contents are split across multiple roots, prefer + the root holding a *complete* copy of the requested file + over the first root that merely contains the model + directory. Fall back to a partial copy only if no root has + the file complete. Pre-fix the server returned 404 for + files that lived in a later root, forcing peers to fall + back to HuggingFace despite a complete local copy. """ model_id = request.match_info["model_id"] file_path = request.match_info["file_path"] - model_dir = await self._locate_model_dir(model_id) - if model_dir is None: + model_dirs = await self._locate_all_model_dirs(model_id) + if not model_dirs: return web.Response(status=404, text="Model not found") - complete_path = _resolve_child(model_dir, file_path) - partial_path = _resolve_child(model_dir, f"{file_path}.partial") - if complete_path is None or partial_path is None: - return web.Response(status=404, text="File not found") + complete_hit: Path | None = None + best_partial: tuple[Path, PartialMeta] | None = None + + for model_dir in model_dirs: + complete_candidate = _resolve_child(model_dir, file_path) + partial_candidate = _resolve_child(model_dir, f"{file_path}.partial") + if complete_candidate is None or partial_candidate is None: + continue + if complete_hit is None and await aios.path.exists(complete_candidate): + complete_hit = complete_candidate + # Complete copy in the first matching root wins; we + # don't need to scan the rest for this file. + break + if await aios.path.exists(partial_candidate): + meta = await _read_partial_meta(partial_candidate) + if ( + meta + and _meta_int(meta, "safe_bytes") > 0 + and ( + best_partial is None + or _meta_int(meta, "safe_bytes") + > _meta_int(best_partial[1], "safe_bytes") + ) + ): + best_partial = (partial_candidate, meta) - # Determine which file to serve and its safe size - if await aios.path.exists(complete_path): - serve_path = complete_path - file_size = (await aios.stat(complete_path)).st_size + if complete_hit is not None: + serve_path = complete_hit + file_size = (await aios.stat(complete_hit)).st_size safe_bytes = file_size is_complete = True - elif await aios.path.exists(partial_path): - meta = await _read_partial_meta(partial_path) - if not meta or _meta_int(meta, "safe_bytes") == 0: - return web.Response(status=404, text="File not available yet") + elif best_partial is not None: + partial_path, meta = best_partial serve_path = partial_path file_size = _meta_int(meta, "total") safe_bytes = _meta_int(meta, "safe_bytes") @@ -193,6 +256,12 @@ async def _locate_model_dir(self, model_id: str) -> Path | None: probe the filesystem. We prefer the first directory in ``models_dirs`` that has a matching subdirectory; this preserves caller-specified priority (e.g. writable before read-only) without re-sorting. + + Note: callers that need to merge contents across multiple + roots should use :meth:`_locate_all_model_dirs` instead. That + helper exists to address Codex P2 (PR #16 round-(N+9), + peer_file_server.py:201) where an earlier incomplete root + masked a later complete copy. """ for root in self.models_dirs: candidate = _resolve_child(root, model_id) @@ -202,6 +271,34 @@ async def _locate_model_dir(self, model_id: str) -> Path | None: return candidate return None + async def _locate_all_model_dirs(self, model_id: str) -> list[Path]: + """Return every configured directory that contains ``model_id``. + + Roots are returned in the same priority order as + ``self.models_dirs`` (writable before read-only) so callers + can short-circuit to the first complete copy. Each candidate + root is path-traversal-checked independently before we probe + the filesystem. + + Codex P2 (PR #16 round-(N+9), peer_file_server.py:201): + ``_locate_model_dir`` returned the first root that *contained* + the model directory regardless of completeness. When an + earlier writable root held a partial download and a later + read-only mount held a complete copy, ``/status`` and + ``/files`` only saw the partial tree -- peers thought the + node had no canonical copy and fell back to HuggingFace. + Callers that merge across roots use this helper to scan + every match. + """ + matches: list[Path] = [] + for root in self.models_dirs: + candidate = _resolve_child(root, model_id) + if candidate is None: + continue + if await aios.path.exists(candidate): + matches.append(candidate) + return matches + def _resolve_child(root: Path, relative_path: str) -> Path | None: """Resolve relative_path under root, rejecting path traversal.""" diff --git a/src/exo/download/tests/test_peer_download.py b/src/exo/download/tests/test_peer_download.py index 81f6b4909c..f591a16d01 100644 --- a/src/exo/download/tests/test_peer_download.py +++ b/src/exo/download/tests/test_peer_download.py @@ -8,6 +8,7 @@ import aiofiles import aiofiles.os as aios +import aiohttp import pytest from exo.download.peer_download import download_file_from_peer, get_peer_file_status @@ -335,6 +336,136 @@ async def test_constructor_rejects_empty_directory_list(self) -> None: with pytest.raises(ValueError, match="at least one models directory"): PeerFileServer(host="127.0.0.1", port=0, models_dirs=[]) + async def test_status_unions_partial_in_first_root_with_complete_in_second( + self, tmp_path: Path + ) -> None: + """Codex P2 (PR #16 round-(N+9), peer_file_server.py:201): if + an earlier root has a stale/incomplete model directory and a + later root has a complete copy, ``/status`` must surface the + complete file -- otherwise peers see the file as missing and + fall back to HuggingFace despite the local node having a + canonical copy on a different mount. + """ + from aiohttp import web + + first = tmp_path / "first" + second = tmp_path / "second" + await aios.makedirs(first / "test--model", exist_ok=True) + await aios.makedirs(second / "test--model", exist_ok=True) + + # First root has only a partial of weights.bin (incomplete). + partial_path = first / "test--model" / "weights.bin.partial" + canonical = b"the canonical model weights" + async with aiofiles.open(partial_path, "wb") as f: + await f.write(canonical[: len(canonical) // 2]) + # Companion meta marking 50% safe. + meta_path = first / "test--model" / "weights.bin.partial.meta" + async with aiofiles.open(meta_path, "w") as f: + await f.write( + json.dumps( + { + "total": len(canonical), + "safe_bytes": len(canonical) // 2, + } + ) + ) + + # Second root has the full canonical file (complete). + async with aiofiles.open(second / "test--model" / "weights.bin", "wb") as f: + await f.write(canonical) + + server = PeerFileServer( + host="127.0.0.1", port=0, models_dirs=[first, second] + ) + server._runner = web.AppRunner(server._app) + await server._runner.setup() + site = web.TCPSite(server._runner, "127.0.0.1", 0) + await site.start() + port_int: int = cast(int, site._server.sockets[0].getsockname()[1]) # type: ignore[union-attr] + server.port = port_int + try: + files = await get_peer_file_status("127.0.0.1", port_int, "test--model") + assert files is not None + file_map = {f.path: f for f in files} + assert "weights.bin" in file_map, ( + "complete copy in the second root must surface in /status; " + "got files={file_map.keys()}" + ) + assert file_map["weights.bin"].complete is True, ( + "complete copy in the second root must dominate the " + "partial in the first root; otherwise peers will fall " + "back to HuggingFace" + ) + assert file_map["weights.bin"].size == len(canonical) + finally: + await server.shutdown() + + async def test_files_serves_complete_copy_when_first_root_has_only_partial( + self, tmp_path: Path + ) -> None: + """End-to-end: ``/files/`` must select the root holding + the complete file even when an earlier root has only a + partial. Pre-fix the server returned 404 (or served the + smaller partial via the partial-bytes path) when a complete + file lived in a later root, forcing peers to fall back to + HuggingFace. + """ + from aiohttp import web + + first = tmp_path / "first" + second = tmp_path / "second" + await aios.makedirs(first / "test--model", exist_ok=True) + await aios.makedirs(second / "test--model", exist_ok=True) + + canonical = b"complete-canonical-bytes" + # First root has partial (with valid meta). + partial_path = first / "test--model" / "weights.bin.partial" + async with aiofiles.open(partial_path, "wb") as f: + await f.write(canonical[: len(canonical) // 2]) + meta_path = first / "test--model" / "weights.bin.partial.meta" + async with aiofiles.open(meta_path, "w") as f: + await f.write( + json.dumps( + { + "total": len(canonical), + "safe_bytes": len(canonical) // 2, + } + ) + ) + # Second root has the complete file. + async with aiofiles.open(second / "test--model" / "weights.bin", "wb") as f: + await f.write(canonical) + + server = PeerFileServer( + host="127.0.0.1", port=0, models_dirs=[first, second] + ) + server._runner = web.AppRunner(server._app) + await server._runner.setup() + site = web.TCPSite(server._runner, "127.0.0.1", 0) + await site.start() + port_int: int = cast(int, site._server.sockets[0].getsockname()[1]) # type: ignore[union-attr] + server.port = port_int + try: + url = f"http://127.0.0.1:{port_int}/files/test--model/weights.bin" + async with ( + aiohttp.ClientSession() as session, + session.get(url) as r, + ): + assert r.status == 200, ( + f"expected 200 from /files when complete copy exists in " + f"a later root; got {r.status}" + ) + body = await r.read() + assert body == canonical, ( + f"expected canonical bytes from later root; got " + f"{len(body)} bytes (expected {len(canonical)})" + ) + # Sanity: X-Exo-Complete header should mark this as a + # complete serving (not a partial-bytes fragment). + assert r.headers.get("X-Exo-Complete") == "true" + finally: + await server.shutdown() + class TestPeerDownloadClient: """Tests for downloading files from a peer server.""" diff --git a/src/exo/master/main.py b/src/exo/master/main.py index 83d2871fb2..2e3a104fcf 100644 --- a/src/exo/master/main.py +++ b/src/exo/master/main.py @@ -482,7 +482,18 @@ async def _command_processor(self) -> None: # These plan loops are the cracks showing in our event sourcing architecture - more things could be commands async def _plan(self) -> None: - node_inactivity_timeout = timedelta(seconds=5) + # Codex P1 (PR #16 round-(N+9), master/main.py:486): the + # inactivity timeout MUST stay safely above ``NodeGatheredInfo`` + # cadence jitter -- 5s was too tight (any node that didn't + # publish telemetry within 5s, e.g. when fast probes are + # unavailable or delayed, would be marked timed out and have + # its instances deleted in the same _plan loop). Because + # this loop now ticks every second, normal jitter caused + # repeated false-positive ``NodeTimedOut`` events and + # unnecessary instance churn. Restore the upstream-safe + # 30s budget while keeping the 1s tick so the master still + # reacts quickly when a node *does* genuinely time out. + node_inactivity_timeout = timedelta(seconds=30) tick_interval_seconds = 1.0 while True: From 9336bf25c61aaf8d3bbb84689aa20874dd315f5c Mon Sep 17 00:00:00 2001 From: jw-wcv <101585096+jw-wcv@users.noreply.github.com> Date: Sat, 9 May 2026 06:33:43 -0700 Subject: [PATCH 25/29] PR #16 R(N+10) P2: keep peer file server task alive for cleanup PeerFileServer.run() returned immediately after site.start(), so the task spawned by Node.run() (tg.start_soon(self.peer_file_server.run)) completed on the first event-loop tick. The parent task group considered the server 'done' the moment the listener bound, so when the node was cancelled there was no live coroutine for the task group to drive teardown -- the aiohttp listener kept its TCP socket open until process exit. That manifested as 'OSError: [Errno 48] address already in use' whenever a node was stopped/restarted in the same process (tests, embedded runs, systemd-style restart loops). Make run() block on anyio.sleep_forever() after starting the listener and run runner.cleanup() in a shielded finally on cancellation. The shield prevents the cancellation from killing cleanup itself (which would re-introduce the leak); the cast on self._runner placates the type-checker without weakening the runtime guard against double-drive when an external shutdown() call has already torn things down. Add lifecycle tests that verify (1) run() does not exit on its own after site.start(), and (2) the listening port is reusable immediately after task-group cancellation. Pre-fix the second test fails with EADDRINUSE; post-fix it passes. --- src/exo/download/peer_file_server.py | 52 ++++++++- src/exo/download/tests/test_peer_download.py | 115 ++++++++++++++++++- 2 files changed, 156 insertions(+), 11 deletions(-) diff --git a/src/exo/download/peer_file_server.py b/src/exo/download/peer_file_server.py index ab3816802a..7591ae1699 100644 --- a/src/exo/download/peer_file_server.py +++ b/src/exo/download/peer_file_server.py @@ -25,6 +25,7 @@ import aiofiles import aiofiles.os as aios +import anyio from aiohttp import web from loguru import logger @@ -49,15 +50,57 @@ def __init__(self, host: str, port: int, models_dirs: Sequence[Path]) -> None: self._runner: web.AppRunner | None = None async def run(self) -> None: - self._runner = web.AppRunner(self._app) - await self._runner.setup() - site = web.TCPSite(self._runner, self.host, self.port) + """Start the peer file server and keep the task alive until cancelled. + + Codex P2 (PR #16 round-(N+10), peer_file_server.py:56): pre-fix + ``run()`` returned immediately after ``site.start()``, so the + task spawned by ``Node.run()`` (``tg.start_soon(self.peer_file_server.run)``) + completed on the first event-loop tick and the parent task + group considered the server "done". When the node was + cancelled, there was no live coroutine for the task group to + cancel, so the aiohttp listener kept its TCP socket open + until process exit. That manifested as + ``OSError: [Errno 48] address already in use`` whenever a + node was stopped/restarted in the same process (commonly in + tests, embedded runs, or systemd-style restart loops). + + The fix keeps the coroutine alive via ``anyio.sleep_forever`` + and runs ``self._runner.cleanup()`` in a shielded ``finally`` + block on cancellation, so the listener is reliably released + before the task group considers the server torn down. + """ + runner = web.AppRunner(self._app) + self._runner = runner + await runner.setup() + site = web.TCPSite(runner, self.host, self.port) await site.start() logger.info(f"PeerFileServer listening on {self.host}:{self.port}") + try: + await anyio.sleep_forever() + finally: + # Shield cleanup from the cancellation that woke us so + # ``aiohttp`` can drain in-flight responses and release + # the listening socket before this task is considered + # complete. Without the shield the cleanup itself is + # cancelled immediately, which leaves the socket bound + # and reproduces the original ``EADDRINUSE`` symptom. + with anyio.CancelScope(shield=True): + # Re-read self._runner so an external ``shutdown()`` + # call (e.g. from a separate code path) doesn't drive + # cleanup twice. ``cast`` because the type-checker has + # narrowed ``self._runner`` to ``AppRunner`` from the + # assignment above; an external mutation could still + # have set it to ``None``. + live_runner = cast(web.AppRunner | None, self._runner) + if live_runner is not None: + self._runner = None + await live_runner.cleanup() + logger.info(f"PeerFileServer on {self.host}:{self.port} stopped") async def shutdown(self) -> None: if self._runner: await self._runner.cleanup() + self._runner = None async def _handle_health(self, request: web.Request) -> web.Response: return web.json_response({"status": "ok"}) @@ -98,8 +141,7 @@ def merge(entry: dict[str, object]) -> None: new_partial_is_more_complete = ( not new_complete and not existing_complete - and cast(int, entry["safe_bytes"]) - > cast(int, existing["safe_bytes"]) + and cast(int, entry["safe_bytes"]) > cast(int, existing["safe_bytes"]) ) if (new_complete and not existing_complete) or ( new_partial_is_more_complete diff --git a/src/exo/download/tests/test_peer_download.py b/src/exo/download/tests/test_peer_download.py index f591a16d01..5dfaf649a5 100644 --- a/src/exo/download/tests/test_peer_download.py +++ b/src/exo/download/tests/test_peer_download.py @@ -2,6 +2,7 @@ # pyright: reportPrivateUsage=false import json +import socket from collections.abc import AsyncIterator, Generator, Iterable from pathlib import Path from typing import Callable, cast @@ -9,6 +10,7 @@ import aiofiles import aiofiles.os as aios import aiohttp +import anyio import pytest from exo.download.peer_download import download_file_from_peer, get_peer_file_status @@ -374,9 +376,7 @@ async def test_status_unions_partial_in_first_root_with_complete_in_second( async with aiofiles.open(second / "test--model" / "weights.bin", "wb") as f: await f.write(canonical) - server = PeerFileServer( - host="127.0.0.1", port=0, models_dirs=[first, second] - ) + server = PeerFileServer(host="127.0.0.1", port=0, models_dirs=[first, second]) server._runner = web.AppRunner(server._app) await server._runner.setup() site = web.TCPSite(server._runner, "127.0.0.1", 0) @@ -436,9 +436,7 @@ async def test_files_serves_complete_copy_when_first_root_has_only_partial( async with aiofiles.open(second / "test--model" / "weights.bin", "wb") as f: await f.write(canonical) - server = PeerFileServer( - host="127.0.0.1", port=0, models_dirs=[first, second] - ) + server = PeerFileServer(host="127.0.0.1", port=0, models_dirs=[first, second]) server._runner = web.AppRunner(server._app) await server._runner.setup() site = web.TCPSite(server._runner, "127.0.0.1", 0) @@ -1233,3 +1231,108 @@ async def recording_meta(*args: object, **_kwargs: object) -> tuple[int, str]: "peer-downloaded bytes against HF's authoritative hash; " f"got meta_calls={meta_calls!r}" ) + + +def _allocate_free_tcp_port() -> int: + """Bind ephemeral port 0 to grab a free TCP port; close before reuse. + + Used by lifecycle tests that want to verify a specific port is + released after server teardown -- we cannot bind 0 in the server + itself because the test needs a stable port to assert on. + """ + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as probe: + probe.bind(("127.0.0.1", 0)) + return cast(int, probe.getsockname()[1]) + + +class TestPeerFileServerLifecycle: + """Codex P2 (PR #16 round-(N+10), peer_file_server.py:56): the + coroutine returned by ``PeerFileServer.run()`` must stay alive + until cancelled, otherwise the parent task group considers the + server "done" the moment ``site.start()`` returns and never drives + cleanup -- the listening socket leaks until process exit, causing + ``EADDRINUSE`` on stop/restart in the same process (tests, + embedded runs, systemd-style restart loops). + """ + + async def test_run_blocks_until_cancelled(self, tmp_path: Path) -> None: + models_dir = tmp_path / "models" + await aios.makedirs(models_dir, exist_ok=True) + server = PeerFileServer(host="127.0.0.1", port=0, models_dirs=[models_dir]) + + run_completed = anyio.Event() + + async def _run_and_signal() -> None: + try: + await server.run() + finally: + run_completed.set() + + async with anyio.create_task_group() as tg: + tg.start_soon(_run_and_signal) + # Yield a few times so the server can boot. + for _ in range(5): + await anyio.sleep(0.01) + assert not run_completed.is_set(), ( + "PeerFileServer.run must keep the coroutine alive after " + "site.start() so task-group cancellation can drive " + "teardown; pre-fix it returned immediately and the " + "listening socket leaked until process exit" + ) + tg.cancel_scope.cancel() + assert run_completed.is_set() + + async def test_listening_port_is_released_after_run_cancellation( + self, tmp_path: Path + ) -> None: + """End-to-end EADDRINUSE regression: pre-fix a stop/restart + in the same process raised ``OSError: [Errno 48] address + already in use`` because cleanup never ran. After the fix the + same port must be re-bindable immediately after cancellation. + """ + models_dir = tmp_path / "models" + await aios.makedirs(models_dir, exist_ok=True) + port = _allocate_free_tcp_port() + + server = PeerFileServer(host="127.0.0.1", port=port, models_dirs=[models_dir]) + + async with anyio.create_task_group() as tg: + tg.start_soon(server.run) + for _ in range(10): + await anyio.sleep(0.02) + async with aiohttp.ClientSession() as s: + try: + async with s.get( + f"http://127.0.0.1:{port}/health", + timeout=aiohttp.ClientTimeout(total=0.5), + ) as r: + if r.status == 200: + break + except (aiohttp.ClientError, TimeoutError): + continue + else: + raise AssertionError( + "PeerFileServer never started listening on the " + f"allocated port {port}" + ) + tg.cancel_scope.cancel() + + # Restart on the same port immediately. Pre-fix this raised + # EADDRINUSE because the prior listener was never closed. + server2 = PeerFileServer(host="127.0.0.1", port=port, models_dirs=[models_dir]) + async with anyio.create_task_group() as tg2: + tg2.start_soon(server2.run) + await anyio.sleep(0.05) + async with ( + aiohttp.ClientSession() as s, + s.get( + f"http://127.0.0.1:{port}/health", + timeout=aiohttp.ClientTimeout(total=2.0), + ) as r, + ): + assert r.status == 200, ( + "server2 must come up cleanly on the recycled " + "port; pre-fix the prior server's socket " + "leaked and this raised EADDRINUSE" + ) + tg2.cancel_scope.cancel() From effabd9e62f510ee970ddb66757f83b01def69e8 Mon Sep 17 00:00:00 2001 From: jw-wcv <101585096+jw-wcv@users.noreply.github.com> Date: Sat, 9 May 2026 07:23:29 -0700 Subject: [PATCH 26/29] PR #16 R(N+10) P2: materialize zero-byte marker files in peer transfer The peer transfer path skipped every file whose declared size was 0 (e.g. .gitattributes markers, empty __init__.py shims, empty config sentinels), so the local snapshot diverged from the filtered file list HF would have produced. DownloadCompleted was published with an incomplete model directory and downstream loaders that probe for those marker files (chat-template adapters, processor configs that expect an empty sentinel) failed in ways that didn't point back at the peer step. After the canonical (non-empty) peer transfers succeed, materialize each zero-byte marker as a local empty file using aiofiles in append mode (so a resumed-from-partial marker isn't truncated). Marker materialization is intentionally deferred until the canonical transfer succeeds: a partial peer transfer must NOT leave behind orphan empty files that masquerade as a complete download and confuse the HF fallback's already-downloaded probe. If marker creation itself fails (filesystem permissions, etc.) we fall back to HF for the full snapshot integrity guarantee. Add two regression tests: - test_zero_byte_marker_files_materialized_after_peer_transfer: asserts both root- and nested-zero-byte markers land on disk with size 0 after a successful peer transfer. - test_zero_byte_files_not_created_when_canonical_transfer_fails: asserts a failing canonical transfer leaves the markers absent so the HF fallback starts from a clean directory state. --- src/exo/download/peer_shard_downloader.py | 40 +++- src/exo/download/tests/test_peer_download.py | 186 +++++++++++++++++++ 2 files changed, 225 insertions(+), 1 deletion(-) diff --git a/src/exo/download/peer_shard_downloader.py b/src/exo/download/peer_shard_downloader.py index 85b282f252..e37b3d02bb 100644 --- a/src/exo/download/peer_shard_downloader.py +++ b/src/exo/download/peer_shard_downloader.py @@ -17,6 +17,7 @@ from pathlib import Path from typing import Any, AsyncIterator, Callable, Literal +import aiofiles import aiofiles.os as aios from loguru import logger @@ -347,10 +348,26 @@ def on_file_progress( start_time=all_start_time, ) - # Download all files in parallel + # Codex P2 (PR #16 round-(N+10), peer_shard_downloader.py:354): + # zero-byte files (e.g. ``.gitattributes`` markers, empty + # ``__init__.py`` shims) MUST still be materialized so the + # local snapshot mirrors the filtered file list HF would + # have produced. Pre-fix the peer path silently skipped any + # file with ``size in (None, 0)`` and reported success, so + # ``DownloadCompleted`` was published with an incomplete + # local model directory -- subsequent loads that touched + # those marker files (model loaders, processors that probe + # for ``chat_template.json``, etc.) would then fail in ways + # that don't point back at the peer step. + zero_byte_files: list[str] = [] tasks: list[Coroutine[Any, Any, bool]] = [] for f in filtered_file_list: if f.size is None or f.size == 0: + # Defer the local touch until after we know the rest + # of the peer transfer succeeded; a partial peer + # transfer should not leave behind orphan empty + # marker files that masquerade as a complete download. + zero_byte_files.append(f.path) continue peer_info = peer_file_map.get(f.path) if peer_info and peer_info.safe_bytes > 0: @@ -366,6 +383,27 @@ def on_file_progress( if any(isinstance(r, Exception) or r is False for r in results): return None + for marker_path in zero_byte_files: + full_path = target_dir / marker_path + try: + await aios.makedirs(full_path.parent, exist_ok=True) + # ``aios.path.exists`` first to avoid an unnecessary + # touch (and the corresponding mtime bump) when + # resume-from-partial finds the marker already on + # disk. ``aios.open`` in append mode is the safest + # way to materialize the empty file without + # truncating an already-present marker. + if not await aios.path.exists(full_path): + async with aiofiles.open(full_path, mode="a"): + pass + except Exception as exc: + logger.warning( + f"Could not materialize zero-byte marker file " + f"{full_path} after peer transfer: {exc}; " + f"falling back to HF for full snapshot integrity" + ) + return None + # Emit final progress final_progress = calculate_repo_progress( shard, diff --git a/src/exo/download/tests/test_peer_download.py b/src/exo/download/tests/test_peer_download.py index 5dfaf649a5..2c7445bfac 100644 --- a/src/exo/download/tests/test_peer_download.py +++ b/src/exo/download/tests/test_peer_download.py @@ -1233,6 +1233,192 @@ async def recording_meta(*args: object, **_kwargs: object) -> tuple[int, str]: ) +class TestPeerDownloadZeroByteFiles: + """Codex P2 (PR #16 round-(N+10), peer_shard_downloader.py:354): + The peer transfer path skipped every file whose declared size was + 0 (e.g. ``.gitattributes`` markers, empty ``__init__.py`` shims), + so DownloadCompleted was published with an incomplete local + snapshot. Loaders that probe for those marker files at runtime + (chat-template adapters, processor configs that expect an empty + sentinel) then failed in ways that didn't point back at the peer + step. The fix materializes the zero-byte files locally after the + rest of the peer transfer succeeds. + """ + + async def test_zero_byte_marker_files_materialized_after_peer_transfer( + self, monkeypatch: pytest.MonkeyPatch, tmp_path: Path + ) -> None: + """A repo containing canonical bytes plus an empty marker file + must end the peer transfer with BOTH on disk -- the marker is + a zero-byte file that pre-fix was silently dropped. + """ + from exo.download import peer_shard_downloader as psd + from exo.download.peer_download import PeerFileInfo + from exo.shared.types.worker.downloads import FileListEntry + + served = [ + FileListEntry(type="file", path="model.safetensors", size=10), + # Zero-byte sentinel; pre-fix the peer path silently + # skipped this and the local snapshot was incomplete. + FileListEntry(type="file", path=".gitattributes", size=0), + # Empty shim that loaders sometimes probe for. + FileListEntry(type="file", path="empty/__init__.py", size=0), + ] + + async def fake_fetch(*_args: object, **_kwargs: object) -> list[FileListEntry]: + return served + + async def fake_peer_status( + peer_host: str, # noqa: ARG001 + peer_port: int, # noqa: ARG001 + model_id_normalized: str, # noqa: ARG001 + timeout: float = 5.0, # noqa: ARG001 + ) -> list[PeerFileInfo] | None: + # The peer reports only the canonical bytes (mirrors + # production peers; HF-shard listings do not include + # zero-byte markers either). + return [ + PeerFileInfo( + path="model.safetensors", size=10, complete=True, safe_bytes=10 + ) + ] + + async def fake_resolve_dir(_model_id: ModelId) -> Path: + return tmp_path + + async def fake_resolve_allow(_shard: ShardMetadata) -> list[str]: + return ["*"] + + target_path = tmp_path / "model.safetensors" + + async def fake_download( + peer_ip: str, # noqa: ARG001 + peer_port: int, # noqa: ARG001 + model_id_normalized: str, # noqa: ARG001 + file_path: str, # noqa: ARG001 + target_dir: Path, # noqa: ARG001 + expected_size: int, # noqa: ARG001 + on_progress: object = None, # noqa: ARG001 + ) -> Path | None: + async with aiofiles.open(target_path, "wb") as f: + await f.write(b"0123456789") + return target_path + + monkeypatch.setattr(psd, "fetch_file_list_with_cache", fake_fetch) + monkeypatch.setattr(psd, "get_peer_file_status", fake_peer_status) + monkeypatch.setattr(psd, "resolve_model_dir", fake_resolve_dir) + monkeypatch.setattr(psd, "resolve_allow_patterns", fake_resolve_allow) + monkeypatch.setattr(psd, "download_file_from_peer", fake_download) + + downloader = PeerAwareShardDownloader(NoopShardDownloader(), offline=True) + shard = _make_shard(ModelId("test-org/model-a")) + + result = await downloader._try_peer_download( + shard, + peer_ip="10.0.0.1", + peer_port=52415, + model_id_normalized="test-org/model-a", + ) + assert result is not None, ( + "peer transfer must succeed when the only missing 'files' " + "are zero-byte markers; pre-fix the path returned success " + "without materializing them, so subsequent loads broke" + ) + assert await aios.path.exists(target_path), ( + "the canonical safetensor must still be present" + ) + # The crux of the regression test: zero-byte markers MUST be on disk. + gitattributes = tmp_path / ".gitattributes" + empty_shim = tmp_path / "empty" / "__init__.py" + assert await aios.path.exists(gitattributes), ( + "zero-byte ``.gitattributes`` marker must be materialized on " + "disk after peer transfer; pre-fix it was silently skipped " + "and DownloadCompleted reported success on an incomplete dir" + ) + assert await aios.path.exists(empty_shim), ( + "zero-byte ``empty/__init__.py`` shim must exist after peer " + "transfer (parent dir must also be created)" + ) + # Both must literally be empty. + assert (await aios.stat(gitattributes)).st_size == 0 + assert (await aios.stat(empty_shim)).st_size == 0 + + async def test_zero_byte_files_not_created_when_canonical_transfer_fails( + self, monkeypatch: pytest.MonkeyPatch, tmp_path: Path + ) -> None: + """If the non-empty file transfer fails, the zero-byte markers + must NOT be created. Otherwise the local model dir would + contain orphan empty files masquerading as a partial download + and the HF fallback might skip them. + """ + from exo.download import peer_shard_downloader as psd + from exo.download.peer_download import PeerFileInfo + from exo.shared.types.worker.downloads import FileListEntry + + served = [ + FileListEntry(type="file", path="model.safetensors", size=10), + FileListEntry(type="file", path=".gitattributes", size=0), + ] + + async def fake_fetch(*_args: object, **_kwargs: object) -> list[FileListEntry]: + return served + + async def fake_peer_status( + peer_host: str, # noqa: ARG001 + peer_port: int, # noqa: ARG001 + model_id_normalized: str, # noqa: ARG001 + timeout: float = 5.0, # noqa: ARG001 + ) -> list[PeerFileInfo] | None: + return [ + PeerFileInfo( + path="model.safetensors", size=10, complete=True, safe_bytes=10 + ) + ] + + async def fake_resolve_dir(_model_id: ModelId) -> Path: + return tmp_path + + async def fake_resolve_allow(_shard: ShardMetadata) -> list[str]: + return ["*"] + + async def failing_download( + peer_ip: str, # noqa: ARG001 + peer_port: int, # noqa: ARG001 + model_id_normalized: str, # noqa: ARG001 + file_path: str, # noqa: ARG001 + target_dir: Path, # noqa: ARG001 + expected_size: int, # noqa: ARG001 + on_progress: object = None, # noqa: ARG001 + ) -> Path | None: + return None + + monkeypatch.setattr(psd, "fetch_file_list_with_cache", fake_fetch) + monkeypatch.setattr(psd, "get_peer_file_status", fake_peer_status) + monkeypatch.setattr(psd, "resolve_model_dir", fake_resolve_dir) + monkeypatch.setattr(psd, "resolve_allow_patterns", fake_resolve_allow) + monkeypatch.setattr(psd, "download_file_from_peer", failing_download) + + downloader = PeerAwareShardDownloader(NoopShardDownloader(), offline=True) + shard = _make_shard(ModelId("test-org/model-a")) + + result = await downloader._try_peer_download( + shard, + peer_ip="10.0.0.1", + peer_port=52415, + model_id_normalized="test-org/model-a", + ) + assert result is None, ( + "peer transfer must report failure when the non-empty " + "canonical bytes never landed; the HF fallback then runs" + ) + gitattributes = tmp_path / ".gitattributes" + assert not await aios.path.exists(gitattributes), ( + "zero-byte markers must NOT be created if the canonical " + "transfer failed -- otherwise the partial dir confuses the " + "HF fallback's already-downloaded probe" + ) + + def _allocate_free_tcp_port() -> int: """Bind ephemeral port 0 to grab a free TCP port; close before reuse. From 363e695c867c15699b34bd9ec4e89f3a74021b88 Mon Sep 17 00:00:00 2001 From: jw-wcv <101585096+jw-wcv@users.noreply.github.com> Date: Sat, 9 May 2026 07:56:18 -0700 Subject: [PATCH 27/29] PR #16 R(N+11) P1: distinguish unknown-size from zero-byte in peer transfer Round-(N+10)'s fix lumped FileListEntry(size=None) together with size=0 markers and materialized both as empty local files. But fetch_file_list_with_cache returns size=None for files discovered via the safetensors index whose size wasn't in the HF API response -- those are *real weight shards*, not markers. Treating them as empty produced 'DownloadCompleted' snapshots with corrupted weights that failed only at load/inference time, far from the peer step that caused the corruption. Split the two cases: - size == 0 stays the marker materialization path (touch on disk). - size is None aborts the peer transfer with a logged warning so the HF fallback gets a real download path. A pre-pass over filtered_file_list detects size=None and missing peer_info BEFORE any download_one coroutines are constructed, so the early bail-out can't leak un-awaited coroutines. Add a regression test that builds a file list with one canonical real-size file plus one size=None weight shard and asserts the peer transfer aborts (returns None), the unknown-size file is NOT created locally, and download_file_from_peer is never called. --- src/exo/download/peer_shard_downloader.py | 38 ++++++- src/exo/download/tests/test_peer_download.py | 106 +++++++++++++++++++ 2 files changed, 143 insertions(+), 1 deletion(-) diff --git a/src/exo/download/peer_shard_downloader.py b/src/exo/download/peer_shard_downloader.py index e37b3d02bb..d92a407dec 100644 --- a/src/exo/download/peer_shard_downloader.py +++ b/src/exo/download/peer_shard_downloader.py @@ -359,10 +359,46 @@ def on_file_progress( # those marker files (model loaders, processors that probe # for ``chat_template.json``, etc.) would then fail in ways # that don't point back at the peer step. + # + # Codex P1 (PR #16 round-(N+11), peer_shard_downloader.py:354): + # ``size is None`` is *not* the same as ``size == 0``. + # ``fetch_file_list_with_cache`` returns ``FileListEntry(size=None)`` + # for files discovered via the safetensors index (e.g. weight + # shards whose size is not in the HF API response). Pre-fix + # the previous round lumped ``None`` together with literal + # zero and materialized those weight files as empty, + # producing a "DownloadCompleted" snapshot with corrupted / + # incomplete weights that failed only at load/inference + # time. Split the cases: ``== 0`` is materialized as an + # empty marker; ``is None`` aborts the peer transfer and + # forces the HF fallback so the file gets a real download + # path. + # + # Pre-pass: detect bail-out conditions before constructing any + # ``download_one`` coroutines so we don't leak un-awaited + # coroutines on the unknown-size or missing-peer-info paths. + for f in filtered_file_list: + if f.size is None: + logger.info( + f"Peer transfer for {model_id_normalized} aborted: " + f"unknown-size entry {f.path!r} (size=None) cannot " + f"be safely transferred over peer; falling back to HF" + ) + return None + if f.size == 0: + continue + peer_info = peer_file_map.get(f.path) + if not peer_info or peer_info.safe_bytes <= 0: + # Real-size file the peer doesn't have => abort transfer. + return None + zero_byte_files: list[str] = [] tasks: list[Coroutine[Any, Any, bool]] = [] for f in filtered_file_list: - if f.size is None or f.size == 0: + if f.size is None: + # Pre-pass already bailed; safety net for type-narrowing. + return None + if f.size == 0: # Defer the local touch until after we know the rest # of the peer transfer succeeded; a partial peer # transfer should not leave behind orphan empty diff --git a/src/exo/download/tests/test_peer_download.py b/src/exo/download/tests/test_peer_download.py index 2c7445bfac..5cdd17b656 100644 --- a/src/exo/download/tests/test_peer_download.py +++ b/src/exo/download/tests/test_peer_download.py @@ -1343,6 +1343,112 @@ async def fake_download( assert (await aios.stat(gitattributes)).st_size == 0 assert (await aios.stat(empty_shim)).st_size == 0 + async def test_unknown_size_file_aborts_peer_transfer_for_hf_fallback( + self, monkeypatch: pytest.MonkeyPatch, tmp_path: Path + ) -> None: + """Codex P1 (PR #16 round-(N+11), peer_shard_downloader.py:354): + ``FileListEntry(size=None)`` is NOT a zero-byte marker -- the + upstream ``fetch_file_list_with_cache`` returns ``size=None`` + for files discovered via the safetensors index whose size + wasn't in the HF API response (real weight shards). Pre-fix + the round-(N+10) materialize-as-empty path treated those as + empty markers and reported peer transfer success on a + corrupted snapshot. + + Post-fix, ``size is None`` aborts the peer transfer (returns + None) so the HF fallback gets a real download path. We + construct a file list with a real safetensor (size=10) and + an unknown-size weight shard (size=None) and assert the + peer transfer returns None *without* materializing the + unknown-size entry as an empty file. + """ + from exo.download import peer_shard_downloader as psd + from exo.download.peer_download import PeerFileInfo + from exo.shared.types.worker.downloads import FileListEntry + + served = [ + FileListEntry(type="file", path="model.safetensors", size=10), + # Unknown size: real weight shard from safetensors index. + FileListEntry( + type="file", path="model-00002-of-00003.safetensors", size=None + ), + ] + + async def fake_fetch(*_args: object, **_kwargs: object) -> list[FileListEntry]: + return served + + async def fake_peer_status( + peer_host: str, # noqa: ARG001 + peer_port: int, # noqa: ARG001 + model_id_normalized: str, # noqa: ARG001 + timeout: float = 5.0, # noqa: ARG001 + ) -> list[PeerFileInfo] | None: + return [ + PeerFileInfo( + path="model.safetensors", size=10, complete=True, safe_bytes=10 + ), + PeerFileInfo( + path="model-00002-of-00003.safetensors", + size=999, + complete=True, + safe_bytes=999, + ), + ] + + async def fake_resolve_dir(_model_id: ModelId) -> Path: + return tmp_path + + async def fake_resolve_allow(_shard: ShardMetadata) -> list[str]: + return ["*"] + + download_called = anyio.Event() + + async def fake_download( + peer_ip: str, # noqa: ARG001 + peer_port: int, # noqa: ARG001 + model_id_normalized: str, # noqa: ARG001 + file_path: str, # noqa: ARG001 + target_dir: Path, # noqa: ARG001 + expected_size: int, # noqa: ARG001 + on_progress: object = None, # noqa: ARG001 + ) -> Path | None: + download_called.set() + return None + + monkeypatch.setattr(psd, "fetch_file_list_with_cache", fake_fetch) + monkeypatch.setattr(psd, "get_peer_file_status", fake_peer_status) + monkeypatch.setattr(psd, "resolve_model_dir", fake_resolve_dir) + monkeypatch.setattr(psd, "resolve_allow_patterns", fake_resolve_allow) + monkeypatch.setattr(psd, "download_file_from_peer", fake_download) + + downloader = PeerAwareShardDownloader(NoopShardDownloader(), offline=True) + shard = _make_shard(ModelId("test-org/model-a")) + + result = await downloader._try_peer_download( + shard, + peer_ip="10.0.0.1", + peer_port=52415, + model_id_normalized="test-org/model-a", + ) + assert result is None, ( + "peer transfer must abort (return None) when the file list " + "contains a size=None entry; HF fallback then takes over to " + "ensure the unknown-size weight is properly downloaded. " + "Pre-fix the size=None entry was lumped with size=0 markers " + "and materialized as empty, producing corrupted snapshots." + ) + # The unknown-size file must NOT have been created as empty + # by the marker-materialization path. + unknown_path = tmp_path / "model-00002-of-00003.safetensors" + assert not await aios.path.exists(unknown_path), ( + "size=None entries must NOT be materialized as empty marker " + "files -- they're real weights of unknown size, not markers" + ) + assert not download_called.is_set(), ( + "peer transfer should abort BEFORE issuing any download " + "call when a size=None entry is encountered" + ) + async def test_zero_byte_files_not_created_when_canonical_transfer_fails( self, monkeypatch: pytest.MonkeyPatch, tmp_path: Path ) -> None: From 6fcc7437f8504eb84d9cfff67f5d5719eb566dc4 Mon Sep 17 00:00:00 2001 From: jw-wcv <101585096+jw-wcv@users.noreply.github.com> Date: Sat, 9 May 2026 09:49:40 -0700 Subject: [PATCH 28/29] PR #16 R(N+13) P1: serialize legacy keypair adoption across scopes Codex flagged a P1 finding on PR #16 head 8a9bc7ccb7 at ``src/exo/routing/router.py:359``: > ``get_node_id_keypair()`` locks on ``resolved_path``, so two > same-host processes with different ``process_scope`` values > acquire different lock files and can run legacy migration > concurrently. In the cross-device fallback path (``replace()`` > fails, then copy bytes), both processes can read the same legacy > keypair before it is unlinked and each write it to its own > scoped file, producing duplicate node identities despite > different scopes. This breaks the unique-NodeId assumptions in > routing/election during concurrent startup when cache/config > live on different filesystems. Root cause: ``_scoped_keypair_path`` (R(N+2)) intentionally folds the process scope into the on-disk filename so each process gets a distinct keypair, and the existing FileLock at ``router.py:357`` is keyed on that scoped path. Result: two concurrent same-host processes with distinct scopes acquire DIFFERENT lock files and never serialize against each other -- so both can enter ``_migrate_legacy_node_id_keypair`` concurrently before either has a chance to unlink the legacy file. ``replace()`` on the legacy path is atomic (only one wins), but the cross-device fallback (the ``OSError`` branch on Linux when ``XDG_*`` dirs span filesystems) reads-then-writes-then-unlinks, and that read-then- write window is wide enough for a second scope to read the same legacy bytes and copy them into its own scoped file. Two scoped keypairs end up holding identical bytes, so two same-host nodes boot with the same ``NodeId`` -- this is exactly the routing / election unique-identity assumption Codex called out. Fix: wrap ``_migrate_legacy_node_id_keypair`` in a second ``FileLock`` keyed on the **legacy** path. The legacy path is intentionally NOT scope-suffixed (it pre-dates per-process scoping), so the legacy-keyed lock is the single global serialization point shared across every scope. Lock ordering is ``resolved_path`` (outer) -> ``resolved_legacy`` (inner): no deadlock is possible because the legacy lock is only acquired while holding the per-scope lock and is released before keypair I/O resumes, so unrelated scopes' keypair I/O isn't blocked on identity housekeeping. Documented behaviour preserved: the docstring's "first process boots wins" semantic is now actually enforced -- the winner of the legacy lock unlinks the legacy file, the loser's migrator no-ops on the absent legacy and proceeds to generate a fresh keypair as the per-process isolation invariant requires. Regression: ``test_legacy_migration_serialized_across_process_scopes``. Forces the cross-device fallback by monkey-patching ``Path.replace`` to raise ``OSError`` on the legacy path, then pauses inside the byte-copy ``write_bytes`` for thread A while thread B starts up with a distinct scope. Pre-fix both threads slip through their per-scope locks and end up with identical scoped keypairs (the test's ``scope_a_bytes != scope_b_bytes`` assertion fails with identical byte strings -- verified by running the test against the pre-fix code via ``git stash``); post-fix the legacy lock blocks thread B until thread A finishes adoption, and exactly one scope ends up holding the legacy bytes while the other generates a fresh identity. --- src/exo/routing/router.py | 32 ++++- .../routing/tests/test_node_id_migration.py | 122 ++++++++++++++++++ 2 files changed, 153 insertions(+), 1 deletion(-) diff --git a/src/exo/routing/router.py b/src/exo/routing/router.py index 1319290f4d..5e42639475 100644 --- a/src/exo/routing/router.py +++ b/src/exo/routing/router.py @@ -326,6 +326,12 @@ def get_node_id_keypair( migration is performed INSIDE the file lock so two concurrent processes can't both pass the existence check and then race each other into divergent in-memory vs. on-disk identities. + Codex P1 (PR #16 round-(N+13), router.py:359): when callers + pass distinct ``process_scope`` values, the per-scope lock + above does NOT serialize legacy adoption across scopes, so a + second lock keyed on the (unscoped) legacy path is acquired + before invoking the migrator -- otherwise the cross-device + byte-copy fallback can produce duplicate ``NodeId``s. """ base_path = Path(str(path)) resolved_path = ( @@ -356,7 +362,31 @@ def lock_path(p: str | bytes | PathLike[str] | PathLike[bytes]) -> Path: # on-disk file but divergent in-memory identities. with FileLock(lock_path(resolved_path)): if resolved_legacy is not None: - _migrate_legacy_node_id_keypair(resolved_path, resolved_legacy) + # Codex P1 (PR #16 round-(N+13), router.py:359): + # serialize legacy adoption across ALL ``process_scope`` + # values. The outer ``resolved_path`` lock is per-scope, + # so two same-host processes with different scopes + # acquire DIFFERENT lock files and can each enter + # ``_migrate_legacy_node_id_keypair`` concurrently. In + # the cross-device fallback path -- where ``replace()`` + # raises ``OSError`` and the migrator falls back to a + # ``read_bytes`` + ``write_bytes`` + ``unlink`` + # sequence -- both processes can read the same legacy + # keypair before either unlinks it, then each writes + # those bytes into its own scoped file. Result: two + # nodes claiming the same ``NodeId`` despite distinct + # scopes, breaking routing's unique-identity and + # election's tiebreaker invariants. A lock keyed on the + # legacy path (which is intentionally NOT scope-suffixed + # because it pre-dates scoping) serializes migration so + # exactly one scope wins legacy adoption and any + # concurrent peers observe the file already gone and + # generate fresh keypairs -- the documented "first + # process boots wins" semantic. Released immediately + # after migration so unrelated keypair I/O on other + # scopes isn't blocked on identity housekeeping. + with FileLock(lock_path(resolved_legacy)): + _migrate_legacy_node_id_keypair(resolved_path, resolved_legacy) with open(resolved_path, "a+b") as f: # opens in append-mode => starts at EOF # if non-zero EOF, then file exists => use to get node-ID diff --git a/src/exo/routing/tests/test_node_id_migration.py b/src/exo/routing/tests/test_node_id_migration.py index bf3860719b..bfa0a8f8a9 100644 --- a/src/exo/routing/tests/test_node_id_migration.py +++ b/src/exo/routing/tests/test_node_id_migration.py @@ -388,6 +388,128 @@ def test_libp2p_port_zero_in_two_processes_yield_distinct_scopes( ) +def test_legacy_migration_serialized_across_process_scopes( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + """Codex P1 (PR #16 round-(N+13), router.py:359): legacy + adoption MUST be serialized across all ``process_scope`` values, + even when the per-scope ``resolved_path`` lock differs and the + cross-device byte-copy fallback path is taken inside + ``_migrate_legacy_node_id_keypair``. + + Pre-fix this test produces two identical scoped keypairs (both + matching the legacy bytes), simulating two same-host processes + racing legacy adoption: each acquires its own per-scope lock, + both fall through to the byte-copy branch, both read the same + legacy bytes, and both end up writing those bytes to their own + scoped file -- duplicate ``NodeId`` despite distinct scopes. + + Post-fix the migrator is wrapped in a second FileLock keyed on + the legacy path. The first scope wins adoption and unlinks the + legacy file; the second scope's migrator no-ops on the absent + legacy and generates a fresh keypair, so the two scopes diverge + as required by the per-process isolation invariant. + + We simulate the cross-device fallback by monkey-patching + ``Path.replace`` to raise ``OSError`` (the same trigger that + fires on Linux when ``XDG_*`` dirs span filesystems). The + serialization invariant is asserted by also blocking the byte + copy with a ``threading.Event`` so two threads must contend on + the legacy lock; only one thread should observe the legacy + file present at copy time. + """ + import threading + + import exo.routing.router as router_mod + + legacy_path = tmp_path / "cache" / "node_id.keypair" + base_path = tmp_path / "config" / "node_id.keypair" + legacy_path.parent.mkdir(parents=True) + base_path.parent.mkdir(parents=True) + + legacy_bytes = Keypair.generate().to_bytes() + legacy_path.write_bytes(legacy_bytes) + + # Force the cross-device fallback so the migrator goes through + # the read_bytes/write_bytes/unlink sequence (the path Codex + # flagged as racy). + real_replace = Path.replace + + def _force_cross_device(self: Path, target: object) -> object: # noqa: ANN001 + if Path(self) == legacy_path: + raise OSError("simulated cross-device link error") + return real_replace(self, target) # pyright: ignore[reportArgumentType] + + monkeypatch.setattr(Path, "replace", _force_cross_device) + + # Pause inside the byte-copy branch so two threads pile up on + # the legacy lock while one thread holds it. Without the legacy + # lock both threads would observe the legacy file present at + # this point and both would proceed to write_bytes/unlink. + in_copy = threading.Event() + release_copy = threading.Event() + real_write_bytes = Path.write_bytes + + def _slow_write_bytes(self: Path, data: bytes) -> int: + if self.parent == base_path.parent: + in_copy.set() + release_copy.wait(timeout=5.0) + return real_write_bytes(self, data) + + monkeypatch.setattr(Path, "write_bytes", _slow_write_bytes) + + keypairs: dict[int, Keypair] = {} + + def _run(scope: int) -> None: + keypairs[scope] = router_mod.get_node_id_keypair( + path=base_path, legacy_path=legacy_path, process_scope=scope + ) + + thread_a = threading.Thread(target=_run, args=(52416,), daemon=True) + thread_b = threading.Thread(target=_run, args=(52417,), daemon=True) + thread_a.start() + in_copy.wait(timeout=5.0) + # While thread_a is paused inside the byte copy holding the + # legacy lock, thread_b should be blocked on the legacy lock -- + # NOT racing through its own byte copy of the same legacy file. + thread_b.start() + # Give thread_b a moment to attempt acquiring the legacy lock + # so we can assert it did not slip through. + thread_b.join(timeout=0.2) + assert thread_b.is_alive(), ( + "second scope must be blocked on the legacy lock while the " + "first scope is mid-copy; if this fails, both scopes will " + "duplicate the legacy NodeId via the byte-copy race" + ) + release_copy.set() + thread_a.join(timeout=5.0) + thread_b.join(timeout=5.0) + assert not thread_a.is_alive() and not thread_b.is_alive() + + scope_a_bytes = keypairs[52416].to_bytes() + scope_b_bytes = keypairs[52417].to_bytes() + assert scope_a_bytes != scope_b_bytes, ( + "concurrent legacy adoption across distinct process_scope " + "values must NOT produce duplicate keypairs; the legacy " + "lock should let exactly one scope adopt the legacy bytes " + "while the other generates a fresh identity" + ) + # Exactly one scoped file should match the legacy bytes (the + # winner of adoption); the other was generated fresh. + scoped_a = base_path.parent / "node_id.52416.keypair" + scoped_b = base_path.parent / "node_id.52417.keypair" + matches = sum( + 1 + for p in (scoped_a, scoped_b) + if p.exists() and p.read_bytes() == legacy_bytes + ) + assert matches == 1, ( + f"exactly one scope must have adopted the legacy bytes; " + f"matches={matches} indicates the cross-device race fired" + ) + assert not legacy_path.exists(), "legacy file must be unlinked after adoption" + + def test_legacy_migration_adopts_into_scoped_path(tmp_path: Path) -> None: """When a process passes a scope and a legacy unscoped keypair exists, the legacy bytes must be adopted into the scoped path. From 19d6ce179e587195b1249232a8b8698fe70c4dcd Mon Sep 17 00:00:00 2001 From: jw-wcv <101585096+jw-wcv@users.noreply.github.com> Date: Sat, 9 May 2026 10:34:36 -0700 Subject: [PATCH 29/29] PR #16 R(N+14) P2: mark zero-byte peer files complete in progress map Codex flagged a P2 finding on PR #16 head a1374b4f1d at ``src/exo/download/peer_shard_downloader.py:407``: > When a repo includes zero-byte files, this branch skips > ``download_one`` and later materializes marker files on disk, > but it never updates ``file_progress`` for those paths. As a > result, ``calculate_repo_progress()`` can leave the overall > status as ``not_started`` even after all bytes are present, so > ``_download_progress_callback`` never emits ``DownloadCompleted`` > immediately and the model can remain stuck in ``DownloadOngoing`` > until the periodic reconciliation loop runs. Root cause: the round-(N+10) zero-byte materialization path explicitly skips ``download_one`` for zero-byte files, but ``download_one`` is the SOLE writer of the per-file ``status="complete"`` transition (via its inner ``on_file_progress`` callback). The seeded entry for a zero-byte file at line 338 defaults to ``status="not_started"`` and stays that way through the materialization step, so the final ``calculate_repo_progress`` rollup -- which uses per-file statuses to derive the overall repo status -- emits ``RepoDownloadProgress(status="not_started")`` even though every file is on disk. ``_download_progress_callback`` in the ``DownloadCoordinator`` only publishes ``DownloadCompleted`` on ``status="complete"``, so the model's effective state stays at ``DownloadOngoing`` until the periodic ``_emit_existing_download_progress`` reconciliation loop notices the on-disk snapshot and force-promotes it. That delay can silently break test fixtures that expect synchronous completion (the bench harness, the API's polling ``StartDownload`` -> ``DownloadCompleted`` waiter) and forces an unnecessary HF re-validation step for offline / air-gapped users who restart between transfer and reconciliation. Fix: after materializing each zero-byte marker on disk, replace its seeded ``not_started`` entry in ``file_progress`` with a fully-complete ``RepoFileDownloadProgress(status="complete")``. ``RepoFileDownloadProgress`` is frozen so we replace the dict slot rather than mutating in place. The pattern mirrors the regular file completion path in ``download_one``'s ``on_file_progress`` callback (``status="complete" if is_renamed else "in_progress"`` resolves to ``complete`` for the final emission), preserving the documented progress-rollup invariant. Regression: ``test_zero_byte_files_marked_complete_in_progress_map`` exercises the zero-byte materialization fixture from the round-(N+10) test and additionally captures the final progress callback emission. The canonical safetensor's progress callback is now invoked by the test fake (matching the production ``download_file_from_peer`` contract), so the canonical entry's status flips to ``complete`` correctly. Pre-fix the rolled-up status is ``not_started`` because the zero-byte entries never transition; post-fix every per-file entry is ``complete`` and the rollup status is ``complete`` -- the regression Codex called out is locked in. Verified the test fails on the pre-fix coordinator via ``git stash``. --- src/exo/download/peer_shard_downloader.py | 30 +++++ src/exo/download/tests/test_peer_download.py | 129 +++++++++++++++++++ 2 files changed, 159 insertions(+) diff --git a/src/exo/download/peer_shard_downloader.py b/src/exo/download/peer_shard_downloader.py index d92a407dec..5d4e127d4c 100644 --- a/src/exo/download/peer_shard_downloader.py +++ b/src/exo/download/peer_shard_downloader.py @@ -439,6 +439,36 @@ def on_file_progress( f"falling back to HF for full snapshot integrity" ) return None + # Codex P2 (PR #16 round-(N+13), peer_shard_downloader.py:407): + # ``download_one`` -> ``on_file_progress`` is the only + # writer of the per-file ``status="complete"`` marker; + # the zero-byte branch never invokes it (there are no + # bytes to stream), so the file_progress entry seeded + # at line 338 stays at ``status="not_started"``. The + # final ``calculate_repo_progress`` call below then + # rolls those up into a non-``complete`` overall status, + # which means ``_download_progress_callback`` does NOT + # publish ``DownloadCompleted`` -- the model gets stuck + # in ``DownloadOngoing`` until the periodic + # reconciliation loop in ``DownloadCoordinator`` notices + # the on-disk snapshot and force-updates the status. + # Mirror the regular file completion path by overwriting + # the seeded entry with a fully-complete one once the + # marker is on disk. ``RepoFileDownloadProgress`` is + # frozen, so we replace the dict slot rather than + # mutating the existing instance. + file_progress[marker_path] = RepoFileDownloadProgress( + repo_id=str(shard.model_card.model_id), + repo_revision=revision, + file_path=marker_path, + downloaded=Memory.from_bytes(0), + downloaded_this_session=Memory.from_bytes(0), + total=Memory.from_bytes(0), + speed=0, + eta=timedelta(0), + status="complete", + start_time=all_start_time, + ) # Emit final progress final_progress = calculate_repo_progress( diff --git a/src/exo/download/tests/test_peer_download.py b/src/exo/download/tests/test_peer_download.py index 5cdd17b656..3a753eea48 100644 --- a/src/exo/download/tests/test_peer_download.py +++ b/src/exo/download/tests/test_peer_download.py @@ -1343,6 +1343,135 @@ async def fake_download( assert (await aios.stat(gitattributes)).st_size == 0 assert (await aios.stat(empty_shim)).st_size == 0 + async def test_zero_byte_files_marked_complete_in_progress_map( + self, monkeypatch: pytest.MonkeyPatch, tmp_path: Path + ) -> None: + """Codex P2 (PR #16 round-(N+13), peer_shard_downloader.py:407): + zero-byte files must be marked ``status="complete"`` in the + progress map AFTER materialization, otherwise the final + ``calculate_repo_progress`` call rolls them up as + ``status="not_started"`` and the overall repo status stays + non-complete -- so ``_download_progress_callback`` does not + publish ``DownloadCompleted`` immediately and the model is + stuck in ``DownloadOngoing`` until reconciliation runs. + + We exercise the same fixture as the materialization test, + but capture the *final* progress callback emission (the one + the coordinator turns into ``DownloadCompleted``) and + assert its ``status`` is ``"complete"`` and that every + per-file entry is also ``"complete"``. + """ + from exo.download import peer_shard_downloader as psd + from exo.download.download_utils import RepoDownloadProgress + from exo.download.peer_download import PeerFileInfo + from exo.shared.types.worker.downloads import FileListEntry + + served = [ + FileListEntry(type="file", path="model.safetensors", size=10), + FileListEntry(type="file", path=".gitattributes", size=0), + FileListEntry(type="file", path="empty/__init__.py", size=0), + ] + + async def fake_fetch(*_args: object, **_kwargs: object) -> list[FileListEntry]: + return served + + async def fake_peer_status( + peer_host: str, # noqa: ARG001 + peer_port: int, # noqa: ARG001 + model_id_normalized: str, # noqa: ARG001 + timeout: float = 5.0, # noqa: ARG001 + ) -> list[PeerFileInfo] | None: + return [ + PeerFileInfo( + path="model.safetensors", size=10, complete=True, safe_bytes=10 + ) + ] + + async def fake_resolve_dir(_model_id: ModelId) -> Path: + return tmp_path + + async def fake_resolve_allow(_shard: ShardMetadata) -> list[str]: + return ["*"] + + target_path = tmp_path / "model.safetensors" + + async def fake_download( + peer_ip: str, # noqa: ARG001 + peer_port: int, # noqa: ARG001 + model_id_normalized: str, # noqa: ARG001 + file_path: str, # noqa: ARG001 + target_dir: Path, # noqa: ARG001 + expected_size: int, + on_progress: Callable[[int, int, bool], None] = lambda _a, _b, _c: None, + ) -> Path | None: + async with aiofiles.open(target_path, "wb") as f: + await f.write(b"0123456789") + # Match the production peer_download contract: emit the + # final rename-completed progress callback so the + # canonical-file's per-file progress entry transitions + # to ``status="complete"`` like it would in production. + on_progress(expected_size, expected_size, True) + return target_path + + monkeypatch.setattr(psd, "fetch_file_list_with_cache", fake_fetch) + monkeypatch.setattr(psd, "get_peer_file_status", fake_peer_status) + monkeypatch.setattr(psd, "resolve_model_dir", fake_resolve_dir) + monkeypatch.setattr(psd, "resolve_allow_patterns", fake_resolve_allow) + monkeypatch.setattr(psd, "download_file_from_peer", fake_download) + + downloader = PeerAwareShardDownloader(NoopShardDownloader(), offline=True) + shard = _make_shard(ModelId("test-org/model-a")) + + # Capture the final progress emitted by the peer downloader + # so we can assert its rolled-up status. + captured: list[RepoDownloadProgress] = [] + + async def capture_progress( + _shard: ShardMetadata, progress: RepoDownloadProgress + ) -> None: + captured.append(progress) + + downloader._progress_callbacks.append(capture_progress) + + result = await downloader._try_peer_download( + shard, + peer_ip="10.0.0.1", + peer_port=52415, + model_id_normalized="test-org/model-a", + ) + assert result is not None + assert captured, ( + "peer downloader must emit at least one progress event " + "(the rolled-up final status); pre-fix the test never " + "got past this because the canonical file's per-byte " + "callback also triggers an emit" + ) + final = captured[-1] + assert final.status == "complete", ( + "rolled-up final repo progress must be ``complete`` once " + "every file (including zero-byte markers) is on disk; " + "pre-(N+13)-fix the zero-byte entries stayed at " + "``not_started`` so the rollup was non-complete and " + "DownloadCompleted was never published. " + f"final.status={final.status!r} " + f"per_file={[(p, e.status) for p, e in final.file_progress.items()]}" + ) + for marker in (".gitattributes", "empty/__init__.py"): + entry = final.file_progress.get(marker) + assert entry is not None, ( + f"file_progress must contain entry for {marker!r}; " + "pre-fix the seeded ``not_started`` entry was never " + "updated, so this assert succeeded but on the wrong " + "status -- this version of the assert covers both " + "regressions (entry presence and final status)" + ) + assert entry.status == "complete", ( + f"zero-byte marker {marker!r} must be marked complete " + f"in the progress map after materialization; " + f"pre-fix status was {entry.status!r} which causes " + f"calculate_repo_progress to roll up to non-complete" + ) + async def test_unknown_size_file_aborts_peer_transfer_for_hf_fallback( self, monkeypatch: pytest.MonkeyPatch, tmp_path: Path ) -> None: