Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 37 additions & 2 deletions src/exo/api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -1010,7 +1010,28 @@ async def _validate_model_has_instance(self, model_id: ModelId) -> ModelId:
)
return model_id

def stream_events(self) -> StreamingResponse:
async def _validate_image_model(self, model: ModelId) -> ModelId:
"""Validate model exists and return resolved model ID.

Raises HTTPException 404 if no instance is found for the model.
"""
model_card = await ModelCard.load(model)
resolved_model = model_card.model_id
if not any(
instance.shard_assignments.model_id == resolved_model
for instance in self.state.instances.values()
):
await self._trigger_notify_user_to_download_model(resolved_model)
raise HTTPException(
status_code=404, detail=f"No instance found for model {resolved_model}"
)
return resolved_model

def stream_events(
self,
since: int = Query(default=0, ge=0),
limit: int | None = Query(default=None, ge=0),
) -> StreamingResponse:
def _generate_json_array(events: Iterable[Event]) -> Iterable[str]:
yield "["
first = True
Expand All @@ -1021,9 +1042,23 @@ def _generate_json_array(events: Iterable[Event]) -> Iterable[str]:
yield event.model_dump_json()
yield "]"

log_count = len(self._event_log)
if since == 0 and limit is None:
# Backward-compatible path: full ledger dump (matches pre-cursor behavior).
return StreamingResponse(
_generate_json_array(self._event_log.read_all()),
media_type="application/json",
headers={"X-EXO-Last-Idx": str(log_count)},
)
end = log_count if limit is None else min(since + limit, log_count)
if since >= end:
events_iter: Iterable[Event] = iter(())
else:
events_iter = self._event_log.read_range(since, end)
return StreamingResponse(
_generate_json_array(self._event_log.read_all()),
_generate_json_array(events_iter),
media_type="application/json",
headers={"X-EXO-Last-Idx": str(end)},
)

async def get_image(self, image_id: str) -> FileResponse:
Expand Down
37 changes: 37 additions & 0 deletions src/exo/api/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
"""Pytest configuration for API tests.

Stubs the exo_rs Rust extension so API tests can run without a compiled
binary. The stub provides empty placeholder classes for symbols that are
imported at module level by exo.routing, but not exercised by these tests.
"""

import sys
import types
from unittest.mock import MagicMock

# Only install the stub if the real extension is not already available.
if "exo_rs" not in sys.modules:
_stub = types.ModuleType("exo_rs")

# Symbols imported by exo.routing.connection_message
class _FromSwarm:
class Connection:
peer_id: str = ""
connected: bool = False

_stub.FromSwarm = _FromSwarm # type: ignore[attr-defined]

# Symbols imported by exo.routing.router
_stub.AllQueuesFullError = type("AllQueuesFullError", (Exception,), {}) # type: ignore[attr-defined]
_stub.MessageTooLargeError = type("MessageTooLargeError", (Exception,), {}) # type: ignore[attr-defined]
_stub.NoPeersSubscribedToTopicError = type(
"NoPeersSubscribedToTopicError", (Exception,), {}
) # type: ignore[attr-defined]
_stub.Keypair = MagicMock # type: ignore[attr-defined]
_stub.NetworkingHandle = MagicMock # type: ignore[attr-defined]

# Symbols imported by exo.main
_stub.Pidfile = MagicMock # type: ignore[attr-defined]
_stub.PidfileError = type("PidfileError", (Exception,), {}) # type: ignore[attr-defined]

sys.modules["exo_rs"] = _stub
130 changes: 130 additions & 0 deletions src/exo/api/tests/test_stream_events.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# pyright: reportUnusedFunction=false, reportAny=false
"""Tests for the GET /events endpoint cursor support.

The handler at exo/api/main.py:894 supports two query parameters:
- since: int (default 0) — start index, inclusive
- limit: int | None (default None) — max events to return

The response sets `X-EXO-Last-Idx` to the upper bound consumed, allowing
clients to chain reads without a separate /state round-trip.
"""

from pathlib import Path
from typing import Any
from unittest.mock import AsyncMock

from fastapi import FastAPI
from fastapi.testclient import TestClient

