-
Notifications
You must be signed in to change notification settings - Fork 895
fix(rdma): auto-chunk MRs larger than device max_mr_size (#2017) #2644
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,13 +18,15 @@ | |
| #include <sys/mman.h> | ||
| #include <sys/time.h> | ||
|
|
||
| #include <algorithm> | ||
| #include <cassert> | ||
| #include <chrono> | ||
| #include <cstddef> | ||
| #include <cstdlib> | ||
| #include <future> | ||
| #include <set> | ||
| #include <thread> | ||
| #include <utility> | ||
|
|
||
| #include <dlfcn.h> | ||
|
|
||
|
|
@@ -202,7 +204,6 @@ int RdmaTransport::registerLocalMemoryInternal(void *addr, size_t length, | |
| bool update_metadata, | ||
| bool force_sequential) { | ||
| (void)remote_accessible; | ||
| BufferDesc buffer_desc; | ||
| const int kBaseAccessRights = IBV_ACCESS_LOCAL_WRITE | | ||
| IBV_ACCESS_REMOTE_WRITE | | ||
| IBV_ACCESS_REMOTE_READ; | ||
|
|
@@ -211,111 +212,166 @@ int RdmaTransport::registerLocalMemoryInternal(void *addr, size_t length, | |
| if (MCIbRelaxedOrderingEnabled) { | ||
| access_rights |= IBV_ACCESS_RELAXED_ORDERING; | ||
| } | ||
| bool do_pre_touch = context_list_.size() > 0 && | ||
| std::thread::hardware_concurrency() >= 4 && | ||
| length >= (size_t)4 * 1024 * 1024 * 1024; | ||
| if (do_pre_touch) { | ||
| // Parallel Pre-touch the memory to speedup the registration process. | ||
| int ret = preTouchMemory(addr, length); | ||
| if (ret != 0) { | ||
| return ret; | ||
| } | ||
| } | ||
|
|
||
| /* Parallel register when: | ||
| 1. parallel_reg_mr is enabled via MC_ENABLE_PARALLEL_REG_MR; | ||
| 2. parallel_reg_mr not set and multiple contexts exist and memory has been | ||
| pre-touched | ||
| Note: If memory hasn't been touched, parallel register can be | ||
| slower. Details in: https://github.com/kvcache-ai/Mooncake/issues/848 | ||
| Note: force_sequential is used by batch operations to avoid nested | ||
| parallelism. | ||
| */ | ||
| int use_parallel_reg = 0; | ||
| if (!force_sequential) { | ||
| use_parallel_reg = globalConfig().parallel_reg_mr; | ||
| if (use_parallel_reg == -1) { | ||
| use_parallel_reg = context_list_.size() > 1 && do_pre_touch; | ||
| // Mooncake#2017: ibv_reg_mr silently truncates a registration to the device | ||
| // max_mr_size, but the metadata would still advertise the full BufferDesc | ||
| // length, so any remote RDMA op past the boundary fails with | ||
| // IBV_WC_REM_ACCESS_ERR (ionic CQE error 10). Split buffers larger than | ||
| // max_mr_size into chunks of <= max_mr_size, register each as its own MR, | ||
| // and publish one BufferDesc per chunk (the per-context rkey/lkey lookups | ||
| // are address-range based, so each chunk gets the correct key). | ||
| size_t chunk_limit = (size_t)globalConfig().max_mr_size; | ||
| std::vector<std::pair<void *, size_t>> chunks; | ||
| if (chunk_limit > 0 && length > chunk_limit) { | ||
| for (size_t offset = 0; offset < length;) { | ||
| size_t chunk_len = std::min(chunk_limit, length - offset); | ||
| chunks.emplace_back(static_cast<char *>(addr) + offset, chunk_len); | ||
| offset += chunk_len; | ||
| } | ||
| LOG(WARNING) << "Auto-splitting buffer " << addr << " (" << length | ||
| << " bytes) into " << chunks.size() | ||
| << " chunks of <= " << chunk_limit | ||
| << " bytes each (device max_mr_size; Mooncake#2017)"; | ||
| } else { | ||
| chunks.emplace_back(addr, length); | ||
| } | ||
|
|
||
| auto reg_start = std::chrono::steady_clock::now(); | ||
|
|
||
| if (use_parallel_reg) { | ||
| std::vector<std::thread> reg_threads; | ||
| reg_threads.reserve(context_list_.size()); | ||
| std::vector<int> ret_codes(context_list_.size(), 0); | ||
| const int ar = access_rights; // Local copy for lambda capture | ||
|
|
||
| for (size_t i = 0; i < context_list_.size(); ++i) { | ||
| reg_threads.emplace_back([this, &ret_codes, i, addr, length, ar]() { | ||
| ret_codes[i] = | ||
| context_list_[i]->registerMemoryRegion(addr, length, ar); | ||
| }); | ||
| } | ||
| // Resolve the location name once, from the original buffer. | ||
| std::string resolved_name; | ||
| if (name == kWildcardLocation) { | ||
| bool only_first_page = true; | ||
| const std::vector<MemoryLocationEntry> entries = | ||
| getMemoryLocation(addr, length, only_first_page); | ||
| if (entries.empty()) return -1; | ||
| resolved_name = entries[0].location; | ||
| } else { | ||
| resolved_name = name; | ||
| } | ||
|
|
||
| for (auto &thread : reg_threads) { | ||
| thread.join(); | ||
| // Best-effort rollback of already-registered chunks [0, up_to_ci]. | ||
| auto rollbackChunks = [&](size_t up_to_ci) { | ||
| for (size_t ri = 0; ri <= up_to_ci && ri < chunks.size(); ++ri) { | ||
| metadata_->removeLocalMemoryBuffer(chunks[ri].first, | ||
| update_metadata); | ||
| for (auto &context : context_list_) | ||
| context->unregisterMemoryRegion(chunks[ri].first); | ||
| } | ||
|
|
||
| for (size_t i = 0; i < ret_codes.size(); ++i) { | ||
| if (ret_codes[i] != 0) { | ||
| LOG(ERROR) << "Failed to register memory region with context " | ||
| << i; | ||
| return ret_codes[i]; | ||
| }; | ||
|
|
||
| for (size_t ci = 0; ci < chunks.size(); ++ci) { | ||
| void *chunk_addr = chunks[ci].first; | ||
| size_t chunk_len = chunks[ci].second; | ||
|
|
||
| // Decide pre-touch from the ORIGINAL buffer length, not the capped | ||
| // chunk_len (which is <= max_mr_size, so a chunk_len-based >=4GiB check | ||
| // would never fire and silently disable parallel pre-touch). | ||
| bool do_pre_touch = context_list_.size() > 0 && | ||
| std::thread::hardware_concurrency() >= 4 && | ||
| length >= (size_t)4 * 1024 * 1024 * 1024; | ||
| if (do_pre_touch) { | ||
| // Parallel pre-touch the memory to speed up registration. | ||
| int ret = preTouchMemory(chunk_addr, chunk_len); | ||
| if (ret != 0) { | ||
| if (ci) rollbackChunks(ci - 1); | ||
| return ret; | ||
| } | ||
| } | ||
| } else { | ||
| for (size_t i = 0; i < context_list_.size(); ++i) { | ||
| int ret = context_list_[i]->registerMemoryRegion(addr, length, | ||
| access_rights); | ||
| if (ret) { | ||
| LOG(ERROR) << "Failed to register memory region with context " | ||
| << i; | ||
| return ret; | ||
|
|
||
| /* Parallel register when: | ||
| 1. parallel_reg_mr is enabled via MC_ENABLE_PARALLEL_REG_MR; | ||
| 2. parallel_reg_mr not set, multiple contexts exist, memory pre-touched. | ||
| force_sequential is used by batch operations to avoid nested | ||
| parallelism. | ||
| */ | ||
| int use_parallel_reg = 0; | ||
| if (!force_sequential) { | ||
| use_parallel_reg = globalConfig().parallel_reg_mr; | ||
| if (use_parallel_reg == -1) { | ||
| use_parallel_reg = context_list_.size() > 1 && do_pre_touch; | ||
| } | ||
| } | ||
| } | ||
|
|
||
| auto reg_end = std::chrono::steady_clock::now(); | ||
| auto reg_duration_ms = | ||
| std::chrono::duration_cast<std::chrono::milliseconds>(reg_end - | ||
| reg_start) | ||
| .count(); | ||
| auto reg_start = std::chrono::steady_clock::now(); | ||
|
|
||
| if (globalConfig().trace) { | ||
| LOG(INFO) << "registerMemoryRegion: addr=" << addr | ||
| << ", length=" << length | ||
| << ", contexts=" << context_list_.size() | ||
| << ", parallel=" << (use_parallel_reg ? "true" : "false") | ||
| << ", duration=" << reg_duration_ms << "ms"; | ||
| } | ||
| if (use_parallel_reg) { | ||
| std::vector<std::thread> reg_threads; | ||
| reg_threads.reserve(context_list_.size()); | ||
| std::vector<int> ret_codes(context_list_.size(), 0); | ||
| const int ar = access_rights; // Local copy for lambda capture | ||
|
|
||
| // Collect keys from all contexts | ||
| for (auto &context : context_list_) { | ||
| buffer_desc.lkey.push_back(context->lkey(addr)); | ||
| buffer_desc.rkey.push_back(context->rkey(addr)); | ||
| } | ||
| for (size_t i = 0; i < context_list_.size(); ++i) { | ||
| reg_threads.emplace_back( | ||
| [this, &ret_codes, i, chunk_addr, chunk_len, ar]() { | ||
| ret_codes[i] = context_list_[i]->registerMemoryRegion( | ||
| chunk_addr, chunk_len, ar); | ||
| }); | ||
| } | ||
|
|
||
| // Get the memory location automatically after registered MR(pinned), | ||
| // when the name is kWildcardLocation("*"). | ||
| if (name == kWildcardLocation) { | ||
| bool only_first_page = true; | ||
| const std::vector<MemoryLocationEntry> entries = | ||
| getMemoryLocation(addr, length, only_first_page); | ||
| if (entries.empty()) return -1; | ||
| buffer_desc.name = entries[0].location; | ||
| } else { | ||
| buffer_desc.name = name; | ||
| } | ||
| for (auto &thread : reg_threads) thread.join(); | ||
|
|
||
| for (size_t i = 0; i < ret_codes.size(); ++i) { | ||
| if (ret_codes[i] != 0) { | ||
| LOG(ERROR) << "Failed to register memory region (chunk " | ||
| << ci << ") with context " << i; | ||
| rollbackChunks(ci); | ||
| return ret_codes[i]; | ||
| } | ||
| } | ||
| } else { | ||
| for (size_t i = 0; i < context_list_.size(); ++i) { | ||
| int ret = context_list_[i]->registerMemoryRegion( | ||
| chunk_addr, chunk_len, access_rights); | ||
| if (ret) { | ||
| LOG(ERROR) << "Failed to register memory region (chunk " | ||
| << ci << ") with context " << i; | ||
| rollbackChunks(ci); | ||
| return ret; | ||
| } | ||
| } | ||
| } | ||
|
|
||
| buffer_desc.addr = (uint64_t)addr; | ||
| buffer_desc.length = length; | ||
| auto reg_end = std::chrono::steady_clock::now(); | ||
| auto reg_duration_ms = | ||
| std::chrono::duration_cast<std::chrono::milliseconds>(reg_end - | ||
| reg_start) | ||
| .count(); | ||
| if (globalConfig().trace) { | ||
| LOG(INFO) << "registerMemoryRegion: chunk " << ci << "/" | ||
| << chunks.size() << ", addr=" << chunk_addr | ||
| << ", length=" << chunk_len | ||
| << ", contexts=" << context_list_.size() | ||
| << ", parallel=" << (use_parallel_reg ? "true" : "false") | ||
| << ", duration=" << reg_duration_ms << "ms"; | ||
| } | ||
|
|
||
| // Collect per-context keys for THIS chunk (address-range lookup). | ||
| BufferDesc buffer_desc; | ||
| for (auto &context : context_list_) { | ||
| buffer_desc.lkey.push_back(context->lkey(chunk_addr)); | ||
| buffer_desc.rkey.push_back(context->rkey(chunk_addr)); | ||
| } | ||
| buffer_desc.name = resolved_name; | ||
| buffer_desc.addr = (uint64_t)chunk_addr; | ||
| buffer_desc.length = chunk_len; | ||
| #ifdef ENABLE_MULTI_PROTOCOL | ||
| buffer_desc.protocol = "rdma"; | ||
| buffer_desc.protocol = "rdma"; | ||
| #endif | ||
| int rc = metadata_->addLocalMemoryBuffer(buffer_desc, update_metadata); | ||
| if (rc) return rc; | ||
| int rc = metadata_->addLocalMemoryBuffer(buffer_desc, update_metadata); | ||
| if (rc) { | ||
| rollbackChunks(ci); | ||
| return rc; | ||
| } | ||
| } | ||
|
|
||
| // Remember chunk start-addresses so unregisterLocalMemory(addr) (which only | ||
| // gets the base addr) can clean up every chunk. | ||
| if (chunks.size() > 1) { | ||
| std::lock_guard<std::mutex> lock(chunk_map_mutex_); | ||
| std::vector<uint64_t> chunk_addrs; | ||
| chunk_addrs.reserve(chunks.size()); | ||
| for (auto &c : chunks) chunk_addrs.push_back((uint64_t)c.first); | ||
| chunk_map_[(uint64_t)addr] = std::move(chunk_addrs); | ||
| } | ||
| return 0; | ||
| } | ||
|
|
||
|
|
@@ -326,6 +382,38 @@ int RdmaTransport::unregisterLocalMemory(void *addr, bool update_metadata) { | |
| int RdmaTransport::unregisterLocalMemoryInternal(void *addr, | ||
| bool update_metadata, | ||
| bool force_sequential) { | ||
| // Mooncake#2017: if this base buffer was split into chunks at registration, | ||
| // unregister each chunk's MR + metadata entry (unregisterLocalMemory only | ||
| // receives the base addr). | ||
| std::vector<uint64_t> chunk_addrs; | ||
| { | ||
| std::lock_guard<std::mutex> lock(chunk_map_mutex_); | ||
| auto it = chunk_map_.find((uint64_t)addr); | ||
| if (it != chunk_map_.end()) { | ||
| chunk_addrs = std::move(it->second); | ||
| chunk_map_.erase(it); | ||
| } | ||
| } | ||
| if (!chunk_addrs.empty()) { | ||
| // Unregister EVERY chunk even if one fails; chunk_map_ was already | ||
| // erased, so an early return would leak the remaining chunks' MRs + | ||
| // metadata. Remember the first error and report it at the end. | ||
| int first_err = 0; | ||
| for (uint64_t ca : chunk_addrs) { | ||
| void *cap = reinterpret_cast<void *>(ca); | ||
| int rc = metadata_->removeLocalMemoryBuffer(cap, update_metadata); | ||
| if (rc && !first_err) first_err = rc; | ||
| for (auto &context : context_list_) { | ||
| int ret = context->unregisterMemoryRegion(cap); | ||
| if (ret) { | ||
| LOG(ERROR) << "Failed to unregister chunk MR at " << cap; | ||
| if (!first_err) first_err = ret; | ||
| } | ||
| } | ||
| } | ||
| return first_err; | ||
| } | ||
|
Comment on lines
+397
to
+415
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If unregistration fails for any chunk, returning early prevents the remaining chunks from being unregistered. Since the chunk addresses have already been erased from if (!chunk_addrs.empty()) {
int first_err = 0;
for (uint64_t ca : chunk_addrs) {
void *cap = reinterpret_cast<void *>(ca);
int rc = metadata_->removeLocalMemoryBuffer(cap, update_metadata);
if (rc && !first_err) {
first_err = rc;
}
for (auto &context : context_list_) {
int ret = context->unregisterMemoryRegion(cap);
if (ret) {
LOG(ERROR) << "Failed to unregister chunk MR at " << cap;
if (!first_err) {
first_err = ret;
}
}
}
}
return first_err;
} |
||
|
|
||
| int rc = metadata_->removeLocalMemoryBuffer(addr, update_metadata); | ||
| if (rc) return rc; | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If failed, print a warning log.