diff --git a/docs/source/design/transfer-engine/index.md b/docs/source/design/transfer-engine/index.md index 9ffbff0272..045f9d84da 100644 --- a/docs/source/design/transfer-engine/index.md +++ b/docs/source/design/transfer-engine/index.md @@ -79,6 +79,10 @@ next data transfer attempt. Evicted and deleted endpoints are moved to an internal `waiting_list_` and reclaimed asynchronously once their outstanding slices drain. Reclaim runs on every new endpoint insertion, and additionally on a ~1 Hz heartbeat from the per-context `monitorWorker`, so the waiting list drains even under failure load where new insertions stall while evictions continue. +### Metadata Version Reliability + +Transfer Engine treats segment names and segment ids as lookup handles rather than stable metadata identities. Replacement nodes may reuse both values, and memory registration can also change descriptor contents. Published segment descriptors therefore carry a single `metadata_version`; RDMA workers use version changes to invalidate cached endpoints and rail health derived from older descriptors. See [Transfer Metadata Version Reliability](metadata-generation-reliability.md) for the detailed design. + ### Fault Handling In a multi-NIC environment, one common failure scenario is the temporary unavailability of a specific NIC, while other routes may still connect two nodes. Mooncake Store is designed to adeptly manage such temporary diff --git a/docs/source/design/transfer-engine/metadata-generation-reliability.md b/docs/source/design/transfer-engine/metadata-generation-reliability.md new file mode 100644 index 0000000000..f6c0dc9a90 --- /dev/null +++ b/docs/source/design/transfer-engine/metadata-generation-reliability.md @@ -0,0 +1,217 @@ +# Transfer Metadata Version Reliability + +## Background + +Transfer Engine uses segment metadata to describe remote devices, memory +regions, keys, and topology. In production, a segment name may be reused after a +node replacement. The old and new nodes are assumed not to overlap, but there is +usually a gap between the old node disappearing and the new node publishing its +metadata. + +The reliability problem is not only node replacement. Dynamic memory +registration and deregistration can also change rkeys, address ranges, and +buffer availability. Readers need one simple way to know that cached metadata +and derived transport resources may be stale. + +This design uses a single descriptor-level `metadata_version`. + +```text +same segment name/id + different metadata_version = metadata changed +``` + +The transport layer does not need to classify the change as replacement, +topology update, or memory-region update. It invalidates resources derived from +the old descriptor and rebuilds from the new descriptor. + +## Goals + +- Use one version field for all segment metadata changes. +- Detect replacement of a node that reuses the same segment name/id. +- Detect dynamic memory-region changes. +- Keep legacy metadata readable. +- Keep the transport invalidation rule small and conservative. +- Avoid introducing backend leases, CAS, fencing, or writer heartbeats in this + first step. + +## Non-Goals + +- Supporting overlapping old and new writers for the same segment name. +- Proving strict global ordering across multiple metadata backends. +- Adding metadata watch/subscription. +- Adding typed metadata errors. + +## Data Model + +`SegmentDesc` contains: + +```cpp +uint64_t metadata_version; +``` + +Meaning: + +- `0` means legacy or unknown version. +- Non-zero values identify a published descriptor revision. +- Every metadata publish should assign a new version. +- Readers treat any version change as invalidating cached transport resources + derived from the previous descriptor. + +`BufferDesc` keeps only lifecycle state: + +```cpp +state = READY | DRAINING | REMOVED +``` + +Meaning: + +- missing or empty state is treated as `READY`; +- `READY` buffers can be selected for new transfers; +- `DRAINING` buffers are not selected for new transfers; +- `REMOVED` is reserved for future explicit tombstones. + +## Version Assignment + +The implementation uses a timestamp-style version: + +```text +metadata_version = max(now_ns, previous_metadata_version + 1) +``` + +This avoids the main problem with a per-process counter: after node replacement, +the new process might also start at version `1`, making replacement invisible to +clients that cached the old `1`. + +The timestamp-style value is not intended to be a perfect distributed clock. It +is a compact monotonic freshness token. Under the current non-overlap model it +is sufficient to make a replacement publish differ from the old descriptor. + +## Write Path + +### Segment Startup + +1. The transport creates the local `SegmentDesc`. +2. `addLocalSegment()` stores it in the local cache. +3. `updateLocalSegmentDesc()` bumps `metadata_version`. +4. The descriptor is encoded and published. + +### Memory Registration + +1. Add the local buffer to `SegmentDesc::buffers`. +2. Mark it `READY` if no state was provided. +3. If the operation publishes metadata, `updateLocalSegmentDesc()` bumps + `metadata_version` once and publishes the descriptor. +4. Batch registration can mutate many local buffers and publish once; this + results in one version bump. + +### Memory Deregistration + +For `update_metadata=true`, deregistration is two-phase: + +1. Mark the buffer `DRAINING`. +2. Bump `metadata_version` and publish. +3. Wait for the deregistration grace period. By default this is aligned with + `MC_METADATA_CACHE_TTL_MS`; `MC_METADATA_DEREG_GRACE_MS` can increase it but + should not be lower than the metadata cache TTL while metacache is enabled. +4. Remove the buffer locally. +5. Bump `metadata_version` and publish again. + +For `update_metadata=false`, local metadata changes are not published and do not +need a version bump. + +## Read Path + +Readers decode `metadata_version` from JSON. This implementation does not decode +older experimental `descriptor_version` or `buffer_version` fields. + +Legacy descriptors without any version fields decode as version `0`. + +## Cache Behavior + +The metadata cache stores immutable descriptor snapshots by segment id/name. +When a refreshed descriptor replaces an older cached descriptor: + +```text +if old.metadata_version != new.metadata_version: + record metadata-version change + replace cached descriptor +``` + +The cache also tracks a soft TTL (`MC_METADATA_CACHE_TTL_MS`, default `1000`). +When a cached descriptor expires, the first caller refreshes it; concurrent +callers keep using the existing snapshot instead of stampeding the metadata +backend. `MC_METADATA_CACHE_TTL_MS=0` keeps the historical non-expiring cache +behavior unless a caller explicitly requests `force_update=true`. + +## RDMA Transport Behavior + +Each worker remembers, per target segment: + +```text +segment_id -> { + metadata_version, + peer_nic_paths +} +``` + +On first observation, the worker records the version and peer NIC paths. + +On later observations: + +```text +if metadata_version changed: + delete old RDMA endpoints for old peer NIC paths + clear rail state for old peer NIC paths + remember new version and paths +``` + +This is intentionally conservative. A memory-region update may not strictly +require deleting endpoints, but using one invalidation rule keeps the first +implementation simple and avoids stale rkey/path coupling. + +## Replacement Gap Behavior + +During the gap between old and new nodes: + +- metadata may be missing; +- endpoint setup may fail; +- in-flight work may complete with errors; +- force refresh may not find the new descriptor yet. + +Expected behavior: + +- workers use existing retry and redispatch logic; +- once the new descriptor appears, refresh observes a new `metadata_version`; +- old RDMA endpoints and rail state are invalidated; +- subsequent transfers use the new descriptor. + +If the gap exceeds the retry budget, the transfer can still fail and the caller +or scheduler should retry later. + +## Observability + +The implementation exposes: + +- `SegmentDesc::dump()` prints `metadata_version`; +- each buffer dump prints lifecycle state; +- metadata dump prints `segmentMetadataVersionChangeCount()`; +- RDMA workers log version-triggered endpoint invalidation. + +Useful future metrics: + +```text +metadata_version_change_total +rdma_endpoint_invalidated_by_metadata_version_total +metadata_refresh_total +metadata_refresh_failure_total +``` + +## Current Limitations + +- There is no metadata watch/subscription yet. +- Metadata storage plugins still expose only get/set/remove, not typed errors. +- Two-phase deregistration relies on a grace period rather than explicit drain + ACKs. +- Batch memory deregistration currently publishes one final descriptor after + local removal. It does not publish per-buffer `DRAINING` state in this patch. +- If overlapping old and new writers become possible, this must be extended + with backend leases, CAS, remove-if-owner, or fencing tokens. diff --git a/mooncake-transfer-engine/include/config.h b/mooncake-transfer-engine/include/config.h index 82eb050460..c79e646c80 100644 --- a/mooncake-transfer-engine/include/config.h +++ b/mooncake-transfer-engine/include/config.h @@ -59,6 +59,11 @@ struct GlobalConfig { // which is minutes. Override via MC_HANDSHAKE_CONNECT_TIMEOUT. int handshake_connect_timeout = 5; bool metacache = true; + // 0 keeps the historical behavior: cached segment descriptors never expire + // unless callers request force_update. Positive values enable + // stale-while-refresh: the first caller after TTL expiry refreshes metadata + // synchronously while concurrent callers can keep using stale cache. + uint64_t metadata_cache_ttl_ms = 1000; int log_level = google::INFO; bool trace = false; int64_t slice_timeout = -1; diff --git a/mooncake-transfer-engine/include/transfer_metadata.h b/mooncake-transfer-engine/include/transfer_metadata.h index 69bd651213..507bc84af4 100644 --- a/mooncake-transfer-engine/include/transfer_metadata.h +++ b/mooncake-transfer-engine/include/transfer_metadata.h @@ -30,6 +30,7 @@ #include #include #include +#include #include "common.h" #include "topology.h" @@ -50,9 +51,15 @@ class TransferMetadata { }; struct BufferDesc { + static constexpr const char *STATE_READY = "READY"; + static constexpr const char *STATE_DRAINING = "DRAINING"; + static constexpr const char *STATE_REMOVED = "REMOVED"; + std::string name; uint64_t addr; uint64_t length; + // Empty means legacy READY. + std::string state; #ifdef ENABLE_MULTI_PROTOCOL std::string protocol; // for multi-protocol mode (cxl/tcp/rdma) #endif @@ -107,6 +114,10 @@ class TransferMetadata { uint64_t cxl_base_addr; // TODO : make these two a union or a std::variant std::string timestamp; + // Monotonic metadata publish version. Segment name/id may be reused + // after replacement, so readers treat a version change as invalidating + // resources derived from the old descriptor. + uint64_t metadata_version = 0; // this is for ascend RankInfoDesc rank_info; @@ -183,6 +194,10 @@ class TransferMetadata { int syncSegmentCache(const std::string &segment_name); + void updateSegmentCacheEntry( + SegmentID segment_id, const std::string &segment_name, + const std::shared_ptr &desc); + int removeSegmentDesc(const std::string &segment_name); int addLocalMemoryBuffer(const BufferDesc &buffer_desc, @@ -225,16 +240,39 @@ class TransferMetadata { void dumpMetadataContentUnlocked(); - private: int encodeSegmentDesc(const SegmentDesc &desc, Json::Value &segmentJSON); + std::shared_ptr decodeSegmentDesc( Json::Value &segmentJSON, const std::string &segment_name); + + uint64_t segmentMetadataVersionChangeCount() const { + return segment_metadata_version_change_count_.load( + std::memory_order_relaxed); + } + + private: int receivePeerMetadata(const Json::Value &peer_json, Json::Value &local_json); int receivePeerNotify(const Json::Value &peer_json, Json::Value &local_json); int receivePeerProbe(const Json::Value &peer_json, Json::Value &local_json); std::string getFullMetadataKey(const std::string &segment_name) const; + void updateSegmentCacheEntryLocked( + SegmentID segment_id, const std::string &segment_name, + const std::shared_ptr &desc); + bool isSegmentCacheFreshLocked(SegmentID segment_id) const; + struct SegmentCacheLookup { + SegmentID segment_id = 0; + std::string segment_name; + std::shared_ptr desc; + bool fresh = false; + bool refresh_owner = false; + }; + SegmentCacheLookup lookupSegmentCacheByName( + const std::string &segment_name); + SegmentCacheLookup lookupSegmentCacheByID(SegmentID segment_id); + std::shared_ptr getLocalSegmentDescByName( + const std::string &segment_name); bool p2p_handshake_mode_{false}; std::string common_key_prefix_; @@ -244,6 +282,8 @@ class TransferMetadata { std::unordered_map> segment_id_to_desc_map_; std::unordered_map segment_name_to_id_map_; + std::unordered_map segment_cache_update_ns_map_; + std::unordered_set segment_refreshing_ids_; RWSpinlock notify_lock_; std::vector notifys; @@ -252,6 +292,7 @@ class TransferMetadata { RpcMetaDesc local_rpc_meta_; std::atomic next_segment_id_; + std::atomic segment_metadata_version_change_count_{0}; std::shared_ptr handshake_plugin_; std::shared_ptr storage_plugin_; diff --git a/mooncake-transfer-engine/include/transport/rdma_transport/rdma_endpoint.h b/mooncake-transfer-engine/include/transport/rdma_transport/rdma_endpoint.h index d1ae09df06..e98fa8e76a 100644 --- a/mooncake-transfer-engine/include/transport/rdma_transport/rdma_endpoint.h +++ b/mooncake-transfer-engine/include/transport/rdma_transport/rdma_endpoint.h @@ -59,6 +59,7 @@ class RdmaEndPoint { int reconstruct(); int deconstruct(); int deconstructLocked(); + void beginDestroyLocked(); public: void setPeerNicPath(const std::string &peer_nic_path); @@ -92,6 +93,11 @@ class RdmaEndPoint { return status_.load(std::memory_order_relaxed) == CONNECTED; } + bool retired() const { + auto status = status_.load(std::memory_order_relaxed); + return status == DESTROYING || status == DESTROYED; + } + // Interrupts the connection, which can be triggered by user or by internal // error. Use setupConnectionsByActive or setupConnectionsByPassive to // reconnect @@ -113,25 +119,8 @@ class RdmaEndPoint { private: int disconnectUnlocked(); - // Resets the connection. - // - // The main difference between this function and `disconnectUnlocked` - // is that it will reconstruct QPs when `CONFIG_ERDMA` is defined. - // Without `CONFIG_ERDMA`, it is essentially the same as - // `disconnectUnlocked` but with additional logging. - // - // This serves as a workaround for Aliyun eRDMA devices (i.e., once a QP is - // transitioned to the RTS state, it cannot be reset to RTS again directly). - // For more details: - // https://github.com/kvcache-ai/Mooncake/pull/1733#discussion_r2992088663 - // - // In practice: - // - Call `resetConnection` if the QPs' state may have transitioned to RTS. - // - Call `disconnectUnlocked` otherwise. - // - // This is mainly used in `setupConnectionsByActive` or - // `setupConnectionsByPassive`. It is NOT invoked in the normal execution - // flow, so a `reason` argument is passed for internal logging purposes. + // Resets only pre-connected handshake attempts. Once an endpoint has ever + // reached CONNECTED, it is retired instead of being reused. int resetConnection(const std::string &reason); public: @@ -199,6 +188,7 @@ class RdmaEndPoint { std::string peer_nic_path_; std::vector peer_qp_num_list_; + bool has_connected_; volatile int *wr_depth_list_; int max_wr_depth_; diff --git a/mooncake-transfer-engine/include/transport/rdma_transport/worker_pool.h b/mooncake-transfer-engine/include/transport/rdma_transport/worker_pool.h index 08ee083ba9..c42af8499b 100644 --- a/mooncake-transfer-engine/include/transport/rdma_transport/worker_pool.h +++ b/mooncake-transfer-engine/include/transport/rdma_transport/worker_pool.h @@ -40,6 +40,8 @@ class WorkerPool { void transferWorker(int thread_id); + bool hasOutstandingCq(int thread_id); + void monitorWorker(); int doProcessContextEvents(); @@ -52,6 +54,12 @@ class WorkerPool { void markRailFailed(const std::string &peer_nic_path); bool isRailAvailable(const std::string &peer_nic_path); + void clearRailState(const std::vector &peer_nic_paths); + + std::vector buildPeerNicPaths( + const Transport::SegmentDesc &desc) const; + void recordPeerMetadataVersion(SegmentID segment_id, + const Transport::SegmentDesc &desc); // Retry helper: increment retry count and return whether retry is allowed static bool shouldRetrySlice(Transport::Slice *slice); @@ -86,8 +94,8 @@ class WorkerPool { std::vector worker_thread_; std::atomic workers_running_; - std::atomic suspended_flag_; + std::atomic parked_worker_count_; std::atomic redispatch_counter_; std::mutex cond_mutex_; @@ -109,6 +117,14 @@ class WorkerPool { std::unordered_map rail_states_; std::mutex rail_state_lock_; + struct TargetMetadataState { + bool initialized = false; + uint64_t metadata_version = 0; + std::vector peer_nic_paths; + }; + std::unordered_map target_metadata_; + std::mutex target_metadata_lock_; + // Rail monitor configuration const static int kRailErrorThreshold = 5; // Errors before pause const static uint64_t kRailPauseNs = 1000000000ull; // 1 second pause diff --git a/mooncake-transfer-engine/src/config.cpp b/mooncake-transfer-engine/src/config.cpp index aef962c230..b4dff2886c 100644 --- a/mooncake-transfer-engine/src/config.cpp +++ b/mooncake-transfer-engine/src/config.cpp @@ -240,6 +240,24 @@ void loadGlobalConfig(GlobalConfig& config) { config.metacache = false; } + const char* metadata_cache_ttl_ms = + std::getenv("MC_METADATA_CACHE_TTL_MS"); + if (metadata_cache_ttl_ms) { + try { + uint64_t val = std::stoull(metadata_cache_ttl_ms); + if (val <= 3600ULL * 1000ULL) { + config.metadata_cache_ttl_ms = val; + } else { + LOG(WARNING) << "Ignore value from environment variable " + "MC_METADATA_CACHE_TTL_MS"; + } + } catch (const std::exception& e) { + LOG(WARNING) << "Invalid MC_METADATA_CACHE_TTL_MS environment " + "value: " + << metadata_cache_ttl_ms << ". Error: " << e.what(); + } + } + const char* handshake_listen_backlog = std::getenv("MC_HANDSHAKE_LISTEN_BACKLOG"); if (handshake_listen_backlog) { @@ -505,6 +523,9 @@ void dumpGlobalConfig() { LOG(INFO) << "max_wr = " << config.max_wr; LOG(INFO) << "max_inline = " << config.max_inline; LOG(INFO) << "mtu_length = " << mtuLengthToString(config.mtu_length); + LOG(INFO) << "metacache = " << (config.metacache ? "true" : "false"); + LOG(INFO) << "metadata_cache_ttl_ms = " + << config.metadata_cache_ttl_ms; LOG(INFO) << "parallel_reg_mr = " << config.parallel_reg_mr; LOG(INFO) << "ib_traffic_class = " << config.ib_traffic_class; { diff --git a/mooncake-transfer-engine/src/transfer_metadata.cpp b/mooncake-transfer-engine/src/transfer_metadata.cpp index 0597daf0d0..7359d6424d 100644 --- a/mooncake-transfer-engine/src/transfer_metadata.cpp +++ b/mooncake-transfer-engine/src/transfer_metadata.cpp @@ -17,8 +17,10 @@ #include #include +#include #include -#include +#include +#include #include "common.h" #include "config.h" @@ -48,6 +50,52 @@ static inline std::string extractProtocolFromConnString( return "etcd"; } +static uint64_t metadataDeregisterGraceMs() { + uint64_t min_grace_ms = + globalConfig().metacache ? globalConfig().metadata_cache_ttl_ms : 0; + const char *env = std::getenv("MC_METADATA_DEREG_GRACE_MS"); + if (!env || !*env) return min_grace_ms; + try { + uint64_t grace_ms = std::stoull(env); + if (grace_ms < min_grace_ms) { + LOG(WARNING) << "MC_METADATA_DEREG_GRACE_MS=" << grace_ms + << " is less than MC_METADATA_CACHE_TTL_MS=" + << min_grace_ms << ", using " << min_grace_ms; + return min_grace_ms; + } + return grace_ms; + } catch (const std::exception &e) { + LOG(WARNING) << "Invalid MC_METADATA_DEREG_GRACE_MS=" << env + << ", using " << min_grace_ms << ": " << e.what(); + return min_grace_ms; + } +} + +static void decodeBufferState(const Json::Value &bufferJSON, + TransferMetadata::BufferDesc &buffer) { + if (bufferJSON.isMember("state")) { + buffer.state = bufferJSON["state"].asString(); + } else { + buffer.state = TransferMetadata::BufferDesc::STATE_READY; + } +} + +static void normalizeSegmentDescVersions( + TransferMetadata::SegmentDesc &desc) { + if (desc.metadata_version == 0) desc.metadata_version = 1; + for (auto &buffer : desc.buffers) { + if (buffer.state.empty()) + buffer.state = TransferMetadata::BufferDesc::STATE_READY; + } +} + +static void bumpSegmentMetadataVersion( + TransferMetadata::SegmentDesc &desc) { + uint64_t now_ns = getCurrentTimeInNano(); + desc.metadata_version = + now_ns > desc.metadata_version ? now_ns : desc.metadata_version + 1; +} + struct TransferNotifyUtil { static Json::Value encode(const TransferMetadata::NotifyDesc &desc) { Json::Value root; @@ -222,6 +270,8 @@ static int encodeMultiProtocolSegmentDesc( if (!desc.rdma_server_name.empty()) { segmentJSON["rdma_server_name"] = desc.rdma_server_name; } + segmentJSON["metadata_version"] = + static_cast(desc.metadata_version); Json::Value protocolJSON(Json::arrayValue); for (const auto &proto : protocols) { if (proto == "rdma") { @@ -246,6 +296,7 @@ static int encodeMultiProtocolSegmentDesc( Json::Value buffersJSON(Json::arrayValue); for (const auto &buffer : desc.buffers) { Json::Value bufferJSON; + if (!buffer.state.empty()) bufferJSON["state"] = buffer.state; bufferJSON["name"] = buffer.name; bufferJSON["length"] = static_cast(buffer.length); bufferJSON["protocol"] = buffer.protocol; @@ -318,6 +369,8 @@ int TransferMetadata::encodeSegmentDesc(const SegmentDesc &desc, segmentJSON["protocol"] = desc.protocol; segmentJSON["tcp_data_port"] = desc.tcp_data_port; segmentJSON["timestamp"] = getCurrentDateTime(); + segmentJSON["metadata_version"] = + static_cast(desc.metadata_version); if (!desc.rdma_server_name.empty()) { segmentJSON["rdma_server_name"] = desc.rdma_server_name; } @@ -338,6 +391,7 @@ int TransferMetadata::encodeSegmentDesc(const SegmentDesc &desc, Json::Value buffersJSON(Json::arrayValue); for (const auto &buffer : desc.buffers) { Json::Value bufferJSON; + if (!buffer.state.empty()) bufferJSON["state"] = buffer.state; bufferJSON["name"] = buffer.name; bufferJSON["addr"] = static_cast(buffer.addr); bufferJSON["length"] = static_cast(buffer.length); @@ -364,6 +418,7 @@ int TransferMetadata::encodeSegmentDesc(const SegmentDesc &desc, Json::Value buffersJSON(Json::arrayValue); for (const auto &buffer : desc.buffers) { Json::Value bufferJSON; + if (!buffer.state.empty()) bufferJSON["state"] = buffer.state; bufferJSON["name"] = buffer.name; bufferJSON["addr"] = static_cast(buffer.addr); bufferJSON["length"] = static_cast(buffer.length); @@ -378,6 +433,7 @@ int TransferMetadata::encodeSegmentDesc(const SegmentDesc &desc, Json::Value buffersJSON(Json::arrayValue); for (const auto &buffer : desc.buffers) { Json::Value bufferJSON; + if (!buffer.state.empty()) bufferJSON["state"] = buffer.state; bufferJSON["name"] = buffer.name; bufferJSON["addr"] = static_cast(buffer.addr); bufferJSON["length"] = static_cast(buffer.length); @@ -396,6 +452,7 @@ int TransferMetadata::encodeSegmentDesc(const SegmentDesc &desc, Json::Value buffersJSON(Json::arrayValue); for (const auto &buffer : desc.buffers) { Json::Value bufferJSON; + if (!buffer.state.empty()) bufferJSON["state"] = buffer.state; bufferJSON["name"] = buffer.name; bufferJSON["addr"] = static_cast(buffer.addr); bufferJSON["length"] = static_cast(buffer.length); @@ -435,6 +492,7 @@ int TransferMetadata::encodeSegmentDesc(const SegmentDesc &desc, Json::Value buffersJSON(Json::arrayValue); for (const auto &buffer : desc.buffers) { Json::Value bufferJSON; + if (!buffer.state.empty()) bufferJSON["state"] = buffer.state; bufferJSON["name"] = buffer.name; bufferJSON["addr"] = static_cast(buffer.addr); bufferJSON["length"] = static_cast(buffer.length); @@ -449,6 +507,7 @@ int TransferMetadata::encodeSegmentDesc(const SegmentDesc &desc, Json::Value buffersJSON(Json::arrayValue); for (const auto &buffer : desc.buffers) { Json::Value bufferJSON; + if (!buffer.state.empty()) bufferJSON["state"] = buffer.state; bufferJSON["name"] = buffer.name; bufferJSON["offset"] = static_cast(buffer.offset); bufferJSON["length"] = static_cast(buffer.length); @@ -515,6 +574,8 @@ decodeMultiProtocolSegmentDesc(Json::Value &segmentJSON, desc->tcp_data_port = segmentJSON["tcp_data_port"].asInt(); if (segmentJSON.isMember("timestamp")) desc->timestamp = segmentJSON["timestamp"].asString(); + if (segmentJSON.isMember("metadata_version")) + desc->metadata_version = segmentJSON["metadata_version"].asUInt64(); if (segmentJSON.isMember("rdma_server_name")) desc->rdma_server_name = segmentJSON["rdma_server_name"].asString(); @@ -556,6 +617,7 @@ decodeMultiProtocolSegmentDesc(Json::Value &segmentJSON, if (buffer_protocol == "cxl") { TransferMetadata::BufferDesc buffer; + decodeBufferState(bufferJSON, buffer); buffer.name = bufferJSON["name"].asString(); buffer.offset = bufferJSON["offset"].asUInt64(); buffer.length = bufferJSON["length"].asUInt64(); @@ -569,6 +631,7 @@ decodeMultiProtocolSegmentDesc(Json::Value &segmentJSON, desc->buffers.push_back(buffer); } else if (buffer_protocol == "rdma") { TransferMetadata::BufferDesc buffer; + decodeBufferState(bufferJSON, buffer); buffer.name = bufferJSON["name"].asString(); buffer.addr = bufferJSON["addr"].asUInt64(); buffer.length = bufferJSON["length"].asUInt64(); @@ -595,6 +658,7 @@ decodeMultiProtocolSegmentDesc(Json::Value &segmentJSON, desc->buffers.push_back(buffer); } else if (buffer_protocol == "tcp") { TransferMetadata::BufferDesc buffer; + decodeBufferState(bufferJSON, buffer); buffer.name = bufferJSON["name"].asString(); buffer.addr = bufferJSON["addr"].asUInt64(); buffer.length = bufferJSON["length"].asUInt64(); @@ -664,6 +728,8 @@ TransferMetadata::decodeSegmentDesc(Json::Value &segmentJSON, desc->tcp_data_port = segmentJSON["tcp_data_port"].asInt(); if (segmentJSON.isMember("timestamp")) desc->timestamp = segmentJSON["timestamp"].asString(); + if (segmentJSON.isMember("metadata_version")) + desc->metadata_version = segmentJSON["metadata_version"].asUInt64(); if (segmentJSON.isMember("rdma_server_name")) desc->rdma_server_name = segmentJSON["rdma_server_name"].asString(); @@ -684,6 +750,7 @@ TransferMetadata::decodeSegmentDesc(Json::Value &segmentJSON, for (const auto &bufferJSON : segmentJSON["buffers"]) { BufferDesc buffer; + decodeBufferState(bufferJSON, buffer); buffer.name = bufferJSON["name"].asString(); buffer.addr = bufferJSON["addr"].asUInt64(); buffer.length = bufferJSON["length"].asUInt64(); @@ -729,6 +796,7 @@ TransferMetadata::decodeSegmentDesc(Json::Value &segmentJSON, for (const auto &bufferJSON : segmentJSON["buffers"]) { BufferDesc buffer; + decodeBufferState(bufferJSON, buffer); buffer.name = bufferJSON["name"].asString(); buffer.addr = bufferJSON["addr"].asUInt64(); buffer.length = bufferJSON["length"].asUInt64(); @@ -753,6 +821,7 @@ TransferMetadata::decodeSegmentDesc(Json::Value &segmentJSON, } else if (desc->protocol == "tcp") { for (const auto &bufferJSON : segmentJSON["buffers"]) { BufferDesc buffer; + decodeBufferState(bufferJSON, buffer); buffer.name = bufferJSON["name"].asString(); buffer.addr = bufferJSON["addr"].asUInt64(); buffer.length = bufferJSON["length"].asUInt64(); @@ -769,6 +838,7 @@ TransferMetadata::decodeSegmentDesc(Json::Value &segmentJSON, desc->protocol == "sunrise_link") { for (const auto &bufferJSON : segmentJSON["buffers"]) { BufferDesc buffer; + decodeBufferState(bufferJSON, buffer); buffer.name = bufferJSON["name"].asString(); buffer.addr = bufferJSON["addr"].asUInt64(); buffer.length = bufferJSON["length"].asUInt64(); @@ -810,6 +880,7 @@ TransferMetadata::decodeSegmentDesc(Json::Value &segmentJSON, for (const auto &bufferJSON : segmentJSON["buffers"]) { BufferDesc buffer; + decodeBufferState(bufferJSON, buffer); buffer.name = bufferJSON["name"].asString(); buffer.addr = bufferJSON["addr"].asUInt64(); buffer.length = bufferJSON["length"].asUInt64(); @@ -845,6 +916,7 @@ TransferMetadata::decodeSegmentDesc(Json::Value &segmentJSON, desc->cxl_base_addr = segmentJSON["cxl_base_addr"].asUInt64(); for (const auto &bufferJSON : segmentJSON["buffers"]) { BufferDesc buffer; + decodeBufferState(bufferJSON, buffer); buffer.name = bufferJSON["name"].asString(); buffer.offset = bufferJSON["offset"].asUInt64(); buffer.length = bufferJSON["length"].asUInt64(); @@ -972,38 +1044,47 @@ int TransferMetadata::syncSegmentCache(const std::string &segment_name) { for (const auto &[name, desc] : updates) { auto it = segment_name_to_id_map_.find(name); if (it != segment_name_to_id_map_.end()) { - segment_id_to_desc_map_[it->second] = desc; + updateSegmentCacheEntryLocked(it->second, name, desc); } } return 0; } +void TransferMetadata::updateSegmentCacheEntry( + SegmentID segment_id, const std::string &segment_name, + const std::shared_ptr &desc) { + RWSpinlock::WriteGuard guard(segment_lock_); + updateSegmentCacheEntryLocked(segment_id, segment_name, desc); +} + std::shared_ptr TransferMetadata::getSegmentDescByName(const std::string &segment_name, bool force_update) { - if (globalConfig().metacache && !force_update) { - RWSpinlock::ReadGuard guard(segment_lock_); - auto iter = segment_name_to_id_map_.find(segment_name); - if (iter != segment_name_to_id_map_.end()) - return segment_id_to_desc_map_[iter->second]; - } + auto local_desc = getLocalSegmentDescByName(segment_name); + if (local_desc) return local_desc; - // Check if it's LOCAL_SEGMENT_ID - { - RWSpinlock::ReadGuard guard(segment_lock_); - auto iter = segment_name_to_id_map_.find(segment_name); - if (iter != segment_name_to_id_map_.end() && - iter->second == LOCAL_SEGMENT_ID) { - return segment_id_to_desc_map_[iter->second]; - } + SegmentCacheLookup cached; + if (globalConfig().metacache && !force_update) { + cached = lookupSegmentCacheByName(segment_name); + if (cached.fresh) return cached.desc; + if (cached.desc && !cached.refresh_owner) return cached.desc; } // Fetch segment descriptor without holding lock (may involve network I/O) auto segment_desc = this->getSegmentDesc(segment_name); - if (!segment_desc) return nullptr; + if (!segment_desc) { + if (cached.desc && cached.refresh_owner) { + RWSpinlock::WriteGuard guard(segment_lock_); + segment_refreshing_ids_.erase(cached.segment_id); + } + return cached.desc; + } // Update cache with write lock RWSpinlock::WriteGuard guard(segment_lock_); + if (cached.desc && cached.refresh_owner) { + segment_refreshing_ids_.erase(cached.segment_id); + } auto iter = segment_name_to_id_map_.find(segment_name); SegmentID segment_id; if (iter != segment_name_to_id_map_.end()) { @@ -1011,37 +1092,48 @@ TransferMetadata::getSegmentDescByName(const std::string &segment_name, } else { segment_id = next_segment_id_.fetch_add(1); } - segment_id_to_desc_map_[segment_id] = segment_desc; - segment_name_to_id_map_[segment_name] = segment_id; + updateSegmentCacheEntryLocked(segment_id, segment_name, segment_desc); return segment_desc; } std::shared_ptr TransferMetadata::getSegmentDescByID(SegmentID segment_id, bool force_update) { - if (segment_id != LOCAL_SEGMENT_ID && - (!globalConfig().metacache || force_update)) { - // Get segment name without holding lock during network I/O - std::string segment_name; - { - RWSpinlock::ReadGuard guard(segment_lock_); - if (!segment_id_to_desc_map_.count(segment_id)) return nullptr; - segment_name = segment_id_to_desc_map_[segment_id]->name; - } + if (segment_id == LOCAL_SEGMENT_ID) { + RWSpinlock::ReadGuard guard(segment_lock_); + if (!segment_id_to_desc_map_.count(segment_id)) return nullptr; + return segment_id_to_desc_map_[segment_id]; + } - // Fetch segment descriptor without holding lock (may involve network - // I/O) - auto segment_desc = getSegmentDesc(segment_name); - if (!segment_desc) return nullptr; + std::string segment_name; + SegmentCacheLookup cached; - // Update cache with write lock - RWSpinlock::WriteGuard guard(segment_lock_); - segment_id_to_desc_map_[segment_id] = segment_desc; - return segment_id_to_desc_map_[segment_id]; + if (globalConfig().metacache && !force_update) { + cached = lookupSegmentCacheByID(segment_id); + if (!cached.desc) return nullptr; + if (cached.fresh) return cached.desc; + if (!cached.refresh_owner) return cached.desc; + segment_name = cached.segment_name; } else { + // Get segment name without holding lock during network I/O. RWSpinlock::ReadGuard guard(segment_lock_); if (!segment_id_to_desc_map_.count(segment_id)) return nullptr; - return segment_id_to_desc_map_[segment_id]; + segment_name = segment_id_to_desc_map_[segment_id]->name; + } + + // Fetch segment descriptor without holding lock (may involve network I/O). + auto segment_desc = getSegmentDesc(segment_name); + { + RWSpinlock::WriteGuard guard(segment_lock_); + if (cached.refresh_owner) segment_refreshing_ids_.erase(segment_id); + if (segment_desc) { + updateSegmentCacheEntryLocked(segment_id, segment_name, + segment_desc); + return segment_id_to_desc_map_[segment_id]; + } } + + if (globalConfig().metacache && !force_update) return cached.desc; + return nullptr; } TransferMetadata::SegmentID TransferMetadata::getSegmentID( @@ -1061,21 +1153,107 @@ TransferMetadata::SegmentID TransferMetadata::getSegmentID( if (segment_name_to_id_map_.count(segment_name)) return segment_name_to_id_map_[segment_name]; SegmentID id = next_segment_id_.fetch_add(1); - segment_id_to_desc_map_[id] = segment_desc; - segment_name_to_id_map_[segment_name] = id; + updateSegmentCacheEntryLocked(id, segment_name, segment_desc); return id; } +void TransferMetadata::updateSegmentCacheEntryLocked( + SegmentID segment_id, const std::string &segment_name, + const std::shared_ptr &desc) { + auto old_iter = segment_id_to_desc_map_.find(segment_id); + if (segment_id != LOCAL_SEGMENT_ID && + old_iter != segment_id_to_desc_map_.end() && old_iter->second && desc && + old_iter->second->metadata_version != desc->metadata_version) { + segment_metadata_version_change_count_.fetch_add( + 1, std::memory_order_relaxed); + LOG(INFO) << "Segment metadata version changed in cache: id=" + << segment_id << " name=" << segment_name + << " old_name=" << old_iter->second->name + << " old_version=" + << old_iter->second->metadata_version + << " new_version=" << desc->metadata_version; + } + + segment_id_to_desc_map_[segment_id] = desc; + segment_name_to_id_map_[segment_name] = segment_id; + segment_cache_update_ns_map_[segment_id] = getCurrentTimeInNano(); +} + +bool TransferMetadata::isSegmentCacheFreshLocked(SegmentID segment_id) const { + if (segment_id == LOCAL_SEGMENT_ID) return true; + if (!globalConfig().metacache) return false; + uint64_t ttl_ms = globalConfig().metadata_cache_ttl_ms; + if (ttl_ms == 0) return true; + auto iter = segment_cache_update_ns_map_.find(segment_id); + if (iter == segment_cache_update_ns_map_.end()) return false; + uint64_t now_ns = getCurrentTimeInNano(); + uint64_t ttl_ns = ttl_ms * 1000ULL * 1000ULL; + return now_ns >= iter->second && now_ns - iter->second <= ttl_ns; +} + +std::shared_ptr +TransferMetadata::getLocalSegmentDescByName(const std::string &segment_name) { + RWSpinlock::ReadGuard guard(segment_lock_); + auto iter = segment_name_to_id_map_.find(segment_name); + if (iter == segment_name_to_id_map_.end() || + iter->second != LOCAL_SEGMENT_ID) { + return nullptr; + } + return segment_id_to_desc_map_[LOCAL_SEGMENT_ID]; +} + +TransferMetadata::SegmentCacheLookup +TransferMetadata::lookupSegmentCacheByName(const std::string &segment_name) { + SegmentCacheLookup lookup; + RWSpinlock::WriteGuard guard(segment_lock_); + auto iter = segment_name_to_id_map_.find(segment_name); + if (iter == segment_name_to_id_map_.end()) return lookup; + + lookup.segment_id = iter->second; + lookup.segment_name = segment_name; + auto desc_iter = segment_id_to_desc_map_.find(lookup.segment_id); + if (desc_iter == segment_id_to_desc_map_.end() || !desc_iter->second) { + return lookup; + } + lookup.desc = desc_iter->second; + lookup.fresh = isSegmentCacheFreshLocked(lookup.segment_id); + if (!lookup.fresh) { + lookup.refresh_owner = + segment_refreshing_ids_.insert(lookup.segment_id).second; + } + return lookup; +} + +TransferMetadata::SegmentCacheLookup +TransferMetadata::lookupSegmentCacheByID(SegmentID segment_id) { + SegmentCacheLookup lookup; + RWSpinlock::WriteGuard guard(segment_lock_); + auto iter = segment_id_to_desc_map_.find(segment_id); + if (iter == segment_id_to_desc_map_.end()) return lookup; + + lookup.segment_id = segment_id; + lookup.desc = iter->second; + if (!lookup.desc) return lookup; + lookup.segment_name = lookup.desc->name; + lookup.fresh = isSegmentCacheFreshLocked(segment_id); + if (!lookup.fresh) { + lookup.refresh_owner = segment_refreshing_ids_.insert(segment_id).second; + } + return lookup; +} + int TransferMetadata::updateLocalSegmentDesc(uint64_t segment_id) { std::shared_ptr desc; { - RWSpinlock::ReadGuard guard(segment_lock_); + RWSpinlock::WriteGuard guard(segment_lock_); auto it = segment_id_to_desc_map_.find(segment_id); if (it == segment_id_to_desc_map_.end() || !it->second) { LOG(ERROR) << "Segment descriptor " << segment_id << " not found"; return ERR_METADATA; } - desc = it->second; + desc = std::make_shared(*it->second); + bumpSegmentMetadataVersion(*desc); + it->second = desc; } return this->updateSegmentDesc(desc->name, *desc); } @@ -1084,8 +1262,11 @@ int TransferMetadata::addLocalSegment(SegmentID segment_id, const std::string &segment_name, std::shared_ptr &&desc) { RWSpinlock::WriteGuard guard(segment_lock_); + if (desc) normalizeSegmentDescVersions(*desc); segment_id_to_desc_map_[segment_id] = desc; segment_name_to_id_map_[segment_name] = segment_id; + segment_cache_update_ns_map_[segment_id] = getCurrentTimeInNano(); + segment_refreshing_ids_.erase(segment_id); return 0; } @@ -1095,6 +1276,8 @@ int TransferMetadata::removeLocalSegment(const std::string &segment_name) { int segment_id = segment_name_to_id_map_[segment_name]; segment_name_to_id_map_.erase(segment_name); segment_id_to_desc_map_.erase(segment_id); + segment_cache_update_ns_map_.erase(segment_id); + segment_refreshing_ids_.erase(segment_id); } return 0; } @@ -1107,7 +1290,10 @@ int TransferMetadata::addLocalMemoryBuffer(const BufferDesc &buffer_desc, auto &segment_desc = segment_id_to_desc_map_[LOCAL_SEGMENT_ID]; *new_segment_desc = *segment_desc; segment_desc = new_segment_desc; - segment_desc->buffers.push_back(buffer_desc); + BufferDesc updated_buffer = buffer_desc; + if (updated_buffer.state.empty()) + updated_buffer.state = BufferDesc::STATE_READY; + segment_desc->buffers.push_back(updated_buffer); } if (update_metadata) return updateLocalSegmentDesc(); return 0; @@ -1116,6 +1302,8 @@ int TransferMetadata::addLocalMemoryBuffer(const BufferDesc &buffer_desc, int TransferMetadata::removeLocalMemoryBuffer(void *addr, bool update_metadata) { bool addr_exist = false; + std::shared_ptr draining_desc; + std::shared_ptr removed_desc; { RWSpinlock::WriteGuard guard(segment_lock_); auto new_segment_desc = std::make_shared(); @@ -1130,17 +1318,56 @@ int TransferMetadata::removeLocalMemoryBuffer(void *addr, (iter->offset + segment_desc->cxl_base_addr) == (uint64_t)addr #endif ) { - segment_desc->buffers.erase(iter); + if (update_metadata) { + iter->state = BufferDesc::STATE_DRAINING; + bumpSegmentMetadataVersion(*segment_desc); + draining_desc = + std::make_shared(*segment_desc); + } addr_exist = true; break; } } } - if (addr_exist) { - if (update_metadata) return updateLocalSegmentDesc(); - return 0; + if (!addr_exist) { + return ERR_ADDRESS_NOT_REGISTERED; } - return ERR_ADDRESS_NOT_REGISTERED; + + if (draining_desc) { + int ret = updateSegmentDesc(draining_desc->name, *draining_desc); + if (ret) return ret; + uint64_t grace_ms = metadataDeregisterGraceMs(); + if (grace_ms) { + std::this_thread::sleep_for(std::chrono::milliseconds(grace_ms)); + } + } + + { + RWSpinlock::WriteGuard guard(segment_lock_); + auto new_segment_desc = std::make_shared(); + auto &segment_desc = segment_id_to_desc_map_[LOCAL_SEGMENT_ID]; + *new_segment_desc = *segment_desc; + segment_desc = new_segment_desc; + for (auto iter = segment_desc->buffers.begin(); + iter != segment_desc->buffers.end(); ++iter) { + if (iter->addr == (uint64_t)addr +#ifdef USE_CXL + || + (iter->offset + segment_desc->cxl_base_addr) == (uint64_t)addr +#endif + ) { + segment_desc->buffers.erase(iter); + if (update_metadata) bumpSegmentMetadataVersion(*segment_desc); + removed_desc = std::make_shared(*segment_desc); + break; + } + } + } + + if (update_metadata && removed_desc) { + return updateSegmentDesc(removed_desc->name, *removed_desc); + } + return 0; } int TransferMetadata::addRpcMetaEntry(const std::string &server_name, diff --git a/mooncake-transfer-engine/src/transfer_metadata_dump.cpp b/mooncake-transfer-engine/src/transfer_metadata_dump.cpp index 093cd50284..2eac8ad4ae 100644 --- a/mooncake-transfer-engine/src/transfer_metadata_dump.cpp +++ b/mooncake-transfer-engine/src/transfer_metadata_dump.cpp @@ -22,6 +22,7 @@ void TransferMetadata::SegmentDesc::dump() const { LOG(INFO) << " rdma server name: " << rdma_server_name; } LOG(INFO) << " protocol: " << protocol; + LOG(INFO) << " metadata version: " << metadata_version; LOG(INFO) << " topology: " << topology.toString(); LOG(INFO) << " devices: "; for (auto& device : devices) { @@ -32,7 +33,9 @@ void TransferMetadata::SegmentDesc::dump() const { for (auto& buffer : buffers) { LOG(INFO) << " buffer type " << buffer.name << ", address " << (void*)buffer.addr << "--" - << (void*)(buffer.addr + buffer.length); + << (void*)(buffer.addr + buffer.length) + << ", state " + << (buffer.state.empty() ? "READY" : buffer.state); } LOG(INFO) << " nvmeof buffers: " << nvmeof_buffers.size() << " items"; LOG(INFO) << " timestamp: " << timestamp; @@ -67,6 +70,8 @@ void TransferMetadata::dumpMetadataContent(const std::string& segment_name, void TransferMetadata::dumpMetadataContentUnlocked() { LOG(INFO) << "-----------------------------------------------------------"; LOG(INFO) << "TransferMetadata::dumpMetadataContent"; + LOG(INFO) << "remote metadata version changes observed: " + << segmentMetadataVersionChangeCount(); LOG(INFO) << "-----------------------------------------------------------"; LOG(INFO) << "=== Cached Segment Descriptors ==="; for (auto& entry : segment_id_to_desc_map_) { @@ -89,4 +94,4 @@ void TransferMetadata::dumpMetadataContentUnlocked() { << entry.second.rpc_port; } } -} // namespace mooncake \ No newline at end of file +} // namespace mooncake diff --git a/mooncake-transfer-engine/src/transport/rdma_transport/endpoint_store.cpp b/mooncake-transfer-engine/src/transport/rdma_transport/endpoint_store.cpp index f6fcfdace9..9fc47179c2 100644 --- a/mooncake-transfer-engine/src/transport/rdma_transport/endpoint_store.cpp +++ b/mooncake-transfer-engine/src/transport/rdma_transport/endpoint_store.cpp @@ -149,9 +149,15 @@ int FIFOEndpointStore::destroyQPs() { } int FIFOEndpointStore::disconnectQPs() { + RWSpinlock::WriteGuard guard(endpoint_map_lock_); for (auto &kv : endpoint_map_) { - kv.second->disconnect(); + kv.second->beginDestroy(); + waiting_list_.insert(kv.second); } + waiting_list_len_ += endpoint_map_.size(); + endpoint_map_.clear(); + fifo_list_.clear(); + fifo_map_.clear(); return 0; } @@ -318,8 +324,16 @@ int SIEVEEndpointStore::destroyQPs() { } int SIEVEEndpointStore::disconnectQPs() { - for (auto &endpoint : waiting_list_) endpoint->disconnect(); - for (auto &kv : endpoint_map_) kv.second.first->disconnect(); + RWSpinlock::WriteGuard guard(endpoint_map_lock_); + for (auto &kv : endpoint_map_) { + kv.second.first->beginDestroy(); + waiting_list_.insert(kv.second.first); + } + waiting_list_len_ += endpoint_map_.size(); + endpoint_map_.clear(); + fifo_list_.clear(); + fifo_map_.clear(); + hand_ = std::nullopt; return 0; } diff --git a/mooncake-transfer-engine/src/transport/rdma_transport/rdma_endpoint.cpp b/mooncake-transfer-engine/src/transport/rdma_transport/rdma_endpoint.cpp index a374a35879..293227acdb 100644 --- a/mooncake-transfer-engine/src/transport/rdma_transport/rdma_endpoint.cpp +++ b/mooncake-transfer-engine/src/transport/rdma_transport/rdma_endpoint.cpp @@ -67,6 +67,7 @@ static void rememberAutoGidSelection( RdmaEndPoint::RdmaEndPoint(RdmaContext &context) : context_(context), status_(INITIALIZING), + has_connected_(false), wr_depth_list_(nullptr), active_(true), cq_outstanding_(nullptr) {} @@ -202,6 +203,10 @@ int RdmaEndPoint::destroyQP() { return deconstruct(); } void RdmaEndPoint::beginDestroy() { RWSpinlock::WriteGuard guard(lock_); + beginDestroyLocked(); +} + +void RdmaEndPoint::beginDestroyLocked() { auto current_status = status_.load(std::memory_order_relaxed); if (current_status == DESTROYING || current_status == DESTROYED) return; @@ -285,9 +290,12 @@ bool RdmaEndPoint::finishDestroy() { void RdmaEndPoint::setPeerNicPath(const std::string &peer_nic_path) { RWSpinlock::WriteGuard guard(lock_); - if (connected()) { - LOG(WARNING) << "Previous connection will be discarded"; - disconnectUnlocked(); + auto curr_status = status_.load(std::memory_order_relaxed); + if (curr_status != INITIALIZING && curr_status != UNCONNECTED) { + LOG(ERROR) << "Cannot change peer NIC path after endpoint lifecycle " + "has started: " + << toString(); + return; } peer_nic_path_ = peer_nic_path; } @@ -655,51 +663,60 @@ int RdmaEndPoint::disconnectUnlocked() { auto curr_status = status_.load(std::memory_order_acquire); if (curr_status != CONNECTED && curr_status != CONNECTING) return 0; - ibv_qp_attr attr; - memset(&attr, 0, sizeof(attr)); - attr.qp_state = IBV_QPS_RESET; - int ret = 0; - for (size_t i = 0; i < qp_list_.size(); ++i) { - int curr_ret = ibv_modify_qp(qp_list_[i], &attr, IBV_QP_STATE); - if (curr_ret) { - PLOG(ERROR) << "Failed to modify QP to RESET"; - ret = ERR_ENDPOINT; + if (!has_connected_) { + // Pre-connected handshake retries are allowed to reuse this endpoint: + // no user WR has been posted yet. eRDMA still needs fresh QPs because + // a QP that reached RTS cannot be reliably reset back to RTS. +#ifdef CONFIG_ERDMA + for (size_t i = 0; i < qp_list_.size(); ++i) { + CHECK_EQ(wr_depth_list_[i], 0) + << "Pre-connected endpoint must not have outstanding WRs"; } - // After resetting QP, the wr_depth_list_ won't change - bool displayed = false; - if (wr_depth_list_[i] != 0) { - if (!displayed) { - LOG(WARNING) << "Outstanding work requests found, CQ will not " - "be generated"; - displayed = true; + return reconstruct(); +#else + ibv_qp_attr attr; + memset(&attr, 0, sizeof(attr)); + attr.qp_state = IBV_QPS_RESET; + int ret = 0; + for (size_t i = 0; i < qp_list_.size(); ++i) { + int curr_ret = ibv_modify_qp(qp_list_[i], &attr, IBV_QP_STATE); + if (curr_ret) { + PLOG(ERROR) << "Failed to modify pre-connected QP to RESET"; + ret = ERR_ENDPOINT; } - __sync_fetch_and_sub(cq_outstanding_, wr_depth_list_[i]); - wr_depth_list_[i] = 0; + CHECK_EQ(wr_depth_list_[i], 0) + << "Pre-connected endpoint must not have outstanding WRs"; } + peer_qp_num_list_.clear(); + status_.store(UNCONNECTED, std::memory_order_release); + return ret; +#endif } - peer_qp_num_list_.clear(); - status_.store(UNCONNECTED, std::memory_order_release); - return ret; + + beginDestroyLocked(); + return 0; } int RdmaEndPoint::resetConnection(const std::string &reason) { auto curr_status = status_.load(std::memory_order_acquire); if (curr_status != CONNECTING && curr_status != CONNECTED) return 0; -#ifdef CONFIG_ERDMA - int ret = reconstruct(); -#else - int ret = disconnectUnlocked(); -#endif - - if (ret) { - LOG(ERROR) << "Failed to reset the endpoint (triggered by: " << reason - << "): error=" << ret; - } else { - LOG(INFO) << "Successfully reset the endpoint (triggered by: " << reason - << ")."; + if (!has_connected_) { + int ret = disconnectUnlocked(); + if (ret) { + LOG(ERROR) << "Failed to reset pre-connected endpoint " + << "(triggered by: " << reason << "): error=" << ret; + } else { + LOG(INFO) << "Successfully reset pre-connected endpoint " + << "(triggered by: " << reason << ")."; + } + return ret; } - return ret; + + LOG(WARNING) << "Retiring endpoint instead of resetting it (triggered by: " + << reason << "): " << toString(); + beginDestroyLocked(); + return ERR_ENDPOINT; } const std::string RdmaEndPoint::toString() const { @@ -896,6 +913,7 @@ int RdmaEndPoint::doSetupConnection(const std::string &peer_gid, } peer_qp_num_list_ = std::move(peer_qp_num_list); + has_connected_ = true; status_.store(CONNECTED, std::memory_order_relaxed); return 0; } diff --git a/mooncake-transfer-engine/src/transport/rdma_transport/rdma_transport.cpp b/mooncake-transfer-engine/src/transport/rdma_transport/rdma_transport.cpp index 6273a26218..a199aba2ab 100644 --- a/mooncake-transfer-engine/src/transport/rdma_transport/rdma_transport.cpp +++ b/mooncake-transfer-engine/src/transport/rdma_transport/rdma_transport.cpp @@ -702,7 +702,9 @@ int RdmaTransport::onSetupRdmaConnections(const HandShakeDesc &peer_desc, // Use existing endpoint or create new one. auto endpoint = context->endpoint(peer_desc.local_nic_path); if (!endpoint) return ERR_ENDPOINT; - return endpoint->setupConnectionsByPassive(peer_desc, local_desc); + int ret = endpoint->setupConnectionsByPassive(peer_desc, local_desc); + if (endpoint->retired()) context->deleteEndpointByPtr(endpoint.get()); + return ret; } int RdmaTransport::initializeRdmaResources() { @@ -747,6 +749,10 @@ int RdmaTransport::selectDevice(SegmentDesc *desc, uint64_t offset, for (buffer_id = 0; buffer_id < static_cast(buffers.size()); ++buffer_id) { const auto &buffer = buffers[buffer_id]; + if (!buffer.state.empty() && + buffer.state != TransferMetadata::BufferDesc::STATE_READY) { + continue; + } // Check if offset is within buffer range if (offset < buffer.addr || length > buffer.length || diff --git a/mooncake-transfer-engine/src/transport/rdma_transport/worker_pool.cpp b/mooncake-transfer-engine/src/transport/rdma_transport/worker_pool.cpp index 9ef0dd1282..577a894f49 100644 --- a/mooncake-transfer-engine/src/transport/rdma_transport/worker_pool.cpp +++ b/mooncake-transfer-engine/src/transport/rdma_transport/worker_pool.cpp @@ -35,7 +35,7 @@ WorkerPool::WorkerPool(RdmaContext &context, int numa_socket_id) : context_(context), numa_socket_id_(numa_socket_id), workers_running_(true), - suspended_flag_(0), + parked_worker_count_(0), redispatch_counter_(0), submitted_slice_count_(0), processed_slice_count_(0) { @@ -82,6 +82,8 @@ int WorkerPool::submitPostSend( << target_id; return ERR_INVALID_ARGUMENT; } + auto &peer_segment_desc = segment_desc_map[target_id]; + recordPeerMetadataVersion(target_id, *peer_segment_desc); } } #else @@ -89,9 +91,18 @@ int WorkerPool::submitPostSend( segment_desc_map; for (auto &slice : slice_list) { auto target_id = slice->target_id; - if (!segment_desc_map.count(target_id)) + if (!segment_desc_map.count(target_id)) { segment_desc_map[target_id] = context_.engine().meta()->getSegmentDescByID(target_id); + if (!segment_desc_map[target_id]) { + segment_desc_map.clear(); + LOG(ERROR) << "Cannot get target segment description #" + << target_id; + return ERR_INVALID_ARGUMENT; + } + auto &peer_segment_desc = segment_desc_map[target_id]; + recordPeerMetadataVersion(target_id, *peer_segment_desc); + } } #endif // CONFIG_CACHE_SEGMENT_DESC @@ -126,6 +137,7 @@ int WorkerPool::submitPostSend( failed_target_ids[slice->target_id] = getCurrentTimeInNano(); continue; } + recordPeerMetadataVersion(slice->target_id, *peer_segment_desc); if (RdmaTransport::selectDevice( peer_segment_desc.get(), slice->rdma.dest_addr, @@ -154,7 +166,7 @@ int WorkerPool::submitPostSend( alt_dev_id < peer_segment_desc->devices.size(); ++alt_dev_id) { if (alt_dev_id == (size_t)device_id) continue; auto alt_path = - MakeNicPath(peer_segment_desc->name, + MakeNicPath(peer_segment_desc->nicPathServerName(), peer_segment_desc->devices[alt_dev_id].name); if (isRailAvailable(alt_path)) { device_id = alt_dev_id; @@ -189,7 +201,8 @@ int WorkerPool::submitPostSend( } submitted_slice_count_.fetch_add(submitted_slice_count); - if (suspended_flag_.load()) { + if (submitted_slice_count && + parked_worker_count_.load(std::memory_order_acquire) > 0) { std::lock_guard lock(cond_mutex_); cond_var_.notify_all(); } @@ -208,7 +221,6 @@ int WorkerPool::submitPostSend( void WorkerPool::performPostSend(int thread_id) { // Fast-fail if context is unhealthy due to catastrophic hardware failure if (!contextHealthy()) { - auto &local_slice_queue = collective_slice_queue_[thread_id]; for (int shard_id = thread_id; shard_id < kShardCount; shard_id += kTransferWorkerCount) { if (slice_queue_count_[shard_id].load(std::memory_order_relaxed) == @@ -419,6 +431,10 @@ void WorkerPool::redispatch(std::vector &slice_list, processed_slice_count_++; } else { auto &peer_segment_desc = segment_desc_map[slice->target_id]; + if (peer_segment_desc) { + recordPeerMetadataVersion(slice->target_id, + *peer_segment_desc); + } int buffer_id, device_id; if (!peer_segment_desc || RdmaTransport::selectDevice(peer_segment_desc.get(), @@ -440,6 +456,14 @@ void WorkerPool::redispatch(std::vector &slice_list, } } +bool WorkerPool::hasOutstandingCq(int thread_id) { + for (int cq_index = thread_id; cq_index < context_.cqCount(); + cq_index += kTransferWorkerCount) { + if (*context_.cqOutstandingCount(cq_index) > 0) return true; + } + return false; +} + void WorkerPool::transferWorker(int thread_id) { bindToSocket(numa_socket_id_); const static uint64_t kWaitPeriodInNano = 100000000; // 100ms @@ -449,18 +473,21 @@ void WorkerPool::transferWorker(int thread_id) { processed_slice_count_.load(std::memory_order_relaxed); auto submitted_slice_count = submitted_slice_count_.load(std::memory_order_relaxed); - if (processed_slice_count == submitted_slice_count) { + if (processed_slice_count == submitted_slice_count && + !hasOutstandingCq(thread_id)) { uint64_t curr_wait_ts = getCurrentTimeInNano(); if (curr_wait_ts - last_wait_ts > kWaitPeriodInNano) { std::unique_lock lock(cond_mutex_); - suspended_flag_.fetch_add(1); + parked_worker_count_.fetch_add(1, std::memory_order_acq_rel); // Double-check condition after acquiring lock to avoid lost - // wakeup + // wakeup. parked_worker_count_ is set before this check so + // producers that submit after it will notify this worker. if (processed_slice_count_.load(std::memory_order_relaxed) == - submitted_slice_count_.load()) { + submitted_slice_count_.load() && + !hasOutstandingCq(thread_id)) { cond_var_.wait_for(lock, std::chrono::seconds(1)); } - suspended_flag_.fetch_sub(1); + parked_worker_count_.fetch_sub(1, std::memory_order_acq_rel); last_wait_ts = curr_wait_ts; } continue; @@ -603,6 +630,63 @@ bool WorkerPool::isRailAvailable(const std::string &peer_nic_path) { return false; } +void WorkerPool::clearRailState( + const std::vector &peer_nic_paths) { + std::lock_guard lock(rail_state_lock_); + for (const auto &peer_nic_path : peer_nic_paths) { + rail_states_.erase(peer_nic_path); + } +} + +std::vector WorkerPool::buildPeerNicPaths( + const Transport::SegmentDesc &desc) const { + std::vector peer_nic_paths; + peer_nic_paths.reserve(desc.devices.size()); + for (const auto &device : desc.devices) { + peer_nic_paths.push_back( + MakeNicPath(desc.nicPathServerName(), device.name)); + } + return peer_nic_paths; +} + +void WorkerPool::recordPeerMetadataVersion( + SegmentID segment_id, const Transport::SegmentDesc &desc) { + auto new_paths = buildPeerNicPaths(desc); + std::vector stale_paths; + bool metadata_version_changed = false; + + { + std::lock_guard lock(target_metadata_lock_); + auto &state = target_metadata_[segment_id]; + if (!state.initialized) { + state.initialized = true; + state.metadata_version = desc.metadata_version; + state.peer_nic_paths = std::move(new_paths); + return; + } + + metadata_version_changed = + state.metadata_version != desc.metadata_version; + if (metadata_version_changed) { + stale_paths = state.peer_nic_paths; + LOG(INFO) << "Peer segment metadata version changed: segment_id=" + << segment_id << " name=" << desc.name + << " old_version=" << state.metadata_version + << " new_version=" << desc.metadata_version + << ", invalidating old RDMA endpoints"; + } + state.metadata_version = desc.metadata_version; + state.peer_nic_paths = std::move(new_paths); + } + + if (!metadata_version_changed) return; + + for (const auto &peer_nic_path : stale_paths) { + context_.deleteEndpoint(peer_nic_path); + } + clearRailState(stale_paths); +} + // Unified retry logic: increment retry count and return whether retry is // allowed bool WorkerPool::shouldRetrySlice(Transport::Slice *slice) { diff --git a/mooncake-transfer-engine/tests/transfer_metadata_test.cpp b/mooncake-transfer-engine/tests/transfer_metadata_test.cpp index d8fd0d889c..f9af4d7d6e 100644 --- a/mooncake-transfer-engine/tests/transfer_metadata_test.cpp +++ b/mooncake-transfer-engine/tests/transfer_metadata_test.cpp @@ -57,6 +57,29 @@ class TransferMetadataTest : public ::testing::Test { std::unique_ptr metadata_client; std::string metadata_server; std::string local_server_name; + + static TransferMetadata::SegmentDesc MakeRdmaSegment( + const std::string& name) { + TransferMetadata::SegmentDesc desc; + desc.name = name; + desc.protocol = "rdma"; + desc.tcp_data_port = 1234; + + TransferMetadata::DeviceDesc device; + device.name = "mlx5_0"; + device.lid = 1; + device.gid = "0000:0000:0000:0000:0000:ffff:7f00:0001"; + desc.devices.push_back(device); + + TransferMetadata::BufferDesc buffer; + buffer.name = "buffer-0"; + buffer.addr = 0x100000; + buffer.length = 4096; + buffer.lkey.push_back(11); + buffer.rkey.push_back(22); + desc.buffers.push_back(buffer); + return desc; + } }; // add and search LocalSegmentMeta @@ -71,6 +94,7 @@ TEST_F(TransferMetadataTest, LocalSegmentTest) { ASSERT_EQ(re, 0); auto des = metadata_client->getSegmentDescByName(segment_name); ASSERT_EQ(des, segment_des); + ASSERT_EQ(des->metadata_version, 1); des = metadata_client->getSegmentDescByID(segment_id, false); ASSERT_EQ(des, segment_des); auto id = metadata_client->getSegmentID(segment_name); @@ -79,6 +103,25 @@ TEST_F(TransferMetadataTest, LocalSegmentTest) { ASSERT_EQ(re, 0); } +TEST_F(TransferMetadataTest, LocalSegmentPreservesExplicitVersion) { + auto segment_des = std::make_shared( + MakeRdmaSegment("explicit_versions")); + segment_des->metadata_version = 8; + segment_des->buffers[0].state = + TransferMetadata::BufferDesc::STATE_DRAINING; + + ASSERT_EQ(metadata_client->addLocalSegment(2222222, "explicit_segment", + std::move(segment_des)), + 0); + + auto desc = metadata_client->getSegmentDescByID(2222222); + ASSERT_TRUE(desc); + ASSERT_EQ(desc->metadata_version, 8); + ASSERT_EQ(desc->buffers.size(), 1); + ASSERT_EQ(desc->buffers[0].state, + TransferMetadata::BufferDesc::STATE_DRAINING); +} + // add and remove LocalMemoryBufferMeta TEST_F(TransferMetadataTest, LocalMemoryBufferTest) { auto segment_des = std::make_shared(); @@ -87,6 +130,9 @@ TEST_F(TransferMetadataTest, LocalMemoryBufferTest) { int re = metadata_client->addLocalSegment( LOCAL_SEGMENT_ID, "test_local_segment", std::move(segment_des)); ASSERT_EQ(re, 0); + auto local_desc = metadata_client->getSegmentDescByID(LOCAL_SEGMENT_ID); + ASSERT_TRUE(local_desc); + auto metadata_version = local_desc->metadata_version; uint64_t addr = 0; for (int i = 0; i < 10; ++i) { TransferMetadata::BufferDesc buffer_des; @@ -95,20 +141,126 @@ TEST_F(TransferMetadataTest, LocalMemoryBufferTest) { re = metadata_client->addLocalMemoryBuffer(buffer_des, false); ASSERT_EQ(re, 0); } + local_desc = metadata_client->getSegmentDescByID(LOCAL_SEGMENT_ID); + ASSERT_TRUE(local_desc); + ASSERT_EQ(local_desc->metadata_version, metadata_version); + ASSERT_EQ(local_desc->buffers.size(), 10); + for (const auto& buffer : local_desc->buffers) { + ASSERT_EQ(buffer.state, TransferMetadata::BufferDesc::STATE_READY); + } + ASSERT_EQ(metadata_client->updateLocalSegmentDesc(), 0); + local_desc = metadata_client->getSegmentDescByID(LOCAL_SEGMENT_ID); + ASSERT_TRUE(local_desc); + ASSERT_GT(local_desc->metadata_version, metadata_version); addr = 1000; re = metadata_client->removeLocalMemoryBuffer((void*)addr, false); ASSERT_EQ(re, ERR_ADDRESS_NOT_REGISTERED); + auto before_remove_metadata_version = local_desc->metadata_version; for (int i = 9; i > 0; --i) { addr = i * 2048; re = metadata_client->removeLocalMemoryBuffer((void*)addr, false); ASSERT_EQ(re, 0); } + local_desc = metadata_client->getSegmentDescByID(LOCAL_SEGMENT_ID); + ASSERT_TRUE(local_desc); + ASSERT_EQ(local_desc->metadata_version, before_remove_metadata_version); + ASSERT_EQ(local_desc->buffers.size(), 1); re = metadata_client->removeLocalSegment("test_local_segment"); ASSERT_EQ(re, 0); } +TEST_F(TransferMetadataTest, LocalMemoryBufferDeregisterWithMetadata) { + auto segment_des = std::make_shared(); + segment_des->name = "test_metadata_deregister"; + segment_des->protocol = "rdma"; + ASSERT_EQ(metadata_client->addLocalSegment( + LOCAL_SEGMENT_ID, "test_metadata_deregister", + std::move(segment_des)), + 0); + + TransferMetadata::BufferDesc buffer_des; + buffer_des.addr = 4096; + buffer_des.length = 1024; + ASSERT_EQ(metadata_client->addLocalMemoryBuffer(buffer_des, false), 0); + + auto local_desc = metadata_client->getSegmentDescByID(LOCAL_SEGMENT_ID); + ASSERT_TRUE(local_desc); + auto metadata_version = local_desc->metadata_version; + + ASSERT_EQ(metadata_client->removeLocalMemoryBuffer((void*)4096, true), 0); + local_desc = metadata_client->getSegmentDescByID(LOCAL_SEGMENT_ID); + ASSERT_TRUE(local_desc); + ASSERT_GT(local_desc->metadata_version, metadata_version); + ASSERT_TRUE(local_desc->buffers.empty()); +} + +TEST_F(TransferMetadataTest, SegmentDescJsonRoundTripIncludesReliabilityFields) { + auto desc = MakeRdmaSegment("json_round_trip"); + desc.metadata_version = 42; + desc.buffers[0].state = TransferMetadata::BufferDesc::STATE_READY; + + Json::Value json; + ASSERT_EQ(metadata_client->encodeSegmentDesc(desc, json), 0); + ASSERT_EQ(json["metadata_version"].asUInt64(), desc.metadata_version); + ASSERT_EQ(json["buffers"][0]["state"].asString(), desc.buffers[0].state); + + auto decoded = metadata_client->decodeSegmentDesc(json, desc.name); + ASSERT_TRUE(decoded); + ASSERT_EQ(decoded->metadata_version, desc.metadata_version); + ASSERT_EQ(decoded->buffers.size(), 1); + ASSERT_EQ(decoded->buffers[0].state, desc.buffers[0].state); +} + +TEST_F(TransferMetadataTest, LegacySegmentDescJsonDecodesAsReady) { + auto desc = MakeRdmaSegment("legacy_json"); + Json::Value json; + ASSERT_EQ(metadata_client->encodeSegmentDesc(desc, json), 0); + json.removeMember("metadata_version"); + json["buffers"][0].removeMember("state"); + + auto decoded = metadata_client->decodeSegmentDesc(json, desc.name); + ASSERT_TRUE(decoded); + ASSERT_EQ(decoded->metadata_version, 0); + ASSERT_EQ(decoded->buffers.size(), 1); + ASSERT_EQ(decoded->buffers[0].state, + TransferMetadata::BufferDesc::STATE_READY); +} + +TEST_F(TransferMetadataTest, SegmentCacheTracksMetadataVersionChanges) { + auto desc_v1 = std::make_shared( + MakeRdmaSegment("remote_segment")); + desc_v1->metadata_version = 1; + + metadata_client->updateSegmentCacheEntry(3333333, "remote_segment", + desc_v1); + ASSERT_EQ(metadata_client->segmentMetadataVersionChangeCount(), 0); + + auto desc_v2 = std::make_shared(*desc_v1); + desc_v2->metadata_version = 2; + metadata_client->updateSegmentCacheEntry(3333333, "remote_segment", + desc_v2); + ASSERT_EQ(metadata_client->segmentMetadataVersionChangeCount(), 1); + + auto desc_v3 = std::make_shared(*desc_v2); + desc_v3->metadata_version = 3; + metadata_client->updateSegmentCacheEntry(3333333, "remote_segment", + desc_v3); + ASSERT_EQ(metadata_client->segmentMetadataVersionChangeCount(), 2); +} + +TEST_F(TransferMetadataTest, SegmentCacheIgnoresNullDescriptor) { + metadata_client->updateSegmentCacheEntry(4444444, "null_remote_segment", + nullptr); + + auto desc = metadata_client->getSegmentDescByID(4444444); + ASSERT_FALSE(desc); +} + // add, get and remove RPCMetaEntryMeta TEST_F(TransferMetadataTest, RpcMetaEntryTest) { + if (metadata_server == P2PHANDSHAKE) { + GTEST_SKIP() << "P2P RPC metadata requires a local listening socket"; + } auto hostname_port = parseHostNameWithPort(local_server_name); TransferMetadata::RpcMetaDesc desc; desc.ip_or_host_name = hostname_port.first.c_str(); @@ -128,4 +280,4 @@ TEST_F(TransferMetadataTest, RpcMetaEntryTest) { int main(int argc, char** argv) { ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); -} \ No newline at end of file +}