Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,9 @@ if (WITH_EP)
"-DTORCH_CUDA_ARCH_LIST=${_torch_cuda_arch_list_pipe}"
"-DSTAGING_DIR=${EP_PG_STAGING_DIR}"
"-DENGINE_SO_PATH=$<TARGET_FILE:engine>"
"-DPython3_EXECUTABLE=${Python3_EXECUTABLE}"
"-DEP_USE_MUSA=$<IF:$<BOOL:${USE_MUSA}>,1,0>"
"-DEP_USE_MACA=$<IF:$<BOOL:${USE_MACA}>,1,0>"
-P "${CMAKE_CURRENT_SOURCE_DIR}/mooncake-ep/BuildEpExt.cmake"
COMMENT "Building Mooncake EP Python extension(s)"
DEPENDS engine
Expand All @@ -162,7 +164,9 @@ if (WITH_EP)
"-DTORCH_CUDA_ARCH_LIST=${_torch_cuda_arch_list_pipe}"
"-DSTAGING_DIR=${EP_PG_STAGING_DIR}"
"-DENGINE_SO_PATH=$<TARGET_FILE:engine>"
"-DPython3_EXECUTABLE=${Python3_EXECUTABLE}"
"-DEP_USE_MUSA=$<IF:$<BOOL:${USE_MUSA}>,1,0>"
"-DEP_USE_MACA=$<IF:$<BOOL:${USE_MACA}>,1,0>"
-P "${CMAKE_CURRENT_SOURCE_DIR}/mooncake-pg/BuildPgExt.cmake"
COMMENT "Building Mooncake PG Python extension(s)"
DEPENDS engine mooncake_ep_ext
Expand Down
11 changes: 11 additions & 0 deletions mooncake-ep/BuildEpExt.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# STAGING_DIR - destination directory for the built .so files
# ENGINE_SO_PATH - absolute path to the built engine.cpython-XYZ.so
# EP_USE_MUSA - set to "1" when building for MUSA (MTLink path)
# EP_USE_MACA - set to "1" when building for MACA (MTLink path)

cmake_minimum_required(VERSION 3.16)

Expand Down Expand Up @@ -40,6 +41,16 @@ if(EP_USE_MUSA)
else()
unset(ENV{MOONCAKE_EP_USE_MUSA})
endif()
if(EP_USE_MACA)
set(ENV{MOONCAKE_EP_USE_MACA} "1")
if(DEFINED ENV{MACA_PATH})
set(ENV{MACA_HOME} "$ENV{MACA_PATH}")
elseif(DEFINED ENV{MACA_HOME})
set(ENV{MACA_PATH} "$ENV{MACA_HOME}")
endif()
else()
unset(ENV{MOONCAKE_EP_USE_MACA})
endif()

# ---------------------------------------------------------------------------
# 2. Ensure engine.so exists in mooncake-wheel/mooncake/ for setup.py linking.
Expand Down
10 changes: 8 additions & 2 deletions mooncake-ep/include/mooncake_ep_configs.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,22 @@
#endif

#include <cuda_bf16.h>
#ifndef MOONCAKE_EP_USE_MACA
#include <cuda_fp8.h>
#include <cuda_runtime.h>
#include <infiniband/mlx5dv.h>
#endif
#include <cuda_runtime.h>

#if defined(MOONCAKE_EP_USE_MUSA) || defined(MOONCAKE_EP_USE_MACA)
#define MOONCAKE_EP_SPLIT_SEND_RECV 1
#endif

// torchada maps nv_bfloat16 → __mt_bfloat16 which is an incomplete type on
// MUSA, so sizeof(__mt_bfloat16) fails. mt_bfloat16 (the complete typedef in
// musa_bf16.hpp) requires the MUSA device compiler (mcc) and cannot be
// included from host .cpp files. Use EP_BF16_SIZE: sizeof(nv_bfloat16) on
// CUDA, hardcoded 2 on MUSA (both are 2 bytes).
#ifdef MOONCAKE_EP_USE_MUSA
#if defined(MOONCAKE_EP_USE_MUSA) || defined(MOONCAKE_EP_USE_MACA)
#define EP_BF16_SIZE 2
#else
#define EP_BF16_SIZE sizeof(nv_bfloat16)
Expand Down
48 changes: 46 additions & 2 deletions mooncake-ep/include/mooncake_ep_device.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,49 @@ __forceinline__ __device__ int get_lane_id() { return threadIdx.x % 32; }
} \
}

#else // !MOONCAKE_EP_USE_MUSA
#elif defined(MOONCAKE_EP_USE_MACA)

