Skip to content

[EP]:Add MCCL all-to-all fallback for MACA EP#2592

Open
Dayuxiaoshui wants to merge 2 commits into
kvcache-ai:mainfrom
Dayuxiaoshui:ep-support-maca
Open

[EP]:Add MCCL all-to-all fallback for MACA EP#2592
Dayuxiaoshui wants to merge 2 commits into
kvcache-ai:mainfrom
Dayuxiaoshui:ep-support-maca

Conversation

@Dayuxiaoshui

@Dayuxiaoshui Dayuxiaoshui commented Jun 24, 2026

Copy link
Copy Markdown
Collaborator

Description

This PR adds an optional MCCL/torch all_to_all_single fallback path for
Mooncake Expert Parallelism on MACA. The path is enabled by
MOONCAKE_EP_USE_TORCH_ALLTOALL=1 and avoids the current MACA P2P kernel path
when validating EP dispatch/combine throughput.

The implementation moves the EP data layout stages into C++/MACA helper
kernels:

  • dispatch pack into fused per-rank payload buffers;
  • receive-side compact into per-expert buffers;
  • combine-side pack back to per-rank payload buffers;
  • combine weighted reduce into the final token output.

The dispatch metadata is fused into the payload all-to-all buffer, so dispatch
uses one MCCL collective instead of separate payload and metadata collectives.

Performance on a 2-GPU MetaX C500 run with tokens=64, hidden=7168,
experts=16, topk=4:

Run Dispatch Combine Dispatch + Combine
Best validated run 363.49 us, 10.12 GB/s 267.17 us, 13.74 GB/s 630.66 us
Post-cleanup quick regression 394.73 us, 9.32 GB/s 279.85 us, 13.11 GB/s 674.58 us

For comparison, the earlier helper-kernel version before metadata fusion was
about 446.55 us dispatch and 310.73 us combine, or about 757.28 us total. The
best validated fused path reduces dispatch+combine latency by about 16.7%.

Module

  • Transfer Engine (mooncake-transfer-engine)
  • Mooncake Store (mooncake-store)
  • Mooncake EP (mooncake-ep)
  • Mooncake PG (mooncake-pg)
  • Integration (mooncake-integration)
  • P2P Store (mooncake-p2p-store)
  • Python Wheel (mooncake-wheel)
  • Common (mooncake-common)
  • Mooncake RL (mooncake-rl)
  • CI/CD
  • Docs
  • Other

Type of Change

  • Bug fix
  • New feature
  • Refactor
  • Breaking change
  • Documentation update
  • Performance improvement
  • Other

How Has This Been Tested?

Test commands:

export PYTHONUNBUFFERED=1
export PYTHONPATH=/home/zhouyuhan/Mooncake/mooncake-wheel:$PYTHONPATH
export LD_LIBRARY_PATH=/home/zhouyuhan/Mooncake/build-metax-ep-ninja/mooncake-common:/home/zhouyuhan/Mooncake/.deps/sysroot/usr/lib/x86_64-linux-gnu:/opt/miniconda3/envs/py310/lib/python3.10/site-packages/torch/lib:/opt/miniconda3/envs/py310/lib:/opt/maca/lib:/opt/maca/mxgpu_llvm/lib:/opt/maca/ompi/lib:$LD_LIBRARY_PATH
export MCCL_P2P_LEVEL=SYS
export MCCL_P2P_DISABLE=0
export MOONCAKE_EP_USE_TORCH_ALLTOALL=1

/opt/miniconda3/envs/py310/bin/python scripts/metax/smoke_ep_torch_alltoall.py \
  --world-size 2 \
  --tokens 32 \
  --hidden 1024 \
  --experts 16 \
  --topk 2 \
  --port 29801

/opt/miniconda3/envs/py310/bin/python scripts/metax/bench_ep_torch_alltoall_split.py \
  --world-size 2 \
  --tokens 64 \
  --hidden 7168 \
  --experts 16 \
  --topk 4 \
  --warmups 3 \
  --iters 5 \
  --port 29802

Test results:

  • Unit tests pass
  • Integration tests pass (if applicable)
  • Manual testing done (describe below)

Manual test output:

smoke: passed

t64/topk=4 quick regression:
dispatch.total: avg=394.73 us min=380.19 us max=410.72 us bw=9.32 GB/s
combine.total: avg=279.85 us min=262.73 us max=324.06 us bw=13.11 GB/s

Additional best validated run:

dispatch.total: avg=363.49 us bw=10.12 GB/s
combine.total: avg=267.17 us bw=13.74 GB/s

