[EP]:Add MCCL all-to-all fallback for MACA EP#2592
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces MACA support and a PyTorch-based fallback implementation of the alltoall operation, including new CUDA kernels for fused packing, compacting, and reducing. The review feedback focuses on critical optimizations and safety improvements in the CUDA kernels, such as utilizing shared memory to reduce redundant global memory reads, adding bounds checks to prevent out-of-bounds accesses, and guarding kernel launches against empty inputs to avoid invalid configuration errors. Additionally, it is recommended to make the PyTorch alltoall fallback implementation stateless to prevent potential race conditions when multiple MoE layers are interleaved.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| int count = load_i32_words(rank_base + local_expert * fused_hidden); | ||
|
|
||
| int begin = 0; | ||
| for (int r = 0; r < src_rank; ++r) { | ||
| const nv_bfloat16* prev_rank_base = | ||
| recv_payload + r * fused_slots_per_rank * fused_hidden; | ||
| begin += load_i32_words(prev_rank_base + local_expert * fused_hidden); | ||
| } | ||
|
|
||
| if (m == 0 && threadIdx.x == 0) { | ||
| layout_range[local_expert * num_ranks + src_rank] = | ||
| (static_cast<int64_t>(begin) << 32) | static_cast<uint32_t>(count); | ||
| atomicAdd(recv_count + local_expert, count); | ||
| } | ||
|
|
||
| if (m >= count) return; | ||
|
|
||
| int src_begin = 0; | ||
| for (int e = 0; e < local_expert; ++e) | ||
| src_begin += load_i32_words(rank_base + e * fused_hidden); |
There was a problem hiding this comment.
In compact_dispatch_fused_kernel, the loop to compute begin and src_begin is executed redundantly by all 256 threads in the block. Since these values only depend on blockIdx.x (which is constant for all threads in the block), we can compute them once in thread 0 and share them via __shared__ memory. This significantly reduces redundant global memory reads and instruction overhead.
__shared__ int shared_count;
__shared__ int shared_begin;
__shared__ int shared_src_begin;
if (threadIdx.x == 0) {
int count = load_i32_words(rank_base + local_expert * fused_hidden);
shared_count = count;
int begin = 0;
for (int r = 0; r < src_rank; ++r) {
const nv_bfloat16* prev_rank_base =
recv_payload + r * fused_slots_per_rank * fused_hidden;
begin += load_i32_words(prev_rank_base + local_expert * fused_hidden);
}
shared_begin = begin;
int src_begin = 0;
for (int e = 0; e < local_expert; ++e)
src_begin += load_i32_words(rank_base + e * fused_hidden);
shared_src_begin = src_begin;
}
__syncthreads();
int count = shared_count;
int begin = shared_begin;
if (m == 0 && threadIdx.x == 0) {
layout_range[local_expert * num_ranks + src_rank] =
(static_cast<int64_t>(begin) << 32) | static_cast<uint32_t>(count);
atomicAdd(recv_count + local_expert, count);
}
if (m >= count) return;
int src_begin = shared_src_begin;
| int64_t expert = topk_idx[i]; | ||
| if (expert < 0) return; | ||
|
|
| count_dispatch_kernel<<<ceil_div(total, 256), 256, 0, stream>>>( | ||
| topk_idx.data_ptr<int64_t>(), counts_by_expert.data_ptr<int32_t>(), | ||
| num_tokens, num_topk, num_local_experts); |
There was a problem hiding this comment.
If total is 0 (e.g., empty batch or no tokens), launching count_dispatch_kernel with 0 blocks will result in a CUDA invalid configuration error. We should wrap the kernel launch in a check if (total > 0) to handle empty inputs gracefully.
if (total > 0) {
count_dispatch_kernel<<<ceil_div(total, 256), 256, 0, stream>>>(
topk_idx.data_ptr<int64_t>(), counts_by_expert.data_ptr<int32_t>(),
num_tokens, num_topk, num_local_experts);
}
| pack_dispatch_fused_kernel<<<total, 256, 0, stream>>>( | ||
| reinterpret_cast<const nv_bfloat16*>(x.data_ptr()), | ||
| topk_idx.data_ptr<int64_t>(), | ||
| expert_offsets.data_ptr<int32_t>(), counters.data_ptr<int32_t>(), | ||
| reinterpret_cast<nv_bfloat16*>(send_payload.data_ptr()), | ||
| send_route.data_ptr<int64_t>(), num_tokens, hidden, num_topk, | ||
| num_ranks, num_local_experts, max_messages_per_rank); |
There was a problem hiding this comment.
If total is 0 (e.g., empty batch or no tokens), launching pack_dispatch_fused_kernel with 0 blocks will result in a CUDA invalid configuration error. We should wrap the kernel launch in a check if (total > 0) to handle empty inputs gracefully.
if (total > 0) {
pack_dispatch_fused_kernel<<<total, 256, 0, stream>>>(
reinterpret_cast<const nv_bfloat16*>(x.data_ptr()),
topk_idx.data_ptr<int64_t>(),
expert_offsets.data_ptr<int32_t>(), counters.data_ptr<int32_t>(),
reinterpret_cast<nv_bfloat16*>(send_payload.data_ptr()),
send_route.data_ptr<int64_t>(), num_tokens, hidden, num_topk,
num_ranks, num_local_experts, max_messages_per_rank);
}
| compact_dispatch_fused_kernel<<<num_ranks * num_local_experts * | ||
| max_messages_per_rank, | ||
| 256, 0, stream>>>( | ||
| reinterpret_cast<const nv_bfloat16*>(recv_payload.data_ptr()), | ||
| layout_range.data_ptr<int64_t>(), recv_count.data_ptr<int32_t>(), | ||
| return_src_pos.data_ptr<int64_t>(), | ||
| reinterpret_cast<nv_bfloat16*>(packed_recv_x.data_ptr()), | ||
| packed_recv_src_info.data_ptr<int32_t>(), hidden, num_ranks, | ||
| num_local_experts, max_messages_per_rank, num_recv_slots); |
There was a problem hiding this comment.
If max_messages_per_rank is 0, launching compact_dispatch_fused_kernel with 0 blocks will result in a CUDA invalid configuration error. We should wrap the kernel launch in a check if (max_messages_per_rank > 0) to handle empty inputs gracefully.
if (max_messages_per_rank > 0) {
compact_dispatch_fused_kernel<<<num_ranks * num_local_experts *
max_messages_per_rank,
256, 0, stream>>>(
reinterpret_cast<const nv_bfloat16*>(recv_payload.data_ptr()),
layout_range.data_ptr<int64_t>(), recv_count.data_ptr<int32_t>(),
return_src_pos.data_ptr<int64_t>(),
reinterpret_cast<nv_bfloat16*>(packed_recv_x.data_ptr()),
packed_recv_src_info.data_ptr<int32_t>(), hidden, num_ranks,
num_local_experts, max_messages_per_rank, num_recv_slots);
}
| self._torch_alltoall_state = { | ||
| "send_route": send_route, | ||
| "return_src_pos": return_src_pos, | ||
| "num_tokens": num_tokens, | ||
| "num_topk": num_topk, | ||
| "max_messages_per_rank": max_messages_per_rank, | ||
| } |
There was a problem hiding this comment.
The current implementation of _torch_alltoall_routed_dispatch and _torch_alltoall_routed_combine uses a stateful dictionary self._torch_alltoall_state to store send_route and return_src_pos. This can lead to subtle bugs or race conditions if multiple MoE layers are executed concurrently or interleaved. We can make the implementation completely stateless by wrapping send_route and return_src_pos inside the src_info tuple returned by dispatch and passed to combine via handle.
0112980 to
187db4a
Compare
|
Codecov Report✅ All modified and coverable lines are covered by tests. 📢 Thoughts on this report? Let us know! |
Description
This PR adds an optional MCCL/torch
all_to_all_singlefallback path forMooncake Expert Parallelism on MACA. The path is enabled by
MOONCAKE_EP_USE_TORCH_ALLTOALL=1and avoids the current MACA P2P kernel pathwhen validating EP dispatch/combine throughput.
The implementation moves the EP data layout stages into C++/MACA helper
kernels:
The dispatch metadata is fused into the payload all-to-all buffer, so dispatch
uses one MCCL collective instead of separate payload and metadata collectives.
Performance on a 2-GPU MetaX C500 run with
tokens=64,hidden=7168,experts=16,topk=4:For comparison, the earlier helper-kernel version before metadata fusion was
about 446.55 us dispatch and 310.73 us combine, or about 757.28 us total. The
best validated fused path reduces dispatch+combine latency by about 16.7%.
Module
mooncake-transfer-engine)mooncake-store)mooncake-ep)mooncake-pg)mooncake-integration)mooncake-p2p-store)mooncake-wheel)mooncake-common)mooncake-rl)Type of Change
How Has This Been Tested?
Test commands:
Test results:
Manual test output:
Additional best validated run:
Checklist
./scripts/code_format.shpre-commit run --all-filesand all hooks passNote: this PR adds more than 500 LOC, mainly from the new MACA all-to-all helper
kernel file. No RFC issue has been filed yet; please file one or uncheck this
item before submitting if required by project policy.
AI Assistance Disclosure
AI tools were used to help inspect the code path, draft the helper-kernel
implementation, summarize benchmark results, and prepare this PR description.