From 187db4a49a02936fc6f43f301009f83bcc88d2aa Mon Sep 17 00:00:00 2001 From: Dayuxiaoshui <792179245@qq.com> Date: Wed, 24 Jun 2026 06:17:52 +0000 Subject: [PATCH 1/3] Add MCCL all-to-all fallback for MACA EP --- mooncake-ep/include/mooncake_ep_alltoall.h | 26 ++ mooncake-ep/setup.py | 43 +- mooncake-ep/src/CMakeLists.txt | 6 +- mooncake-ep/src/ep_py.cpp | 7 + mooncake-ep/src/mooncake_ep_alltoall.cu | 406 +++++++++++++++++ mooncake-wheel/mooncake/mooncake_ep_buffer.py | 413 ++++++++++++++++-- 6 files changed, 870 insertions(+), 31 deletions(-) create mode 100644 mooncake-ep/include/mooncake_ep_alltoall.h create mode 100644 mooncake-ep/src/mooncake_ep_alltoall.cu diff --git a/mooncake-ep/include/mooncake_ep_alltoall.h b/mooncake-ep/include/mooncake_ep_alltoall.h new file mode 100644 index 0000000000..2b6b07c2a7 --- /dev/null +++ b/mooncake-ep/include/mooncake_ep_alltoall.h @@ -0,0 +1,26 @@ +#pragma once + +#include + +namespace mooncake { + +std::tuple torch_alltoall_pack_dispatch_fused( + const torch::Tensor& x, const torch::Tensor& topk_idx, int num_experts, + int num_ranks); + +std::tuple +torch_alltoall_compact_dispatch_fused(const torch::Tensor& recv_payload, + int num_local_experts, + int num_max_dispatch_tokens_per_rank); + +torch::Tensor torch_alltoall_pack_combine(const torch::Tensor& expert_buffers, + const torch::Tensor& return_src_pos, + int num_ranks, + int max_messages_per_rank); + +torch::Tensor torch_alltoall_reduce_combine( + const torch::Tensor& recv_payload, const torch::Tensor& send_route, + const torch::Tensor& topk_weights, const std::optional& out); + +} // namespace mooncake diff --git a/mooncake-ep/setup.py b/mooncake-ep/setup.py index 0345c54865..bf1614b6a3 100644 --- a/mooncake-ep/setup.py +++ b/mooncake-ep/setup.py @@ -5,6 +5,10 @@ import torch use_musa = os.getenv("MOONCAKE_EP_USE_MUSA", "").upper() in {"1", "ON", "TRUE", "YES"} +use_maca = ( + os.getenv("MOONCAKE_EP_USE_MACA", "").upper() in {"1", "ON", "TRUE", "YES"} + or (hasattr(torch.version, "maca") and torch.version.maca is not None) +) if use_musa: try: import torchada # noqa: F401 @@ -28,12 +32,33 @@ abi_flag = int(torch._C._GLIBCXX_USE_CXX11_ABI) current_dir = os.path.abspath(os.path.dirname(__file__)) +repo_dir = os.path.abspath(os.path.join(current_dir, os.pardir)) +sysroot_dir = os.path.join(repo_dir, ".deps", "sysroot", "usr") + + +def existing_dirs(*paths): + return [path for path in paths if os.path.isdir(path)] + + +sysroot_include_dirs = existing_dirs( + os.path.join(sysroot_dir, "include"), + os.path.join(sysroot_dir, "include", "jsoncpp"), + os.path.join(sysroot_dir, "include", "libnl3"), +) +sysroot_library_dirs = existing_dirs( + os.path.join(sysroot_dir, "lib", "x86_64-linux-gnu"), + os.path.join(sysroot_dir, "lib"), +) abi_define = f"-D_GLIBCXX_USE_CXX11_ABI={abi_flag}" cxx_args = [abi_define, "-std=c++20", "-O3", "-g0"] cuda_libraries = ["ibverbs", "mlx5"] cuda_library_dirs = [] +include_dirs = [ + os.path.join(current_dir, "include"), + os.path.join(current_dir, "../mooncake-transfer-engine/include"), +] if use_musa: cuda_libraries = [] @@ -48,6 +73,18 @@ "--cuda-gpu-arch=mp_31", "-O3", ] +elif use_maca: + cuda_libraries = [] + cuda_library_dirs = sysroot_library_dirs.copy() + include_dirs += sysroot_include_dirs + maca_defines = ["-DUSE_MACA", "-DMOONCAKE_EP_USE_MACA=1"] + cxx_args += maca_defines + device_args = [ + abi_define, + *maca_defines, + "-std=c++20", + "-O3", + ] else: cxx_args.append("-DUSE_CUDA") device_args = [ @@ -72,14 +109,12 @@ ext_modules=[ CUDAExtension( name=module_name, - include_dirs=[ - os.path.join(current_dir, "include"), - os.path.join(current_dir, "../mooncake-transfer-engine/include"), - ], + include_dirs=include_dirs, sources=[ "src/ep_py.cpp", "src/mooncake_ep_buffer.cpp", "src/mooncake_ep_kernel.cu", + "src/mooncake_ep_alltoall.cu", ], extra_compile_args={"cxx": cxx_args, "nvcc": device_args}, libraries=cuda_libraries, diff --git a/mooncake-ep/src/CMakeLists.txt b/mooncake-ep/src/CMakeLists.txt index 574ab514c0..1394615ea3 100644 --- a/mooncake-ep/src/CMakeLists.txt +++ b/mooncake-ep/src/CMakeLists.txt @@ -1,4 +1,8 @@ -add_library(mooncake_ep ep_py.cpp mooncake_ep_buffer.cpp mooncake_ep_kernel.cu) +add_library(mooncake_ep + ep_py.cpp + mooncake_ep_buffer.cpp + mooncake_ep_kernel.cu + mooncake_ep_alltoall.cu) set_target_properties(mooncake_ep PROPERTIES POSITION_INDEPENDENT_CODE ON) target_link_libraries(mooncake_ep PUBLIC ${TORCH_LIBRARIES} transfer_engine ibverbs mlx5) diff --git a/mooncake-ep/src/ep_py.cpp b/mooncake-ep/src/ep_py.cpp index bd3cf53a99..9937311b77 100644 --- a/mooncake-ep/src/ep_py.cpp +++ b/mooncake-ep/src/ep_py.cpp @@ -1,4 +1,5 @@ #include +#include #include #include #include @@ -13,6 +14,12 @@ namespace mooncake { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("get_ep_buffer_size_hint", &get_ep_buffer_size_hint); + m.def("torch_alltoall_pack_dispatch_fused", + &torch_alltoall_pack_dispatch_fused); + m.def("torch_alltoall_compact_dispatch_fused", + &torch_alltoall_compact_dispatch_fused); + m.def("torch_alltoall_pack_combine", &torch_alltoall_pack_combine); + m.def("torch_alltoall_reduce_combine", &torch_alltoall_reduce_combine); py::class_(m, "EventHandle") .def(py::init<>()) diff --git a/mooncake-ep/src/mooncake_ep_alltoall.cu b/mooncake-ep/src/mooncake_ep_alltoall.cu new file mode 100644 index 0000000000..fb1493a550 --- /dev/null +++ b/mooncake-ep/src/mooncake_ep_alltoall.cu @@ -0,0 +1,406 @@ +#include + +#include +#include +#include + +#include + +namespace mooncake { +namespace { + +inline void check_cuda(const torch::Tensor& t, const char* name) { + TORCH_CHECK(t.is_cuda(), name, " must be a CUDA tensor"); + TORCH_CHECK(t.is_contiguous(), name, " must be contiguous"); +} + +inline int ceil_div(int a, int b) { return (a + b - 1) / b; } + +__device__ __forceinline__ void store_i32_words(nv_bfloat16* dst, + int32_t value) { + auto* words = reinterpret_cast(dst); + uint32_t raw = static_cast(value); + words[0] = static_cast(raw & 0xffffu); + words[1] = static_cast(raw >> 16); +} + +__device__ __forceinline__ int32_t load_i32_words(const nv_bfloat16* src) { + const auto* words = reinterpret_cast(src); + uint32_t raw = static_cast(words[0]) | + (static_cast(words[1]) << 16); + return static_cast(raw); +} + +__global__ void count_dispatch_kernel(const int64_t* topk_idx, + int32_t* counts_by_expert, int num_tokens, + int num_topk, int num_local_experts) { + int i = blockIdx.x * blockDim.x + threadIdx.x; + int total = num_tokens * num_topk; + if (i >= total) return; + + int64_t expert = topk_idx[i]; + if (expert < 0) return; + int dst_rank = static_cast(expert / num_local_experts); + int local_expert = static_cast(expert % num_local_experts); + atomicAdd(counts_by_expert + dst_rank * num_local_experts + local_expert, + 1); +} + +__global__ void prefix_counts_kernel(const int32_t* counts_by_expert, + int32_t* expert_offsets, int num_ranks, + int num_local_experts) { + int r = blockIdx.x; + if (r >= num_ranks || threadIdx.x != 0) return; + int running = 0; + for (int e = 0; e < num_local_experts; ++e) { + expert_offsets[r * num_local_experts + e] = running; + running += counts_by_expert[r * num_local_experts + e]; + } +} + +__global__ void pack_dispatch_fused_kernel( + const nv_bfloat16* x, const int64_t* topk_idx, + const int32_t* expert_offsets, int32_t* counters, nv_bfloat16* send_payload, + int64_t* send_route, int num_tokens, int hidden, int num_topk, + int num_ranks, int num_local_experts, int max_messages_per_rank) { + int i = blockIdx.x; + int tid = threadIdx.x; + int total = num_tokens * num_topk; + if (i >= total) return; + + int64_t expert = topk_idx[i]; + if (expert < 0) return; + + int token = i / num_topk; + int slot = i - token * num_topk; + int dst_rank = static_cast(expert / num_local_experts); + int local_expert = static_cast(expert % num_local_experts); + int group = dst_rank * num_local_experts + local_expert; + __shared__ int block_local_pos; + if (tid == 0) block_local_pos = atomicAdd(counters + group, 1); + __syncthreads(); + int local_pos = block_local_pos; + int pos = expert_offsets[group] + local_pos; + if (pos >= max_messages_per_rank) return; + + int fused_hidden = hidden + 2; + int fused_slots_per_rank = max_messages_per_rank + num_local_experts; + nv_bfloat16* dst = send_payload + (dst_rank * fused_slots_per_rank + + num_local_experts + pos) * + fused_hidden; + const nv_bfloat16* src = x + token * hidden; + for (int h = tid; h < hidden; h += blockDim.x) dst[h] = src[h]; + if (tid == 0) { + store_i32_words(dst + hidden, slot * num_tokens + token); + int route_idx = i * 4; + send_route[route_idx + 0] = dst_rank; + send_route[route_idx + 1] = pos; + send_route[route_idx + 2] = token; + send_route[route_idx + 3] = slot; + } +} + +__global__ void append_counts_to_payload_kernel( + const int32_t* counts_by_expert, nv_bfloat16* send_payload, int num_ranks, + int num_local_experts, int max_messages_per_rank, int hidden) { + int i = blockIdx.x * blockDim.x + threadIdx.x; + int total = num_ranks * num_local_experts; + if (i >= total) return; + int dst_rank = i / num_local_experts; + int local_expert = i - dst_rank * num_local_experts; + int fused_hidden = hidden + 2; + int fused_slots_per_rank = max_messages_per_rank + num_local_experts; + nv_bfloat16* dst = + send_payload + + (dst_rank * fused_slots_per_rank + local_expert) * fused_hidden; + store_i32_words(dst, counts_by_expert[i]); +} + +__global__ void compact_dispatch_fused_kernel( + const nv_bfloat16* recv_payload, int64_t* layout_range, int32_t* recv_count, + int64_t* return_src_pos, nv_bfloat16* packed_recv_x, + int32_t* packed_recv_src_info, int hidden, int num_ranks, + int num_local_experts, int max_messages_per_rank, int num_recv_slots) { + int idx = blockIdx.x; + int m = idx % max_messages_per_rank; + int pair = idx / max_messages_per_rank; + int src_rank = pair / num_local_experts; + int local_expert = pair - src_rank * num_local_experts; + if (src_rank >= num_ranks) return; + + int fused_hidden = hidden + 2; + int fused_slots_per_rank = max_messages_per_rank + num_local_experts; + const nv_bfloat16* rank_base = + recv_payload + src_rank * fused_slots_per_rank * fused_hidden; + int count = load_i32_words(rank_base + local_expert * fused_hidden); + + int begin = 0; + for (int r = 0; r < src_rank; ++r) { + const nv_bfloat16* prev_rank_base = + recv_payload + r * fused_slots_per_rank * fused_hidden; + begin += load_i32_words(prev_rank_base + local_expert * fused_hidden); + } + + if (m == 0 && threadIdx.x == 0) { + layout_range[local_expert * num_ranks + src_rank] = + (static_cast(begin) << 32) | static_cast(count); + atomicAdd(recv_count + local_expert, count); + } + + if (m >= count) return; + + int src_begin = 0; + for (int e = 0; e < local_expert; ++e) + src_begin += load_i32_words(rank_base + e * fused_hidden); + + int dst_pos = begin + m; + int src_pos = src_begin + m; + if (dst_pos >= num_recv_slots || src_pos >= max_messages_per_rank) return; + + const nv_bfloat16* src = + rank_base + (num_local_experts + src_pos) * fused_hidden; + nv_bfloat16* dst = + packed_recv_x + (local_expert * num_recv_slots + dst_pos) * hidden; + for (int h = threadIdx.x; h < hidden; h += blockDim.x) dst[h] = src[h]; + if (threadIdx.x == 0) { + packed_recv_src_info[local_expert * num_recv_slots + dst_pos] = + load_i32_words(src + hidden); + int route_idx = (local_expert * num_recv_slots + dst_pos) * 2; + return_src_pos[route_idx + 0] = src_rank; + return_src_pos[route_idx + 1] = src_pos; + } +} + +__global__ void pack_combine_kernel(const nv_bfloat16* expert_buffers, + const int64_t* return_src_pos, + nv_bfloat16* send_payload, int hidden, + int max_messages_per_rank, + int num_recv_slots) { + int idx = blockIdx.x; + int tid = threadIdx.x; + int local_expert = idx / num_recv_slots; + int expert_pos = idx - local_expert * num_recv_slots; + int route_idx = idx * 2; + int64_t dst_rank = return_src_pos[route_idx + 0]; + int64_t dst_pos = return_src_pos[route_idx + 1]; + if (dst_rank < 0 || dst_pos < 0) return; + + nv_bfloat16* dst = + send_payload + (dst_rank * max_messages_per_rank + dst_pos) * hidden; + const nv_bfloat16* src = + expert_buffers + (local_expert * num_recv_slots + expert_pos) * hidden; + for (int h = tid; h < hidden; h += blockDim.x) dst[h] = src[h]; +} + +__global__ void reduce_combine_kernel(const nv_bfloat16* recv_payload, + const int64_t* send_route, + const float* topk_weights, + nv_bfloat16* combined, int num_tokens, + int hidden, int num_topk, + int max_messages_per_rank) { + int token = blockIdx.x; + int h = blockIdx.y * blockDim.x + threadIdx.x; + if (token >= num_tokens || h >= hidden) return; + + float acc = 0.0f; + for (int slot = 0; slot < num_topk; ++slot) { + int route = token * num_topk + slot; + int64_t rank = send_route[route * 4 + 0]; + int64_t pos = send_route[route * 4 + 1]; + if (rank < 0 || pos < 0) continue; + const nv_bfloat16* src = + recv_payload + (rank * max_messages_per_rank + pos) * hidden; + float weight = topk_weights[route]; + acc += __bfloat162float(src[h]) * weight; + } + combined[token * hidden + h] = __float2bfloat16(acc); +} + +} // namespace + +std::tuple torch_alltoall_pack_dispatch_fused( + const torch::Tensor& x, const torch::Tensor& topk_idx, int num_experts, + int num_ranks) { + check_cuda(x, "x"); + check_cuda(topk_idx, "topk_idx"); + TORCH_CHECK(x.scalar_type() == torch::kBFloat16, "x must be bfloat16"); + TORCH_CHECK(topk_idx.scalar_type() == torch::kInt64, + "topk_idx must be int64"); + TORCH_CHECK(x.dim() == 2 && topk_idx.dim() == 2, + "x/topk_idx must be 2D tensors"); + TORCH_CHECK(x.size(0) == topk_idx.size(0), + "x and topk_idx must have the same token count"); + TORCH_CHECK(num_ranks > 0, "num_ranks must be positive"); + TORCH_CHECK(num_experts > 0 && num_experts % num_ranks == 0, + "num_experts must be positive and divisible by num_ranks"); + + const c10::cuda::CUDAGuard guard(x.device()); + auto stream = at::cuda::getCurrentCUDAStream(); + int num_tokens = static_cast(x.size(0)); + int hidden = static_cast(x.size(1)); + int num_topk = static_cast(topk_idx.size(1)); + int num_local_experts = num_experts / num_ranks; + int max_messages_per_rank = num_tokens * num_topk; + int fused_hidden = hidden + 2; + int fused_slots_per_rank = max_messages_per_rank + num_local_experts; + auto int32_opts = + torch::TensorOptions().dtype(torch::kInt32).device(x.device()); + auto int64_opts = + torch::TensorOptions().dtype(torch::kInt64).device(x.device()); + + auto counts_by_expert = + torch::zeros({num_ranks, num_local_experts}, int32_opts); + int total = num_tokens * num_topk; + count_dispatch_kernel<<>>( + topk_idx.data_ptr(), counts_by_expert.data_ptr(), + num_tokens, num_topk, num_local_experts); + + auto expert_offsets = + torch::empty({num_ranks, num_local_experts}, int32_opts); + prefix_counts_kernel<<>>( + counts_by_expert.data_ptr(), + expert_offsets.data_ptr(), num_ranks, num_local_experts); + + auto send_payload = torch::zeros( + {num_ranks, fused_slots_per_rank, fused_hidden}, x.options()); + auto send_route = torch::full({total, 4}, -1, int64_opts); + auto counters = torch::zeros({num_ranks, num_local_experts}, int32_opts); + pack_dispatch_fused_kernel<<>>( + reinterpret_cast(x.data_ptr()), + topk_idx.data_ptr(), expert_offsets.data_ptr(), + counters.data_ptr(), + reinterpret_cast(send_payload.data_ptr()), + send_route.data_ptr(), num_tokens, hidden, num_topk, num_ranks, + num_local_experts, max_messages_per_rank); + append_counts_to_payload_kernel<<< + ceil_div(num_ranks * num_local_experts, 256), 256, 0, stream>>>( + counts_by_expert.data_ptr(), + reinterpret_cast(send_payload.data_ptr()), num_ranks, + num_local_experts, max_messages_per_rank, hidden); + return {send_payload, send_route}; +} + +std::tuple +torch_alltoall_compact_dispatch_fused(const torch::Tensor& recv_payload, + int num_local_experts, + int num_max_dispatch_tokens_per_rank) { + check_cuda(recv_payload, "recv_payload"); + TORCH_CHECK(recv_payload.scalar_type() == torch::kBFloat16, + "recv_payload must be bfloat16"); + TORCH_CHECK(recv_payload.dim() == 3, "recv_payload must be a 3D tensor"); + TORCH_CHECK(num_local_experts > 0, "num_local_experts must be positive"); + TORCH_CHECK(num_max_dispatch_tokens_per_rank > 0, + "num_max_dispatch_tokens_per_rank must be positive"); + const c10::cuda::CUDAGuard guard(recv_payload.device()); + auto stream = at::cuda::getCurrentCUDAStream(); + int num_ranks = static_cast(recv_payload.size(0)); + int fused_slots_per_rank = static_cast(recv_payload.size(1)); + int fused_hidden = static_cast(recv_payload.size(2)); + TORCH_CHECK(fused_hidden > 2, "recv_payload hidden dimension is invalid"); + TORCH_CHECK(fused_slots_per_rank > num_local_experts, + "recv_payload slot dimension is invalid"); + int hidden = fused_hidden - 2; + int max_messages_per_rank = fused_slots_per_rank - num_local_experts; + int num_recv_slots = num_ranks * num_max_dispatch_tokens_per_rank; + auto int32_opts = torch::TensorOptions() + .dtype(torch::kInt32) + .device(recv_payload.device()); + auto int64_opts = torch::TensorOptions() + .dtype(torch::kInt64) + .device(recv_payload.device()); + + auto packed_recv_x = torch::zeros( + {num_local_experts, num_recv_slots, hidden}, recv_payload.options()); + auto packed_recv_src_info = + torch::full({num_local_experts, num_recv_slots}, -1, int32_opts); + auto layout_range = + torch::zeros({num_local_experts, num_ranks}, int64_opts); + auto recv_count = torch::zeros({num_local_experts}, int32_opts); + auto return_src_pos = + torch::full({num_local_experts, num_recv_slots, 2}, -1, int64_opts); + compact_dispatch_fused_kernel<<>>( + reinterpret_cast(recv_payload.data_ptr()), + layout_range.data_ptr(), recv_count.data_ptr(), + return_src_pos.data_ptr(), + reinterpret_cast(packed_recv_x.data_ptr()), + packed_recv_src_info.data_ptr(), hidden, num_ranks, + num_local_experts, max_messages_per_rank, num_recv_slots); + return {packed_recv_x, packed_recv_src_info, layout_range, recv_count, + return_src_pos}; +} + +torch::Tensor torch_alltoall_pack_combine(const torch::Tensor& expert_buffers, + const torch::Tensor& return_src_pos, + int num_ranks, + int max_messages_per_rank) { + check_cuda(expert_buffers, "expert_buffers"); + check_cuda(return_src_pos, "return_src_pos"); + TORCH_CHECK(expert_buffers.scalar_type() == torch::kBFloat16, + "expert_buffers must be bfloat16"); + TORCH_CHECK(return_src_pos.scalar_type() == torch::kInt64, + "return_src_pos must be int64"); + TORCH_CHECK(expert_buffers.dim() == 3, + "expert_buffers must be a 3D tensor"); + TORCH_CHECK(return_src_pos.dim() == 3 && return_src_pos.size(2) == 2, + "return_src_pos must have shape [experts, slots, 2]"); + TORCH_CHECK(num_ranks > 0 && max_messages_per_rank > 0, + "num_ranks and max_messages_per_rank must be positive"); + const c10::cuda::CUDAGuard guard(expert_buffers.device()); + auto stream = at::cuda::getCurrentCUDAStream(); + int hidden = static_cast(expert_buffers.size(2)); + int num_recv_slots = static_cast(expert_buffers.size(1)); + auto send_payload = torch::zeros({num_ranks, max_messages_per_rank, hidden}, + expert_buffers.options()); + int num_local_experts = static_cast(expert_buffers.size(0)); + pack_combine_kernel<<>>( + reinterpret_cast(expert_buffers.data_ptr()), + return_src_pos.data_ptr(), + reinterpret_cast(send_payload.data_ptr()), hidden, + max_messages_per_rank, num_recv_slots); + return send_payload; +} + +torch::Tensor torch_alltoall_reduce_combine( + const torch::Tensor& recv_payload, const torch::Tensor& send_route, + const torch::Tensor& topk_weights, + const std::optional& out) { + check_cuda(recv_payload, "recv_payload"); + check_cuda(send_route, "send_route"); + check_cuda(topk_weights, "topk_weights"); + TORCH_CHECK(recv_payload.scalar_type() == torch::kBFloat16, + "recv_payload must be bfloat16"); + TORCH_CHECK(send_route.scalar_type() == torch::kInt64, + "send_route must be int64"); + TORCH_CHECK(topk_weights.scalar_type() == torch::kFloat32, + "topk_weights must be float32"); + TORCH_CHECK(recv_payload.dim() == 3 && send_route.dim() == 2 && + topk_weights.dim() == 2, + "recv_payload, send_route, and topk_weights must be 3D/2D/2D"); + TORCH_CHECK(send_route.size(1) == 4, + "send_route must have shape [tokens * topk, 4]"); + const c10::cuda::CUDAGuard guard(recv_payload.device()); + auto stream = at::cuda::getCurrentCUDAStream(); + int hidden = static_cast(recv_payload.size(2)); + int max_messages_per_rank = static_cast(recv_payload.size(1)); + int num_tokens = static_cast(topk_weights.size(0)); + int num_topk = static_cast(topk_weights.size(1)); + TORCH_CHECK(send_route.size(0) == num_tokens * num_topk, + "send_route length must match topk_weights"); + auto combined = out.has_value() ? out.value() + : torch::empty({num_tokens, hidden}, + recv_payload.options()); + TORCH_CHECK(combined.scalar_type() == torch::kBFloat16, + "out must be bfloat16"); + dim3 grid(num_tokens, ceil_div(hidden, 256)); + reduce_combine_kernel<<>>( + reinterpret_cast(recv_payload.data_ptr()), + send_route.data_ptr(), topk_weights.data_ptr(), + reinterpret_cast(combined.data_ptr()), num_tokens, hidden, + num_topk, max_messages_per_rank); + return combined; +} + +} // namespace mooncake diff --git a/mooncake-wheel/mooncake/mooncake_ep_buffer.py b/mooncake-wheel/mooncake/mooncake_ep_buffer.py index b3deb7e762..00ed2d1d95 100644 --- a/mooncake-wheel/mooncake/mooncake_ep_buffer.py +++ b/mooncake-wheel/mooncake/mooncake_ep_buffer.py @@ -4,6 +4,26 @@ from typing import Any, Callable, List, Tuple, Optional, Union +def _env_enabled(name: str, default: bool = False) -> bool: + value = os.getenv(name) + if value is None: + return default + return value.upper() in {"1", "ON", "TRUE", "YES"} + + +_USE_MACA = ( + _env_enabled("MOONCAKE_EP_USE_MACA") + or bool(getattr(torch.version, "maca", None)) +) +_USE_SPLIT_SEND_RECV = ( + _env_enabled("MOONCAKE_EP_USE_MUSA") + or _USE_MACA +) +_MACA_PHASE_FENCE = os.getenv("MOONCAKE_EP_MACA_PHASE_FENCE", "p2p").lower() +_USE_TORCH_ALLTOALL = _env_enabled("MOONCAKE_EP_USE_TORCH_ALLTOALL") +_PROFILE_TORCH_ALLTOALL = _env_enabled("MOONCAKE_EP_PROFILE_TORCH_ALLTOALL") + + class EventOverlap: """ A wrapper class to manage CUDA events, also for better overlapping convenience. @@ -64,28 +84,107 @@ def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: class Buffer: def __init__(self, group: dist.ProcessGroup, num_ep_buffer_bytes: int = 0): - from mooncake import ep - # Initialize the CPP runtime self.rank = group.rank() self.group_size = group.size() self.group = group self.num_ep_buffer_bytes = num_ep_buffer_bytes self.backend = self.group - # NIC auto-detection happens inside ep.Buffer via Topology::discover(). - self.runtime = ep.Buffer( - self.rank, self.group_size, num_ep_buffer_bytes - ) + self._use_torch_alltoall = _USE_TORCH_ALLTOALL + if self._use_torch_alltoall: + self.runtime = None + else: + from mooncake import ep + + # NIC auto-detection happens inside ep.Buffer via Topology::discover(). + self.runtime = ep.Buffer( + self.rank, self.group_size, num_ep_buffer_bytes + ) # Fallback flag and buffers. # Note: `sync_nvlink_ipc_handles()` can mutate C++ `ibgda_disabled_` (True->False when # P2P+IPC succeeds for all ranks). We re-evaluate after IPC sync below. - self._use_fallback = bool(self.runtime.ibgda_disabled()) + self._use_fallback = self._use_torch_alltoall or bool( + self.runtime is not None and self.runtime.ibgda_disabled() + ) self._fallback_next_combine_buffer: Optional[torch.Tensor] = None + self._torch_alltoall_profile = {} + self._torch_alltoall_state = {} + self._maca_phase_token: Optional[torch.Tensor] = None + self._maca_phase_recv_tokens: Optional[List[torch.Tensor]] = None self.connect() + + def _maca_phase_fence(self) -> None: + if not _USE_MACA or _MACA_PHASE_FENCE in {"", "0", "off", "none"}: + return + + # Compatibility fence between SEND and RECV. The EP payload still + # uses the P2P fast path; this only keeps rank phases aligned on MACA. + if _MACA_PHASE_FENCE == "barrier": + torch.cuda.synchronize() + dist.barrier(self.group) + return + if _MACA_PHASE_FENCE == "p2p": + if self._maca_phase_token is None: + self._maca_phase_token = torch.empty( + 1, dtype=torch.int32, device="cuda" + ) + if self._maca_phase_recv_tokens is None: + self._maca_phase_recv_tokens = [ + torch.empty(1, dtype=torch.int32, device="cuda") + for _ in range(self.group_size) + ] + self._maca_phase_token.fill_(1) + ops = [] + for peer in range(self.group_size): + if peer == self.rank: + continue + ops.append( + dist.P2POp( + dist.isend, self._maca_phase_token, peer, self.group + ) + ) + ops.append( + dist.P2POp( + dist.irecv, + self._maca_phase_recv_tokens[peer], + peer, + self.group, + ) + ) + if not ops: + return + for work in dist.batch_isend_irecv(ops): + work.wait() + return + if _MACA_PHASE_FENCE != "allreduce": + raise ValueError( + "MOONCAKE_EP_MACA_PHASE_FENCE must be one of: " + "p2p, allreduce, barrier, none" + ) + if self._maca_phase_token is None: + self._maca_phase_token = torch.empty( + 1, dtype=torch.int32, device="cuda" + ) + self._maca_phase_token.fill_(1) + dist.all_reduce( + self._maca_phase_token, op=dist.ReduceOp.SUM, group=self.group + ) + + def _wrap_maca_recv_hook(self, hook: Optional[Callable]) -> Callable: + def wrapped_hook() -> None: + self._maca_phase_fence() + if hook is not None: + hook() + + return wrapped_hook def connect(self, is_update: bool = False): from mooncake import ep + if self._use_torch_alltoall: + self._use_fallback = True + return + if not self._use_fallback: (raddr, rkey) = self.runtime.get_mr_info() @@ -200,8 +299,26 @@ def connect(self, is_update: bool = False): def update_ep_member(self): + if self._use_torch_alltoall: + return self.connect(True) + def _active_ranks_tensor( + self, device: torch.device, dtype: torch.dtype = torch.int32 + ) -> torch.Tensor: + if self._use_torch_alltoall: + return torch.ones((self.group_size,), dtype=dtype, device=device) + + try: + from mooncake.ep import get_active_ranks + + return get_active_ranks(self.backend).to(device=device, dtype=dtype) + except Exception: + return torch.ones((self.group_size,), dtype=dtype, device=device) + + def _active_ranks_list(self, device: torch.device) -> List[int]: + return self._active_ranks_tensor(device=device, dtype=torch.int32).tolist() + @staticmethod def get_ep_buffer_size_hint( num_max_dispatch_tokens_per_rank: int, @@ -209,6 +326,15 @@ def get_ep_buffer_size_hint( num_ranks: int, num_experts: int, ) -> int: + if _USE_TORCH_ALLTOALL: + return ( + 4 + * num_experts + * num_max_dispatch_tokens_per_rank + * (32 + hidden * 2) + + 4 * num_experts * 4 + ) + from mooncake.ep import get_ep_buffer_size_hint return get_ep_buffer_size_hint( @@ -224,7 +350,7 @@ def dispatch( num_max_dispatch_tokens_per_rank: int, num_experts: int, timeout_us: int, - use_fp8: bool = True, + use_fp8: Optional[bool] = None, async_finish: bool = False, return_recv_hook: bool = False, ) -> Tuple[ @@ -234,23 +360,30 @@ def dispatch( EventOverlap, Callable, ]: - # MUSA does not support cooperative grid sync, so the C++ runtime + if use_fp8 is None: + use_fp8 = not _USE_MACA + elif _USE_MACA and use_fp8: + raise NotImplementedError("FP8 dispatch is not supported on MACA") + + # MUSA/MACA do not support cooperative grid sync, so the C++ runtime # splits no-hook calls into SEND -> phase-ack -> RECV instead of using # a single cooperative kernel. async_finish still returns a stream # event, but it is not the CUDA single-kernel cooperative path. - if os.getenv("MOONCAKE_EP_USE_MUSA") and async_finish: + if _USE_SPLIT_SEND_RECV and async_finish: import warnings warnings.warn( - "MUSA async_finish uses split SEND/RECV kernels plus a stream " + "async_finish uses split SEND/RECV kernels plus a stream " "event, not CUDA cooperative single-kernel async semantics.", RuntimeWarning, stacklevel=2, ) - if self._use_fallback: - from mooncake.ep import get_active_ranks + runtime_return_recv_hook = return_recv_hook or ( + _USE_MACA and not self._use_fallback + ) + if self._use_fallback: ( packed_recv_x, packed_recv_x_scales, @@ -267,7 +400,7 @@ def dispatch( use_fp8, return_recv_hook, ) - backend_active_ranks = get_active_ranks(self.backend).to( + backend_active_ranks = self._active_ranks_tensor( device=active_ranks.device, dtype=active_ranks.dtype ) if active_ranks.numel() == backend_active_ranks.numel(): @@ -290,8 +423,13 @@ def dispatch( timeout_us, use_fp8, async_finish, - return_recv_hook, + runtime_return_recv_hook, ) + if _USE_MACA: + hook = self._wrap_maca_recv_hook(hook) + if not return_recv_hook: + hook() + hook = None handle = ( packed_recv_src_info, packed_recv_layout_range, @@ -330,12 +468,12 @@ def combine( return_recv_hook: bool = False, out: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, EventOverlap, Callable]: - # Same MUSA split-kernel behavior as dispatch(). - if os.getenv("MOONCAKE_EP_USE_MUSA") and async_finish: + # Same split-kernel behavior as dispatch(). + if _USE_SPLIT_SEND_RECV and async_finish: import warnings warnings.warn( - "MUSA async_finish uses split SEND/RECV kernels plus a stream " + "async_finish uses split SEND/RECV kernels plus a stream " "event, not CUDA cooperative single-kernel async semantics.", RuntimeWarning, stacklevel=2, @@ -348,9 +486,11 @@ def combine( hidden, num_experts, ) = handle - if self._use_fallback: - from mooncake.ep import get_active_ranks + runtime_return_recv_hook = return_recv_hook or ( + _USE_MACA and not self._use_fallback + ) + if self._use_fallback: combined_x, event, hook = self._fallback_combine( x, topk_idx, @@ -363,7 +503,7 @@ def combine( return_recv_hook, out, ) - backend_active_ranks = get_active_ranks(self.backend).to( + backend_active_ranks = self._active_ranks_tensor( device=active_ranks.device, dtype=active_ranks.dtype ) if active_ranks.numel() == backend_active_ranks.numel(): @@ -381,9 +521,14 @@ def combine( timeout_us, zero_copy, async_finish, - return_recv_hook, + runtime_return_recv_hook, out, ) + if _USE_MACA: + hook = self._wrap_maca_recv_hook(hook) + if not return_recv_hook: + hook() + hook = None tensors_to_record = ( x, topk_idx, @@ -437,6 +582,23 @@ class _DummyEvent: def current_stream_wait(self): torch.cuda.synchronize() + class _CudaTimer: + def __init__(self, enabled: bool) -> None: + self.enabled = enabled + self.samples = {} + + def measure(self, name: str, fn: Callable[[], Any]) -> Any: + if not self.enabled: + return fn() + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + result = fn() + end.record() + torch.cuda.synchronize() + self.samples[name] = start.elapsed_time(end) / 1000.0 + return result + @staticmethod def _fp8_cast(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: assert x.dim() == 2 and x.size(1) % 128 == 0 @@ -449,6 +611,185 @@ def _fp8_cast(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: x_scales = (x_amax / 448.0).view(m, -1) return x_fp8, x_scales + def _torch_alltoall_routed_dispatch( + self, + x: torch.Tensor, + topk_idx: torch.Tensor, + num_max_dispatch_tokens_per_rank: int, + num_experts: int, + return_recv_hook: bool, + ): + with torch.profiler.record_function("dispatch.torch_alltoall_routed"): + timer = Buffer._CudaTimer(_PROFILE_TORCH_ALLTOALL) + num_tokens = x.size(0) + num_topk = topk_idx.size(1) + num_ranks = self.group_size + num_local_experts = num_experts // num_ranks + max_messages_per_rank = num_tokens * num_topk + + if x.dtype != torch.bfloat16: + raise NotImplementedError( + "torch all-to-all fallback currently supports bfloat16 only" + ) + + from mooncake import ep + + send_payload, send_route = timer.measure( + "dispatch_pack", + lambda: ep.torch_alltoall_pack_dispatch_fused( + x.contiguous(), + topk_idx.contiguous(), + num_experts, + num_ranks, + ), + ) + recv_payload = torch.empty_like(send_payload) + timer.measure( + "dispatch_a2a_payload", + lambda: dist.all_to_all_single( + recv_payload, send_payload, group=self.group + ), + ) + if _PROFILE_TORCH_ALLTOALL: + timer.samples["dispatch_a2a_meta"] = 0.0 + ( + packed_recv_x, + packed_recv_src_info, + packed_recv_layout_range, + packed_recv_count, + return_src_pos, + ) = timer.measure( + "dispatch_compact", + lambda: ep.torch_alltoall_compact_dispatch_fused( + recv_payload, + num_local_experts, + num_max_dispatch_tokens_per_rank, + ), + ) + self._torch_alltoall_state = { + "send_route": send_route, + "return_src_pos": return_src_pos, + "num_tokens": num_tokens, + "num_topk": num_topk, + "max_messages_per_rank": max_messages_per_rank, + } + self._fallback_next_combine_buffer = torch.empty_like(packed_recv_x) + self._torch_alltoall_profile = timer.samples + hook = (lambda: None) if return_recv_hook else (lambda: None) + event = Buffer._DummyEvent() + return ( + packed_recv_x, + None, + packed_recv_count, + packed_recv_src_info, + packed_recv_layout_range, + event, + hook, + ) + + def _torch_alltoall_routed_combine( + self, + x: torch.Tensor, + topk_idx: torch.Tensor, + topk_weights: torch.Tensor, + layout_range: torch.Tensor, + zero_copy: bool, + return_recv_hook: bool, + out: Optional[torch.Tensor], + ): + with torch.profiler.record_function("combine.torch_alltoall_routed"): + timer = Buffer._CudaTimer(_PROFILE_TORCH_ALLTOALL) + expert_buffers = self._fallback_next_combine_buffer if zero_copy else x + if expert_buffers is None: + raise RuntimeError( + "zero_copy combine called before dispatch buffer allocation" + ) + if expert_buffers.dtype != torch.bfloat16: + expert_buffers = expert_buffers.to(torch.bfloat16) + + num_tokens, num_topk = topk_idx.shape + num_ranks = self.group_size + max_messages_per_rank = num_tokens * num_topk + state = self._torch_alltoall_state + if not state: + raise RuntimeError("combine called without torch all-to-all dispatch state") + + from mooncake import ep + + send_payload = timer.measure( + "combine_pack", + lambda: ep.torch_alltoall_pack_combine( + expert_buffers.contiguous(), + state["return_src_pos"], + num_ranks, + max_messages_per_rank, + ), + ) + recv_payload = torch.empty_like(send_payload) + timer.measure( + "combine_a2a_payload", + lambda: dist.all_to_all_single( + recv_payload, send_payload, group=self.group + ), + ) + combined = timer.measure( + "combine_reduce", + lambda: ep.torch_alltoall_reduce_combine( + recv_payload, + state["send_route"], + topk_weights.contiguous(), + out, + ), + ) + self._torch_alltoall_profile = timer.samples + hook = (lambda: None) if return_recv_hook else (lambda: None) + event = Buffer._DummyEvent() + return combined, event, hook + + def _torch_alltoall_dispatch( + self, + x: torch.Tensor, + topk_idx: torch.Tensor, + num_max_dispatch_tokens_per_rank: int, + num_experts: int, + use_fp8: bool, + return_recv_hook: bool, + ): + if use_fp8: + raise NotImplementedError( + "FP8 dispatch is not supported by torch all-to-all fallback" + ) + return self._torch_alltoall_routed_dispatch( + x, + topk_idx, + num_max_dispatch_tokens_per_rank, + num_experts, + return_recv_hook, + ) + + def _torch_alltoall_combine( + self, + x: torch.Tensor, + topk_idx: torch.Tensor, + topk_weights: torch.Tensor, + src_info: torch.Tensor, + layout_range: torch.Tensor, + num_max_dispatch_tokens_per_rank: int, + num_experts: int, + zero_copy: bool, + return_recv_hook: bool, + out: Optional[torch.Tensor], + ): + return self._torch_alltoall_routed_combine( + x, + topk_idx, + topk_weights, + layout_range, + zero_copy, + return_recv_hook, + out, + ) + def _fallback_dispatch( self, x: torch.Tensor, @@ -458,7 +799,15 @@ def _fallback_dispatch( use_fp8: bool, return_recv_hook: bool, ): - from mooncake.ep import get_active_ranks + if self._use_torch_alltoall: + return self._torch_alltoall_dispatch( + x, + topk_idx, + num_max_dispatch_tokens_per_rank, + num_experts, + use_fp8, + return_recv_hook, + ) with torch.profiler.record_function("dispatch"): num_tokens, hidden = x.shape @@ -476,7 +825,7 @@ def _fallback_dispatch( ] dist.all_gather(num_tokens_list, num_tokens_tensor, group=self.group) num_tokens_per_rank = [t.item() for t in num_tokens_list] - backend_active_ranks = get_active_ranks(self.backend).tolist() + backend_active_ranks = self._active_ranks_list(x.device) for i in range(num_ranks): if backend_active_ranks[i] == 0: num_tokens_per_rank[i] = 0 @@ -682,7 +1031,19 @@ def _fallback_combine( return_recv_hook: bool, out: Optional[torch.Tensor], ): - from mooncake.ep import get_active_ranks + if self._use_torch_alltoall: + return self._torch_alltoall_combine( + x, + topk_idx, + topk_weights, + src_info, + layout_range, + num_max_dispatch_tokens_per_rank, + num_experts, + zero_copy, + return_recv_hook, + out, + ) with torch.profiler.record_function("combine"): num_tokens = topk_idx.size(0) @@ -702,7 +1063,7 @@ def _fallback_combine( ] dist.all_gather(num_tokens_list, num_tokens_tensor, group=self.group) num_tokens_per_rank = [t.item() for t in num_tokens_list] - backend_active_ranks = get_active_ranks(self.backend).tolist() + backend_active_ranks = self._active_ranks_list(topk_idx.device) for i in range(num_ranks): if backend_active_ranks[i] == 0: num_tokens_per_rank[i] = 0 From 10103df17d5a4cb3f465404242acf3807da01120 Mon Sep 17 00:00:00 2001 From: Dayuxiaoshui <792179245@qq.com> Date: Fri, 26 Jun 2026 10:00:29 +0000 Subject: [PATCH 2/3] Stabilize MACA EP P2P path --- CMakeLists.txt | 4 + mooncake-ep/BuildEpExt.cmake | 11 + mooncake-ep/include/mooncake_ep_alltoall.h | 26 -- mooncake-ep/include/mooncake_ep_configs.cuh | 10 +- mooncake-ep/include/mooncake_ep_device.h | 48 ++- mooncake-ep/include/mooncake_ep_event.h | 2 + mooncake-ep/include/mooncake_ep_exception.cuh | 2 +- mooncake-ep/include/mooncake_ep_utils.cuh | 4 +- mooncake-ep/setup.py | 1 - mooncake-ep/src/CMakeLists.txt | 3 +- mooncake-ep/src/ep_py.cpp | 10 +- mooncake-ep/src/mooncake_ep_alltoall.cu | 406 ------------------ mooncake-ep/src/mooncake_ep_buffer.cpp | 37 +- mooncake-ep/src/mooncake_ep_kernel.cu | 42 +- mooncake-ep/tests/test_ep_grid.py | 10 +- mooncake-pg/BuildPgExt.cmake | 11 + mooncake-pg/setup.py | 40 +- .../include/CMakeLists.txt | 1 + .../include/transfer_engine.h | 6 +- .../include/transfer_engine_impl.h | 9 +- .../include/transport/device/comm_device.cuh | 4 + .../include/transport/device/device_ops.cuh | 2 + .../include/transport/device/ibgda_device.cuh | 28 ++ .../transport/device/maca/maca_ops.cuh | 96 +++++ .../src/transfer_engine.cpp | 6 +- .../src/transfer_engine_impl.cpp | 3 +- .../src/transport/CMakeLists.txt | 12 +- .../src/transport/device/CMakeLists.txt | 7 +- .../ibgda_device_transport_maca_stub.cpp | 34 ++ .../transport/device/p2p_device_transport.cpp | 249 ++++++++++- mooncake-wheel/mooncake/mooncake_ep_buffer.py | 394 +++++------------ mooncake-wheel/tests/test_mooncake_ep.py | 8 +- 32 files changed, 741 insertions(+), 785 deletions(-) delete mode 100644 mooncake-ep/include/mooncake_ep_alltoall.h delete mode 100644 mooncake-ep/src/mooncake_ep_alltoall.cu create mode 100644 mooncake-transfer-engine/include/transport/device/maca/maca_ops.cuh create mode 100644 mooncake-transfer-engine/src/transport/device/ibgda_device_transport_maca_stub.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 2872bd1ad4..122c50ff52 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -145,7 +145,9 @@ if (WITH_EP) "-DTORCH_CUDA_ARCH_LIST=${_torch_cuda_arch_list_pipe}" "-DSTAGING_DIR=${EP_PG_STAGING_DIR}" "-DENGINE_SO_PATH=$" + "-DPython3_EXECUTABLE=${Python3_EXECUTABLE}" "-DEP_USE_MUSA=$,1,0>" + "-DEP_USE_MACA=$,1,0>" -P "${CMAKE_CURRENT_SOURCE_DIR}/mooncake-ep/BuildEpExt.cmake" COMMENT "Building Mooncake EP Python extension(s)" DEPENDS engine @@ -162,7 +164,9 @@ if (WITH_EP) "-DTORCH_CUDA_ARCH_LIST=${_torch_cuda_arch_list_pipe}" "-DSTAGING_DIR=${EP_PG_STAGING_DIR}" "-DENGINE_SO_PATH=$" + "-DPython3_EXECUTABLE=${Python3_EXECUTABLE}" "-DEP_USE_MUSA=$,1,0>" + "-DEP_USE_MACA=$,1,0>" -P "${CMAKE_CURRENT_SOURCE_DIR}/mooncake-pg/BuildPgExt.cmake" COMMENT "Building Mooncake PG Python extension(s)" DEPENDS engine mooncake_ep_ext diff --git a/mooncake-ep/BuildEpExt.cmake b/mooncake-ep/BuildEpExt.cmake index c94529cbb7..4a5a661ed4 100644 --- a/mooncake-ep/BuildEpExt.cmake +++ b/mooncake-ep/BuildEpExt.cmake @@ -11,6 +11,7 @@ # STAGING_DIR - destination directory for the built .so files # ENGINE_SO_PATH - absolute path to the built engine.cpython-XYZ.so # EP_USE_MUSA - set to "1" when building for MUSA (MTLink path) +# EP_USE_MACA - set to "1" when building for MACA (MTLink path) cmake_minimum_required(VERSION 3.16) @@ -40,6 +41,16 @@ if(EP_USE_MUSA) else() unset(ENV{MOONCAKE_EP_USE_MUSA}) endif() +if(EP_USE_MACA) + set(ENV{MOONCAKE_EP_USE_MACA} "1") + if(DEFINED ENV{MACA_PATH}) + set(ENV{MACA_HOME} "$ENV{MACA_PATH}") + elseif(DEFINED ENV{MACA_HOME}) + set(ENV{MACA_PATH} "$ENV{MACA_HOME}") + endif() +else() + unset(ENV{MOONCAKE_EP_USE_MACA}) +endif() # --------------------------------------------------------------------------- # 2. Ensure engine.so exists in mooncake-wheel/mooncake/ for setup.py linking. diff --git a/mooncake-ep/include/mooncake_ep_alltoall.h b/mooncake-ep/include/mooncake_ep_alltoall.h deleted file mode 100644 index 2b6b07c2a7..0000000000 --- a/mooncake-ep/include/mooncake_ep_alltoall.h +++ /dev/null @@ -1,26 +0,0 @@ -#pragma once - -#include - -namespace mooncake { - -std::tuple torch_alltoall_pack_dispatch_fused( - const torch::Tensor& x, const torch::Tensor& topk_idx, int num_experts, - int num_ranks); - -std::tuple -torch_alltoall_compact_dispatch_fused(const torch::Tensor& recv_payload, - int num_local_experts, - int num_max_dispatch_tokens_per_rank); - -torch::Tensor torch_alltoall_pack_combine(const torch::Tensor& expert_buffers, - const torch::Tensor& return_src_pos, - int num_ranks, - int max_messages_per_rank); - -torch::Tensor torch_alltoall_reduce_combine( - const torch::Tensor& recv_payload, const torch::Tensor& send_route, - const torch::Tensor& topk_weights, const std::optional& out); - -} // namespace mooncake diff --git a/mooncake-ep/include/mooncake_ep_configs.cuh b/mooncake-ep/include/mooncake_ep_configs.cuh index 103f6b8f09..1e7f0c2149 100644 --- a/mooncake-ep/include/mooncake_ep_configs.cuh +++ b/mooncake-ep/include/mooncake_ep_configs.cuh @@ -40,16 +40,22 @@ #endif #include +#ifndef MOONCAKE_EP_USE_MACA #include -#include #include +#endif +#include + +#if defined(MOONCAKE_EP_USE_MUSA) || defined(MOONCAKE_EP_USE_MACA) +#define MOONCAKE_EP_SPLIT_SEND_RECV 1 +#endif // torchada maps nv_bfloat16 → __mt_bfloat16 which is an incomplete type on // MUSA, so sizeof(__mt_bfloat16) fails. mt_bfloat16 (the complete typedef in // musa_bf16.hpp) requires the MUSA device compiler (mcc) and cannot be // included from host .cpp files. Use EP_BF16_SIZE: sizeof(nv_bfloat16) on // CUDA, hardcoded 2 on MUSA (both are 2 bytes). -#ifdef MOONCAKE_EP_USE_MUSA +#if defined(MOONCAKE_EP_USE_MUSA) || defined(MOONCAKE_EP_USE_MACA) #define EP_BF16_SIZE 2 #else #define EP_BF16_SIZE sizeof(nv_bfloat16) diff --git a/mooncake-ep/include/mooncake_ep_device.h b/mooncake-ep/include/mooncake_ep_device.h index 244751bec0..9322aea49b 100644 --- a/mooncake-ep/include/mooncake_ep_device.h +++ b/mooncake-ep/include/mooncake_ep_device.h @@ -49,7 +49,49 @@ __forceinline__ __device__ int get_lane_id() { return threadIdx.x % 32; } } \ } -#else // !MOONCAKE_EP_USE_MUSA +#elif defined(MOONCAKE_EP_USE_MACA) + +// -- FP8 types --------------------------------------------------------------- +// MetaX C500 does not support the FP8 EP path. The dispatch template is +// still compiled because the host selects the kernel through a runtime bool, so +// provide storage stubs and reject actual FP8 use in the host/Python wrappers. +#include +using ep_fp8_storage_t = uint8_t; +using ep_fp8x2_storage_t = uint16_t; +#if defined(__CUDACC__) || defined(__MCC__) +__device__ __forceinline__ ep_fp8x2_storage_t ep_cvt_float2_to_fp8x2(float2) { + return 0; +} +#endif + +// -- Device intrinsics ------------------------------------------------------- +#ifndef __activemask +#define __activemask() (0xffffffff) +#endif + +#if defined(__CUDACC__) || defined(__MCC__) +__forceinline__ __device__ int get_lane_id() { return threadIdx.x % 32; } +#endif + +// -- Kernel launch (MACA: no cooperative launch) ----------------------------- +#define EP_LAUNCH_BOUNDS(max_threads, min_blocks) + +#define SETUP_LAUNCH_CONFIG(num_sms, num_threads, stream) \ + dim3 _grid(num_sms); \ + dim3 _block(num_threads); \ + cudaStream_t _stream = stream + +#define LAUNCH_KERNEL(config, kernel, ...) \ + kernel<<<_grid, _block, 0, _stream>>>(__VA_ARGS__); \ + { \ + auto _err = cudaGetLastError(); \ + if (_err != cudaSuccess) { \ + fprintf(stderr, "[EP] kernel launch failed: %s\n", \ + cudaGetErrorString(_err)); \ + } \ + } + +#else // !MOONCAKE_EP_USE_MUSA && !MOONCAKE_EP_USE_MACA // -- FP8 types (CUDA native names) ------------------------------------------- #include @@ -86,7 +128,9 @@ __forceinline__ __device__ int get_lane_id() { #define LAUNCH_KERNEL(config, kernel, ...) \ CUDA_CHECK(cudaLaunchKernelEx(config, kernel, ##__VA_ARGS__)) -#endif // MOONCAKE_EP_USE_MUSA +#endif // MOONCAKE_EP_USE_MUSA / MOONCAKE_EP_USE_MACA // Both platforms need IB verbs +#ifndef MOONCAKE_EP_USE_MACA #include +#endif diff --git a/mooncake-ep/include/mooncake_ep_event.h b/mooncake-ep/include/mooncake_ep_event.h index 6dc6eb63fc..4809704983 100644 --- a/mooncake-ep/include/mooncake_ep_event.h +++ b/mooncake-ep/include/mooncake_ep_event.h @@ -25,6 +25,8 @@ struct EventHandle { void current_stream_wait() const { at::cuda::getCurrentCUDAStream().unwrap().wait(*event); } + + void synchronize() const { event->synchronize(); } }; inline torch::Event create_event(const at::cuda::CUDAStream& s) { diff --git a/mooncake-ep/include/mooncake_ep_exception.cuh b/mooncake-ep/include/mooncake_ep_exception.cuh index 060744594a..097421f3a2 100644 --- a/mooncake-ep/include/mooncake_ep_exception.cuh +++ b/mooncake-ep/include/mooncake_ep_exception.cuh @@ -42,7 +42,7 @@ class EPException : public std::exception { #endif #ifndef EP_DEVICE_ASSERT -#ifdef MOONCAKE_EP_USE_MUSA +#if defined(MOONCAKE_EP_USE_MUSA) || defined(MOONCAKE_EP_USE_MACA) // MUSA SDK 4.3.x can turn kernels that merely contain a device-side __trap() // branch into illegal memory accesses, even when the assertion condition is // true. Keep these invariants as host/static checks on MUSA builds. diff --git a/mooncake-ep/include/mooncake_ep_utils.cuh b/mooncake-ep/include/mooncake_ep_utils.cuh index 90492f67ed..7b6b194c2b 100644 --- a/mooncake-ep/include/mooncake_ep_utils.cuh +++ b/mooncake-ep/include/mooncake_ep_utils.cuh @@ -48,7 +48,7 @@ struct VecInt<16> { }; // ---- TMA / mbarrier helpers (CUDA only) ---- -#ifndef MOONCAKE_EP_USE_MUSA +#if !defined(MOONCAKE_EP_USE_MUSA) && !defined(MOONCAKE_EP_USE_MACA) __device__ __forceinline__ void fence_view_async_shared() { asm volatile("fence.proxy.async.shared::cta; \n" ::); @@ -136,7 +136,7 @@ __device__ __forceinline__ void tma_store_wait() { asm volatile("cp.async.bulk.wait_group.read %0;" ::"n"(N) : "memory"); } -#endif // MOONCAKE_EP_USE_MUSA +#endif // !MOONCAKE_EP_USE_MUSA && !MOONCAKE_EP_USE_MACA template __host__ __device__ dtype_t cell_div(dtype_t a, dtype_t b) { diff --git a/mooncake-ep/setup.py b/mooncake-ep/setup.py index bf1614b6a3..955af759bd 100644 --- a/mooncake-ep/setup.py +++ b/mooncake-ep/setup.py @@ -114,7 +114,6 @@ def existing_dirs(*paths): "src/ep_py.cpp", "src/mooncake_ep_buffer.cpp", "src/mooncake_ep_kernel.cu", - "src/mooncake_ep_alltoall.cu", ], extra_compile_args={"cxx": cxx_args, "nvcc": device_args}, libraries=cuda_libraries, diff --git a/mooncake-ep/src/CMakeLists.txt b/mooncake-ep/src/CMakeLists.txt index 1394615ea3..a88f10d97b 100644 --- a/mooncake-ep/src/CMakeLists.txt +++ b/mooncake-ep/src/CMakeLists.txt @@ -1,8 +1,7 @@ add_library(mooncake_ep ep_py.cpp mooncake_ep_buffer.cpp - mooncake_ep_kernel.cu - mooncake_ep_alltoall.cu) + mooncake_ep_kernel.cu) set_target_properties(mooncake_ep PROPERTIES POSITION_INDEPENDENT_CODE ON) target_link_libraries(mooncake_ep PUBLIC ${TORCH_LIBRARIES} transfer_engine ibverbs mlx5) diff --git a/mooncake-ep/src/ep_py.cpp b/mooncake-ep/src/ep_py.cpp index 9937311b77..361df2f96f 100644 --- a/mooncake-ep/src/ep_py.cpp +++ b/mooncake-ep/src/ep_py.cpp @@ -1,5 +1,4 @@ #include -#include #include #include #include @@ -14,16 +13,11 @@ namespace mooncake { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("get_ep_buffer_size_hint", &get_ep_buffer_size_hint); - m.def("torch_alltoall_pack_dispatch_fused", - &torch_alltoall_pack_dispatch_fused); - m.def("torch_alltoall_compact_dispatch_fused", - &torch_alltoall_compact_dispatch_fused); - m.def("torch_alltoall_pack_combine", &torch_alltoall_pack_combine); - m.def("torch_alltoall_reduce_combine", &torch_alltoall_reduce_combine); py::class_(m, "EventHandle") .def(py::init<>()) - .def("current_stream_wait", &EventHandle::current_stream_wait); + .def("current_stream_wait", &EventHandle::current_stream_wait) + .def("synchronize", &EventHandle::synchronize); m.attr("MAX_QP_COUNT") = pybind11::int_(MAX_QP_COUNT); diff --git a/mooncake-ep/src/mooncake_ep_alltoall.cu b/mooncake-ep/src/mooncake_ep_alltoall.cu deleted file mode 100644 index fb1493a550..0000000000 --- a/mooncake-ep/src/mooncake_ep_alltoall.cu +++ /dev/null @@ -1,406 +0,0 @@ -#include - -#include -#include -#include - -#include - -namespace mooncake { -namespace { - -inline void check_cuda(const torch::Tensor& t, const char* name) { - TORCH_CHECK(t.is_cuda(), name, " must be a CUDA tensor"); - TORCH_CHECK(t.is_contiguous(), name, " must be contiguous"); -} - -inline int ceil_div(int a, int b) { return (a + b - 1) / b; } - -__device__ __forceinline__ void store_i32_words(nv_bfloat16* dst, - int32_t value) { - auto* words = reinterpret_cast(dst); - uint32_t raw = static_cast(value); - words[0] = static_cast(raw & 0xffffu); - words[1] = static_cast(raw >> 16); -} - -__device__ __forceinline__ int32_t load_i32_words(const nv_bfloat16* src) { - const auto* words = reinterpret_cast(src); - uint32_t raw = static_cast(words[0]) | - (static_cast(words[1]) << 16); - return static_cast(raw); -} - -__global__ void count_dispatch_kernel(const int64_t* topk_idx, - int32_t* counts_by_expert, int num_tokens, - int num_topk, int num_local_experts) { - int i = blockIdx.x * blockDim.x + threadIdx.x; - int total = num_tokens * num_topk; - if (i >= total) return; - - int64_t expert = topk_idx[i]; - if (expert < 0) return; - int dst_rank = static_cast(expert / num_local_experts); - int local_expert = static_cast(expert % num_local_experts); - atomicAdd(counts_by_expert + dst_rank * num_local_experts + local_expert, - 1); -} - -__global__ void prefix_counts_kernel(const int32_t* counts_by_expert, - int32_t* expert_offsets, int num_ranks, - int num_local_experts) { - int r = blockIdx.x; - if (r >= num_ranks || threadIdx.x != 0) return; - int running = 0; - for (int e = 0; e < num_local_experts; ++e) { - expert_offsets[r * num_local_experts + e] = running; - running += counts_by_expert[r * num_local_experts + e]; - } -} - -__global__ void pack_dispatch_fused_kernel( - const nv_bfloat16* x, const int64_t* topk_idx, - const int32_t* expert_offsets, int32_t* counters, nv_bfloat16* send_payload, - int64_t* send_route, int num_tokens, int hidden, int num_topk, - int num_ranks, int num_local_experts, int max_messages_per_rank) { - int i = blockIdx.x; - int tid = threadIdx.x; - int total = num_tokens * num_topk; - if (i >= total) return; - - int64_t expert = topk_idx[i]; - if (expert < 0) return; - - int token = i / num_topk; - int slot = i - token * num_topk; - int dst_rank = static_cast(expert / num_local_experts); - int local_expert = static_cast(expert % num_local_experts); - int group = dst_rank * num_local_experts + local_expert; - __shared__ int block_local_pos; - if (tid == 0) block_local_pos = atomicAdd(counters + group, 1); - __syncthreads(); - int local_pos = block_local_pos; - int pos = expert_offsets[group] + local_pos; - if (pos >= max_messages_per_rank) return; - - int fused_hidden = hidden + 2; - int fused_slots_per_rank = max_messages_per_rank + num_local_experts; - nv_bfloat16* dst = send_payload + (dst_rank * fused_slots_per_rank + - num_local_experts + pos) * - fused_hidden; - const nv_bfloat16* src = x + token * hidden; - for (int h = tid; h < hidden; h += blockDim.x) dst[h] = src[h]; - if (tid == 0) { - store_i32_words(dst + hidden, slot * num_tokens + token); - int route_idx = i * 4; - send_route[route_idx + 0] = dst_rank; - send_route[route_idx + 1] = pos; - send_route[route_idx + 2] = token; - send_route[route_idx + 3] = slot; - } -} - -__global__ void append_counts_to_payload_kernel( - const int32_t* counts_by_expert, nv_bfloat16* send_payload, int num_ranks, - int num_local_experts, int max_messages_per_rank, int hidden) { - int i = blockIdx.x * blockDim.x + threadIdx.x; - int total = num_ranks * num_local_experts; - if (i >= total) return; - int dst_rank = i / num_local_experts; - int local_expert = i - dst_rank * num_local_experts; - int fused_hidden = hidden + 2; - int fused_slots_per_rank = max_messages_per_rank + num_local_experts; - nv_bfloat16* dst = - send_payload + - (dst_rank * fused_slots_per_rank + local_expert) * fused_hidden; - store_i32_words(dst, counts_by_expert[i]); -} - -__global__ void compact_dispatch_fused_kernel( - const nv_bfloat16* recv_payload, int64_t* layout_range, int32_t* recv_count, - int64_t* return_src_pos, nv_bfloat16* packed_recv_x, - int32_t* packed_recv_src_info, int hidden, int num_ranks, - int num_local_experts, int max_messages_per_rank, int num_recv_slots) { - int idx = blockIdx.x; - int m = idx % max_messages_per_rank; - int pair = idx / max_messages_per_rank; - int src_rank = pair / num_local_experts; - int local_expert = pair - src_rank * num_local_experts; - if (src_rank >= num_ranks) return; - - int fused_hidden = hidden + 2; - int fused_slots_per_rank = max_messages_per_rank + num_local_experts; - const nv_bfloat16* rank_base = - recv_payload + src_rank * fused_slots_per_rank * fused_hidden; - int count = load_i32_words(rank_base + local_expert * fused_hidden); - - int begin = 0; - for (int r = 0; r < src_rank; ++r) { - const nv_bfloat16* prev_rank_base = - recv_payload + r * fused_slots_per_rank * fused_hidden; - begin += load_i32_words(prev_rank_base + local_expert * fused_hidden); - } - - if (m == 0 && threadIdx.x == 0) { - layout_range[local_expert * num_ranks + src_rank] = - (static_cast(begin) << 32) | static_cast(count); - atomicAdd(recv_count + local_expert, count); - } - - if (m >= count) return; - - int src_begin = 0; - for (int e = 0; e < local_expert; ++e) - src_begin += load_i32_words(rank_base + e * fused_hidden); - - int dst_pos = begin + m; - int src_pos = src_begin + m; - if (dst_pos >= num_recv_slots || src_pos >= max_messages_per_rank) return; - - const nv_bfloat16* src = - rank_base + (num_local_experts + src_pos) * fused_hidden; - nv_bfloat16* dst = - packed_recv_x + (local_expert * num_recv_slots + dst_pos) * hidden; - for (int h = threadIdx.x; h < hidden; h += blockDim.x) dst[h] = src[h]; - if (threadIdx.x == 0) { - packed_recv_src_info[local_expert * num_recv_slots + dst_pos] = - load_i32_words(src + hidden); - int route_idx = (local_expert * num_recv_slots + dst_pos) * 2; - return_src_pos[route_idx + 0] = src_rank; - return_src_pos[route_idx + 1] = src_pos; - } -} - -__global__ void pack_combine_kernel(const nv_bfloat16* expert_buffers, - const int64_t* return_src_pos, - nv_bfloat16* send_payload, int hidden, - int max_messages_per_rank, - int num_recv_slots) { - int idx = blockIdx.x; - int tid = threadIdx.x; - int local_expert = idx / num_recv_slots; - int expert_pos = idx - local_expert * num_recv_slots; - int route_idx = idx * 2; - int64_t dst_rank = return_src_pos[route_idx + 0]; - int64_t dst_pos = return_src_pos[route_idx + 1]; - if (dst_rank < 0 || dst_pos < 0) return; - - nv_bfloat16* dst = - send_payload + (dst_rank * max_messages_per_rank + dst_pos) * hidden; - const nv_bfloat16* src = - expert_buffers + (local_expert * num_recv_slots + expert_pos) * hidden; - for (int h = tid; h < hidden; h += blockDim.x) dst[h] = src[h]; -} - -__global__ void reduce_combine_kernel(const nv_bfloat16* recv_payload, - const int64_t* send_route, - const float* topk_weights, - nv_bfloat16* combined, int num_tokens, - int hidden, int num_topk, - int max_messages_per_rank) { - int token = blockIdx.x; - int h = blockIdx.y * blockDim.x + threadIdx.x; - if (token >= num_tokens || h >= hidden) return; - - float acc = 0.0f; - for (int slot = 0; slot < num_topk; ++slot) { - int route = token * num_topk + slot; - int64_t rank = send_route[route * 4 + 0]; - int64_t pos = send_route[route * 4 + 1]; - if (rank < 0 || pos < 0) continue; - const nv_bfloat16* src = - recv_payload + (rank * max_messages_per_rank + pos) * hidden; - float weight = topk_weights[route]; - acc += __bfloat162float(src[h]) * weight; - } - combined[token * hidden + h] = __float2bfloat16(acc); -} - -} // namespace - -std::tuple torch_alltoall_pack_dispatch_fused( - const torch::Tensor& x, const torch::Tensor& topk_idx, int num_experts, - int num_ranks) { - check_cuda(x, "x"); - check_cuda(topk_idx, "topk_idx"); - TORCH_CHECK(x.scalar_type() == torch::kBFloat16, "x must be bfloat16"); - TORCH_CHECK(topk_idx.scalar_type() == torch::kInt64, - "topk_idx must be int64"); - TORCH_CHECK(x.dim() == 2 && topk_idx.dim() == 2, - "x/topk_idx must be 2D tensors"); - TORCH_CHECK(x.size(0) == topk_idx.size(0), - "x and topk_idx must have the same token count"); - TORCH_CHECK(num_ranks > 0, "num_ranks must be positive"); - TORCH_CHECK(num_experts > 0 && num_experts % num_ranks == 0, - "num_experts must be positive and divisible by num_ranks"); - - const c10::cuda::CUDAGuard guard(x.device()); - auto stream = at::cuda::getCurrentCUDAStream(); - int num_tokens = static_cast(x.size(0)); - int hidden = static_cast(x.size(1)); - int num_topk = static_cast(topk_idx.size(1)); - int num_local_experts = num_experts / num_ranks; - int max_messages_per_rank = num_tokens * num_topk; - int fused_hidden = hidden + 2; - int fused_slots_per_rank = max_messages_per_rank + num_local_experts; - auto int32_opts = - torch::TensorOptions().dtype(torch::kInt32).device(x.device()); - auto int64_opts = - torch::TensorOptions().dtype(torch::kInt64).device(x.device()); - - auto counts_by_expert = - torch::zeros({num_ranks, num_local_experts}, int32_opts); - int total = num_tokens * num_topk; - count_dispatch_kernel<<>>( - topk_idx.data_ptr(), counts_by_expert.data_ptr(), - num_tokens, num_topk, num_local_experts); - - auto expert_offsets = - torch::empty({num_ranks, num_local_experts}, int32_opts); - prefix_counts_kernel<<>>( - counts_by_expert.data_ptr(), - expert_offsets.data_ptr(), num_ranks, num_local_experts); - - auto send_payload = torch::zeros( - {num_ranks, fused_slots_per_rank, fused_hidden}, x.options()); - auto send_route = torch::full({total, 4}, -1, int64_opts); - auto counters = torch::zeros({num_ranks, num_local_experts}, int32_opts); - pack_dispatch_fused_kernel<<>>( - reinterpret_cast(x.data_ptr()), - topk_idx.data_ptr(), expert_offsets.data_ptr(), - counters.data_ptr(), - reinterpret_cast(send_payload.data_ptr()), - send_route.data_ptr(), num_tokens, hidden, num_topk, num_ranks, - num_local_experts, max_messages_per_rank); - append_counts_to_payload_kernel<<< - ceil_div(num_ranks * num_local_experts, 256), 256, 0, stream>>>( - counts_by_expert.data_ptr(), - reinterpret_cast(send_payload.data_ptr()), num_ranks, - num_local_experts, max_messages_per_rank, hidden); - return {send_payload, send_route}; -} - -std::tuple -torch_alltoall_compact_dispatch_fused(const torch::Tensor& recv_payload, - int num_local_experts, - int num_max_dispatch_tokens_per_rank) { - check_cuda(recv_payload, "recv_payload"); - TORCH_CHECK(recv_payload.scalar_type() == torch::kBFloat16, - "recv_payload must be bfloat16"); - TORCH_CHECK(recv_payload.dim() == 3, "recv_payload must be a 3D tensor"); - TORCH_CHECK(num_local_experts > 0, "num_local_experts must be positive"); - TORCH_CHECK(num_max_dispatch_tokens_per_rank > 0, - "num_max_dispatch_tokens_per_rank must be positive"); - const c10::cuda::CUDAGuard guard(recv_payload.device()); - auto stream = at::cuda::getCurrentCUDAStream(); - int num_ranks = static_cast(recv_payload.size(0)); - int fused_slots_per_rank = static_cast(recv_payload.size(1)); - int fused_hidden = static_cast(recv_payload.size(2)); - TORCH_CHECK(fused_hidden > 2, "recv_payload hidden dimension is invalid"); - TORCH_CHECK(fused_slots_per_rank > num_local_experts, - "recv_payload slot dimension is invalid"); - int hidden = fused_hidden - 2; - int max_messages_per_rank = fused_slots_per_rank - num_local_experts; - int num_recv_slots = num_ranks * num_max_dispatch_tokens_per_rank; - auto int32_opts = torch::TensorOptions() - .dtype(torch::kInt32) - .device(recv_payload.device()); - auto int64_opts = torch::TensorOptions() - .dtype(torch::kInt64) - .device(recv_payload.device()); - - auto packed_recv_x = torch::zeros( - {num_local_experts, num_recv_slots, hidden}, recv_payload.options()); - auto packed_recv_src_info = - torch::full({num_local_experts, num_recv_slots}, -1, int32_opts); - auto layout_range = - torch::zeros({num_local_experts, num_ranks}, int64_opts); - auto recv_count = torch::zeros({num_local_experts}, int32_opts); - auto return_src_pos = - torch::full({num_local_experts, num_recv_slots, 2}, -1, int64_opts); - compact_dispatch_fused_kernel<<>>( - reinterpret_cast(recv_payload.data_ptr()), - layout_range.data_ptr(), recv_count.data_ptr(), - return_src_pos.data_ptr(), - reinterpret_cast(packed_recv_x.data_ptr()), - packed_recv_src_info.data_ptr(), hidden, num_ranks, - num_local_experts, max_messages_per_rank, num_recv_slots); - return {packed_recv_x, packed_recv_src_info, layout_range, recv_count, - return_src_pos}; -} - -torch::Tensor torch_alltoall_pack_combine(const torch::Tensor& expert_buffers, - const torch::Tensor& return_src_pos, - int num_ranks, - int max_messages_per_rank) { - check_cuda(expert_buffers, "expert_buffers"); - check_cuda(return_src_pos, "return_src_pos"); - TORCH_CHECK(expert_buffers.scalar_type() == torch::kBFloat16, - "expert_buffers must be bfloat16"); - TORCH_CHECK(return_src_pos.scalar_type() == torch::kInt64, - "return_src_pos must be int64"); - TORCH_CHECK(expert_buffers.dim() == 3, - "expert_buffers must be a 3D tensor"); - TORCH_CHECK(return_src_pos.dim() == 3 && return_src_pos.size(2) == 2, - "return_src_pos must have shape [experts, slots, 2]"); - TORCH_CHECK(num_ranks > 0 && max_messages_per_rank > 0, - "num_ranks and max_messages_per_rank must be positive"); - const c10::cuda::CUDAGuard guard(expert_buffers.device()); - auto stream = at::cuda::getCurrentCUDAStream(); - int hidden = static_cast(expert_buffers.size(2)); - int num_recv_slots = static_cast(expert_buffers.size(1)); - auto send_payload = torch::zeros({num_ranks, max_messages_per_rank, hidden}, - expert_buffers.options()); - int num_local_experts = static_cast(expert_buffers.size(0)); - pack_combine_kernel<<>>( - reinterpret_cast(expert_buffers.data_ptr()), - return_src_pos.data_ptr(), - reinterpret_cast(send_payload.data_ptr()), hidden, - max_messages_per_rank, num_recv_slots); - return send_payload; -} - -torch::Tensor torch_alltoall_reduce_combine( - const torch::Tensor& recv_payload, const torch::Tensor& send_route, - const torch::Tensor& topk_weights, - const std::optional& out) { - check_cuda(recv_payload, "recv_payload"); - check_cuda(send_route, "send_route"); - check_cuda(topk_weights, "topk_weights"); - TORCH_CHECK(recv_payload.scalar_type() == torch::kBFloat16, - "recv_payload must be bfloat16"); - TORCH_CHECK(send_route.scalar_type() == torch::kInt64, - "send_route must be int64"); - TORCH_CHECK(topk_weights.scalar_type() == torch::kFloat32, - "topk_weights must be float32"); - TORCH_CHECK(recv_payload.dim() == 3 && send_route.dim() == 2 && - topk_weights.dim() == 2, - "recv_payload, send_route, and topk_weights must be 3D/2D/2D"); - TORCH_CHECK(send_route.size(1) == 4, - "send_route must have shape [tokens * topk, 4]"); - const c10::cuda::CUDAGuard guard(recv_payload.device()); - auto stream = at::cuda::getCurrentCUDAStream(); - int hidden = static_cast(recv_payload.size(2)); - int max_messages_per_rank = static_cast(recv_payload.size(1)); - int num_tokens = static_cast(topk_weights.size(0)); - int num_topk = static_cast(topk_weights.size(1)); - TORCH_CHECK(send_route.size(0) == num_tokens * num_topk, - "send_route length must match topk_weights"); - auto combined = out.has_value() ? out.value() - : torch::empty({num_tokens, hidden}, - recv_payload.options()); - TORCH_CHECK(combined.scalar_type() == torch::kBFloat16, - "out must be bfloat16"); - dim3 grid(num_tokens, ceil_div(hidden, 256)); - reduce_combine_kernel<<>>( - reinterpret_cast(recv_payload.data_ptr()), - send_route.data_ptr(), topk_weights.data_ptr(), - reinterpret_cast(combined.data_ptr()), num_tokens, hidden, - num_topk, max_messages_per_rank); - return combined; -} - -} // namespace mooncake diff --git a/mooncake-ep/src/mooncake_ep_buffer.cpp b/mooncake-ep/src/mooncake_ep_buffer.cpp index e8819f36c8..d7d5d37a42 100644 --- a/mooncake-ep/src/mooncake_ep_buffer.cpp +++ b/mooncake-ep/src/mooncake_ep_buffer.cpp @@ -1,5 +1,7 @@ #include #include +#include +#include #include #include @@ -23,6 +25,17 @@ static bool initRdmaTransport(device::RdmaTransport* t, void* gdr_buffer, return ret == 0; } +static bool macaHostPhaseFenceCoversPeers() { +#ifdef MOONCAKE_EP_USE_MACA + const char* env = std::getenv("MOONCAKE_EP_MACA_PHASE_FENCE"); + if (env == nullptr || env[0] == '\0') return true; + return std::strcmp(env, "0") != 0 && std::strcmp(env, "off") != 0 && + std::strcmp(env, "none") != 0; +#else + return false; +#endif +} + MooncakeEpBuffer::MooncakeEpBuffer(int rank, int num_ranks, int64_t num_ep_buffer_bytes, TransferEngine* engine) @@ -202,7 +215,7 @@ MooncakeEpBuffer::dispatch(const torch::Tensor& x, void** ipc_ptrs = p2p_transport_->peerPtrsTablePtr(); auto mark_send_done = [=]() { -#ifdef MOONCAKE_EP_USE_MUSA +#ifdef MOONCAKE_EP_SPLIT_SEND_RECV mooncake::mark_phase_ack(gdr_buffer, nvlink_avail, ipc_ptrs, buffer.rdma_send_signal_buffer, rank, num_ranks, phase_epoch, launch_stream); @@ -210,7 +223,7 @@ MooncakeEpBuffer::dispatch(const torch::Tensor& x, }; auto wait_peer_send_done = [=]() { -#ifdef MOONCAKE_EP_USE_MUSA +#ifdef MOONCAKE_EP_SPLIT_SEND_RECV mooncake::wait_phase_ack(buffer.rdma_send_signal_buffer, rank, num_ranks, phase_epoch, launch_stream, timeout_ticks); @@ -218,7 +231,7 @@ MooncakeEpBuffer::dispatch(const torch::Tensor& x, }; auto mark_and_wait_peer_send_done = [=]() { -#ifdef MOONCAKE_EP_USE_MUSA +#ifdef MOONCAKE_EP_SPLIT_SEND_RECV mooncake::mark_and_wait_phase_ack( gdr_buffer, nvlink_avail, ipc_ptrs, buffer.rdma_send_signal_buffer, rank, num_ranks, phase_epoch, launch_stream, timeout_ticks); @@ -244,7 +257,7 @@ MooncakeEpBuffer::dispatch(const torch::Tensor& x, launcher(LOW_LATENCY_SEND_PHASE); mark_send_done(); } else { -#ifdef MOONCAKE_EP_USE_MUSA +#ifdef MOONCAKE_EP_SPLIT_SEND_RECV launcher(LOW_LATENCY_SEND_PHASE); mark_and_wait_peer_send_done(); launcher(LOW_LATENCY_RECV_PHASE); @@ -260,6 +273,8 @@ MooncakeEpBuffer::dispatch(const torch::Tensor& x, // before the stream-wait happens, so in Python API, we must wrap // all tensors into the event handle. event = EventHandle(launch_stream); + } else if (return_recv_hook && macaHostPhaseFenceCoversPeers()) { + event = EventHandle(launch_stream); } else if (not return_recv_hook) { stream_wait(compute_stream, launch_stream); } @@ -268,7 +283,7 @@ MooncakeEpBuffer::dispatch(const torch::Tensor& x, std::optional> recv_hook = std::nullopt; if (return_recv_hook) recv_hook = [=]() { - wait_peer_send_done(); + if (!macaHostPhaseFenceCoversPeers()) wait_peer_send_done(); launcher(LOW_LATENCY_RECV_PHASE); }; @@ -358,7 +373,7 @@ MooncakeEpBuffer::combine(const torch::Tensor& x, const torch::Tensor& topk_idx, void** ipc_ptrs = p2p_transport_->peerPtrsTablePtr(); auto mark_send_done = [=]() { -#ifdef MOONCAKE_EP_USE_MUSA +#ifdef MOONCAKE_EP_SPLIT_SEND_RECV mooncake::mark_phase_ack(gdr_buffer, nvlink_avail, ipc_ptrs, buffer.rdma_send_signal_buffer, rank, num_ranks, phase_epoch, launch_stream); @@ -366,7 +381,7 @@ MooncakeEpBuffer::combine(const torch::Tensor& x, const torch::Tensor& topk_idx, }; auto wait_peer_send_done = [=]() { -#ifdef MOONCAKE_EP_USE_MUSA +#ifdef MOONCAKE_EP_SPLIT_SEND_RECV mooncake::wait_phase_ack(buffer.rdma_send_signal_buffer, rank, num_ranks, phase_epoch, launch_stream, timeout_ticks); @@ -374,7 +389,7 @@ MooncakeEpBuffer::combine(const torch::Tensor& x, const torch::Tensor& topk_idx, }; auto mark_and_wait_peer_send_done = [=]() { -#ifdef MOONCAKE_EP_USE_MUSA +#ifdef MOONCAKE_EP_SPLIT_SEND_RECV mooncake::mark_and_wait_phase_ack( gdr_buffer, nvlink_avail, ipc_ptrs, buffer.rdma_send_signal_buffer, rank, num_ranks, phase_epoch, launch_stream, timeout_ticks); @@ -400,7 +415,7 @@ MooncakeEpBuffer::combine(const torch::Tensor& x, const torch::Tensor& topk_idx, launcher(LOW_LATENCY_SEND_PHASE); mark_send_done(); } else { -#ifdef MOONCAKE_EP_USE_MUSA +#ifdef MOONCAKE_EP_SPLIT_SEND_RECV launcher(LOW_LATENCY_SEND_PHASE); mark_and_wait_peer_send_done(); launcher(LOW_LATENCY_RECV_PHASE); @@ -416,6 +431,8 @@ MooncakeEpBuffer::combine(const torch::Tensor& x, const torch::Tensor& topk_idx, // before the stream-wait happens, so in Python API, we must wrap // all tensors into the event handle. event = EventHandle(launch_stream); + } else if (return_recv_hook && macaHostPhaseFenceCoversPeers()) { + event = EventHandle(launch_stream); } else if (not return_recv_hook) { stream_wait(compute_stream, launch_stream); } @@ -424,7 +441,7 @@ MooncakeEpBuffer::combine(const torch::Tensor& x, const torch::Tensor& topk_idx, std::optional> recv_hook = std::nullopt; if (return_recv_hook) recv_hook = [=]() { - wait_peer_send_done(); + if (!macaHostPhaseFenceCoversPeers()) wait_peer_send_done(); launcher(LOW_LATENCY_RECV_PHASE); }; diff --git a/mooncake-ep/src/mooncake_ep_kernel.cu b/mooncake-ep/src/mooncake_ep_kernel.cu index 0ef1509784..0cede9695f 100644 --- a/mooncake-ep/src/mooncake_ep_kernel.cu +++ b/mooncake-ep/src/mooncake_ep_kernel.cu @@ -153,6 +153,22 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales, const auto warp_group_id = warp_id / kNumWarpsPerGroup; const auto sub_warp_id = warp_id % kNumWarpsPerGroup; const auto responsible_expert_idx = sm_id * kNumWarpGroups + warp_group_id; +#ifdef MOONCAKE_EP_USE_MACA + // C500 reports 64-thread hardware warps. Do not split the last hardware + // warp by assigning only the final 32-thread pseudo-warp to count work. + // Reserve one full warp group from the data path, but write counts from a + // single 32-thread lane group to avoid duplicate per-expert increments. + const bool is_count_warp = warp_group_id == kNumWarpGroups - 1; + const bool is_count_worker = is_count_warp && sub_warp_id == 0; + const bool is_data_warp = warp_group_id < kNumWarpGroups - 1; + const int num_send_threads = + (kNumWarpGroups - 1) * kNumWarpsPerGroup * 32; +#else + const bool is_count_warp = warp_id == num_warps - 1; + const bool is_count_worker = is_count_warp; + const bool is_data_warp = warp_id < num_warps - 1; + const int num_send_threads = (num_warps - 1) * 32; +#endif // FP8 staffs constexpr int kNumPerChannels = 128; @@ -183,14 +199,16 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales, // Expert counts __shared__ int shared_num_tokens_sent_per_expert[kNumWarpGroups]; - // There are 2 kinds of warps in this part: - // 1. The first-kind warps for FP8 cast and sending top-k tokens - // 2. The last warp for reading `topk_idx` and count for per-expert information - if (warp_id < num_warps - 1) { + // There are 2 kinds of execution lanes in this part: + // 1. Data lanes for FP8 cast and sending top-k tokens. + // 2. Count lanes for reading `topk_idx` and per-expert token counts. + // MACA reserves a full warp group for the count path; CUDA keeps the + // original final 32-thread warp behavior. + if (is_data_warp) { constexpr int kNumElemsPerRead = sizeof(int4) / EP_BF16_SIZE; EP_DEVICE_ASSERT(kHidden % kNumElemsPerRead == 0); EP_STATIC_ASSERT(kNumElemsPerRead * 32 % kNumPerChannels == 0, "Invalid vectorization"); - const auto num_threads = (num_warps - 1) * 32; + const auto num_threads = num_send_threads; const size_t hidden_bf16_int4 = kHidden / kNumElemsPerRead; for (int token_idx = sm_id; token_idx < num_tokens; token_idx += num_sms) { @@ -274,15 +292,17 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales, lane_id == 0 ? mc_atomic_add_release(atomic_finish_counter_per_expert + dst_expert_idx, 1) : 0; } } - } else if (warp_id == num_warps - 1) { -#ifdef MOONCAKE_EP_USE_MUSA + } else if (is_count_warp) { +#ifdef MOONCAKE_EP_SPLIT_SEND_RECV // Participate in __syncthreads() barriers from data warps. // Each token iteration in the send loop above calls - // __syncthreads() once; the count warp must match. + // __syncthreads() once; the count path must match. for (int token_idx = sm_id; token_idx < num_tokens; token_idx += num_sms) { __syncthreads(); } #endif + } + if (is_count_worker) { EP_DEVICE_ASSERT(num_sms > 1); if (sm_id == 0) { // The first SM is also responsible for cleaning the next buffer @@ -634,9 +654,9 @@ combine(void* combined_x, int32_t* active_ranks, } } } -#ifdef MOONCAKE_EP_USE_MUSA - // mc_grid_sync() is a no-op on MUSA; use a block-wide fence/barrier before - // reduction so threads see peer writes. +#ifdef MOONCAKE_EP_SPLIT_SEND_RECV + // mc_grid_sync() is a no-op on split-kernel platforms; use a block-wide + // fence/barrier before reduction so threads see peer writes. __syncthreads(); mc_fence(); __syncthreads(); diff --git a/mooncake-ep/tests/test_ep_grid.py b/mooncake-ep/tests/test_ep_grid.py index 2d5e42ddea..456c2944d2 100644 --- a/mooncake-ep/tests/test_ep_grid.py +++ b/mooncake-ep/tests/test_ep_grid.py @@ -22,6 +22,13 @@ def using_musa_backend() -> bool: } +def using_maca_backend() -> bool: + return ( + os.getenv("MOONCAKE_EP_USE_MACA", "").upper() in {"1", "ON", "TRUE", "YES"} + or bool(getattr(torch.version, "maca", None)) + ) + + def import_torchada_if_needed(): if not using_musa_backend(): return @@ -289,8 +296,9 @@ def make_test_name(cfg): def generate_tests(): + fp8_options = [False] if using_maca_backend() else [False, True] test_grid = { - "use_fp8": [False, True], + "use_fp8": fp8_options, "zero_copy": [False, True], "async_finish": [False, True], "return_recv_hook": [False, True], diff --git a/mooncake-pg/BuildPgExt.cmake b/mooncake-pg/BuildPgExt.cmake index ead67c3e8b..1829f4e3b3 100644 --- a/mooncake-pg/BuildPgExt.cmake +++ b/mooncake-pg/BuildPgExt.cmake @@ -12,6 +12,7 @@ # STAGING_DIR - destination directory for the built .so files # ENGINE_SO_PATH - absolute path to the built engine.cpython-XYZ.so # EP_USE_MUSA - set to "1" when building for MUSA (MTLink path) +# EP_USE_MACA - set to "1" when building for MACA (MTLink path) cmake_minimum_required(VERSION 3.16) @@ -41,6 +42,16 @@ if(EP_USE_MUSA) else() unset(ENV{MOONCAKE_EP_USE_MUSA}) endif() +if(EP_USE_MACA) + set(ENV{MOONCAKE_EP_USE_MACA} "1") + if(DEFINED ENV{MACA_PATH}) + set(ENV{MACA_HOME} "$ENV{MACA_PATH}") + elseif(DEFINED ENV{MACA_HOME}) + set(ENV{MACA_PATH} "$ENV{MACA_HOME}") + endif() +else() + unset(ENV{MOONCAKE_EP_USE_MACA}) +endif() # --------------------------------------------------------------------------- # 2. Ensure engine.so exists in mooncake-wheel/mooncake/ for setup.py linking. diff --git a/mooncake-pg/setup.py b/mooncake-pg/setup.py index e839339b5a..1fe68b53c2 100644 --- a/mooncake-pg/setup.py +++ b/mooncake-pg/setup.py @@ -28,13 +28,37 @@ abi_flag = int(torch._C._GLIBCXX_USE_CXX11_ABI) current_dir = os.path.abspath(os.path.dirname(__file__)) +repo_dir = os.path.abspath(os.path.join(current_dir, os.pardir)) +sysroot_dir = os.path.join(repo_dir, ".deps", "sysroot", "usr") + + +def existing_dirs(*paths): + return [path for path in paths if os.path.isdir(path)] + + +sysroot_include_dirs = existing_dirs( + os.path.join(sysroot_dir, "include"), + os.path.join(sysroot_dir, "include", "jsoncpp"), + os.path.join(sysroot_dir, "include", "libnl3"), +) +sysroot_library_dirs = existing_dirs( + os.path.join(sysroot_dir, "lib", "x86_64-linux-gnu"), + os.path.join(sysroot_dir, "lib"), +) abi_define = f"-D_GLIBCXX_USE_CXX11_ABI={abi_flag}" cxx_args = [abi_define, "-std=c++20", "-O3", "-g0"] cuda_libraries = ["ibverbs", "mlx5"] cuda_library_dirs = [] -use_maca = hasattr(torch.version, "maca") and torch.version.maca is not None +include_dirs = [ + os.path.join(current_dir, "include"), + os.path.join(current_dir, "../mooncake-transfer-engine/include"), +] +use_maca = ( + os.getenv("MOONCAKE_EP_USE_MACA", "").upper() in {"1", "ON", "TRUE", "YES"} + or (hasattr(torch.version, "maca") and torch.version.maca is not None) +) if use_musa: musa_defines = ["-DUSE_MUSA", "-DMOONCAKE_EP_USE_MUSA=1"] @@ -50,7 +74,10 @@ ] else: if use_maca: - cxx_args.append("-DUSE_MACA") + cuda_libraries = [] + cuda_library_dirs = sysroot_library_dirs.copy() + include_dirs += sysroot_include_dirs + cxx_args += ["-DUSE_MACA", "-DMOONCAKE_EP_USE_MACA=1"] device_args = [ abi_define, "-std=c++20", @@ -60,10 +87,10 @@ "-g0", ] if use_maca: - device_args.append("-DUSE_MACA") + device_args += ["-DUSE_MACA", "-DMOONCAKE_EP_USE_MACA=1"] # Link against the CUDA driver stub library if available. # Same approach as mooncake-ep/setup.py. - if CUDA_HOME is not None: + if not use_maca and CUDA_HOME is not None: cuda_stub_dir = os.path.join(CUDA_HOME, "lib64", "stubs") cuda_stub_lib = os.path.join(cuda_stub_dir, "libcuda.so") if os.path.exists(cuda_stub_lib): @@ -75,10 +102,7 @@ ext_modules=[ CUDAExtension( name=module_name, - include_dirs=[ - os.path.join(current_dir, "include"), - os.path.join(current_dir, "../mooncake-transfer-engine/include"), - ], + include_dirs=include_dirs, sources=[ "src/pg_py.cpp", "src/mooncake_backend.cpp", diff --git a/mooncake-transfer-engine/include/CMakeLists.txt b/mooncake-transfer-engine/include/CMakeLists.txt index c9486b9b52..5500d1ae45 100644 --- a/mooncake-transfer-engine/include/CMakeLists.txt +++ b/mooncake-transfer-engine/include/CMakeLists.txt @@ -19,6 +19,7 @@ install(FILES transport/device/p2p_device.cuh DESTINATION include/transport/devi install(FILES transport/device/ibgda_device.cuh DESTINATION include/transport/device) install(FILES transport/device/cuda/cuda_ops.cuh DESTINATION include/transport/device/cuda) install(FILES transport/device/musa/musa_ops.cuh DESTINATION include/transport/device/musa) +install(FILES transport/device/maca/maca_ops.cuh DESTINATION include/transport/device/maca) # IBGDA library headers install(DIRECTORY transport/device/ibgda/ DESTINATION include/transport/device/ibgda diff --git a/mooncake-transfer-engine/include/transfer_engine.h b/mooncake-transfer-engine/include/transfer_engine.h index 8d5b8d2e71..0e8233d76d 100644 --- a/mooncake-transfer-engine/include/transfer_engine.h +++ b/mooncake-transfer-engine/include/transfer_engine.h @@ -25,7 +25,8 @@ class TransferEngineImpl; namespace tent { class TransferEngine; }; -#if defined(USE_CUDA) || defined(USE_MUSA) +#if (defined(USE_CUDA) || defined(USE_MUSA) || defined(USE_MACA)) && \ + !defined(USE_CXI) namespace device { class P2pTransport; class RdmaTransport; @@ -156,7 +157,8 @@ class TransferEngine { Transport* getTransport(const std::string& proto); -#if defined(USE_CUDA) || defined(USE_MUSA) +#if (defined(USE_CUDA) || defined(USE_MUSA) || defined(USE_MACA)) && \ + !defined(USE_CXI) // Device transport accessors (P2P + IBGDA). Lazily created on first // call and owned by the TransferEngine. These allow EP (and future // CPU-proxy paths) to obtain device transports from an engine instance diff --git a/mooncake-transfer-engine/include/transfer_engine_impl.h b/mooncake-transfer-engine/include/transfer_engine_impl.h index b1e4fff7e8..c5d919d232 100644 --- a/mooncake-transfer-engine/include/transfer_engine_impl.h +++ b/mooncake-transfer-engine/include/transfer_engine_impl.h @@ -33,7 +33,8 @@ #include "transfer_metadata.h" #include "transfer_engine.h" #include "transport/transport.h" -#if defined(USE_CUDA) || defined(USE_MUSA) +#if (defined(USE_CUDA) || defined(USE_MUSA) || defined(USE_MACA)) && \ + !defined(USE_CXI) #include "transport/device/device_transport.h" #endif #ifdef WITH_METRICS @@ -344,7 +345,8 @@ class TransferEngineImpl { return multi_transports_->getTransport(proto); } -#if defined(USE_CUDA) || defined(USE_MUSA) +#if (defined(USE_CUDA) || defined(USE_MUSA) || defined(USE_MACA)) && \ + !defined(USE_CXI) // Device transport accessors — lazily created, owned by this impl. device::P2pTransport* getOrCreateP2pTransport(int num_ranks); device::RdmaTransport* getOrCreateRdmaTransport( @@ -427,7 +429,8 @@ class TransferEngineImpl { std::vector filter_; bool use_barex_ = false; -#if defined(USE_CUDA) || defined(USE_MUSA) +#if (defined(USE_CUDA) || defined(USE_MUSA) || defined(USE_MACA)) && \ + !defined(USE_CXI) // Device transports (P2P + IBGDA) — lazily created, owned by this impl. // Referenced by EP and future CPU-proxy paths. std::unique_ptr p2p_transport_; diff --git a/mooncake-transfer-engine/include/transport/device/comm_device.cuh b/mooncake-transfer-engine/include/transport/device/comm_device.cuh index 29e6c6c6f4..a3be4bd619 100644 --- a/mooncake-transfer-engine/include/transport/device/comm_device.cuh +++ b/mooncake-transfer-engine/include/transport/device/comm_device.cuh @@ -36,7 +36,11 @@ __device__ __forceinline__ CommCtx make_comm_ctx( ctx.p2p.peer_ptrs = ipc_peer_ptrs; ctx.p2p.local_base = gdr_buffer; +#ifdef MOONCAKE_EP_USE_MACA + ctx.ibgda.qp_devctxs = qp_devctxs; +#else ctx.ibgda.qp_devctxs = reinterpret_cast(qp_devctxs); +#endif ctx.ibgda.raddrs = reinterpret_cast(raddrs); ctx.ibgda.rkeys = reinterpret_cast(rkeys); ctx.ibgda.local_atomic_base = rdma_send_signal_buffer; diff --git a/mooncake-transfer-engine/include/transport/device/device_ops.cuh b/mooncake-transfer-engine/include/transport/device/device_ops.cuh index 2835f81533..29bf4b7d01 100644 --- a/mooncake-transfer-engine/include/transport/device/device_ops.cuh +++ b/mooncake-transfer-engine/include/transport/device/device_ops.cuh @@ -7,6 +7,8 @@ #ifdef MOONCAKE_EP_USE_MUSA #include "transport/device/musa/musa_ops.cuh" +#elif defined(MOONCAKE_EP_USE_MACA) +#include "transport/device/maca/maca_ops.cuh" #else #include "transport/device/cuda/cuda_ops.cuh" #endif diff --git a/mooncake-transfer-engine/include/transport/device/ibgda_device.cuh b/mooncake-transfer-engine/include/transport/device/ibgda_device.cuh index 15fcb990c8..2376f3b42e 100644 --- a/mooncake-transfer-engine/include/transport/device/ibgda_device.cuh +++ b/mooncake-transfer-engine/include/transport/device/ibgda_device.cuh @@ -11,6 +11,32 @@ #include #include "transport/device/device_ops.cuh" +#ifdef MOONCAKE_EP_USE_MACA + +namespace mooncake { +namespace device { + +struct IbgdaContext { + void* qp_devctxs; + const uint64_t* raddrs; + const uint32_t* rkeys; + const void* local_atomic_base; + const void* remote_atomic_base; +}; + +__device__ __forceinline__ void mc_ibgda_put(const IbgdaContext&, int, int, int, + int, const void*, uint64_t, + uint32_t) {} + +__device__ __forceinline__ void mc_ibgda_red_add(const IbgdaContext&, int, int, + int, int, uint64_t, uint64_t, + int32_t) {} + +} // namespace device +} // namespace mooncake + +#else // !MOONCAKE_EP_USE_MACA + #ifndef MOONCAKE_EP_USE_MUSA #include #endif @@ -204,3 +230,5 @@ __device__ __forceinline__ void mc_ibgda_red_add( } // namespace device } // namespace mooncake + +#endif // MOONCAKE_EP_USE_MACA diff --git a/mooncake-transfer-engine/include/transport/device/maca/maca_ops.cuh b/mooncake-transfer-engine/include/transport/device/maca/maca_ops.cuh new file mode 100644 index 0000000000..50fdff72f7 --- /dev/null +++ b/mooncake-transfer-engine/include/transport/device/maca/maca_ops.cuh @@ -0,0 +1,96 @@ +// MACA implementations of device-side memory ordering primitives. +// +// MACA's cu-bridge compiler accepts CUDA-like intrinsics, but does not reliably +// compile the PTX acquire/release/barrier instructions used by the CUDA path. +#pragma once + +#include + +namespace mooncake { +namespace device { + +__device__ __forceinline__ int mc_ld_acquire(const int* ptr) { + __threadfence_system(); + return *const_cast(ptr); +} + +__device__ __forceinline__ uint64_t mc_ld_acquire_u64(const uint64_t* ptr) { + __threadfence_system(); + return *const_cast(ptr); +} + +__device__ __forceinline__ void mc_st_release(const int* ptr, int val) { + *const_cast(ptr) = val; + __threadfence_system(); +} + +__device__ __forceinline__ void mc_st_release_u32(const uint32_t* ptr, + uint32_t val) { + *const_cast(ptr) = val; + __threadfence_system(); +} + +__device__ __forceinline__ void mc_st_release_u64(const uint64_t* ptr, + uint64_t val) { + *const_cast(ptr) = val; + __threadfence_system(); +} + +__device__ __forceinline__ int mc_atomic_add_release(const int* ptr, int val) { + int ret = atomicAdd(const_cast(ptr), val); + __threadfence_system(); + return ret; +} + +__device__ __forceinline__ int4 mc_ld_nc(const int4* ptr) { return __ldg(ptr); } + +__device__ __forceinline__ int mc_ld_nc_s32(const int* ptr) { + return __ldg(ptr); +} + +__device__ __forceinline__ float mc_ld_nc_f32(const float* ptr) { + return __ldg(ptr); +} + +__device__ __forceinline__ int64_t mc_ld_nc_s64(const int64_t* ptr) { + return __ldg(ptr); +} + +__device__ __forceinline__ void mc_st_na(const int4* ptr, const int4& val) { + *const_cast(ptr) = val; +} + +__device__ __forceinline__ void mc_bar_init() {} + +__device__ __forceinline__ void mc_bar_sync(int /*bar_id*/, + int /*num_threads*/) { + __syncthreads(); +} + +__device__ __forceinline__ void mc_grid_sync() {} + +__device__ __forceinline__ void mc_fence() { __threadfence_system(); } + +__device__ __forceinline__ void mc_fence_barrier_fence() { + mc_fence(); + mc_bar_sync(0, 0); + mc_fence(); +} + +__device__ __forceinline__ uint16_t mc_bswap16(uint16_t x) { + return (uint16_t)(((x & 0x00FFu) << 8) | ((x & 0xFF00u) >> 8)); +} + +__device__ __forceinline__ uint32_t mc_bswap32(uint32_t x) { + return ((x & 0x000000FFu) << 24) | ((x & 0x0000FF00u) << 8) | + ((x & 0x00FF0000u) >> 8) | ((x & 0xFF000000u) >> 24); +} + +__device__ __forceinline__ uint64_t mc_bswap64(uint64_t x) { + uint32_t hi = mc_bswap32((uint32_t)(x >> 32)); + uint32_t lo = mc_bswap32((uint32_t)(x)); + return ((uint64_t)lo << 32) | hi; +} + +} // namespace device +} // namespace mooncake diff --git a/mooncake-transfer-engine/src/transfer_engine.cpp b/mooncake-transfer-engine/src/transfer_engine.cpp index 3747244e76..7bd2ced5d2 100644 --- a/mooncake-transfer-engine/src/transfer_engine.cpp +++ b/mooncake-transfer-engine/src/transfer_engine.cpp @@ -179,7 +179,8 @@ Transport* TransferEngine::getTransport(const std::string& proto) { return impl_->getTransport(proto); } -#if defined(USE_CUDA) || defined(USE_MUSA) +#if (defined(USE_CUDA) || defined(USE_MUSA) || defined(USE_MACA)) && \ + !defined(USE_CXI) device::P2pTransport* TransferEngine::getOrCreateP2pTransport(int num_ranks) { return impl_->getOrCreateP2pTransport(num_ranks); } @@ -595,7 +596,8 @@ Transport* TransferEngine::getTransport(const std::string& proto) { return impl_->getTransport(proto); } -#if defined(USE_CUDA) || defined(USE_MUSA) +#if (defined(USE_CUDA) || defined(USE_MUSA) || defined(USE_MACA)) && \ + !defined(USE_CXI) device::P2pTransport* TransferEngine::getOrCreateP2pTransport(int num_ranks) { if (use_tent_) return nullptr; return impl_->getOrCreateP2pTransport(num_ranks); diff --git a/mooncake-transfer-engine/src/transfer_engine_impl.cpp b/mooncake-transfer-engine/src/transfer_engine_impl.cpp index 83960ad575..a212be23d2 100644 --- a/mooncake-transfer-engine/src/transfer_engine_impl.cpp +++ b/mooncake-transfer-engine/src/transfer_engine_impl.cpp @@ -444,7 +444,8 @@ int TransferEngineImpl::uninstallTransport(const std::string& proto) { return 0; } -#if defined(USE_CUDA) || defined(USE_MUSA) +#if (defined(USE_CUDA) || defined(USE_MUSA) || defined(USE_MACA)) && \ + !defined(USE_CXI) device::P2pTransport* TransferEngineImpl::getOrCreateP2pTransport( int num_ranks) { if (!p2p_transport_) { diff --git a/mooncake-transfer-engine/src/transport/CMakeLists.txt b/mooncake-transfer-engine/src/transport/CMakeLists.txt index 47a1dc99cf..5b7ecdc0eb 100644 --- a/mooncake-transfer-engine/src/transport/CMakeLists.txt +++ b/mooncake-transfer-engine/src/transport/CMakeLists.txt @@ -78,10 +78,18 @@ if (USE_EFA) target_link_libraries(transport PRIVATE fabric) endif() -if(USE_CUDA OR USE_MUSA) +if(USE_CXI) + add_subdirectory(cxi_transport) + target_sources(transport PUBLIC $) + target_link_libraries(transport PRIVATE fabric) +endif() + +if(USE_CUDA OR USE_MUSA OR USE_MACA) add_subdirectory(device) target_sources(transport PUBLIC $) # device_transport (ibgda_device_transport.cpp / mlx5gda.cpp) calls libmlx5 # DevX symbols (mlx5dv_devx_*, mlx5dv_init_obj) directly. - target_link_libraries(transport PUBLIC mlx5) + if((USE_CUDA OR USE_MUSA) AND NOT USE_CXI) + target_link_libraries(transport PUBLIC mlx5) + endif() endif() diff --git a/mooncake-transfer-engine/src/transport/device/CMakeLists.txt b/mooncake-transfer-engine/src/transport/device/CMakeLists.txt index 20ec48c850..60fa9872d7 100644 --- a/mooncake-transfer-engine/src/transport/device/CMakeLists.txt +++ b/mooncake-transfer-engine/src/transport/device/CMakeLists.txt @@ -6,8 +6,10 @@ # written ldflags (Go p2p-store / mooncake-store), which bypass CMake's # target_link_libraries propagation. set(DEVICE_TRANSPORT_SOURCES p2p_device_transport.cpp) -if(USE_CUDA OR USE_MUSA) +if((USE_CUDA OR USE_MUSA) AND NOT USE_CXI) list(APPEND DEVICE_TRANSPORT_SOURCES ibgda_device_transport.cpp mlx5gda.cpp) +elseif(USE_MACA) + list(APPEND DEVICE_TRANSPORT_SOURCES ibgda_device_transport_maca_stub.cpp) endif() add_library(device_transport OBJECT ${DEVICE_TRANSPORT_SOURCES}) @@ -20,3 +22,6 @@ if(USE_MUSA) target_include_directories(device_transport PRIVATE /usr/local/musa/include) target_compile_definitions(device_transport PRIVATE USE_MUSA) endif() +if(USE_MACA) + target_compile_definitions(device_transport PRIVATE USE_MACA) +endif() diff --git a/mooncake-transfer-engine/src/transport/device/ibgda_device_transport_maca_stub.cpp b/mooncake-transfer-engine/src/transport/device/ibgda_device_transport_maca_stub.cpp new file mode 100644 index 0000000000..d59f8409e6 --- /dev/null +++ b/mooncake-transfer-engine/src/transport/device/ibgda_device_transport_maca_stub.cpp @@ -0,0 +1,34 @@ +#include "transport/device/device_transport.h" + +namespace mooncake { +namespace device { + +class NullRdmaTransport : public RdmaTransport { + public: + int initialize(const std::string&, int, int) override { return -1; } + int registerMemory(void*, size_t) override { return -1; } + int allocateControlBuffer() override { return -1; } + int createQueuePairs(void*) override { return -1; } + int recreateQueuePairs(void*) override { return -1; } + int connectPeers(int, bool, const std::vector&, + const std::vector&, const std::vector&, + const std::vector&, const std::vector&, + const std::vector&, + const std::vector&) override { + return -1; + } + RdmaLocalMetadata localMetadata() const override { return {}; } + void* raddrsPtr() override { return nullptr; } + void* rkeysPtr() override { return nullptr; } + void* qpDevCtxsPtr() override { return nullptr; } + bool isRoce() const override { return false; } + int gidIndex() const override { return -1; } +}; + +std::unique_ptr createIbgdaDeviceTransport( + const std::vector&) { + return std::make_unique(); +} + +} // namespace device +} // namespace mooncake diff --git a/mooncake-transfer-engine/src/transport/device/p2p_device_transport.cpp b/mooncake-transfer-engine/src/transport/device/p2p_device_transport.cpp index 23e0225531..1f259ec068 100644 --- a/mooncake-transfer-engine/src/transport/device/p2p_device_transport.cpp +++ b/mooncake-transfer-engine/src/transport/device/p2p_device_transport.cpp @@ -21,13 +21,151 @@ #include #include +#include +#include #include +#include #include "cuda_alike.h" namespace mooncake { namespace device { +#ifdef USE_MACA +namespace { + +bool parseBoolEnv(const char* name) { + const char* value = std::getenv(name); + if (value == nullptr) return false; + std::string s(value); + std::transform(s.begin(), s.end(), s.begin(), [](unsigned char c) { + return static_cast(std::tolower(c)); + }); + return s == "1" || s == "on" || s == "true" || s == "yes"; +} + +std::string getLowerEnv(const char* name) { + const char* value = std::getenv(name); + if (value == nullptr) return ""; + std::string s(value); + std::transform(s.begin(), s.end(), s.begin(), [](unsigned char c) { + return static_cast(std::tolower(c)); + }); + return s; +} + +int macaAllocFlagFromMode(const std::string& mode, const char* env_name) { + if (mode.empty() || mode == "default" || mode == "cuda") + return mcDeviceMallocDefault; + if (mode == "fine" || mode == "finegrained" || mode == "fine-grained") + return mcDeviceMallocFinegrained; + if (mode == "signal") return mcMallocSignalMemory; + if (mode == "wc" || mode == "writecoherence" || mode == "write-coherence") + return mcDeviceMallocWriteCoherence; + if (mode == "pcie" || mode == "pcie-uncache" || mode == "map-pcie") + return mcDeviceMallocMapPcieDefault; + if (mode == "pcie-wc" || mode == "map-pcie-wc") + return mcDeviceMallocMapPcieCoherence; + if (mode == "fixed" || mode == "fixed-uncache") + return mcDeviceMallocFixedMemDefault; + if (mode == "fixed-wc") return mcDeviceMallocFixedMemCoherence; + LOG(WARNING) << "[EP P2P] unknown " << env_name << "=" << mode + << ", using default cudaMalloc"; + return mcDeviceMallocDefault; +} + +int macaAllocFlagFromEnv() { + return macaAllocFlagFromMode(getLowerEnv("MOONCAKE_EP_MACA_ALLOC"), + "MOONCAKE_EP_MACA_ALLOC"); +} + +std::string macaIpcMode() { + std::string mode = getLowerEnv("MOONCAKE_EP_MACA_IPC"); + return mode.empty() ? "normal" : mode; +} + +bool parseNonNegativeInt(const std::string& token, int* value) { + if (token.empty()) return false; + int result = 0; + for (char c : token) { + if (!std::isdigit(static_cast(c))) return false; + result = result * 10 + (c - '0'); + } + *value = result; + return true; +} + +int physicalDeviceFromVisibleList(int logical_device) { + const char* visible = std::getenv("CUDA_VISIBLE_DEVICES"); + if (visible == nullptr || visible[0] == '\0') return logical_device; + + std::string list(visible); + size_t begin = 0; + int logical = 0; + while (begin <= list.size()) { + size_t end = list.find(',', begin); + if (end == std::string::npos) end = list.size(); + std::string token = list.substr(begin, end - begin); + token.erase(std::remove_if( + token.begin(), token.end(), + [](unsigned char c) { return std::isspace(c) != 0; }), + token.end()); + if (logical == logical_device) { + int physical = logical_device; + return parseNonNegativeInt(token, &physical) ? physical + : logical_device; + } + if (end == list.size()) break; + begin = end + 1; + ++logical; + } + return logical_device; +} + +bool pairListed(const std::string& pairs, int src, int dst) { + size_t begin = 0; + while (begin <= pairs.size()) { + size_t end = pairs.find(',', begin); + if (end == std::string::npos) end = pairs.size(); + std::string item = pairs.substr(begin, end - begin); + item.erase(std::remove_if( + item.begin(), item.end(), + [](unsigned char c) { return std::isspace(c) != 0; }), + item.end()); + + size_t dash = item.find('-'); + if (dash != std::string::npos) { + int a = -1, b = -1; + if (parseNonNegativeInt(item.substr(0, dash), &a) && + parseNonNegativeInt(item.substr(dash + 1), &b)) { + if ((a == src && b == dst) || (a == dst && b == src)) + return true; + } + } + + if (end == pairs.size()) break; + begin = end + 1; + } + return false; +} + +bool macaP2pPairAllowed(int src_physical, int dst_physical) { + if (parseBoolEnv("MOONCAKE_EP_MACA_ALLOW_NODE_P2P")) return true; + + const char* explicit_pairs = std::getenv("MOONCAKE_EP_MACA_P2P_PAIRS"); + if (explicit_pairs != nullptr && explicit_pairs[0] != '\0') + return pairListed(explicit_pairs, src_physical, dst_physical); + + // C500 exposes two direct MetaXLink islands by default: 0<->1 and 2<->3. + // NODE pairs may report canAccessPeer=1, but EP kernel peer stores can hang + // waiting for device-side signals on those paths. + return src_physical / 2 == dst_physical / 2 && + std::abs(src_physical - dst_physical) == 1; +} + +} // namespace +#endif + class P2pDeviceTransportImpl : public P2pTransport { public: explicit P2pDeviceTransportImpl(int num_ranks) : num_ranks_(num_ranks) { @@ -54,22 +192,79 @@ class P2pDeviceTransportImpl : public P2pTransport { void* allocateBuffer(size_t bytes) override { void* ptr = nullptr; +#ifdef USE_MACA + int alloc_flag = macaAllocFlagFromEnv(); + cudaError_t err = alloc_flag == mcDeviceMallocDefault + ? cudaMalloc(&ptr, bytes) + : mcExtMallocWithFlags(&ptr, bytes, alloc_flag); +#else cudaError_t err = cudaMalloc(&ptr, bytes); +#endif if (err != cudaSuccess) { - LOG(ERROR) << "[EP P2P] cudaMalloc(" << bytes + LOG(ERROR) << "[EP P2P] device allocation(" << bytes << ") failed: " << cudaGetErrorString(err); return nullptr; } +#ifdef USE_MACA + if (alloc_flag != mcDeviceMallocDefault) { + LOG(INFO) << "[EP P2P] allocated MACA buffer with " + "mcExtMallocWithFlags flag=" + << alloc_flag; + } +#endif return ptr; } void freeBuffer(void* ptr) override { cudaFree(ptr); } std::vector exportIpcHandle(void* ptr) override { +#ifdef USE_MACA + if (parseBoolEnv("MOONCAKE_EP_MACA_DISABLE_IPC")) { + LOG(INFO) << "[EP P2P] MACA IPC handle export disabled by " + "MOONCAKE_EP_MACA_DISABLE_IPC"; + return {}; + } + + cudaPointerAttributes attr{}; + cudaError_t attr_err = cudaPointerGetAttributes(&attr, ptr); + if (attr_err != cudaSuccess || attr.type != cudaMemoryTypeDevice || + attr.devicePointer == nullptr) { + LOG(WARNING) << "[EP P2P] skip MACA IPC handle export for " + << "non-device pointer=" << ptr + << ", attr_err=" << cudaGetErrorString(attr_err) + << ", type=" << attr.type + << ", devicePointer=" << attr.devicePointer + << ", allocationFlags=" << attr.allocationFlags; + return {}; + } + + std::string ipc_mode = macaIpcMode(); + if (ipc_mode == "cross-v2" || ipc_mode == "cross_v2") { + mcIpcCrossMemHandle_t handle; + cudaError_t err = mcIpcGetMemHandleCross_v2(&handle, ptr); + if (err != cudaSuccess) { + LOG(ERROR) << "[EP P2P] mcIpcGetMemHandleCross_v2 failed: " + << cudaGetErrorString(err); + return {}; + } + constexpr size_t kHandleBytes = sizeof(mcIpcCrossMemHandle_t); + constexpr size_t kNumInt32s = + (kHandleBytes + sizeof(int32_t) - 1) / sizeof(int32_t); + std::vector result(kNumInt32s); + memcpy(result.data(), &handle, kHandleBytes); + return result; + } +#endif cudaIpcMemHandle_t handle; +#ifdef USE_MACA + cudaError_t err = macaIpcMode() == "cross" + ? mcIpcGetMemHandleCross(&handle, ptr) + : cudaIpcGetMemHandle(&handle, ptr); +#else cudaError_t err = cudaIpcGetMemHandle(&handle, ptr); +#endif if (err != cudaSuccess) { - LOG(ERROR) << "[EP P2P] cudaIpcGetMemHandle failed: " + LOG(ERROR) << "[EP P2P] IPC handle export failed: " << cudaGetErrorString(err); return {}; } @@ -112,6 +307,20 @@ class P2pDeviceTransportImpl : public P2pTransport { << "): canAccessPeer=" << can_access; if (!can_access) continue; +#ifdef USE_MACA + int src_physical = physicalDeviceFromVisibleList(device_id); + int dst_physical = physicalDeviceFromVisibleList(dst_device); + if (!macaP2pPairAllowed(src_physical, dst_physical)) { + LOG(INFO) << "[EP P2P] rank " << rank << " physical GPU" + << src_physical << " -> rank " << dst + << " physical GPU" << dst_physical + << " disabled for MACA EP fast path; set " + "MOONCAKE_EP_MACA_ALLOW_NODE_P2P=1 or " + "MOONCAKE_EP_MACA_P2P_PAIRS to override"; + continue; + } +#endif + cudaError_t err = cudaDeviceEnablePeerAccess(dst_device, 0); if (err != cudaSuccess && err != cudaErrorPeerAccessAlreadyEnabled) { @@ -129,14 +338,50 @@ class P2pDeviceTransportImpl : public P2pTransport { constexpr size_t kHandleBytes = sizeof(cudaIpcMemHandle_t); constexpr size_t kNumInt32s = (kHandleBytes + sizeof(int32_t) - 1) / sizeof(int32_t); +#ifdef USE_MACA + std::string ipc_mode = macaIpcMode(); + if (ipc_mode == "cross-v2" || ipc_mode == "cross_v2") { + constexpr size_t kCrossHandleBytes = + sizeof(mcIpcCrossMemHandle_t); + constexpr size_t kCrossNumInt32s = + (kCrossHandleBytes + sizeof(int32_t) - 1) / sizeof(int32_t); + if (h.size() < kCrossNumInt32s) continue; + mcIpcCrossMemHandle_t handle; + memcpy(&handle, h.data(), kCrossHandleBytes); + void* peer_ptr = nullptr; + err = mcIpcOpenMemHandleCross_v2( + &peer_ptr, &handle, cudaIpcMemLazyEnablePeerAccess); + if (err != cudaSuccess) { + LOG(WARNING) + << "[EP P2P] rank " << rank + << " failed to open cross_v2 IPC handle for rank " + << dst << ": " << cudaGetErrorString(err); + continue; + } + LOG(INFO) << "[EP P2P] rank " << rank + << " opened cross_v2 IPC handle for rank " << dst + << ": peer_ptr=" << peer_ptr; + available[dst] = 1; + peer_ptrs_host_[dst] = peer_ptr; + continue; + } +#endif if (h.size() < kNumInt32s) continue; cudaIpcMemHandle_t handle; memcpy(&handle, h.data(), kHandleBytes); void* peer_ptr = nullptr; +#ifdef USE_MACA + err = ipc_mode == "cross" + ? mcIpcOpenMemHandleCross(&peer_ptr, handle, + cudaIpcMemLazyEnablePeerAccess) + : cudaIpcOpenMemHandle(&peer_ptr, handle, + cudaIpcMemLazyEnablePeerAccess); +#else err = cudaIpcOpenMemHandle(&peer_ptr, handle, cudaIpcMemLazyEnablePeerAccess); +#endif if (err != cudaSuccess) { LOG(WARNING) << "[EP P2P] rank " << rank << " failed to open IPC handle for rank " << dst diff --git a/mooncake-wheel/mooncake/mooncake_ep_buffer.py b/mooncake-wheel/mooncake/mooncake_ep_buffer.py index 00ed2d1d95..daf7bbe622 100644 --- a/mooncake-wheel/mooncake/mooncake_ep_buffer.py +++ b/mooncake-wheel/mooncake/mooncake_ep_buffer.py @@ -20,8 +20,12 @@ def _env_enabled(name: str, default: bool = False) -> bool: or _USE_MACA ) _MACA_PHASE_FENCE = os.getenv("MOONCAKE_EP_MACA_PHASE_FENCE", "p2p").lower() -_USE_TORCH_ALLTOALL = _env_enabled("MOONCAKE_EP_USE_TORCH_ALLTOALL") -_PROFILE_TORCH_ALLTOALL = _env_enabled("MOONCAKE_EP_PROFILE_TORCH_ALLTOALL") +_DEBUG_INIT = _env_enabled("MOONCAKE_EP_DEBUG_INIT") + + +def _debug_init(rank: int, message: str) -> None: + if _DEBUG_INIT: + print(f"[rank {rank}] {message}", flush=True) class EventOverlap: @@ -84,53 +88,63 @@ def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: class Buffer: def __init__(self, group: dist.ProcessGroup, num_ep_buffer_bytes: int = 0): + from mooncake import ep + # Initialize the CPP runtime self.rank = group.rank() self.group_size = group.size() self.group = group self.num_ep_buffer_bytes = num_ep_buffer_bytes self.backend = self.group - self._use_torch_alltoall = _USE_TORCH_ALLTOALL - if self._use_torch_alltoall: - self.runtime = None - else: - from mooncake import ep - - # NIC auto-detection happens inside ep.Buffer via Topology::discover(). - self.runtime = ep.Buffer( - self.rank, self.group_size, num_ep_buffer_bytes - ) + # NIC auto-detection happens inside ep.Buffer via Topology::discover(). + _debug_init(self.rank, "before ep.Buffer") + self.runtime = ep.Buffer( + self.rank, self.group_size, num_ep_buffer_bytes + ) + _debug_init(self.rank, "after ep.Buffer") # Fallback flag and buffers. # Note: `sync_nvlink_ipc_handles()` can mutate C++ `ibgda_disabled_` (True->False when # P2P+IPC succeeds for all ranks). We re-evaluate after IPC sync below. - self._use_fallback = self._use_torch_alltoall or bool( - self.runtime is not None and self.runtime.ibgda_disabled() - ) + self._use_fallback = bool(self.runtime.ibgda_disabled()) self._fallback_next_combine_buffer: Optional[torch.Tensor] = None - self._torch_alltoall_profile = {} - self._torch_alltoall_state = {} self._maca_phase_token: Optional[torch.Tensor] = None self._maca_phase_recv_tokens: Optional[List[torch.Tensor]] = None self.connect() - def _maca_phase_fence(self) -> None: + def _maca_phase_fence(self, send_event: Optional[Any] = None) -> None: if not _USE_MACA or _MACA_PHASE_FENCE in {"", "0", "off", "none"}: return + backend = dist.get_backend(self.group) + fence_device = torch.device("cpu" if backend == "gloo" else "cuda") + + def wait_send_done() -> None: + if send_event is not None: + send_event.synchronize() + else: + torch.cuda.synchronize() + # Compatibility fence between SEND and RECV. The EP payload still # uses the P2P fast path; this only keeps rank phases aligned on MACA. if _MACA_PHASE_FENCE == "barrier": - torch.cuda.synchronize() + wait_send_done() dist.barrier(self.group) return if _MACA_PHASE_FENCE == "p2p": - if self._maca_phase_token is None: + wait_send_done() + if ( + self._maca_phase_token is None + or self._maca_phase_token.device != fence_device + ): self._maca_phase_token = torch.empty( - 1, dtype=torch.int32, device="cuda" + 1, dtype=torch.int32, device=fence_device ) - if self._maca_phase_recv_tokens is None: + if ( + self._maca_phase_recv_tokens is None + or self._maca_phase_recv_tokens[0].device != fence_device + ): self._maca_phase_recv_tokens = [ - torch.empty(1, dtype=torch.int32, device="cuda") + torch.empty(1, dtype=torch.int32, device=fence_device) for _ in range(self.group_size) ] self._maca_phase_token.fill_(1) @@ -161,18 +175,24 @@ def _maca_phase_fence(self) -> None: "MOONCAKE_EP_MACA_PHASE_FENCE must be one of: " "p2p, allreduce, barrier, none" ) - if self._maca_phase_token is None: + wait_send_done() + if ( + self._maca_phase_token is None + or self._maca_phase_token.device != fence_device + ): self._maca_phase_token = torch.empty( - 1, dtype=torch.int32, device="cuda" + 1, dtype=torch.int32, device=fence_device ) self._maca_phase_token.fill_(1) dist.all_reduce( self._maca_phase_token, op=dist.ReduceOp.SUM, group=self.group ) - def _wrap_maca_recv_hook(self, hook: Optional[Callable]) -> Callable: + def _wrap_maca_recv_hook( + self, hook: Optional[Callable], send_event: Optional[Any] + ) -> Callable: def wrapped_hook() -> None: - self._maca_phase_fence() + self._maca_phase_fence(send_event) if hook is not None: hook() @@ -181,12 +201,11 @@ def wrapped_hook() -> None: def connect(self, is_update: bool = False): from mooncake import ep - if self._use_torch_alltoall: - self._use_fallback = True - return - + _debug_init(self.rank, f"connect start fallback={self._use_fallback}") if not self._use_fallback: + _debug_init(self.rank, "before get_mr_info") (raddr, rkey) = self.runtime.get_mr_info() + _debug_init(self.rank, "after get_mr_info") raddr = torch.tensor([raddr], dtype=torch.int64, device="cuda") raddrs = [ @@ -256,57 +275,79 @@ def connect(self, is_update: bool = False): dist.all_gather(interface_ids_list, interface_id_t, self.group) interface_ids = torch.cat(interface_ids_list).tolist() - from mooncake.ep import get_active_ranks - active_ranks_mask = get_active_ranks(self.backend).tolist() + active_ranks_mask = self._active_ranks_list(torch.device("cuda")) self.runtime.sync_ibgda_peers( raddrs, rkeys, peer_qpns, peer_lids, subnet_prefixes, interface_ids, active_ranks_mask ) - try: - local_handle_ints = self.runtime.get_ipc_handle() - # pybind11 converts std::vector to a list of integers - local_handle_tensor = torch.tensor( - local_handle_ints, dtype=torch.int32, device="cuda" - ) - handles = [ - torch.empty(len(local_handle_ints), dtype=torch.int32, device="cuda") - for _ in range(self.group_size) - ] - dist.all_gather(handles, local_handle_tensor, self.group) - remote_handles = [h.tolist() for h in handles] - from mooncake.ep import get_active_ranks - active_ranks_mask = get_active_ranks(self.backend).tolist() - self.runtime.sync_nvlink_ipc_handles(remote_handles, - active_ranks_mask) - except Exception as e: - import warnings - - warnings.warn( - f"[Rank {self.rank}] Failed to exchange IPC handles: {e}. Falling back.", - RuntimeWarning, - stacklevel=2, - ) + if self.group_size == 1: + _debug_init(self.rank, "single-rank skip ipc handle export") + self._use_fallback = False + _debug_init(self.rank, "connect done fallback=False") + return + else: + try: + _debug_init(self.rank, "before get_ipc_handle") + local_handle_ints = self.runtime.get_ipc_handle() + _debug_init( + self.rank, f"after get_ipc_handle len={len(local_handle_ints)}" + ) + # pybind11 converts std::vector to a list of integers + local_handle_tensor = torch.tensor( + local_handle_ints, dtype=torch.int32, device="cuda" + ) + handles = [ + torch.empty(len(local_handle_ints), dtype=torch.int32, device="cuda") + for _ in range(self.group_size) + ] + _debug_init(self.rank, "before all_gather ipc handles") + dist.all_gather(handles, local_handle_tensor, self.group) + _debug_init(self.rank, "after all_gather ipc handles") + remote_handles = [h.tolist() for h in handles] + _debug_init(self.rank, "before get_active_ranks") + active_ranks_mask = self._active_ranks_list(torch.device("cuda")) + _debug_init( + self.rank, + f"before sync_nvlink_ipc_handles active={active_ranks_mask}", + ) + self.runtime.sync_nvlink_ipc_handles(remote_handles, active_ranks_mask) + _debug_init(self.rank, "after sync_nvlink_ipc_handles") + except Exception as e: + import warnings + + warnings.warn( + f"[Rank {self.rank}] Failed to exchange IPC handles: {e}. Falling back.", + RuntimeWarning, + stacklevel=2, + ) use_fast_path = False try: + _debug_init(self.rank, "before use_fast_path") use_fast_path = bool(self.runtime.use_fast_path()) + _debug_init(self.rank, f"after use_fast_path fast={use_fast_path}") except Exception: ibgda_disabled = bool(self.runtime.ibgda_disabled()) use_fast_path = not ibgda_disabled self._use_fallback = not use_fast_path + _debug_init(self.rank, f"connect done fallback={self._use_fallback}") def update_ep_member(self): - if self._use_torch_alltoall: - return self.connect(True) + def _is_mooncake_backend(self) -> bool: + try: + return dist.get_backend(self.group) == "mooncake" + except Exception: + return False + def _active_ranks_tensor( self, device: torch.device, dtype: torch.dtype = torch.int32 ) -> torch.Tensor: - if self._use_torch_alltoall: + if not self._is_mooncake_backend(): return torch.ones((self.group_size,), dtype=dtype, device=device) try: @@ -326,15 +367,6 @@ def get_ep_buffer_size_hint( num_ranks: int, num_experts: int, ) -> int: - if _USE_TORCH_ALLTOALL: - return ( - 4 - * num_experts - * num_max_dispatch_tokens_per_rank - * (32 + hidden * 2) - + 4 * num_experts * 4 - ) - from mooncake.ep import get_ep_buffer_size_hint return get_ep_buffer_size_hint( @@ -426,7 +458,7 @@ def dispatch( runtime_return_recv_hook, ) if _USE_MACA: - hook = self._wrap_maca_recv_hook(hook) + hook = self._wrap_maca_recv_hook(hook, event) if not return_recv_hook: hook() hook = None @@ -525,7 +557,7 @@ def combine( out, ) if _USE_MACA: - hook = self._wrap_maca_recv_hook(hook) + hook = self._wrap_maca_recv_hook(hook, event) if not return_recv_hook: hook() hook = None @@ -582,23 +614,6 @@ class _DummyEvent: def current_stream_wait(self): torch.cuda.synchronize() - class _CudaTimer: - def __init__(self, enabled: bool) -> None: - self.enabled = enabled - self.samples = {} - - def measure(self, name: str, fn: Callable[[], Any]) -> Any: - if not self.enabled: - return fn() - start = torch.cuda.Event(enable_timing=True) - end = torch.cuda.Event(enable_timing=True) - start.record() - result = fn() - end.record() - torch.cuda.synchronize() - self.samples[name] = start.elapsed_time(end) / 1000.0 - return result - @staticmethod def _fp8_cast(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: assert x.dim() == 2 and x.size(1) % 128 == 0 @@ -611,185 +626,6 @@ def _fp8_cast(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: x_scales = (x_amax / 448.0).view(m, -1) return x_fp8, x_scales - def _torch_alltoall_routed_dispatch( - self, - x: torch.Tensor, - topk_idx: torch.Tensor, - num_max_dispatch_tokens_per_rank: int, - num_experts: int, - return_recv_hook: bool, - ): - with torch.profiler.record_function("dispatch.torch_alltoall_routed"): - timer = Buffer._CudaTimer(_PROFILE_TORCH_ALLTOALL) - num_tokens = x.size(0) - num_topk = topk_idx.size(1) - num_ranks = self.group_size - num_local_experts = num_experts // num_ranks - max_messages_per_rank = num_tokens * num_topk - - if x.dtype != torch.bfloat16: - raise NotImplementedError( - "torch all-to-all fallback currently supports bfloat16 only" - ) - - from mooncake import ep - - send_payload, send_route = timer.measure( - "dispatch_pack", - lambda: ep.torch_alltoall_pack_dispatch_fused( - x.contiguous(), - topk_idx.contiguous(), - num_experts, - num_ranks, - ), - ) - recv_payload = torch.empty_like(send_payload) - timer.measure( - "dispatch_a2a_payload", - lambda: dist.all_to_all_single( - recv_payload, send_payload, group=self.group - ), - ) - if _PROFILE_TORCH_ALLTOALL: - timer.samples["dispatch_a2a_meta"] = 0.0 - ( - packed_recv_x, - packed_recv_src_info, - packed_recv_layout_range, - packed_recv_count, - return_src_pos, - ) = timer.measure( - "dispatch_compact", - lambda: ep.torch_alltoall_compact_dispatch_fused( - recv_payload, - num_local_experts, - num_max_dispatch_tokens_per_rank, - ), - ) - self._torch_alltoall_state = { - "send_route": send_route, - "return_src_pos": return_src_pos, - "num_tokens": num_tokens, - "num_topk": num_topk, - "max_messages_per_rank": max_messages_per_rank, - } - self._fallback_next_combine_buffer = torch.empty_like(packed_recv_x) - self._torch_alltoall_profile = timer.samples - hook = (lambda: None) if return_recv_hook else (lambda: None) - event = Buffer._DummyEvent() - return ( - packed_recv_x, - None, - packed_recv_count, - packed_recv_src_info, - packed_recv_layout_range, - event, - hook, - ) - - def _torch_alltoall_routed_combine( - self, - x: torch.Tensor, - topk_idx: torch.Tensor, - topk_weights: torch.Tensor, - layout_range: torch.Tensor, - zero_copy: bool, - return_recv_hook: bool, - out: Optional[torch.Tensor], - ): - with torch.profiler.record_function("combine.torch_alltoall_routed"): - timer = Buffer._CudaTimer(_PROFILE_TORCH_ALLTOALL) - expert_buffers = self._fallback_next_combine_buffer if zero_copy else x - if expert_buffers is None: - raise RuntimeError( - "zero_copy combine called before dispatch buffer allocation" - ) - if expert_buffers.dtype != torch.bfloat16: - expert_buffers = expert_buffers.to(torch.bfloat16) - - num_tokens, num_topk = topk_idx.shape - num_ranks = self.group_size - max_messages_per_rank = num_tokens * num_topk - state = self._torch_alltoall_state - if not state: - raise RuntimeError("combine called without torch all-to-all dispatch state") - - from mooncake import ep - - send_payload = timer.measure( - "combine_pack", - lambda: ep.torch_alltoall_pack_combine( - expert_buffers.contiguous(), - state["return_src_pos"], - num_ranks, - max_messages_per_rank, - ), - ) - recv_payload = torch.empty_like(send_payload) - timer.measure( - "combine_a2a_payload", - lambda: dist.all_to_all_single( - recv_payload, send_payload, group=self.group - ), - ) - combined = timer.measure( - "combine_reduce", - lambda: ep.torch_alltoall_reduce_combine( - recv_payload, - state["send_route"], - topk_weights.contiguous(), - out, - ), - ) - self._torch_alltoall_profile = timer.samples - hook = (lambda: None) if return_recv_hook else (lambda: None) - event = Buffer._DummyEvent() - return combined, event, hook - - def _torch_alltoall_dispatch( - self, - x: torch.Tensor, - topk_idx: torch.Tensor, - num_max_dispatch_tokens_per_rank: int, - num_experts: int, - use_fp8: bool, - return_recv_hook: bool, - ): - if use_fp8: - raise NotImplementedError( - "FP8 dispatch is not supported by torch all-to-all fallback" - ) - return self._torch_alltoall_routed_dispatch( - x, - topk_idx, - num_max_dispatch_tokens_per_rank, - num_experts, - return_recv_hook, - ) - - def _torch_alltoall_combine( - self, - x: torch.Tensor, - topk_idx: torch.Tensor, - topk_weights: torch.Tensor, - src_info: torch.Tensor, - layout_range: torch.Tensor, - num_max_dispatch_tokens_per_rank: int, - num_experts: int, - zero_copy: bool, - return_recv_hook: bool, - out: Optional[torch.Tensor], - ): - return self._torch_alltoall_routed_combine( - x, - topk_idx, - topk_weights, - layout_range, - zero_copy, - return_recv_hook, - out, - ) - def _fallback_dispatch( self, x: torch.Tensor, @@ -799,16 +635,6 @@ def _fallback_dispatch( use_fp8: bool, return_recv_hook: bool, ): - if self._use_torch_alltoall: - return self._torch_alltoall_dispatch( - x, - topk_idx, - num_max_dispatch_tokens_per_rank, - num_experts, - use_fp8, - return_recv_hook, - ) - with torch.profiler.record_function("dispatch"): num_tokens, hidden = x.shape k = topk_idx.size(1) @@ -1031,20 +857,6 @@ def _fallback_combine( return_recv_hook: bool, out: Optional[torch.Tensor], ): - if self._use_torch_alltoall: - return self._torch_alltoall_combine( - x, - topk_idx, - topk_weights, - src_info, - layout_range, - num_max_dispatch_tokens_per_rank, - num_experts, - zero_copy, - return_recv_hook, - out, - ) - with torch.profiler.record_function("combine"): num_tokens = topk_idx.size(0) hidden = (x if not zero_copy else self._fallback_next_combine_buffer).size( diff --git a/mooncake-wheel/tests/test_mooncake_ep.py b/mooncake-wheel/tests/test_mooncake_ep.py index fa7636fa4e..b00488ff78 100644 --- a/mooncake-wheel/tests/test_mooncake_ep.py +++ b/mooncake-wheel/tests/test_mooncake_ep.py @@ -1,4 +1,5 @@ import random +import os import torch import torch.distributed as dist from functools import partial @@ -6,6 +7,11 @@ from mooncake.mooncake_ep_buffer import Buffer from ep_test_utils import init_dist, bench, bench_kineto, calc_diff, hash_tensor, per_token_cast_back +_USE_MACA = ( + os.getenv("MOONCAKE_EP_USE_MACA", "").upper() in {"1", "ON", "TRUE", "YES"} + or bool(getattr(torch.version, "maca", None)) +) + def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int, rank: int, num_ranks: int, group: dist.ProcessGroup, cpu_group: dist.ProcessGroup, buffer: Buffer, seed: int = 0): @@ -34,7 +40,7 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int, hash_value, num_times = 0, 0 active_ranks = torch.ones((num_tokens, ), dtype=torch.int32, device='cuda') for return_recv_hook in (False, True): - for dispatch_use_fp8 in (False, True): + for dispatch_use_fp8 in ([False] if _USE_MACA else [False, True]): num_times += 1 for i in range((num_times % 2) + 1): packed_recv_x, packed_recv_count, handle, event, hook = \ From 3b837b55439fce9a98c0f2f8b2d0ca81a7ebfa4e Mon Sep 17 00:00:00 2001 From: Dayuxiaoshui <792179245@qq.com> Date: Fri, 26 Jun 2026 10:00:29 +0000 Subject: [PATCH 3/3] Stabilize MACA EP P2P path --- mooncake-ep/src/CMakeLists.txt | 5 +- mooncake-ep/src/mooncake_ep_buffer.cpp | 6 +- .../src/transport/device/CMakeLists.txt | 3 + mooncake-wheel/mooncake/mooncake_ep_buffer.py | 115 +++++------------- 4 files changed, 37 insertions(+), 92 deletions(-) diff --git a/mooncake-ep/src/CMakeLists.txt b/mooncake-ep/src/CMakeLists.txt index a88f10d97b..574ab514c0 100644 --- a/mooncake-ep/src/CMakeLists.txt +++ b/mooncake-ep/src/CMakeLists.txt @@ -1,7 +1,4 @@ -add_library(mooncake_ep - ep_py.cpp - mooncake_ep_buffer.cpp - mooncake_ep_kernel.cu) +add_library(mooncake_ep ep_py.cpp mooncake_ep_buffer.cpp mooncake_ep_kernel.cu) set_target_properties(mooncake_ep PROPERTIES POSITION_INDEPENDENT_CODE ON) target_link_libraries(mooncake_ep PUBLIC ${TORCH_LIBRARIES} transfer_engine ibverbs mlx5) diff --git a/mooncake-ep/src/mooncake_ep_buffer.cpp b/mooncake-ep/src/mooncake_ep_buffer.cpp index d7d5d37a42..79aa4fbe7d 100644 --- a/mooncake-ep/src/mooncake_ep_buffer.cpp +++ b/mooncake-ep/src/mooncake_ep_buffer.cpp @@ -1,7 +1,6 @@ #include #include #include -#include #include #include @@ -27,10 +26,7 @@ static bool initRdmaTransport(device::RdmaTransport* t, void* gdr_buffer, static bool macaHostPhaseFenceCoversPeers() { #ifdef MOONCAKE_EP_USE_MACA - const char* env = std::getenv("MOONCAKE_EP_MACA_PHASE_FENCE"); - if (env == nullptr || env[0] == '\0') return true; - return std::strcmp(env, "0") != 0 && std::strcmp(env, "off") != 0 && - std::strcmp(env, "none") != 0; + return true; #else return false; #endif diff --git a/mooncake-transfer-engine/src/transport/device/CMakeLists.txt b/mooncake-transfer-engine/src/transport/device/CMakeLists.txt index 527d9eebe4..8e64a68a7a 100644 --- a/mooncake-transfer-engine/src/transport/device/CMakeLists.txt +++ b/mooncake-transfer-engine/src/transport/device/CMakeLists.txt @@ -24,3 +24,6 @@ endif() if(USE_MACA) target_compile_definitions(device_transport PRIVATE USE_MACA) endif() +if(USE_MACA) + target_compile_definitions(device_transport PRIVATE USE_MACA) +endif() diff --git a/mooncake-wheel/mooncake/mooncake_ep_buffer.py b/mooncake-wheel/mooncake/mooncake_ep_buffer.py index daf7bbe622..e2d198a033 100644 --- a/mooncake-wheel/mooncake/mooncake_ep_buffer.py +++ b/mooncake-wheel/mooncake/mooncake_ep_buffer.py @@ -19,13 +19,6 @@ def _env_enabled(name: str, default: bool = False) -> bool: _env_enabled("MOONCAKE_EP_USE_MUSA") or _USE_MACA ) -_MACA_PHASE_FENCE = os.getenv("MOONCAKE_EP_MACA_PHASE_FENCE", "p2p").lower() -_DEBUG_INIT = _env_enabled("MOONCAKE_EP_DEBUG_INIT") - - -def _debug_init(rank: int, message: str) -> None: - if _DEBUG_INIT: - print(f"[rank {rank}] {message}", flush=True) class EventOverlap: @@ -97,11 +90,9 @@ def __init__(self, group: dist.ProcessGroup, num_ep_buffer_bytes: int = 0): self.num_ep_buffer_bytes = num_ep_buffer_bytes self.backend = self.group # NIC auto-detection happens inside ep.Buffer via Topology::discover(). - _debug_init(self.rank, "before ep.Buffer") self.runtime = ep.Buffer( self.rank, self.group_size, num_ep_buffer_bytes ) - _debug_init(self.rank, "after ep.Buffer") # Fallback flag and buffers. # Note: `sync_nvlink_ipc_handles()` can mutate C++ `ibgda_disabled_` (True->False when # P2P+IPC succeeds for all ranks). We re-evaluate after IPC sync below. @@ -112,7 +103,7 @@ def __init__(self, group: dist.ProcessGroup, num_ep_buffer_bytes: int = 0): self.connect() def _maca_phase_fence(self, send_event: Optional[Any] = None) -> None: - if not _USE_MACA or _MACA_PHASE_FENCE in {"", "0", "off", "none"}: + if not _USE_MACA: return backend = dist.get_backend(self.group) @@ -126,55 +117,6 @@ def wait_send_done() -> None: # Compatibility fence between SEND and RECV. The EP payload still # uses the P2P fast path; this only keeps rank phases aligned on MACA. - if _MACA_PHASE_FENCE == "barrier": - wait_send_done() - dist.barrier(self.group) - return - if _MACA_PHASE_FENCE == "p2p": - wait_send_done() - if ( - self._maca_phase_token is None - or self._maca_phase_token.device != fence_device - ): - self._maca_phase_token = torch.empty( - 1, dtype=torch.int32, device=fence_device - ) - if ( - self._maca_phase_recv_tokens is None - or self._maca_phase_recv_tokens[0].device != fence_device - ): - self._maca_phase_recv_tokens = [ - torch.empty(1, dtype=torch.int32, device=fence_device) - for _ in range(self.group_size) - ] - self._maca_phase_token.fill_(1) - ops = [] - for peer in range(self.group_size): - if peer == self.rank: - continue - ops.append( - dist.P2POp( - dist.isend, self._maca_phase_token, peer, self.group - ) - ) - ops.append( - dist.P2POp( - dist.irecv, - self._maca_phase_recv_tokens[peer], - peer, - self.group, - ) - ) - if not ops: - return - for work in dist.batch_isend_irecv(ops): - work.wait() - return - if _MACA_PHASE_FENCE != "allreduce": - raise ValueError( - "MOONCAKE_EP_MACA_PHASE_FENCE must be one of: " - "p2p, allreduce, barrier, none" - ) wait_send_done() if ( self._maca_phase_token is None @@ -183,10 +125,36 @@ def wait_send_done() -> None: self._maca_phase_token = torch.empty( 1, dtype=torch.int32, device=fence_device ) + if ( + self._maca_phase_recv_tokens is None + or self._maca_phase_recv_tokens[0].device != fence_device + ): + self._maca_phase_recv_tokens = [ + torch.empty(1, dtype=torch.int32, device=fence_device) + for _ in range(self.group_size) + ] self._maca_phase_token.fill_(1) - dist.all_reduce( - self._maca_phase_token, op=dist.ReduceOp.SUM, group=self.group - ) + ops = [] + for peer in range(self.group_size): + if peer == self.rank: + continue + ops.append( + dist.P2POp( + dist.isend, self._maca_phase_token, peer, self.group + ) + ) + ops.append( + dist.P2POp( + dist.irecv, + self._maca_phase_recv_tokens[peer], + peer, + self.group, + ) + ) + if not ops: + return + for work in dist.batch_isend_irecv(ops): + work.wait() def _wrap_maca_recv_hook( self, hook: Optional[Callable], send_event: Optional[Any] @@ -201,11 +169,8 @@ def wrapped_hook() -> None: def connect(self, is_update: bool = False): from mooncake import ep - _debug_init(self.rank, f"connect start fallback={self._use_fallback}") if not self._use_fallback: - _debug_init(self.rank, "before get_mr_info") (raddr, rkey) = self.runtime.get_mr_info() - _debug_init(self.rank, "after get_mr_info") raddr = torch.tensor([raddr], dtype=torch.int64, device="cuda") raddrs = [ @@ -282,17 +247,13 @@ def connect(self, is_update: bool = False): ) if self.group_size == 1: - _debug_init(self.rank, "single-rank skip ipc handle export") + # No peer can import this IPC handle in single-rank EP. Skipping + # export also avoids unnecessary driver IPC calls on MACA. self._use_fallback = False - _debug_init(self.rank, "connect done fallback=False") return else: try: - _debug_init(self.rank, "before get_ipc_handle") local_handle_ints = self.runtime.get_ipc_handle() - _debug_init( - self.rank, f"after get_ipc_handle len={len(local_handle_ints)}" - ) # pybind11 converts std::vector to a list of integers local_handle_tensor = torch.tensor( local_handle_ints, dtype=torch.int32, device="cuda" @@ -301,18 +262,10 @@ def connect(self, is_update: bool = False): torch.empty(len(local_handle_ints), dtype=torch.int32, device="cuda") for _ in range(self.group_size) ] - _debug_init(self.rank, "before all_gather ipc handles") dist.all_gather(handles, local_handle_tensor, self.group) - _debug_init(self.rank, "after all_gather ipc handles") remote_handles = [h.tolist() for h in handles] - _debug_init(self.rank, "before get_active_ranks") active_ranks_mask = self._active_ranks_list(torch.device("cuda")) - _debug_init( - self.rank, - f"before sync_nvlink_ipc_handles active={active_ranks_mask}", - ) self.runtime.sync_nvlink_ipc_handles(remote_handles, active_ranks_mask) - _debug_init(self.rank, "after sync_nvlink_ipc_handles") except Exception as e: import warnings @@ -324,16 +277,12 @@ def connect(self, is_update: bool = False): use_fast_path = False try: - _debug_init(self.rank, "before use_fast_path") use_fast_path = bool(self.runtime.use_fast_path()) - _debug_init(self.rank, f"after use_fast_path fast={use_fast_path}") except Exception: ibgda_disabled = bool(self.runtime.ibgda_disabled()) use_fast_path = not ibgda_disabled self._use_fallback = not use_fast_path - _debug_init(self.rank, f"connect done fallback={self._use_fallback}") - def update_ep_member(self): self.connect(True)