Checklist

  • I have performed a self-review of my own code
  • I have formatted my code using ./scripts/code_format.sh
  • I have run pre-commit run --all-files and all hooks pass
  • I have updated the documentation (if applicable)
  • I have added tests to prove my changes are effective
  • For changes >500 LOC: I have filed an RFC issue

Note: this PR adds more than 500 LOC, mainly from the new MACA all-to-all helper
kernel file. No RFC issue has been filed yet; please file one or uncheck this
item before submitting if required by project policy.

AI Assistance Disclosure

  • No AI tools were used
  • AI tools were used (specify below)

AI tools were used to help inspect the code path, draft the helper-kernel
implementation, summarize benchmark results, and prepare this PR description.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces MACA support and a PyTorch-based fallback implementation of the alltoall operation, including new CUDA kernels for fused packing, compacting, and reducing. The review feedback focuses on critical optimizations and safety improvements in the CUDA kernels, such as utilizing shared memory to reduce redundant global memory reads, adding bounds checks to prevent out-of-bounds accesses, and guarding kernel launches against empty inputs to avoid invalid configuration errors. Additionally, it is recommended to make the PyTorch alltoall fallback implementation stateless to prevent potential race conditions when multiple MoE layers are interleaved.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment thread mooncake-ep/src/mooncake_ep_alltoall.cu Outdated
Comment on lines +155 to +174
int count = load_i32_words(rank_base + local_expert * fused_hidden);

int begin = 0;
for (int r = 0; r < src_rank; ++r) {
const nv_bfloat16* prev_rank_base =
recv_payload + r * fused_slots_per_rank * fused_hidden;
begin += load_i32_words(prev_rank_base + local_expert * fused_hidden);
}

if (m == 0 && threadIdx.x == 0) {
layout_range[local_expert * num_ranks + src_rank] =
(static_cast<int64_t>(begin) << 32) | static_cast<uint32_t>(count);
atomicAdd(recv_count + local_expert, count);
}

if (m >= count) return;

int src_begin = 0;
for (int e = 0; e < local_expert; ++e)
src_begin += load_i32_words(rank_base + e * fused_hidden);

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

In compact_dispatch_fused_kernel, the loop to compute begin and src_begin is executed redundantly by all 256 threads in the block. Since these values only depend on blockIdx.x (which is constant for all threads in the block), we can compute them once in thread 0 and share them via __shared__ memory. This significantly reduces redundant global memory reads and instruction overhead.

    __shared__ int shared_count;
    __shared__ int shared_begin;
    __shared__ int shared_src_begin;
    if (threadIdx.x == 0) {
        int count = load_i32_words(rank_base + local_expert * fused_hidden);
        shared_count = count;

        int begin = 0;
        for (int r = 0; r < src_rank; ++r) {
            const nv_bfloat16* prev_rank_base = 
                recv_payload + r * fused_slots_per_rank * fused_hidden;
            begin += load_i32_words(prev_rank_base + local_expert * fused_hidden);
        }
        shared_begin = begin;

        int src_begin = 0;
        for (int e = 0; e < local_expert; ++e) 
            src_begin += load_i32_words(rank_base + e * fused_hidden);
        shared_src_begin = src_begin;
    }
    __syncthreads();
    int count = shared_count;
    int begin = shared_begin;

    if (m == 0 && threadIdx.x == 0) {
        layout_range[local_expert * num_ranks + src_rank] =
            (static_cast<int64_t>(begin) << 32) | static_cast<uint32_t>(count);
        atomicAdd(recv_count + local_expert, count);
    }

    if (m >= count) return;
    int src_begin = shared_src_begin;

Comment thread mooncake-ep/src/mooncake_ep_alltoall.cu Outdated
Comment on lines +79 to +81
int64_t expert = topk_idx[i];
if (expert < 0) return;

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Add a bounds check for expert to prevent potential out-of-bounds memory access on counters and expert_offsets if expert >= num_ranks * num_local_experts.

    int64_t expert = topk_idx[i];
    if (expert < 0 || expert >= num_ranks * num_local_experts) return;

Comment thread mooncake-ep/src/mooncake_ep_alltoall.cu Outdated
Comment on lines +283 to +285
count_dispatch_kernel<<<ceil_div(total, 256), 256, 0, stream>>>(
topk_idx.data_ptr<int64_t>(), counts_by_expert.data_ptr<int32_t>(),
num_tokens, num_topk, num_local_experts);

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

