From f792bd5d520e3bcd2aeccd7b09a41676e1e5e366 Mon Sep 17 00:00:00 2001 From: Alex Cheema Date: Sun, 3 May 2026 01:06:39 +0100 Subject: [PATCH 01/11] feat: store input chunks in state --- src/exo/api/main.py | 14 +-- src/exo/shared/apply.py | 34 +++++++- .../test_apply/test_apply_input_chunks.py | 85 +++++++++++++++++++ src/exo/shared/types/state.py | 6 +- src/exo/worker/main.py | 73 +++++++--------- 5 files changed, 153 insertions(+), 59 deletions(-) create mode 100644 src/exo/shared/tests/test_apply/test_apply_input_chunks.py diff --git a/src/exo/api/main.py b/src/exo/api/main.py index 8fe0cfbecb..902bf3162b 100644 --- a/src/exo/api/main.py +++ b/src/exo/api/main.py @@ -254,7 +254,6 @@ def __init__( self.node_id: NodeId = node_id self.last_completed_election: int = 0 self.port = port - self._sent_image_hashes: set[str] = set() self.paused: bool = False self.paused_ev: anyio.Event = anyio.Event() @@ -304,7 +303,6 @@ def reset(self, result_clock: int, event_receiver: Receiver[IndexedEvent]): self.event_receiver.close() self.event_receiver = event_receiver self._tg.start_soon(self._apply_state) - self._sent_image_hashes = set() def unpause(self, result_clock: int): logger.info("Unpausing API") @@ -826,18 +824,8 @@ async def _send_text_generation_with_images( ) command = TextGeneration(task_params=task_params) - new_images: list[tuple[int, str]] = [] - for idx, (img, h) in enumerate(zip(images, hashes, strict=True)): - if h not in self._sent_image_hashes: - self._sent_image_hashes.add(h) - new_images.append((idx, img)) - - if not new_images: - await self._send(command) - return command - all_chunks: list[tuple[int, str]] = [] - for img_idx, img_data in new_images: + for img_idx, img_data in enumerate(images): for i in range(0, len(img_data), EXO_MAX_CHUNK_SIZE): all_chunks.append((img_idx, img_data[i : i + EXO_MAX_CHUNK_SIZE])) diff --git a/src/exo/shared/apply.py b/src/exo/shared/apply.py index 959f7765b9..b3ff361980 100644 --- a/src/exo/shared/apply.py +++ b/src/exo/shared/apply.py @@ -40,7 +40,14 @@ ThunderboltBridgeStatus, ) from exo.shared.types.state import State -from exo.shared.types.tasks import Task, TaskId, TaskStatus +from exo.shared.types.tasks import ( + ImageEdits, + ImageGeneration, + Task, + TaskId, + TaskStatus, + TextGeneration, +) from exo.shared.types.topology import Connection, RDMAConnection from exo.shared.types.worker.downloads import DownloadProgress from exo.shared.types.worker.instances import Instance, InstanceId @@ -72,7 +79,6 @@ def event_apply(event: Event, state: State) -> State: TestEvent() | ChunkGenerated() | TaskAcknowledged() - | InputChunkReceived() | TracesCollected() | TracesMerged() | CustomModelCardAdded() @@ -93,6 +99,8 @@ def event_apply(event: Event, state: State) -> State: return apply_runner_status_updated(event, state) case TaskCreated(): return apply_task_created(event, state) + case InputChunkReceived(): + return apply_input_chunk_received(event, state) case TaskDeleted(): return apply_task_deleted(event, state) case TaskFailed(): @@ -157,10 +165,32 @@ def apply_task_created(event: TaskCreated, state: State) -> State: return state.model_copy(update={"tasks": new_tasks}) +def apply_input_chunk_received(event: InputChunkReceived, state: State) -> State: + command_chunks = { + **state.input_chunks.get(event.command_id, {}), + event.chunk.chunk_index: event.chunk, + } + return state.model_copy( + update={ + "input_chunks": {**state.input_chunks, event.command_id: command_chunks} + } + ) + + def apply_task_deleted(event: TaskDeleted, state: State) -> State: + task = state.tasks.get(event.task_id) new_tasks: Mapping[TaskId, Task] = { tid: task for tid, task in state.tasks.items() if tid != event.task_id } + if isinstance(task, (TextGeneration, ImageGeneration, ImageEdits)): + new_input_chunks = { + command_id: chunks + for command_id, chunks in state.input_chunks.items() + if command_id != task.command_id + } + return state.model_copy( + update={"tasks": new_tasks, "input_chunks": new_input_chunks} + ) return state.model_copy(update={"tasks": new_tasks}) diff --git a/src/exo/shared/tests/test_apply/test_apply_input_chunks.py b/src/exo/shared/tests/test_apply/test_apply_input_chunks.py new file mode 100644 index 0000000000..d5d494432f --- /dev/null +++ b/src/exo/shared/tests/test_apply/test_apply_input_chunks.py @@ -0,0 +1,85 @@ +from exo.shared.apply import apply +from exo.shared.models.model_cards import ModelId +from exo.shared.types.chunks import InputImageChunk +from exo.shared.types.common import CommandId +from exo.shared.types.events import ( + IndexedEvent, + InputChunkReceived, + TaskCreated, + TaskDeleted, +) +from exo.shared.types.state import State +from exo.shared.types.tasks import TaskId, TaskStatus, TextGeneration +from exo.shared.types.text_generation import ( + InputMessage, + InputMessageContent, + TextGenerationTaskParams, +) +from exo.shared.types.worker.instances import InstanceId + + +def test_apply_input_chunk_received_stores_chunk_in_state() -> None: + command_id = CommandId("command-1") + chunk = InputImageChunk( + model=ModelId("mlx-community/test-model"), + command_id=command_id, + data="abc", + chunk_index=0, + total_chunks=1, + image_index=0, + ) + + state = apply( + State(), + IndexedEvent( + idx=0, + event=InputChunkReceived(command_id=command_id, chunk=chunk), + ), + ) + + assert state.input_chunks == {command_id: {0: chunk}} + + +def test_apply_task_deleted_removes_chunks_for_generation_command() -> None: + command_id = CommandId("command-1") + task_id = TaskId("task-1") + chunk = InputImageChunk( + model=ModelId("mlx-community/test-model"), + command_id=command_id, + data="abc", + chunk_index=0, + total_chunks=1, + image_index=0, + ) + task = TextGeneration( + task_id=task_id, + instance_id=InstanceId("instance-1"), + task_status=TaskStatus.Pending, + command_id=command_id, + task_params=TextGenerationTaskParams( + model=ModelId("mlx-community/test-model"), + input=[ + InputMessage(role="user", content=InputMessageContent("hello")), + ], + ), + ) + + state = State() + state = apply( + state, + IndexedEvent( + idx=0, + event=InputChunkReceived(command_id=command_id, chunk=chunk), + ), + ) + state = apply( + state, + IndexedEvent(idx=1, event=TaskCreated(task_id=task_id, task=task)), + ) + state = apply( + state, + IndexedEvent(idx=2, event=TaskDeleted(task_id=task_id)), + ) + + assert state.tasks == {} + assert state.input_chunks == {} diff --git a/src/exo/shared/types/state.py b/src/exo/shared/types/state.py index 6c976984c8..e0b39e7ff3 100644 --- a/src/exo/shared/types/state.py +++ b/src/exo/shared/types/state.py @@ -6,7 +6,8 @@ from pydantic.alias_generators import to_camel from exo.shared.topology import Topology, TopologySnapshot -from exo.shared.types.common import NodeId +from exo.shared.types.chunks import InputImageChunk +from exo.shared.types.common import CommandId, NodeId from exo.shared.types.instance_link import InstanceLink, InstanceLinkId from exo.shared.types.profiling import ( DiskUsage, @@ -45,6 +46,9 @@ class State(FrozenModel): runners: Mapping[RunnerId, RunnerStatus] = {} downloads: Mapping[NodeId, Sequence[DownloadProgress]] = {} tasks: Mapping[TaskId, Task] = {} + # Durable request input chunks for active image requests. Workers rebuild + # local image caches from this state instead of reading events directly. + input_chunks: Mapping[CommandId, Mapping[int, InputImageChunk]] = {} last_seen: Mapping[NodeId, datetime] = {} topology: Topology = Field(default_factory=Topology) last_event_applied_idx: int = Field(default=-1, ge=-1) diff --git a/src/exo/worker/main.py b/src/exo/worker/main.py index b35f946aac..45e33ad067 100644 --- a/src/exo/worker/main.py +++ b/src/exo/worker/main.py @@ -24,7 +24,6 @@ CustomModelCardDeleted, Event, IndexedEvent, - InputChunkReceived, InstanceDeleted, NodeDownloadProgress, NodeGatheredInfo, @@ -141,37 +140,6 @@ async def _event_applier(self): if isinstance(event, InstanceDeleted): self._instance_backoff.reset(event.instance_id) - # Buffer input image chunks for image editing - if isinstance(event, InputChunkReceived): - cmd_id = event.command_id - if cmd_id not in self.input_chunk_buffer: - self.input_chunk_buffer[cmd_id] = {} - self.input_chunk_counts[cmd_id] = event.chunk.total_chunks - - self.input_chunk_buffer[cmd_id][event.chunk.chunk_index] = ( - event.chunk - ) - - if ( - len(self.input_chunk_buffer[cmd_id]) - == self.input_chunk_counts[cmd_id] - ): - per_image: defaultdict[int, list[InputImageChunk]] = ( - defaultdict(list) - ) - for chunk in self.input_chunk_buffer[cmd_id].values(): - per_image[chunk.image_index].append(chunk) - for chunks_for_image in per_image.values(): - sorted_chunks = sorted( - chunks_for_image, key=lambda c: c.chunk_index - ) - img = Base64Image("".join(c.data for c in sorted_chunks)) - self.image_cache[ - Base64ImageHash( - hashlib.sha256(img.encode("ascii")).hexdigest() - ) - ] = img - if isinstance(event, CustomModelCardAdded): await event.model_card.save_to_custom_dir() add_to_card_cache(event.model_card) @@ -179,6 +147,35 @@ async def _event_applier(self): if isinstance(event, CustomModelCardDeleted): await delete_custom_card(event.model_id) + self._sync_input_views_from_state() + + def _sync_input_views_from_state(self) -> None: + self.input_chunk_buffer = { + command_id: dict(chunks) + for command_id, chunks in self.state.input_chunks.items() + } + self.input_chunk_counts = { + command_id: next(iter(chunks.values())).total_chunks + for command_id, chunks in self.input_chunk_buffer.items() + if chunks + } + + self.image_cache = {} + for command_id, chunks in self.input_chunk_buffer.items(): + expected_chunks = self.input_chunk_counts.get(command_id) + if expected_chunks is None or len(chunks) != expected_chunks: + continue + + per_image: defaultdict[int, list[InputImageChunk]] = defaultdict(list) + for chunk in chunks.values(): + per_image[chunk.image_index].append(chunk) + for chunks_for_image in per_image.values(): + sorted_chunks = sorted(chunks_for_image, key=lambda c: c.chunk_index) + image = Base64Image("".join(chunk.data for chunk in sorted_chunks)) + self.image_cache[ + Base64ImageHash(hashlib.sha256(image.encode("ascii")).hexdigest()) + ] = image + async def plan_step(self): while True: await anyio.sleep(0.1) @@ -189,7 +186,7 @@ async def plan_step(self): self.state.instances, self.state.runners, self.state.tasks, - self.input_chunk_buffer, + self.state.input_chunks, self.image_cache, self._instance_backoff, self._download_backoff, @@ -321,15 +318,9 @@ async def plan_step(self): advanced_params=task.task_params.advanced_params, ), ) - # Cleanup buffers - if cmd_id in self.input_chunk_buffer: - del self.input_chunk_buffer[cmd_id] - if cmd_id in self.input_chunk_counts: - del self.input_chunk_counts[cmd_id] await self._start_runner_task(modified_task) case TextGeneration() if task.task_params.image_hashes: - cmd_id = task.command_id resolved_images = [ self.image_cache[h] for _, h in sorted(task.task_params.image_hashes.items()) @@ -341,10 +332,6 @@ async def plan_step(self): ) } ) - if cmd_id in self.input_chunk_buffer: - del self.input_chunk_buffer[cmd_id] - if cmd_id in self.input_chunk_counts: - del self.input_chunk_counts[cmd_id] await self._start_runner_task(modified_task) case LoadModel(instance_id=instance_id): if (instance := self.state.instances.get(instance_id)) is not None: From f7bdef9f085212f7556f5a5f2620c846102d81fd Mon Sep 17 00:00:00 2001 From: Alex Cheema Date: Sun, 3 May 2026 01:09:09 +0100 Subject: [PATCH 02/11] feat: allow event router buffer fast-forward --- src/exo/routing/event_router.py | 12 +++++++---- src/exo/routing/tests/test_event_buffer.py | 25 ++++++++++++++++++++++ src/exo/utils/event_buffer.py | 12 +++++++++++ 3 files changed, 45 insertions(+), 4 deletions(-) diff --git a/src/exo/routing/event_router.py b/src/exo/routing/event_router.py index 4f99c15250..41d62518b5 100644 --- a/src/exo/routing/event_router.py +++ b/src/exo/routing/event_router.py @@ -80,6 +80,9 @@ def receiver(self) -> Receiver[IndexedEvent]: def shutdown(self) -> None: self._tg.cancel_tasks() + def set_buffer_start(self, idx: int) -> None: + self.event_buffer.fast_forward_to(idx) + async def _ingest(self, system_id: SystemId, recv: Receiver[Event]): idx = 0 with recv as events: @@ -95,7 +98,6 @@ async def _ingest(self, system_id: SystemId, recv: Receiver[Event]): self.out_for_delivery[event.event_id] = (anyio.current_time(), f_ev) async def _run_ext_in(self): - buf = OrderedBuffer[Event]() with self.external_inbound as events: async for event in events: if event.session != self.session_id: @@ -103,12 +105,12 @@ async def _run_ext_in(self): if event.origin != self.session_id.master_node_id: continue - buf.ingest(event.origin_idx, event.event) + self.event_buffer.ingest(event.origin_idx, event.event) event_id = event.event.event_id if event_id in self.out_for_delivery: self.out_for_delivery.pop(event_id) - drained = buf.drain_indexed() + drained = self.event_buffer.drain_indexed() if drained: self._nack_attempts = 0 if self._nack_cancel_scope: @@ -119,7 +121,9 @@ async def _run_ext_in(self): or self._nack_cancel_scope.cancel_called ): # Request the next index. - self._tg.start_soon(self._nack_request, buf.next_idx_to_release) + self._tg.start_soon( + self._nack_request, self.event_buffer.next_idx_to_release + ) continue for idx, event in drained: diff --git a/src/exo/routing/tests/test_event_buffer.py b/src/exo/routing/tests/test_event_buffer.py index 215f53e26f..dcf1d549cd 100644 --- a/src/exo/routing/tests/test_event_buffer.py +++ b/src/exo/routing/tests/test_event_buffer.py @@ -141,3 +141,28 @@ async def test_drain_and_ingest_with_new_sequence(buffer: OrderedBuffer[Event]): assert [e[0] for e in drained] == [2] assert buffer.next_idx_to_release == 3 assert 4 in buffer.store + + +@pytest.mark.asyncio +async def test_fast_forward_discards_buffered_stale_events( + buffer: OrderedBuffer[Event], +): + buffer.ingest(*make_indexed_event(0)) + buffer.ingest(*make_indexed_event(2)) + buffer.ingest(*make_indexed_event(4)) + + buffer.fast_forward_to(3) + + assert buffer.next_idx_to_release == 3 + assert set(buffer.store) == {4} + + +@pytest.mark.asyncio +async def test_fast_forward_only_moves_forward(buffer: OrderedBuffer[Event]): + buffer.ingest(*make_indexed_event(0)) + buffer.ingest(*make_indexed_event(1)) + buffer.drain() + + buffer.fast_forward_to(1) + + assert buffer.next_idx_to_release == 2 diff --git a/src/exo/utils/event_buffer.py b/src/exo/utils/event_buffer.py index 8fcf5fa282..31fc233d8e 100644 --- a/src/exo/utils/event_buffer.py +++ b/src/exo/utils/event_buffer.py @@ -47,6 +47,18 @@ def drain_indexed(self) -> list[tuple[int, T]]: logger.trace(f"Releasing event {ret}") return ret + def fast_forward_to(self, idx: int) -> None: + """Skip every event before idx. + + Snapshot restore uses this after applying state that already includes + events before idx. Any buffered or future event below idx is stale. + """ + if idx <= self.next_idx_to_release: + return + self.next_idx_to_release = idx + for stale_idx in [i for i in self.store if i < idx]: + del self.store[stale_idx] + class MultiSourceBuffer[SourceId, T]: """ From 0e6a56baee940484bfa76f129af98de555183907 Mon Sep 17 00:00:00 2001 From: Alex Cheema Date: Sun, 3 May 2026 01:12:31 +0100 Subject: [PATCH 03/11] feat: version state snapshots --- src/exo/shared/tests/test_state_serialization.py | 1 + src/exo/shared/types/state.py | 3 +++ 2 files changed, 4 insertions(+) diff --git a/src/exo/shared/tests/test_state_serialization.py b/src/exo/shared/tests/test_state_serialization.py index ec94156da3..533f179e85 100644 --- a/src/exo/shared/tests/test_state_serialization.py +++ b/src/exo/shared/tests/test_state_serialization.py @@ -25,6 +25,7 @@ def test_state_serialization_roundtrip() -> None: json_repr = state.model_dump_json() restored_state = State.model_validate_json(json_repr) + assert restored_state.schema_version == state.schema_version assert ( state.topology.to_snapshot().nodes == restored_state.topology.to_snapshot().nodes diff --git a/src/exo/shared/types/state.py b/src/exo/shared/types/state.py index e0b39e7ff3..21199e7af3 100644 --- a/src/exo/shared/types/state.py +++ b/src/exo/shared/types/state.py @@ -42,6 +42,9 @@ class State(FrozenModel): strict=True, arbitrary_types_allowed=True, ) + # Bump when a State change makes older snapshots unsafe to restore. + schema_version: int = Field(default=1, ge=1) + instances: Mapping[InstanceId, Instance] = {} runners: Mapping[RunnerId, RunnerStatus] = {} downloads: Mapping[NodeId, Sequence[DownloadProgress]] = {} From 343d5bc6d4f9b9745b00be386db4c30c6f0c917d Mon Sep 17 00:00:00 2001 From: Alex Cheema Date: Sun, 3 May 2026 01:13:48 +0100 Subject: [PATCH 04/11] feat: add snapshot receiver --- src/exo/routing/snapshot_receiver.py | 109 +++++++++++++ .../routing/tests/test_snapshot_receiver.py | 151 ++++++++++++++++++ src/exo/shared/types/snapshots.py | 54 +++++++ 3 files changed, 314 insertions(+) create mode 100644 src/exo/routing/snapshot_receiver.py create mode 100644 src/exo/routing/tests/test_snapshot_receiver.py create mode 100644 src/exo/shared/types/snapshots.py diff --git a/src/exo/routing/snapshot_receiver.py b/src/exo/routing/snapshot_receiver.py new file mode 100644 index 0000000000..1d4c1fbb5f --- /dev/null +++ b/src/exo/routing/snapshot_receiver.py @@ -0,0 +1,109 @@ +"""Reassembles a snapshot from a stream of `SnapshotChunk`s. + +A receiver belongs to one node; it ignores chunks addressed to other +requesters and chunks from prior sessions. Once a transfer's chunks have +all been collected and the SHA-256 checks out, the snapshot is decoded into +a `State`. Concurrent transfers (for the same requester) are tolerated: +each is keyed by `transfer_id`. +""" + +from __future__ import annotations + +import hashlib +from dataclasses import dataclass, field +from typing import final + +import zstandard +from loguru import logger + +from exo.shared.types.common import NodeId, SessionId +from exo.shared.types.snapshots import SnapshotChunk, SnapshotTransferId +from exo.shared.types.state import State + + +@final +@dataclass +class _Assembly: + """Partial state for one in-flight snapshot transfer.""" + + total_chunks: int + sha256_hex: str + schema_version: int + last_event_applied_idx: int + chunks: dict[int, bytes] = field(default_factory=dict) + + def is_complete(self) -> bool: + return len(self.chunks) == self.total_chunks + + def assemble(self) -> bytes: + return b"".join(self.chunks[i] for i in range(self.total_chunks)) + + +@dataclass +class ReceivedSnapshot: + last_event_applied_idx: int + state: State + + +class SnapshotReceiver: + """Filters and reassembles inbound chunks into a `ReceivedSnapshot`. + + Stateless w.r.t. delivery: callers feed `SnapshotChunk`s in via `ingest` + and check the return value for completion. + """ + + def __init__(self, my_node_id: NodeId, session_id: SessionId) -> None: + self._my_node_id = my_node_id + self._session_id = session_id + self._assemblies: dict[SnapshotTransferId, _Assembly] = {} + + def ingest(self, chunk: SnapshotChunk) -> ReceivedSnapshot | None: + """Absorb a chunk; return the snapshot once a transfer completes. + + Returns None for partial transfers, mismatched recipients, stale + sessions, version mismatches, or corrupt payloads. + """ + if chunk.requester_node_id != self._my_node_id: + return None + if chunk.session_id != self._session_id: + return None + + existing = self._assemblies.get(chunk.transfer_id) + if existing is None: + existing = _Assembly( + total_chunks=chunk.total_chunks, + sha256_hex=chunk.sha256_hex, + schema_version=chunk.schema_version, + last_event_applied_idx=chunk.last_event_applied_idx, + ) + self._assemblies[chunk.transfer_id] = existing + existing.chunks[chunk.chunk_index] = chunk.data + + if not existing.is_complete(): + return None + + # Transfer complete — finalise and remove from the in-flight map. + del self._assemblies[chunk.transfer_id] + body = existing.assemble() + if hashlib.sha256(body).hexdigest() != existing.sha256_hex: + logger.warning(f"Snapshot {chunk.transfer_id} failed checksum; discarding") + return None + try: + decompressed = zstandard.ZstdDecompressor().decompress(body) + state = State.model_validate_json(decompressed.decode("utf-8")) + except (zstandard.ZstdError, ValueError) as e: + logger.opt(exception=e).warning( + f"Snapshot {chunk.transfer_id} could not be decoded; discarding" + ) + return None + if state.schema_version != existing.schema_version: + # Should not happen — the master writes schema_version into both + # the chunk meta and the State payload — but treat it as corrupt. + logger.warning( + f"Snapshot {chunk.transfer_id} schema version mismatch " + f"(chunk={existing.schema_version}, state={state.schema_version})" + ) + return None + return ReceivedSnapshot( + last_event_applied_idx=existing.last_event_applied_idx, state=state + ) diff --git a/src/exo/routing/tests/test_snapshot_receiver.py b/src/exo/routing/tests/test_snapshot_receiver.py new file mode 100644 index 0000000000..8ee74e234b --- /dev/null +++ b/src/exo/routing/tests/test_snapshot_receiver.py @@ -0,0 +1,151 @@ +import hashlib + +import pytest +import zstandard + +from exo.routing.snapshot_receiver import SnapshotReceiver +from exo.shared.types.common import NodeId, SessionId +from exo.shared.types.snapshots import SnapshotChunk, SnapshotTransferId +from exo.shared.types.state import State + + +@pytest.fixture +def session_id() -> SessionId: + return SessionId(master_node_id=NodeId("master"), election_clock=0) + + +@pytest.fixture +def my_node() -> NodeId: + return NodeId("worker-1") + + +def _encode(state: State) -> bytes: + return zstandard.ZstdCompressor().compress(state.model_dump_json().encode("utf-8")) + + +def _make_chunks( + body: bytes, + *, + chunk_size: int, + requester_node_id: NodeId, + session_id: SessionId, + state: State, + transfer_id: SnapshotTransferId | None = None, +) -> list[SnapshotChunk]: + sha256 = hashlib.sha256(body).hexdigest() + transfer_id = transfer_id or SnapshotTransferId() + pieces = [body[i : i + chunk_size] for i in range(0, len(body), chunk_size)] or [ + b"" + ] + return [ + SnapshotChunk.from_data( + data=piece, + transfer_id=transfer_id, + requester_node_id=requester_node_id, + session_id=session_id, + schema_version=state.schema_version, + last_event_applied_idx=state.last_event_applied_idx, + chunk_index=i, + total_chunks=len(pieces), + sha256_hex=sha256, + ) + for i, piece in enumerate(pieces) + ] + + +def test_completes_on_full_transfer(my_node: NodeId, session_id: SessionId): + state = State(last_event_applied_idx=42) + chunks = _make_chunks( + _encode(state), + chunk_size=64, + requester_node_id=my_node, + session_id=session_id, + state=state, + ) + receiver = SnapshotReceiver(my_node, session_id) + + received = None + for chunk in chunks: + received = receiver.ingest(chunk) + assert received is not None + assert received.last_event_applied_idx == 42 + assert received.state.last_event_applied_idx == 42 + + +def test_handles_out_of_order_chunks(my_node: NodeId, session_id: SessionId): + state = State(last_event_applied_idx=99) + chunks = _make_chunks( + _encode(state), + chunk_size=32, + requester_node_id=my_node, + session_id=session_id, + state=state, + ) + receiver = SnapshotReceiver(my_node, session_id) + + # Reverse them. + received = None + for chunk in reversed(chunks): + received = receiver.ingest(chunk) + assert received is not None + assert received.last_event_applied_idx == 99 + + +def test_ignores_chunks_for_other_recipients(my_node: NodeId, session_id: SessionId): + state = State(last_event_applied_idx=1) + other = NodeId("worker-2") + chunks = _make_chunks( + _encode(state), + chunk_size=64, + requester_node_id=other, + session_id=session_id, + state=state, + ) + receiver = SnapshotReceiver(my_node, session_id) + for chunk in chunks: + assert receiver.ingest(chunk) is None + + +def test_ignores_chunks_from_stale_session(my_node: NodeId, session_id: SessionId): + state = State(last_event_applied_idx=1) + other_session = SessionId(master_node_id=NodeId("other-master"), election_clock=99) + chunks = _make_chunks( + _encode(state), + chunk_size=64, + requester_node_id=my_node, + session_id=other_session, + state=state, + ) + receiver = SnapshotReceiver(my_node, session_id) + for chunk in chunks: + assert receiver.ingest(chunk) is None + + +def test_discards_on_checksum_mismatch(my_node: NodeId, session_id: SessionId): + state = State(last_event_applied_idx=1) + chunks = _make_chunks( + _encode(state), + chunk_size=64, + requester_node_id=my_node, + session_id=session_id, + state=state, + ) + # Corrupt the last byte of the last chunk. + original = chunks[-1] + chunks[-1] = SnapshotChunk.from_data( + data=original.data + b"\x00garbage", + transfer_id=original.transfer_id, + requester_node_id=original.requester_node_id, + session_id=original.session_id, + schema_version=original.schema_version, + last_event_applied_idx=original.last_event_applied_idx, + chunk_index=original.chunk_index, + total_chunks=original.total_chunks, + sha256_hex=original.sha256_hex, + ) + + receiver = SnapshotReceiver(my_node, session_id) + received = None + for chunk in chunks: + received = receiver.ingest(chunk) + assert received is None diff --git a/src/exo/shared/types/snapshots.py b/src/exo/shared/types/snapshots.py new file mode 100644 index 0000000000..0ad098aefe --- /dev/null +++ b/src/exo/shared/types/snapshots.py @@ -0,0 +1,54 @@ +"""Wire types for snapshot transfer between master and a joining node. + +Snapshots can be tens of MB; the gossipsub message ceiling is around 1 MB. +We slice the compressed snapshot body into chunks and publish each chunk on +the SNAPSHOT_RESPONSES topic. The receiver collects chunks for its own +`requester_node_id`, validates the SHA-256 of the reassembled body, and +materialises the State. +""" + +import base64 + +from exo.shared.types.common import Id, NodeId, SessionId +from exo.utils.pydantic_ext import FrozenModel + + +class SnapshotTransferId(Id): + """Identifies a single snapshot transfer (one master response to one + `RequestSnapshot`). Distinct transfers may interleave; the id lets + receivers keep them apart.""" + + +class SnapshotChunk(FrozenModel): + """One slice of a snapshot in flight. + + `data_b64` carries a base64-encoded slice of the zstd-compressed JSON + dump of State. Concatenating the *decoded* bytes of all chunks for a + `transfer_id` in order of `chunk_index` yields the full compressed + body; `sha256_hex` is the SHA-256 of that decoded blob. + + We use base64 explicitly because the topic layer JSON-encodes messages, + and JSON can't carry raw binary. Helpers `from_data` / `data` keep the + base64 detail at the boundaries. + """ + + transfer_id: SnapshotTransferId + requester_node_id: NodeId + session_id: SessionId + schema_version: int + last_event_applied_idx: int + chunk_index: int + total_chunks: int + sha256_hex: str + data_b64: str + + @classmethod + def from_data(cls, *, data: bytes, **kwargs: object) -> "SnapshotChunk": + return cls(data_b64=base64.b64encode(data).decode("ascii"), **kwargs) # pyright: ignore[reportArgumentType] + + @property + def data(self) -> bytes: + return base64.b64decode(self.data_b64) + + +__all__ = ["SnapshotChunk", "SnapshotTransferId"] From c89caaf87a0ad55c13c98b7430fa374e03c864fa Mon Sep 17 00:00:00 2001 From: Alex Cheema Date: Sun, 3 May 2026 01:15:46 +0100 Subject: [PATCH 05/11] feat: add snapshot routing types --- src/exo/master/main.py | 5 +++ .../tests/test_snapshot_routing_types.py | 37 +++++++++++++++++++ src/exo/routing/topics.py | 4 ++ src/exo/shared/types/commands.py | 7 ++++ 4 files changed, 53 insertions(+) create mode 100644 src/exo/routing/tests/test_snapshot_routing_types.py diff --git a/src/exo/master/main.py b/src/exo/master/main.py index 044dfbf453..c0a638eeef 100644 --- a/src/exo/master/main.py +++ b/src/exo/master/main.py @@ -25,6 +25,7 @@ ImageGeneration, PlaceInstance, RequestEventLog, + RequestSnapshot, SendInputChunk, SetInstanceLink, TaskCancelled, @@ -447,6 +448,10 @@ async def _command_processor(self) -> None: start=command.since_idx, ): await self._send_event(IndexedEvent(idx=i, event=event)) + case RequestSnapshot(): + logger.info( + "Ignoring RequestSnapshot; snapshot serving is not wired yet" + ) for event in generated_events: await self.event_sender.send(event) except ValueError as e: diff --git a/src/exo/routing/tests/test_snapshot_routing_types.py b/src/exo/routing/tests/test_snapshot_routing_types.py new file mode 100644 index 0000000000..ad10fa10d7 --- /dev/null +++ b/src/exo/routing/tests/test_snapshot_routing_types.py @@ -0,0 +1,37 @@ +from exo.routing import topics +from exo.shared.types.commands import ForwarderCommand, RequestSnapshot +from exo.shared.types.common import NodeId, SessionId, SystemId +from exo.shared.types.snapshots import SnapshotChunk, SnapshotTransferId + + +def test_request_snapshot_round_trips_through_forwarder_command() -> None: + command = ForwarderCommand( + origin=SystemId("system-1"), + command=RequestSnapshot(requester_node_id=NodeId("worker-1")), + ) + + restored = ForwarderCommand.model_validate_json(command.model_dump_json()) + + assert isinstance(restored.command, RequestSnapshot) + assert restored.command.requester_node_id == NodeId("worker-1") + + +def test_snapshot_response_topic_round_trips_chunk() -> None: + chunk = SnapshotChunk.from_data( + data=b"snapshot-bytes", + transfer_id=SnapshotTransferId("transfer-1"), + requester_node_id=NodeId("worker-1"), + session_id=SessionId(master_node_id=NodeId("master"), election_clock=1), + schema_version=1, + last_event_applied_idx=42, + chunk_index=0, + total_chunks=1, + sha256_hex="unused", + ) + + restored = topics.SNAPSHOT_RESPONSES.deserialize( + topics.SNAPSHOT_RESPONSES.serialize(chunk) + ) + + assert restored == chunk + assert restored.data == b"snapshot-bytes" diff --git a/src/exo/routing/topics.py b/src/exo/routing/topics.py index 9776e54246..db2b832d1d 100644 --- a/src/exo/routing/topics.py +++ b/src/exo/routing/topics.py @@ -8,6 +8,7 @@ GlobalForwarderEvent, LocalForwarderEvent, ) +from exo.shared.types.snapshots import SnapshotChunk from exo.utils.pydantic_ext import FrozenModel @@ -49,3 +50,6 @@ def deserialize(self, b: bytes) -> T: DOWNLOAD_COMMANDS = TypedTopic( "download_commands", PublishPolicy.Always, ForwarderDownloadCommand ) +SNAPSHOT_RESPONSES = TypedTopic( + "snapshot_responses", PublishPolicy.Always, SnapshotChunk +) diff --git a/src/exo/shared/types/commands.py b/src/exo/shared/types/commands.py index 67d318b255..ce684f8487 100644 --- a/src/exo/shared/types/commands.py +++ b/src/exo/shared/types/commands.py @@ -67,6 +67,12 @@ class RequestEventLog(BaseCommand): since_idx: int +class RequestSnapshot(BaseCommand): + """Ask the current master to send a State snapshot to this node.""" + + requester_node_id: NodeId + + class StartDownload(BaseCommand): target_node_id: NodeId shard_metadata: ShardMetadata @@ -106,6 +112,7 @@ class DeleteInstanceLink(BaseCommand): Command = ( TestCommand | RequestEventLog + | RequestSnapshot | TextGeneration | ImageGeneration | ImageEdits From 24c56138e32583d7dfc26292bcbf0f845394616a Mon Sep 17 00:00:00 2001 From: Alex Cheema Date: Sun, 3 May 2026 01:18:29 +0100 Subject: [PATCH 06/11] feat: serve state snapshots from master --- src/exo/main.py | 5 +++ src/exo/master/main.py | 60 +++++++++++++++++++++++++++-- src/exo/master/tests/test_master.py | 54 ++++++++++++++++++++++++++ 3 files changed, 116 insertions(+), 3 deletions(-) diff --git a/src/exo/main.py b/src/exo/main.py index 30c54e292e..26fa9563bc 100644 --- a/src/exo/main.py +++ b/src/exo/main.py @@ -59,6 +59,7 @@ async def create(cls, args: "Args") -> Self: await router.register_topic(topics.ELECTION_MESSAGES) await router.register_topic(topics.CONNECTION_MESSAGES) await router.register_topic(topics.DOWNLOAD_COMMANDS) + await router.register_topic(topics.SNAPSHOT_RESPONSES) event_router = EventRouter( session_id, command_sender=router.sender(topics.COMMANDS), @@ -112,6 +113,7 @@ async def create(cls, args: "Args") -> Self: global_event_sender=router.sender(topics.GLOBAL_EVENTS), local_event_receiver=router.receiver(topics.LOCAL_EVENTS), command_receiver=router.receiver(topics.COMMANDS), + snapshot_chunk_sender=router.sender(topics.SNAPSHOT_RESPONSES), download_command_sender=router.sender(topics.DOWNLOAD_COMMANDS), ) @@ -210,6 +212,9 @@ async def _elect_loop(self): global_event_sender=self.router.sender(topics.GLOBAL_EVENTS), local_event_receiver=self.router.receiver(topics.LOCAL_EVENTS), command_receiver=self.router.receiver(topics.COMMANDS), + snapshot_chunk_sender=self.router.sender( + topics.SNAPSHOT_RESPONSES + ), download_command_sender=self.router.sender( topics.DOWNLOAD_COMMANDS ), diff --git a/src/exo/master/main.py b/src/exo/master/main.py index c0a638eeef..38d4ab1fae 100644 --- a/src/exo/master/main.py +++ b/src/exo/master/main.py @@ -1,6 +1,8 @@ +import hashlib from datetime import datetime, timedelta, timezone import anyio +from anyio import to_thread from loguru import logger from exo.master.placement import ( @@ -55,6 +57,7 @@ TracesMerged, ) from exo.shared.types.instance_link import InstanceLink +from exo.shared.types.snapshots import SnapshotChunk, SnapshotTransferId from exo.shared.types.state import State from exo.shared.types.tasks import ( ImageEdits as ImageEditsTask, @@ -75,6 +78,15 @@ from exo.utils.event_buffer import MultiSourceBuffer from exo.utils.task_group import TaskGroup +_SNAPSHOT_CHUNK_BYTES = 512 * 1024 +_MAX_EVENT_LOG_REPLAY_BATCH = 1000 + + +def _encode_state_for_transfer(state: State) -> bytes: + import zstandard + + return zstandard.ZstdCompressor().compress(state.model_dump_json().encode("utf-8")) + def _prefill_endpoint_for(state: State, decode_instance_id: InstanceId) -> str | None: decode = state.instances.get(decode_instance_id) @@ -126,6 +138,7 @@ def __init__( event_sender: Sender[Event], local_event_receiver: Receiver[LocalForwarderEvent], global_event_sender: Sender[GlobalForwarderEvent], + snapshot_chunk_sender: Sender[SnapshotChunk], download_command_sender: Sender[ForwarderDownloadCommand], ): self.node_id = node_id @@ -136,6 +149,7 @@ def __init__( self.command_receiver = command_receiver self.local_event_receiver = local_event_receiver self.global_event_sender = global_event_sender + self.snapshot_chunk_sender = snapshot_chunk_sender self.download_command_sender = download_command_sender self.event_sender = event_sender self._system_id = SystemId() @@ -156,6 +170,7 @@ async def run(self): self._event_log.close() self.global_event_sender.close() self.local_event_receiver.close() + self.snapshot_chunk_sender.close() self.command_receiver.close() async def shutdown(self): @@ -442,15 +457,18 @@ async def _command_processor(self) -> None: case RequestEventLog(): # We should just be able to send everything, since other buffers will ignore old messages # rate limit to 1000 at a time - end = min(command.since_idx + 1000, len(self._event_log)) + end = min( + command.since_idx + _MAX_EVENT_LOG_REPLAY_BATCH, + len(self._event_log), + ) for i, event in enumerate( self._event_log.read_range(command.since_idx, end), start=command.since_idx, ): await self._send_event(IndexedEvent(idx=i, event=event)) case RequestSnapshot(): - logger.info( - "Ignoring RequestSnapshot; snapshot serving is not wired yet" + self._tg.start_soon( + self._serve_snapshot, command.requester_node_id ) for event in generated_events: await self.event_sender.send(event) @@ -511,6 +529,42 @@ async def _event_processor(self) -> None: self._event_log.append(event) await self._send_event(indexed) + async def _serve_snapshot(self, requester_node_id: NodeId) -> None: + state = self.state + if state.last_event_applied_idx < 0: + logger.info( + f"RequestSnapshot from {requester_node_id} but master has no events yet" + ) + return + + body = await to_thread.run_sync(_encode_state_for_transfer, state) + sha256 = hashlib.sha256(body).hexdigest() + chunks = [ + body[i : i + _SNAPSHOT_CHUNK_BYTES] + for i in range(0, len(body), _SNAPSHOT_CHUNK_BYTES) + ] or [b""] + transfer_id = SnapshotTransferId() + + logger.info( + f"Serving snapshot to {requester_node_id}: " + f"idx={state.last_event_applied_idx}, " + f"{len(chunks)} chunk(s), {len(body)} bytes total" + ) + for index, chunk in enumerate(chunks): + await self.snapshot_chunk_sender.send( + SnapshotChunk.from_data( + data=chunk, + transfer_id=transfer_id, + requester_node_id=requester_node_id, + session_id=self.session_id, + schema_version=state.schema_version, + last_event_applied_idx=state.last_event_applied_idx, + chunk_index=index, + total_chunks=len(chunks), + sha256_hex=sha256, + ) + ) + # This function is re-entrant, take care! async def _send_event(self, event: IndexedEvent): # Convenience method since this line is ugly diff --git a/src/exo/master/tests/test_master.py b/src/exo/master/tests/test_master.py index c4a1cff0c0..4138d12adb 100644 --- a/src/exo/master/tests/test_master.py +++ b/src/exo/master/tests/test_master.py @@ -7,12 +7,14 @@ from exo.master.main import Master from exo.routing.router import get_node_id_keypair +from exo.routing.snapshot_receiver import SnapshotReceiver from exo.shared.models.model_cards import ModelCard, ModelTask from exo.shared.types.commands import ( CommandId, ForwarderCommand, ForwarderDownloadCommand, PlaceInstance, + RequestSnapshot, TextGeneration, ) from exo.shared.types.common import ModelId, NodeId, SessionId, SystemId @@ -29,6 +31,7 @@ from exo.shared.types.profiling import ( MemoryUsage, ) +from exo.shared.types.snapshots import SnapshotChunk from exo.shared.types.tasks import TaskStatus from exo.shared.types.tasks import TextGeneration as TextGenerationTask from exo.shared.types.text_generation import ( @@ -56,6 +59,7 @@ async def test_master(): local_event_sender, le_receiver = channel[LocalForwarderEvent]() fcds, _fcdr = channel[ForwarderDownloadCommand]() ev_send, ev_recv = channel[Event]() + snapshot_chunk_send, _snapshot_chunk_recv = channel[SnapshotChunk]() async def mock_event_router(): idx = 0 @@ -92,6 +96,7 @@ def _get_events() -> Sequence[IndexedEvent]: global_event_sender=ge_sender, local_event_receiver=le_receiver, command_receiver=co_receiver, + snapshot_chunk_sender=snapshot_chunk_send, download_command_sender=fcds, ) logger.info("run the master") @@ -229,3 +234,52 @@ def _get_events() -> Sequence[IndexedEvent]: ev_send.close() await master.shutdown() + + +@pytest.mark.asyncio +async def test_master_serves_snapshot_for_current_state(): + node_id = NodeId("master") + requester_node_id = NodeId("worker") + session_id = SessionId(master_node_id=node_id, election_clock=0) + + ge_sender, _global_event_receiver = channel[GlobalForwarderEvent]() + command_sender, command_receiver = channel[ForwarderCommand]() + _local_event_sender, local_event_receiver = channel[LocalForwarderEvent]() + download_command_sender, _download_command_receiver = channel[ + ForwarderDownloadCommand + ]() + event_sender, _event_receiver = channel[Event]() + snapshot_chunk_sender, snapshot_chunk_receiver = channel[SnapshotChunk]() + + master = Master( + node_id, + session_id, + event_sender=event_sender, + global_event_sender=ge_sender, + local_event_receiver=local_event_receiver, + command_receiver=command_receiver, + snapshot_chunk_sender=snapshot_chunk_sender, + download_command_sender=download_command_sender, + ) + master.state = master.state.model_copy(update={"last_event_applied_idx": 12}) + + receiver = SnapshotReceiver(requester_node_id, session_id) + async with anyio.create_task_group() as tg: + tg.start_soon(master.run) + await command_sender.send( + ForwarderCommand( + origin=SystemId("api"), + command=RequestSnapshot(requester_node_id=requester_node_id), + ) + ) + + received = None + while received is None: + chunk = await snapshot_chunk_receiver.receive() + received = receiver.ingest(chunk) + + assert received.last_event_applied_idx == 12 + assert received.state.last_event_applied_idx == 12 + + await master.shutdown() + tg.cancel_scope.cancel() From a2c2a0fc92c8d6be510d385b1aefc843c8717b59 Mon Sep 17 00:00:00 2001 From: Alex Cheema Date: Sun, 3 May 2026 01:35:43 +0100 Subject: [PATCH 07/11] feat: bootstrap worker state from snapshot --- src/exo/main.py | 8 ++ src/exo/worker/main.py | 62 ++++++++- .../test_worker_snapshot_bootstrap.py | 126 ++++++++++++++++++ 3 files changed, 190 insertions(+), 6 deletions(-) create mode 100644 src/exo/worker/tests/unittests/test_worker_snapshot_bootstrap.py diff --git a/src/exo/main.py b/src/exo/main.py index 26fa9563bc..24c598d095 100644 --- a/src/exo/main.py +++ b/src/exo/main.py @@ -96,8 +96,11 @@ async def create(cls, args: "Args") -> Self: if not args.no_worker: worker = Worker( node_id, + session_id, + event_router=event_router, event_receiver=event_router.receiver(), event_sender=event_router.sender(), + snapshot_chunk_receiver=router.receiver(topics.SNAPSHOT_RESPONSES), command_sender=router.sender(topics.COMMANDS), download_command_sender=router.sender(topics.DOWNLOAD_COMMANDS), api_port=args.api_port, @@ -251,8 +254,13 @@ async def _elect_loop(self): # TODO: add profiling etc to resource monitor self.worker = Worker( self.node_id, + result.session_id, + event_router=self.event_router, event_receiver=self.event_router.receiver(), event_sender=self.event_router.sender(), + snapshot_chunk_receiver=self.router.receiver( + topics.SNAPSHOT_RESPONSES + ), command_sender=self.router.sender(topics.COMMANDS), download_command_sender=self.router.sender( topics.DOWNLOAD_COMMANDS diff --git a/src/exo/worker/main.py b/src/exo/worker/main.py index 45e33ad067..254c27fead 100644 --- a/src/exo/worker/main.py +++ b/src/exo/worker/main.py @@ -8,6 +8,8 @@ from exo.api.types import ImageEditsTaskParams from exo.download.download_utils import is_read_only_model_dir, resolve_existing_model +from exo.routing.event_router import EventRouter +from exo.routing.snapshot_receiver import SnapshotReceiver from exo.shared.apply import apply from exo.shared.constants import EXO_MAX_INSTANCE_RETRIES from exo.shared.models.model_cards import ModelId, add_to_card_cache, delete_custom_card @@ -16,9 +18,10 @@ DeleteInstance, ForwarderCommand, ForwarderDownloadCommand, + RequestSnapshot, StartDownload, ) -from exo.shared.types.common import CommandId, NodeId, SystemId +from exo.shared.types.common import CommandId, NodeId, SessionId, SystemId from exo.shared.types.events import ( CustomModelCardAdded, CustomModelCardDeleted, @@ -33,6 +36,7 @@ TopologyEdgeDeleted, ) from exo.shared.types.multiaddr import Multiaddr +from exo.shared.types.snapshots import SnapshotChunk from exo.shared.types.state import State from exo.shared.types.tasks import ( CancelTask, @@ -58,14 +62,19 @@ from exo.worker.plan import plan from exo.worker.runner.supervisor import RunnerSupervisor +_SNAPSHOT_FETCH_TIMEOUT_SECONDS = 30 + class Worker: def __init__( self, node_id: NodeId, + session_id: SessionId, *, + event_router: EventRouter, event_receiver: Receiver[IndexedEvent], event_sender: Sender[Event], + snapshot_chunk_receiver: Receiver[SnapshotChunk], # This is for requesting updates. It doesn't need to be a general command sender right now, # but I think it's the correct way to be thinking about commands command_sender: Sender[ForwarderCommand], @@ -73,8 +82,11 @@ def __init__( api_port: int, ): self.node_id: NodeId = node_id + self.session_id: SessionId = session_id + self.event_router = event_router self.event_receiver = event_receiver self.event_sender = event_sender + self.snapshot_chunk_receiver = snapshot_chunk_receiver self.command_sender = command_sender self.download_command_sender = download_command_sender self.api_port = api_port @@ -104,21 +116,57 @@ async def run(self): try: async with self._tg as tg: - tg.start_soon(info_gatherer.run) - tg.start_soon(self._forward_info, info_recv) - tg.start_soon(self.plan_step) - tg.start_soon(self._event_applier) - tg.start_soon(self._poll_connection_updates) + tg.start_soon(self._bootstrap_then_run, info_gatherer, info_recv) finally: # Actual shutdown code - waits for all tasks to complete before executing. logger.info("Stopping Worker") self.event_sender.close() + self.snapshot_chunk_receiver.close() self.command_sender.close() self.download_command_sender.close() for runner in self.runners.values(): runner.shutdown() self._stopped.set() + async def _bootstrap_then_run( + self, info_gatherer: InfoGatherer, info_recv: Receiver[GatheredInfo] + ) -> None: + await self._fetch_snapshot() + self._sync_input_views_from_state() + self._tg.start_soon(info_gatherer.run) + self._tg.start_soon(self._forward_info, info_recv) + self._tg.start_soon(self.plan_step) + self._tg.start_soon(self._event_applier) + self._tg.start_soon(self._poll_connection_updates) + + async def _fetch_snapshot(self) -> None: + receiver = SnapshotReceiver(self.node_id, self.session_id) + await self.command_sender.send( + ForwarderCommand( + origin=self._system_id, + command=RequestSnapshot(requester_node_id=self.node_id), + ) + ) + + with anyio.move_on_after(_SNAPSHOT_FETCH_TIMEOUT_SECONDS): + with self.snapshot_chunk_receiver as chunks: + async for chunk in chunks: + received = receiver.ingest(chunk) + if received is None: + continue + self.state = received.state + self.event_router.set_buffer_start( + received.last_event_applied_idx + 1 + ) + logger.info( + f"Worker bootstrapped from snapshot at idx " + f"{received.last_event_applied_idx}" + ) + return + logger.info( + "No snapshot received before timeout; falling back to full event-log replay" + ) + async def _forward_info(self, recv: Receiver[GatheredInfo]): with recv as info_stream: async for info in info_stream: @@ -133,6 +181,8 @@ async def _forward_info(self, recv: Receiver[GatheredInfo]): async def _event_applier(self): with self.event_receiver as events: async for event in events: + if event.idx <= self.state.last_event_applied_idx: + continue # 2. for each event, apply it to the state self.state = apply(self.state, event=event) event = event.event diff --git a/src/exo/worker/tests/unittests/test_worker_snapshot_bootstrap.py b/src/exo/worker/tests/unittests/test_worker_snapshot_bootstrap.py new file mode 100644 index 0000000000..c44cf74d83 --- /dev/null +++ b/src/exo/worker/tests/unittests/test_worker_snapshot_bootstrap.py @@ -0,0 +1,126 @@ +# pyright: reportPrivateUsage=false + +import hashlib + +import anyio +import pytest +import zstandard + +from exo.routing.event_router import EventRouter +from exo.shared.types.commands import ( + ForwarderCommand, + ForwarderDownloadCommand, + RequestSnapshot, +) +from exo.shared.types.common import NodeId, SessionId +from exo.shared.types.events import ( + Event, + GlobalForwarderEvent, + IndexedEvent, + LocalForwarderEvent, + TestEvent, +) +from exo.shared.types.snapshots import SnapshotChunk, SnapshotTransferId +from exo.shared.types.state import State +from exo.utils.channels import Receiver, Sender, channel +from exo.worker.main import Worker + + +def _snapshot_chunk( + state: State, *, requester_node_id: NodeId, session_id: SessionId +) -> SnapshotChunk: + body = zstandard.ZstdCompressor().compress(state.model_dump_json().encode("utf-8")) + return SnapshotChunk.from_data( + data=body, + transfer_id=SnapshotTransferId("transfer-1"), + requester_node_id=requester_node_id, + session_id=session_id, + schema_version=state.schema_version, + last_event_applied_idx=state.last_event_applied_idx, + chunk_index=0, + total_chunks=1, + sha256_hex=hashlib.sha256(body).hexdigest(), + ) + + +def _worker( + node_id: NodeId, session_id: SessionId +) -> tuple[ + Worker, + EventRouter, + Receiver[ForwarderCommand], + Sender[SnapshotChunk], + Sender[IndexedEvent], +]: + router_command_sender, _router_command_receiver = channel[ForwarderCommand]() + _global_event_sender, global_event_receiver = channel[GlobalForwarderEvent]() + local_event_sender, _local_event_receiver = channel[LocalForwarderEvent]() + event_router = EventRouter( + session_id=session_id, + command_sender=router_command_sender, + external_inbound=global_event_receiver, + external_outbound=local_event_sender, + ) + + event_sender, event_receiver = channel[IndexedEvent]() + local_event_output_sender, _local_event_output_receiver = channel[Event]() + command_sender, command_receiver = channel[ForwarderCommand]() + download_command_sender, _download_command_receiver = channel[ + ForwarderDownloadCommand + ]() + snapshot_sender, snapshot_receiver = channel[SnapshotChunk]() + + worker = Worker( + node_id, + session_id, + event_router=event_router, + event_receiver=event_receiver, + event_sender=local_event_output_sender, + snapshot_chunk_receiver=snapshot_receiver, + command_sender=command_sender, + download_command_sender=download_command_sender, + api_port=52415, + ) + return worker, event_router, command_receiver, snapshot_sender, event_sender + + +@pytest.mark.asyncio +async def test_worker_fetch_snapshot_applies_state_and_fast_forwards_router() -> None: + node_id = NodeId("worker") + session_id = SessionId(master_node_id=NodeId("master"), election_clock=1) + worker, event_router, command_receiver, snapshot_sender, _event_sender = _worker( + node_id, session_id + ) + state = State(last_event_applied_idx=7) + + async with anyio.create_task_group() as tg: + tg.start_soon(worker._fetch_snapshot) + command = await command_receiver.receive() + assert isinstance(command.command, RequestSnapshot) + assert command.command.requester_node_id == node_id + + await snapshot_sender.send( + _snapshot_chunk(state, requester_node_id=node_id, session_id=session_id) + ) + + assert worker.state.last_event_applied_idx == 7 + assert event_router.event_buffer.next_idx_to_release == 8 + + +@pytest.mark.asyncio +async def test_worker_event_applier_ignores_events_covered_by_snapshot() -> None: + node_id = NodeId("worker") + session_id = SessionId(master_node_id=NodeId("master"), election_clock=1) + worker, _event_router, _command_receiver, _snapshot_sender, event_sender = _worker( + node_id, session_id + ) + worker.state = State(last_event_applied_idx=7) + + async with anyio.create_task_group() as tg: + tg.start_soon(worker._event_applier) + await event_sender.send(IndexedEvent(idx=7, event=TestEvent())) + await event_sender.send(IndexedEvent(idx=8, event=TestEvent())) + + while worker.state.last_event_applied_idx != 8: + await anyio.sleep(0.001) + tg.cancel_scope.cancel() From ea04549692d09625d87baca3a90c049fe413653b Mon Sep 17 00:00:00 2001 From: Alex Cheema Date: Sun, 3 May 2026 01:39:54 +0100 Subject: [PATCH 08/11] feat: bootstrap api state from snapshot --- src/exo/api/main.py | 66 ++++++++- .../api/tests/test_api_snapshot_bootstrap.py | 136 ++++++++++++++++++ src/exo/main.py | 11 +- 3 files changed, 208 insertions(+), 5 deletions(-) create mode 100644 src/exo/api/tests/test_api_snapshot_bootstrap.py diff --git a/src/exo/api/main.py b/src/exo/api/main.py index 902bf3162b..927b9bc3fa 100644 --- a/src/exo/api/main.py +++ b/src/exo/api/main.py @@ -121,6 +121,8 @@ ) from exo.master.image_store import ImageStore from exo.master.placement import place_instance as get_instance_placements +from exo.routing.event_router import EventRouter +from exo.routing.snapshot_receiver import SnapshotReceiver from exo.shared.apply import apply from exo.shared.constants import ( DASHBOARD_DIR, @@ -164,6 +166,7 @@ ImageEdits, ImageGeneration, PlaceInstance, + RequestSnapshot, SendInputChunk, SetInstanceLink, StartDownload, @@ -171,7 +174,7 @@ TaskFinished, TextGeneration, ) -from exo.shared.types.common import CommandId, Id, NodeId, SystemId +from exo.shared.types.common import CommandId, Id, NodeId, SessionId, SystemId from exo.shared.types.events import ( ChunkGenerated, Event, @@ -181,6 +184,7 @@ ) from exo.shared.types.instance_link import InstanceLink, InstanceLinkId from exo.shared.types.memory import Memory +from exo.shared.types.snapshots import SnapshotChunk from exo.shared.types.state import State from exo.shared.types.tasks import ( ImageEdits as ImageEditsTask, @@ -207,6 +211,8 @@ _API_EVENT_LOG_DIR = EXO_EVENT_LOG_DIR / "api" ONBOARDING_COMPLETE_FILE = EXO_CACHE_HOME / "onboarding_complete" +_SNAPSHOT_FETCH_TIMEOUT_SECONDS = 30 + def _format_to_content_type(image_format: Literal["png", "jpeg", "webp"] | None) -> str: return f"image/{image_format or 'png'}" @@ -236,9 +242,12 @@ class API: def __init__( self, node_id: NodeId, + session_id: SessionId, *, port: int, + event_router: EventRouter, event_receiver: Receiver[IndexedEvent], + snapshot_chunk_receiver: Receiver[SnapshotChunk], command_sender: Sender[ForwarderCommand], download_command_sender: Sender[ForwarderDownloadCommand], # This lets us pause the API if an election is running @@ -247,9 +256,12 @@ def __init__( self.state = State() self._event_log = DiskEventLog(_API_EVENT_LOG_DIR) self._system_id = SystemId() + self.session_id = session_id + self.event_router = event_router self.command_sender = command_sender self.download_command_sender = download_command_sender self.event_receiver = event_receiver + self.snapshot_chunk_receiver = snapshot_chunk_receiver self.election_receiver = election_receiver self.node_id: NodeId = node_id self.last_completed_election: int = 0 @@ -291,18 +303,29 @@ async def _log_requests( # pyright: ignore[reportUnusedFunction] self._image_store = ImageStore(EXO_IMAGE_CACHE_DIR) self._tg: TaskGroup = TaskGroup() - def reset(self, result_clock: int, event_receiver: Receiver[IndexedEvent]): + def reset( + self, + result_clock: int, + session_id: SessionId, + event_router: EventRouter, + event_receiver: Receiver[IndexedEvent], + snapshot_chunk_receiver: Receiver[SnapshotChunk], + ): logger.info("Resetting API State") self._event_log.close() self._event_log = DiskEventLog(_API_EVENT_LOG_DIR) self.state = State() self._system_id = SystemId() + self.session_id = session_id + self.event_router = event_router self._text_generation_queues = {} self._image_generation_queues = {} self.unpause(result_clock) self.event_receiver.close() self.event_receiver = event_receiver - self._tg.start_soon(self._apply_state) + self.snapshot_chunk_receiver.close() + self.snapshot_chunk_receiver = snapshot_chunk_receiver + self._tg.start_soon(self._bootstrap_then_apply_state) def unpause(self, result_clock: int): logger.info("Unpausing API") @@ -1836,7 +1859,7 @@ async def run(self): try: async with self._tg as tg: logger.info("Starting API") - tg.start_soon(self._apply_state) + tg.start_soon(self._bootstrap_then_apply_state) tg.start_soon(self._pause_on_new_election) tg.start_soon(self._cleanup_expired_images) print_startup_banner(self.port) @@ -1850,6 +1873,7 @@ async def run(self): self._event_log.close() self.command_sender.close() self.event_receiver.close() + self.snapshot_chunk_receiver.close() async def run_api(self, ev: anyio.Event): cfg = Config() @@ -1865,9 +1889,43 @@ async def run_api(self, ev: anyio.Event): shutdown_trigger=ev.wait, ) + async def _bootstrap_then_apply_state(self): + await self._fetch_snapshot() + await self._apply_state() + + async def _fetch_snapshot(self) -> None: + receiver = SnapshotReceiver(self.node_id, self.session_id) + await self.command_sender.send( + ForwarderCommand( + origin=self._system_id, + command=RequestSnapshot(requester_node_id=self.node_id), + ) + ) + + with anyio.move_on_after(_SNAPSHOT_FETCH_TIMEOUT_SECONDS): + with self.snapshot_chunk_receiver as chunks: + async for chunk in chunks: + received = receiver.ingest(chunk) + if received is None: + continue + self.state = received.state + self.event_router.set_buffer_start( + received.last_event_applied_idx + 1 + ) + logger.info( + f"API bootstrapped from snapshot at idx " + f"{received.last_event_applied_idx}" + ) + return + logger.info( + "API: no snapshot received before timeout; falling back to full event-log replay" + ) + async def _apply_state(self): with self.event_receiver as events: async for i_event in events: + if i_event.idx <= self.state.last_event_applied_idx: + continue self._event_log.append(i_event.event) self.state = apply(self.state, i_event) event = i_event.event diff --git a/src/exo/api/tests/test_api_snapshot_bootstrap.py b/src/exo/api/tests/test_api_snapshot_bootstrap.py new file mode 100644 index 0000000000..1223d3c3d0 --- /dev/null +++ b/src/exo/api/tests/test_api_snapshot_bootstrap.py @@ -0,0 +1,136 @@ +# pyright: reportPrivateUsage=false + +import hashlib + +import anyio +import pytest +import zstandard + +from exo.api.main import API +from exo.routing.event_router import EventRouter +from exo.shared.types.commands import ForwarderCommand, RequestSnapshot +from exo.shared.types.common import NodeId, SessionId, SystemId +from exo.shared.types.events import ( + Event, + GlobalForwarderEvent, + IndexedEvent, + LocalForwarderEvent, + TestEvent, +) +from exo.shared.types.snapshots import SnapshotChunk, SnapshotTransferId +from exo.shared.types.state import State +from exo.utils.channels import Receiver, Sender, channel + + +class _FakeEventLog: + def __init__(self) -> None: + self.appended: list[Event] = [] + + def append(self, event: Event) -> None: + self.appended.append(event) + + +def _snapshot_chunk( + state: State, *, requester_node_id: NodeId, session_id: SessionId +) -> SnapshotChunk: + body = zstandard.ZstdCompressor().compress(state.model_dump_json().encode("utf-8")) + return SnapshotChunk.from_data( + data=body, + transfer_id=SnapshotTransferId("transfer-1"), + requester_node_id=requester_node_id, + session_id=session_id, + schema_version=state.schema_version, + last_event_applied_idx=state.last_event_applied_idx, + chunk_index=0, + total_chunks=1, + sha256_hex=hashlib.sha256(body).hexdigest(), + ) + + +def _api( + node_id: NodeId, session_id: SessionId +) -> tuple[ + API, + EventRouter, + Receiver[ForwarderCommand], + Sender[SnapshotChunk], + Sender[IndexedEvent], + _FakeEventLog, +]: + router_command_sender, _router_command_receiver = channel[ForwarderCommand]() + _global_event_sender, global_event_receiver = channel[GlobalForwarderEvent]() + local_event_sender, _local_event_receiver = channel[LocalForwarderEvent]() + event_router = EventRouter( + session_id=session_id, + command_sender=router_command_sender, + external_inbound=global_event_receiver, + external_outbound=local_event_sender, + ) + + event_sender, event_receiver = channel[IndexedEvent]() + command_sender, command_receiver = channel[ForwarderCommand]() + snapshot_sender, snapshot_receiver = channel[SnapshotChunk]() + + api = object.__new__(API) + api.node_id = node_id + api.session_id = session_id + api.event_router = event_router + api.event_receiver = event_receiver + api.snapshot_chunk_receiver = snapshot_receiver + api.command_sender = command_sender + api._system_id = SystemId("api-system") + api.state = State() + event_log = _FakeEventLog() + api._event_log = event_log # pyright: ignore[reportAttributeAccessIssue] + api._image_generation_queues = {} + api._text_generation_queues = {} + return api, event_router, command_receiver, snapshot_sender, event_sender, event_log + + +@pytest.mark.asyncio +async def test_api_fetch_snapshot_applies_state_and_fast_forwards_router() -> None: + node_id = NodeId("api") + session_id = SessionId(master_node_id=NodeId("master"), election_clock=1) + api, event_router, command_receiver, snapshot_sender, _event_sender, _event_log = ( + _api(node_id, session_id) + ) + state = State(last_event_applied_idx=7) + + async with anyio.create_task_group() as tg: + tg.start_soon(api._fetch_snapshot) + command = await command_receiver.receive() + assert isinstance(command.command, RequestSnapshot) + assert command.command.requester_node_id == node_id + + await snapshot_sender.send( + _snapshot_chunk(state, requester_node_id=node_id, session_id=session_id) + ) + + assert api.state.last_event_applied_idx == 7 + assert event_router.event_buffer.next_idx_to_release == 8 + + +@pytest.mark.asyncio +async def test_api_apply_state_ignores_events_covered_by_snapshot() -> None: + node_id = NodeId("api") + session_id = SessionId(master_node_id=NodeId("master"), election_clock=1) + ( + api, + _event_router, + _command_receiver, + _snapshot_sender, + event_sender, + event_log, + ) = _api(node_id, session_id) + api.state = State(last_event_applied_idx=7) + + async with anyio.create_task_group() as tg: + tg.start_soon(api._apply_state) + await event_sender.send(IndexedEvent(idx=7, event=TestEvent())) + await event_sender.send(IndexedEvent(idx=8, event=TestEvent())) + + while api.state.last_event_applied_idx != 8: + await anyio.sleep(0.001) + tg.cancel_scope.cancel() + + assert len(event_log.appended) == 1 diff --git a/src/exo/main.py b/src/exo/main.py index 24c598d095..293e575572 100644 --- a/src/exo/main.py +++ b/src/exo/main.py @@ -84,8 +84,11 @@ async def create(cls, args: "Args") -> Self: if args.spawn_api: api = API( node_id, + session_id, port=args.api_port, + event_router=event_router, event_receiver=event_router.receiver(), + snapshot_chunk_receiver=router.receiver(topics.SNAPSHOT_RESPONSES), command_sender=router.sender(topics.COMMANDS), download_command_sender=router.sender(topics.DOWNLOAD_COMMANDS), election_receiver=router.receiver(topics.ELECTION_MESSAGES), @@ -269,7 +272,13 @@ async def _elect_loop(self): ) self._tg.start_soon(self.worker.run) if self.api: - self.api.reset(result.won_clock, self.event_router.receiver()) + self.api.reset( + result.won_clock, + result.session_id, + self.event_router, + self.event_router.receiver(), + self.router.receiver(topics.SNAPSHOT_RESPONSES), + ) self._tg.start_soon(self.event_router.run) else: if self.api: From e1a79a0918eb10901cfa2bdb8aa08aa6efc0ee54 Mon Sep 17 00:00:00 2001 From: Alex Cheema Date: Sun, 3 May 2026 01:42:27 +0100 Subject: [PATCH 09/11] feat: store custom model cards in state --- src/exo/shared/apply.py | 28 ++++++++++-- .../test_apply_custom_model_cards.py | 44 +++++++++++++++++++ src/exo/shared/types/state.py | 7 ++- 3 files changed, 75 insertions(+), 4 deletions(-) create mode 100644 src/exo/shared/tests/test_apply/test_apply_custom_model_cards.py diff --git a/src/exo/shared/apply.py b/src/exo/shared/apply.py index b3ff361980..3923abba17 100644 --- a/src/exo/shared/apply.py +++ b/src/exo/shared/apply.py @@ -4,7 +4,8 @@ from loguru import logger -from exo.shared.types.common import NodeId +from exo.shared.models.model_cards import ModelCard +from exo.shared.types.common import ModelId, NodeId from exo.shared.types.events import ( ChunkGenerated, CustomModelCardAdded, @@ -81,10 +82,12 @@ def event_apply(event: Event, state: State) -> State: | TaskAcknowledged() | TracesCollected() | TracesMerged() - | CustomModelCardAdded() - | CustomModelCardDeleted() ): # Pass-through events that don't modify state return state + case CustomModelCardAdded(): + return apply_custom_model_card_added(event, state) + case CustomModelCardDeleted(): + return apply_custom_model_card_deleted(event, state) case InstanceCreated(): return apply_instance_created(event, state) case InstanceDeleted(): @@ -477,3 +480,22 @@ def apply_topology_edge_deleted(event: TopologyEdgeDeleted, state: State) -> Sta topology.remove_connection(event.conn) # TODO: Clean up removing the reverse connection return state.model_copy(update={"topology": topology}) + + +def apply_custom_model_card_added(event: CustomModelCardAdded, state: State) -> State: + new_cards: Mapping[ModelId, ModelCard] = { + **state.custom_model_cards, + event.model_card.model_id: event.model_card, + } + return state.model_copy(update={"custom_model_cards": new_cards}) + + +def apply_custom_model_card_deleted( + event: CustomModelCardDeleted, state: State +) -> State: + new_cards: Mapping[ModelId, ModelCard] = { + model_id: card + for model_id, card in state.custom_model_cards.items() + if model_id != event.model_id + } + return state.model_copy(update={"custom_model_cards": new_cards}) diff --git a/src/exo/shared/tests/test_apply/test_apply_custom_model_cards.py b/src/exo/shared/tests/test_apply/test_apply_custom_model_cards.py new file mode 100644 index 0000000000..b5b3066c26 --- /dev/null +++ b/src/exo/shared/tests/test_apply/test_apply_custom_model_cards.py @@ -0,0 +1,44 @@ +from exo.shared.apply import apply +from exo.shared.models.model_cards import ModelCard, ModelTask +from exo.shared.types.common import ModelId +from exo.shared.types.events import ( + CustomModelCardAdded, + CustomModelCardDeleted, + IndexedEvent, +) +from exo.shared.types.memory import Memory +from exo.shared.types.state import State + + +def _model_card(model_id: ModelId) -> ModelCard: + return ModelCard( + model_id=model_id, + n_layers=1, + storage_size=Memory.from_bytes(1), + hidden_size=1, + supports_tensor=True, + tasks=[ModelTask.TextGeneration], + ) + + +def test_custom_model_card_added_is_reduced_into_state() -> None: + card = _model_card(ModelId("custom/model")) + + state = apply( + State(), + IndexedEvent(idx=0, event=CustomModelCardAdded(model_card=card)), + ) + + assert state.custom_model_cards == {card.model_id: card} + + +def test_custom_model_card_deleted_removes_card_from_state() -> None: + card = _model_card(ModelId("custom/model")) + state = State(custom_model_cards={card.model_id: card}, last_event_applied_idx=0) + + state = apply( + state, + IndexedEvent(idx=1, event=CustomModelCardDeleted(model_id=card.model_id)), + ) + + assert state.custom_model_cards == {} diff --git a/src/exo/shared/types/state.py b/src/exo/shared/types/state.py index 21199e7af3..e8d35ec1b9 100644 --- a/src/exo/shared/types/state.py +++ b/src/exo/shared/types/state.py @@ -5,9 +5,10 @@ from pydantic import ConfigDict, Field, field_serializer, field_validator from pydantic.alias_generators import to_camel +from exo.shared.models.model_cards import ModelCard from exo.shared.topology import Topology, TopologySnapshot from exo.shared.types.chunks import InputImageChunk -from exo.shared.types.common import CommandId, NodeId +from exo.shared.types.common import CommandId, ModelId, NodeId from exo.shared.types.instance_link import InstanceLink, InstanceLinkId from exo.shared.types.profiling import ( DiskUsage, @@ -72,6 +73,10 @@ class State(FrozenModel): instance_links: Mapping[InstanceLinkId, InstanceLink] = {} prefill_server_ports: Mapping[RunnerId, int] = {} + # User-added model cards. Workers can reconcile their on-disk custom card + # cache from this state after snapshot bootstrap. + custom_model_cards: Mapping[ModelId, ModelCard] = {} + @field_serializer("topology", mode="plain") def _encode_topology(self, value: Topology) -> TopologySnapshot: return value.to_snapshot() From dbc6286066fc81abcd3edea0c29e4ef927f97847 Mon Sep 17 00:00:00 2001 From: Alex Cheema Date: Sun, 3 May 2026 01:44:49 +0100 Subject: [PATCH 10/11] feat: reconcile custom model cards from state --- src/exo/worker/main.py | 36 ++++++--- .../unittests/test_worker_custom_cards.py | 78 +++++++++++++++++++ 2 files changed, 105 insertions(+), 9 deletions(-) create mode 100644 src/exo/worker/tests/unittests/test_worker_custom_cards.py diff --git a/src/exo/worker/main.py b/src/exo/worker/main.py index 254c27fead..3e2aa94673 100644 --- a/src/exo/worker/main.py +++ b/src/exo/worker/main.py @@ -12,7 +12,12 @@ from exo.routing.snapshot_receiver import SnapshotReceiver from exo.shared.apply import apply from exo.shared.constants import EXO_MAX_INSTANCE_RETRIES -from exo.shared.models.model_cards import ModelId, add_to_card_cache, delete_custom_card +from exo.shared.models.model_cards import ( + ModelCard, + ModelId, + add_to_card_cache, + delete_custom_card, +) from exo.shared.types.chunks import InputImageChunk from exo.shared.types.commands import ( DeleteInstance, @@ -23,8 +28,6 @@ ) from exo.shared.types.common import CommandId, NodeId, SessionId, SystemId from exo.shared.types.events import ( - CustomModelCardAdded, - CustomModelCardDeleted, Event, IndexedEvent, InstanceDeleted, @@ -106,6 +109,7 @@ def __init__( self._instance_backoff: KeyedBackoff[InstanceId] = KeyedBackoff( base=0.5, cap=10.0 ) + self._synced_custom_cards: dict[ModelId, ModelCard] = {} self._stopped: anyio.Event = anyio.Event() async def run(self): @@ -137,6 +141,7 @@ async def _bootstrap_then_run( self._tg.start_soon(self._forward_info, info_recv) self._tg.start_soon(self.plan_step) self._tg.start_soon(self._event_applier) + self._tg.start_soon(self._reconcile_custom_cards) self._tg.start_soon(self._poll_connection_updates) async def _fetch_snapshot(self) -> None: @@ -190,14 +195,27 @@ async def _event_applier(self): if isinstance(event, InstanceDeleted): self._instance_backoff.reset(event.instance_id) - if isinstance(event, CustomModelCardAdded): - await event.model_card.save_to_custom_dir() - add_to_card_cache(event.model_card) + self._sync_input_views_from_state() - if isinstance(event, CustomModelCardDeleted): - await delete_custom_card(event.model_id) + async def _reconcile_custom_cards(self) -> None: + while True: + await anyio.sleep(1) + await self._sync_custom_cards_from_state() - self._sync_input_views_from_state() + async def _sync_custom_cards_from_state(self) -> None: + target = dict(self.state.custom_model_cards) + for model_id, card in target.items(): + if self._synced_custom_cards.get(model_id) == card: + continue + await card.save_to_custom_dir() + add_to_card_cache(card) + self._synced_custom_cards[model_id] = card + + for model_id in list(self._synced_custom_cards): + if model_id in target: + continue + await delete_custom_card(model_id) + self._synced_custom_cards.pop(model_id, None) def _sync_input_views_from_state(self) -> None: self.input_chunk_buffer = { diff --git a/src/exo/worker/tests/unittests/test_worker_custom_cards.py b/src/exo/worker/tests/unittests/test_worker_custom_cards.py new file mode 100644 index 0000000000..1311329c64 --- /dev/null +++ b/src/exo/worker/tests/unittests/test_worker_custom_cards.py @@ -0,0 +1,78 @@ +# pyright: reportPrivateUsage=false + +import pytest + +import exo.worker.main as worker_main +from exo.shared.models.model_cards import ModelCard, ModelTask +from exo.shared.types.common import ModelId +from exo.shared.types.memory import Memory +from exo.shared.types.state import State +from exo.worker.main import Worker + + +def _model_card(model_id: ModelId) -> ModelCard: + return ModelCard( + model_id=model_id, + n_layers=1, + storage_size=Memory.from_bytes(1), + hidden_size=1, + supports_tensor=True, + tasks=[ModelTask.TextGeneration], + ) + + +def _worker() -> Worker: + worker = object.__new__(Worker) + worker.state = State() + worker._synced_custom_cards = {} + return worker + + +@pytest.mark.asyncio +async def test_worker_syncs_custom_cards_from_state( + monkeypatch: pytest.MonkeyPatch, +) -> None: + saved: list[ModelId] = [] + cached: list[ModelId] = [] + + async def save_to_custom_dir(card: ModelCard) -> None: + saved.append(card.model_id) + + def add_to_card_cache(card: ModelCard) -> None: + cached.append(card.model_id) + + monkeypatch.setattr(ModelCard, "save_to_custom_dir", save_to_custom_dir) + monkeypatch.setattr(worker_main, "add_to_card_cache", add_to_card_cache) + + card = _model_card(ModelId("custom/model")) + worker = _worker() + worker.state = State(custom_model_cards={card.model_id: card}) + + await worker._sync_custom_cards_from_state() + await worker._sync_custom_cards_from_state() + + assert saved == [card.model_id] + assert cached == [card.model_id] + assert worker._synced_custom_cards == {card.model_id: card} + + +@pytest.mark.asyncio +async def test_worker_deletes_custom_cards_missing_from_state( + monkeypatch: pytest.MonkeyPatch, +) -> None: + deleted: list[ModelId] = [] + + async def delete_custom_card(model_id: ModelId) -> bool: + deleted.append(model_id) + return True + + monkeypatch.setattr(worker_main, "delete_custom_card", delete_custom_card) + + card = _model_card(ModelId("custom/model")) + worker = _worker() + worker._synced_custom_cards = {card.model_id: card} + + await worker._sync_custom_cards_from_state() + + assert deleted == [card.model_id] + assert worker._synced_custom_cards == {} From 7a25b4186d763f39bec2abd7154333713087b382 Mon Sep 17 00:00:00 2001 From: Alex Cheema Date: Sun, 3 May 2026 01:58:47 +0100 Subject: [PATCH 11/11] feat: reconcile api streams from state --- src/exo/api/main.py | 52 +++++++--- .../test_instance_deleted_stream_cleanup.py | 98 ++++++++++++++----- 2 files changed, 110 insertions(+), 40 deletions(-) diff --git a/src/exo/api/main.py b/src/exo/api/main.py index 927b9bc3fa..9d3e9216a7 100644 --- a/src/exo/api/main.py +++ b/src/exo/api/main.py @@ -179,7 +179,6 @@ ChunkGenerated, Event, IndexedEvent, - InstanceDeleted, TracesMerged, ) from exo.shared.types.instance_link import InstanceLink, InstanceLinkId @@ -300,6 +299,7 @@ async def _log_requests( # pyright: ignore[reportUnusedFunction] self._image_generation_queues: dict[ CommandId, Sender[ImageChunk | ErrorChunk] ] = {} + self._observed_generation_commands: set[CommandId] = set() self._image_store = ImageStore(EXO_IMAGE_CACHE_DIR) self._tg: TaskGroup = TaskGroup() @@ -320,6 +320,7 @@ def reset( self.event_router = event_router self._text_generation_queues = {} self._image_generation_queues = {} + self._observed_generation_commands = set() self.unpause(result_clock) self.event_receiver.close() self.event_receiver = event_receiver @@ -1860,6 +1861,7 @@ async def run(self): async with self._tg as tg: logger.info("Starting API") tg.start_soon(self._bootstrap_then_apply_state) + tg.start_soon(self._reconcile_streams) tg.start_soon(self._pause_on_new_election) tg.start_soon(self._cleanup_expired_images) print_startup_banner(self.port) @@ -1947,23 +1949,47 @@ async def _apply_state(self): await queue.send(event.chunk) except (BrokenResourceError, ClosedResourceError): self._text_generation_queues.pop(event.command_id, None) - if isinstance(event, InstanceDeleted): - self._close_streams_for_instance(event.instance_id) if isinstance(event, TracesMerged): self._save_merged_trace(event) - def _close_streams_for_instance(self, instance_id: InstanceId) -> None: - """Close any active generation streams for commands running on the given instance.""" - for task in self.state.tasks.values(): - if task.instance_id != instance_id: - continue - if not isinstance( + async def _reconcile_streams(self) -> None: + while True: + await anyio.sleep(1) + self._reconcile_streams_once() + + def _reconcile_streams_once(self) -> None: + generation_tasks = [ + task + for task in self.state.tasks.values() + if isinstance( task, (TextGenerationTask, ImageGenerationTask, ImageEditsTask) - ): - continue - if sender := self._text_generation_queues.pop(task.command_id, None): + ) + ] + state_command_ids = {task.command_id for task in generation_tasks} + self._observed_generation_commands.update(state_command_ids) + + live_command_ids = { + task.command_id + for task in generation_tasks + if task.instance_id in self.state.instances + } + queued_command_ids = set(self._text_generation_queues) | set( + self._image_generation_queues + ) + stale_command_ids = ( + self._observed_generation_commands - live_command_ids + ) & queued_command_ids + self._close_streams_for_commands(stale_command_ids) + + self._observed_generation_commands = ( + self._observed_generation_commands & queued_command_ids + ) | state_command_ids + + def _close_streams_for_commands(self, command_ids: set[CommandId]) -> None: + for command_id in command_ids: + if sender := self._text_generation_queues.pop(command_id, None): sender.close() - if sender := self._image_generation_queues.pop(task.command_id, None): + if sender := self._image_generation_queues.pop(command_id, None): sender.close() def _save_merged_trace(self, event: TracesMerged) -> None: diff --git a/src/exo/api/tests/test_instance_deleted_stream_cleanup.py b/src/exo/api/tests/test_instance_deleted_stream_cleanup.py index fb03c5da66..3a18f74d75 100644 --- a/src/exo/api/tests/test_instance_deleted_stream_cleanup.py +++ b/src/exo/api/tests/test_instance_deleted_stream_cleanup.py @@ -1,11 +1,11 @@ # pyright: reportUnusedFunction=false, reportAny=false -"""Tests that InstanceDeleted events close active generation streams.""" +"""Tests that streaming queues reconcile against durable State.""" from unittest.mock import MagicMock from exo.api.main import API from exo.api.types import ImageGenerationTaskParams -from exo.shared.types.common import CommandId, ModelId +from exo.shared.types.common import CommandId, ModelId, NodeId from exo.shared.types.state import State from exo.shared.types.tasks import ImageGeneration, TextGeneration from exo.shared.types.text_generation import ( @@ -13,15 +13,16 @@ InputMessageContent, TextGenerationTaskParams, ) -from exo.shared.types.worker.instances import InstanceId +from exo.shared.types.worker.instances import InstanceId, MlxRingInstance +from exo.shared.types.worker.runners import ShardAssignments def _make_api_with_state(state: State) -> API: - """Create a minimal API instance with pre-set state.""" api = object.__new__(API) api.state = state api._text_generation_queues = {} # pyright: ignore[reportPrivateUsage] api._image_generation_queues = {} # pyright: ignore[reportPrivateUsage] + api._observed_generation_commands = set() # pyright: ignore[reportPrivateUsage] return api @@ -38,45 +39,90 @@ def _make_text_gen_task( ) -def test_close_streams_for_deleted_instance() -> None: - """Deleting an instance closes the text generation sender for commands on that instance.""" +def _make_instance(instance_id: InstanceId) -> MlxRingInstance: + return MlxRingInstance( + instance_id=instance_id, + shard_assignments=ShardAssignments( + model_id=ModelId("test-model"), + node_to_runner={}, + runner_to_shard={}, + ), + hosts_by_node={NodeId("node-1"): []}, + ephemeral_port=1, + ) + + +def test_reconcile_closes_stream_when_task_instance_is_missing() -> None: instance_id = InstanceId("inst-1") command_id = CommandId("cmd-1") task = _make_text_gen_task(instance_id, command_id) - - state = State(tasks={task.task_id: task}) - api = _make_api_with_state(state) + api = _make_api_with_state(State(tasks={task.task_id: task}, instances={})) sender = MagicMock() api._text_generation_queues[command_id] = sender # pyright: ignore[reportPrivateUsage] - api._close_streams_for_instance(instance_id) # pyright: ignore[reportPrivateUsage] + api._reconcile_streams_once() # pyright: ignore[reportPrivateUsage] sender.close.assert_called_once() assert command_id not in api._text_generation_queues # pyright: ignore[reportPrivateUsage] -def test_close_streams_ignores_unrelated_instances() -> None: - """Deleting an instance does NOT close streams for commands on other instances.""" - target_id = InstanceId("inst-delete") - other_id = InstanceId("inst-keep") - other_cmd = CommandId("cmd-keep") - other_task = _make_text_gen_task(other_id, other_cmd) +def test_reconcile_keeps_stream_for_live_task_instance() -> None: + instance_id = InstanceId("inst-live") + command_id = CommandId("cmd-live") + task = _make_text_gen_task(instance_id, command_id) + api = _make_api_with_state( + State( + tasks={task.task_id: task}, + instances={instance_id: _make_instance(instance_id)}, + ) + ) + + sender = MagicMock() + api._text_generation_queues[command_id] = sender # pyright: ignore[reportPrivateUsage] + + api._reconcile_streams_once() # pyright: ignore[reportPrivateUsage] + + sender.close.assert_not_called() + assert command_id in api._text_generation_queues # pyright: ignore[reportPrivateUsage] + - state = State(tasks={other_task.task_id: other_task}) - api = _make_api_with_state(state) +def test_reconcile_does_not_close_command_before_state_observes_it() -> None: + command_id = CommandId("cmd-not-created-yet") + api = _make_api_with_state(State()) sender = MagicMock() - api._text_generation_queues[other_cmd] = sender # pyright: ignore[reportPrivateUsage] + api._text_generation_queues[command_id] = sender # pyright: ignore[reportPrivateUsage] - api._close_streams_for_instance(target_id) # pyright: ignore[reportPrivateUsage] + api._reconcile_streams_once() # pyright: ignore[reportPrivateUsage] sender.close.assert_not_called() - assert other_cmd in api._text_generation_queues # pyright: ignore[reportPrivateUsage] + assert command_id in api._text_generation_queues # pyright: ignore[reportPrivateUsage] + + +def test_reconcile_closes_stream_after_observed_task_leaves_state() -> None: + instance_id = InstanceId("inst-live") + command_id = CommandId("cmd-deleted") + task = _make_text_gen_task(instance_id, command_id) + api = _make_api_with_state( + State( + tasks={task.task_id: task}, + instances={instance_id: _make_instance(instance_id)}, + ) + ) + + sender = MagicMock() + api._text_generation_queues[command_id] = sender # pyright: ignore[reportPrivateUsage] + api._reconcile_streams_once() # pyright: ignore[reportPrivateUsage] + api.state = State(instances={instance_id: _make_instance(instance_id)}) + api._reconcile_streams_once() # pyright: ignore[reportPrivateUsage] -def test_close_streams_for_deleted_instance_image_generation() -> None: - """Deleting an instance closes the image generation sender for commands on that instance.""" + sender.close.assert_called_once() + assert command_id not in api._text_generation_queues # pyright: ignore[reportPrivateUsage] + + +def test_reconcile_closes_image_stream_when_task_instance_is_missing() -> None: instance_id = InstanceId("inst-img") command_id = CommandId("cmd-img") task = ImageGeneration( @@ -84,14 +130,12 @@ def test_close_streams_for_deleted_instance_image_generation() -> None: command_id=command_id, task_params=ImageGenerationTaskParams(prompt="a cat", model="test-model"), ) - - state = State(tasks={task.task_id: task}) - api = _make_api_with_state(state) + api = _make_api_with_state(State(tasks={task.task_id: task}, instances={})) sender = MagicMock() api._image_generation_queues[command_id] = sender # pyright: ignore[reportPrivateUsage] - api._close_streams_for_instance(instance_id) # pyright: ignore[reportPrivateUsage] + api._reconcile_streams_once() # pyright: ignore[reportPrivateUsage] sender.close.assert_called_once() assert command_id not in api._image_generation_queues # pyright: ignore[reportPrivateUsage]