// -- FP8 types ---------------------------------------------------------------
// MetaX C500 does not support the FP8 EP path. The dispatch<true> template is
// still compiled because the host selects the kernel through a runtime bool, so
// provide storage stubs and reject actual FP8 use in the host/Python wrappers.
#include <cstdint>
using ep_fp8_storage_t = uint8_t;
using ep_fp8x2_storage_t = uint16_t;
#if defined(__CUDACC__) || defined(__MCC__)
__device__ __forceinline__ ep_fp8x2_storage_t ep_cvt_float2_to_fp8x2(float2) {
return 0;
}
#endif

// -- Device intrinsics -------------------------------------------------------
#ifndef __activemask
#define __activemask() (0xffffffff)
#endif

#if defined(__CUDACC__) || defined(__MCC__)
__forceinline__ __device__ int get_lane_id() { return threadIdx.x % 32; }
#endif

// -- Kernel launch (MACA: no cooperative launch) -----------------------------
#define EP_LAUNCH_BOUNDS(max_threads, min_blocks)

#define SETUP_LAUNCH_CONFIG(num_sms, num_threads, stream) \
dim3 _grid(num_sms); \
dim3 _block(num_threads); \
cudaStream_t _stream = stream

#define LAUNCH_KERNEL(config, kernel, ...) \
kernel<<<_grid, _block, 0, _stream>>>(__VA_ARGS__); \
{ \
auto _err = cudaGetLastError(); \
if (_err != cudaSuccess) { \
fprintf(stderr, "[EP] kernel launch failed: %s\n", \
cudaGetErrorString(_err)); \
} \
}

#else // !MOONCAKE_EP_USE_MUSA && !MOONCAKE_EP_USE_MACA