If total is 0 (e.g., empty batch or no tokens), launching count_dispatch_kernel with 0 blocks will result in a CUDA invalid configuration error. We should wrap the kernel launch in a check if (total > 0) to handle empty inputs gracefully.

    if (total > 0) {
        count_dispatch_kernel<<<ceil_div(total, 256), 256, 0, stream>>>(
            topk_idx.data_ptr<int64_t>(), counts_by_expert.data_ptr<int32_t>(),
            num_tokens, num_topk, num_local_experts);
    }

Comment thread mooncake-ep/src/mooncake_ep_alltoall.cu Outdated
Comment on lines +298 to +304
pack_dispatch_fused_kernel<<<total, 256, 0, stream>>>(
reinterpret_cast<const nv_bfloat16*>(x.data_ptr()),
topk_idx.data_ptr<int64_t>(),
expert_offsets.data_ptr<int32_t>(), counters.data_ptr<int32_t>(),
reinterpret_cast<nv_bfloat16*>(send_payload.data_ptr()),
send_route.data_ptr<int64_t>(), num_tokens, hidden, num_topk,
num_ranks, num_local_experts, max_messages_per_rank);

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

If total is 0 (e.g., empty batch or no tokens), launching pack_dispatch_fused_kernel with 0 blocks will result in a CUDA invalid configuration error. We should wrap the kernel launch in a check if (total > 0) to handle empty inputs gracefully.

    if (total > 0) {
        pack_dispatch_fused_kernel<<<total, 256, 0, stream>>>(
            reinterpret_cast<const nv_bfloat16*>(x.data_ptr()),
            topk_idx.data_ptr<int64_t>(),
            expert_offsets.data_ptr<int32_t>(), counters.data_ptr<int32_t>(),
            reinterpret_cast<nv_bfloat16*>(send_payload.data_ptr()),
            send_route.data_ptr<int64_t>(), num_tokens, hidden, num_topk,
            num_ranks, num_local_experts, max_messages_per_rank);
    }

Comment thread mooncake-ep/src/mooncake_ep_alltoall.cu Outdated
Comment on lines +351 to +359
compact_dispatch_fused_kernel<<<num_ranks * num_local_experts *
max_messages_per_rank,
256, 0, stream>>>(
reinterpret_cast<const nv_bfloat16*>(recv_payload.data_ptr()),
layout_range.data_ptr<int64_t>(), recv_count.data_ptr<int32_t>(),
return_src_pos.data_ptr<int64_t>(),
reinterpret_cast<nv_bfloat16*>(packed_recv_x.data_ptr()),
packed_recv_src_info.data_ptr<int32_t>(), hidden, num_ranks,
num_local_experts, max_messages_per_rank, num_recv_slots);

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

If max_messages_per_rank is 0, launching compact_dispatch_fused_kernel with 0 blocks will result in a CUDA invalid configuration error. We should wrap the kernel launch in a check if (max_messages_per_rank > 0) to handle empty inputs gracefully.

    if (max_messages_per_rank > 0) {
        compact_dispatch_fused_kernel<<<num_ranks * num_local_experts *
                                            max_messages_per_rank,
                                        256, 0, stream>>>(
            reinterpret_cast<const nv_bfloat16*>(recv_payload.data_ptr()),
            layout_range.data_ptr<int64_t>(), recv_count.data_ptr<int32_t>(),
            return_src_pos.data_ptr<int64_t>(),
            reinterpret_cast<nv_bfloat16*>(packed_recv_x.data_ptr()),
            packed_recv_src_info.data_ptr<int32_t>(), hidden, num_ranks,
            num_local_experts, max_messages_per_rank, num_recv_slots);
    }

Comment on lines +669 to +675
self._torch_alltoall_state = {
"send_route": send_route,
"return_src_pos": return_src_pos,
"num_tokens": num_tokens,
"num_topk": num_topk,
"max_messages_per_rank": max_messages_per_rank,
}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current implementation of _torch_alltoall_routed_dispatch and _torch_alltoall_routed_combine uses a stateful dictionary self._torch_alltoall_state to store send_route and return_src_pos. This can lead to subtle bugs or race conditions if multiple MoE layers are executed concurrently or interleaved. We can make the implementation completely stateless by wrapping send_route and return_src_pos inside the src_info tuple returned by dispatch and passed to combine via handle.

@codecov-commenter

Copy link
Copy Markdown

⚠️ Please install the 'codecov app svg image' to ensure uploads and comments are reliably processed by Codecov.

Codecov Report

✅ All modified and coverable lines are covered by tests.

📢 Thoughts on this report? Let us know!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[RFC]: Mooncake EP Multi-Vendor GPU Adaptation Design [Feature Request]: MACA Support for Mooncake EP/PG

2 participants