Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/exo/utils/keyed_backoff.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ def attempts(self, key: K) -> int:
"""Return the number of recorded attempts for a key."""
return self._attempts.get(key, 0)

def tracked_keys(self) -> set[K]:
"""Return keys that currently have recorded backoff state."""
return set(self._attempts) | set(self._last_time)

def reset(self, key: K) -> None:
"""Reset backoff state for a key (e.g., on success)."""
self._attempts.pop(key, None)
Expand Down
13 changes: 13 additions & 0 deletions src/exo/utils/tests/test_keyed_backoff.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from exo.utils.keyed_backoff import KeyedBackoff


def test_tracked_keys_reports_and_resets_backoff_state() -> None:
backoff = KeyedBackoff[str]()

backoff.record_attempt("instance-a")

assert backoff.tracked_keys() == {"instance-a"}

backoff.reset("instance-a")

assert backoff.tracked_keys() == set()
16 changes: 11 additions & 5 deletions src/exo/worker/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
from exo.shared.types.events import (
Event,
IndexedEvent,
InstanceDeleted,
NodeDownloadProgress,
NodeGatheredInfo,
TaskCreated,
Expand Down Expand Up @@ -141,6 +140,7 @@ async def _bootstrap_then_run(
self._tg.start_soon(self._forward_info, info_recv)
self._tg.start_soon(self.plan_step)
self._tg.start_soon(self._event_applier)
self._tg.start_soon(self._reconcile_instance_backoff)
self._tg.start_soon(self._reconcile_custom_cards)
self._tg.start_soon(self._poll_connection_updates)

Expand Down Expand Up @@ -190,12 +190,18 @@ async def _event_applier(self):
continue
# 2. for each event, apply it to the state
self.state = apply(self.state, event=event)
event = event.event
self._sync_input_views_from_state()

if isinstance(event, InstanceDeleted):
self._instance_backoff.reset(event.instance_id)
async def _reconcile_instance_backoff(self) -> None:
while True:
await anyio.sleep(1)
self._reconcile_instance_backoff_once()

self._sync_input_views_from_state()
def _reconcile_instance_backoff_once(self) -> None:
live_instances = set(self.state.instances)
for instance_id in self._instance_backoff.tracked_keys():
if instance_id not in live_instances:
self._instance_backoff.reset(instance_id)

async def _reconcile_custom_cards(self) -> None:
while True:
Expand Down
36 changes: 36 additions & 0 deletions src/exo/worker/tests/unittests/test_worker_instance_backoff.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# pyright: reportPrivateUsage=false

from exo.shared.types.common import ModelId, NodeId
from exo.shared.types.state import State
from exo.shared.types.worker.instances import InstanceId, MlxRingInstance
from exo.shared.types.worker.runners import ShardAssignments
from exo.utils.keyed_backoff import KeyedBackoff
from exo.worker.main import Worker


def _make_instance(instance_id: InstanceId) -> MlxRingInstance:
return MlxRingInstance(
instance_id=instance_id,
shard_assignments=ShardAssignments(
model_id=ModelId("test-model"),
node_to_runner={},
runner_to_shard={},
),
hosts_by_node={NodeId("node-1"): []},
ephemeral_port=1,
)


def test_worker_reconciles_instance_backoff_from_state() -> None:
live_instance_id = InstanceId("inst-live")
deleted_instance_id = InstanceId("inst-deleted")
worker = object.__new__(Worker)
worker.state = State(instances={live_instance_id: _make_instance(live_instance_id)})
worker._instance_backoff = KeyedBackoff[InstanceId]()
worker._instance_backoff.record_attempt(live_instance_id)
worker._instance_backoff.record_attempt(deleted_instance_id)

worker._reconcile_instance_backoff_once()

assert worker._instance_backoff.attempts(live_instance_id) == 1
assert worker._instance_backoff.attempts(deleted_instance_id) == 0
Loading