// -- FP8 types (CUDA native names) -------------------------------------------
#include <cuda_fp8.h>
Expand Down Expand Up @@ -86,7 +128,9 @@ __forceinline__ __device__ int get_lane_id() {
#define LAUNCH_KERNEL(config, kernel, ...) \
CUDA_CHECK(cudaLaunchKernelEx(config, kernel, ##__VA_ARGS__))

#endif // MOONCAKE_EP_USE_MUSA
#endif // MOONCAKE_EP_USE_MUSA / MOONCAKE_EP_USE_MACA

// Both platforms need IB verbs
#ifndef MOONCAKE_EP_USE_MACA
#include <infiniband/mlx5dv.h>
#endif
2 changes: 2 additions & 0 deletions mooncake-ep/include/mooncake_ep_event.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ struct EventHandle {
void current_stream_wait() const {
at::cuda::getCurrentCUDAStream().unwrap().wait(*event);
}

void synchronize() const { event->synchronize(); }
};

inline torch::Event create_event(const at::cuda::CUDAStream& s) {
Expand Down
2 changes: 1 addition & 1 deletion mooncake-ep/include/mooncake_ep_exception.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class EPException : public std::exception {
#endif

#ifndef EP_DEVICE_ASSERT
#ifdef MOONCAKE_EP_USE_MUSA
#if defined(MOONCAKE_EP_USE_MUSA) || defined(MOONCAKE_EP_USE_MACA)
// MUSA SDK 4.3.x can turn kernels that merely contain a device-side __trap()
// branch into illegal memory accesses, even when the assertion condition is
// true. Keep these invariants as host/static checks on MUSA builds.
Expand Down
44 changes: 22 additions & 22 deletions mooncake-ep/include/mooncake_ep_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ struct VecInt<16> {
};

// ---- TMA / mbarrier helpers (CUDA only) ----
#ifndef MOONCAKE_EP_USE_MUSA
#if !defined(MOONCAKE_EP_USE_MUSA) && !defined(MOONCAKE_EP_USE_MACA)

__device__ __forceinline__ void fence_view_async_shared() {
asm volatile("fence.proxy.async.shared::cta; \n" ::);
Expand All @@ -58,16 +58,16 @@ __device__ __forceinline__ void fence_barrier_init() {
asm volatile("fence.mbarrier_init.release.cluster; \n" ::);
}

__device__ __forceinline__ void mbarrier_init(uint64_t *mbar_ptr,
__device__ __forceinline__ void mbarrier_init(uint64_t* mbar_ptr,
uint32_t arrive_count) {
auto mbar_int_ptr =
static_cast<uint32_t>(__cvta_generic_to_shared(mbar_ptr));
asm volatile("mbarrier.init.shared::cta.b64 [%1], %0;" ::"r"(arrive_count),
"r"(mbar_int_ptr));
}

__device__ __forceinline__ void mbarrier_wait(uint64_t *mbar_ptr,
uint32_t &phase) {
__device__ __forceinline__ void mbarrier_wait(uint64_t* mbar_ptr,
uint32_t& phase) {
auto mbar_int_ptr =
static_cast<uint32_t>(__cvta_generic_to_shared(mbar_ptr));
asm volatile(
Expand All @@ -84,7 +84,7 @@ __device__ __forceinline__ void mbarrier_wait(uint64_t *mbar_ptr,
}

__device__ __forceinline__ void mbarrier_arrive_and_expect_tx(
uint64_t *mbar_ptr, int num_bytes) {
uint64_t* mbar_ptr, int num_bytes) {
auto mbar_int_ptr =
static_cast<uint32_t>(__cvta_generic_to_shared(mbar_ptr));
asm volatile(
Expand All @@ -100,9 +100,9 @@ __device__ __forceinline__ void tma_store_fence() {
constexpr uint64_t kEvictFirst = 0x12f0000000000000;
constexpr uint64_t kEvictNormal = 0x1000000000000000;

__device__ __forceinline__ void tma_load_1d(const void *smem_ptr,
const void *gmem_ptr,
uint64_t *mbar_ptr, int num_bytes,
__device__ __forceinline__ void tma_load_1d(const void* smem_ptr,
const void* gmem_ptr,
uint64_t* mbar_ptr, int num_bytes,
bool evict_first = true) {
auto mbar_int_ptr =
static_cast<uint32_t>(__cvta_generic_to_shared(mbar_ptr));
Expand All @@ -116,8 +116,8 @@ __device__ __forceinline__ void tma_load_1d(const void *smem_ptr,
: "memory");
}

__device__ __forceinline__ void tma_store_1d(const void *smem_ptr,
const void *gmem_ptr,
__device__ __forceinline__ void tma_store_1d(const void* smem_ptr,
const void* gmem_ptr,
int num_bytes,
bool evict_first = true) {
auto smem_int_ptr =
Expand All @@ -136,7 +136,7 @@ __device__ __forceinline__ void tma_store_wait() {
asm volatile("cp.async.bulk.wait_group.read %0;" ::"n"(N) : "memory");
}

#endif // MOONCAKE_EP_USE_MUSA
#endif // !MOONCAKE_EP_USE_MUSA && !MOONCAKE_EP_USE_MACA

template <typename dtype_t>
__host__ __device__ dtype_t cell_div(dtype_t a, dtype_t b) {
Expand All @@ -150,43 +150,43 @@ __host__ __device__ dtype_t align(dtype_t a, dtype_t b) {

__forceinline__ __device__ void get_channel_task_range(int num_tokens,
int num_sms, int sm_id,
int &token_start_idx,
int &token_end_idx) {
int& token_start_idx,
int& token_end_idx) {
int num_tokens_per_sm = cell_div(num_tokens, num_sms);
token_start_idx = min(num_tokens_per_sm * sm_id, num_tokens);
token_end_idx = min(token_start_idx + num_tokens_per_sm, num_tokens);
}

template <typename dtype_a_t, typename dtype_b_t>
__device__ __forceinline__ dtype_b_t pack2(const dtype_a_t &x,
const dtype_a_t &y) {
__device__ __forceinline__ dtype_b_t pack2(const dtype_a_t& x,
const dtype_a_t& y) {
EP_STATIC_ASSERT(sizeof(dtype_a_t) * 2 == sizeof(dtype_b_t),
"Invalid dtypes");
dtype_b_t packed;
auto unpacked_ptr = reinterpret_cast<dtype_a_t *>(&packed);
auto unpacked_ptr = reinterpret_cast<dtype_a_t*>(&packed);
unpacked_ptr[0] = x, unpacked_ptr[1] = y;
return packed;
}

template <typename dtype_a_t, typename dtype_b_t>
__device__ __forceinline__ void unpack2(const dtype_b_t &packed, dtype_a_t &x,
dtype_a_t &y) {
__device__ __forceinline__ void unpack2(const dtype_b_t& packed, dtype_a_t& x,
dtype_a_t& y) {
EP_STATIC_ASSERT(sizeof(dtype_a_t) * 2 == sizeof(dtype_b_t),
"Invalid dtypes");
auto unpacked_ptr = reinterpret_cast<const dtype_a_t *>(&packed);
auto unpacked_ptr = reinterpret_cast<const dtype_a_t*>(&packed);
x = unpacked_ptr[0], y = unpacked_ptr[1];
}

template <typename dtype_t>
__device__ __forceinline__ dtype_t broadcast(dtype_t &ptr, int src_lane_idx) {
__device__ __forceinline__ dtype_t broadcast(dtype_t& ptr, int src_lane_idx) {
EP_STATIC_ASSERT(sizeof(dtype_t) % sizeof(int) == 0, "");
auto send_int_values = reinterpret_cast<int *>(&ptr);
auto send_int_values = reinterpret_cast<int*>(&ptr);
int recv_int_values[sizeof(dtype_t) / sizeof(int)];
#pragma unroll
for (int i = 0; i < sizeof(dtype_t) / sizeof(int); ++i)
recv_int_values[i] =
__shfl_sync(0xffffffff, send_int_values[i], src_lane_idx);
return *reinterpret_cast<dtype_t *>(recv_int_values);
return *reinterpret_cast<dtype_t*>(recv_int_values);
}

__forceinline__ __device__ int warp_reduce_sum(int value) {
Expand Down
42 changes: 38 additions & 4 deletions mooncake-ep/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
import torch

use_musa = os.getenv("MOONCAKE_EP_USE_MUSA", "").upper() in {"1", "ON", "TRUE", "YES"}
use_maca = (
os.getenv("MOONCAKE_EP_USE_MACA", "").upper() in {"1", "ON", "TRUE", "YES"}
or (hasattr(torch.version, "maca") and torch.version.maca is not None)
)
if use_musa:
try:
import torchada # noqa: F401
Expand All @@ -28,12 +32,33 @@

abi_flag = int(torch._C._GLIBCXX_USE_CXX11_ABI)
current_dir = os.path.abspath(os.path.dirname(__file__))
repo_dir = os.path.abspath(os.path.join(current_dir, os.pardir))
sysroot_dir = os.path.join(repo_dir, ".deps", "sysroot", "usr")


def existing_dirs(*paths):
return [path for path in paths if os.path.isdir(path)]


sysroot_include_dirs = existing_dirs(
os.path.join(sysroot_dir, "include"),
os.path.join(sysroot_dir, "include", "jsoncpp"),
os.path.join(sysroot_dir, "include", "libnl3"),
)
sysroot_library_dirs = existing_dirs(
os.path.join(sysroot_dir, "lib", "x86_64-linux-gnu"),
os.path.join(sysroot_dir, "lib"),
)

abi_define = f"-D_GLIBCXX_USE_CXX11_ABI={abi_flag}"
cxx_args = [abi_define, "-std=c++20", "-O3", "-g0"]

cuda_libraries = ["ibverbs", "mlx5"]
cuda_library_dirs = []
include_dirs = [
os.path.join(current_dir, "include"),
os.path.join(current_dir, "../mooncake-transfer-engine/include"),
]

if use_musa:
cuda_libraries = []
Expand All @@ -48,6 +73,18 @@
"--cuda-gpu-arch=mp_31",
"-O3",
]
elif use_maca:
cuda_libraries = []
cuda_library_dirs = sysroot_library_dirs.copy()
include_dirs += sysroot_include_dirs
maca_defines = ["-DUSE_MACA", "-DMOONCAKE_EP_USE_MACA=1"]
cxx_args += maca_defines
device_args = [
abi_define,
*maca_defines,
"-std=c++20",
"-O3",
]
else:
cxx_args.append("-DUSE_CUDA")
device_args = [
Expand All @@ -72,10 +109,7 @@
ext_modules=[
CUDAExtension(
name=module_name,
include_dirs=[
os.path.join(current_dir, "include"),
os.path.join(current_dir, "../mooncake-transfer-engine/include"),
],
include_dirs=include_dirs,
sources=[
"src/ep_py.cpp",
"src/mooncake_ep_buffer.cpp",
Expand Down
5 changes: 4 additions & 1 deletion mooncake-ep/src/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
add_library(mooncake_ep ep_py.cpp mooncake_ep_buffer.cpp mooncake_ep_kernel.cu)
add_library(mooncake_ep
ep_py.cpp
mooncake_ep_buffer.cpp
mooncake_ep_kernel.cu)

set_target_properties(mooncake_ep PROPERTIES POSITION_INDEPENDENT_CODE ON)
target_link_libraries(mooncake_ep PUBLIC ${TORCH_LIBRARIES} transfer_engine ibverbs mlx5)
3 changes: 2 additions & 1 deletion mooncake-ep/src/ep_py.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {

py::class_<EventHandle>(m, "EventHandle")
.def(py::init<>())
.def("current_stream_wait", &EventHandle::current_stream_wait);
.def("current_stream_wait", &EventHandle::current_stream_wait)
.def("synchronize", &EventHandle::synchronize);

m.attr("MAX_QP_COUNT") = pybind11::int_(MAX_QP_COUNT);

Expand Down
Loading
Loading