Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions dashboard/src/routes/+page.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -147,9 +147,14 @@
if (!rdmaCtl) return false;
const ids = tbIdentifiers;
if (!ids) return false;
// Find nodes with TB5 hardware (any TB interface)
// Find nodes with TB5 hardware (link speed > 40 Gb/s)
const tb5NodeIds = Object.entries(ids)
.filter(([_, node]) => node.interfaces.length > 0)
.filter(([_, node]) =>
node.interfaces.some((iface: { linkSpeed: string }) => {
const match = iface.linkSpeed.match(/(\d+)\s*Gb/i);
return match != null && parseInt(match[1], 10) > 40;
}),
)
.map(([id]) => id);
if (tb5NodeIds.length < 2) return false;
// At least one TB5 node has RDMA disabled
Expand Down
29 changes: 29 additions & 0 deletions src/exo/master/placement.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from exo.shared.types.memory import Memory
from exo.shared.types.profiling import MemoryUsage, NodeNetworkInfo, NodeRdmaCtlStatus
from exo.shared.types.tasks import Task, TaskId, TaskStatus
from exo.shared.types.topology import SocketConnection
from exo.shared.types.worker.downloads import (
DownloadCompleted,
DownloadFailed,
Expand Down Expand Up @@ -241,6 +242,7 @@ def _all_rdma_ctl_enabled(cycle: Cycle) -> bool:
),
),
)
selected_cycle = _prefer_socket_reachable_rank_zero(selected_cycle, topology)

# Single-node: force Pipeline/Ring (Tensor and Jaccl require multi-node)
if len(selected_cycle) == 1:
Expand Down Expand Up @@ -311,6 +313,33 @@ def get_device_rank(node_id: NodeId) -> int:
return target_instances


def _prefer_socket_reachable_rank_zero(cycle: Cycle, topology: Topology) -> Cycle:
"""Rotate multi-node placements so rank 0 is easiest for peers to reach.

MLX ring and JACCL both make rank 0 the listener/coordinator. Discovery can
produce RDMA-only edges in one direction and socket control-plane edges in
another, so putting a node with advertised inbound socket edges at rank 0
avoids assigning the listener role to a machine peers cannot dial.
"""
if len(cycle) <= 1:
return cycle

inbound_socket_edges: dict[NodeId, int] = {node_id: 0 for node_id in cycle}
for connection in topology.list_connections():
if connection.sink not in inbound_socket_edges:
continue
if isinstance(connection.edge, SocketConnection):
inbound_socket_edges[connection.sink] += 1

best_index = max(
range(len(cycle.node_ids)),
key=lambda index: (inbound_socket_edges[cycle.node_ids[index]], -index),
)
if best_index == 0:
return cycle
return Cycle(node_ids=cycle.node_ids[best_index:] + cycle.node_ids[:best_index])


def delete_instance(
command: DeleteInstance,
current_instances: Mapping[InstanceId, Instance],
Expand Down
41 changes: 34 additions & 7 deletions src/exo/master/placement_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,11 +345,15 @@ def find_ip_prioritised(
) -> str | None:
"""Find an IP address between nodes with prioritization.

Priority: ethernet > wifi > unknown > thunderbolt
Interface type drives the primary preference (TB first for ring, ethernet
first for the RDMA coordinator). Address class is only a tiebreaker that
keeps RFC1918 LAN ahead of CGNAT-class addresses (e.g. Tailscale 100.64/10)
when a peer advertises multiple socket-reachable IPs of the same type.
"""
ips = list(_find_connection_ip(node_id, other_node_id, cycle_digraph))
if not ips:
return None

other_network = node_network.get(other_node_id, NodeNetworkInfo())
ip_to_type = {
iface.ip_address: iface.interface_type for iface in other_network.interfaces
Expand All @@ -358,7 +362,7 @@ def find_ip_prioritised(
# Ring should prioritise fastest connection. As a best-effort, we prioritise TB.
# TODO: Profile and get actual connection speeds.
if ring:
priority = {
type_priority = {
"thunderbolt": 0,
"maybe_ethernet": 1,
"ethernet": 2,
Expand All @@ -368,14 +372,37 @@ def find_ip_prioritised(

# RDMA prefers ethernet coordinator
else:
priority = {
type_priority = {
"ethernet": 0,
"wifi": 1,
"unknown": 2,
"maybe_ethernet": 3,
"maybe_ethernet": 1,
"wifi": 2,
"unknown": 3,
"thunderbolt": 4,
}
return min(ips, key=lambda ip: priority.get(ip_to_type.get(ip, "unknown"), 2))

return min(
ips,
key=lambda ip: (
type_priority.get(ip_to_type.get(ip, "unknown"), 5),
_address_priority(ip),
),
)


def _address_priority(ip: str) -> int:
"""RFC1918 LAN addresses are preferred; everything else (CGNAT/Tailscale,
link-local, public) is treated identically and only ranked by interface
type."""
if ip.startswith(("192.168.", "10.")):
return 0
if ip.startswith("172."):
try:
second_octet = int(ip.split(".")[1])
except (IndexError, ValueError):
return 1
if 16 <= second_octet <= 31:
return 0
return 1


def get_mlx_ring_hosts_by_node(
Expand Down
Loading