Skip to content
132 changes: 102 additions & 30 deletions src/exo/api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,8 @@
)
from exo.master.image_store import ImageStore
from exo.master.placement import place_instance as get_instance_placements
from exo.routing.event_router import EventRouter
from exo.routing.snapshot_receiver import SnapshotReceiver
from exo.shared.apply import apply
from exo.shared.constants import (
DASHBOARD_DIR,
Expand Down Expand Up @@ -164,23 +166,24 @@
ImageEdits,
ImageGeneration,
PlaceInstance,
RequestSnapshot,
SendInputChunk,
SetInstanceLink,
StartDownload,
TaskCancelled,
TaskFinished,
TextGeneration,
)
from exo.shared.types.common import CommandId, Id, NodeId, SystemId
from exo.shared.types.common import CommandId, Id, NodeId, SessionId, SystemId
from exo.shared.types.events import (
ChunkGenerated,
Event,
IndexedEvent,
InstanceDeleted,
TracesMerged,
)
from exo.shared.types.instance_link import InstanceLink, InstanceLinkId
from exo.shared.types.memory import Memory
from exo.shared.types.snapshots import SnapshotChunk
from exo.shared.types.state import State
from exo.shared.types.tasks import (
ImageEdits as ImageEditsTask,
Expand All @@ -207,6 +210,8 @@
_API_EVENT_LOG_DIR = EXO_EVENT_LOG_DIR / "api"
ONBOARDING_COMPLETE_FILE = EXO_CACHE_HOME / "onboarding_complete"

_SNAPSHOT_FETCH_TIMEOUT_SECONDS = 30


def _format_to_content_type(image_format: Literal["png", "jpeg", "webp"] | None) -> str:
return f"image/{image_format or 'png'}"
Expand Down Expand Up @@ -236,9 +241,12 @@ class API:
def __init__(
self,
node_id: NodeId,
session_id: SessionId,
*,
port: int,
event_router: EventRouter,
event_receiver: Receiver[IndexedEvent],
snapshot_chunk_receiver: Receiver[SnapshotChunk],
command_sender: Sender[ForwarderCommand],
download_command_sender: Sender[ForwarderDownloadCommand],
# This lets us pause the API if an election is running
Expand All @@ -247,14 +255,16 @@ def __init__(
self.state = State()
self._event_log = DiskEventLog(_API_EVENT_LOG_DIR)
self._system_id = SystemId()
self.session_id = session_id
self.event_router = event_router
self.command_sender = command_sender
self.download_command_sender = download_command_sender
self.event_receiver = event_receiver
self.snapshot_chunk_receiver = snapshot_chunk_receiver
self.election_receiver = election_receiver
self.node_id: NodeId = node_id
self.last_completed_election: int = 0
self.port = port
self._sent_image_hashes: set[str] = set()

self.paused: bool = False
self.paused_ev: anyio.Event = anyio.Event()
Expand Down Expand Up @@ -289,22 +299,34 @@ async def _log_requests( # pyright: ignore[reportUnusedFunction]
self._image_generation_queues: dict[
CommandId, Sender[ImageChunk | ErrorChunk]
] = {}
self._observed_generation_commands: set[CommandId] = set()
self._image_store = ImageStore(EXO_IMAGE_CACHE_DIR)
self._tg: TaskGroup = TaskGroup()

def reset(self, result_clock: int, event_receiver: Receiver[IndexedEvent]):
def reset(
self,
result_clock: int,
session_id: SessionId,
event_router: EventRouter,
event_receiver: Receiver[IndexedEvent],
snapshot_chunk_receiver: Receiver[SnapshotChunk],
):
logger.info("Resetting API State")
self._event_log.close()
self._event_log = DiskEventLog(_API_EVENT_LOG_DIR)
self.state = State()
self._system_id = SystemId()
self.session_id = session_id
self.event_router = event_router
self._text_generation_queues = {}
self._image_generation_queues = {}
self._observed_generation_commands = set()
self.unpause(result_clock)
self.event_receiver.close()
self.event_receiver = event_receiver
self._tg.start_soon(self._apply_state)
self._sent_image_hashes = set()
self.snapshot_chunk_receiver.close()
self.snapshot_chunk_receiver = snapshot_chunk_receiver
self._tg.start_soon(self._bootstrap_then_apply_state)

def unpause(self, result_clock: int):
logger.info("Unpausing API")
Expand Down Expand Up @@ -826,18 +848,8 @@ async def _send_text_generation_with_images(
)
command = TextGeneration(task_params=task_params)

new_images: list[tuple[int, str]] = []
for idx, (img, h) in enumerate(zip(images, hashes, strict=True)):
if h not in self._sent_image_hashes:
self._sent_image_hashes.add(h)
new_images.append((idx, img))

if not new_images:
await self._send(command)
return command

all_chunks: list[tuple[int, str]] = []
for img_idx, img_data in new_images:
for img_idx, img_data in enumerate(images):
for i in range(0, len(img_data), EXO_MAX_CHUNK_SIZE):
all_chunks.append((img_idx, img_data[i : i + EXO_MAX_CHUNK_SIZE]))

Expand Down Expand Up @@ -1848,7 +1860,8 @@ async def run(self):
try:
async with self._tg as tg:
logger.info("Starting API")
tg.start_soon(self._apply_state)
tg.start_soon(self._bootstrap_then_apply_state)
tg.start_soon(self._reconcile_streams)
tg.start_soon(self._pause_on_new_election)
tg.start_soon(self._cleanup_expired_images)
print_startup_banner(self.port)
Expand All @@ -1862,6 +1875,7 @@ async def run(self):
self._event_log.close()
self.command_sender.close()
self.event_receiver.close()
self.snapshot_chunk_receiver.close()

async def run_api(self, ev: anyio.Event):
cfg = Config()
Expand All @@ -1877,9 +1891,43 @@ async def run_api(self, ev: anyio.Event):
shutdown_trigger=ev.wait,
)

async def _bootstrap_then_apply_state(self):
await self._fetch_snapshot()
await self._apply_state()

async def _fetch_snapshot(self) -> None:
receiver = SnapshotReceiver(self.node_id, self.session_id)
await self.command_sender.send(
ForwarderCommand(
origin=self._system_id,
command=RequestSnapshot(requester_node_id=self.node_id),
)
)

with anyio.move_on_after(_SNAPSHOT_FETCH_TIMEOUT_SECONDS):
with self.snapshot_chunk_receiver as chunks:
async for chunk in chunks:
received = receiver.ingest(chunk)
if received is None:
continue
self.state = received.state
self.event_router.set_buffer_start(
received.last_event_applied_idx + 1
)
logger.info(
f"API bootstrapped from snapshot at idx "
f"{received.last_event_applied_idx}"
)
return
logger.info(
"API: no snapshot received before timeout; falling back to full event-log replay"
)

async def _apply_state(self):
with self.event_receiver as events:
async for i_event in events:
if i_event.idx <= self.state.last_event_applied_idx:
continue
self._event_log.append(i_event.event)
self.state = apply(self.state, i_event)
event = i_event.event
Expand All @@ -1901,23 +1949,47 @@ async def _apply_state(self):
await queue.send(event.chunk)
except (BrokenResourceError, ClosedResourceError):
self._text_generation_queues.pop(event.command_id, None)
if isinstance(event, InstanceDeleted):
self._close_streams_for_instance(event.instance_id)
if isinstance(event, TracesMerged):
self._save_merged_trace(event)

def _close_streams_for_instance(self, instance_id: InstanceId) -> None:
"""Close any active generation streams for commands running on the given instance."""
for task in self.state.tasks.values():
if task.instance_id != instance_id:
continue
if not isinstance(
async def _reconcile_streams(self) -> None:
while True:
await anyio.sleep(1)
self._reconcile_streams_once()

def _reconcile_streams_once(self) -> None:
generation_tasks = [
task
for task in self.state.tasks.values()
if isinstance(
task, (TextGenerationTask, ImageGenerationTask, ImageEditsTask)
):
continue
if sender := self._text_generation_queues.pop(task.command_id, None):
)
]
state_command_ids = {task.command_id for task in generation_tasks}
self._observed_generation_commands.update(state_command_ids)

live_command_ids = {
task.command_id
for task in generation_tasks
if task.instance_id in self.state.instances
}
queued_command_ids = set(self._text_generation_queues) | set(
self._image_generation_queues
)
stale_command_ids = (
self._observed_generation_commands - live_command_ids
) & queued_command_ids
self._close_streams_for_commands(stale_command_ids)

self._observed_generation_commands = (
self._observed_generation_commands & queued_command_ids
) | state_command_ids

def _close_streams_for_commands(self, command_ids: set[CommandId]) -> None:
for command_id in command_ids:
if sender := self._text_generation_queues.pop(command_id, None):
sender.close()
if sender := self._image_generation_queues.pop(task.command_id, None):
if sender := self._image_generation_queues.pop(command_id, None):
sender.close()

def _save_merged_trace(self, event: TracesMerged) -> None:
Expand Down
136 changes: 136 additions & 0 deletions src/exo/api/tests/test_api_snapshot_bootstrap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
# pyright: reportPrivateUsage=false

import hashlib

import anyio
import pytest
import zstandard

from exo.api.main import API
from exo.routing.event_router import EventRouter
from exo.shared.types.commands import ForwarderCommand, RequestSnapshot
from exo.shared.types.common import NodeId, SessionId, SystemId
from exo.shared.types.events import (
Event,
GlobalForwarderEvent,
IndexedEvent,
LocalForwarderEvent,
TestEvent,
)
from exo.shared.types.snapshots import SnapshotChunk, SnapshotTransferId
from exo.shared.types.state import State
from exo.utils.channels import Receiver, Sender, channel


class _FakeEventLog:
def __init__(self) -> None:
self.appended: list[Event] = []

def append(self, event: Event) -> None:
self.appended.append(event)


def _snapshot_chunk(
state: State, *, requester_node_id: NodeId, session_id: SessionId
) -> SnapshotChunk:
body = zstandard.ZstdCompressor().compress(state.model_dump_json().encode("utf-8"))
return SnapshotChunk.from_data(
data=body,
transfer_id=SnapshotTransferId("transfer-1"),
requester_node_id=requester_node_id,
session_id=session_id,
schema_version=state.schema_version,
last_event_applied_idx=state.last_event_applied_idx,
chunk_index=0,
total_chunks=1,
sha256_hex=hashlib.sha256(body).hexdigest(),
)


def _api(
node_id: NodeId, session_id: SessionId
) -> tuple[
API,
EventRouter,
Receiver[ForwarderCommand],
Sender[SnapshotChunk],
Sender[IndexedEvent],
_FakeEventLog,
]:
router_command_sender, _router_command_receiver = channel[ForwarderCommand]()
_global_event_sender, global_event_receiver = channel[GlobalForwarderEvent]()
local_event_sender, _local_event_receiver = channel[LocalForwarderEvent]()
event_router = EventRouter(
session_id=session_id,
command_sender=router_command_sender,
external_inbound=global_event_receiver,
external_outbound=local_event_sender,
)

event_sender, event_receiver = channel[IndexedEvent]()
command_sender, command_receiver = channel[ForwarderCommand]()
snapshot_sender, snapshot_receiver = channel[SnapshotChunk]()

api = object.__new__(API)
api.node_id = node_id
api.session_id = session_id
api.event_router = event_router
api.event_receiver = event_receiver
api.snapshot_chunk_receiver = snapshot_receiver
api.command_sender = command_sender
api._system_id = SystemId("api-system")
api.state = State()
event_log = _FakeEventLog()
api._event_log = event_log # pyright: ignore[reportAttributeAccessIssue]
api._image_generation_queues = {}
api._text_generation_queues = {}
return api, event_router, command_receiver, snapshot_sender, event_sender, event_log


@pytest.mark.asyncio
async def test_api_fetch_snapshot_applies_state_and_fast_forwards_router() -> None:
node_id = NodeId("api")
session_id = SessionId(master_node_id=NodeId("master"), election_clock=1)
api, event_router, command_receiver, snapshot_sender, _event_sender, _event_log = (
_api(node_id, session_id)
)
state = State(last_event_applied_idx=7)

async with anyio.create_task_group() as tg:
tg.start_soon(api._fetch_snapshot)
command = await command_receiver.receive()
assert isinstance(command.command, RequestSnapshot)
assert command.command.requester_node_id == node_id

await snapshot_sender.send(
_snapshot_chunk(state, requester_node_id=node_id, session_id=session_id)
)

assert api.state.last_event_applied_idx == 7
assert event_router.event_buffer.next_idx_to_release == 8


@pytest.mark.asyncio
async def test_api_apply_state_ignores_events_covered_by_snapshot() -> None:
node_id = NodeId("api")
session_id = SessionId(master_node_id=NodeId("master"), election_clock=1)
(
api,
_event_router,
_command_receiver,
_snapshot_sender,
event_sender,
event_log,
) = _api(node_id, session_id)
api.state = State(last_event_applied_idx=7)

async with anyio.create_task_group() as tg:
tg.start_soon(api._apply_state)
await event_sender.send(IndexedEvent(idx=7, event=TestEvent()))
await event_sender.send(IndexedEvent(idx=8, event=TestEvent()))

while api.state.last_event_applied_idx != 8:
await anyio.sleep(0.001)
tg.cancel_scope.cancel()

assert len(event_log.appended) == 1
Loading
Loading