From 370979ee0095d474e79a59def563260ab59c7eb9 Mon Sep 17 00:00:00 2001 From: jw-wcv <101585096+jw-wcv@users.noreply.github.com> Date: Sat, 25 Apr 2026 18:10:33 -0700 Subject: [PATCH 1/9] Fix RDMA placement host selection --- src/exo/master/placement.py | 29 ++++ src/exo/master/placement_utils.py | 60 ++++++-- src/exo/master/tests/test_placement.py | 190 +++++++++++++++++++++++++ 3 files changed, 271 insertions(+), 8 deletions(-) diff --git a/src/exo/master/placement.py b/src/exo/master/placement.py index b55571c7f7..7c8d4b2257 100644 --- a/src/exo/master/placement.py +++ b/src/exo/master/placement.py @@ -30,6 +30,7 @@ from exo.shared.types.memory import Memory from exo.shared.types.profiling import MemoryUsage, NodeNetworkInfo, NodeRdmaCtlStatus from exo.shared.types.tasks import Task, TaskId, TaskStatus +from exo.shared.types.topology import SocketConnection from exo.shared.types.worker.downloads import ( DownloadCompleted, DownloadFailed, @@ -211,6 +212,7 @@ def _all_rdma_ctl_enabled(cycle: Cycle) -> bool: ), ), ) + selected_cycle = _prefer_socket_reachable_rank_zero(selected_cycle, topology) # Single-node: force Pipeline/Ring (Tensor and Jaccl require multi-node) if len(selected_cycle) == 1: @@ -281,6 +283,33 @@ def get_device_rank(node_id: NodeId) -> int: return target_instances +def _prefer_socket_reachable_rank_zero(cycle: Cycle, topology: Topology) -> Cycle: + """Rotate multi-node placements so rank 0 is easiest for peers to reach. + + MLX ring and JACCL both make rank 0 the listener/coordinator. Discovery can + produce RDMA-only edges in one direction and socket control-plane edges in + another, so putting a node with advertised inbound socket edges at rank 0 + avoids assigning the listener role to a machine peers cannot dial. + """ + if len(cycle) <= 1: + return cycle + + inbound_socket_edges: dict[NodeId, int] = {node_id: 0 for node_id in cycle} + for connection in topology.list_connections(): + if connection.sink not in inbound_socket_edges: + continue + if isinstance(connection.edge, SocketConnection): + inbound_socket_edges[connection.sink] += 1 + + best_index = max( + range(len(cycle.node_ids)), + key=lambda index: (inbound_socket_edges[cycle.node_ids[index]], -index), + ) + if best_index == 0: + return cycle + return Cycle(node_ids=cycle.node_ids[best_index:] + cycle.node_ids[:best_index]) + + def delete_instance( command: DeleteInstance, current_instances: Mapping[InstanceId, Instance], diff --git a/src/exo/master/placement_utils.py b/src/exo/master/placement_utils.py index 0375e97e01..f5162ebe74 100644 --- a/src/exo/master/placement_utils.py +++ b/src/exo/master/placement_utils.py @@ -348,17 +348,20 @@ def find_ip_prioritised( Priority: ethernet > wifi > unknown > thunderbolt """ ips = list(_find_connection_ip(node_id, other_node_id, cycle_digraph)) - if not ips: - return None other_network = node_network.get(other_node_id, NodeNetworkInfo()) ip_to_type = { iface.ip_address: iface.interface_type for iface in other_network.interfaces } + if not ips: + ips = _fallback_interface_ips(other_network) + if not ips: + return None + # Ring should prioritise fastest connection. As a best-effort, we prioritise TB. # TODO: Profile and get actual connection speeds. if ring: - priority = { + type_priority = { "thunderbolt": 0, "maybe_ethernet": 1, "ethernet": 2, @@ -368,14 +371,55 @@ def find_ip_prioritised( # RDMA prefers ethernet coordinator else: - priority = { + type_priority = { "ethernet": 0, - "wifi": 1, - "unknown": 2, - "maybe_ethernet": 3, + "maybe_ethernet": 1, + "wifi": 2, + "unknown": 3, "thunderbolt": 4, } - return min(ips, key=lambda ip: priority.get(ip_to_type.get(ip, "unknown"), 2)) + + return min( + ips, + key=lambda ip: ( + _address_priority(ip), + type_priority.get(ip_to_type.get(ip, "unknown"), 5), + ), + ) + + +def _fallback_interface_ips(node_network: NodeNetworkInfo) -> list[str]: + """Return advertised node IPs when topology only has non-socket edges.""" + return [ + iface.ip_address + for iface in node_network.interfaces + if _is_candidate_host_ip(iface.ip_address) + ] + + +def _is_candidate_host_ip(ip: str) -> bool: + if ":" in ip: + return False + if ip.startswith("127.") or ip == "0.0.0.0": + return False + return True + + +def _address_priority(ip: str) -> int: + if ip.startswith("192.168.") or ip.startswith("10."): + return 0 + if ip.startswith("172."): + try: + second_octet = int(ip.split(".")[1]) + except (IndexError, ValueError): + return 3 + if 16 <= second_octet <= 31: + return 0 + if ip.startswith("100."): + return 2 + if ip.startswith("169.254."): + return 3 + return 1 def get_mlx_ring_hosts_by_node( diff --git a/src/exo/master/tests/test_placement.py b/src/exo/master/tests/test_placement.py index d3acd24f18..03218ba5ed 100644 --- a/src/exo/master/tests/test_placement.py +++ b/src/exo/master/tests/test_placement.py @@ -624,6 +624,196 @@ def test_place_mlx_jaccl_rejects_when_node_rdma_ctl_missing(model_card: ModelCar ) +def test_ring_placement_uses_advertised_lan_ips_for_rdma_only_topology( + model_card: ModelCard, +) -> None: + topology = Topology() + model_card = model_card.model_copy( + update={ + "storage_size": Memory.from_bytes(1500), + "n_layers": 12, + } + ) + + node_a = NodeId() + node_b = NodeId() + + topology.add_node(node_a) + topology.add_node(node_b) + topology.add_connection( + Connection(source=node_a, sink=node_b, edge=create_rdma_connection(1)) + ) + topology.add_connection( + Connection(source=node_b, sink=node_a, edge=create_rdma_connection(2)) + ) + + node_memory = { + node_a: create_node_memory(1000), + node_b: create_node_memory(1000), + } + node_network = { + node_a: NodeNetworkInfo( + interfaces=[ + NetworkInterfaceInfo( + name="en9", ip_address="192.168.1.10", interface_type="ethernet" + ) + ] + ), + node_b: NodeNetworkInfo( + interfaces=[ + NetworkInterfaceInfo( + name="en9", ip_address="192.168.1.11", interface_type="ethernet" + ) + ] + ), + } + + command = place_instance_command(model_card) + command = command.model_copy(update={"min_nodes": 2}) + + placements = place_instance(command, topology, {}, node_memory, node_network) + + instance = list(placements.values())[0] + assert isinstance(instance, MlxRingInstance) + assert len(instance.shard_assignments.node_to_runner) == 2 + assert any(host.ip == "192.168.1.11" for host in instance.hosts_by_node[node_a]) + assert any(host.ip == "192.168.1.10" for host in instance.hosts_by_node[node_b]) + + +def test_jaccl_placement_uses_advertised_lan_ip_for_rdma_coordinator( + model_card: ModelCard, +) -> None: + topology = Topology() + model_card = model_card.model_copy( + update={ + "storage_size": Memory.from_bytes(1500), + "n_layers": 12, + "hidden_size": 32, + "num_key_value_heads": 8, + "supports_tensor": True, + } + ) + + node_a = NodeId() + node_b = NodeId() + + topology.add_node(node_a) + topology.add_node(node_b) + topology.add_connection( + Connection(source=node_a, sink=node_b, edge=create_rdma_connection(1)) + ) + topology.add_connection( + Connection(source=node_b, sink=node_a, edge=create_rdma_connection(2)) + ) + + node_memory = { + node_a: create_node_memory(1000), + node_b: create_node_memory(1000), + } + node_network = { + node_a: NodeNetworkInfo( + interfaces=[ + NetworkInterfaceInfo( + name="en9", ip_address="192.168.1.10", interface_type="ethernet" + ) + ] + ), + node_b: NodeNetworkInfo( + interfaces=[ + NetworkInterfaceInfo( + name="en9", ip_address="192.168.1.11", interface_type="ethernet" + ) + ] + ), + } + command = PlaceInstance( + sharding=Sharding.Tensor, + instance_meta=InstanceMeta.MlxJaccl, + command_id=CommandId(), + model_card=model_card, + min_nodes=2, + ) + + node_rdma_ctl = { + node_a: NodeRdmaCtlStatus(enabled=True), + node_b: NodeRdmaCtlStatus(enabled=True), + } + placements = place_instance( + command, + topology, + {}, + node_memory, + node_network, + node_rdma_ctl=node_rdma_ctl, + ) + + instance = list(placements.values())[0] + assert isinstance(instance, MlxJacclInstance) + assert len(instance.shard_assignments.node_to_runner) == 2 + assert any( + coordinator.startswith("192.168.1.") + for coordinator in instance.jaccl_coordinators.values() + ) + + +def test_placement_prefers_socket_reachable_rank_zero( + model_card: ModelCard, +) -> None: + topology = Topology() + model_card = model_card.model_copy( + update={ + "storage_size": Memory.from_bytes(1500), + "n_layers": 12, + } + ) + + listener = NodeId() + peer = NodeId() + + topology.add_node(listener) + topology.add_node(peer) + topology.add_connection( + Connection(source=listener, sink=peer, edge=create_rdma_connection(1)) + ) + topology.add_connection( + Connection(source=peer, sink=listener, edge=create_rdma_connection(2)) + ) + topology.add_connection( + Connection(source=peer, sink=listener, edge=create_socket_connection(10)) + ) + + node_memory = { + listener: create_node_memory(1000), + peer: create_node_memory(1000), + } + node_network = { + listener: NodeNetworkInfo( + interfaces=[ + NetworkInterfaceInfo( + name="en9", ip_address="192.168.1.10", interface_type="ethernet" + ) + ] + ), + peer: NodeNetworkInfo( + interfaces=[ + NetworkInterfaceInfo( + name="en9", ip_address="192.168.1.11", interface_type="ethernet" + ) + ] + ), + } + + command = place_instance_command(model_card) + command = command.model_copy(update={"min_nodes": 2}) + + placements = place_instance(command, topology, {}, node_memory, node_network) + + instance = list(placements.values())[0] + runner_id = instance.shard_assignments.node_to_runner[listener] + shard = instance.shard_assignments.runner_to_shard[runner_id] + assert shard.device_rank == 0 + + def _make_task( instance_id: InstanceId, status: TaskStatus = TaskStatus.Running, From ce61802b77cabd49364e4b2291e74dcaadbdb8b9 Mon Sep 17 00:00:00 2001 From: jw-wcv <101585096+jw-wcv@users.noreply.github.com> Date: Sat, 25 Apr 2026 19:17:32 -0700 Subject: [PATCH 2/9] Walk nested Thunderbolt _items so hub-attached peers register `system_profiler SPThunderboltDataType -json` represents transparent TB hubs (e.g. the iVANKY Fusiondock Ultra) as an extra layer in `_items`, nesting the peer Mac one level deeper than a direct cable would. The previous parser only inspected the top-level `_items`, so on the side that enumerated through the dock the peer's `domain_uuid_key` was never found and the corresponding RDMA edge silently went missing - producing an asymmetric mesh where placement could see the link in one direction but not the other. Extend `_ConnectivityItem` to recursively model `_items` and walk the tree depth-first for the first descendant `domain_uuid_key`. The link is the actual peer endpoint regardless of how many transparent switches sit between the local receptacle and the peer Mac. Cover the new behaviour with three deterministic unit tests: the iVANKY hub case observed in production, a direct-cable sanity case, and an empty-receptacle case. --- src/exo/shared/types/thunderbolt.py | 31 +++++++--- .../info_gatherer/tests/test_tb_parsing.py | 62 +++++++++++++++++++ 2 files changed, 85 insertions(+), 8 deletions(-) diff --git a/src/exo/shared/types/thunderbolt.py b/src/exo/shared/types/thunderbolt.py index 34cd1ccad9..d6b8d3742c 100644 --- a/src/exo/shared/types/thunderbolt.py +++ b/src/exo/shared/types/thunderbolt.py @@ -25,6 +25,28 @@ class _ReceptacleTag(BaseModel, extra="ignore"): class _ConnectivityItem(BaseModel, extra="ignore"): domain_uuid_key: str | None = None + items: list["_ConnectivityItem"] | None = Field(None, alias="_items") + + +def _first_descendant_domain_uuid(items: list[_ConnectivityItem]) -> str | None: + """Return the first ``domain_uuid_key`` found by depth-first search. + + Apple's ``system_profiler SPThunderboltDataType`` output places intermediate + Thunderbolt hubs/docks (e.g. an iVANKY Fusiondock Ultra) between the local + receptacle and the peer Mac. The hub appears as an ``_items`` entry without + a ``domain_uuid_key`` of its own; the peer Mac sits one level deeper. We + descend until we hit the first node that exposes a domain UUID, which is + always the actual peer endpoint regardless of how many transparent + switches sit between us. + """ + for item in items: + if item.domain_uuid_key is not None: + return item.domain_uuid_key + if item.items is not None: + descendant = _first_descendant_domain_uuid(item.items) + if descendant is not None: + return descendant + return None class ThunderboltConnectivityData(BaseModel, extra="ignore"): @@ -53,14 +75,7 @@ def conn(self) -> ThunderboltConnection | None: if self.domain_uuid_key is None or self.items is None: return - sink_key = next( - ( - item.domain_uuid_key - for item in self.items - if item.domain_uuid_key is not None - ), - None, - ) + sink_key = _first_descendant_domain_uuid(self.items) if sink_key is None: return None diff --git a/src/exo/utils/info_gatherer/tests/test_tb_parsing.py b/src/exo/utils/info_gatherer/tests/test_tb_parsing.py index 787dd3d5f9..d2551f0c6b 100644 --- a/src/exo/utils/info_gatherer/tests/test_tb_parsing.py +++ b/src/exo/utils/info_gatherer/tests/test_tb_parsing.py @@ -4,6 +4,7 @@ from exo.shared.types.thunderbolt import ( ThunderboltConnectivity, + ThunderboltConnectivityData, ) from exo.utils.info_gatherer.info_gatherer import ( _gather_iface_map, # pyright: ignore[reportPrivateUsage] @@ -22,3 +23,64 @@ async def test_tb_parsing(): for datum in data: datum.ident(ifaces) datum.conn() + + +def test_conn_resolves_peer_through_intermediate_hub() -> None: + """A TB hub between two Macs hides the peer one level deeper. + + Reproduces the iVANKY Fusiondock Ultra topology observed on the + wc-bmbp <-> wc-smbp link, where ``system_profiler`` reports the dock at + the first ``_items`` level and the peer Mac nested inside it. The parser + must walk past the dock and surface the peer's domain UUID, otherwise + half the RDMA mesh stays invisible to the placement engine. + """ + payload = { + "domain_uuid_key": "DCA2B6F5-1C58-4589-8DA8-90B9326462D6", + "receptacle_1_tag": { + "receptacle_id_key": "2", + "current_speed_key": "80 Gb/s", + }, + "_items": [ + { + "_name": "iVANKY Fusiondock Ultra", + "_items": [ + { + "_name": "MacBook Pro", + "domain_uuid_key": "F74D8F9B-DCDF-40D4-A428-3A3674BCB3F4", + } + ], + } + ], + } + datum = ThunderboltConnectivityData.model_validate(payload) + conn = datum.conn() + assert conn is not None + assert conn.source_uuid == "DCA2B6F5-1C58-4589-8DA8-90B9326462D6" + assert conn.sink_uuid == "F74D8F9B-DCDF-40D4-A428-3A3674BCB3F4" + + +def test_conn_returns_first_peer_for_direct_link() -> None: + """A direct cable still surfaces the peer at the first level.""" + payload = { + "domain_uuid_key": "EA94B959-A0C4-1111-1111-111111111111", + "_items": [ + { + "_name": "MacBook Pro", + "domain_uuid_key": "D02B9C20-7504-2222-2222-222222222222", + } + ], + } + datum = ThunderboltConnectivityData.model_validate(payload) + conn = datum.conn() + assert conn is not None + assert conn.sink_uuid == "D02B9C20-7504-2222-2222-222222222222" + + +def test_conn_returns_none_when_no_peer_present() -> None: + """Empty ``_items`` (e.g. unconnected receptacle) yields no edge.""" + payload: dict[str, object] = { + "domain_uuid_key": "AAA", + "_items": [], + } + datum = ThunderboltConnectivityData.model_validate(payload) + assert datum.conn() is None From 677144f4b4c8e0bc7e67087fe893af012751ecc2 Mon Sep 17 00:00:00 2001 From: jw-wcv <101585096+jw-wcv@users.noreply.github.com> Date: Sun, 26 Apr 2026 01:22:02 -0700 Subject: [PATCH 3/9] Add JACCL init retry loop MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Retry the JACCL initialization handshake when peers come online out of order, which we frequently observed when a 4-node Apple Silicon cluster booted on a Thunderbolt RDMA fabric — the master would attempt the JACCL collective before all worker engines had registered, yielding a single-shot failure that took the whole inference path with it. Adds a small retry loop with backoff in mlx utils so the master will reattempt init for ~30s before giving up. --- src/exo/worker/engines/mlx/utils_mlx.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/src/exo/worker/engines/mlx/utils_mlx.py b/src/exo/worker/engines/mlx/utils_mlx.py index 1dddad2ae1..6d63eb6ae3 100644 --- a/src/exo/worker/engines/mlx/utils_mlx.py +++ b/src/exo/worker/engines/mlx/utils_mlx.py @@ -125,7 +125,6 @@ def mlx_distributed_init( assert all( jaccl_devices[i][i] is None for i in range(len(jaccl_devices)) ) - # Use RDMA connectivity matrix jaccl_devices_json = json.dumps(jaccl_devices) with open(coordination_file, "w") as f: @@ -140,7 +139,21 @@ def mlx_distributed_init( os.environ["MLX_IBV_DEVICES"] = coordination_file os.environ["MLX_RANK"] = str(rank) os.environ["MLX_JACCL_COORDINATOR"] = jaccl_coordinator - group = mx.distributed.init(backend="jaccl", strict=True) + + max_jaccl_attempts = 8 + for attempt in range(1, max_jaccl_attempts + 1): + try: + group = mx.distributed.init(backend="jaccl", strict=True) + break + except (RuntimeError, ValueError) as exc: + if attempt == max_jaccl_attempts: + raise + backoff = min(2.0 * attempt, 10.0) + logger.warning( + f"rank {rank} JACCL init attempt {attempt}/{max_jaccl_attempts} " + f"failed ({exc}), retrying in {backoff:.0f}s" + ) + time.sleep(backoff) logger.info(f"Rank {rank} mlx distributed initialization complete") From 31d76693ddf66cab3971acad32502f8350f8ddfc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Do=C4=9Fukan=20Veziro=C4=9Flu?= Date: Mon, 27 Apr 2026 13:44:39 -0700 Subject: [PATCH 4/9] Distinguish TB4 from TB5 for Thunderbolt connections (dashboard) Only treat a Thunderbolt link as TB5 when the link speed is >40 Gb/s (TB5 negotiated speeds are 80 Gb/s symmetric or 120/40 Gb/s asymmetric). Without this, the dashboard was prompting users to enable RDMA on TB4-only nodes that can't actually run rdma_ctl in TB5 mode. --- dashboard/src/routes/+page.svelte | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/dashboard/src/routes/+page.svelte b/dashboard/src/routes/+page.svelte index f69cdd1519..9b4563fd18 100644 --- a/dashboard/src/routes/+page.svelte +++ b/dashboard/src/routes/+page.svelte @@ -147,9 +147,14 @@ if (!rdmaCtl) return false; const ids = tbIdentifiers; if (!ids) return false; - // Find nodes with TB5 hardware (any TB interface) + // Find nodes with TB5 hardware (link speed > 40 Gb/s) const tb5NodeIds = Object.entries(ids) - .filter(([_, node]) => node.interfaces.length > 0) + .filter(([_, node]) => + node.interfaces.some((iface: { linkSpeed: string }) => { + const match = iface.linkSpeed.match(/(\d+)\s*Gb/i); + return match != null && parseInt(match[1], 10) > 40; + }), + ) .map(([id]) => id); if (tb5NodeIds.length < 2) return false; // At least one TB5 node has RDMA disabled From f469af129c16d39c30bdfa31b464f3b1833cb794 Mon Sep 17 00:00:00 2001 From: jw-wcv <101585096+jw-wcv@users.noreply.github.com> Date: Wed, 6 May 2026 17:05:34 -0700 Subject: [PATCH 5/9] chore: ruff/basedpyright cleanup on RDMA placement utils + JACCL retry - placement_utils._is_candidate_host_ip / _address_priority: SIM103 + PIE810 ruff fixes (no behavior change). - utils_mlx: scope a # pyright: ignore[reportPossiblyUnboundVariable] on the JACCL retry-loop return; group is always bound when the loop exits with break, but the static analyzer can't reason across the raise-on-final-attempt branch. --- src/exo/master/placement_utils.py | 6 ++---- src/exo/worker/engines/mlx/utils_mlx.py | 2 +- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/src/exo/master/placement_utils.py b/src/exo/master/placement_utils.py index f5162ebe74..acc0c162e6 100644 --- a/src/exo/master/placement_utils.py +++ b/src/exo/master/placement_utils.py @@ -400,13 +400,11 @@ def _fallback_interface_ips(node_network: NodeNetworkInfo) -> list[str]: def _is_candidate_host_ip(ip: str) -> bool: if ":" in ip: return False - if ip.startswith("127.") or ip == "0.0.0.0": - return False - return True + return not (ip.startswith("127.") or ip == "0.0.0.0") def _address_priority(ip: str) -> int: - if ip.startswith("192.168.") or ip.startswith("10."): + if ip.startswith(("192.168.", "10.")): return 0 if ip.startswith("172."): try: diff --git a/src/exo/worker/engines/mlx/utils_mlx.py b/src/exo/worker/engines/mlx/utils_mlx.py index 6d63eb6ae3..e3f7778758 100644 --- a/src/exo/worker/engines/mlx/utils_mlx.py +++ b/src/exo/worker/engines/mlx/utils_mlx.py @@ -157,7 +157,7 @@ def mlx_distributed_init( logger.info(f"Rank {rank} mlx distributed initialization complete") - return group + return group # pyright: ignore[reportPossiblyUnboundVariable] def initialize_mlx( From 47277f5cfed00dd49feb27c1aeb0e326f737f3fb Mon Sep 17 00:00:00 2001 From: jw-wcv <101585096+jw-wcv@users.noreply.github.com> Date: Tue, 28 Apr 2026 10:48:16 -0700 Subject: [PATCH 6/9] Add asymmetric tensor parallel integration --- src/exo/api/main.py | 33 +- src/exo/master/placement.py | 129 ++++-- src/exo/master/placement_utils.py | 78 ++++ src/exo/master/tests/test_placement.py | 191 ++++++++- src/exo/shared/types/worker/shards.py | 31 +- .../worker/engines/mlx/asymmetric_parallel.py | 376 ++++++++++++++++++ src/exo/worker/engines/mlx/utils_mlx.py | 28 ++ .../test_mlx/test_asymmetric_parallel.py | 119 ++++++ 8 files changed, 949 insertions(+), 36 deletions(-) create mode 100644 src/exo/worker/engines/mlx/asymmetric_parallel.py create mode 100644 src/exo/worker/tests/unittests/test_mlx/test_asymmetric_parallel.py diff --git a/src/exo/api/main.py b/src/exo/api/main.py index 4fb6d2d3b0..3afa18de20 100644 --- a/src/exo/api/main.py +++ b/src/exo/api/main.py @@ -195,7 +195,7 @@ ) from exo.shared.types.worker.downloads import DownloadCompleted from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta -from exo.shared.types.worker.shards import Sharding +from exo.shared.types.worker.shards import AsymmetricTensorShardMetadata, Sharding from exo.utils.banner import print_startup_banner from exo.utils.channels import Receiver, Sender, channel from exo.utils.disk_event_log import DiskEventLog @@ -587,11 +587,32 @@ async def get_placement_previews( memory_delta_by_node: dict[str, int] = {} if placement_node_ids: total_bytes = model_card.storage_size.in_bytes - per_node = total_bytes // len(placement_node_ids) - remainder = total_bytes % len(placement_node_ids) - for index, node_id in enumerate(sorted(placement_node_ids, key=str)): - extra = 1 if index < remainder else 0 - memory_delta_by_node[str(node_id)] = per_node + extra + asymmetric_shards: dict[NodeId, AsymmetricTensorShardMetadata] = {} + for ( + node_id, + runner_id, + ) in shard_assignments.node_to_runner.items(): + shard_metadata = shard_assignments.runner_to_shard[runner_id] + if isinstance(shard_metadata, AsymmetricTensorShardMetadata): + asymmetric_shards[node_id] = shard_metadata + if asymmetric_shards: + for node_id, shard_metadata in asymmetric_shards.items(): + rank_weight_fraction = ( + shard_metadata.ratio + if shard_metadata.device_rank == 0 + else 1.0 - shard_metadata.ratio + ) + memory_delta_by_node[str(node_id)] = int( + total_bytes * rank_weight_fraction + ) + else: + per_node = total_bytes // len(placement_node_ids) + remainder = total_bytes % len(placement_node_ids) + for index, node_id in enumerate( + sorted(placement_node_ids, key=str) + ): + extra = 1 if index < remainder else 0 + memory_delta_by_node[str(node_id)] = per_node + extra if ( model_card.model_id, diff --git a/src/exo/master/placement.py b/src/exo/master/placement.py index 7c8d4b2257..bc5ed0a606 100644 --- a/src/exo/master/placement.py +++ b/src/exo/master/placement.py @@ -1,7 +1,10 @@ from collections.abc import Mapping from copy import deepcopy +from os import environ from typing import Sequence +from loguru import logger + from exo.master.placement_utils import ( Cycle, filter_cycles_by_memory, @@ -11,7 +14,7 @@ get_shard_assignments, get_smallest_cycles, ) -from exo.shared.models.model_cards import ModelId +from exo.shared.models.model_cards import ModelCard, ModelId from exo.shared.topology import Topology from exo.shared.types.commands import ( CancelDownload, @@ -48,6 +51,26 @@ from exo.shared.types.worker.shards import Sharding from exo.utils.ports import random_ephemeral_port +ASYMMETRIC_TENSOR_AUTO_UPGRADE_ENV = "EXO_ENABLE_ASYMMETRIC_TP_AUTO_UPGRADE" + + +def _supports_asymmetric_tensor_parallel(model_card: ModelCard) -> bool: + model_id = model_card.model_id.lower() + base_model = model_card.base_model.lower() + return ( + base_model.startswith("qwen3.5") + or "qwen3.5" in model_id + or "qwen-3.5" in model_id + ) + + +def _asymmetric_tensor_auto_upgrade_enabled() -> bool: + return environ.get(ASYMMETRIC_TENSOR_AUTO_UPGRADE_ENV, "").lower() in { + "1", + "true", + "yes", + } + def add_instance_to_placements( command: CreateInstance, @@ -64,7 +87,7 @@ def _get_node_download_fraction( model_id: ModelId, download_status: Mapping[NodeId, Sequence[DownloadProgress]], ) -> float: - """Return the download fraction (0.0–1.0) for a model on a given node.""" + """Return the download fraction (0.0-1.0) for a model on a given node.""" for progress in download_status.get(node_id, []): if progress.shard_metadata.model_card.model_id != model_id: continue @@ -108,6 +131,8 @@ def place_instance( download_status: Mapping[NodeId, Sequence[DownloadProgress]] | None = None, node_rdma_ctl: Mapping[NodeId, NodeRdmaCtlStatus] | None = None, ) -> dict[InstanceId, Instance]: + sharding = command.sharding + instance_meta = command.instance_meta cycles = topology.get_cycles() candidate_cycles = list(filter(lambda it: len(it) >= command.min_nodes, cycles)) @@ -124,38 +149,82 @@ def place_instance( if len(cycles_with_sufficient_memory) == 0: raise ValueError("No cycles found with sufficient memory") - if command.sharding == Sharding.Tensor: + if sharding == Sharding.AsymmetricTensor and not _supports_asymmetric_tensor_parallel( + command.model_card + ): + raise ValueError( + f"Asymmetric tensor parallelism is not yet supported for " + f"model '{command.model_card.model_id}'. Supported: Qwen3.5." + ) + + if sharding in (Sharding.Tensor, Sharding.AsymmetricTensor): if not command.model_card.supports_tensor: raise ValueError( f"Requested Tensor sharding but this model does not support tensor parallelism: {command.model_card.model_id}" ) - # TODO: the condition here for tensor parallel is not correct, but it works good enough for now. - # DeepSeek V4 is MQA (num_key_value_heads=1) but its sharding strategy - # head-parallelises wq_b/wo_a and shards MoE experts instead of splitting - # KV heads, so the kv-head divisibility check doesn't apply. - is_deepseek_v4 = command.model_card.base_model.startswith("DeepSeek V4") - kv_heads = command.model_card.num_key_value_heads + if sharding == Sharding.Tensor: + # TODO: the condition here for tensor parallel is not correct, but it works good enough for now. + # DeepSeek V4 is MQA (num_key_value_heads=1) but its sharding strategy + # head-parallelises wq_b/wo_a and shards MoE experts instead of splitting + # KV heads, so the kv-head divisibility check doesn't apply. + is_deepseek_v4 = command.model_card.base_model.startswith("DeepSeek V4") + kv_heads = command.model_card.num_key_value_heads + cycles_with_sufficient_memory = [ + cycle + for cycle in cycles_with_sufficient_memory + if command.model_card.hidden_size % len(cycle) == 0 + and (is_deepseek_v4 or kv_heads is None or kv_heads % len(cycle) == 0) + ] + if not cycles_with_sufficient_memory: + raise ValueError( + f"No tensor sharding found for model with " + f"hidden_size={command.model_card.hidden_size}" + f"{f', num_key_value_heads={kv_heads}' if kv_heads is not None else ''}" + f" across candidate cycles" + ) + + # Auto-upgrade to AsymmetricTensor when equal TP won't fit on + # the smallest node but asymmetric split would. + if ( + _asymmetric_tensor_auto_upgrade_enabled() + and _supports_asymmetric_tensor_parallel(command.model_card) + ): + for cycle in cycles_with_sufficient_memory: + equal_share = command.model_card.storage_size.in_bytes / len(cycle) + min_node_mem = min( + node_memory[nid].ram_available.in_bytes for nid in cycle + ) + if equal_share > min_node_mem * 0.9: + # Equal split too tight; try asymmetric. + total_mem = sum( + node_memory[nid].ram_available.in_bytes for nid in cycle + ) + if command.model_card.storage_size.in_bytes < total_mem * 0.85: + logger.info( + "Equal tensor split won't fit on smallest node " + f"({min_node_mem / 1e9:.0f}GB available, " + f"needs {equal_share / 1e9:.0f}GB). " + "Auto-upgrading to AsymmetricTensor." + ) + sharding = Sharding.AsymmetricTensor + break + if sharding == Sharding.AsymmetricTensor: cycles_with_sufficient_memory = [ - cycle - for cycle in cycles_with_sufficient_memory - if command.model_card.hidden_size % len(cycle) == 0 - and (is_deepseek_v4 or kv_heads is None or kv_heads % len(cycle) == 0) + cycle for cycle in cycles_with_sufficient_memory if len(cycle) == 2 ] if not cycles_with_sufficient_memory: raise ValueError( - f"No tensor sharding found for model with " - f"hidden_size={command.model_card.hidden_size}" - f"{f', num_key_value_heads={kv_heads}' if kv_heads is not None else ''}" - f" across candidate cycles" + "Asymmetric tensor parallelism currently requires exactly 2 nodes" ) - if command.sharding == Sharding.Pipeline and command.model_card.model_id == ModelId( + + if sharding == Sharding.Pipeline and command.model_card.model_id == ModelId( "mlx-community/DeepSeek-V3.1-8bit" ): raise ValueError( "Pipeline parallelism is not supported for DeepSeek V3.1 (8-bit)" ) if ( - command.sharding == Sharding.Pipeline + sharding == Sharding.Pipeline and command.model_card.base_model.startswith("Gemma 4") ): cycles_with_sufficient_memory = [ @@ -182,7 +251,7 @@ def _all_rdma_ctl_enabled(cycle: Cycle) -> bool: if topology.is_rdma_cycle(cycle) and _all_rdma_ctl_enabled(cycle) ] - if command.instance_meta == InstanceMeta.MlxJaccl: + if instance_meta == InstanceMeta.MlxJaccl: if not smallest_rdma_cycles: raise ValueError( "Requested RDMA (MlxJaccl) but no RDMA-connected cycles available" @@ -213,18 +282,22 @@ def _all_rdma_ctl_enabled(cycle: Cycle) -> bool: ), ) selected_cycle = _prefer_socket_reachable_rank_zero(selected_cycle, topology) + if sharding == Sharding.AsymmetricTensor: + selected_cycle = Cycle( + node_ids=sorted( + selected_cycle.node_ids, + key=lambda node_id: node_memory[node_id].ram_available.in_bytes, + reverse=True, + ) + ) # Single-node: force Pipeline/Ring (Tensor and Jaccl require multi-node) if len(selected_cycle) == 1: - command = command.model_copy( - update={ - "instance_meta": InstanceMeta.MlxRing, - "sharding": Sharding.Pipeline, - } - ) + instance_meta = InstanceMeta.MlxRing + sharding = Sharding.Pipeline shard_assignments = get_shard_assignments( - command.model_card, selected_cycle, command.sharding, node_memory + command.model_card, selected_cycle, sharding, node_memory ) cycle_digraph: Topology = topology.get_subgraph_from_nodes(selected_cycle.node_ids) @@ -232,7 +305,7 @@ def _all_rdma_ctl_enabled(cycle: Cycle) -> bool: instance_id = InstanceId() target_instances = dict(deepcopy(current_instances)) - match command.instance_meta: + match instance_meta: case InstanceMeta.MlxJaccl: # TODO(evan): shard assignments should contain information about ranks, this is ugly def get_device_rank(node_id: NodeId) -> int: diff --git a/src/exo/master/placement_utils.py b/src/exo/master/placement_utils.py index acc0c162e6..45f1bb983c 100644 --- a/src/exo/master/placement_utils.py +++ b/src/exo/master/placement_utils.py @@ -10,6 +10,7 @@ from exo.shared.types.topology import Cycle, RDMAConnection, SocketConnection from exo.shared.types.worker.runners import RunnerId, ShardAssignments from exo.shared.types.worker.shards import ( + AsymmetricTensorShardMetadata, CfgShardMetadata, PipelineShardMetadata, Sharding, @@ -273,6 +274,77 @@ def get_shard_assignments_for_tensor_parallel( return shard_assignments +def get_shard_assignments_for_asymmetric_tensor_parallel( + model_card: ModelCard, + cycle: Cycle, + node_memory: Mapping[NodeId, MemoryUsage], +) -> ShardAssignments: + """Create shard assignments for asymmetric tensor parallelism. + + Each node gets a ratio of weights proportional to its available memory. + All nodes compute every layer simultaneously. + """ + total_layers = model_card.n_layers + world_size = len(cycle) + + sorted_nodes = sorted( + cycle, + key=lambda node_id: node_memory[node_id].ram_available.in_bytes, + reverse=True, + ) + + # Compute memory fractions with the largest-memory node fixed as rank 0. + total_available = sum( + node_memory[node_id].ram_available.in_bytes for node_id in sorted_nodes + ) + memory_fractions = [ + node_memory[node_id].ram_available.in_bytes / total_available + for node_id in sorted_nodes + ] + + from exo.worker.engines.mlx.asymmetric_parallel import find_valid_ratios + + ratios = find_valid_ratios( + memory_fractions=memory_fractions, + hidden_size=model_card.hidden_size, + num_attention_heads=model_card.hidden_size // 128, + num_key_value_heads=model_card.num_key_value_heads or 2, + ) + if ratios is None: + raise ValueError( + f"No valid asymmetric ratio found for hidden_size={model_card.hidden_size}" + ) + + runner_to_shard: dict[RunnerId, ShardMetadata] = {} + node_to_runner: dict[NodeId, RunnerId] = {} + rank_zero_ratio = ratios[0] + + for i, node_id in enumerate(sorted_nodes): + shard = AsymmetricTensorShardMetadata( + model_card=model_card, + device_rank=i, + world_size=world_size, + start_layer=0, + end_layer=total_layers, + n_layers=total_layers, + ratio=rank_zero_ratio, + ) + runner_id = RunnerId() + runner_to_shard[runner_id] = shard + node_to_runner[node_id] = runner_id + + logger.info( + f"Asymmetric TP: ratios={[f'{r:.0%}' for r in ratios]} " + f"across {world_size} nodes" + ) + + return ShardAssignments( + model_id=model_card.model_id, + runner_to_shard=runner_to_shard, + node_to_runner=node_to_runner, + ) + + def get_shard_assignments( model_card: ModelCard, cycle: Cycle, @@ -291,6 +363,12 @@ def get_shard_assignments( model_card=model_card, cycle=cycle, ) + case Sharding.AsymmetricTensor: + return get_shard_assignments_for_asymmetric_tensor_parallel( + model_card=model_card, + cycle=cycle, + node_memory=node_memory, + ) def get_mlx_jaccl_devices_matrix( diff --git a/src/exo/master/tests/test_placement.py b/src/exo/master/tests/test_placement.py index 03218ba5ed..33097b6748 100644 --- a/src/exo/master/tests/test_placement.py +++ b/src/exo/master/tests/test_placement.py @@ -47,7 +47,11 @@ MlxRingInstance, ) from exo.shared.types.worker.runners import ShardAssignments -from exo.shared.types.worker.shards import PipelineShardMetadata, Sharding +from exo.shared.types.worker.shards import ( + AsymmetricTensorShardMetadata, + PipelineShardMetadata, + Sharding, +) @pytest.fixture @@ -499,6 +503,191 @@ def test_tensor_rdma_backend_connectivity_matrix( assert len(ip_part.split(".")) == 4 +def test_qwen3_5_tensor_auto_upgrade_requires_opt_in( + monkeypatch: pytest.MonkeyPatch, +) -> None: + topology = Topology() + large_node = NodeId() + small_node = NodeId() + topology.add_node(large_node) + topology.add_node(small_node) + topology.add_connection( + Connection(source=large_node, sink=small_node, edge=create_rdma_connection(1)) + ) + topology.add_connection( + Connection(source=small_node, sink=large_node, edge=create_rdma_connection(2)) + ) + topology.add_connection( + Connection(source=large_node, sink=small_node, edge=create_socket_connection(1)) + ) + topology.add_connection( + Connection(source=small_node, sink=large_node, edge=create_socket_connection(2)) + ) + + model_card = ModelCard( + model_id=ModelId("mlx-community/Qwen3.5-72B-8bit"), + storage_size=Memory.from_bytes(130_648_036_320), + n_layers=48, + hidden_size=3072, + num_key_value_heads=8, + supports_tensor=True, + tasks=[ModelTask.TextGeneration], + family="qwen", + base_model="Qwen3.5 72B", + ) + command = PlaceInstance( + sharding=Sharding.Tensor, + instance_meta=InstanceMeta.MlxJaccl, + command_id=CommandId(), + model_card=model_card, + min_nodes=2, + ) + + node_rdma_ctl = { + large_node: NodeRdmaCtlStatus(enabled=True), + small_node: NodeRdmaCtlStatus(enabled=True), + } + placements_without_opt_in = place_instance( + command, + topology, + {}, + { + large_node: create_node_memory(128_000_000_000), + small_node: create_node_memory(48_000_000_000), + }, + { + large_node: create_node_network(), + small_node: create_node_network(), + }, + node_rdma_ctl=node_rdma_ctl, + ) + instance_without_opt_in = next(iter(placements_without_opt_in.values())) + large_runner_without_opt_in = instance_without_opt_in.shard_assignments.node_to_runner[ + large_node + ] + large_shard_without_opt_in = ( + instance_without_opt_in.shard_assignments.runner_to_shard[ + large_runner_without_opt_in + ] + ) + assert not isinstance(large_shard_without_opt_in, AsymmetricTensorShardMetadata) + + monkeypatch.setenv("EXO_ENABLE_ASYMMETRIC_TP_AUTO_UPGRADE", "1") + + placements = place_instance( + command, + topology, + {}, + { + large_node: create_node_memory(128_000_000_000), + small_node: create_node_memory(48_000_000_000), + }, + { + large_node: create_node_network(), + small_node: create_node_network(), + }, + node_rdma_ctl=node_rdma_ctl, + ) + + instance = next(iter(placements.values())) + large_runner = instance.shard_assignments.node_to_runner[large_node] + small_runner = instance.shard_assignments.node_to_runner[small_node] + large_shard = instance.shard_assignments.runner_to_shard[large_runner] + small_shard = instance.shard_assignments.runner_to_shard[small_runner] + + assert isinstance(large_shard, AsymmetricTensorShardMetadata) + assert isinstance(small_shard, AsymmetricTensorShardMetadata) + assert large_shard.device_rank == 0 + assert small_shard.device_rank == 1 + assert large_shard.ratio == small_shard.ratio == 0.75 + + +def test_asymmetric_tensor_rejects_qwen3_5_with_unsplittable_kv_heads() -> None: + topology = Topology() + large_node = NodeId() + small_node = NodeId() + topology.add_node(large_node) + topology.add_node(small_node) + topology.add_connection( + Connection(source=large_node, sink=small_node, edge=create_socket_connection(1)) + ) + topology.add_connection( + Connection(source=small_node, sink=large_node, edge=create_socket_connection(2)) + ) + + model_card = ModelCard( + model_id=ModelId("mlx-community/Qwen3.5-122B-A10B-8bit"), + storage_size=Memory.from_bytes(130_648_036_320), + n_layers=48, + hidden_size=3072, + num_key_value_heads=2, + supports_tensor=True, + tasks=[ModelTask.TextGeneration], + family="qwen", + base_model="Qwen3.5 122B A10B", + ) + command = PlaceInstance( + sharding=Sharding.AsymmetricTensor, + instance_meta=InstanceMeta.MlxRing, + command_id=CommandId(), + model_card=model_card, + min_nodes=2, + ) + + with pytest.raises(ValueError, match="No valid asymmetric ratio"): + place_instance( + command, + topology, + {}, + { + large_node: create_node_memory(128_000_000_000), + small_node: create_node_memory(48_000_000_000), + }, + { + large_node: create_node_network(), + small_node: create_node_network(), + }, + ) + + +def test_asymmetric_tensor_rejects_unsupported_model_family( + model_card: ModelCard, +) -> None: + topology = Topology() + node_id_a = NodeId() + node_id_b = NodeId() + topology.add_node(node_id_a) + topology.add_node(node_id_b) + topology.add_connection( + Connection(source=node_id_a, sink=node_id_b, edge=create_socket_connection(1)) + ) + topology.add_connection( + Connection(source=node_id_b, sink=node_id_a, edge=create_socket_connection(2)) + ) + command = PlaceInstance( + sharding=Sharding.AsymmetricTensor, + instance_meta=InstanceMeta.MlxRing, + command_id=CommandId(), + model_card=model_card, + min_nodes=2, + ) + + with pytest.raises(ValueError, match="Supported: Qwen3.5"): + place_instance( + command, + topology, + {}, + { + node_id_a: create_node_memory(2_000_000), + node_id_b: create_node_memory(2_000_000), + }, + { + node_id_a: create_node_network(), + node_id_b: create_node_network(), + }, + ) + + def _build_three_node_rdma_topology() -> tuple[ Topology, NodeId, NodeId, NodeId, dict[NodeId, NodeNetworkInfo] ]: diff --git a/src/exo/shared/types/worker/shards.py b/src/exo/shared/types/worker/shards.py index 59a6c54eb0..112f6377a7 100644 --- a/src/exo/shared/types/worker/shards.py +++ b/src/exo/shared/types/worker/shards.py @@ -9,6 +9,7 @@ class Sharding(str, Enum): Tensor = "Tensor" + AsymmetricTensor = "AsymmetricTensor" Pipeline = "Pipeline" @@ -79,6 +80,34 @@ class TensorShardMetadata(BaseShardMetadata): pass +@final +class AsymmetricTensorShardMetadata(BaseShardMetadata): + """ + Asymmetric tensor parallelism shard metadata. + + Unlike standard tensor parallelism which splits weights 50/50 (or equally + across N nodes), asymmetric TP splits weights proportionally to each node's + available memory. This enables heterogeneous clusters (e.g. 128GB + 48GB) + to run models using tensor parallelism where equal splits wouldn't fit. + + Each node holds a different fraction of each weight tensor, but ALL nodes + compute every layer simultaneously. The all_sum reduction still works + correctly because (x_a @ W_a^T) + (x_b @ W_b^T) = x @ W^T regardless + of how W is partitioned. + """ + + ratio: float = Field( + ge=0.0, + le=1.0, + description="Split point for rank 0, shared across all ranks. " + "e.g. 0.75 means rank 0 gets the first 75% and rank 1 gets the last 25%. " + "Every rank stores the same value so all workers agree on the split.", + ) + + ShardMetadata: TypeAlias = ( - PipelineShardMetadata | CfgShardMetadata | TensorShardMetadata + PipelineShardMetadata + | CfgShardMetadata + | TensorShardMetadata + | AsymmetricTensorShardMetadata ) diff --git a/src/exo/worker/engines/mlx/asymmetric_parallel.py b/src/exo/worker/engines/mlx/asymmetric_parallel.py new file mode 100644 index 0000000000..f8fea2be74 --- /dev/null +++ b/src/exo/worker/engines/mlx/asymmetric_parallel.py @@ -0,0 +1,376 @@ +""" +Asymmetric Tensor Parallelism for heterogeneous clusters. + +When nodes have different amounts of RAM, standard 50/50 tensor parallelism +fails because the smaller node can't hold half the weights. Asymmetric TP +splits each weight tensor proportionally to available memory (e.g. 75/25) +so both nodes compute every layer simultaneously. + +Mathematical correctness: + Column parallel: y = x @ [W_a; W_b]^T = [x @ W_a^T, x @ W_b^T] + Row parallel: y = x_a @ W_a^T + x_b @ W_b^T = x @ W^T (via all_sum) + Both hold regardless of the split ratio. + +Usage: + asymmetric_tensor_auto_parallel(model, group, ratios=[0.75, 0.25]) +""" +# pyright: reportAny=false, reportUnknownMemberType=false, reportUnknownArgumentType=false, reportUnknownVariableType=false + +from __future__ import annotations + +from collections.abc import Generator +from typing import Any + +import mlx.core as mx +import mlx.nn as nn +from mlx.nn.layers.distributed import sum_gradients +from mlx_lm.models.qwen3_5 import DecoderLayer as Qwen3_5DecoderLayer +from mlx_lm.models.qwen3_5 import GatedDeltaNet +from mlx_lm.models.qwen3_5 import SparseMoeBlock as Qwen3_5SparseMoeBlock +from mlx_lm.models.qwen3_next import Qwen3NextAttention as Attention +from mlx_lm.models.qwen3_next import Qwen3NextMLP as Qwen3NextMLP +from mlx_lm.models.qwen3_next import Qwen3NextSparseMoeBlock as SparseMoeBlock + +from exo.shared.types.worker.runner_response import ModelLoadingResponse + +try: + from exo.shared.logging import logger +except ImportError: + import logging + + logger = logging.getLogger(__name__) + + +def find_valid_ratios( + memory_fractions: list[float], + hidden_size: int, + num_attention_heads: int, + num_key_value_heads: int, + num_experts: int = 0, + moe_intermediate_size: int = 0, + linear_num_value_heads: int = 0, + linear_num_key_heads: int = 0, + quantization_group_size: int = 64, +) -> list[float] | None: + """ + Find valid split ratios for asymmetric TP given model dimensions and memory fractions. + + A valid ratio must produce integer dimensions for all split tensors, + and all split dimensions must be divisible by the quantization group size. + + Returns a list of ratios (one per node) that sum to 1.0, or None if no valid + ratio exists. Currently supports 2 nodes only. + """ + if len(memory_fractions) != 2: + logger.warning("Asymmetric TP currently only supports 2 nodes") + return None + + # Key dimensions that must split cleanly + key_dims = [ + num_attention_heads, + num_key_value_heads, + hidden_size, + ] + if linear_num_value_heads > 0: + key_dims.extend([linear_num_value_heads, linear_num_key_heads]) + if num_experts > 0 and moe_intermediate_size > 0: + key_dims.append(moe_intermediate_size) + + target_ratio = memory_fractions[0] + + # Try ratios of the form n/d where d is a power of 2 or common denominator + # that produces clean splits. Test denominators 2..32. + best_ratio = None + best_distance = float("inf") + + for denom in [2, 4, 8, 16, 32]: + for numer in range(1, denom): + ratio = numer / denom + if ratio <= 0.5 or ratio > 0.95: + continue + + # Check all dimensions split cleanly + valid = True + for dim in key_dims: + # dim * ratio must be EXACTLY integer (for head counts) + exact = dim * ratio + if exact != int(exact): + valid = False + break + a = int(exact) + b = dim - a + if a <= 0 or b <= 0: + valid = False + break + # For quantized weights, split dims must be divisible by 8 + if dim > quantization_group_size and (a % 8 != 0 or b % 8 != 0): + valid = False + break + + if valid: + distance = abs(ratio - target_ratio) + if distance < best_distance: + best_distance = distance + best_ratio = ratio + + if best_ratio is None: + return None + + return [best_ratio, 1.0 - best_ratio] + + +def _split_at(tensor: mx.array, axis: int, ratio: float) -> tuple[mx.array, mx.array]: + """Split tensor at ratio point along axis.""" + sp = int(tensor.shape[axis] * ratio) + parts = mx.split(tensor, [sp], axis=axis) + return mx.contiguous(parts[0]), mx.contiguous(parts[1]) + + +def _my_shard(tensor: mx.array, axis: int, rank: int, ratio: float) -> mx.array: + """Get rank's portion of an asymmetric split.""" + parts = _split_at(tensor, axis, ratio) + return parts[0] if rank == 0 else parts[1] + + +def _shard_quantized_ats( + layer: Any, + axis: int, + rank: int, + ratio: float, + segments: list[int] | None = None, +) -> None: + """Shard quantized linear all-to-sharded (output dim split).""" + if segments is not None: + w: mx.array = layer.weight + seg_parts_w = mx.split(w, segments, axis=axis) + my_w_parts = [_my_shard(p, axis, rank, ratio) for p in seg_parts_w] + layer.weight = mx.contiguous(mx.concatenate(my_w_parts, axis=axis)) + for attr in ["scales", "biases"]: + t: mx.array | None = getattr(layer, attr, None) + if t is None: + continue + t_seg = [int(s * t.shape[axis] / w.shape[axis]) for s in segments] + t_parts = mx.split(t, t_seg, axis=axis) + my_parts = [_my_shard(p, axis, rank, ratio) for p in t_parts] + setattr(layer, attr, mx.contiguous(mx.concatenate(my_parts, axis=axis))) + else: + for attr in ["weight", "scales", "biases"]: + t_val: mx.array | None = getattr(layer, attr, None) + if t_val is None: + continue + setattr(layer, attr, _my_shard(t_val, axis, rank, ratio)) + + +def _shard_quantized_sta(layer: Any, rank: int, ratio: float) -> None: + """Shard quantized linear sharded-to-all (input dim, axis -1).""" + for attr in ["weight", "scales", "biases"]: + t: mx.array | None = getattr(layer, attr, None) + if t is None: + continue + setattr(layer, attr, _my_shard(t, -1, rank, ratio)) + + +def _shard_gated_delta_net( + gdn: GatedDeltaNet, rank: int, ratio: float, group: mx.distributed.Group +) -> None: + """Asymmetric shard for GatedDeltaNet (linear attention) layers.""" + kd = gdn.key_dim + _shard_quantized_ats(gdn.in_proj_qkv, 0, rank, ratio, segments=[kd, 2 * kd]) + _shard_quantized_ats(gdn.in_proj_z, 0, rank, ratio) + _shard_quantized_ats(gdn.in_proj_b, 0, rank, ratio) + _shard_quantized_ats(gdn.in_proj_a, 0, rank, ratio) + _shard_quantized_sta(gdn.out_proj, rank, ratio) + + # conv1d: segmented split along channel dim + conv_w = gdn.conv1d.weight + seg_parts = mx.split(conv_w, [kd, 2 * kd], axis=0) + my_parts = [_my_shard(p, 0, rank, ratio) for p in seg_parts] + gdn.conv1d.weight = mx.contiguous(mx.concatenate(my_parts, axis=0)) + + gdn.dt_bias = _my_shard(gdn.dt_bias, 0, rank, ratio) + gdn.A_log = _my_shard(gdn.A_log, 0, rank, ratio) + + r = ratio if rank == 0 else (1 - ratio) + gdn.num_k_heads = int(gdn.num_k_heads * r) + gdn.num_v_heads = int(gdn.num_v_heads * r) + gdn.key_dim = int(gdn.key_dim * r) + gdn.value_dim = int(gdn.value_dim * r) + gdn.conv_dim = int(gdn.conv_dim * r) + gdn.conv1d.groups = gdn.conv_dim + gdn.sharding_group = group + + +# Patching must happen at the class level since nn.Module.__call__ ignores instance overrides +_attention_class_patched: set[type] = set() + + +def _patch_attention_class(attn_cls: type) -> None: + """Patch an attention class to add all_sum when _asymmetric_tp_group is set.""" + if attn_cls in _attention_class_patched: + return + + original_call = attn_cls.__call__ + + def patched_call( + self: nn.Module, + x: mx.array, + mask: mx.array | None = None, + cache: object | None = None, + ) -> mx.array: + result = original_call(self, x, mask=mask, cache=cache) + grp = getattr(self, "_asymmetric_tp_group", None) + if grp is not None: + result = mx.distributed.all_sum(result, group=grp) + return result + + attn_cls.__call__ = patched_call + _attention_class_patched.add(attn_cls) + + +def _shard_attention( + attn: Attention, rank: int, ratio: float, group: mx.distributed.Group +) -> None: + """Asymmetric shard for self-attention layers.""" + _patch_attention_class(type(attn)) + _shard_quantized_ats(attn.q_proj, 0, rank, ratio) + _shard_quantized_ats(attn.k_proj, 0, rank, ratio) + _shard_quantized_ats(attn.v_proj, 0, rank, ratio) + _shard_quantized_sta(attn.o_proj, rank, ratio) + + r = ratio if rank == 0 else (1 - ratio) + attn.num_attention_heads = int(attn.num_attention_heads * r) + attn.num_key_value_heads = int(attn.num_key_value_heads * r) + attn._asymmetric_tp_group = group + + + +class AsymmetricShardedMoE(nn.Module): + def __init__(self, layer: SparseMoeBlock | Qwen3_5SparseMoeBlock): + super().__init__() + self.original_layer = layer + self.sharding_group: mx.distributed.Group | None = None + + def __call__(self, x: mx.array) -> mx.array: + if self.sharding_group is not None: + x = sum_gradients(self.sharding_group)(x) + y = self.original_layer(x) + if self.sharding_group is not None: + y = mx.distributed.all_sum(y, group=self.sharding_group) + return y + + +def _shard_sparse_moe( + moe: SparseMoeBlock | Qwen3_5SparseMoeBlock, + rank: int, + ratio: float, + group: mx.distributed.Group, +) -> AsymmetricShardedMoE: + """Asymmetric shard for SparseMoeBlock (MoE layers).""" + # switch_mlp: split expert intermediate dims (axis 1 for 3D expert weights) + _shard_quantized_ats(moe.switch_mlp.gate_proj, 1, rank, ratio) + _shard_quantized_ats(moe.switch_mlp.up_proj, 1, rank, ratio) + _shard_quantized_sta(moe.switch_mlp.down_proj, rank, ratio) + + # shared_expert: standard MLP split + _shard_quantized_ats(moe.shared_expert.gate_proj, 0, rank, ratio) + _shard_quantized_ats(moe.shared_expert.up_proj, 0, rank, ratio) + _shard_quantized_sta(moe.shared_expert.down_proj, rank, ratio) + + sharded_moe = AsymmetricShardedMoE(moe) + sharded_moe.sharding_group = group + return sharded_moe + + +_mlp_class_patched: set[type] = set() + + +def _patch_mlp_class(mlp_cls: type) -> None: + """Patch a dense MLP class to add all_sum when _asymmetric_tp_group is set.""" + if mlp_cls in _mlp_class_patched: + return + + original_call = mlp_cls.__call__ + + def patched_call(self: nn.Module, x: mx.array) -> mx.array: + result = original_call(self, x) + grp = getattr(self, "_asymmetric_tp_group", None) + if grp is not None: + result = mx.distributed.all_sum(result, group=grp) + return result + + mlp_cls.__call__ = patched_call + _mlp_class_patched.add(mlp_cls) + + +def _shard_dense_mlp( + mlp: Qwen3NextMLP, rank: int, ratio: float, group: mx.distributed.Group +) -> None: + """Asymmetric shard for dense (non-MoE) MLP layers.""" + _patch_mlp_class(type(mlp)) + _shard_quantized_ats(mlp.gate_proj, 0, rank, ratio) + _shard_quantized_ats(mlp.up_proj, 0, rank, ratio) + _shard_quantized_sta(mlp.down_proj, rank, ratio) + mlp._asymmetric_tp_group = group + + +def asymmetric_tensor_auto_parallel( + model: nn.Module, + group: mx.distributed.Group, + ratios: list[float], +) -> Generator[ModelLoadingResponse, None, nn.Module]: + """ + Apply asymmetric tensor parallelism to a model. + + Args: + model: The model to parallelize (must have .layers property) + group: MLX distributed group + ratios: Per-rank weight fractions, e.g. [0.75, 0.25] for 2 nodes. + ratios[group.rank()] is this node's fraction. + + Returns: + The model with asymmetric sharding applied. + """ + rank = group.rank() + ratio = ratios[0] # ratio for rank 0; rank 1 gets 1-ratio + + # Get the inner model's layers + inner = model + for attr in ["language_model", "model"]: + candidate = getattr(inner, attr, None) + if candidate is not None and hasattr(candidate, "layers"): + inner = candidate + + layers: list[Any] = inner.layers if hasattr(inner, "layers") else model.layers + + total = len(layers) + for layer_index, layer in enumerate(layers): + if isinstance(layer, Qwen3_5DecoderLayer): + # Qwen3.5 hybrid: linear_attn or self_attn per layer + if layer.is_linear: + _shard_gated_delta_net(layer.linear_attn, rank, ratio, group) + else: + _shard_attention(layer.self_attn, rank, ratio, group) + + mlp = layer.mlp + if isinstance(mlp, (SparseMoeBlock, Qwen3_5SparseMoeBlock)): + dict.__setitem__( + layer, + "mlp", + _shard_sparse_moe(mlp, rank, ratio, group), + ) + else: + _shard_dense_mlp(mlp, rank, ratio, group) + else: + raise ValueError( + f"Asymmetric TP does not yet support layer type {type(layer).__name__}. " + f"Currently supported: Qwen3.5 (GatedDeltaNet + Attention + MoE). " + f"Contributions for other architectures welcome." + ) + mx.eval(layer) + yield ModelLoadingResponse(layers_loaded=layer_index, total=total) + + logger.info( + f"Asymmetric TP applied: rank {rank} gets " + f"{ratios[rank] * 100:.0f}% of each weight tensor" + ) + return model diff --git a/src/exo/worker/engines/mlx/utils_mlx.py b/src/exo/worker/engines/mlx/utils_mlx.py index e3f7778758..ada5d9cc03 100644 --- a/src/exo/worker/engines/mlx/utils_mlx.py +++ b/src/exo/worker/engines/mlx/utils_mlx.py @@ -53,11 +53,15 @@ ) from exo.shared.types.worker.runner_response import ModelLoadingResponse from exo.shared.types.worker.shards import ( + AsymmetricTensorShardMetadata, CfgShardMetadata, PipelineShardMetadata, ShardMetadata, TensorShardMetadata, ) +from exo.worker.engines.mlx.asymmetric_parallel import ( + asymmetric_tensor_auto_parallel, +) from exo.worker.engines.mlx.auto_parallel import ( get_inner_model, get_layers, @@ -69,6 +73,19 @@ def get_weights_size(model_shard_meta: ShardMetadata) -> Memory: + if isinstance(model_shard_meta, AsymmetricTensorShardMetadata): + rank_weight_fraction = ( + model_shard_meta.ratio + if model_shard_meta.device_rank == 0 + else 1.0 - model_shard_meta.ratio + ) + return Memory.from_float_kb( + (model_shard_meta.end_layer - model_shard_meta.start_layer) + / model_shard_meta.n_layers + * model_shard_meta.model_card.storage_size.in_kb + * rank_weight_fraction + ) + return Memory.from_float_kb( (model_shard_meta.end_layer - model_shard_meta.start_layer) / model_shard_meta.n_layers @@ -100,6 +117,7 @@ def mlx_distributed_init( coordination_file = str( Path(tmpdir) / f"hosts_{bound_instance.instance.instance_id}_{rank}.json" ) + group: mx.distributed.Group | None = None # TODO: singleton instances match bound_instance.instance: case MlxRingInstance(hosts_by_node=hosts_by_node, ephemeral_port=_): @@ -156,6 +174,8 @@ def mlx_distributed_init( time.sleep(backoff) logger.info(f"Rank {rank} mlx distributed initialization complete") + if group is None: + raise RuntimeError("MLX distributed initialization did not return a group") return group # pyright: ignore[reportPossiblyUnboundVariable] @@ -277,6 +297,14 @@ def shard_and_load( case TensorShardMetadata(): logger.info(f"loading model from {model_path} with tensor parallelism") model = yield from tensor_auto_parallel(model, group) + case AsymmetricTensorShardMetadata(): + rank_zero_ratio = shard_metadata.ratio + ratios_list = [rank_zero_ratio, 1.0 - rank_zero_ratio] + logger.info( + f"loading model from {model_path} with asymmetric tensor parallelism " + f"(ratios={[f'{r:.0%}' for r in ratios_list]})" + ) + model = yield from asymmetric_tensor_auto_parallel(model, group, ratios_list) case PipelineShardMetadata(): logger.info(f"loading model from {model_path} with pipeline parallelism") model = yield from pipeline_auto_parallel(model, group, shard_metadata) diff --git a/src/exo/worker/tests/unittests/test_mlx/test_asymmetric_parallel.py b/src/exo/worker/tests/unittests/test_mlx/test_asymmetric_parallel.py new file mode 100644 index 0000000000..6736cddd04 --- /dev/null +++ b/src/exo/worker/tests/unittests/test_mlx/test_asymmetric_parallel.py @@ -0,0 +1,119 @@ +"""Tests for asymmetric tensor parallelism ratio finding and sharding.""" + +class TestFindValidRatios: + """Test the ratio solver that finds valid asymmetric split points.""" + + def test_qwen3_5_full_attention_dimensions_with_divisible_kv_heads(self) -> None: + from exo.worker.engines.mlx.asymmetric_parallel import find_valid_ratios + + ratios = find_valid_ratios( + memory_fractions=[0.73, 0.27], + hidden_size=3072, + num_attention_heads=32, + num_key_value_heads=8, + linear_num_value_heads=64, + linear_num_key_heads=16, + moe_intermediate_size=1024, + num_experts=256, + ) + assert ratios is not None + assert len(ratios) == 2 + assert abs(ratios[0] + ratios[1] - 1.0) < 1e-10 + # All head counts must be exact integers after split + assert 32 * ratios[0] == int(32 * ratios[0]) # attention heads + assert 8 * ratios[0] == int(8 * ratios[0]) # KV heads + assert 64 * ratios[0] == int(64 * ratios[0]) # value heads + assert 16 * ratios[0] == int(16 * ratios[0]) # key heads + + def test_rejects_qwen3_5_122b_two_kv_heads_for_asymmetric_attention(self) -> None: + from exo.worker.engines.mlx.asymmetric_parallel import find_valid_ratios + + ratios = find_valid_ratios( + memory_fractions=[0.73, 0.27], + hidden_size=3072, + num_attention_heads=32, + num_key_value_heads=2, + linear_num_value_heads=64, + linear_num_key_heads=16, + moe_intermediate_size=1024, + num_experts=256, + ) + + assert ratios is None + + def test_llama_70b_dimensions(self) -> None: + from exo.worker.engines.mlx.asymmetric_parallel import find_valid_ratios + + ratios = find_valid_ratios( + memory_fractions=[0.73, 0.27], + hidden_size=8192, + num_attention_heads=64, + num_key_value_heads=8, + ) + assert ratios is not None + assert 64 * ratios[0] == int(64 * ratios[0]) + + def test_nemotron_120b_dimensions(self) -> None: + from exo.worker.engines.mlx.asymmetric_parallel import find_valid_ratios + + ratios = find_valid_ratios( + memory_fractions=[0.73, 0.27], + hidden_size=4096, + num_attention_heads=32, + num_key_value_heads=8, + ) + assert ratios is not None + assert 32 * ratios[0] == int(32 * ratios[0]) + + def test_rejects_impossible_dimensions(self) -> None: + """Prime-number head count with no valid fractional split.""" + from exo.worker.engines.mlx.asymmetric_parallel import find_valid_ratios + + ratios = find_valid_ratios( + memory_fractions=[0.73, 0.27], + hidden_size=3072, + num_attention_heads=7, # prime: cannot split into 2 integer parts > 0.5 + num_key_value_heads=2, + ) + assert ratios is None + + def test_only_two_nodes_supported(self) -> None: + from exo.worker.engines.mlx.asymmetric_parallel import find_valid_ratios + + ratios = find_valid_ratios( + memory_fractions=[0.5, 0.25, 0.25], + hidden_size=4096, + num_attention_heads=32, + num_key_value_heads=8, + ) + assert ratios is None + + def test_ratio_closer_to_target(self) -> None: + """Ratio should be the closest valid one to the memory fraction.""" + from exo.worker.engines.mlx.asymmetric_parallel import find_valid_ratios + + # With 80% target, 0.8125 (13/16) is closer than 0.75 (12/16) + ratios = find_valid_ratios( + memory_fractions=[0.80, 0.20], + hidden_size=3072, + num_attention_heads=32, + num_key_value_heads=16, + linear_num_value_heads=64, + linear_num_key_heads=16, + ) + assert ratios is not None + assert abs(ratios[0] - 0.80) < abs(0.75 - 0.80) + + def test_equal_memory_returns_near_symmetric(self) -> None: + """When memory is roughly equal, ratio should be close to 0.5.""" + from exo.worker.engines.mlx.asymmetric_parallel import find_valid_ratios + + ratios = find_valid_ratios( + memory_fractions=[0.50, 0.50], + hidden_size=4096, + num_attention_heads=32, + num_key_value_heads=8, + ) + # Finder searches > 0.5, so it may find a near-symmetric split + if ratios is not None: + assert ratios[0] < 0.7 # should be close to 0.5 From d3ce0ec301bd5ec4831dfd65a880e8ca30cf42db Mon Sep 17 00:00:00 2001 From: jw-wcv <101585096+jw-wcv@users.noreply.github.com> Date: Tue, 28 Apr 2026 12:37:21 -0700 Subject: [PATCH 7/9] Label asymmetric tensor dashboard instances --- dashboard/src/lib/components/ChatSidebar.svelte | 2 ++ dashboard/src/lib/stores/app.svelte.ts | 2 ++ dashboard/src/routes/+page.svelte | 2 ++ 3 files changed, 6 insertions(+) diff --git a/dashboard/src/lib/components/ChatSidebar.svelte b/dashboard/src/lib/components/ChatSidebar.svelte index 1e8c975fb9..97b7af7745 100644 --- a/dashboard/src/lib/components/ChatSidebar.svelte +++ b/dashboard/src/lib/components/ChatSidebar.svelte @@ -213,6 +213,8 @@ const [shardTag] = getTaggedValue(firstShardWrapped); if (shardTag === "PipelineShardMetadata") sharding = "Pipeline"; else if (shardTag === "TensorShardMetadata") sharding = "Tensor"; + else if (shardTag === "AsymmetricTensorShardMetadata") + sharding = "Asymmetric Tensor"; else if (shardTag === "PrefillDecodeShardMetadata") sharding = "Prefill/Decode"; } diff --git a/dashboard/src/lib/stores/app.svelte.ts b/dashboard/src/lib/stores/app.svelte.ts index 5b28b3f020..e69cc43d29 100644 --- a/dashboard/src/lib/stores/app.svelte.ts +++ b/dashboard/src/lib/stores/app.svelte.ts @@ -945,6 +945,8 @@ class AppStore { const [shardTag] = this.getTaggedValue(firstShardWrapped); if (shardTag === "PipelineShardMetadata") sharding = "Pipeline"; else if (shardTag === "TensorShardMetadata") sharding = "Tensor"; + else if (shardTag === "AsymmetricTensorShardMetadata") + sharding = "Asymmetric Tensor"; else if (shardTag === "PrefillDecodeShardMetadata") sharding = "Prefill/Decode"; } diff --git a/dashboard/src/routes/+page.svelte b/dashboard/src/routes/+page.svelte index 9b4563fd18..acc026c81f 100644 --- a/dashboard/src/routes/+page.svelte +++ b/dashboard/src/routes/+page.svelte @@ -2084,6 +2084,8 @@ const [shardTag] = getTagged(firstShardWrapped); if (shardTag === "PipelineShardMetadata") sharding = "Pipeline"; else if (shardTag === "TensorShardMetadata") sharding = "Tensor"; + else if (shardTag === "AsymmetricTensorShardMetadata") + sharding = "Asymmetric Tensor"; else if (shardTag === "PrefillDecodeShardMetadata") sharding = "Prefill/Decode"; } From 089c16f6cb353f7522f5daa220a3505d17750f1a Mon Sep 17 00:00:00 2001 From: jw-wcv <101585096+jw-wcv@users.noreply.github.com> Date: Tue, 28 Apr 2026 12:47:31 -0700 Subject: [PATCH 8/9] Show asymmetric model share in topology --- .../src/lib/components/TopologyGraph.svelte | 110 +++++++++++++++++- 1 file changed, 109 insertions(+), 1 deletion(-) diff --git a/dashboard/src/lib/components/TopologyGraph.svelte b/dashboard/src/lib/components/TopologyGraph.svelte index 0d0d3c3d08..4cf4d8a648 100644 --- a/dashboard/src/lib/components/TopologyGraph.svelte +++ b/dashboard/src/lib/components/TopologyGraph.svelte @@ -5,6 +5,7 @@ topologyData, isTopologyMinimized, debugMode, + instances, nodeThunderboltBridge, nodeRdmaCtl, nodeIdentities, @@ -31,11 +32,69 @@ const isMinimized = $derived(isTopologyMinimized()); const data = $derived(topologyData()); + const instanceData = $derived(instances()); const debugEnabled = $derived(debugMode()); const tbBridgeData = $derived(nodeThunderboltBridge()); const rdmaCtlData = $derived(nodeRdmaCtl()); const identitiesData = $derived(nodeIdentities()); + function getTaggedValue(value: unknown): [string | null, unknown] { + if (!value || typeof value !== "object") return [null, null]; + const record = value as Record; + const keys = Object.keys(record); + if (keys.length !== 1) return [null, value]; + return [keys[0], record[keys[0]]]; + } + + function getAsymmetricModelShareByNode(): Map { + const shares = new Map(); + + for (const instanceWrapped of Object.values(instanceData)) { + const [, instance] = getTaggedValue(instanceWrapped); + if (!instance || typeof instance !== "object") continue; + + const shardAssignments = ( + instance as { + shardAssignments?: { + nodeToRunner?: Record; + runnerToShard?: Record; + }; + } + ).shardAssignments; + if (!shardAssignments?.nodeToRunner || !shardAssignments.runnerToShard) { + continue; + } + + for (const [nodeId, runnerId] of Object.entries( + shardAssignments.nodeToRunner, + )) { + const shardWrapped = shardAssignments.runnerToShard[runnerId]; + const [shardTag, shardValue] = getTaggedValue(shardWrapped); + if ( + shardTag !== "AsymmetricTensorShardMetadata" || + !shardValue || + typeof shardValue !== "object" + ) { + continue; + } + + const shard = shardValue as { + deviceRank?: number; + ratio?: number; + worldSize?: number; + }; + if (shard.worldSize !== 2 || typeof shard.ratio !== "number") { + continue; + } + + const share = shard.deviceRank === 0 ? shard.ratio : 1 - shard.ratio; + shares.set(nodeId, share); + } + } + + return shares; + } + function getNodeLabel(nodeId: string): string { const node = data?.nodes?.[nodeId]; return node?.friendly_name || nodeId.slice(0, 8); @@ -166,6 +225,7 @@ const nodes = data.nodes || {}; const edges = data.edges || []; const nodeIds = Object.keys(nodes); + const asymmetricModelShareByNode = getAsymmetricModelShareByNode(); const rect = svgContainer.getBoundingClientRect(); const width = rect.width; @@ -562,6 +622,7 @@ const isFilteredOut = filteredNodes.size > 0 && !filteredNodes.has(nodeInfo.id); const isHovered = hoveredNodeId === nodeInfo.id && !isInFilter; + const asymmetricShare = asymmetricModelShareByNode.get(nodeInfo.id); // Holographic wireframe colors - bright yellow for filter, subtle yellow for hover, grey for filtered out const wireColor = isInFilter @@ -1137,12 +1198,59 @@ .text(` (${ramUsagePercent.toFixed(0)}%)`); } + if (asymmetricShare !== undefined) { + const sharePercent = Math.round(asymmetricShare * 100); + const badgeY = + nodeInfo.y + + iconBaseHeight / 2 + + (showFullLabels ? 34 : showCompactLabels ? 24 : 22); + const badgeText = `${sharePercent}% MODEL`; + const badgeFontSize = showFullLabels ? 10 : showCompactLabels ? 7 : 7; + const badgeWidth = Math.max( + showFullLabels ? 70 : 52, + badgeText.length * badgeFontSize * 0.62 + 12, + ); + const badgeHeight = showFullLabels ? 16 : 12; + + nodeG + .append("rect") + .attr("x", nodeInfo.x - badgeWidth / 2) + .attr("y", badgeY - badgeHeight / 2) + .attr("width", badgeWidth) + .attr("height", badgeHeight) + .attr("rx", badgeHeight / 2) + .attr("fill", "rgba(255,215,0,0.14)") + .attr("stroke", "rgba(255,215,0,0.55)") + .attr("stroke-width", 1); + + nodeG + .append("text") + .attr("x", nodeInfo.x) + .attr("y", badgeY + 0.5) + .attr("text-anchor", "middle") + .attr("dominant-baseline", "middle") + .attr("fill", "rgba(255,215,0,0.95)") + .attr("font-size", badgeFontSize) + .attr("font-weight", "700") + .attr("font-family", "SF Mono, Monaco, monospace") + .attr("letter-spacing", "0.04em") + .text(badgeText); + } + // Debug mode: Show TB bridge and RDMA status if (debugEnabled) { let debugLabelY = nodeInfo.y + iconBaseHeight / 2 + - (showFullLabels ? 32 : showCompactLabels ? 26 : 22); + (asymmetricShare !== undefined + ? showFullLabels + ? 52 + : 38 + : showFullLabels + ? 32 + : showCompactLabels + ? 26 + : 22); const debugFontSize = showFullLabels ? 9 : 7; const debugLineHeight = showFullLabels ? 11 : 9; From 201ade31def606eea685c9decef130e6e53143b4 Mon Sep 17 00:00:00 2001 From: jw-wcv <101585096+jw-wcv@users.noreply.github.com> Date: Tue, 28 Apr 2026 12:55:25 -0700 Subject: [PATCH 9/9] Preserve reachable asymmetric rank zero Keep asymmetric tensor rank 0 constrained to the largest socket-reachable node and avoid auto-upgrading tensor placements that cannot satisfy the two-node asymmetric constraints. --- src/exo/master/placement.py | 72 +++++++++-- src/exo/master/placement_utils.py | 15 +-- src/exo/master/tests/test_placement.py | 122 +++++++++++++++++- .../worker/engines/mlx/asymmetric_parallel.py | 1 - src/exo/worker/engines/mlx/utils_mlx.py | 6 +- .../test_mlx/test_asymmetric_parallel.py | 1 + 6 files changed, 190 insertions(+), 27 deletions(-) diff --git a/src/exo/master/placement.py b/src/exo/master/placement.py index bc5ed0a606..3337cd7626 100644 --- a/src/exo/master/placement.py +++ b/src/exo/master/placement.py @@ -149,8 +149,9 @@ def place_instance( if len(cycles_with_sufficient_memory) == 0: raise ValueError("No cycles found with sufficient memory") - if sharding == Sharding.AsymmetricTensor and not _supports_asymmetric_tensor_parallel( - command.model_card + if ( + sharding == Sharding.AsymmetricTensor + and not _supports_asymmetric_tensor_parallel(command.model_card) ): raise ValueError( f"Asymmetric tensor parallelism is not yet supported for " @@ -190,6 +191,8 @@ def place_instance( and _supports_asymmetric_tensor_parallel(command.model_card) ): for cycle in cycles_with_sufficient_memory: + if len(cycle) != 2: + continue equal_share = command.model_card.storage_size.in_bytes / len(cycle) min_node_mem = min( node_memory[nid].ram_available.in_bytes for nid in cycle @@ -212,9 +215,19 @@ def place_instance( cycles_with_sufficient_memory = [ cycle for cycle in cycles_with_sufficient_memory if len(cycle) == 2 ] + cycles_with_sufficient_memory = [ + cycle + for cycle in cycles_with_sufficient_memory + if _asymmetric_tensor_rank_zero_is_socket_reachable( + cycle=cycle, + node_memory=node_memory, + topology=topology, + ) + ] if not cycles_with_sufficient_memory: raise ValueError( - "Asymmetric tensor parallelism currently requires exactly 2 nodes" + "Asymmetric tensor parallelism currently requires exactly 2 nodes " + "with the largest-memory rank-0 node socket-reachable" ) if sharding == Sharding.Pipeline and command.model_card.model_id == ModelId( @@ -223,9 +236,8 @@ def place_instance( raise ValueError( "Pipeline parallelism is not supported for DeepSeek V3.1 (8-bit)" ) - if ( - sharding == Sharding.Pipeline - and command.model_card.base_model.startswith("Gemma 4") + if sharding == Sharding.Pipeline and command.model_card.base_model.startswith( + "Gemma 4" ): cycles_with_sufficient_memory = [ cycle for cycle in cycles_with_sufficient_memory if len(cycle) == 1 @@ -283,12 +295,10 @@ def _all_rdma_ctl_enabled(cycle: Cycle) -> bool: ) selected_cycle = _prefer_socket_reachable_rank_zero(selected_cycle, topology) if sharding == Sharding.AsymmetricTensor: - selected_cycle = Cycle( - node_ids=sorted( - selected_cycle.node_ids, - key=lambda node_id: node_memory[node_id].ram_available.in_bytes, - reverse=True, - ) + selected_cycle = _order_asymmetric_tensor_cycle( + cycle=selected_cycle, + node_memory=node_memory, + topology=topology, ) # Single-node: force Pipeline/Ring (Tensor and Jaccl require multi-node) @@ -383,6 +393,44 @@ def _prefer_socket_reachable_rank_zero(cycle: Cycle, topology: Topology) -> Cycl return Cycle(node_ids=cycle.node_ids[best_index:] + cycle.node_ids[:best_index]) +def _order_asymmetric_tensor_cycle( + cycle: Cycle, + node_memory: Mapping[NodeId, MemoryUsage], + topology: Topology, +) -> Cycle: + """Order an asymmetric TP cycle with the largest reachable node as rank 0.""" + ordered_cycle = Cycle( + node_ids=sorted( + cycle.node_ids, + key=lambda node_id: node_memory[node_id].ram_available.in_bytes, + reverse=True, + ) + ) + preferred_cycle = _prefer_socket_reachable_rank_zero(ordered_cycle, topology) + if preferred_cycle.node_ids[0] != ordered_cycle.node_ids[0]: + raise ValueError( + "Asymmetric tensor parallelism requires the largest-memory rank-0 " + "node to be socket-reachable" + ) + return ordered_cycle + + +def _asymmetric_tensor_rank_zero_is_socket_reachable( + cycle: Cycle, + node_memory: Mapping[NodeId, MemoryUsage], + topology: Topology, +) -> bool: + try: + _order_asymmetric_tensor_cycle( + cycle=cycle, + node_memory=node_memory, + topology=topology, + ) + except ValueError: + return False + return True + + def delete_instance( command: DeleteInstance, current_instances: Mapping[InstanceId, Instance], diff --git a/src/exo/master/placement_utils.py b/src/exo/master/placement_utils.py index 45f1bb983c..370e20fae9 100644 --- a/src/exo/master/placement_utils.py +++ b/src/exo/master/placement_utils.py @@ -287,19 +287,16 @@ def get_shard_assignments_for_asymmetric_tensor_parallel( total_layers = model_card.n_layers world_size = len(cycle) - sorted_nodes = sorted( - cycle, - key=lambda node_id: node_memory[node_id].ram_available.in_bytes, - reverse=True, - ) + ordered_nodes = list(cycle) - # Compute memory fractions with the largest-memory node fixed as rank 0. + # The placement layer orders the cycle so rank 0 is both the largest-memory + # node and socket-reachable for distributed initialization. total_available = sum( - node_memory[node_id].ram_available.in_bytes for node_id in sorted_nodes + node_memory[node_id].ram_available.in_bytes for node_id in ordered_nodes ) memory_fractions = [ node_memory[node_id].ram_available.in_bytes / total_available - for node_id in sorted_nodes + for node_id in ordered_nodes ] from exo.worker.engines.mlx.asymmetric_parallel import find_valid_ratios @@ -319,7 +316,7 @@ def get_shard_assignments_for_asymmetric_tensor_parallel( node_to_runner: dict[NodeId, RunnerId] = {} rank_zero_ratio = ratios[0] - for i, node_id in enumerate(sorted_nodes): + for i, node_id in enumerate(ordered_nodes): shard = AsymmetricTensorShardMetadata( model_card=model_card, device_rank=i, diff --git a/src/exo/master/tests/test_placement.py b/src/exo/master/tests/test_placement.py index 33097b6748..2856bf246c 100644 --- a/src/exo/master/tests/test_placement.py +++ b/src/exo/master/tests/test_placement.py @@ -562,9 +562,9 @@ def test_qwen3_5_tensor_auto_upgrade_requires_opt_in( node_rdma_ctl=node_rdma_ctl, ) instance_without_opt_in = next(iter(placements_without_opt_in.values())) - large_runner_without_opt_in = instance_without_opt_in.shard_assignments.node_to_runner[ - large_node - ] + large_runner_without_opt_in = ( + instance_without_opt_in.shard_assignments.node_to_runner[large_node] + ) large_shard_without_opt_in = ( instance_without_opt_in.shard_assignments.runner_to_shard[ large_runner_without_opt_in @@ -602,6 +602,122 @@ def test_qwen3_5_tensor_auto_upgrade_requires_opt_in( assert large_shard.ratio == small_shard.ratio == 0.75 +def test_qwen3_5_tensor_auto_upgrade_ignores_non_two_node_cycles( + monkeypatch: pytest.MonkeyPatch, +) -> None: + topology = Topology() + node_id_a = NodeId() + node_id_b = NodeId() + node_id_c = NodeId() + topology.add_node(node_id_a) + topology.add_node(node_id_b) + topology.add_node(node_id_c) + topology.add_connection( + Connection(source=node_id_a, sink=node_id_b, edge=create_socket_connection(1)) + ) + topology.add_connection( + Connection(source=node_id_b, sink=node_id_c, edge=create_socket_connection(2)) + ) + topology.add_connection( + Connection(source=node_id_c, sink=node_id_a, edge=create_socket_connection(3)) + ) + + model_card = ModelCard( + model_id=ModelId("mlx-community/Qwen3.5-72B-8bit"), + storage_size=Memory.from_bytes(140_000_000_000), + n_layers=48, + hidden_size=3072, + num_key_value_heads=6, + supports_tensor=True, + tasks=[ModelTask.TextGeneration], + family="qwen", + base_model="Qwen3.5 72B", + ) + command = PlaceInstance( + sharding=Sharding.Tensor, + instance_meta=InstanceMeta.MlxRing, + command_id=CommandId(), + model_card=model_card, + min_nodes=3, + ) + + monkeypatch.setenv("EXO_ENABLE_ASYMMETRIC_TP_AUTO_UPGRADE", "1") + + placements = place_instance( + command, + topology, + {}, + { + node_id_a: create_node_memory(128_000_000_000), + node_id_b: create_node_memory(128_000_000_000), + node_id_c: create_node_memory(48_000_000_000), + }, + { + node_id_a: create_node_network(), + node_id_b: create_node_network(), + node_id_c: create_node_network(), + }, + ) + + instance = next(iter(placements.values())) + assert len(instance.shard_assignments.node_to_runner) == 3 + assert all( + not isinstance(shard, AsymmetricTensorShardMetadata) + for shard in instance.shard_assignments.runner_to_shard.values() + ) + + +def test_asymmetric_tensor_rejects_unreachable_largest_rank_zero() -> None: + topology = Topology() + large_node = NodeId() + small_node = NodeId() + topology.add_node(large_node) + topology.add_node(small_node) + topology.add_connection( + Connection(source=large_node, sink=small_node, edge=create_rdma_connection(1)) + ) + topology.add_connection( + Connection(source=small_node, sink=large_node, edge=create_rdma_connection(2)) + ) + topology.add_connection( + Connection(source=large_node, sink=small_node, edge=create_socket_connection(3)) + ) + + model_card = ModelCard( + model_id=ModelId("mlx-community/Qwen3.5-72B-8bit"), + storage_size=Memory.from_bytes(130_648_036_320), + n_layers=48, + hidden_size=3072, + num_key_value_heads=8, + supports_tensor=True, + tasks=[ModelTask.TextGeneration], + family="qwen", + base_model="Qwen3.5 72B", + ) + command = PlaceInstance( + sharding=Sharding.AsymmetricTensor, + instance_meta=InstanceMeta.MlxRing, + command_id=CommandId(), + model_card=model_card, + min_nodes=2, + ) + + with pytest.raises(ValueError, match="rank-0 node socket-reachable"): + place_instance( + command, + topology, + {}, + { + large_node: create_node_memory(128_000_000_000), + small_node: create_node_memory(48_000_000_000), + }, + { + large_node: create_node_network(), + small_node: create_node_network(), + }, + ) + + def test_asymmetric_tensor_rejects_qwen3_5_with_unsplittable_kv_heads() -> None: topology = Topology() large_node = NodeId() diff --git a/src/exo/worker/engines/mlx/asymmetric_parallel.py b/src/exo/worker/engines/mlx/asymmetric_parallel.py index f8fea2be74..7f142955b5 100644 --- a/src/exo/worker/engines/mlx/asymmetric_parallel.py +++ b/src/exo/worker/engines/mlx/asymmetric_parallel.py @@ -243,7 +243,6 @@ def _shard_attention( attn._asymmetric_tp_group = group - class AsymmetricShardedMoE(nn.Module): def __init__(self, layer: SparseMoeBlock | Qwen3_5SparseMoeBlock): super().__init__() diff --git a/src/exo/worker/engines/mlx/utils_mlx.py b/src/exo/worker/engines/mlx/utils_mlx.py index ada5d9cc03..db9143ec32 100644 --- a/src/exo/worker/engines/mlx/utils_mlx.py +++ b/src/exo/worker/engines/mlx/utils_mlx.py @@ -177,7 +177,7 @@ def mlx_distributed_init( if group is None: raise RuntimeError("MLX distributed initialization did not return a group") - return group # pyright: ignore[reportPossiblyUnboundVariable] + return group def initialize_mlx( @@ -304,7 +304,9 @@ def shard_and_load( f"loading model from {model_path} with asymmetric tensor parallelism " f"(ratios={[f'{r:.0%}' for r in ratios_list]})" ) - model = yield from asymmetric_tensor_auto_parallel(model, group, ratios_list) + model = yield from asymmetric_tensor_auto_parallel( + model, group, ratios_list + ) case PipelineShardMetadata(): logger.info(f"loading model from {model_path} with pipeline parallelism") model = yield from pipeline_auto_parallel(model, group, shard_metadata) diff --git a/src/exo/worker/tests/unittests/test_mlx/test_asymmetric_parallel.py b/src/exo/worker/tests/unittests/test_mlx/test_asymmetric_parallel.py index 6736cddd04..bb3b89dd93 100644 --- a/src/exo/worker/tests/unittests/test_mlx/test_asymmetric_parallel.py +++ b/src/exo/worker/tests/unittests/test_mlx/test_asymmetric_parallel.py @@ -1,5 +1,6 @@ """Tests for asymmetric tensor parallelism ratio finding and sharding.""" + class TestFindValidRatios: """Test the ratio solver that finds valid asymmetric split points."""