diff --git a/dashboard/src/lib/components/ChatSidebar.svelte b/dashboard/src/lib/components/ChatSidebar.svelte index 1e8c975fb9..97b7af7745 100644 --- a/dashboard/src/lib/components/ChatSidebar.svelte +++ b/dashboard/src/lib/components/ChatSidebar.svelte @@ -213,6 +213,8 @@ const [shardTag] = getTaggedValue(firstShardWrapped); if (shardTag === "PipelineShardMetadata") sharding = "Pipeline"; else if (shardTag === "TensorShardMetadata") sharding = "Tensor"; + else if (shardTag === "AsymmetricTensorShardMetadata") + sharding = "Asymmetric Tensor"; else if (shardTag === "PrefillDecodeShardMetadata") sharding = "Prefill/Decode"; } diff --git a/dashboard/src/lib/components/TopologyGraph.svelte b/dashboard/src/lib/components/TopologyGraph.svelte index 0d0d3c3d08..4cf4d8a648 100644 --- a/dashboard/src/lib/components/TopologyGraph.svelte +++ b/dashboard/src/lib/components/TopologyGraph.svelte @@ -5,6 +5,7 @@ topologyData, isTopologyMinimized, debugMode, + instances, nodeThunderboltBridge, nodeRdmaCtl, nodeIdentities, @@ -31,11 +32,69 @@ const isMinimized = $derived(isTopologyMinimized()); const data = $derived(topologyData()); + const instanceData = $derived(instances()); const debugEnabled = $derived(debugMode()); const tbBridgeData = $derived(nodeThunderboltBridge()); const rdmaCtlData = $derived(nodeRdmaCtl()); const identitiesData = $derived(nodeIdentities()); + function getTaggedValue(value: unknown): [string | null, unknown] { + if (!value || typeof value !== "object") return [null, null]; + const record = value as Record; + const keys = Object.keys(record); + if (keys.length !== 1) return [null, value]; + return [keys[0], record[keys[0]]]; + } + + function getAsymmetricModelShareByNode(): Map { + const shares = new Map(); + + for (const instanceWrapped of Object.values(instanceData)) { + const [, instance] = getTaggedValue(instanceWrapped); + if (!instance || typeof instance !== "object") continue; + + const shardAssignments = ( + instance as { + shardAssignments?: { + nodeToRunner?: Record; + runnerToShard?: Record; + }; + } + ).shardAssignments; + if (!shardAssignments?.nodeToRunner || !shardAssignments.runnerToShard) { + continue; + } + + for (const [nodeId, runnerId] of Object.entries( + shardAssignments.nodeToRunner, + )) { + const shardWrapped = shardAssignments.runnerToShard[runnerId]; + const [shardTag, shardValue] = getTaggedValue(shardWrapped); + if ( + shardTag !== "AsymmetricTensorShardMetadata" || + !shardValue || + typeof shardValue !== "object" + ) { + continue; + } + + const shard = shardValue as { + deviceRank?: number; + ratio?: number; + worldSize?: number; + }; + if (shard.worldSize !== 2 || typeof shard.ratio !== "number") { + continue; + } + + const share = shard.deviceRank === 0 ? shard.ratio : 1 - shard.ratio; + shares.set(nodeId, share); + } + } + + return shares; + } + function getNodeLabel(nodeId: string): string { const node = data?.nodes?.[nodeId]; return node?.friendly_name || nodeId.slice(0, 8); @@ -166,6 +225,7 @@ const nodes = data.nodes || {}; const edges = data.edges || []; const nodeIds = Object.keys(nodes); + const asymmetricModelShareByNode = getAsymmetricModelShareByNode(); const rect = svgContainer.getBoundingClientRect(); const width = rect.width; @@ -562,6 +622,7 @@ const isFilteredOut = filteredNodes.size > 0 && !filteredNodes.has(nodeInfo.id); const isHovered = hoveredNodeId === nodeInfo.id && !isInFilter; + const asymmetricShare = asymmetricModelShareByNode.get(nodeInfo.id); // Holographic wireframe colors - bright yellow for filter, subtle yellow for hover, grey for filtered out const wireColor = isInFilter @@ -1137,12 +1198,59 @@ .text(` (${ramUsagePercent.toFixed(0)}%)`); } + if (asymmetricShare !== undefined) { + const sharePercent = Math.round(asymmetricShare * 100); + const badgeY = + nodeInfo.y + + iconBaseHeight / 2 + + (showFullLabels ? 34 : showCompactLabels ? 24 : 22); + const badgeText = `${sharePercent}% MODEL`; + const badgeFontSize = showFullLabels ? 10 : showCompactLabels ? 7 : 7; + const badgeWidth = Math.max( + showFullLabels ? 70 : 52, + badgeText.length * badgeFontSize * 0.62 + 12, + ); + const badgeHeight = showFullLabels ? 16 : 12; + + nodeG + .append("rect") + .attr("x", nodeInfo.x - badgeWidth / 2) + .attr("y", badgeY - badgeHeight / 2) + .attr("width", badgeWidth) + .attr("height", badgeHeight) + .attr("rx", badgeHeight / 2) + .attr("fill", "rgba(255,215,0,0.14)") + .attr("stroke", "rgba(255,215,0,0.55)") + .attr("stroke-width", 1); + + nodeG + .append("text") + .attr("x", nodeInfo.x) + .attr("y", badgeY + 0.5) + .attr("text-anchor", "middle") + .attr("dominant-baseline", "middle") + .attr("fill", "rgba(255,215,0,0.95)") + .attr("font-size", badgeFontSize) + .attr("font-weight", "700") + .attr("font-family", "SF Mono, Monaco, monospace") + .attr("letter-spacing", "0.04em") + .text(badgeText); + } + // Debug mode: Show TB bridge and RDMA status if (debugEnabled) { let debugLabelY = nodeInfo.y + iconBaseHeight / 2 + - (showFullLabels ? 32 : showCompactLabels ? 26 : 22); + (asymmetricShare !== undefined + ? showFullLabels + ? 52 + : 38 + : showFullLabels + ? 32 + : showCompactLabels + ? 26 + : 22); const debugFontSize = showFullLabels ? 9 : 7; const debugLineHeight = showFullLabels ? 11 : 9; diff --git a/dashboard/src/lib/stores/app.svelte.ts b/dashboard/src/lib/stores/app.svelte.ts index 5b28b3f020..e69cc43d29 100644 --- a/dashboard/src/lib/stores/app.svelte.ts +++ b/dashboard/src/lib/stores/app.svelte.ts @@ -945,6 +945,8 @@ class AppStore { const [shardTag] = this.getTaggedValue(firstShardWrapped); if (shardTag === "PipelineShardMetadata") sharding = "Pipeline"; else if (shardTag === "TensorShardMetadata") sharding = "Tensor"; + else if (shardTag === "AsymmetricTensorShardMetadata") + sharding = "Asymmetric Tensor"; else if (shardTag === "PrefillDecodeShardMetadata") sharding = "Prefill/Decode"; } diff --git a/dashboard/src/routes/+page.svelte b/dashboard/src/routes/+page.svelte index f69cdd1519..acc026c81f 100644 --- a/dashboard/src/routes/+page.svelte +++ b/dashboard/src/routes/+page.svelte @@ -147,9 +147,14 @@ if (!rdmaCtl) return false; const ids = tbIdentifiers; if (!ids) return false; - // Find nodes with TB5 hardware (any TB interface) + // Find nodes with TB5 hardware (link speed > 40 Gb/s) const tb5NodeIds = Object.entries(ids) - .filter(([_, node]) => node.interfaces.length > 0) + .filter(([_, node]) => + node.interfaces.some((iface: { linkSpeed: string }) => { + const match = iface.linkSpeed.match(/(\d+)\s*Gb/i); + return match != null && parseInt(match[1], 10) > 40; + }), + ) .map(([id]) => id); if (tb5NodeIds.length < 2) return false; // At least one TB5 node has RDMA disabled @@ -2079,6 +2084,8 @@ const [shardTag] = getTagged(firstShardWrapped); if (shardTag === "PipelineShardMetadata") sharding = "Pipeline"; else if (shardTag === "TensorShardMetadata") sharding = "Tensor"; + else if (shardTag === "AsymmetricTensorShardMetadata") + sharding = "Asymmetric Tensor"; else if (shardTag === "PrefillDecodeShardMetadata") sharding = "Prefill/Decode"; } diff --git a/src/exo/api/main.py b/src/exo/api/main.py index 4fb6d2d3b0..3afa18de20 100644 --- a/src/exo/api/main.py +++ b/src/exo/api/main.py @@ -195,7 +195,7 @@ ) from exo.shared.types.worker.downloads import DownloadCompleted from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta -from exo.shared.types.worker.shards import Sharding +from exo.shared.types.worker.shards import AsymmetricTensorShardMetadata, Sharding from exo.utils.banner import print_startup_banner from exo.utils.channels import Receiver, Sender, channel from exo.utils.disk_event_log import DiskEventLog @@ -587,11 +587,32 @@ async def get_placement_previews( memory_delta_by_node: dict[str, int] = {} if placement_node_ids: total_bytes = model_card.storage_size.in_bytes - per_node = total_bytes // len(placement_node_ids) - remainder = total_bytes % len(placement_node_ids) - for index, node_id in enumerate(sorted(placement_node_ids, key=str)): - extra = 1 if index < remainder else 0 - memory_delta_by_node[str(node_id)] = per_node + extra + asymmetric_shards: dict[NodeId, AsymmetricTensorShardMetadata] = {} + for ( + node_id, + runner_id, + ) in shard_assignments.node_to_runner.items(): + shard_metadata = shard_assignments.runner_to_shard[runner_id] + if isinstance(shard_metadata, AsymmetricTensorShardMetadata): + asymmetric_shards[node_id] = shard_metadata + if asymmetric_shards: + for node_id, shard_metadata in asymmetric_shards.items(): + rank_weight_fraction = ( + shard_metadata.ratio + if shard_metadata.device_rank == 0 + else 1.0 - shard_metadata.ratio + ) + memory_delta_by_node[str(node_id)] = int( + total_bytes * rank_weight_fraction + ) + else: + per_node = total_bytes // len(placement_node_ids) + remainder = total_bytes % len(placement_node_ids) + for index, node_id in enumerate( + sorted(placement_node_ids, key=str) + ): + extra = 1 if index < remainder else 0 + memory_delta_by_node[str(node_id)] = per_node + extra if ( model_card.model_id, diff --git a/src/exo/master/placement.py b/src/exo/master/placement.py index b55571c7f7..3337cd7626 100644 --- a/src/exo/master/placement.py +++ b/src/exo/master/placement.py @@ -1,7 +1,10 @@ from collections.abc import Mapping from copy import deepcopy +from os import environ from typing import Sequence +from loguru import logger + from exo.master.placement_utils import ( Cycle, filter_cycles_by_memory, @@ -11,7 +14,7 @@ get_shard_assignments, get_smallest_cycles, ) -from exo.shared.models.model_cards import ModelId +from exo.shared.models.model_cards import ModelCard, ModelId from exo.shared.topology import Topology from exo.shared.types.commands import ( CancelDownload, @@ -30,6 +33,7 @@ from exo.shared.types.memory import Memory from exo.shared.types.profiling import MemoryUsage, NodeNetworkInfo, NodeRdmaCtlStatus from exo.shared.types.tasks import Task, TaskId, TaskStatus +from exo.shared.types.topology import SocketConnection from exo.shared.types.worker.downloads import ( DownloadCompleted, DownloadFailed, @@ -47,6 +51,26 @@ from exo.shared.types.worker.shards import Sharding from exo.utils.ports import random_ephemeral_port +ASYMMETRIC_TENSOR_AUTO_UPGRADE_ENV = "EXO_ENABLE_ASYMMETRIC_TP_AUTO_UPGRADE" + + +def _supports_asymmetric_tensor_parallel(model_card: ModelCard) -> bool: + model_id = model_card.model_id.lower() + base_model = model_card.base_model.lower() + return ( + base_model.startswith("qwen3.5") + or "qwen3.5" in model_id + or "qwen-3.5" in model_id + ) + + +def _asymmetric_tensor_auto_upgrade_enabled() -> bool: + return environ.get(ASYMMETRIC_TENSOR_AUTO_UPGRADE_ENV, "").lower() in { + "1", + "true", + "yes", + } + def add_instance_to_placements( command: CreateInstance, @@ -63,7 +87,7 @@ def _get_node_download_fraction( model_id: ModelId, download_status: Mapping[NodeId, Sequence[DownloadProgress]], ) -> float: - """Return the download fraction (0.0–1.0) for a model on a given node.""" + """Return the download fraction (0.0-1.0) for a model on a given node.""" for progress in download_status.get(node_id, []): if progress.shard_metadata.model_card.model_id != model_id: continue @@ -107,6 +131,8 @@ def place_instance( download_status: Mapping[NodeId, Sequence[DownloadProgress]] | None = None, node_rdma_ctl: Mapping[NodeId, NodeRdmaCtlStatus] | None = None, ) -> dict[InstanceId, Instance]: + sharding = command.sharding + instance_meta = command.instance_meta cycles = topology.get_cycles() candidate_cycles = list(filter(lambda it: len(it) >= command.min_nodes, cycles)) @@ -123,39 +149,95 @@ def place_instance( if len(cycles_with_sufficient_memory) == 0: raise ValueError("No cycles found with sufficient memory") - if command.sharding == Sharding.Tensor: + if ( + sharding == Sharding.AsymmetricTensor + and not _supports_asymmetric_tensor_parallel(command.model_card) + ): + raise ValueError( + f"Asymmetric tensor parallelism is not yet supported for " + f"model '{command.model_card.model_id}'. Supported: Qwen3.5." + ) + + if sharding in (Sharding.Tensor, Sharding.AsymmetricTensor): if not command.model_card.supports_tensor: raise ValueError( f"Requested Tensor sharding but this model does not support tensor parallelism: {command.model_card.model_id}" ) - # TODO: the condition here for tensor parallel is not correct, but it works good enough for now. - # DeepSeek V4 is MQA (num_key_value_heads=1) but its sharding strategy - # head-parallelises wq_b/wo_a and shards MoE experts instead of splitting - # KV heads, so the kv-head divisibility check doesn't apply. - is_deepseek_v4 = command.model_card.base_model.startswith("DeepSeek V4") - kv_heads = command.model_card.num_key_value_heads + if sharding == Sharding.Tensor: + # TODO: the condition here for tensor parallel is not correct, but it works good enough for now. + # DeepSeek V4 is MQA (num_key_value_heads=1) but its sharding strategy + # head-parallelises wq_b/wo_a and shards MoE experts instead of splitting + # KV heads, so the kv-head divisibility check doesn't apply. + is_deepseek_v4 = command.model_card.base_model.startswith("DeepSeek V4") + kv_heads = command.model_card.num_key_value_heads + cycles_with_sufficient_memory = [ + cycle + for cycle in cycles_with_sufficient_memory + if command.model_card.hidden_size % len(cycle) == 0 + and (is_deepseek_v4 or kv_heads is None or kv_heads % len(cycle) == 0) + ] + if not cycles_with_sufficient_memory: + raise ValueError( + f"No tensor sharding found for model with " + f"hidden_size={command.model_card.hidden_size}" + f"{f', num_key_value_heads={kv_heads}' if kv_heads is not None else ''}" + f" across candidate cycles" + ) + + # Auto-upgrade to AsymmetricTensor when equal TP won't fit on + # the smallest node but asymmetric split would. + if ( + _asymmetric_tensor_auto_upgrade_enabled() + and _supports_asymmetric_tensor_parallel(command.model_card) + ): + for cycle in cycles_with_sufficient_memory: + if len(cycle) != 2: + continue + equal_share = command.model_card.storage_size.in_bytes / len(cycle) + min_node_mem = min( + node_memory[nid].ram_available.in_bytes for nid in cycle + ) + if equal_share > min_node_mem * 0.9: + # Equal split too tight; try asymmetric. + total_mem = sum( + node_memory[nid].ram_available.in_bytes for nid in cycle + ) + if command.model_card.storage_size.in_bytes < total_mem * 0.85: + logger.info( + "Equal tensor split won't fit on smallest node " + f"({min_node_mem / 1e9:.0f}GB available, " + f"needs {equal_share / 1e9:.0f}GB). " + "Auto-upgrading to AsymmetricTensor." + ) + sharding = Sharding.AsymmetricTensor + break + if sharding == Sharding.AsymmetricTensor: + cycles_with_sufficient_memory = [ + cycle for cycle in cycles_with_sufficient_memory if len(cycle) == 2 + ] cycles_with_sufficient_memory = [ cycle for cycle in cycles_with_sufficient_memory - if command.model_card.hidden_size % len(cycle) == 0 - and (is_deepseek_v4 or kv_heads is None or kv_heads % len(cycle) == 0) + if _asymmetric_tensor_rank_zero_is_socket_reachable( + cycle=cycle, + node_memory=node_memory, + topology=topology, + ) ] if not cycles_with_sufficient_memory: raise ValueError( - f"No tensor sharding found for model with " - f"hidden_size={command.model_card.hidden_size}" - f"{f', num_key_value_heads={kv_heads}' if kv_heads is not None else ''}" - f" across candidate cycles" + "Asymmetric tensor parallelism currently requires exactly 2 nodes " + "with the largest-memory rank-0 node socket-reachable" ) - if command.sharding == Sharding.Pipeline and command.model_card.model_id == ModelId( + + if sharding == Sharding.Pipeline and command.model_card.model_id == ModelId( "mlx-community/DeepSeek-V3.1-8bit" ): raise ValueError( "Pipeline parallelism is not supported for DeepSeek V3.1 (8-bit)" ) - if ( - command.sharding == Sharding.Pipeline - and command.model_card.base_model.startswith("Gemma 4") + if sharding == Sharding.Pipeline and command.model_card.base_model.startswith( + "Gemma 4" ): cycles_with_sufficient_memory = [ cycle for cycle in cycles_with_sufficient_memory if len(cycle) == 1 @@ -181,7 +263,7 @@ def _all_rdma_ctl_enabled(cycle: Cycle) -> bool: if topology.is_rdma_cycle(cycle) and _all_rdma_ctl_enabled(cycle) ] - if command.instance_meta == InstanceMeta.MlxJaccl: + if instance_meta == InstanceMeta.MlxJaccl: if not smallest_rdma_cycles: raise ValueError( "Requested RDMA (MlxJaccl) but no RDMA-connected cycles available" @@ -211,18 +293,21 @@ def _all_rdma_ctl_enabled(cycle: Cycle) -> bool: ), ), ) + selected_cycle = _prefer_socket_reachable_rank_zero(selected_cycle, topology) + if sharding == Sharding.AsymmetricTensor: + selected_cycle = _order_asymmetric_tensor_cycle( + cycle=selected_cycle, + node_memory=node_memory, + topology=topology, + ) # Single-node: force Pipeline/Ring (Tensor and Jaccl require multi-node) if len(selected_cycle) == 1: - command = command.model_copy( - update={ - "instance_meta": InstanceMeta.MlxRing, - "sharding": Sharding.Pipeline, - } - ) + instance_meta = InstanceMeta.MlxRing + sharding = Sharding.Pipeline shard_assignments = get_shard_assignments( - command.model_card, selected_cycle, command.sharding, node_memory + command.model_card, selected_cycle, sharding, node_memory ) cycle_digraph: Topology = topology.get_subgraph_from_nodes(selected_cycle.node_ids) @@ -230,7 +315,7 @@ def _all_rdma_ctl_enabled(cycle: Cycle) -> bool: instance_id = InstanceId() target_instances = dict(deepcopy(current_instances)) - match command.instance_meta: + match instance_meta: case InstanceMeta.MlxJaccl: # TODO(evan): shard assignments should contain information about ranks, this is ugly def get_device_rank(node_id: NodeId) -> int: @@ -281,6 +366,71 @@ def get_device_rank(node_id: NodeId) -> int: return target_instances +def _prefer_socket_reachable_rank_zero(cycle: Cycle, topology: Topology) -> Cycle: + """Rotate multi-node placements so rank 0 is easiest for peers to reach. + + MLX ring and JACCL both make rank 0 the listener/coordinator. Discovery can + produce RDMA-only edges in one direction and socket control-plane edges in + another, so putting a node with advertised inbound socket edges at rank 0 + avoids assigning the listener role to a machine peers cannot dial. + """ + if len(cycle) <= 1: + return cycle + + inbound_socket_edges: dict[NodeId, int] = {node_id: 0 for node_id in cycle} + for connection in topology.list_connections(): + if connection.sink not in inbound_socket_edges: + continue + if isinstance(connection.edge, SocketConnection): + inbound_socket_edges[connection.sink] += 1 + + best_index = max( + range(len(cycle.node_ids)), + key=lambda index: (inbound_socket_edges[cycle.node_ids[index]], -index), + ) + if best_index == 0: + return cycle + return Cycle(node_ids=cycle.node_ids[best_index:] + cycle.node_ids[:best_index]) + + +def _order_asymmetric_tensor_cycle( + cycle: Cycle, + node_memory: Mapping[NodeId, MemoryUsage], + topology: Topology, +) -> Cycle: + """Order an asymmetric TP cycle with the largest reachable node as rank 0.""" + ordered_cycle = Cycle( + node_ids=sorted( + cycle.node_ids, + key=lambda node_id: node_memory[node_id].ram_available.in_bytes, + reverse=True, + ) + ) + preferred_cycle = _prefer_socket_reachable_rank_zero(ordered_cycle, topology) + if preferred_cycle.node_ids[0] != ordered_cycle.node_ids[0]: + raise ValueError( + "Asymmetric tensor parallelism requires the largest-memory rank-0 " + "node to be socket-reachable" + ) + return ordered_cycle + + +def _asymmetric_tensor_rank_zero_is_socket_reachable( + cycle: Cycle, + node_memory: Mapping[NodeId, MemoryUsage], + topology: Topology, +) -> bool: + try: + _order_asymmetric_tensor_cycle( + cycle=cycle, + node_memory=node_memory, + topology=topology, + ) + except ValueError: + return False + return True + + def delete_instance( command: DeleteInstance, current_instances: Mapping[InstanceId, Instance], diff --git a/src/exo/master/placement_utils.py b/src/exo/master/placement_utils.py index 0375e97e01..370e20fae9 100644 --- a/src/exo/master/placement_utils.py +++ b/src/exo/master/placement_utils.py @@ -10,6 +10,7 @@ from exo.shared.types.topology import Cycle, RDMAConnection, SocketConnection from exo.shared.types.worker.runners import RunnerId, ShardAssignments from exo.shared.types.worker.shards import ( + AsymmetricTensorShardMetadata, CfgShardMetadata, PipelineShardMetadata, Sharding, @@ -273,6 +274,74 @@ def get_shard_assignments_for_tensor_parallel( return shard_assignments +def get_shard_assignments_for_asymmetric_tensor_parallel( + model_card: ModelCard, + cycle: Cycle, + node_memory: Mapping[NodeId, MemoryUsage], +) -> ShardAssignments: + """Create shard assignments for asymmetric tensor parallelism. + + Each node gets a ratio of weights proportional to its available memory. + All nodes compute every layer simultaneously. + """ + total_layers = model_card.n_layers + world_size = len(cycle) + + ordered_nodes = list(cycle) + + # The placement layer orders the cycle so rank 0 is both the largest-memory + # node and socket-reachable for distributed initialization. + total_available = sum( + node_memory[node_id].ram_available.in_bytes for node_id in ordered_nodes + ) + memory_fractions = [ + node_memory[node_id].ram_available.in_bytes / total_available + for node_id in ordered_nodes + ] + + from exo.worker.engines.mlx.asymmetric_parallel import find_valid_ratios + + ratios = find_valid_ratios( + memory_fractions=memory_fractions, + hidden_size=model_card.hidden_size, + num_attention_heads=model_card.hidden_size // 128, + num_key_value_heads=model_card.num_key_value_heads or 2, + ) + if ratios is None: + raise ValueError( + f"No valid asymmetric ratio found for hidden_size={model_card.hidden_size}" + ) + + runner_to_shard: dict[RunnerId, ShardMetadata] = {} + node_to_runner: dict[NodeId, RunnerId] = {} + rank_zero_ratio = ratios[0] + + for i, node_id in enumerate(ordered_nodes): + shard = AsymmetricTensorShardMetadata( + model_card=model_card, + device_rank=i, + world_size=world_size, + start_layer=0, + end_layer=total_layers, + n_layers=total_layers, + ratio=rank_zero_ratio, + ) + runner_id = RunnerId() + runner_to_shard[runner_id] = shard + node_to_runner[node_id] = runner_id + + logger.info( + f"Asymmetric TP: ratios={[f'{r:.0%}' for r in ratios]} " + f"across {world_size} nodes" + ) + + return ShardAssignments( + model_id=model_card.model_id, + runner_to_shard=runner_to_shard, + node_to_runner=node_to_runner, + ) + + def get_shard_assignments( model_card: ModelCard, cycle: Cycle, @@ -291,6 +360,12 @@ def get_shard_assignments( model_card=model_card, cycle=cycle, ) + case Sharding.AsymmetricTensor: + return get_shard_assignments_for_asymmetric_tensor_parallel( + model_card=model_card, + cycle=cycle, + node_memory=node_memory, + ) def get_mlx_jaccl_devices_matrix( @@ -348,17 +423,20 @@ def find_ip_prioritised( Priority: ethernet > wifi > unknown > thunderbolt """ ips = list(_find_connection_ip(node_id, other_node_id, cycle_digraph)) - if not ips: - return None other_network = node_network.get(other_node_id, NodeNetworkInfo()) ip_to_type = { iface.ip_address: iface.interface_type for iface in other_network.interfaces } + if not ips: + ips = _fallback_interface_ips(other_network) + if not ips: + return None + # Ring should prioritise fastest connection. As a best-effort, we prioritise TB. # TODO: Profile and get actual connection speeds. if ring: - priority = { + type_priority = { "thunderbolt": 0, "maybe_ethernet": 1, "ethernet": 2, @@ -368,14 +446,53 @@ def find_ip_prioritised( # RDMA prefers ethernet coordinator else: - priority = { + type_priority = { "ethernet": 0, - "wifi": 1, - "unknown": 2, - "maybe_ethernet": 3, + "maybe_ethernet": 1, + "wifi": 2, + "unknown": 3, "thunderbolt": 4, } - return min(ips, key=lambda ip: priority.get(ip_to_type.get(ip, "unknown"), 2)) + + return min( + ips, + key=lambda ip: ( + _address_priority(ip), + type_priority.get(ip_to_type.get(ip, "unknown"), 5), + ), + ) + + +def _fallback_interface_ips(node_network: NodeNetworkInfo) -> list[str]: + """Return advertised node IPs when topology only has non-socket edges.""" + return [ + iface.ip_address + for iface in node_network.interfaces + if _is_candidate_host_ip(iface.ip_address) + ] + + +def _is_candidate_host_ip(ip: str) -> bool: + if ":" in ip: + return False + return not (ip.startswith("127.") or ip == "0.0.0.0") + + +def _address_priority(ip: str) -> int: + if ip.startswith(("192.168.", "10.")): + return 0 + if ip.startswith("172."): + try: + second_octet = int(ip.split(".")[1]) + except (IndexError, ValueError): + return 3 + if 16 <= second_octet <= 31: + return 0 + if ip.startswith("100."): + return 2 + if ip.startswith("169.254."): + return 3 + return 1 def get_mlx_ring_hosts_by_node( diff --git a/src/exo/master/tests/test_placement.py b/src/exo/master/tests/test_placement.py index d3acd24f18..2856bf246c 100644 --- a/src/exo/master/tests/test_placement.py +++ b/src/exo/master/tests/test_placement.py @@ -47,7 +47,11 @@ MlxRingInstance, ) from exo.shared.types.worker.runners import ShardAssignments -from exo.shared.types.worker.shards import PipelineShardMetadata, Sharding +from exo.shared.types.worker.shards import ( + AsymmetricTensorShardMetadata, + PipelineShardMetadata, + Sharding, +) @pytest.fixture @@ -499,6 +503,307 @@ def test_tensor_rdma_backend_connectivity_matrix( assert len(ip_part.split(".")) == 4 +def test_qwen3_5_tensor_auto_upgrade_requires_opt_in( + monkeypatch: pytest.MonkeyPatch, +) -> None: + topology = Topology() + large_node = NodeId() + small_node = NodeId() + topology.add_node(large_node) + topology.add_node(small_node) + topology.add_connection( + Connection(source=large_node, sink=small_node, edge=create_rdma_connection(1)) + ) + topology.add_connection( + Connection(source=small_node, sink=large_node, edge=create_rdma_connection(2)) + ) + topology.add_connection( + Connection(source=large_node, sink=small_node, edge=create_socket_connection(1)) + ) + topology.add_connection( + Connection(source=small_node, sink=large_node, edge=create_socket_connection(2)) + ) + + model_card = ModelCard( + model_id=ModelId("mlx-community/Qwen3.5-72B-8bit"), + storage_size=Memory.from_bytes(130_648_036_320), + n_layers=48, + hidden_size=3072, + num_key_value_heads=8, + supports_tensor=True, + tasks=[ModelTask.TextGeneration], + family="qwen", + base_model="Qwen3.5 72B", + ) + command = PlaceInstance( + sharding=Sharding.Tensor, + instance_meta=InstanceMeta.MlxJaccl, + command_id=CommandId(), + model_card=model_card, + min_nodes=2, + ) + + node_rdma_ctl = { + large_node: NodeRdmaCtlStatus(enabled=True), + small_node: NodeRdmaCtlStatus(enabled=True), + } + placements_without_opt_in = place_instance( + command, + topology, + {}, + { + large_node: create_node_memory(128_000_000_000), + small_node: create_node_memory(48_000_000_000), + }, + { + large_node: create_node_network(), + small_node: create_node_network(), + }, + node_rdma_ctl=node_rdma_ctl, + ) + instance_without_opt_in = next(iter(placements_without_opt_in.values())) + large_runner_without_opt_in = ( + instance_without_opt_in.shard_assignments.node_to_runner[large_node] + ) + large_shard_without_opt_in = ( + instance_without_opt_in.shard_assignments.runner_to_shard[ + large_runner_without_opt_in + ] + ) + assert not isinstance(large_shard_without_opt_in, AsymmetricTensorShardMetadata) + + monkeypatch.setenv("EXO_ENABLE_ASYMMETRIC_TP_AUTO_UPGRADE", "1") + + placements = place_instance( + command, + topology, + {}, + { + large_node: create_node_memory(128_000_000_000), + small_node: create_node_memory(48_000_000_000), + }, + { + large_node: create_node_network(), + small_node: create_node_network(), + }, + node_rdma_ctl=node_rdma_ctl, + ) + + instance = next(iter(placements.values())) + large_runner = instance.shard_assignments.node_to_runner[large_node] + small_runner = instance.shard_assignments.node_to_runner[small_node] + large_shard = instance.shard_assignments.runner_to_shard[large_runner] + small_shard = instance.shard_assignments.runner_to_shard[small_runner] + + assert isinstance(large_shard, AsymmetricTensorShardMetadata) + assert isinstance(small_shard, AsymmetricTensorShardMetadata) + assert large_shard.device_rank == 0 + assert small_shard.device_rank == 1 + assert large_shard.ratio == small_shard.ratio == 0.75 + + +def test_qwen3_5_tensor_auto_upgrade_ignores_non_two_node_cycles( + monkeypatch: pytest.MonkeyPatch, +) -> None: + topology = Topology() + node_id_a = NodeId() + node_id_b = NodeId() + node_id_c = NodeId() + topology.add_node(node_id_a) + topology.add_node(node_id_b) + topology.add_node(node_id_c) + topology.add_connection( + Connection(source=node_id_a, sink=node_id_b, edge=create_socket_connection(1)) + ) + topology.add_connection( + Connection(source=node_id_b, sink=node_id_c, edge=create_socket_connection(2)) + ) + topology.add_connection( + Connection(source=node_id_c, sink=node_id_a, edge=create_socket_connection(3)) + ) + + model_card = ModelCard( + model_id=ModelId("mlx-community/Qwen3.5-72B-8bit"), + storage_size=Memory.from_bytes(140_000_000_000), + n_layers=48, + hidden_size=3072, + num_key_value_heads=6, + supports_tensor=True, + tasks=[ModelTask.TextGeneration], + family="qwen", + base_model="Qwen3.5 72B", + ) + command = PlaceInstance( + sharding=Sharding.Tensor, + instance_meta=InstanceMeta.MlxRing, + command_id=CommandId(), + model_card=model_card, + min_nodes=3, + ) + + monkeypatch.setenv("EXO_ENABLE_ASYMMETRIC_TP_AUTO_UPGRADE", "1") + + placements = place_instance( + command, + topology, + {}, + { + node_id_a: create_node_memory(128_000_000_000), + node_id_b: create_node_memory(128_000_000_000), + node_id_c: create_node_memory(48_000_000_000), + }, + { + node_id_a: create_node_network(), + node_id_b: create_node_network(), + node_id_c: create_node_network(), + }, + ) + + instance = next(iter(placements.values())) + assert len(instance.shard_assignments.node_to_runner) == 3 + assert all( + not isinstance(shard, AsymmetricTensorShardMetadata) + for shard in instance.shard_assignments.runner_to_shard.values() + ) + + +def test_asymmetric_tensor_rejects_unreachable_largest_rank_zero() -> None: + topology = Topology() + large_node = NodeId() + small_node = NodeId() + topology.add_node(large_node) + topology.add_node(small_node) + topology.add_connection( + Connection(source=large_node, sink=small_node, edge=create_rdma_connection(1)) + ) + topology.add_connection( + Connection(source=small_node, sink=large_node, edge=create_rdma_connection(2)) + ) + topology.add_connection( + Connection(source=large_node, sink=small_node, edge=create_socket_connection(3)) + ) + + model_card = ModelCard( + model_id=ModelId("mlx-community/Qwen3.5-72B-8bit"), + storage_size=Memory.from_bytes(130_648_036_320), + n_layers=48, + hidden_size=3072, + num_key_value_heads=8, + supports_tensor=True, + tasks=[ModelTask.TextGeneration], + family="qwen", + base_model="Qwen3.5 72B", + ) + command = PlaceInstance( + sharding=Sharding.AsymmetricTensor, + instance_meta=InstanceMeta.MlxRing, + command_id=CommandId(), + model_card=model_card, + min_nodes=2, + ) + + with pytest.raises(ValueError, match="rank-0 node socket-reachable"): + place_instance( + command, + topology, + {}, + { + large_node: create_node_memory(128_000_000_000), + small_node: create_node_memory(48_000_000_000), + }, + { + large_node: create_node_network(), + small_node: create_node_network(), + }, + ) + + +def test_asymmetric_tensor_rejects_qwen3_5_with_unsplittable_kv_heads() -> None: + topology = Topology() + large_node = NodeId() + small_node = NodeId() + topology.add_node(large_node) + topology.add_node(small_node) + topology.add_connection( + Connection(source=large_node, sink=small_node, edge=create_socket_connection(1)) + ) + topology.add_connection( + Connection(source=small_node, sink=large_node, edge=create_socket_connection(2)) + ) + + model_card = ModelCard( + model_id=ModelId("mlx-community/Qwen3.5-122B-A10B-8bit"), + storage_size=Memory.from_bytes(130_648_036_320), + n_layers=48, + hidden_size=3072, + num_key_value_heads=2, + supports_tensor=True, + tasks=[ModelTask.TextGeneration], + family="qwen", + base_model="Qwen3.5 122B A10B", + ) + command = PlaceInstance( + sharding=Sharding.AsymmetricTensor, + instance_meta=InstanceMeta.MlxRing, + command_id=CommandId(), + model_card=model_card, + min_nodes=2, + ) + + with pytest.raises(ValueError, match="No valid asymmetric ratio"): + place_instance( + command, + topology, + {}, + { + large_node: create_node_memory(128_000_000_000), + small_node: create_node_memory(48_000_000_000), + }, + { + large_node: create_node_network(), + small_node: create_node_network(), + }, + ) + + +def test_asymmetric_tensor_rejects_unsupported_model_family( + model_card: ModelCard, +) -> None: + topology = Topology() + node_id_a = NodeId() + node_id_b = NodeId() + topology.add_node(node_id_a) + topology.add_node(node_id_b) + topology.add_connection( + Connection(source=node_id_a, sink=node_id_b, edge=create_socket_connection(1)) + ) + topology.add_connection( + Connection(source=node_id_b, sink=node_id_a, edge=create_socket_connection(2)) + ) + command = PlaceInstance( + sharding=Sharding.AsymmetricTensor, + instance_meta=InstanceMeta.MlxRing, + command_id=CommandId(), + model_card=model_card, + min_nodes=2, + ) + + with pytest.raises(ValueError, match="Supported: Qwen3.5"): + place_instance( + command, + topology, + {}, + { + node_id_a: create_node_memory(2_000_000), + node_id_b: create_node_memory(2_000_000), + }, + { + node_id_a: create_node_network(), + node_id_b: create_node_network(), + }, + ) + + def _build_three_node_rdma_topology() -> tuple[ Topology, NodeId, NodeId, NodeId, dict[NodeId, NodeNetworkInfo] ]: @@ -624,6 +929,196 @@ def test_place_mlx_jaccl_rejects_when_node_rdma_ctl_missing(model_card: ModelCar ) +def test_ring_placement_uses_advertised_lan_ips_for_rdma_only_topology( + model_card: ModelCard, +) -> None: + topology = Topology() + model_card = model_card.model_copy( + update={ + "storage_size": Memory.from_bytes(1500), + "n_layers": 12, + } + ) + + node_a = NodeId() + node_b = NodeId() + + topology.add_node(node_a) + topology.add_node(node_b) + topology.add_connection( + Connection(source=node_a, sink=node_b, edge=create_rdma_connection(1)) + ) + topology.add_connection( + Connection(source=node_b, sink=node_a, edge=create_rdma_connection(2)) + ) + + node_memory = { + node_a: create_node_memory(1000), + node_b: create_node_memory(1000), + } + node_network = { + node_a: NodeNetworkInfo( + interfaces=[ + NetworkInterfaceInfo( + name="en9", ip_address="192.168.1.10", interface_type="ethernet" + ) + ] + ), + node_b: NodeNetworkInfo( + interfaces=[ + NetworkInterfaceInfo( + name="en9", ip_address="192.168.1.11", interface_type="ethernet" + ) + ] + ), + } + + command = place_instance_command(model_card) + command = command.model_copy(update={"min_nodes": 2}) + + placements = place_instance(command, topology, {}, node_memory, node_network) + + instance = list(placements.values())[0] + assert isinstance(instance, MlxRingInstance) + assert len(instance.shard_assignments.node_to_runner) == 2 + assert any(host.ip == "192.168.1.11" for host in instance.hosts_by_node[node_a]) + assert any(host.ip == "192.168.1.10" for host in instance.hosts_by_node[node_b]) + + +def test_jaccl_placement_uses_advertised_lan_ip_for_rdma_coordinator( + model_card: ModelCard, +) -> None: + topology = Topology() + model_card = model_card.model_copy( + update={ + "storage_size": Memory.from_bytes(1500), + "n_layers": 12, + "hidden_size": 32, + "num_key_value_heads": 8, + "supports_tensor": True, + } + ) + + node_a = NodeId() + node_b = NodeId() + + topology.add_node(node_a) + topology.add_node(node_b) + topology.add_connection( + Connection(source=node_a, sink=node_b, edge=create_rdma_connection(1)) + ) + topology.add_connection( + Connection(source=node_b, sink=node_a, edge=create_rdma_connection(2)) + ) + + node_memory = { + node_a: create_node_memory(1000), + node_b: create_node_memory(1000), + } + node_network = { + node_a: NodeNetworkInfo( + interfaces=[ + NetworkInterfaceInfo( + name="en9", ip_address="192.168.1.10", interface_type="ethernet" + ) + ] + ), + node_b: NodeNetworkInfo( + interfaces=[ + NetworkInterfaceInfo( + name="en9", ip_address="192.168.1.11", interface_type="ethernet" + ) + ] + ), + } + command = PlaceInstance( + sharding=Sharding.Tensor, + instance_meta=InstanceMeta.MlxJaccl, + command_id=CommandId(), + model_card=model_card, + min_nodes=2, + ) + + node_rdma_ctl = { + node_a: NodeRdmaCtlStatus(enabled=True), + node_b: NodeRdmaCtlStatus(enabled=True), + } + placements = place_instance( + command, + topology, + {}, + node_memory, + node_network, + node_rdma_ctl=node_rdma_ctl, + ) + + instance = list(placements.values())[0] + assert isinstance(instance, MlxJacclInstance) + assert len(instance.shard_assignments.node_to_runner) == 2 + assert any( + coordinator.startswith("192.168.1.") + for coordinator in instance.jaccl_coordinators.values() + ) + + +def test_placement_prefers_socket_reachable_rank_zero( + model_card: ModelCard, +) -> None: + topology = Topology() + model_card = model_card.model_copy( + update={ + "storage_size": Memory.from_bytes(1500), + "n_layers": 12, + } + ) + + listener = NodeId() + peer = NodeId() + + topology.add_node(listener) + topology.add_node(peer) + topology.add_connection( + Connection(source=listener, sink=peer, edge=create_rdma_connection(1)) + ) + topology.add_connection( + Connection(source=peer, sink=listener, edge=create_rdma_connection(2)) + ) + topology.add_connection( + Connection(source=peer, sink=listener, edge=create_socket_connection(10)) + ) + + node_memory = { + listener: create_node_memory(1000), + peer: create_node_memory(1000), + } + node_network = { + listener: NodeNetworkInfo( + interfaces=[ + NetworkInterfaceInfo( + name="en9", ip_address="192.168.1.10", interface_type="ethernet" + ) + ] + ), + peer: NodeNetworkInfo( + interfaces=[ + NetworkInterfaceInfo( + name="en9", ip_address="192.168.1.11", interface_type="ethernet" + ) + ] + ), + } + + command = place_instance_command(model_card) + command = command.model_copy(update={"min_nodes": 2}) + + placements = place_instance(command, topology, {}, node_memory, node_network) + + instance = list(placements.values())[0] + runner_id = instance.shard_assignments.node_to_runner[listener] + shard = instance.shard_assignments.runner_to_shard[runner_id] + assert shard.device_rank == 0 + + def _make_task( instance_id: InstanceId, status: TaskStatus = TaskStatus.Running, diff --git a/src/exo/shared/types/thunderbolt.py b/src/exo/shared/types/thunderbolt.py index 34cd1ccad9..d6b8d3742c 100644 --- a/src/exo/shared/types/thunderbolt.py +++ b/src/exo/shared/types/thunderbolt.py @@ -25,6 +25,28 @@ class _ReceptacleTag(BaseModel, extra="ignore"): class _ConnectivityItem(BaseModel, extra="ignore"): domain_uuid_key: str | None = None + items: list["_ConnectivityItem"] | None = Field(None, alias="_items") + + +def _first_descendant_domain_uuid(items: list[_ConnectivityItem]) -> str | None: + """Return the first ``domain_uuid_key`` found by depth-first search. + + Apple's ``system_profiler SPThunderboltDataType`` output places intermediate + Thunderbolt hubs/docks (e.g. an iVANKY Fusiondock Ultra) between the local + receptacle and the peer Mac. The hub appears as an ``_items`` entry without + a ``domain_uuid_key`` of its own; the peer Mac sits one level deeper. We + descend until we hit the first node that exposes a domain UUID, which is + always the actual peer endpoint regardless of how many transparent + switches sit between us. + """ + for item in items: + if item.domain_uuid_key is not None: + return item.domain_uuid_key + if item.items is not None: + descendant = _first_descendant_domain_uuid(item.items) + if descendant is not None: + return descendant + return None class ThunderboltConnectivityData(BaseModel, extra="ignore"): @@ -53,14 +75,7 @@ def conn(self) -> ThunderboltConnection | None: if self.domain_uuid_key is None or self.items is None: return - sink_key = next( - ( - item.domain_uuid_key - for item in self.items - if item.domain_uuid_key is not None - ), - None, - ) + sink_key = _first_descendant_domain_uuid(self.items) if sink_key is None: return None diff --git a/src/exo/shared/types/worker/shards.py b/src/exo/shared/types/worker/shards.py index 59a6c54eb0..112f6377a7 100644 --- a/src/exo/shared/types/worker/shards.py +++ b/src/exo/shared/types/worker/shards.py @@ -9,6 +9,7 @@ class Sharding(str, Enum): Tensor = "Tensor" + AsymmetricTensor = "AsymmetricTensor" Pipeline = "Pipeline" @@ -79,6 +80,34 @@ class TensorShardMetadata(BaseShardMetadata): pass +@final +class AsymmetricTensorShardMetadata(BaseShardMetadata): + """ + Asymmetric tensor parallelism shard metadata. + + Unlike standard tensor parallelism which splits weights 50/50 (or equally + across N nodes), asymmetric TP splits weights proportionally to each node's + available memory. This enables heterogeneous clusters (e.g. 128GB + 48GB) + to run models using tensor parallelism where equal splits wouldn't fit. + + Each node holds a different fraction of each weight tensor, but ALL nodes + compute every layer simultaneously. The all_sum reduction still works + correctly because (x_a @ W_a^T) + (x_b @ W_b^T) = x @ W^T regardless + of how W is partitioned. + """ + + ratio: float = Field( + ge=0.0, + le=1.0, + description="Split point for rank 0, shared across all ranks. " + "e.g. 0.75 means rank 0 gets the first 75% and rank 1 gets the last 25%. " + "Every rank stores the same value so all workers agree on the split.", + ) + + ShardMetadata: TypeAlias = ( - PipelineShardMetadata | CfgShardMetadata | TensorShardMetadata + PipelineShardMetadata + | CfgShardMetadata + | TensorShardMetadata + | AsymmetricTensorShardMetadata ) diff --git a/src/exo/utils/info_gatherer/tests/test_tb_parsing.py b/src/exo/utils/info_gatherer/tests/test_tb_parsing.py index 787dd3d5f9..d2551f0c6b 100644 --- a/src/exo/utils/info_gatherer/tests/test_tb_parsing.py +++ b/src/exo/utils/info_gatherer/tests/test_tb_parsing.py @@ -4,6 +4,7 @@ from exo.shared.types.thunderbolt import ( ThunderboltConnectivity, + ThunderboltConnectivityData, ) from exo.utils.info_gatherer.info_gatherer import ( _gather_iface_map, # pyright: ignore[reportPrivateUsage] @@ -22,3 +23,64 @@ async def test_tb_parsing(): for datum in data: datum.ident(ifaces) datum.conn() + + +def test_conn_resolves_peer_through_intermediate_hub() -> None: + """A TB hub between two Macs hides the peer one level deeper. + + Reproduces the iVANKY Fusiondock Ultra topology observed on the + wc-bmbp <-> wc-smbp link, where ``system_profiler`` reports the dock at + the first ``_items`` level and the peer Mac nested inside it. The parser + must walk past the dock and surface the peer's domain UUID, otherwise + half the RDMA mesh stays invisible to the placement engine. + """ + payload = { + "domain_uuid_key": "DCA2B6F5-1C58-4589-8DA8-90B9326462D6", + "receptacle_1_tag": { + "receptacle_id_key": "2", + "current_speed_key": "80 Gb/s", + }, + "_items": [ + { + "_name": "iVANKY Fusiondock Ultra", + "_items": [ + { + "_name": "MacBook Pro", + "domain_uuid_key": "F74D8F9B-DCDF-40D4-A428-3A3674BCB3F4", + } + ], + } + ], + } + datum = ThunderboltConnectivityData.model_validate(payload) + conn = datum.conn() + assert conn is not None + assert conn.source_uuid == "DCA2B6F5-1C58-4589-8DA8-90B9326462D6" + assert conn.sink_uuid == "F74D8F9B-DCDF-40D4-A428-3A3674BCB3F4" + + +def test_conn_returns_first_peer_for_direct_link() -> None: + """A direct cable still surfaces the peer at the first level.""" + payload = { + "domain_uuid_key": "EA94B959-A0C4-1111-1111-111111111111", + "_items": [ + { + "_name": "MacBook Pro", + "domain_uuid_key": "D02B9C20-7504-2222-2222-222222222222", + } + ], + } + datum = ThunderboltConnectivityData.model_validate(payload) + conn = datum.conn() + assert conn is not None + assert conn.sink_uuid == "D02B9C20-7504-2222-2222-222222222222" + + +def test_conn_returns_none_when_no_peer_present() -> None: + """Empty ``_items`` (e.g. unconnected receptacle) yields no edge.""" + payload: dict[str, object] = { + "domain_uuid_key": "AAA", + "_items": [], + } + datum = ThunderboltConnectivityData.model_validate(payload) + assert datum.conn() is None diff --git a/src/exo/worker/engines/mlx/asymmetric_parallel.py b/src/exo/worker/engines/mlx/asymmetric_parallel.py new file mode 100644 index 0000000000..7f142955b5 --- /dev/null +++ b/src/exo/worker/engines/mlx/asymmetric_parallel.py @@ -0,0 +1,375 @@ +""" +Asymmetric Tensor Parallelism for heterogeneous clusters. + +When nodes have different amounts of RAM, standard 50/50 tensor parallelism +fails because the smaller node can't hold half the weights. Asymmetric TP +splits each weight tensor proportionally to available memory (e.g. 75/25) +so both nodes compute every layer simultaneously. + +Mathematical correctness: + Column parallel: y = x @ [W_a; W_b]^T = [x @ W_a^T, x @ W_b^T] + Row parallel: y = x_a @ W_a^T + x_b @ W_b^T = x @ W^T (via all_sum) + Both hold regardless of the split ratio. + +Usage: + asymmetric_tensor_auto_parallel(model, group, ratios=[0.75, 0.25]) +""" +# pyright: reportAny=false, reportUnknownMemberType=false, reportUnknownArgumentType=false, reportUnknownVariableType=false + +from __future__ import annotations + +from collections.abc import Generator +from typing import Any + +import mlx.core as mx +import mlx.nn as nn +from mlx.nn.layers.distributed import sum_gradients +from mlx_lm.models.qwen3_5 import DecoderLayer as Qwen3_5DecoderLayer +from mlx_lm.models.qwen3_5 import GatedDeltaNet +from mlx_lm.models.qwen3_5 import SparseMoeBlock as Qwen3_5SparseMoeBlock +from mlx_lm.models.qwen3_next import Qwen3NextAttention as Attention +from mlx_lm.models.qwen3_next import Qwen3NextMLP as Qwen3NextMLP +from mlx_lm.models.qwen3_next import Qwen3NextSparseMoeBlock as SparseMoeBlock + +from exo.shared.types.worker.runner_response import ModelLoadingResponse + +try: + from exo.shared.logging import logger +except ImportError: + import logging + + logger = logging.getLogger(__name__) + + +def find_valid_ratios( + memory_fractions: list[float], + hidden_size: int, + num_attention_heads: int, + num_key_value_heads: int, + num_experts: int = 0, + moe_intermediate_size: int = 0, + linear_num_value_heads: int = 0, + linear_num_key_heads: int = 0, + quantization_group_size: int = 64, +) -> list[float] | None: + """ + Find valid split ratios for asymmetric TP given model dimensions and memory fractions. + + A valid ratio must produce integer dimensions for all split tensors, + and all split dimensions must be divisible by the quantization group size. + + Returns a list of ratios (one per node) that sum to 1.0, or None if no valid + ratio exists. Currently supports 2 nodes only. + """ + if len(memory_fractions) != 2: + logger.warning("Asymmetric TP currently only supports 2 nodes") + return None + + # Key dimensions that must split cleanly + key_dims = [ + num_attention_heads, + num_key_value_heads, + hidden_size, + ] + if linear_num_value_heads > 0: + key_dims.extend([linear_num_value_heads, linear_num_key_heads]) + if num_experts > 0 and moe_intermediate_size > 0: + key_dims.append(moe_intermediate_size) + + target_ratio = memory_fractions[0] + + # Try ratios of the form n/d where d is a power of 2 or common denominator + # that produces clean splits. Test denominators 2..32. + best_ratio = None + best_distance = float("inf") + + for denom in [2, 4, 8, 16, 32]: + for numer in range(1, denom): + ratio = numer / denom + if ratio <= 0.5 or ratio > 0.95: + continue + + # Check all dimensions split cleanly + valid = True + for dim in key_dims: + # dim * ratio must be EXACTLY integer (for head counts) + exact = dim * ratio + if exact != int(exact): + valid = False + break + a = int(exact) + b = dim - a + if a <= 0 or b <= 0: + valid = False + break + # For quantized weights, split dims must be divisible by 8 + if dim > quantization_group_size and (a % 8 != 0 or b % 8 != 0): + valid = False + break + + if valid: + distance = abs(ratio - target_ratio) + if distance < best_distance: + best_distance = distance + best_ratio = ratio + + if best_ratio is None: + return None + + return [best_ratio, 1.0 - best_ratio] + + +def _split_at(tensor: mx.array, axis: int, ratio: float) -> tuple[mx.array, mx.array]: + """Split tensor at ratio point along axis.""" + sp = int(tensor.shape[axis] * ratio) + parts = mx.split(tensor, [sp], axis=axis) + return mx.contiguous(parts[0]), mx.contiguous(parts[1]) + + +def _my_shard(tensor: mx.array, axis: int, rank: int, ratio: float) -> mx.array: + """Get rank's portion of an asymmetric split.""" + parts = _split_at(tensor, axis, ratio) + return parts[0] if rank == 0 else parts[1] + + +def _shard_quantized_ats( + layer: Any, + axis: int, + rank: int, + ratio: float, + segments: list[int] | None = None, +) -> None: + """Shard quantized linear all-to-sharded (output dim split).""" + if segments is not None: + w: mx.array = layer.weight + seg_parts_w = mx.split(w, segments, axis=axis) + my_w_parts = [_my_shard(p, axis, rank, ratio) for p in seg_parts_w] + layer.weight = mx.contiguous(mx.concatenate(my_w_parts, axis=axis)) + for attr in ["scales", "biases"]: + t: mx.array | None = getattr(layer, attr, None) + if t is None: + continue + t_seg = [int(s * t.shape[axis] / w.shape[axis]) for s in segments] + t_parts = mx.split(t, t_seg, axis=axis) + my_parts = [_my_shard(p, axis, rank, ratio) for p in t_parts] + setattr(layer, attr, mx.contiguous(mx.concatenate(my_parts, axis=axis))) + else: + for attr in ["weight", "scales", "biases"]: + t_val: mx.array | None = getattr(layer, attr, None) + if t_val is None: + continue + setattr(layer, attr, _my_shard(t_val, axis, rank, ratio)) + + +def _shard_quantized_sta(layer: Any, rank: int, ratio: float) -> None: + """Shard quantized linear sharded-to-all (input dim, axis -1).""" + for attr in ["weight", "scales", "biases"]: + t: mx.array | None = getattr(layer, attr, None) + if t is None: + continue + setattr(layer, attr, _my_shard(t, -1, rank, ratio)) + + +def _shard_gated_delta_net( + gdn: GatedDeltaNet, rank: int, ratio: float, group: mx.distributed.Group +) -> None: + """Asymmetric shard for GatedDeltaNet (linear attention) layers.""" + kd = gdn.key_dim + _shard_quantized_ats(gdn.in_proj_qkv, 0, rank, ratio, segments=[kd, 2 * kd]) + _shard_quantized_ats(gdn.in_proj_z, 0, rank, ratio) + _shard_quantized_ats(gdn.in_proj_b, 0, rank, ratio) + _shard_quantized_ats(gdn.in_proj_a, 0, rank, ratio) + _shard_quantized_sta(gdn.out_proj, rank, ratio) + + # conv1d: segmented split along channel dim + conv_w = gdn.conv1d.weight + seg_parts = mx.split(conv_w, [kd, 2 * kd], axis=0) + my_parts = [_my_shard(p, 0, rank, ratio) for p in seg_parts] + gdn.conv1d.weight = mx.contiguous(mx.concatenate(my_parts, axis=0)) + + gdn.dt_bias = _my_shard(gdn.dt_bias, 0, rank, ratio) + gdn.A_log = _my_shard(gdn.A_log, 0, rank, ratio) + + r = ratio if rank == 0 else (1 - ratio) + gdn.num_k_heads = int(gdn.num_k_heads * r) + gdn.num_v_heads = int(gdn.num_v_heads * r) + gdn.key_dim = int(gdn.key_dim * r) + gdn.value_dim = int(gdn.value_dim * r) + gdn.conv_dim = int(gdn.conv_dim * r) + gdn.conv1d.groups = gdn.conv_dim + gdn.sharding_group = group + + +# Patching must happen at the class level since nn.Module.__call__ ignores instance overrides +_attention_class_patched: set[type] = set() + + +def _patch_attention_class(attn_cls: type) -> None: + """Patch an attention class to add all_sum when _asymmetric_tp_group is set.""" + if attn_cls in _attention_class_patched: + return + + original_call = attn_cls.__call__ + + def patched_call( + self: nn.Module, + x: mx.array, + mask: mx.array | None = None, + cache: object | None = None, + ) -> mx.array: + result = original_call(self, x, mask=mask, cache=cache) + grp = getattr(self, "_asymmetric_tp_group", None) + if grp is not None: + result = mx.distributed.all_sum(result, group=grp) + return result + + attn_cls.__call__ = patched_call + _attention_class_patched.add(attn_cls) + + +def _shard_attention( + attn: Attention, rank: int, ratio: float, group: mx.distributed.Group +) -> None: + """Asymmetric shard for self-attention layers.""" + _patch_attention_class(type(attn)) + _shard_quantized_ats(attn.q_proj, 0, rank, ratio) + _shard_quantized_ats(attn.k_proj, 0, rank, ratio) + _shard_quantized_ats(attn.v_proj, 0, rank, ratio) + _shard_quantized_sta(attn.o_proj, rank, ratio) + + r = ratio if rank == 0 else (1 - ratio) + attn.num_attention_heads = int(attn.num_attention_heads * r) + attn.num_key_value_heads = int(attn.num_key_value_heads * r) + attn._asymmetric_tp_group = group + + +class AsymmetricShardedMoE(nn.Module): + def __init__(self, layer: SparseMoeBlock | Qwen3_5SparseMoeBlock): + super().__init__() + self.original_layer = layer + self.sharding_group: mx.distributed.Group | None = None + + def __call__(self, x: mx.array) -> mx.array: + if self.sharding_group is not None: + x = sum_gradients(self.sharding_group)(x) + y = self.original_layer(x) + if self.sharding_group is not None: + y = mx.distributed.all_sum(y, group=self.sharding_group) + return y + + +def _shard_sparse_moe( + moe: SparseMoeBlock | Qwen3_5SparseMoeBlock, + rank: int, + ratio: float, + group: mx.distributed.Group, +) -> AsymmetricShardedMoE: + """Asymmetric shard for SparseMoeBlock (MoE layers).""" + # switch_mlp: split expert intermediate dims (axis 1 for 3D expert weights) + _shard_quantized_ats(moe.switch_mlp.gate_proj, 1, rank, ratio) + _shard_quantized_ats(moe.switch_mlp.up_proj, 1, rank, ratio) + _shard_quantized_sta(moe.switch_mlp.down_proj, rank, ratio) + + # shared_expert: standard MLP split + _shard_quantized_ats(moe.shared_expert.gate_proj, 0, rank, ratio) + _shard_quantized_ats(moe.shared_expert.up_proj, 0, rank, ratio) + _shard_quantized_sta(moe.shared_expert.down_proj, rank, ratio) + + sharded_moe = AsymmetricShardedMoE(moe) + sharded_moe.sharding_group = group + return sharded_moe + + +_mlp_class_patched: set[type] = set() + + +def _patch_mlp_class(mlp_cls: type) -> None: + """Patch a dense MLP class to add all_sum when _asymmetric_tp_group is set.""" + if mlp_cls in _mlp_class_patched: + return + + original_call = mlp_cls.__call__ + + def patched_call(self: nn.Module, x: mx.array) -> mx.array: + result = original_call(self, x) + grp = getattr(self, "_asymmetric_tp_group", None) + if grp is not None: + result = mx.distributed.all_sum(result, group=grp) + return result + + mlp_cls.__call__ = patched_call + _mlp_class_patched.add(mlp_cls) + + +def _shard_dense_mlp( + mlp: Qwen3NextMLP, rank: int, ratio: float, group: mx.distributed.Group +) -> None: + """Asymmetric shard for dense (non-MoE) MLP layers.""" + _patch_mlp_class(type(mlp)) + _shard_quantized_ats(mlp.gate_proj, 0, rank, ratio) + _shard_quantized_ats(mlp.up_proj, 0, rank, ratio) + _shard_quantized_sta(mlp.down_proj, rank, ratio) + mlp._asymmetric_tp_group = group + + +def asymmetric_tensor_auto_parallel( + model: nn.Module, + group: mx.distributed.Group, + ratios: list[float], +) -> Generator[ModelLoadingResponse, None, nn.Module]: + """ + Apply asymmetric tensor parallelism to a model. + + Args: + model: The model to parallelize (must have .layers property) + group: MLX distributed group + ratios: Per-rank weight fractions, e.g. [0.75, 0.25] for 2 nodes. + ratios[group.rank()] is this node's fraction. + + Returns: + The model with asymmetric sharding applied. + """ + rank = group.rank() + ratio = ratios[0] # ratio for rank 0; rank 1 gets 1-ratio + + # Get the inner model's layers + inner = model + for attr in ["language_model", "model"]: + candidate = getattr(inner, attr, None) + if candidate is not None and hasattr(candidate, "layers"): + inner = candidate + + layers: list[Any] = inner.layers if hasattr(inner, "layers") else model.layers + + total = len(layers) + for layer_index, layer in enumerate(layers): + if isinstance(layer, Qwen3_5DecoderLayer): + # Qwen3.5 hybrid: linear_attn or self_attn per layer + if layer.is_linear: + _shard_gated_delta_net(layer.linear_attn, rank, ratio, group) + else: + _shard_attention(layer.self_attn, rank, ratio, group) + + mlp = layer.mlp + if isinstance(mlp, (SparseMoeBlock, Qwen3_5SparseMoeBlock)): + dict.__setitem__( + layer, + "mlp", + _shard_sparse_moe(mlp, rank, ratio, group), + ) + else: + _shard_dense_mlp(mlp, rank, ratio, group) + else: + raise ValueError( + f"Asymmetric TP does not yet support layer type {type(layer).__name__}. " + f"Currently supported: Qwen3.5 (GatedDeltaNet + Attention + MoE). " + f"Contributions for other architectures welcome." + ) + mx.eval(layer) + yield ModelLoadingResponse(layers_loaded=layer_index, total=total) + + logger.info( + f"Asymmetric TP applied: rank {rank} gets " + f"{ratios[rank] * 100:.0f}% of each weight tensor" + ) + return model diff --git a/src/exo/worker/engines/mlx/utils_mlx.py b/src/exo/worker/engines/mlx/utils_mlx.py index 1dddad2ae1..db9143ec32 100644 --- a/src/exo/worker/engines/mlx/utils_mlx.py +++ b/src/exo/worker/engines/mlx/utils_mlx.py @@ -53,11 +53,15 @@ ) from exo.shared.types.worker.runner_response import ModelLoadingResponse from exo.shared.types.worker.shards import ( + AsymmetricTensorShardMetadata, CfgShardMetadata, PipelineShardMetadata, ShardMetadata, TensorShardMetadata, ) +from exo.worker.engines.mlx.asymmetric_parallel import ( + asymmetric_tensor_auto_parallel, +) from exo.worker.engines.mlx.auto_parallel import ( get_inner_model, get_layers, @@ -69,6 +73,19 @@ def get_weights_size(model_shard_meta: ShardMetadata) -> Memory: + if isinstance(model_shard_meta, AsymmetricTensorShardMetadata): + rank_weight_fraction = ( + model_shard_meta.ratio + if model_shard_meta.device_rank == 0 + else 1.0 - model_shard_meta.ratio + ) + return Memory.from_float_kb( + (model_shard_meta.end_layer - model_shard_meta.start_layer) + / model_shard_meta.n_layers + * model_shard_meta.model_card.storage_size.in_kb + * rank_weight_fraction + ) + return Memory.from_float_kb( (model_shard_meta.end_layer - model_shard_meta.start_layer) / model_shard_meta.n_layers @@ -100,6 +117,7 @@ def mlx_distributed_init( coordination_file = str( Path(tmpdir) / f"hosts_{bound_instance.instance.instance_id}_{rank}.json" ) + group: mx.distributed.Group | None = None # TODO: singleton instances match bound_instance.instance: case MlxRingInstance(hosts_by_node=hosts_by_node, ephemeral_port=_): @@ -125,7 +143,6 @@ def mlx_distributed_init( assert all( jaccl_devices[i][i] is None for i in range(len(jaccl_devices)) ) - # Use RDMA connectivity matrix jaccl_devices_json = json.dumps(jaccl_devices) with open(coordination_file, "w") as f: @@ -140,9 +157,25 @@ def mlx_distributed_init( os.environ["MLX_IBV_DEVICES"] = coordination_file os.environ["MLX_RANK"] = str(rank) os.environ["MLX_JACCL_COORDINATOR"] = jaccl_coordinator - group = mx.distributed.init(backend="jaccl", strict=True) + + max_jaccl_attempts = 8 + for attempt in range(1, max_jaccl_attempts + 1): + try: + group = mx.distributed.init(backend="jaccl", strict=True) + break + except (RuntimeError, ValueError) as exc: + if attempt == max_jaccl_attempts: + raise + backoff = min(2.0 * attempt, 10.0) + logger.warning( + f"rank {rank} JACCL init attempt {attempt}/{max_jaccl_attempts} " + f"failed ({exc}), retrying in {backoff:.0f}s" + ) + time.sleep(backoff) logger.info(f"Rank {rank} mlx distributed initialization complete") + if group is None: + raise RuntimeError("MLX distributed initialization did not return a group") return group @@ -264,6 +297,16 @@ def shard_and_load( case TensorShardMetadata(): logger.info(f"loading model from {model_path} with tensor parallelism") model = yield from tensor_auto_parallel(model, group) + case AsymmetricTensorShardMetadata(): + rank_zero_ratio = shard_metadata.ratio + ratios_list = [rank_zero_ratio, 1.0 - rank_zero_ratio] + logger.info( + f"loading model from {model_path} with asymmetric tensor parallelism " + f"(ratios={[f'{r:.0%}' for r in ratios_list]})" + ) + model = yield from asymmetric_tensor_auto_parallel( + model, group, ratios_list + ) case PipelineShardMetadata(): logger.info(f"loading model from {model_path} with pipeline parallelism") model = yield from pipeline_auto_parallel(model, group, shard_metadata) diff --git a/src/exo/worker/tests/unittests/test_mlx/test_asymmetric_parallel.py b/src/exo/worker/tests/unittests/test_mlx/test_asymmetric_parallel.py new file mode 100644 index 0000000000..bb3b89dd93 --- /dev/null +++ b/src/exo/worker/tests/unittests/test_mlx/test_asymmetric_parallel.py @@ -0,0 +1,120 @@ +"""Tests for asymmetric tensor parallelism ratio finding and sharding.""" + + +class TestFindValidRatios: + """Test the ratio solver that finds valid asymmetric split points.""" + + def test_qwen3_5_full_attention_dimensions_with_divisible_kv_heads(self) -> None: + from exo.worker.engines.mlx.asymmetric_parallel import find_valid_ratios + + ratios = find_valid_ratios( + memory_fractions=[0.73, 0.27], + hidden_size=3072, + num_attention_heads=32, + num_key_value_heads=8, + linear_num_value_heads=64, + linear_num_key_heads=16, + moe_intermediate_size=1024, + num_experts=256, + ) + assert ratios is not None + assert len(ratios) == 2 + assert abs(ratios[0] + ratios[1] - 1.0) < 1e-10 + # All head counts must be exact integers after split + assert 32 * ratios[0] == int(32 * ratios[0]) # attention heads + assert 8 * ratios[0] == int(8 * ratios[0]) # KV heads + assert 64 * ratios[0] == int(64 * ratios[0]) # value heads + assert 16 * ratios[0] == int(16 * ratios[0]) # key heads + + def test_rejects_qwen3_5_122b_two_kv_heads_for_asymmetric_attention(self) -> None: + from exo.worker.engines.mlx.asymmetric_parallel import find_valid_ratios + + ratios = find_valid_ratios( + memory_fractions=[0.73, 0.27], + hidden_size=3072, + num_attention_heads=32, + num_key_value_heads=2, + linear_num_value_heads=64, + linear_num_key_heads=16, + moe_intermediate_size=1024, + num_experts=256, + ) + + assert ratios is None + + def test_llama_70b_dimensions(self) -> None: + from exo.worker.engines.mlx.asymmetric_parallel import find_valid_ratios + + ratios = find_valid_ratios( + memory_fractions=[0.73, 0.27], + hidden_size=8192, + num_attention_heads=64, + num_key_value_heads=8, + ) + assert ratios is not None + assert 64 * ratios[0] == int(64 * ratios[0]) + + def test_nemotron_120b_dimensions(self) -> None: + from exo.worker.engines.mlx.asymmetric_parallel import find_valid_ratios + + ratios = find_valid_ratios( + memory_fractions=[0.73, 0.27], + hidden_size=4096, + num_attention_heads=32, + num_key_value_heads=8, + ) + assert ratios is not None + assert 32 * ratios[0] == int(32 * ratios[0]) + + def test_rejects_impossible_dimensions(self) -> None: + """Prime-number head count with no valid fractional split.""" + from exo.worker.engines.mlx.asymmetric_parallel import find_valid_ratios + + ratios = find_valid_ratios( + memory_fractions=[0.73, 0.27], + hidden_size=3072, + num_attention_heads=7, # prime: cannot split into 2 integer parts > 0.5 + num_key_value_heads=2, + ) + assert ratios is None + + def test_only_two_nodes_supported(self) -> None: + from exo.worker.engines.mlx.asymmetric_parallel import find_valid_ratios + + ratios = find_valid_ratios( + memory_fractions=[0.5, 0.25, 0.25], + hidden_size=4096, + num_attention_heads=32, + num_key_value_heads=8, + ) + assert ratios is None + + def test_ratio_closer_to_target(self) -> None: + """Ratio should be the closest valid one to the memory fraction.""" + from exo.worker.engines.mlx.asymmetric_parallel import find_valid_ratios + + # With 80% target, 0.8125 (13/16) is closer than 0.75 (12/16) + ratios = find_valid_ratios( + memory_fractions=[0.80, 0.20], + hidden_size=3072, + num_attention_heads=32, + num_key_value_heads=16, + linear_num_value_heads=64, + linear_num_key_heads=16, + ) + assert ratios is not None + assert abs(ratios[0] - 0.80) < abs(0.75 - 0.80) + + def test_equal_memory_returns_near_symmetric(self) -> None: + """When memory is roughly equal, ratio should be close to 0.5.""" + from exo.worker.engines.mlx.asymmetric_parallel import find_valid_ratios + + ratios = find_valid_ratios( + memory_fractions=[0.50, 0.50], + hidden_size=4096, + num_attention_heads=32, + num_key_value_heads=8, + ) + # Finder searches > 0.5, so it may find a near-symmetric split + if ratios is not None: + assert ratios[0] < 0.7 # should be close to 0.5