From c0d4fd7fa7a4770eff7877e19dd1c678bbae8a86 Mon Sep 17 00:00:00 2001 From: Alex Cheema Date: Sun, 3 May 2026 02:01:08 +0100 Subject: [PATCH] feat: reconcile worker backoff from state --- src/exo/utils/keyed_backoff.py | 4 +++ src/exo/utils/tests/test_keyed_backoff.py | 13 +++++++ src/exo/worker/main.py | 16 ++++++--- .../unittests/test_worker_instance_backoff.py | 36 +++++++++++++++++++ 4 files changed, 64 insertions(+), 5 deletions(-) create mode 100644 src/exo/utils/tests/test_keyed_backoff.py create mode 100644 src/exo/worker/tests/unittests/test_worker_instance_backoff.py diff --git a/src/exo/utils/keyed_backoff.py b/src/exo/utils/keyed_backoff.py index 4d7c9a66ed..a95fe5c5f7 100644 --- a/src/exo/utils/keyed_backoff.py +++ b/src/exo/utils/keyed_backoff.py @@ -29,6 +29,10 @@ def attempts(self, key: K) -> int: """Return the number of recorded attempts for a key.""" return self._attempts.get(key, 0) + def tracked_keys(self) -> set[K]: + """Return keys that currently have recorded backoff state.""" + return set(self._attempts) | set(self._last_time) + def reset(self, key: K) -> None: """Reset backoff state for a key (e.g., on success).""" self._attempts.pop(key, None) diff --git a/src/exo/utils/tests/test_keyed_backoff.py b/src/exo/utils/tests/test_keyed_backoff.py new file mode 100644 index 0000000000..b592a4fabd --- /dev/null +++ b/src/exo/utils/tests/test_keyed_backoff.py @@ -0,0 +1,13 @@ +from exo.utils.keyed_backoff import KeyedBackoff + + +def test_tracked_keys_reports_and_resets_backoff_state() -> None: + backoff = KeyedBackoff[str]() + + backoff.record_attempt("instance-a") + + assert backoff.tracked_keys() == {"instance-a"} + + backoff.reset("instance-a") + + assert backoff.tracked_keys() == set() diff --git a/src/exo/worker/main.py b/src/exo/worker/main.py index 3e2aa94673..f63b7b5ec1 100644 --- a/src/exo/worker/main.py +++ b/src/exo/worker/main.py @@ -30,7 +30,6 @@ from exo.shared.types.events import ( Event, IndexedEvent, - InstanceDeleted, NodeDownloadProgress, NodeGatheredInfo, TaskCreated, @@ -141,6 +140,7 @@ async def _bootstrap_then_run( self._tg.start_soon(self._forward_info, info_recv) self._tg.start_soon(self.plan_step) self._tg.start_soon(self._event_applier) + self._tg.start_soon(self._reconcile_instance_backoff) self._tg.start_soon(self._reconcile_custom_cards) self._tg.start_soon(self._poll_connection_updates) @@ -190,12 +190,18 @@ async def _event_applier(self): continue # 2. for each event, apply it to the state self.state = apply(self.state, event=event) - event = event.event + self._sync_input_views_from_state() - if isinstance(event, InstanceDeleted): - self._instance_backoff.reset(event.instance_id) + async def _reconcile_instance_backoff(self) -> None: + while True: + await anyio.sleep(1) + self._reconcile_instance_backoff_once() - self._sync_input_views_from_state() + def _reconcile_instance_backoff_once(self) -> None: + live_instances = set(self.state.instances) + for instance_id in self._instance_backoff.tracked_keys(): + if instance_id not in live_instances: + self._instance_backoff.reset(instance_id) async def _reconcile_custom_cards(self) -> None: while True: diff --git a/src/exo/worker/tests/unittests/test_worker_instance_backoff.py b/src/exo/worker/tests/unittests/test_worker_instance_backoff.py new file mode 100644 index 0000000000..b0052c1eb7 --- /dev/null +++ b/src/exo/worker/tests/unittests/test_worker_instance_backoff.py @@ -0,0 +1,36 @@ +# pyright: reportPrivateUsage=false + +from exo.shared.types.common import ModelId, NodeId +from exo.shared.types.state import State +from exo.shared.types.worker.instances import InstanceId, MlxRingInstance +from exo.shared.types.worker.runners import ShardAssignments +from exo.utils.keyed_backoff import KeyedBackoff +from exo.worker.main import Worker + + +def _make_instance(instance_id: InstanceId) -> MlxRingInstance: + return MlxRingInstance( + instance_id=instance_id, + shard_assignments=ShardAssignments( + model_id=ModelId("test-model"), + node_to_runner={}, + runner_to_shard={}, + ), + hosts_by_node={NodeId("node-1"): []}, + ephemeral_port=1, + ) + + +def test_worker_reconciles_instance_backoff_from_state() -> None: + live_instance_id = InstanceId("inst-live") + deleted_instance_id = InstanceId("inst-deleted") + worker = object.__new__(Worker) + worker.state = State(instances={live_instance_id: _make_instance(live_instance_id)}) + worker._instance_backoff = KeyedBackoff[InstanceId]() + worker._instance_backoff.record_attempt(live_instance_id) + worker._instance_backoff.record_attempt(deleted_instance_id) + + worker._reconcile_instance_backoff_once() + + assert worker._instance_backoff.attempts(live_instance_id) == 1 + assert worker._instance_backoff.attempts(deleted_instance_id) == 0