diff --git a/mooncake-transfer-engine/include/transport/rdma_transport/rdma_transport.h b/mooncake-transfer-engine/include/transport/rdma_transport/rdma_transport.h index c999b89cda..c7a98a83f7 100644 --- a/mooncake-transfer-engine/include/transport/rdma_transport/rdma_transport.h +++ b/mooncake-transfer-engine/include/transport/rdma_transport/rdma_transport.h @@ -140,6 +140,13 @@ class RdmaTransport : public Transport { // local_server_name_ keeps the TCP-reachable address for P2P routing. std::string rdma_server_name_; std::mutex local_desc_lock_; + // Mooncake#2017: buffers larger than the device max_mr_size are split into + // multiple sub-max_mr_size MRs (one BufferDesc per chunk) so that + // ibv_reg_mr is never silently truncated. unregisterLocalMemory() only + // receives the base addr, so remember each base buffer's chunk + // start-addresses for cleanup. + std::mutex chunk_map_mutex_; + std::unordered_map> chunk_map_; }; using TransferRequest = Transport::TransferRequest; diff --git a/mooncake-transfer-engine/src/transport/rdma_transport/rdma_transport.cpp b/mooncake-transfer-engine/src/transport/rdma_transport/rdma_transport.cpp index d65e7d4abd..b11aba5cdd 100644 --- a/mooncake-transfer-engine/src/transport/rdma_transport/rdma_transport.cpp +++ b/mooncake-transfer-engine/src/transport/rdma_transport/rdma_transport.cpp @@ -18,6 +18,7 @@ #include #include +#include #include #include #include @@ -25,6 +26,7 @@ #include #include #include +#include #include @@ -202,7 +204,6 @@ int RdmaTransport::registerLocalMemoryInternal(void *addr, size_t length, bool update_metadata, bool force_sequential) { (void)remote_accessible; - BufferDesc buffer_desc; const int kBaseAccessRights = IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ; @@ -211,111 +212,166 @@ int RdmaTransport::registerLocalMemoryInternal(void *addr, size_t length, if (MCIbRelaxedOrderingEnabled) { access_rights |= IBV_ACCESS_RELAXED_ORDERING; } - bool do_pre_touch = context_list_.size() > 0 && - std::thread::hardware_concurrency() >= 4 && - length >= (size_t)4 * 1024 * 1024 * 1024; - if (do_pre_touch) { - // Parallel Pre-touch the memory to speedup the registration process. - int ret = preTouchMemory(addr, length); - if (ret != 0) { - return ret; - } - } - /* Parallel register when: - 1. parallel_reg_mr is enabled via MC_ENABLE_PARALLEL_REG_MR; - 2. parallel_reg_mr not set and multiple contexts exist and memory has been - pre-touched - Note: If memory hasn't been touched, parallel register can be - slower. Details in: https://github.com/kvcache-ai/Mooncake/issues/848 - Note: force_sequential is used by batch operations to avoid nested - parallelism. - */ - int use_parallel_reg = 0; - if (!force_sequential) { - use_parallel_reg = globalConfig().parallel_reg_mr; - if (use_parallel_reg == -1) { - use_parallel_reg = context_list_.size() > 1 && do_pre_touch; + // Mooncake#2017: ibv_reg_mr silently truncates a registration to the device + // max_mr_size, but the metadata would still advertise the full BufferDesc + // length, so any remote RDMA op past the boundary fails with + // IBV_WC_REM_ACCESS_ERR (ionic CQE error 10). Split buffers larger than + // max_mr_size into chunks of <= max_mr_size, register each as its own MR, + // and publish one BufferDesc per chunk (the per-context rkey/lkey lookups + // are address-range based, so each chunk gets the correct key). + size_t chunk_limit = (size_t)globalConfig().max_mr_size; + std::vector> chunks; + if (chunk_limit > 0 && length > chunk_limit) { + for (size_t offset = 0; offset < length;) { + size_t chunk_len = std::min(chunk_limit, length - offset); + chunks.emplace_back(static_cast(addr) + offset, chunk_len); + offset += chunk_len; } + LOG(WARNING) << "Auto-splitting buffer " << addr << " (" << length + << " bytes) into " << chunks.size() + << " chunks of <= " << chunk_limit + << " bytes each (device max_mr_size; Mooncake#2017)"; + } else { + chunks.emplace_back(addr, length); } - auto reg_start = std::chrono::steady_clock::now(); - - if (use_parallel_reg) { - std::vector reg_threads; - reg_threads.reserve(context_list_.size()); - std::vector ret_codes(context_list_.size(), 0); - const int ar = access_rights; // Local copy for lambda capture - - for (size_t i = 0; i < context_list_.size(); ++i) { - reg_threads.emplace_back([this, &ret_codes, i, addr, length, ar]() { - ret_codes[i] = - context_list_[i]->registerMemoryRegion(addr, length, ar); - }); - } + // Resolve the location name once, from the original buffer. + std::string resolved_name; + if (name == kWildcardLocation) { + bool only_first_page = true; + const std::vector entries = + getMemoryLocation(addr, length, only_first_page); + if (entries.empty()) return -1; + resolved_name = entries[0].location; + } else { + resolved_name = name; + } - for (auto &thread : reg_threads) { - thread.join(); + // Best-effort rollback of already-registered chunks [0, up_to_ci]. + auto rollbackChunks = [&](size_t up_to_ci) { + for (size_t ri = 0; ri <= up_to_ci && ri < chunks.size(); ++ri) { + metadata_->removeLocalMemoryBuffer(chunks[ri].first, + update_metadata); + for (auto &context : context_list_) + context->unregisterMemoryRegion(chunks[ri].first); } - - for (size_t i = 0; i < ret_codes.size(); ++i) { - if (ret_codes[i] != 0) { - LOG(ERROR) << "Failed to register memory region with context " - << i; - return ret_codes[i]; + }; + + for (size_t ci = 0; ci < chunks.size(); ++ci) { + void *chunk_addr = chunks[ci].first; + size_t chunk_len = chunks[ci].second; + + // Decide pre-touch from the ORIGINAL buffer length, not the capped + // chunk_len (which is <= max_mr_size, so a chunk_len-based >=4GiB check + // would never fire and silently disable parallel pre-touch). + bool do_pre_touch = context_list_.size() > 0 && + std::thread::hardware_concurrency() >= 4 && + length >= (size_t)4 * 1024 * 1024 * 1024; + if (do_pre_touch) { + // Parallel pre-touch the memory to speed up registration. + int ret = preTouchMemory(chunk_addr, chunk_len); + if (ret != 0) { + if (ci) rollbackChunks(ci - 1); + return ret; } } - } else { - for (size_t i = 0; i < context_list_.size(); ++i) { - int ret = context_list_[i]->registerMemoryRegion(addr, length, - access_rights); - if (ret) { - LOG(ERROR) << "Failed to register memory region with context " - << i; - return ret; + + /* Parallel register when: + 1. parallel_reg_mr is enabled via MC_ENABLE_PARALLEL_REG_MR; + 2. parallel_reg_mr not set, multiple contexts exist, memory pre-touched. + force_sequential is used by batch operations to avoid nested + parallelism. + */ + int use_parallel_reg = 0; + if (!force_sequential) { + use_parallel_reg = globalConfig().parallel_reg_mr; + if (use_parallel_reg == -1) { + use_parallel_reg = context_list_.size() > 1 && do_pre_touch; } } - } - auto reg_end = std::chrono::steady_clock::now(); - auto reg_duration_ms = - std::chrono::duration_cast(reg_end - - reg_start) - .count(); + auto reg_start = std::chrono::steady_clock::now(); - if (globalConfig().trace) { - LOG(INFO) << "registerMemoryRegion: addr=" << addr - << ", length=" << length - << ", contexts=" << context_list_.size() - << ", parallel=" << (use_parallel_reg ? "true" : "false") - << ", duration=" << reg_duration_ms << "ms"; - } + if (use_parallel_reg) { + std::vector reg_threads; + reg_threads.reserve(context_list_.size()); + std::vector ret_codes(context_list_.size(), 0); + const int ar = access_rights; // Local copy for lambda capture - // Collect keys from all contexts - for (auto &context : context_list_) { - buffer_desc.lkey.push_back(context->lkey(addr)); - buffer_desc.rkey.push_back(context->rkey(addr)); - } + for (size_t i = 0; i < context_list_.size(); ++i) { + reg_threads.emplace_back( + [this, &ret_codes, i, chunk_addr, chunk_len, ar]() { + ret_codes[i] = context_list_[i]->registerMemoryRegion( + chunk_addr, chunk_len, ar); + }); + } - // Get the memory location automatically after registered MR(pinned), - // when the name is kWildcardLocation("*"). - if (name == kWildcardLocation) { - bool only_first_page = true; - const std::vector entries = - getMemoryLocation(addr, length, only_first_page); - if (entries.empty()) return -1; - buffer_desc.name = entries[0].location; - } else { - buffer_desc.name = name; - } + for (auto &thread : reg_threads) thread.join(); + + for (size_t i = 0; i < ret_codes.size(); ++i) { + if (ret_codes[i] != 0) { + LOG(ERROR) << "Failed to register memory region (chunk " + << ci << ") with context " << i; + rollbackChunks(ci); + return ret_codes[i]; + } + } + } else { + for (size_t i = 0; i < context_list_.size(); ++i) { + int ret = context_list_[i]->registerMemoryRegion( + chunk_addr, chunk_len, access_rights); + if (ret) { + LOG(ERROR) << "Failed to register memory region (chunk " + << ci << ") with context " << i; + rollbackChunks(ci); + return ret; + } + } + } - buffer_desc.addr = (uint64_t)addr; - buffer_desc.length = length; + auto reg_end = std::chrono::steady_clock::now(); + auto reg_duration_ms = + std::chrono::duration_cast(reg_end - + reg_start) + .count(); + if (globalConfig().trace) { + LOG(INFO) << "registerMemoryRegion: chunk " << ci << "/" + << chunks.size() << ", addr=" << chunk_addr + << ", length=" << chunk_len + << ", contexts=" << context_list_.size() + << ", parallel=" << (use_parallel_reg ? "true" : "false") + << ", duration=" << reg_duration_ms << "ms"; + } + + // Collect per-context keys for THIS chunk (address-range lookup). + BufferDesc buffer_desc; + for (auto &context : context_list_) { + buffer_desc.lkey.push_back(context->lkey(chunk_addr)); + buffer_desc.rkey.push_back(context->rkey(chunk_addr)); + } + buffer_desc.name = resolved_name; + buffer_desc.addr = (uint64_t)chunk_addr; + buffer_desc.length = chunk_len; #ifdef ENABLE_MULTI_PROTOCOL - buffer_desc.protocol = "rdma"; + buffer_desc.protocol = "rdma"; #endif - int rc = metadata_->addLocalMemoryBuffer(buffer_desc, update_metadata); - if (rc) return rc; + int rc = metadata_->addLocalMemoryBuffer(buffer_desc, update_metadata); + if (rc) { + rollbackChunks(ci); + return rc; + } + } + + // Remember chunk start-addresses so unregisterLocalMemory(addr) (which only + // gets the base addr) can clean up every chunk. + if (chunks.size() > 1) { + std::lock_guard lock(chunk_map_mutex_); + std::vector chunk_addrs; + chunk_addrs.reserve(chunks.size()); + for (auto &c : chunks) chunk_addrs.push_back((uint64_t)c.first); + chunk_map_[(uint64_t)addr] = std::move(chunk_addrs); + } return 0; } @@ -326,6 +382,38 @@ int RdmaTransport::unregisterLocalMemory(void *addr, bool update_metadata) { int RdmaTransport::unregisterLocalMemoryInternal(void *addr, bool update_metadata, bool force_sequential) { + // Mooncake#2017: if this base buffer was split into chunks at registration, + // unregister each chunk's MR + metadata entry (unregisterLocalMemory only + // receives the base addr). + std::vector chunk_addrs; + { + std::lock_guard lock(chunk_map_mutex_); + auto it = chunk_map_.find((uint64_t)addr); + if (it != chunk_map_.end()) { + chunk_addrs = std::move(it->second); + chunk_map_.erase(it); + } + } + if (!chunk_addrs.empty()) { + // Unregister EVERY chunk even if one fails; chunk_map_ was already + // erased, so an early return would leak the remaining chunks' MRs + + // metadata. Remember the first error and report it at the end. + int first_err = 0; + for (uint64_t ca : chunk_addrs) { + void *cap = reinterpret_cast(ca); + int rc = metadata_->removeLocalMemoryBuffer(cap, update_metadata); + if (rc && !first_err) first_err = rc; + for (auto &context : context_list_) { + int ret = context->unregisterMemoryRegion(cap); + if (ret) { + LOG(ERROR) << "Failed to unregister chunk MR at " << cap; + if (!first_err) first_err = ret; + } + } + } + return first_err; + } + int rc = metadata_->removeLocalMemoryBuffer(addr, update_metadata); if (rc) return rc; diff --git a/mooncake-transfer-engine/tests/CMakeLists.txt b/mooncake-transfer-engine/tests/CMakeLists.txt index 9bf551eebf..f707e2e5ca 100644 --- a/mooncake-transfer-engine/tests/CMakeLists.txt +++ b/mooncake-transfer-engine/tests/CMakeLists.txt @@ -35,6 +35,12 @@ add_executable(rdma_loopback_test ${WORKSPACE}/rdma_loopback_test.cpp) target_link_libraries(rdma_loopback_test PUBLIC transfer_engine gtest gtest_main ) # add_test(NAME rdma_loopback_test COMMAND rdma_loopback_test) +# Regression test for #2017 (registerLocalMemory must auto-chunk buffers larger +# than the device max_mr_size; loopback WRITE past the boundary must succeed). +add_executable(rdma_large_mr_test ${WORKSPACE}/rdma_large_mr_test.cpp) +target_link_libraries(rdma_large_mr_test PUBLIC transfer_engine gtest gtest_main ) +# add_test(NAME rdma_large_mr_test COMMAND rdma_large_mr_test) # needs an RDMA dev + metadata server + # This test verifies endpoint re-establishment in RDMATransport. add_executable(rdma_endpoint_reestablish_test ${WORKSPACE}/rdma_endpoint_reestablish_test.cpp) target_link_libraries(rdma_endpoint_reestablish_test PUBLIC transfer_engine gtest gtest_main ) diff --git a/mooncake-transfer-engine/tests/rdma_large_mr_test.cpp b/mooncake-transfer-engine/tests/rdma_large_mr_test.cpp new file mode 100644 index 0000000000..4e25625021 --- /dev/null +++ b/mooncake-transfer-engine/tests/rdma_large_mr_test.cpp @@ -0,0 +1,138 @@ +// Copyright 2024 KVCache.AI +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Regression test for issue #2017: RdmaTransport::registerLocalMemory must NOT +// silently truncate a buffer larger than the device max_mr_size. +// +// Bug: registerMemoryRegionInternal shrank `length` to max_mr_size and +// registered a single MR of that size, but registerLocalMemory still published +// BufferDesc.length = full length with that one (truncated) rkey. A remote RDMA +// op whose target address fell past max_mr_size then failed with +// IBV_WC_REM_ACCESS_ERR (e.g. ionic CQE error 10). It is ADDRESS-driven: ops to +// low addresses succeed, ops past the boundary fail -> PD KV transfer ran for a +// while then a worker died mid-run (seen on MI355x/ionic, per-layer KV ~3.25 +// GiB > the ionic 2 GiB max_mr_size). +// +// Fix: split buffers > max_mr_size into <= max_mr_size chunks, register each as +// its own MR, and publish one BufferDesc per chunk. +// +// This test forces the condition at a small, HW-independent size by setting +// MC_MAX_MR_SIZE, then does a LOOPBACK RDMA WRITE whose target lands PAST that +// boundary. Pre-fix: the transfer FAILS (REM_ACCESS_ERR). Post-fix: it +// COMPLETES and the bytes match. Runs on any RDMA device (incl. rdma_rxe / +// loopback). + +#include +#include +#include + +#include +#include +#include + +#include "transfer_engine.h" +#include "transport/transport.h" + +using namespace mooncake; + +namespace mooncake { + +DEFINE_string(metadata_server, "127.0.0.1:2379", + "central metadata server for transfer engine"); + +// Small max_mr_size so the >max_mr_size path is exercised without needing a +// multi-GB allocation. Must be set before TransferEngine init (config reads +// env). +static constexpr size_t kMaxMrSize = 64ull << 20; // 64 MiB +static constexpr size_t kBufferSize = 256ull << 20; // 256 MiB -> 4 chunks + +class RDMALargeMrTest : public ::testing::Test { + public: + void *addr = nullptr; + std::unique_ptr engine; + + protected: + void SetUp() override { + google::InitGoogleLogging("RDMALargeMrTest"); + FLAGS_logtostderr = 1; + // Force a small device max_mr_size cap (config caps to min(env, + // device)). + setenv("MC_MAX_MR_SIZE", std::to_string(kMaxMrSize).c_str(), 1); + engine = std::make_unique(true); + engine->init(FLAGS_metadata_server, "test_node_large_mr"); + addr = numa_alloc_onnode(kBufferSize, 0); + ASSERT_NE(addr, nullptr); + // Register a buffer LARGER than max_mr_size. Pre-fix this silently + // truncates the MR to kMaxMrSize; post-fix it splits into 4 chunks. + int rc = engine->registerLocalMemory(addr, kBufferSize, "cpu:0"); + ASSERT_EQ(rc, 0); + } + + void TearDown() override { + if (engine && addr) engine->unregisterLocalMemory(addr); + if (addr) numa_free(addr, kBufferSize); + google::ShutdownGoogleLogging(); + } +}; + +// Loopback RDMA WRITE whose TARGET lands past max_mr_size. The source stays in +// the first MR; only the destination address exercises the truncation boundary. +TEST_F(RDMALargeMrTest, WritePastMaxMrSizeBoundary) { + const size_t kDataLength = 1ull << 20; // 1 MiB + // Target offset is well past kMaxMrSize (in the 4th chunk). Pre-fix: the + // single 64 MiB MR does not cover this address -> IBV_WC_REM_ACCESS_ERR. + const size_t kTargetOffset = kBufferSize - kDataLength; // ~255 MiB + ASSERT_GT(kTargetOffset, kMaxMrSize); + + for (size_t i = 0; i < kDataLength; ++i) + *((char *)addr + i) = (char)('a' + (lrand48() % 26)); + + auto batch_id = engine->allocateBatchID(1); + TransferRequest entry; + entry.opcode = TransferRequest::WRITE; + entry.length = kDataLength; + entry.source = (uint8_t *)addr; // src in chunk 0 + entry.target_id = LOCAL_SEGMENT_ID; + entry.target_offset = (uint64_t)addr + kTargetOffset; // dst past 64 MiB + + Status s = engine->submitTransfer(batch_id, {entry}); + ASSERT_TRUE(s.ok()); + + TransferStatus status; + bool completed = false; + while (!completed) { + Status gs = engine->getTransferStatus(batch_id, 0, status); + ASSERT_EQ(gs, Status::OK()); + if (status.s == TransferStatusEnum::COMPLETED) + completed = true; + else if (status.s == TransferStatusEnum::FAILED) + break; + } + ASSERT_EQ(engine->freeBatchID(batch_id), Status::OK()); + + // The regression assertion: pre-fix this is FAILED (remote access error); + // post-fix it COMPLETES and the bytes match. + ASSERT_EQ(status.s, TransferStatusEnum::COMPLETED) + << "RDMA WRITE to an address past max_mr_size failed -- buffer MR was " + "truncated and not auto-chunked (issue #2017)."; + ASSERT_EQ(0, memcmp(addr, (char *)addr + kTargetOffset, kDataLength)); +} + +} // namespace mooncake + +int main(int argc, char **argv) { + gflags::ParseCommandLineFlags(&argc, &argv, false); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +}