From 370979ee0095d474e79a59def563260ab59c7eb9 Mon Sep 17 00:00:00 2001 From: jw-wcv <101585096+jw-wcv@users.noreply.github.com> Date: Sat, 25 Apr 2026 18:10:33 -0700 Subject: [PATCH 1/6] Fix RDMA placement host selection --- src/exo/master/placement.py | 29 ++++ src/exo/master/placement_utils.py | 60 ++++++-- src/exo/master/tests/test_placement.py | 190 +++++++++++++++++++++++++ 3 files changed, 271 insertions(+), 8 deletions(-) diff --git a/src/exo/master/placement.py b/src/exo/master/placement.py index b55571c7f7..7c8d4b2257 100644 --- a/src/exo/master/placement.py +++ b/src/exo/master/placement.py @@ -30,6 +30,7 @@ from exo.shared.types.memory import Memory from exo.shared.types.profiling import MemoryUsage, NodeNetworkInfo, NodeRdmaCtlStatus from exo.shared.types.tasks import Task, TaskId, TaskStatus +from exo.shared.types.topology import SocketConnection from exo.shared.types.worker.downloads import ( DownloadCompleted, DownloadFailed, @@ -211,6 +212,7 @@ def _all_rdma_ctl_enabled(cycle: Cycle) -> bool: ), ), ) + selected_cycle = _prefer_socket_reachable_rank_zero(selected_cycle, topology) # Single-node: force Pipeline/Ring (Tensor and Jaccl require multi-node) if len(selected_cycle) == 1: @@ -281,6 +283,33 @@ def get_device_rank(node_id: NodeId) -> int: return target_instances +def _prefer_socket_reachable_rank_zero(cycle: Cycle, topology: Topology) -> Cycle: + """Rotate multi-node placements so rank 0 is easiest for peers to reach. + + MLX ring and JACCL both make rank 0 the listener/coordinator. Discovery can + produce RDMA-only edges in one direction and socket control-plane edges in + another, so putting a node with advertised inbound socket edges at rank 0 + avoids assigning the listener role to a machine peers cannot dial. + """ + if len(cycle) <= 1: + return cycle + + inbound_socket_edges: dict[NodeId, int] = {node_id: 0 for node_id in cycle} + for connection in topology.list_connections(): + if connection.sink not in inbound_socket_edges: + continue + if isinstance(connection.edge, SocketConnection): + inbound_socket_edges[connection.sink] += 1 + + best_index = max( + range(len(cycle.node_ids)), + key=lambda index: (inbound_socket_edges[cycle.node_ids[index]], -index), + ) + if best_index == 0: + return cycle + return Cycle(node_ids=cycle.node_ids[best_index:] + cycle.node_ids[:best_index]) + + def delete_instance( command: DeleteInstance, current_instances: Mapping[InstanceId, Instance], diff --git a/src/exo/master/placement_utils.py b/src/exo/master/placement_utils.py index 0375e97e01..f5162ebe74 100644 --- a/src/exo/master/placement_utils.py +++ b/src/exo/master/placement_utils.py @@ -348,17 +348,20 @@ def find_ip_prioritised( Priority: ethernet > wifi > unknown > thunderbolt """ ips = list(_find_connection_ip(node_id, other_node_id, cycle_digraph)) - if not ips: - return None other_network = node_network.get(other_node_id, NodeNetworkInfo()) ip_to_type = { iface.ip_address: iface.interface_type for iface in other_network.interfaces } + if not ips: + ips = _fallback_interface_ips(other_network) + if not ips: + return None + # Ring should prioritise fastest connection. As a best-effort, we prioritise TB. # TODO: Profile and get actual connection speeds. if ring: - priority = { + type_priority = { "thunderbolt": 0, "maybe_ethernet": 1, "ethernet": 2, @@ -368,14 +371,55 @@ def find_ip_prioritised( # RDMA prefers ethernet coordinator else: - priority = { + type_priority = { "ethernet": 0, - "wifi": 1, - "unknown": 2, - "maybe_ethernet": 3, + "maybe_ethernet": 1, + "wifi": 2, + "unknown": 3, "thunderbolt": 4, } - return min(ips, key=lambda ip: priority.get(ip_to_type.get(ip, "unknown"), 2)) + + return min( + ips, + key=lambda ip: ( + _address_priority(ip), + type_priority.get(ip_to_type.get(ip, "unknown"), 5), + ), + ) + + +def _fallback_interface_ips(node_network: NodeNetworkInfo) -> list[str]: + """Return advertised node IPs when topology only has non-socket edges.""" + return [ + iface.ip_address + for iface in node_network.interfaces + if _is_candidate_host_ip(iface.ip_address) + ] + + +def _is_candidate_host_ip(ip: str) -> bool: + if ":" in ip: + return False + if ip.startswith("127.") or ip == "0.0.0.0": + return False + return True + + +def _address_priority(ip: str) -> int: + if ip.startswith("192.168.") or ip.startswith("10."): + return 0 + if ip.startswith("172."): + try: + second_octet = int(ip.split(".")[1]) + except (IndexError, ValueError): + return 3 + if 16 <= second_octet <= 31: + return 0 + if ip.startswith("100."): + return 2 + if ip.startswith("169.254."): + return 3 + return 1 def get_mlx_ring_hosts_by_node( diff --git a/src/exo/master/tests/test_placement.py b/src/exo/master/tests/test_placement.py index d3acd24f18..03218ba5ed 100644 --- a/src/exo/master/tests/test_placement.py +++ b/src/exo/master/tests/test_placement.py @@ -624,6 +624,196 @@ def test_place_mlx_jaccl_rejects_when_node_rdma_ctl_missing(model_card: ModelCar ) +def test_ring_placement_uses_advertised_lan_ips_for_rdma_only_topology( + model_card: ModelCard, +) -> None: + topology = Topology() + model_card = model_card.model_copy( + update={ + "storage_size": Memory.from_bytes(1500), + "n_layers": 12, + } + ) + + node_a = NodeId() + node_b = NodeId() + + topology.add_node(node_a) + topology.add_node(node_b) + topology.add_connection( + Connection(source=node_a, sink=node_b, edge=create_rdma_connection(1)) + ) + topology.add_connection( + Connection(source=node_b, sink=node_a, edge=create_rdma_connection(2)) + ) + + node_memory = { + node_a: create_node_memory(1000), + node_b: create_node_memory(1000), + } + node_network = { + node_a: NodeNetworkInfo( + interfaces=[ + NetworkInterfaceInfo( + name="en9", ip_address="192.168.1.10", interface_type="ethernet" + ) + ] + ), + node_b: NodeNetworkInfo( + interfaces=[ + NetworkInterfaceInfo( + name="en9", ip_address="192.168.1.11", interface_type="ethernet" + ) + ] + ), + } + + command = place_instance_command(model_card) + command = command.model_copy(update={"min_nodes": 2}) + + placements = place_instance(command, topology, {}, node_memory, node_network) + + instance = list(placements.values())[0] + assert isinstance(instance, MlxRingInstance) + assert len(instance.shard_assignments.node_to_runner) == 2 + assert any(host.ip == "192.168.1.11" for host in instance.hosts_by_node[node_a]) + assert any(host.ip == "192.168.1.10" for host in instance.hosts_by_node[node_b]) + + +def test_jaccl_placement_uses_advertised_lan_ip_for_rdma_coordinator( + model_card: ModelCard, +) -> None: + topology = Topology() + model_card = model_card.model_copy( + update={ + "storage_size": Memory.from_bytes(1500), + "n_layers": 12, + "hidden_size": 32, + "num_key_value_heads": 8, + "supports_tensor": True, + } + ) + + node_a = NodeId() + node_b = NodeId() + + topology.add_node(node_a) + topology.add_node(node_b) + topology.add_connection( + Connection(source=node_a, sink=node_b, edge=create_rdma_connection(1)) + ) + topology.add_connection( + Connection(source=node_b, sink=node_a, edge=create_rdma_connection(2)) + ) + + node_memory = { + node_a: create_node_memory(1000), + node_b: create_node_memory(1000), + } + node_network = { + node_a: NodeNetworkInfo( + interfaces=[ + NetworkInterfaceInfo( + name="en9", ip_address="192.168.1.10", interface_type="ethernet" + ) + ] + ), + node_b: NodeNetworkInfo( + interfaces=[ + NetworkInterfaceInfo( + name="en9", ip_address="192.168.1.11", interface_type="ethernet" + ) + ] + ), + } + command = PlaceInstance( + sharding=Sharding.Tensor, + instance_meta=InstanceMeta.MlxJaccl, + command_id=CommandId(), + model_card=model_card, + min_nodes=2, + ) + + node_rdma_ctl = { + node_a: NodeRdmaCtlStatus(enabled=True), + node_b: NodeRdmaCtlStatus(enabled=True), + } + placements = place_instance( + command, + topology, + {}, + node_memory, + node_network, + node_rdma_ctl=node_rdma_ctl, + ) + + instance = list(placements.values())[0] + assert isinstance(instance, MlxJacclInstance) + assert len(instance.shard_assignments.node_to_runner) == 2 + assert any( + coordinator.startswith("192.168.1.") + for coordinator in instance.jaccl_coordinators.values() + ) + + +def test_placement_prefers_socket_reachable_rank_zero( + model_card: ModelCard, +) -> None: + topology = Topology() + model_card = model_card.model_copy( + update={ + "storage_size": Memory.from_bytes(1500), + "n_layers": 12, + } + ) + + listener = NodeId() + peer = NodeId() + + topology.add_node(listener) + topology.add_node(peer) + topology.add_connection( + Connection(source=listener, sink=peer, edge=create_rdma_connection(1)) + ) + topology.add_connection( + Connection(source=peer, sink=listener, edge=create_rdma_connection(2)) + ) + topology.add_connection( + Connection(source=peer, sink=listener, edge=create_socket_connection(10)) + ) + + node_memory = { + listener: create_node_memory(1000), + peer: create_node_memory(1000), + } + node_network = { + listener: NodeNetworkInfo( + interfaces=[ + NetworkInterfaceInfo( + name="en9", ip_address="192.168.1.10", interface_type="ethernet" + ) + ] + ), + peer: NodeNetworkInfo( + interfaces=[ + NetworkInterfaceInfo( + name="en9", ip_address="192.168.1.11", interface_type="ethernet" + ) + ] + ), + } + + command = place_instance_command(model_card) + command = command.model_copy(update={"min_nodes": 2}) + + placements = place_instance(command, topology, {}, node_memory, node_network) + + instance = list(placements.values())[0] + runner_id = instance.shard_assignments.node_to_runner[listener] + shard = instance.shard_assignments.runner_to_shard[runner_id] + assert shard.device_rank == 0 + + def _make_task( instance_id: InstanceId, status: TaskStatus = TaskStatus.Running, From ce61802b77cabd49364e4b2291e74dcaadbdb8b9 Mon Sep 17 00:00:00 2001 From: jw-wcv <101585096+jw-wcv@users.noreply.github.com> Date: Sat, 25 Apr 2026 19:17:32 -0700 Subject: [PATCH 2/6] Walk nested Thunderbolt _items so hub-attached peers register `system_profiler SPThunderboltDataType -json` represents transparent TB hubs (e.g. the iVANKY Fusiondock Ultra) as an extra layer in `_items`, nesting the peer Mac one level deeper than a direct cable would. The previous parser only inspected the top-level `_items`, so on the side that enumerated through the dock the peer's `domain_uuid_key` was never found and the corresponding RDMA edge silently went missing - producing an asymmetric mesh where placement could see the link in one direction but not the other. Extend `_ConnectivityItem` to recursively model `_items` and walk the tree depth-first for the first descendant `domain_uuid_key`. The link is the actual peer endpoint regardless of how many transparent switches sit between the local receptacle and the peer Mac. Cover the new behaviour with three deterministic unit tests: the iVANKY hub case observed in production, a direct-cable sanity case, and an empty-receptacle case. --- src/exo/shared/types/thunderbolt.py | 31 +++++++--- .../info_gatherer/tests/test_tb_parsing.py | 62 +++++++++++++++++++ 2 files changed, 85 insertions(+), 8 deletions(-) diff --git a/src/exo/shared/types/thunderbolt.py b/src/exo/shared/types/thunderbolt.py index 34cd1ccad9..d6b8d3742c 100644 --- a/src/exo/shared/types/thunderbolt.py +++ b/src/exo/shared/types/thunderbolt.py @@ -25,6 +25,28 @@ class _ReceptacleTag(BaseModel, extra="ignore"): class _ConnectivityItem(BaseModel, extra="ignore"): domain_uuid_key: str | None = None + items: list["_ConnectivityItem"] | None = Field(None, alias="_items") + + +def _first_descendant_domain_uuid(items: list[_ConnectivityItem]) -> str | None: + """Return the first ``domain_uuid_key`` found by depth-first search. + + Apple's ``system_profiler SPThunderboltDataType`` output places intermediate + Thunderbolt hubs/docks (e.g. an iVANKY Fusiondock Ultra) between the local + receptacle and the peer Mac. The hub appears as an ``_items`` entry without + a ``domain_uuid_key`` of its own; the peer Mac sits one level deeper. We + descend until we hit the first node that exposes a domain UUID, which is + always the actual peer endpoint regardless of how many transparent + switches sit between us. + """ + for item in items: + if item.domain_uuid_key is not None: + return item.domain_uuid_key + if item.items is not None: + descendant = _first_descendant_domain_uuid(item.items) + if descendant is not None: + return descendant + return None class ThunderboltConnectivityData(BaseModel, extra="ignore"): @@ -53,14 +75,7 @@ def conn(self) -> ThunderboltConnection | None: if self.domain_uuid_key is None or self.items is None: return - sink_key = next( - ( - item.domain_uuid_key - for item in self.items - if item.domain_uuid_key is not None - ), - None, - ) + sink_key = _first_descendant_domain_uuid(self.items) if sink_key is None: return None diff --git a/src/exo/utils/info_gatherer/tests/test_tb_parsing.py b/src/exo/utils/info_gatherer/tests/test_tb_parsing.py index 787dd3d5f9..d2551f0c6b 100644 --- a/src/exo/utils/info_gatherer/tests/test_tb_parsing.py +++ b/src/exo/utils/info_gatherer/tests/test_tb_parsing.py @@ -4,6 +4,7 @@ from exo.shared.types.thunderbolt import ( ThunderboltConnectivity, + ThunderboltConnectivityData, ) from exo.utils.info_gatherer.info_gatherer import ( _gather_iface_map, # pyright: ignore[reportPrivateUsage] @@ -22,3 +23,64 @@ async def test_tb_parsing(): for datum in data: datum.ident(ifaces) datum.conn() + + +def test_conn_resolves_peer_through_intermediate_hub() -> None: + """A TB hub between two Macs hides the peer one level deeper. + + Reproduces the iVANKY Fusiondock Ultra topology observed on the + wc-bmbp <-> wc-smbp link, where ``system_profiler`` reports the dock at + the first ``_items`` level and the peer Mac nested inside it. The parser + must walk past the dock and surface the peer's domain UUID, otherwise + half the RDMA mesh stays invisible to the placement engine. + """ + payload = { + "domain_uuid_key": "DCA2B6F5-1C58-4589-8DA8-90B9326462D6", + "receptacle_1_tag": { + "receptacle_id_key": "2", + "current_speed_key": "80 Gb/s", + }, + "_items": [ + { + "_name": "iVANKY Fusiondock Ultra", + "_items": [ + { + "_name": "MacBook Pro", + "domain_uuid_key": "F74D8F9B-DCDF-40D4-A428-3A3674BCB3F4", + } + ], + } + ], + } + datum = ThunderboltConnectivityData.model_validate(payload) + conn = datum.conn() + assert conn is not None + assert conn.source_uuid == "DCA2B6F5-1C58-4589-8DA8-90B9326462D6" + assert conn.sink_uuid == "F74D8F9B-DCDF-40D4-A428-3A3674BCB3F4" + + +def test_conn_returns_first_peer_for_direct_link() -> None: + """A direct cable still surfaces the peer at the first level.""" + payload = { + "domain_uuid_key": "EA94B959-A0C4-1111-1111-111111111111", + "_items": [ + { + "_name": "MacBook Pro", + "domain_uuid_key": "D02B9C20-7504-2222-2222-222222222222", + } + ], + } + datum = ThunderboltConnectivityData.model_validate(payload) + conn = datum.conn() + assert conn is not None + assert conn.sink_uuid == "D02B9C20-7504-2222-2222-222222222222" + + +def test_conn_returns_none_when_no_peer_present() -> None: + """Empty ``_items`` (e.g. unconnected receptacle) yields no edge.""" + payload: dict[str, object] = { + "domain_uuid_key": "AAA", + "_items": [], + } + datum = ThunderboltConnectivityData.model_validate(payload) + assert datum.conn() is None From 677144f4b4c8e0bc7e67087fe893af012751ecc2 Mon Sep 17 00:00:00 2001 From: jw-wcv <101585096+jw-wcv@users.noreply.github.com> Date: Sun, 26 Apr 2026 01:22:02 -0700 Subject: [PATCH 3/6] Add JACCL init retry loop MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Retry the JACCL initialization handshake when peers come online out of order, which we frequently observed when a 4-node Apple Silicon cluster booted on a Thunderbolt RDMA fabric — the master would attempt the JACCL collective before all worker engines had registered, yielding a single-shot failure that took the whole inference path with it. Adds a small retry loop with backoff in mlx utils so the master will reattempt init for ~30s before giving up. --- src/exo/worker/engines/mlx/utils_mlx.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/src/exo/worker/engines/mlx/utils_mlx.py b/src/exo/worker/engines/mlx/utils_mlx.py index 1dddad2ae1..6d63eb6ae3 100644 --- a/src/exo/worker/engines/mlx/utils_mlx.py +++ b/src/exo/worker/engines/mlx/utils_mlx.py @@ -125,7 +125,6 @@ def mlx_distributed_init( assert all( jaccl_devices[i][i] is None for i in range(len(jaccl_devices)) ) - # Use RDMA connectivity matrix jaccl_devices_json = json.dumps(jaccl_devices) with open(coordination_file, "w") as f: @@ -140,7 +139,21 @@ def mlx_distributed_init( os.environ["MLX_IBV_DEVICES"] = coordination_file os.environ["MLX_RANK"] = str(rank) os.environ["MLX_JACCL_COORDINATOR"] = jaccl_coordinator - group = mx.distributed.init(backend="jaccl", strict=True) + + max_jaccl_attempts = 8 + for attempt in range(1, max_jaccl_attempts + 1): + try: + group = mx.distributed.init(backend="jaccl", strict=True) + break + except (RuntimeError, ValueError) as exc: + if attempt == max_jaccl_attempts: + raise + backoff = min(2.0 * attempt, 10.0) + logger.warning( + f"rank {rank} JACCL init attempt {attempt}/{max_jaccl_attempts} " + f"failed ({exc}), retrying in {backoff:.0f}s" + ) + time.sleep(backoff) logger.info(f"Rank {rank} mlx distributed initialization complete") From 31d76693ddf66cab3971acad32502f8350f8ddfc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Do=C4=9Fukan=20Veziro=C4=9Flu?= Date: Mon, 27 Apr 2026 13:44:39 -0700 Subject: [PATCH 4/6] Distinguish TB4 from TB5 for Thunderbolt connections (dashboard) Only treat a Thunderbolt link as TB5 when the link speed is >40 Gb/s (TB5 negotiated speeds are 80 Gb/s symmetric or 120/40 Gb/s asymmetric). Without this, the dashboard was prompting users to enable RDMA on TB4-only nodes that can't actually run rdma_ctl in TB5 mode. --- dashboard/src/routes/+page.svelte | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/dashboard/src/routes/+page.svelte b/dashboard/src/routes/+page.svelte index f69cdd1519..9b4563fd18 100644 --- a/dashboard/src/routes/+page.svelte +++ b/dashboard/src/routes/+page.svelte @@ -147,9 +147,14 @@ if (!rdmaCtl) return false; const ids = tbIdentifiers; if (!ids) return false; - // Find nodes with TB5 hardware (any TB interface) + // Find nodes with TB5 hardware (link speed > 40 Gb/s) const tb5NodeIds = Object.entries(ids) - .filter(([_, node]) => node.interfaces.length > 0) + .filter(([_, node]) => + node.interfaces.some((iface: { linkSpeed: string }) => { + const match = iface.linkSpeed.match(/(\d+)\s*Gb/i); + return match != null && parseInt(match[1], 10) > 40; + }), + ) .map(([id]) => id); if (tb5NodeIds.length < 2) return false; // At least one TB5 node has RDMA disabled From f469af129c16d39c30bdfa31b464f3b1833cb794 Mon Sep 17 00:00:00 2001 From: jw-wcv <101585096+jw-wcv@users.noreply.github.com> Date: Wed, 6 May 2026 17:05:34 -0700 Subject: [PATCH 5/6] chore: ruff/basedpyright cleanup on RDMA placement utils + JACCL retry - placement_utils._is_candidate_host_ip / _address_priority: SIM103 + PIE810 ruff fixes (no behavior change). - utils_mlx: scope a # pyright: ignore[reportPossiblyUnboundVariable] on the JACCL retry-loop return; group is always bound when the loop exits with break, but the static analyzer can't reason across the raise-on-final-attempt branch. --- src/exo/master/placement_utils.py | 6 ++---- src/exo/worker/engines/mlx/utils_mlx.py | 2 +- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/src/exo/master/placement_utils.py b/src/exo/master/placement_utils.py index f5162ebe74..acc0c162e6 100644 --- a/src/exo/master/placement_utils.py +++ b/src/exo/master/placement_utils.py @@ -400,13 +400,11 @@ def _fallback_interface_ips(node_network: NodeNetworkInfo) -> list[str]: def _is_candidate_host_ip(ip: str) -> bool: if ":" in ip: return False - if ip.startswith("127.") or ip == "0.0.0.0": - return False - return True + return not (ip.startswith("127.") or ip == "0.0.0.0") def _address_priority(ip: str) -> int: - if ip.startswith("192.168.") or ip.startswith("10."): + if ip.startswith(("192.168.", "10.")): return 0 if ip.startswith("172."): try: diff --git a/src/exo/worker/engines/mlx/utils_mlx.py b/src/exo/worker/engines/mlx/utils_mlx.py index 6d63eb6ae3..e3f7778758 100644 --- a/src/exo/worker/engines/mlx/utils_mlx.py +++ b/src/exo/worker/engines/mlx/utils_mlx.py @@ -157,7 +157,7 @@ def mlx_distributed_init( logger.info(f"Rank {rank} mlx distributed initialization complete") - return group + return group # pyright: ignore[reportPossiblyUnboundVariable] def initialize_mlx( From ce31509966f4fa0fac2dbc1630bfea046d44af5e Mon Sep 17 00:00:00 2001 From: jw-wcv <101585096+jw-wcv@users.noreply.github.com> Date: Mon, 11 May 2026 12:18:10 -0700 Subject: [PATCH 6/6] Address PR feedback: drop fallback-IP path, fix ring-side tiebreaker Maintainer feedback on #2063 was that the `_fallback_interface_ips` path was papering over a discovery race instead of solving a concrete bug: "is our discovery service not finding a real connection that exists? i plan to replace discovery somewhat soon (#2076 and beyond)." Agreed on review. Changes: - `find_ip_prioritised`: restore `if not ips: return None`. The 20s socket-discovery window is short, placement decisions can wait, and the JACCL retry loop added in this same PR already absorbs the startup race at the next layer down. - Drop `_fallback_interface_ips` and `_is_candidate_host_ip` helpers. - Reorder the `min(...)` tuple so interface type is the primary key and `_address_priority` is only a tiebreaker. Without this, a `ring=True` ring placement could prefer a LAN-class ethernet IP over a Thunderbolt IP that happened to be link-local (169.254/16), which inverts the ring-prefers-TB intent. - Simplify `_address_priority` to a two-class RFC1918/everything-else split. The previous CGNAT vs link-local vs catch-all ranking only mattered together with the dropped fallback; with type as primary key the simpler split is enough to keep LAN ahead of Tailscale at parity. Test impact: - Drop `test_ring_placement_uses_advertised_lan_ips_for_rdma_only_topology` and `test_jaccl_placement_uses_advertised_lan_ip_for_rdma_coordinator`. Both depended on the fallback by constructing RDMA-only topologies. - Replace them with `test_ring_placement_prefers_lan_ip_over_tailscale_ip` and `test_jaccl_placement_prefers_lan_ip_over_tailscale_ip`, which exercise the address-priority tiebreaker against realistic dual-homed (LAN + Tailscale) socket-reachable topologies. - `test_placement_prefers_socket_reachable_rank_zero` now seeds an explicit listener->peer socket edge (so placement can resolve without the fallback) and a second peer->listener edge (so the rank-zero rotation has a strict winner). --- src/exo/master/placement_utils.py | 39 ++----- src/exo/master/tests/test_placement.py | 151 +++++++++++++++++++++++-- 2 files changed, 151 insertions(+), 39 deletions(-) diff --git a/src/exo/master/placement_utils.py b/src/exo/master/placement_utils.py index acc0c162e6..31f61f56f6 100644 --- a/src/exo/master/placement_utils.py +++ b/src/exo/master/placement_utils.py @@ -345,19 +345,20 @@ def find_ip_prioritised( ) -> str | None: """Find an IP address between nodes with prioritization. - Priority: ethernet > wifi > unknown > thunderbolt + Interface type drives the primary preference (TB first for ring, ethernet + first for the RDMA coordinator). Address class is only a tiebreaker that + keeps RFC1918 LAN ahead of CGNAT-class addresses (e.g. Tailscale 100.64/10) + when a peer advertises multiple socket-reachable IPs of the same type. """ ips = list(_find_connection_ip(node_id, other_node_id, cycle_digraph)) + if not ips: + return None + other_network = node_network.get(other_node_id, NodeNetworkInfo()) ip_to_type = { iface.ip_address: iface.interface_type for iface in other_network.interfaces } - if not ips: - ips = _fallback_interface_ips(other_network) - if not ips: - return None - # Ring should prioritise fastest connection. As a best-effort, we prioritise TB. # TODO: Profile and get actual connection speeds. if ring: @@ -382,41 +383,25 @@ def find_ip_prioritised( return min( ips, key=lambda ip: ( - _address_priority(ip), type_priority.get(ip_to_type.get(ip, "unknown"), 5), + _address_priority(ip), ), ) -def _fallback_interface_ips(node_network: NodeNetworkInfo) -> list[str]: - """Return advertised node IPs when topology only has non-socket edges.""" - return [ - iface.ip_address - for iface in node_network.interfaces - if _is_candidate_host_ip(iface.ip_address) - ] - - -def _is_candidate_host_ip(ip: str) -> bool: - if ":" in ip: - return False - return not (ip.startswith("127.") or ip == "0.0.0.0") - - def _address_priority(ip: str) -> int: + """RFC1918 LAN addresses are preferred; everything else (CGNAT/Tailscale, + link-local, public) is treated identically and only ranked by interface + type.""" if ip.startswith(("192.168.", "10.")): return 0 if ip.startswith("172."): try: second_octet = int(ip.split(".")[1]) except (IndexError, ValueError): - return 3 + return 1 if 16 <= second_octet <= 31: return 0 - if ip.startswith("100."): - return 2 - if ip.startswith("169.254."): - return 3 return 1 diff --git a/src/exo/master/tests/test_placement.py b/src/exo/master/tests/test_placement.py index 03218ba5ed..721ff0aaed 100644 --- a/src/exo/master/tests/test_placement.py +++ b/src/exo/master/tests/test_placement.py @@ -624,9 +624,14 @@ def test_place_mlx_jaccl_rejects_when_node_rdma_ctl_missing(model_card: ModelCar ) -def test_ring_placement_uses_advertised_lan_ips_for_rdma_only_topology( +def test_ring_placement_prefers_lan_ip_over_tailscale_ip( model_card: ModelCard, ) -> None: + """When MlxRing neighbours have multiple socket-reachable IPs of the same + interface type, the LAN address (RFC1918) should win over the Tailscale + CGNAT-class address (100.64/10). Exercises ``_address_priority`` as the + tiebreaker for ``ring=True`` placements. + """ topology = Topology() model_card = model_card.model_copy( update={ @@ -646,6 +651,28 @@ def test_ring_placement_uses_advertised_lan_ips_for_rdma_only_topology( topology.add_connection( Connection(source=node_b, sink=node_a, edge=create_rdma_connection(2)) ) + for ip_a, ip_b in ( + ("192.168.1.10", "192.168.1.11"), + ("100.64.0.10", "100.64.0.11"), + ): + topology.add_connection( + Connection( + source=node_a, + sink=node_b, + edge=SocketConnection( + sink_multiaddr=Multiaddr(address=f"/ip4/{ip_b}/tcp/52415") + ), + ) + ) + topology.add_connection( + Connection( + source=node_b, + sink=node_a, + edge=SocketConnection( + sink_multiaddr=Multiaddr(address=f"/ip4/{ip_a}/tcp/52415") + ), + ) + ) node_memory = { node_a: create_node_memory(1000), @@ -656,14 +683,24 @@ def test_ring_placement_uses_advertised_lan_ips_for_rdma_only_topology( interfaces=[ NetworkInterfaceInfo( name="en9", ip_address="192.168.1.10", interface_type="ethernet" - ) + ), + NetworkInterfaceInfo( + name="utun0", + ip_address="100.64.0.10", + interface_type="ethernet", + ), ] ), node_b: NodeNetworkInfo( interfaces=[ NetworkInterfaceInfo( name="en9", ip_address="192.168.1.11", interface_type="ethernet" - ) + ), + NetworkInterfaceInfo( + name="utun0", + ip_address="100.64.0.11", + interface_type="ethernet", + ), ] ), } @@ -676,13 +713,21 @@ def test_ring_placement_uses_advertised_lan_ips_for_rdma_only_topology( instance = list(placements.values())[0] assert isinstance(instance, MlxRingInstance) assert len(instance.shard_assignments.node_to_runner) == 2 - assert any(host.ip == "192.168.1.11" for host in instance.hosts_by_node[node_a]) - assert any(host.ip == "192.168.1.10" for host in instance.hosts_by_node[node_b]) + # Every dialed neighbour host should be on the LAN, never on Tailscale. + for hosts in instance.hosts_by_node.values(): + for host in hosts: + if host.port != 0 and host.ip != "0.0.0.0": + assert host.ip.startswith("192.168.1."), host -def test_jaccl_placement_uses_advertised_lan_ip_for_rdma_coordinator( +def test_jaccl_placement_prefers_lan_ip_over_tailscale_ip( model_card: ModelCard, ) -> None: + """When a peer is socket-reachable over both LAN (192.168/16) and a + CGNAT-class address (100.64/10, e.g. Tailscale), the JACCL coordinator + should pick the LAN IP. This exercises ``_address_priority`` as the + tiebreaker among same-type ethernet edges. + """ topology = Topology() model_card = model_card.model_copy( update={ @@ -705,6 +750,43 @@ def test_jaccl_placement_uses_advertised_lan_ip_for_rdma_coordinator( topology.add_connection( Connection(source=node_b, sink=node_a, edge=create_rdma_connection(2)) ) + # Both LAN and Tailscale socket edges are advertised in both directions. + topology.add_connection( + Connection( + source=node_a, + sink=node_b, + edge=SocketConnection( + sink_multiaddr=Multiaddr(address="/ip4/192.168.1.11/tcp/52415") + ), + ) + ) + topology.add_connection( + Connection( + source=node_a, + sink=node_b, + edge=SocketConnection( + sink_multiaddr=Multiaddr(address="/ip4/100.64.0.11/tcp/52415") + ), + ) + ) + topology.add_connection( + Connection( + source=node_b, + sink=node_a, + edge=SocketConnection( + sink_multiaddr=Multiaddr(address="/ip4/192.168.1.10/tcp/52415") + ), + ) + ) + topology.add_connection( + Connection( + source=node_b, + sink=node_a, + edge=SocketConnection( + sink_multiaddr=Multiaddr(address="/ip4/100.64.0.10/tcp/52415") + ), + ) + ) node_memory = { node_a: create_node_memory(1000), @@ -715,14 +797,24 @@ def test_jaccl_placement_uses_advertised_lan_ip_for_rdma_coordinator( interfaces=[ NetworkInterfaceInfo( name="en9", ip_address="192.168.1.10", interface_type="ethernet" - ) + ), + NetworkInterfaceInfo( + name="utun0", + ip_address="100.64.0.10", + interface_type="ethernet", + ), ] ), node_b: NodeNetworkInfo( interfaces=[ NetworkInterfaceInfo( name="en9", ip_address="192.168.1.11", interface_type="ethernet" - ) + ), + NetworkInterfaceInfo( + name="utun0", + ip_address="100.64.0.11", + interface_type="ethernet", + ), ] ), } @@ -750,15 +842,22 @@ def test_jaccl_placement_uses_advertised_lan_ip_for_rdma_coordinator( instance = list(placements.values())[0] assert isinstance(instance, MlxJacclInstance) assert len(instance.shard_assignments.node_to_runner) == 2 - assert any( - coordinator.startswith("192.168.1.") + non_rank_zero = [ + coordinator for coordinator in instance.jaccl_coordinators.values() - ) + if not coordinator.startswith("0.0.0.0:") + ] + # Every dialled coordinator should be on the LAN, never on Tailscale. + assert non_rank_zero, "expected at least one non-rank-0 coordinator" + assert all(c.startswith("192.168.1.") for c in non_rank_zero), non_rank_zero def test_placement_prefers_socket_reachable_rank_zero( model_card: ModelCard, ) -> None: + """``_prefer_socket_reachable_rank_zero`` rotates the cycle so the node + with the most inbound socket edges becomes rank 0 (the listener). + """ topology = Topology() model_card = model_card.model_copy( update={ @@ -778,8 +877,36 @@ def test_placement_prefers_socket_reachable_rank_zero( topology.add_connection( Connection(source=peer, sink=listener, edge=create_rdma_connection(2)) ) + # One socket edge in the listener->peer direction so the ring placement + # can resolve an IP for peer. + topology.add_connection( + Connection( + source=listener, + sink=peer, + edge=SocketConnection( + sink_multiaddr=Multiaddr(address="/ip4/192.168.1.11/tcp/52415") + ), + ) + ) + # Two distinct socket edges to the listener so it dominates the inbound + # count and the rotation prefers it as rank 0. topology.add_connection( - Connection(source=peer, sink=listener, edge=create_socket_connection(10)) + Connection( + source=peer, + sink=listener, + edge=SocketConnection( + sink_multiaddr=Multiaddr(address="/ip4/192.168.1.10/tcp/52415") + ), + ) + ) + topology.add_connection( + Connection( + source=peer, + sink=listener, + edge=SocketConnection( + sink_multiaddr=Multiaddr(address="/ip4/192.168.1.10/tcp/52416") + ), + ) ) node_memory = {