Skip to content
Open
2 changes: 2 additions & 0 deletions dashboard/src/lib/components/ChatSidebar.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,8 @@
const [shardTag] = getTaggedValue(firstShardWrapped);
if (shardTag === "PipelineShardMetadata") sharding = "Pipeline";
else if (shardTag === "TensorShardMetadata") sharding = "Tensor";
else if (shardTag === "AsymmetricTensorShardMetadata")
sharding = "Asymmetric Tensor";
else if (shardTag === "PrefillDecodeShardMetadata")
sharding = "Prefill/Decode";
}
Expand Down
110 changes: 109 additions & 1 deletion dashboard/src/lib/components/TopologyGraph.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
topologyData,
isTopologyMinimized,
debugMode,
instances,
nodeThunderboltBridge,
nodeRdmaCtl,
nodeIdentities,
Expand All @@ -31,11 +32,69 @@

const isMinimized = $derived(isTopologyMinimized());
const data = $derived(topologyData());
const instanceData = $derived(instances());
const debugEnabled = $derived(debugMode());
const tbBridgeData = $derived(nodeThunderboltBridge());
const rdmaCtlData = $derived(nodeRdmaCtl());
const identitiesData = $derived(nodeIdentities());

function getTaggedValue(value: unknown): [string | null, unknown] {
if (!value || typeof value !== "object") return [null, null];
const record = value as Record<string, unknown>;
const keys = Object.keys(record);
if (keys.length !== 1) return [null, value];
return [keys[0], record[keys[0]]];
}

function getAsymmetricModelShareByNode(): Map<string, number> {
const shares = new Map<string, number>();

for (const instanceWrapped of Object.values(instanceData)) {
const [, instance] = getTaggedValue(instanceWrapped);
if (!instance || typeof instance !== "object") continue;

const shardAssignments = (
instance as {
shardAssignments?: {
nodeToRunner?: Record<string, string>;
runnerToShard?: Record<string, unknown>;
};
}
).shardAssignments;
if (!shardAssignments?.nodeToRunner || !shardAssignments.runnerToShard) {
continue;
}

for (const [nodeId, runnerId] of Object.entries(
shardAssignments.nodeToRunner,
)) {
const shardWrapped = shardAssignments.runnerToShard[runnerId];
const [shardTag, shardValue] = getTaggedValue(shardWrapped);
if (
shardTag !== "AsymmetricTensorShardMetadata" ||
!shardValue ||
typeof shardValue !== "object"
) {
continue;
}

const shard = shardValue as {
deviceRank?: number;
ratio?: number;
worldSize?: number;
};
if (shard.worldSize !== 2 || typeof shard.ratio !== "number") {
continue;
}

const share = shard.deviceRank === 0 ? shard.ratio : 1 - shard.ratio;
shares.set(nodeId, share);
}
}

return shares;
}

function getNodeLabel(nodeId: string): string {
const node = data?.nodes?.[nodeId];
return node?.friendly_name || nodeId.slice(0, 8);
Expand Down Expand Up @@ -166,6 +225,7 @@
const nodes = data.nodes || {};
const edges = data.edges || [];
const nodeIds = Object.keys(nodes);
const asymmetricModelShareByNode = getAsymmetricModelShareByNode();

const rect = svgContainer.getBoundingClientRect();
const width = rect.width;
Expand Down Expand Up @@ -562,6 +622,7 @@
const isFilteredOut =
filteredNodes.size > 0 && !filteredNodes.has(nodeInfo.id);
const isHovered = hoveredNodeId === nodeInfo.id && !isInFilter;
const asymmetricShare = asymmetricModelShareByNode.get(nodeInfo.id);

// Holographic wireframe colors - bright yellow for filter, subtle yellow for hover, grey for filtered out
const wireColor = isInFilter
Expand Down Expand Up @@ -1137,12 +1198,59 @@
.text(` (${ramUsagePercent.toFixed(0)}%)`);
}

if (asymmetricShare !== undefined) {
const sharePercent = Math.round(asymmetricShare * 100);
const badgeY =
nodeInfo.y +
iconBaseHeight / 2 +
(showFullLabels ? 34 : showCompactLabels ? 24 : 22);
const badgeText = `${sharePercent}% MODEL`;
const badgeFontSize = showFullLabels ? 10 : showCompactLabels ? 7 : 7;
const badgeWidth = Math.max(
showFullLabels ? 70 : 52,
badgeText.length * badgeFontSize * 0.62 + 12,
);
const badgeHeight = showFullLabels ? 16 : 12;

nodeG
.append("rect")
.attr("x", nodeInfo.x - badgeWidth / 2)
.attr("y", badgeY - badgeHeight / 2)
.attr("width", badgeWidth)
.attr("height", badgeHeight)
.attr("rx", badgeHeight / 2)
.attr("fill", "rgba(255,215,0,0.14)")
.attr("stroke", "rgba(255,215,0,0.55)")
.attr("stroke-width", 1);

nodeG
.append("text")
.attr("x", nodeInfo.x)
.attr("y", badgeY + 0.5)
.attr("text-anchor", "middle")
.attr("dominant-baseline", "middle")
.attr("fill", "rgba(255,215,0,0.95)")
.attr("font-size", badgeFontSize)
.attr("font-weight", "700")
.attr("font-family", "SF Mono, Monaco, monospace")
.attr("letter-spacing", "0.04em")
.text(badgeText);
}