from exo.api.main import API
from exo.shared.types.events import TestEvent
from exo.utils.disk_event_log import DiskEventLog


def _make_api(log_dir: Path, n_events: int) -> Any:
"""Create a minimal API with a DiskEventLog containing n_events records
and only the GET /events route mounted."""
app = FastAPI()
api = object.__new__(API)
api.app = app
api._send = AsyncMock() # pyright: ignore[reportPrivateUsage]
api._setup_exception_handlers() # pyright: ignore[reportPrivateUsage]

log = DiskEventLog(log_dir)
for _ in range(n_events):
log.append(TestEvent())
api._event_log = log # pyright: ignore[reportPrivateUsage]

app.get("/events")(api.stream_events)
return api


def test_stream_events_full_dump_backward_compatible(tmp_path: Path) -> None:
"""No params -> full ledger; X-EXO-Last-Idx equals count."""
api = _make_api(tmp_path / "log_full", n_events=5)
client = TestClient(api.app)

resp = client.get("/events")
assert resp.status_code == 200
data: list[dict[str, Any]] = resp.json()
assert len(data) == 5
assert resp.headers["X-EXO-Last-Idx"] == "5"


def test_stream_events_with_since_and_limit(tmp_path: Path) -> None:
"""since=N&limit=M returns events in [since, since+M); header reflects bound."""
api = _make_api(tmp_path / "log_cursor", n_events=10)
client = TestClient(api.app)

resp = client.get("/events", params={"since": 3, "limit": 4})
assert resp.status_code == 200
data: list[dict[str, Any]] = resp.json()
assert len(data) == 4
assert resp.headers["X-EXO-Last-Idx"] == "7"


def test_stream_events_since_only_reads_to_end(tmp_path: Path) -> None:
"""since with no limit returns [since, count); header equals count."""
api = _make_api(tmp_path / "log_tail", n_events=8)
client = TestClient(api.app)

resp = client.get("/events", params={"since": 5})
assert resp.status_code == 200
data: list[dict[str, Any]] = resp.json()
assert len(data) == 3
assert resp.headers["X-EXO-Last-Idx"] == "8"


def test_stream_events_since_beyond_count_returns_empty(tmp_path: Path) -> None:
"""since past end yields []; header reflects clamped end."""
api = _make_api(tmp_path / "log_overshoot", n_events=4)
client = TestClient(api.app)

resp = client.get("/events", params={"since": 99})
assert resp.status_code == 200
assert resp.json() == []
# end is clamped to log_count (4) since limit is None and since > count
assert resp.headers["X-EXO-Last-Idx"] == "4"


def test_stream_events_limit_larger_than_remaining(tmp_path: Path) -> None:
"""limit > remaining is clamped to log_count; no error."""
api = _make_api(tmp_path / "log_clamp", n_events=10)
client = TestClient(api.app)

resp = client.get("/events", params={"since": 7, "limit": 100})
assert resp.status_code == 200
data: list[dict[str, Any]] = resp.json()
assert len(data) == 3
assert resp.headers["X-EXO-Last-Idx"] == "10"


def test_stream_events_negative_since_rejected(tmp_path: Path) -> None:
"""FastAPI Query(ge=0) rejects negative since with 422."""
api = _make_api(tmp_path / "log_neg", n_events=3)
client = TestClient(api.app)

resp = client.get("/events", params={"since": -1})
assert resp.status_code == 422


def test_stream_events_chained_cursor_reads(tmp_path: Path) -> None:
"""Two sequential reads using returned cursor cover full ledger without overlap."""
api = _make_api(tmp_path / "log_chain", n_events=6)
client = TestClient(api.app)

first = client.get("/events", params={"since": 0, "limit": 4})
assert first.status_code == 200
cursor = int(first.headers["X-EXO-Last-Idx"])
assert cursor == 4
first_data: list[dict[str, Any]] = first.json()
assert len(first_data) == 4

second = client.get("/events", params={"since": cursor})
assert second.status_code == 200
second_data: list[dict[str, Any]] = second.json()
assert len(second_data) == 2
assert second.headers["X-EXO-Last-Idx"] == "6"

# Two reads cover the full ledger with no gap and no overlap.
assert len(first_data) + len(second_data) == 6