// Debug mode: Show TB bridge and RDMA status
if (debugEnabled) {
let debugLabelY =
nodeInfo.y +
iconBaseHeight / 2 +
(showFullLabels ? 32 : showCompactLabels ? 26 : 22);
(asymmetricShare !== undefined
? showFullLabels
? 52
: 38
: showFullLabels
? 32
: showCompactLabels
? 26
: 22);
const debugFontSize = showFullLabels ? 9 : 7;
const debugLineHeight = showFullLabels ? 11 : 9;

Expand Down
2 changes: 2 additions & 0 deletions dashboard/src/lib/stores/app.svelte.ts
Original file line number Diff line number Diff line change
Expand Up @@ -945,6 +945,8 @@ class AppStore {
const [shardTag] = this.getTaggedValue(firstShardWrapped);
if (shardTag === "PipelineShardMetadata") sharding = "Pipeline";
else if (shardTag === "TensorShardMetadata") sharding = "Tensor";
else if (shardTag === "AsymmetricTensorShardMetadata")
sharding = "Asymmetric Tensor";
else if (shardTag === "PrefillDecodeShardMetadata")
sharding = "Prefill/Decode";
}
Expand Down
11 changes: 9 additions & 2 deletions dashboard/src/routes/+page.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -147,9 +147,14 @@
if (!rdmaCtl) return false;
const ids = tbIdentifiers;
if (!ids) return false;
// Find nodes with TB5 hardware (any TB interface)
// Find nodes with TB5 hardware (link speed > 40 Gb/s)
const tb5NodeIds = Object.entries(ids)
.filter(([_, node]) => node.interfaces.length > 0)
.filter(([_, node]) =>
node.interfaces.some((iface: { linkSpeed: string }) => {
const match = iface.linkSpeed.match(/(\d+)\s*Gb/i);
return match != null && parseInt(match[1], 10) > 40;
}),
)
.map(([id]) => id);
if (tb5NodeIds.length < 2) return false;
// At least one TB5 node has RDMA disabled
Expand Down Expand Up @@ -2079,6 +2084,8 @@
const [shardTag] = getTagged(firstShardWrapped);
if (shardTag === "PipelineShardMetadata") sharding = "Pipeline";
else if (shardTag === "TensorShardMetadata") sharding = "Tensor";
else if (shardTag === "AsymmetricTensorShardMetadata")
sharding = "Asymmetric Tensor";
else if (shardTag === "PrefillDecodeShardMetadata")
sharding = "Prefill/Decode";
}
Expand Down
33 changes: 27 additions & 6 deletions src/exo/api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@
)
from exo.shared.types.worker.downloads import DownloadCompleted
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.shards import Sharding
from exo.shared.types.worker.shards import AsymmetricTensorShardMetadata, Sharding
from exo.utils.banner import print_startup_banner
from exo.utils.channels import Receiver, Sender, channel
from exo.utils.disk_event_log import DiskEventLog
Expand Down Expand Up @@ -587,11 +587,32 @@ async def get_placement_previews(
memory_delta_by_node: dict[str, int] = {}
if placement_node_ids:
total_bytes = model_card.storage_size.in_bytes
per_node = total_bytes // len(placement_node_ids)
remainder = total_bytes % len(placement_node_ids)
for index, node_id in enumerate(sorted(placement_node_ids, key=str)):
extra = 1 if index < remainder else 0
memory_delta_by_node[str(node_id)] = per_node + extra
asymmetric_shards: dict[NodeId, AsymmetricTensorShardMetadata] = {}
for (
node_id,
runner_id,
) in shard_assignments.node_to_runner.items():
shard_metadata = shard_assignments.runner_to_shard[runner_id]
if isinstance(shard_metadata, AsymmetricTensorShardMetadata):
asymmetric_shards[node_id] = shard_metadata
if asymmetric_shards:
for node_id, shard_metadata in asymmetric_shards.items():
rank_weight_fraction = (
shard_metadata.ratio
if shard_metadata.device_rank == 0
else 1.0 - shard_metadata.ratio
)
memory_delta_by_node[str(node_id)] = int(
total_bytes * rank_weight_fraction
)
else:
per_node = total_bytes // len(placement_node_ids)
remainder = total_bytes % len(placement_node_ids)
for index, node_id in enumerate(
sorted(placement_node_ids, key=str)
):
extra = 1 if index < remainder else 0
memory_delta_by_node[str(node_id)] = per_node + extra

if (
model_card.model_id,
Expand Down
Loading