diff --git a/doc/multi-recv.md b/doc/multi-recv.md index 3910d6feaa..851f8af4af 100644 --- a/doc/multi-recv.md +++ b/doc/multi-recv.md @@ -1,4 +1,4 @@ -# Multi-Recv Support in the RDMA Protocol +# Multi-Recv and Eager Support in the RDMA Protocol ## Overview @@ -14,12 +14,9 @@ N separate buffers (one per remote rank). With `NCCL_NET_SHARED_COMMS=1`, NCCL multiplexes these N sub-channels onto a single network communicator, using tags to distinguish them. -This document describes the multi-recv design and the changes to the RDMA -protocol to support it. - -**Note:** Eager message support for grouped receives is not yet implemented and -will be added in a future change. Currently, eager sends with grouped receives -are disabled. +This document describes the multi-recv design, the eager message extension that +allows small messages to be sent before the receiver posts its receive, and the +ordering constraints that make this work correctly. ## Background: Single Recv Flow @@ -38,6 +35,11 @@ In the baseline RDMA protocol, a single send/recv pair works as follows: 3. **Receiver** gets a write completion with the immediate data, identifies the request, and marks it complete. +For **eager** sends (small messages, ≤ `eager_send_size`), the sender writes the +data into a pre-posted bounce buffer on the receiver *before* the ctrl msg +arrives. The receiver later copies the data from the bounce buffer to the final +destination. + ## Multi-Recv Design ### Control Message Format @@ -56,7 +58,8 @@ struct nccl_net_ofi_ctrl_msg_entry { uint16_t flags; // e.g. recv completion optional uint16_t num_recvs; // N (only in entry[0]) uint8_t recv_idx; // index of this entry (0..N-1) - uint8_t pad[9]; + uint8_t entry_used; // set when consumed by eager or write + uint8_t pad[8]; }; ``` @@ -91,18 +94,158 @@ receiver extracts `recv_idx` from the immediate data and accumulates `cq_entry->len` into `recvs[recv_idx].recv_size`. The `test()` function reports per-sub sizes to NCCL. +## Eager Messages with Multi-Recv + +### The Problem + +Without eager support, every send must wait for the ctrl msg before transmitting +data. This adds a round-trip of latency for small messages. With multi-recv, the +challenge is that the sender doesn't know at eager-send time whether the receiver +will post a single recv or a grouped recv for a given `msg_seq_num`. + +### Eager Message Header + +Each eager message prepends an 8-byte header to the bounce buffer data: + +``` +struct nccl_ofi_eager_msg_header { + uint8_t eager_offset; // position within the eager batch + uint8_t prev_batch_count; // count of previous batch (when offset == 0) + uint16_t prev_msg_seq_num; // seq of previous batch (when offset == 0) + int32_t tag; // NCCL tag for multi-recv routing +}; +``` + +The sender transmits this via `fi_sendmsg` with two iovecs: the header (from a +registered freelist buffer) and the payload (from the user buffer). + +### Sender-Side Eager Queue + +The sender maintains a circular queue of up to `NCCL_OFI_MAX_EAGER_PENDING` (`NCCL_NET_MAX_REQUESTS`) +outstanding eager sends. Key behaviors: + +- **Eager decision**: A send goes eager if there is no ctrl msg, the sender is + not mid-group, `size + 8 ≤ eager_send_size`, the queue is not full, there + are no inflight RDMA writes, and the sender is not in a state where the + queue has undrained entries from a previous batch with `eager_offset_next` + already reset to 0. This last condition + (`eager_queue_count == 0 || eager_offset_next > 0`) prevents starting a new + eager batch while the previous batch's entries are still in the queue awaiting + ctrl msg drain. + +- **No seq_num advance**: Eager sends do NOT advance `next_msg_seq_num`. Instead, + `eager_offset_next` increments (0, 1, 2, ...). All eager sends in a batch + share the same `msg_seq_num`. + +- **Drain**: When a ctrl msg arrives (detected in `send()` or `test()`), the + drain function matches queued eager sends against ctrl msg entries: + - **Single recv**: Pop the front entry, mark the send as having received its + ctrl msg, advance `next_msg_seq_num`. + - **Grouped recv**: Rotate the queue, matching by tag. Matched entries are + consumed (`entry_used = 1`). Unmatched entries are pushed back. If all N + sub-recvs are satisfied, advance `next_msg_seq_num`. + +- **Batch boundary tracking**: When `next_msg_seq_num` advances (in the drain + or in the non-eager send path) and `eager_offset_next > 0`, the sender + records `prev_eager_msg_seq_num` and `prev_eager_batch_count` from the + current state, then resets `eager_offset_next` to 0. These values are + stamped into the next batch's `offset == 0` header so the receiver can + verify batch boundaries. The sender initializes `prev_eager_msg_seq_num` + to `0xFFFF` (sentinel) so the receiver can distinguish the very first + eager batch from a later batch that arrives out of order. + +### Receiver-Side Eager Queue + +The receiver maintains a **sorted doubly-linked list** of pending eager messages, +ordered by `(msg_seq_num, eager_offset)`. A pre-allocated pool of +`NCCL_OFI_CTRL_MAILBOX_SIZE` entries avoids dynamic allocation. + +When an eager message arrives (`handle_eager_recv`): +1. Parse the 8-byte header to extract `eager_offset`, `tag`, and batch info. +2. Subtract 8 from `recv_len` (the header is not part of the payload). +3. Insert into the sorted list. +4. Call `drain_recv_eager_queue()`. + +### Ordering Requirements + +**Why ordering matters**: The mapping from `(msg_seq_num, eager_offset)` to a +target recv depends on the recv sequence. Eager offset 0 targets the recv at +`msg_seq_num`. Offset 1 targets the next recv. But a grouped recv consumes +multiple offsets (one per matching tag). Without ordered processing, the receiver +cannot determine which recv an eager message belongs to. + +**Sender ordering**: The sender assigns offsets sequentially (0, 1, 2, ...) and +the drain processes them in FIFO order against ctrl msgs. For grouped recvs, the +drain matches by tag, ensuring each eager send is paired with the correct +sub-receive. + +**Receiver ordering**: The drain processes entries in strict +`(msg_seq_num, eager_offset)` order. Before processing an entry, it verifies +continuity: + +- **First-ever batch** (`has_processed_eager == false`): The entry must have + `eager_offset == 0` and `prev_msg_seq_num == 0xFFFF` (the sentinel value). + This ensures that if a later batch arrives before the first batch (due to + out-of-order delivery), it is not mistakenly processed as the first batch. + +- **offset == 0 (new batch)**: The previous batch must be complete. This is + verified by checking that `last_eager_msg_seq_num == prev_msg_seq_num` and + `last_eager_offset == prev_batch_count - 1`. + +- **offset > 0 (same batch)**: Must be consecutive with the last processed + entry: `last_eager_msg_seq_num == entry.msg_seq_num` and + `last_eager_offset == entry.eager_offset - 1`. + +If the check fails (e.g., an earlier offset hasn't arrived yet), the drain stops +and retries later. + +### Target Recv Resolution + +Once an entry passes the continuity check, the drain resolves which recv it +targets using `eager_drain_recv_seq`: + +- Look up the recv at `eager_drain_recv_seq` in the message buffer. +- If the recv completed and was removed (detected via `last_completed_seq`), + advance past it. +- **Single recv**: Eager-copy the data, advance `eager_drain_recv_seq`. +- **Grouped recv**: Match by tag using `eager_match_recv()`. If matched, + eager-copy to the matched sub-recv. If no match, advance `recv_seq` to the + next recv (the eager message belongs to a later recv on this communicator). + +### Eager Copy + +The eager copy reads data from the bounce buffer into the destination buffer +using `fi_read`. The bounce buffer offset is adjusted by `NCCL_OFI_EAGER_HEADER_SIZE` +to skip the header. Each sub-recv has its own `eager_copy_req` to avoid leaking +requests when multiple sub-recvs in a grouped receive are handled by eager. + ## Limitations - **Maximum grouped receives**: `NCCL_OFI_MAX_RECVS = 8` (limited by 3-bit `recv_idx` in immediate data). +- **Maximum outstanding eager sends**: `NCCL_NET_MAX_REQUESTS` (32) per + communicator (`NCCL_OFI_MAX_EAGER_PENDING`). + - **Maximum communicators**: Reduced from 256K to 32K (15-bit `comm_id`) to make room for `recv_idx` in the immediate data. +- **Eager disabled for GPU buffers with 2-iovec sends**: The `fi_sendmsg` with + two iovecs (host header + GPU payload) requires provider support for + scatter-gather across host and device memory. + - **Version gating**: Grouped receives (`maxRecvs > 1`) are only reported for ncclNet v9 and later, where `irecv` uses `size_t` sizes. Earlier versions and the Neuron/sendrecv protocol report `maxRecvs = 1`. -- **Eager sends**: Eager message support for grouped receives is not yet - implemented. When multi-recv is enabled, eager sends continue to work for - single receives (n=1) using the existing eager path. +- **Interleaved eager sends across groups**: When NCCL interleaves `send()` + calls across what will become different grouped receives, the receiver's + eager drain processes entries in strict offset order and cannot skip past + an unresolved entry. If the receiver serializes recv posting (waiting for + recv N to complete before posting recv N+1), this can deadlock. This is + not an issue in practice because NCCL's proxy thread posts recvs + independently without waiting for prior completions. + +- **Eager size overhead**: The 8-byte header reduces the effective eager payload + by 8 bytes. The eager decision accounts for this: + `size + NCCL_OFI_EAGER_HEADER_SIZE ≤ eager_send_size`. diff --git a/include/nccl_ofi_rdma.h b/include/nccl_ofi_rdma.h index bdcd18d5b4..e160554734 100644 --- a/include/nccl_ofi_rdma.h +++ b/include/nccl_ofi_rdma.h @@ -18,6 +18,7 @@ #include "nccl_ofi_idpool.h" #include "nccl_ofi_log.h" #include "nccl_ofi_msgbuff.h" +#include "nccl_ofi_dlist.h" #include "nccl_ofi_scheduler.h" #include "nccl_ofi_topo.h" #include "nccl_ofi_ofiutils.h" @@ -44,6 +45,10 @@ static_assert(MAX_NUM_RAILS <= UINT16_MAX); * @brief Number of bits used for the communicator ID */ #define NCCL_OFI_RDMA_COMM_ID_BITS (15) + +/* + * @brief Number of bits used for the sub-receive index in grouped receives + */ #define NCCL_OFI_RDMA_RECV_IDX_BITS (3) /* Maximum number of comms open simultaneously. Eventually this will be @@ -65,9 +70,10 @@ static_assert(MAX_NUM_RAILS <= UINT16_MAX); * communicator ID, and the message sequence number (msg_seq_num). * The data is encoded as follows: * - * | 4-bit segment count | 18-bit comm ID | 10-bit msg_seq_num | + * | 4-bit segment count | 3-bit recv_idx | 15-bit comm ID | 10-bit msg_seq_num | * * - Segment count: number of RDMA writes that will be delivered as part of this message + * - Recv index: sub-receive index within a grouped receive (0 for non-grouped) * - Comm ID: the ID for this communicator * - Message sequence number: message identifier */ @@ -164,6 +170,27 @@ class nccl_net_ofi_rdma_mr_handle_t : public nccl_net_ofi_mr_handle_t { /* Sentinel tag value indicating a ctrl msg entry has been consumed */ #define NCCL_OFI_CTRL_MSG_TAG_INVALID ((int16_t)-1) +/* Maximum number of outstanding eager messages in the sender queue */ +#define NCCL_OFI_MAX_EAGER_PENDING NCCL_NET_MAX_REQUESTS + +/* Size of the eager message header prepended to bounce buffer data */ +#define NCCL_OFI_EAGER_HEADER_SIZE 8 + +/* + * @brief Header prepended to eager message data in the bounce buffer. + * + * The sender prepends this header so the receiver can identify which + * sub-receive (in a grouped receive) the eager data belongs to. + */ +typedef struct nccl_ofi_eager_msg_header { + uint8_t eager_offset; + uint8_t prev_batch_count; /* only meaningful when eager_offset == 0 */ + uint16_t prev_msg_seq_num; /* only meaningful when eager_offset == 0 */ + int32_t tag; +} nccl_ofi_eager_msg_header_t; +static_assert(sizeof(nccl_ofi_eager_msg_header_t) == NCCL_OFI_EAGER_HEADER_SIZE, + "Wrong size for eager message header"); + /* * @brief Control message sub-entry * @@ -194,8 +221,10 @@ typedef struct nccl_net_ofi_ctrl_msg_entry { uint16_t num_recvs; /* Index of this entry within the grouped receive (0..N-1) */ uint8_t recv_idx; + /* Set by sender when this entry has been consumed by either eager or write */ + uint8_t entry_used; /* Padding to 64-byte cache line boundary */ - uint8_t pad[9]; + uint8_t pad[8]; } nccl_net_ofi_ctrl_msg_entry_t; static_assert(sizeof(nccl_net_ofi_ctrl_msg_entry_t) == 64, "Wrong size for RDMA Control message entry"); @@ -325,6 +354,8 @@ typedef struct { typedef struct { /* True for eager messages */ bool eager; + /* True if ctrl msg was received - Only valid if eager=true */ + bool eager_ctrl_msg_received; /* Remote destination buffer offset from base address */ uintptr_t remote_buff_offset; /* Remote buffer length */ @@ -341,8 +372,15 @@ typedef struct { nccl_net_ofi_rdma_mr_handle_t *buff_mr_handle; /* Tag for matching to the correct sub-entry in a grouped receive ctrl msg */ int tag; - /* Sub-receive index within a grouped receive (encoded in immediate data) */ + /* Eager offset within the sender's eager queue (only for eager sends) */ + uint8_t eager_offset; + /* Freelist entry for the eager header buffer (returned on send completion) */ + nccl_ofi_freelist::fl_entry *eager_hdr_fl_entry; + /* Sub-receive index from ctrl msg entry (for RDMA write immediate data) */ uint8_t recv_idx; + /* Previous batch info */ + uint8_t prev_batch_count; + uint16_t prev_msg_seq_num; /* Schedule used to transfer this request. We save the pointer to * reference it when transferring the request over network. */ nccl_net_ofi_schedule_t *schedule; @@ -381,6 +419,8 @@ typedef struct { nccl_net_ofi_rdma_req *eager_rx_buff_req; /* Pointer to recv parent request */ nccl_net_ofi_rdma_req *recv_req; + /* Sub-receive index within grouped receive (0 for single recv) */ + int sub_recv_idx; } rdma_req_eager_copy_data_t; /* @@ -415,6 +455,10 @@ typedef struct rdma_req_recv_sub { int ncompls; /* Total expected segments for this sub-receive */ int total_segms; + /* Flag indicating if this request was consumed */ + bool consumed; + /* (Eager) pointer to eager local copy request for this sub-recv */ + nccl_net_ofi_rdma_req *eager_copy_req; } rdma_req_recv_sub_t; /* @@ -427,9 +471,6 @@ typedef struct { rdma_req_recv_sub_t recvs[NCCL_OFI_MAX_RECVS]; /* Pointer to receive segments child request */ nccl_net_ofi_rdma_req *recv_segms_req; - /* (Eager messages) pointer to eager local copy request. - * Only used when num_recvs == 1 */ - nccl_net_ofi_rdma_req *eager_copy_req; /* Total number of completions. Expect one send ctrl * completion and one completion that indicates that all * segments have arrived. @@ -591,6 +632,15 @@ typedef struct nccl_net_ofi_rdma_send_comm_rail { struct fid_ep *local_ep; } nccl_net_ofi_rdma_send_comm_rail_t; +/* + * @brief Entry in the sender-side eager queue + */ +typedef struct nccl_ofi_eager_queue_entry { + nccl_net_ofi_rdma_req *req; + int tag; + uint8_t eager_offset; +} nccl_ofi_eager_queue_entry_t; + /* * @brief RDMA send communicator * @@ -632,6 +682,22 @@ class nccl_net_ofi_rdma_send_comm : public nccl_net_ofi_send_comm { /* Bitmask of tags used in the current group */ uint32_t group_tag_used; + /* Eager queue: outstanding eager sends awaiting ctrl msg resolution */ + nccl_ofi_eager_queue_entry_t eager_queue[NCCL_OFI_MAX_EAGER_PENDING]; + uint8_t eager_queue_head; + uint8_t eager_queue_tail; + uint8_t eager_queue_count; + /* Next eager offset to assign */ + uint8_t eager_offset_next; + /* Previous batch info (set when drain resets eager_offset_next) */ + uint16_t prev_eager_msg_seq_num; + uint8_t prev_eager_batch_count; + /* Freelist of eager header buffers */ + nccl_ofi_freelist *eager_hdr_fl; + nccl_net_ofi_rdma_mr_handle_t *eager_hdr_mr_handle; + /* Whether eager sends prepend an 8-byte header for multi-recv routing */ + bool use_eager_header; + /* Number of rails */ uint16_t num_rails; /* Number of control rails */ @@ -699,6 +765,20 @@ typedef struct nccl_net_ofi_rdma_flush_buffer { * Rails and control rails are fixed-size arrays of MAX_NUM_RAILS. * The constructor allocates a page-aligned control mailbox. */ +/* + * @brief Entry in the receiver-side eager arrival queue + */ +typedef struct nccl_ofi_recv_eager_entry { + nccl_ofi_dlist_node link; /* intrusive list node */ + nccl_net_ofi_rdma_req *rx_buff_req; + uint16_t msg_seq_num; + uint8_t eager_offset; + uint8_t prev_batch_count; /* from header, only meaningful when eager_offset == 0 */ + uint16_t prev_msg_seq_num; /* from header, only meaningful when eager_offset == 0 */ + int32_t tag; + size_t recv_len; /* payload length (excluding header) */ +} nccl_ofi_recv_eager_entry_t; + class nccl_net_ofi_rdma_recv_comm : public nccl_net_ofi_recv_comm { public: nccl_net_ofi_rdma_recv_comm(); @@ -715,11 +795,11 @@ class nccl_net_ofi_rdma_recv_comm : public nccl_net_ofi_recv_comm { nccl_net_ofi_rdma_recv_comm_rail_t *get_data_rail(uint16_t rail_id); nccl_net_ofi_rdma_recv_comm_rail_t *get_control_rail(uint16_t rail_id); int allocate_recv_req(nccl_net_ofi_rdma_device_t *device, - int dev_id_arg, uint16_t msg_seq_num, int num_recvs, - void **buffs, size_t *sizes, int *tags, - nccl_net_ofi_rdma_mr_handle_t **buff_mr_handles, - nccl_net_ofi_rdma_req **ret_req, - bool recv_completion_optional); + int dev_id_arg, uint16_t msg_seq_num, int num_recvs, + void **buffs, size_t *sizes, int *tags, + nccl_net_ofi_rdma_mr_handle_t **buff_mr_handles, + nccl_net_ofi_rdma_req **ret_req, + bool recv_completion_optional); /* CM receiver for connection establishment */ nccl_ofi_cm_receiver *receiver; @@ -753,6 +833,24 @@ class nccl_net_ofi_rdma_recv_comm : public nccl_net_ofi_recv_comm { #if HAVE_NVTX_TRACING nvtxDomainHandle_t nvtx_domain[NCCL_OFI_N_NVTX_DOMAIN_PER_COMM]; #endif + /* Receiver-side eager queue: arrived eager messages pending resolution. + * Can be more than NCCL_OFI_MAX_EAGER_PENDING since sender can complete these (see control message) + * and continue sending before they arrive at the receiver + */ + nccl_ofi_dlist recv_eager_list; + nccl_ofi_recv_eager_entry_t recv_eager_pool[NCCL_OFI_CTRL_MAILBOX_SIZE]; + nccl_ofi_dlist recv_eager_free; + /* Whether eager recvs carry an 8-byte header for multi-recv routing */ + bool use_eager_header; + /* Last processed eager entry's coordinates */ + uint16_t last_eager_msg_seq_num; + uint8_t last_eager_offset; + bool has_processed_eager; + /* Current recv seq being filled by the drain */ + uint16_t eager_drain_recv_seq; + /* Last completed recv seq */ + uint16_t last_completed_seq; + nccl_net_ofi_rdma_req *send_close_req; /* Counters for total sent and received control messages */ @@ -1569,6 +1667,10 @@ class nccl_net_ofi_rdma_device_t : public nccl_net_ofi_device_t { /* Maximum number of supported communicator IDs */ uint32_t num_comm_ids; + /* Whether the provider supports mixed host/device iovecs, + * enabling the eager header for multi-recv routing. */ + bool supports_eager_header = false; + /* ID pool */ nccl_ofi_idpool_t comm_idpool; diff --git a/m4/check_pkg_libfabric.m4 b/m4/check_pkg_libfabric.m4 index 14d5c43c5b..47e38f31a3 100644 --- a/m4/check_pkg_libfabric.m4 +++ b/m4/check_pkg_libfabric.m4 @@ -49,6 +49,9 @@ AC_DEFUN([CHECK_PKG_LIBFABRIC], [ AS_IF([test "${check_pkg_found}" = "yes"], [AC_CHECK_HEADERS([rdma/fi_ext.h])]) + AS_IF([test "${check_pkg_found}" = "yes"], + [AC_CHECK_HEADERS([rdma/fi_ext_efa.h])]) + AS_IF([test "${check_pkg_found}" = "yes"], [AC_CHECK_DECLS([FI_OPT_CUDA_API_PERMITTED, FI_OPT_EFA_USE_DEVICE_RDMA, @@ -58,11 +61,15 @@ AC_DEFUN([CHECK_PKG_LIBFABRIC], [ FI_OPT_MAX_MSG_SIZE, FI_OPT_SHARED_MEMORY_PERMITTED, FI_MR_DMABUF, - FI_OPT_INJECT_RMA_SIZE], + FI_OPT_INJECT_RMA_SIZE, + FI_EFA_FEATURE_OPS], [], [], [AC_INCLUDES_DEFAULT [#include #ifdef HAVE_RDMA_FI_EXT_H #include +#endif +#ifdef HAVE_RDMA_FI_EXT_EFA_H +#include #endif]])]) AS_IF([test "${check_pkg_found}" = "yes"], diff --git a/src/nccl_ofi_net.cpp b/src/nccl_ofi_net.cpp index fdbff50490..09835026bc 100644 --- a/src/nccl_ofi_net.cpp +++ b/src/nccl_ofi_net.cpp @@ -16,6 +16,11 @@ #include #include +#ifdef HAVE_RDMA_FI_EXT_EFA_H +#ifdef HAVE_RDMA_FI_EXT_EFA_H +#include +#endif +#endif #include "nccl_ofi.h" #include "nccl_ofi_assert.h" #include "nccl_ofi_environ.h" @@ -544,12 +549,28 @@ int nccl_net_ofi_plugin_t::nccl_net_ofi_info_properties(struct fi_info *nic_prov goto exit; } - /* Only support multi-recv if eager is disabled for now - * Enabling it requires multiple iovs that point to both host and GPU memory - */ - if (ofi_nccl_eager_max_size() < 0) { - props->max_group_receives = NCCL_OFI_MAX_RECVS; +#if HAVE_DECL_FI_EFA_FEATURE_OPS + { /* Scope block: C++ forbids goto past auto variable initialization */ + auto fabric_result = nccl_ofi_ofiutils_fabric_create(nic_prov); + if (OFI_UNLIKELY(fabric_result.is_failure())) { + NCCL_OFI_WARN("Couldn't open a fabric provider. RC: %d, ERROR: %s", + fabric_result.error_code, fi_strerror(-fabric_result.error_code)); + ret = fabric_result.error_code; + goto error; + } + + struct fi_efa_feature_ops *feat_ops = NULL; + ret = fi_open_ops(&fabric_result.resource->fid, FI_EFA_FEATURE_OPS, 0, + (void **)&feat_ops, NULL); + if (ret != 0) { + NCCL_OFI_WARN("fi_open_ops for EFA features failed. RC: %d, ERROR: %s", + ret, fi_strerror(-ret)); + ret = 0; + } else if (feat_ops->query("mixed_hmem_iov")) { + props->max_group_receives = NCCL_OFI_MAX_RECVS; + } } +#endif /* name is NULL if device is a part of multirail config */ /* overriding default name only if value is available from provider */ diff --git a/src/nccl_ofi_rdma.cpp b/src/nccl_ofi_rdma.cpp index 7b27569863..0307ba98eb 100644 --- a/src/nccl_ofi_rdma.cpp +++ b/src/nccl_ofi_rdma.cpp @@ -423,7 +423,12 @@ int nccl_net_ofi_rdma_device_t::get_properties(nccl_ofi_properties_t *props) props->port_speed *= plugin_ptr->topo->max_group_size; static_assert(NCCL_OFI_RDMA_COMM_ID_BITS < 31, "NCCL_OFI_RDMA_COMM_ID_BITS must be less than 31 so max_communicators fits in an integer"); + static_assert(NCCL_OFI_RDMA_SEQ_BITS + NCCL_OFI_RDMA_COMM_ID_BITS + NCCL_OFI_RDMA_RECV_IDX_BITS + NUM_NUM_SEG_BITS == 32, + "Immediate data fields must sum to 32 bits"); + static_assert(NCCL_OFI_MAX_RECVS <= (1 << NCCL_OFI_RDMA_RECV_IDX_BITS), + "NCCL_OFI_MAX_RECVS must fit in RECV_IDX_BITS"); props->max_communicators = NCCL_OFI_RDMA_MAX_COMMS; + this->supports_eager_header = (props->max_group_receives > 1); } else { return ret; } @@ -622,6 +627,8 @@ static inline int set_eager_copy_completed(nccl_net_ofi_rdma_req *req) return ret; } + /* Record per-sub-receive size for eager completion */ + recv_data->recvs[eager_copy_data->sub_recv_idx].recv_size += size; /* Add completion to parent request */ if (recv_data->num_recvs > 1) { ret = inc_recv_seg_completion(recv_data->recv_segms_req, size, 1); @@ -759,7 +766,7 @@ static inline int update_send_data_from_remote(nccl_net_ofi_rdma_send_comm *s_co send_data->recv_idx = 0; } else { for (uint16_t i = 0; i < ctrl->entries[0].num_recvs; i++) { - if (ctrl->entries[i].tag == send_data->tag) { + if (ctrl->entries[i].tag == send_data->tag && !ctrl->entries[i].entry_used) { entry = &ctrl->entries[i]; send_data->recv_idx = entry->recv_idx; break; @@ -779,9 +786,8 @@ static inline int update_send_data_from_remote(nccl_net_ofi_rdma_send_comm *s_co send_data->remote_mr_key[rail_id] = entry->mr_key[rail_id]; } - /* Invalidate the entry so duplicate tags from step interleaving - * won't match this already-consumed entry */ - entry->tag = NCCL_OFI_CTRL_MSG_TAG_INVALID; + /* Mark entry as consumed so it won't match again */ + entry->entry_used = 1; /* If recv buffer is smaller than send buffer, we reduce the size of the send req */ nccl_net_ofi_mutex_lock(&req->req_lock); @@ -870,7 +876,7 @@ static inline int free_eager_copy_req(nccl_net_ofi_rdma_req *req, bool dec_infli req, dec_inflight_reqs); } -static inline int alloc_eager_copy_req(nccl_net_ofi_rdma_req *recv_req, nccl_net_ofi_rdma_recv_comm *r_comm, +static inline int alloc_eager_copy_req(nccl_net_ofi_rdma_req *recv_req, nccl_net_ofi_rdma_recv_comm *r_comm, int sub_recv_idx, nccl_net_ofi_rdma_req *rx_buff_req) { nccl_net_ofi_rdma_req *eager_copy_req = allocate_req(r_comm->nccl_ofi_reqs_fl); @@ -887,13 +893,230 @@ static inline int alloc_eager_copy_req(nccl_net_ofi_rdma_req *recv_req, nccl_net rdma_req_eager_copy_data_t *eager_copy_data = get_eager_copy_data(eager_copy_req); eager_copy_data->recv_req = recv_req; eager_copy_data->eager_rx_buff_req = rx_buff_req; + eager_copy_data->sub_recv_idx = sub_recv_idx; assert(get_rx_buff_data(rx_buff_req)->recv_len != 0); - get_recv_data(recv_req)->eager_copy_req = eager_copy_req; + get_recv_data(recv_req)->recvs[sub_recv_idx].eager_copy_req = eager_copy_req; return 0; } +/* + * @brief Try to resolve a single eager entry against a recv request. + * + * For single recv: always matches (copies to recvs[0]). + * For grouped recv: matches if tag matches an unsatisfied sub-recv. + * + * Returns the sub-recv index on match, -1 on no match. + */ +static int eager_match_recv(nccl_net_ofi_rdma_req *recv_req, int32_t eager_tag) +{ + rdma_req_recv_data_t *recv_data = get_recv_data(recv_req); + if (recv_data->num_recvs <= 1) { + return 0; + } + for (int i = 0; i < recv_data->num_recvs; i++) { + if ((!recv_data->recvs[i].consumed) && (recv_data->recvs[i].tag == eager_tag)) { + recv_data->recvs[i].consumed = true; + return i; + } + } + return -1; +} + +/* + * @brief Initiate eager copy for a specific sub-recv. + * + * For sub_idx == 0, uses the existing eager_copy_req path. + * For sub_idx > 0, we complete via inc_recv_seg_completion. + */ +static int eager_copy_to_sub_recv(nccl_net_ofi_rdma_recv_comm *r_comm, + nccl_net_ofi_rdma_req *recv_req, + nccl_net_ofi_rdma_req *rx_buff_req, + int sub_idx) +{ + rdma_req_recv_data_t *recv_data = get_recv_data(recv_req); + rdma_req_rx_buff_data_t *rx_buff_data = get_rx_buff_data(rx_buff_req); + + if (rx_buff_data->recv_len == 0) { + /* Zero-sized: repost buffer and complete */ + int ret = check_post_rx_buff_req(rx_buff_req); + if (ret != 0) return ret; + if (recv_data->num_recvs > 1) { + return inc_recv_seg_completion(recv_data->recv_segms_req, 0, 1); + } else { + return inc_req_completion(recv_req, 0, recv_data->total_num_compls); + } + } + + if (sub_idx == 0 && recv_data->num_recvs <= 1) { + /* Single recv: use existing eager_copy_req path */ + int ret = alloc_eager_copy_req(recv_req, r_comm, sub_idx, rx_buff_req); + if (ret != 0) return ret; + return receive_progress(recv_data->recvs[sub_idx].eager_copy_req, true); + } + + /* Multi-recv sub: do inline fi_read to the sub-recv buffer. + * We reuse the eager_copy mechanism but target recvs[sub_idx]. */ + /* For now, use recvs[0] eager_copy_req path for sub_idx 0, + * and for sub_idx > 0, allocate a separate eager copy. + * The eager_copy_req in recv_data is only for the legacy single path. + * For grouped, we complete via recv_segms. */ + int ret = alloc_eager_copy_req(recv_req, r_comm, sub_idx, rx_buff_req); + if (ret != 0) return ret; + return receive_progress(recv_data->recvs[sub_idx].eager_copy_req, true); +} + +/* + * @brief Drain the receiver-side ordered eager queue. + * + * Processes eager entries in offset order against posted recv requests. + * - Single recv: consume one entry, advance seq. + * - Grouped recv: rotate queue matching tags. Non-matches pushed back. + * Once group fully satisfied (by eager + writes), advance seq. + * Stops if group is not fully satisfied by eager alone (writes will finish it). + */ +/* + * @brief Wraparound-aware sequence number comparison. + * Returns true if a comes before b in the sequence space. + */ +static inline bool seq_before(uint16_t a, uint16_t b) +{ + uint16_t diff = (a - b) & MSG_SEQ_NUM_MASK; return diff > (MSG_SEQ_NUM_MASK >> 1); +} + +/* + * @brief Compare two eager entries by (msg_seq_num, eager_offset). + * Returns true if a should come before b. + */ +static inline bool eager_entry_less(const nccl_ofi_recv_eager_entry_t *a, + const nccl_ofi_recv_eager_entry_t *b) +{ + if (a->msg_seq_num != b->msg_seq_num) + return seq_before(a->msg_seq_num, b->msg_seq_num); + return a->eager_offset < b->eager_offset; +} + +/* + * @brief Insert an eager entry into the sorted list. + * Walks from tail backwards (most inserts go near the end). + */ +static inline void recv_eager_sorted_insert(nccl_ofi_dlist *list, + nccl_ofi_recv_eager_entry_t *entry) +{ + /* Walk from tail to find insertion point */ + nccl_ofi_dlist_node *pos = list->head.prev; + while (pos != &list->head) { + nccl_ofi_recv_eager_entry_t *existing = + nccl_ofi_dlist_entry(pos, &nccl_ofi_recv_eager_entry_t::link); + if (!eager_entry_less(entry, existing)) + break; + pos = pos->prev; + } + /* Insert after pos */ + entry->link.prev = pos; + entry->link.next = pos->next; + pos->next->prev = &entry->link; + pos->next = &entry->link; +} + +static int drain_recv_eager_queue(nccl_net_ofi_rdma_recv_comm *r_comm) +{ + while (!r_comm->recv_eager_list.empty()) { + nccl_ofi_dlist_node *front = r_comm->recv_eager_list.front(); + nccl_ofi_recv_eager_entry_t *entry = + nccl_ofi_dlist_entry(front, &nccl_ofi_recv_eager_entry_t::link); + + /* Check if this entry is the next in sequence */ + bool can_process; + if (!r_comm->has_processed_eager) { + can_process = ((entry->eager_offset == 0) && (entry->prev_msg_seq_num == 0xFFFF)); + } else if (entry->eager_offset == 0) { + /* New batch: previous batch must be complete */ + can_process = (r_comm->last_eager_msg_seq_num == entry->prev_msg_seq_num + && r_comm->last_eager_offset == entry->prev_batch_count - 1); + } else { + /* Same batch: must be consecutive offset */ + can_process = (r_comm->last_eager_msg_seq_num == entry->msg_seq_num + && r_comm->last_eager_offset == entry->eager_offset - 1); + } + + if (!can_process) + break; + + if (entry->eager_offset == 0) + /* New batch starts at entry->msg_seq_num */ + r_comm->eager_drain_recv_seq = entry->msg_seq_num; + + /* Find the target recv, advancing past completed/non-matching recvs */ + uint16_t recv_seq = r_comm->eager_drain_recv_seq; + bool resolved = false; + while (!resolved) { + void *elem; + nccl_ofi_msgbuff_elemtype_t type; + nccl_ofi_msgbuff_status_t stat; + nccl_ofi_msgbuff_result_t mb_res = r_comm->msgbuff->retrieve( + recv_seq, &elem, &type, &stat); + if (mb_res != NCCL_OFI_MSGBUFF_SUCCESS || type != NCCL_OFI_MSGBUFF_REQ) { + if ((entry->eager_offset > 0) && (recv_seq == r_comm->last_completed_seq)) { + /* We already handled ctrl msg with this seq num but it completed since then, skip to next*/ + recv_seq = (recv_seq + 1) & MSG_SEQ_NUM_MASK; + continue; + } else + break; /* recv not posted yet */ + } + + nccl_net_ofi_rdma_req *recv_req = (nccl_net_ofi_rdma_req *)elem; + rdma_req_recv_data_t *recv_data = get_recv_data(recv_req); + + if (recv_data->num_recvs <= 1) { + /* Single recv: consume this entry */ + front->remove(); + eager_copy_to_sub_recv(r_comm, recv_req, entry->rx_buff_req, 0); + r_comm->recv_eager_free.push_back(&entry->link); + r_comm->eager_drain_recv_seq = + (r_comm->eager_drain_recv_seq + 1) & MSG_SEQ_NUM_MASK; + resolved = true; + } else { + /* Multi recv: match by tag */ + int sub_idx = eager_match_recv(recv_req, entry->tag); + if (sub_idx >= 0) { + front->remove(); + eager_copy_to_sub_recv(r_comm, recv_req, entry->rx_buff_req, sub_idx); + r_comm->recv_eager_free.push_back(&entry->link); + // Check if all sub-recvs are now consumed + bool all_consumed = true; + for (int i = 0; i < recv_data->num_recvs; i++) { + if (!recv_data->recvs[i].consumed) { + all_consumed = false; + break; + } + } + if (all_consumed) { + r_comm->eager_drain_recv_seq = (r_comm->eager_drain_recv_seq + 1) & MSG_SEQ_NUM_MASK; + } + resolved = true; + } else { + /* This eager message belongs in the next recv but don't advance eager_drain_recv_seq + * in case there are additional eagers for this message + */ + recv_seq = (recv_seq + 1) & MSG_SEQ_NUM_MASK; + continue; + } + } + } + + if (!resolved) + break; /* recv not posted, stop draining */ + + /* Update last processed */ + r_comm->last_eager_msg_seq_num = entry->msg_seq_num; + r_comm->last_eager_offset = entry->eager_offset; + r_comm->has_processed_eager = true; + } + return 0; +} + /** * @brief Handle receiving an RDMA eager message. */ @@ -910,66 +1133,59 @@ static inline int handle_eager_recv(nccl_net_ofi_rdma_recv_comm *r_comm, return ret; } - nccl_ofi_msgbuff_status_t stat; - nccl_ofi_msgbuff_result_t mb_res = r_comm->msgbuff->insert(msg_seq_num, - rx_buff_req, NCCL_OFI_MSGBUFF_BUFF, &stat); - - if (mb_res == NCCL_OFI_MSGBUFF_SUCCESS) { - /* Inserted! In this case receiver has not yet called recv() for this message, so - return success and initiate eager read when receiver calls recv(). */ - return 0; - } - if (OFI_UNLIKELY(mb_res != NCCL_OFI_MSGBUFF_INVALID_IDX)) { - NCCL_OFI_WARN("Unexpected message insert result (%d) (eager recv)", (int)mb_res); - return -EINVAL; - } - - if (OFI_UNLIKELY(stat != NCCL_OFI_MSGBUFF_INPROGRESS)) { - NCCL_OFI_WARN("Unexpected message status (%d) (ctrl recv)", (int)stat); - return -EINVAL; - } - - // In this case, there is already a req entry here. Initiate eager copy. - void *elem; - nccl_ofi_msgbuff_elemtype_t type; - mb_res = r_comm->msgbuff->retrieve( msg_seq_num, &elem, &type, &stat); - if (OFI_UNLIKELY(mb_res != NCCL_OFI_MSGBUFF_SUCCESS || type != NCCL_OFI_MSGBUFF_REQ)) { - NCCL_OFI_WARN("Invalid message retrieval result for msg %hu", msg_seq_num); - return -EINVAL; - } - nccl_net_ofi_rdma_req *recv_req = (nccl_net_ofi_rdma_req *)elem; - rdma_req_recv_data_t *recv_data = get_recv_data(recv_req); - - rdma_req_rx_buff_data_t *rx_buff_data = get_rx_buff_data(rx_buff_req); - if (rx_buff_data->recv_len == 0) { - /* Special case: for zero-sized messages, we can skip the local read */ - /* Re-post rx buffer */ - ret = check_post_rx_buff_req(rx_buff_req); - if (ret != 0) { - NCCL_OFI_WARN("Failed call to check_post_rx_buff_req"); - return ret; - } - if (recv_data->num_recvs > 1) { - ret = inc_recv_seg_completion(recv_data->recv_segms_req, 0, 1); + /* Without multi-recv, each eager maps 1:1 to a recv by msg_seq_num. + * Look up directly in the msgbuff - no ordering or queue needed. */ + if (!r_comm->use_eager_header) { + void *elem; + nccl_ofi_msgbuff_elemtype_t type; + nccl_ofi_msgbuff_status_t stat; + nccl_ofi_msgbuff_result_t mb_res = r_comm->msgbuff->retrieve( + msg_seq_num, &elem, &type, &stat); + if (mb_res == NCCL_OFI_MSGBUFF_SUCCESS && type == NCCL_OFI_MSGBUFF_REQ) { + return eager_copy_to_sub_recv(r_comm, (nccl_net_ofi_rdma_req *)elem, rx_buff_req, 0); } else { - ret = inc_req_completion(recv_req, 0, recv_data->total_num_compls); + mb_res = r_comm->msgbuff->insert(msg_seq_num, + rx_buff_req, NCCL_OFI_MSGBUFF_BUFF, &stat); + if (OFI_UNLIKELY(mb_res != NCCL_OFI_MSGBUFF_SUCCESS)) { + NCCL_OFI_WARN("Unexpected msgbuff insert result for eager msg %hu", msg_seq_num); + return -EINVAL; + } + return 0; } - return ret; } - ret = alloc_eager_copy_req(recv_req, r_comm, rx_buff_req); - if (ret != 0) { - NCCL_OFI_WARN("Failed call to alloc_eager_copy_req"); - return ret; - } + /* Multi-recv path: parse header, enqueue in sorted order and drain */ + rdma_req_rx_buff_data_t *rx_data = get_rx_buff_data(rx_buff_req); + nccl_ofi_eager_msg_header_t *eager_hdr = + (nccl_ofi_eager_msg_header_t *)rx_data->rx_buff_fl_elem->ptr; + uint8_t eager_offset = eager_hdr->eager_offset; + int32_t eager_tag = eager_hdr->tag; + uint8_t prev_batch_count = eager_hdr->prev_batch_count; + uint16_t prev_msg_seq_num = eager_hdr->prev_msg_seq_num; + /* Adjust recv_len to exclude header */ + rx_data->recv_len -= NCCL_OFI_EAGER_HEADER_SIZE; - ret = receive_progress(recv_data->eager_copy_req, true); - if (ret != 0) { - NCCL_OFI_WARN("Failed to post eager read: %d", ret); - return ret; + /* Always enqueue in sorted order and drain */ + if (r_comm->recv_eager_free.empty()) { + NCCL_OFI_WARN("Receiver eager queue full"); + return -ENOMEM; } - return 0; + nccl_ofi_dlist_node *free_node = r_comm->recv_eager_free.pop_front(); + nccl_ofi_recv_eager_entry_t *new_entry = + nccl_ofi_dlist_entry(free_node, &nccl_ofi_recv_eager_entry_t::link); + new_entry->rx_buff_req = rx_buff_req; + new_entry->msg_seq_num = msg_seq_num; + new_entry->eager_offset = eager_offset; + new_entry->tag = eager_tag; + new_entry->prev_batch_count = prev_batch_count; + new_entry->prev_msg_seq_num = prev_msg_seq_num; + new_entry->recv_len = rx_data->recv_len; + recv_eager_sorted_insert(&r_comm->recv_eager_list, new_entry); + + ret = drain_recv_eager_queue(r_comm); + + return ret; } static int finish_connect(nccl_net_ofi_rdma_send_comm *s_comm); @@ -1326,6 +1542,13 @@ int nccl_net_ofi_rdma_context::handle_cq_entry(struct fi_cq_entry *cq_entry_base NCCL_OFI_TRACE_EAGER_SEND_COMPLETE(req->dev_id, rail_id, req->comm, req->msg_seq_num, req); send_data = get_send_data(req); assert(send_data->eager); + /* Return eager header buffer to freelist */ + if (send_data->eager_hdr_fl_entry) { + nccl_net_ofi_rdma_send_comm *sc = (nccl_net_ofi_rdma_send_comm *)req->comm; + sc->eager_hdr_fl->entry_free(send_data->eager_hdr_fl_entry); + send_data->eager_hdr_fl_entry = nullptr; + } + ret = inc_req_completion(req, 0, send_data->total_num_compls); } else if (req->type == NCCL_OFI_RDMA_SEND_CLOSE) { ret = inc_req_completion(req, sizeof(nccl_net_ofi_rdma_close_msg_t), 1); @@ -1898,8 +2121,6 @@ static inline int free_recv_req(nccl_net_ofi_rdma_req *req, (nccl_net_ofi_rdma_recv_comm *)req->comm; rdma_req_recv_data_t *recv_data = get_recv_data(req); nccl_net_ofi_rdma_req *recv_segms_req = recv_data->recv_segms_req; - nccl_net_ofi_rdma_req *eager_copy_req = recv_data->eager_copy_req; - if (recv_segms_req) { ret = recv_segms_req->free(false); if (ret) { @@ -1908,11 +2129,14 @@ static inline int free_recv_req(nccl_net_ofi_rdma_req *req, } } - if (eager_copy_req) { - ret = eager_copy_req->free(false); - if (ret) { - NCCL_OFI_WARN("Failed to free receive request"); - return ret; + /* Free eager copy requests for all sub-receives */ + for (int i = 0; i < recv_data->num_recvs; i++) { + if (recv_data->recvs[i].eager_copy_req) { + ret = recv_data->recvs[i].eager_copy_req->free(false); + if (ret) { + NCCL_OFI_WARN("Failed to free eager copy request"); + return ret; + } } } @@ -2350,7 +2574,7 @@ static inline bool has_flush_completed(nccl_net_ofi_rdma_req *req) * @brief Check the contents of the control mailbox to check if the * control message has arrived or not */ -static inline bool has_ctrl_msg(nccl_net_ofi_rdma_send_comm* s_comm, uint16_t seq_num, int tag) +static inline bool has_ctrl_msg(nccl_net_ofi_rdma_send_comm* s_comm, uint16_t seq_num, int tag, bool ignore_tag_match) { uint16_t slot = seq_num % NCCL_OFI_CTRL_MAILBOX_SIZE; uint16_t expected = (uint16_t)(seq_num & MSG_SEQ_NUM_MASK); @@ -2366,7 +2590,10 @@ static inline bool has_ctrl_msg(nccl_net_ofi_rdma_send_comm* s_comm, uint16_t se if (READ_ONCE(ctrl->entries[i].msg_seq_num) != expected) { return false; } - if (ctrl->entries[i].tag == tag) { + if (ctrl->entries[i].entry_used == 1) { + continue; + } + if (ignore_tag_match || (ctrl->entries[i].tag == tag)) { return true; } } @@ -2396,13 +2623,19 @@ static inline uint32_t get_ctrl_msg_buff_len(nccl_net_ofi_rdma_send_comm* s_comm * Increment completion for eager sends if the control message has been received * Update request size if the recv buffer is small than the send */ +static bool drain_sender_eager_queue(nccl_net_ofi_rdma_send_comm *s_comm); + static inline int update_send_request(nccl_net_ofi_rdma_send_comm* s_comm, nccl_net_ofi_rdma_req *req) { int ret = 0; rdma_req_send_data_t *send_data = get_send_data(req); + if (s_comm->eager_queue_count > 0) + drain_sender_eager_queue(s_comm); + /* Only increment completion if the send has completed */ - if (send_data->eager && has_ctrl_msg(s_comm, req->msg_seq_num, send_data->tag) && req->ncompls > 0) { + bool recv_ctrl_msg = s_comm->use_eager_header ? send_data->eager_ctrl_msg_received : has_ctrl_msg(s_comm, req->msg_seq_num, send_data->tag, true); + if (send_data->eager && recv_ctrl_msg && req->ncompls > 0) { send_data->remote_len = get_ctrl_msg_buff_len(s_comm, req->msg_seq_num, send_data->tag); if (send_data->remote_len < send_data->buff_len) { @@ -2418,7 +2651,6 @@ static inline int update_send_request(nccl_net_ofi_rdma_send_comm* s_comm, nccl_ NCCL_OFI_WARN("Failed to increase completion count for eager send request"); return ret; } - s_comm->n_ctrl_received += 1; } return ret; @@ -2536,6 +2768,8 @@ int nccl_net_ofi_rdma_req::test(int *done, int *size_p) ret = -EINVAL; goto exit; } + auto *r_comm = reinterpret_cast(this->comm); + r_comm->last_completed_seq = this->msg_seq_num; } if (this->type == NCCL_OFI_RDMA_SEND) { @@ -2988,7 +3222,6 @@ int nccl_net_ofi_rdma_recv_comm::allocate_recv_req( recv_data = get_recv_data(req); /* In the case of early completion, only expect the completion for control msg itself */ recv_data->total_num_compls = recv_completion_optional ? 1 : 2; - recv_data->eager_copy_req = NULL; recv_data->num_recvs = num_recvs; for (i = 0; i < num_recvs; i++) { recv_data->recvs[i].dst_buff = buffs[i]; @@ -2998,6 +3231,8 @@ int nccl_net_ofi_rdma_recv_comm::allocate_recv_req( recv_data->recvs[i].recv_size = 0; recv_data->recvs[i].ncompls = 0; recv_data->recvs[i].total_segms = 0; + recv_data->recvs[i].eager_copy_req = NULL; + recv_data->recvs[i].consumed = false; } ret = insert_recv_segms_req(this, device, dev_id_arg, msg_seq_num, num_recvs, req); @@ -3008,6 +3243,7 @@ int nccl_net_ofi_rdma_recv_comm::allocate_recv_req( /* Update receiver's control mailbox slot */ uint16_t slot = req->msg_seq_num % NCCL_OFI_CTRL_MAILBOX_SIZE; + /* Zero the slot first, then populate. msg_seq_num is written last as the ready bit. */ memset(&this->ctrl_mailbox[slot], 0, sizeof(nccl_net_ofi_ctrl_msg_t)); this->ctrl_mailbox[slot].entries[0].num_recvs = num_recvs; if (recv_completion_optional) { @@ -3022,6 +3258,7 @@ int nccl_net_ofi_rdma_recv_comm::allocate_recv_req( entry->buff_len = sizes[i]; entry->tag = tags[i]; entry->recv_idx = i; + entry->entry_used = 0; entry->msg_seq_num = req->msg_seq_num & MSG_SEQ_NUM_MASK; for (uint16_t rail_id = 0; rail_id < this->num_rails; rail_id++) { uint64_t rkey = fi_mr_key(mr_h->mr[rail_id].get()); @@ -3201,15 +3438,6 @@ int nccl_net_ofi_rdma_recv_comm::recv(int n, void **buffers, goto error; } - /* - * Disable eager for grouped receives (n > 1). - * Eager requires matching a single rx bounce buffer to a single dest, - * which doesn't generalize to multiple destinations. - */ - if (n > 1) { - eager = false; - } - /* NCCL versions prior to 2.24 require special handling for 0 byte * messages when using user buffer registration. NCCL passes the base * pointer from the user buffer, but passes the registration from the @@ -3248,9 +3476,9 @@ int nccl_net_ofi_rdma_recv_comm::recv(int n, void **buffers, NCCL_OFI_WARN("Failed call to check_post_rx_buff_req"); return ret; } - recv_data->eager_copy_req = NULL; + recv_data->recvs[0].eager_copy_req = NULL; } else { - ret = alloc_eager_copy_req(req, this, rx_buff_req); + ret = alloc_eager_copy_req(req, this, 0, rx_buff_req); if (ret != 0) { goto error; } @@ -3276,7 +3504,7 @@ int nccl_net_ofi_rdma_recv_comm::recv(int n, void **buffers, } if (eager) { - if (recv_data->eager_copy_req == NULL) { + if (recv_data->recvs[0].eager_copy_req == NULL) { /* If we don't need to do eager copy, this recv is already complete */ ret = inc_req_completion(req, 0, recv_data->total_num_compls); if (ret != 0) { @@ -3284,7 +3512,7 @@ int nccl_net_ofi_rdma_recv_comm::recv(int n, void **buffers, } } else { /* Post eager copy */ - ret = receive_progress(recv_data->eager_copy_req, true); + ret = receive_progress(recv_data->recvs[0].eager_copy_req, true); if (ret != 0) { NCCL_OFI_WARN("Failed to issue eager read"); /* TODO: Remove req from message buffer */ @@ -3298,6 +3526,15 @@ int nccl_net_ofi_rdma_recv_comm::recv(int n, void **buffers, /* Increment next_msg_seq_num for next call */ this->next_msg_seq_num = (this->next_msg_seq_num + 1) & MSG_SEQ_NUM_MASK; + /* Drain receiver eager queue: new recv() may unblock pending eager messages */ + if (!this->recv_eager_list.empty()) { + int drain_ret = drain_recv_eager_queue(this); + if (drain_ret != 0) { + ret = drain_ret; + goto error; + } + } + goto exit; free_req: @@ -3707,6 +3944,7 @@ static int recv_comm_process_all_finalizing(void) } nccl_net_ofi_rdma_send_comm::~nccl_net_ofi_rdma_send_comm() { + delete this->eager_hdr_fl; if (this->ctrl_mailbox) { free(this->ctrl_mailbox); } @@ -3734,6 +3972,8 @@ static int send_comm_destroy(nccl_net_ofi_rdma_send_comm *s_comm) /* Deregister control mailbox */ domain->dereg_mr(s_comm->ctrl_mr_handle); + /* Eager header freelist is cleaned up in destructor */ + /* Release communicator ID */ device->comm_idpool.free_id(s_comm->local_comm_id); @@ -4123,6 +4363,15 @@ nccl_net_ofi_rdma_recv_comm::nccl_net_ofi_rdma_recv_comm() ctrl_mailbox = nullptr; ctrl_mr_handle = nullptr; remote_mailbox_addr = 0; + /* Initialize eager entry pool as free list */ + for (int i = 0; i < NCCL_OFI_CTRL_MAILBOX_SIZE; i++) { + recv_eager_free.push_back(&recv_eager_pool[i].link); + } + has_processed_eager = false; + last_eager_msg_seq_num = 0; + last_eager_offset = 0; + eager_drain_recv_seq = NCCL_OFI_RDMA_MSG_SEQ_NUM_START; + last_completed_seq = NCCL_OFI_RDMA_MSG_SEQ_NUM_START - 1; remote_mr_key = {}; const size_t ctrl_mailbox_raw_size = sizeof(nccl_net_ofi_ctrl_msg_t) * NCCL_OFI_CTRL_MAILBOX_SIZE; @@ -4402,6 +4651,7 @@ static nccl_net_ofi_rdma_recv_comm *prepare_recv_comm(nccl_net_ofi_rdma_domain_t r_comm->next_msg_seq_num = NCCL_OFI_RDMA_MSG_SEQ_NUM_START; r_comm->ep = l_comm_ep->shared_from_this(); + r_comm->use_eager_header = device->supports_eager_header; ep = (nccl_net_ofi_rdma_ep_t *)r_comm->ep.get(); @@ -4472,9 +4722,9 @@ static nccl_net_ofi_rdma_recv_comm *prepare_recv_comm(nccl_net_ofi_rdma_domain_t /* Allocate request freelist */ /* Maximum freelist entries is 4*NCCL_OFI_MAX_REQUESTS because each receive request - can have associated reqs for send_ctrl, recv_segms, and eager_copy */ + can have associated reqs for send_ctrl, recv_segms, and 8 eager_copy */ r_comm->nccl_ofi_reqs_fl = new nccl_ofi_freelist(sizeof(nccl_net_ofi_rdma_req), 16, 16, - 4 * NCCL_OFI_MAX_REQUESTS, + 11 * NCCL_OFI_MAX_REQUESTS, rdma_fl_req_entry_init, rdma_fl_req_entry_fini, "Recv Communicator Requests", true); @@ -4672,6 +4922,7 @@ static int accept_wait_for_connection(nccl_net_ofi_rdma_domain_t *domain, */ r_comm->ep = l_comm->ep; + r_comm->use_eager_header = ep->rdma_endpoint_get_device()->supports_eager_header; /* Initialize connect response message */ nccl_ofi_rdma_connection_info_t conn_resp_msg; @@ -4972,7 +5223,7 @@ static int alloc_rdma_send_req(nccl_net_ofi_rdma_send_comm *s_comm, has not arrived, so we expect one extra completion for the ctrl msg recv. */ send_data->total_num_compls = send_data->schedule->num_xfer_infos + 1; send_data->wdata = GET_RDMA_WRITE_IMM_DATA(s_comm->remote_comm_id, req->msg_seq_num, - send_data->recv_idx, send_data->schedule->num_xfer_infos); + send_data->recv_idx, send_data->schedule->num_xfer_infos); } send_data->eager = eager; @@ -5067,19 +5318,67 @@ static int post_rdma_eager_send(nccl_net_ofi_rdma_req *req, nccl_net_ofi_rdma_send_comm_rail_t *comm_rail, nccl_net_ofi_xfer_info_t *xfer_info) { + nccl_net_ofi_rdma_send_comm *s_comm = (nccl_net_ofi_rdma_send_comm *)req->comm; rdma_req_send_data_t *send_data = get_send_data(req); assert(xfer_info->rail_id < send_data->buff_mr_handle->num_rails); uint16_t rail_id = xfer_info->rail_id; struct fid_mr *rail_mr_handle = send_data->buff_mr_handle->mr[rail_id].get(); - void *desc = fi_mr_desc(rail_mr_handle); - ssize_t rc; - /* Post eager send */ - rc = fi_senddata(comm_rail->local_ep, (void*)(((uintptr_t)send_data->buff) + xfer_info->offset), xfer_info->msg_size, desc, - send_data->wdata, comm_rail->remote_addr, rdma_req_get_ofi_context(req, rail_id)); + + if (!s_comm->use_eager_header) { + /* Single-recv path: plain fi_senddata without header */ + void *desc = fi_mr_desc(rail_mr_handle); + void *data = (void*)(((uintptr_t)send_data->buff) + xfer_info->offset); + + rc = fi_senddata(comm_rail->local_ep, data, xfer_info->msg_size, + desc, send_data->wdata, comm_rail->remote_addr, + rdma_req_get_ofi_context(req, rail_id)); + + if ((rc != 0) && (rc != -FI_EAGAIN)) { + NCCL_OFI_WARN("fi_senddata (eager) failed; RC: %zd, Error: %s", rc, fi_strerror(-rc)); + } + return rc; + } + + /* Multi-recv path: 2-iovec send with [header][payload] */ + /* Allocate eager header from freelist */ + nccl_ofi_freelist::fl_entry *hdr_fl_entry = s_comm->eager_hdr_fl->entry_alloc(); + if (OFI_UNLIKELY(!hdr_fl_entry)) { + NCCL_OFI_WARN("No free eager header buffers"); + return -ENOMEM; + } + nccl_ofi_eager_msg_header_t *hdr = (nccl_ofi_eager_msg_header_t *)hdr_fl_entry->ptr; + send_data->eager_hdr_fl_entry = hdr_fl_entry; + hdr->eager_offset = send_data->eager_offset; + hdr->tag = send_data->tag; + hdr->prev_batch_count = send_data->prev_batch_count; + hdr->prev_msg_seq_num = send_data->prev_msg_seq_num; + + /* Build 2-iovec message: [header][payload] */ + struct iovec iov[2]; + void *desc_arr[2]; + + iov[0].iov_base = hdr; + iov[0].iov_len = NCCL_OFI_EAGER_HEADER_SIZE; + freelist_regmr_fn_handle_t *hdr_mr_fl = (freelist_regmr_fn_handle_t *)hdr_fl_entry->mr_handle; + desc_arr[0] = fi_mr_desc(hdr_mr_fl->mr_handle->mr[rail_id].get()); + + iov[1].iov_base = (void*)(((uintptr_t)send_data->buff) + xfer_info->offset); + iov[1].iov_len = xfer_info->msg_size; + desc_arr[1] = fi_mr_desc(rail_mr_handle); + + struct fi_msg msg = {}; + msg.msg_iov = iov; + msg.desc = desc_arr; + msg.iov_count = 2; + msg.addr = comm_rail->remote_addr; + msg.context = rdma_req_get_ofi_context(req, rail_id); + msg.data = send_data->wdata; + + rc = fi_sendmsg(comm_rail->local_ep, &msg, FI_REMOTE_CQ_DATA); if ((rc != 0) && (rc != -FI_EAGAIN)) { - NCCL_OFI_WARN("fi_senddata failed; RC: %zd, Error: %s", rc, fi_strerror(-rc)); + NCCL_OFI_WARN("fi_sendmsg (eager) failed; RC: %zd, Error: %s", rc, fi_strerror(-rc)); } else if (rc == 0) { NCCL_OFI_TRACE_EAGER_SEND_START(req->dev_id, rail_id, xfer_info->msg_size, req->comm, req->msg_seq_num, req); } @@ -5251,7 +5550,6 @@ static int post_rdma_ctrl(nccl_net_ofi_rdma_req *req) size_t ctrl_msg_len = r_comm->ctrl_mailbox[slot].entries[0].num_recvs * sizeof(nccl_net_ofi_ctrl_msg_entry_t); nccl_net_ofi_schedule_t *schedule = NULL; - if (ep->num_control_rails > 1) { schedule = scheduler->get_schedule(ctrl_msg_len, ep->num_control_rails); @@ -5322,11 +5620,12 @@ static int post_eager_copy(nccl_net_ofi_rdma_req *req) rdma_req_rx_buff_data_t *rx_buff_data = get_rx_buff_data(eager_copy_data->eager_rx_buff_req); rdma_req_recv_data_t *recv_data = get_recv_data(eager_copy_data->recv_req); - /* Validate size of data (eager only used for single recv, so use recvs[0]) */ - if (recv_data->recvs[0].dst_len < rx_buff_data->recv_len) { + /* Validate size of data against target sub-recv */ + int sub_idx = eager_copy_data->sub_recv_idx; + if (recv_data->recvs[sub_idx].dst_len < rx_buff_data->recv_len) { NCCL_OFI_TRACE(NCCL_NET, "Recv buffer (%zu) smaller than eager send size (%zu)", - recv_data->recvs[0].dst_len, rx_buff_data->recv_len); - rx_buff_data->recv_len = recv_data->recvs[0].dst_len; + recv_data->recvs[sub_idx].dst_len, rx_buff_data->recv_len); + rx_buff_data->recv_len = recv_data->recvs[sub_idx].dst_len; } // Get communicator rail information to xfer the req @@ -5339,7 +5638,7 @@ static int post_eager_copy(nccl_net_ofi_rdma_req *req) (freelist_regmr_fn_handle_t *)rx_buff_data->rx_buff_fl_elem->mr_handle; nccl_net_ofi_rdma_mr_handle_t *rx_mr_handle = fl_handle->mr_handle; - nccl_net_ofi_rdma_mr_handle_t *dest_mr_handle = recv_data->recvs[0].dest_mr_handle; + nccl_net_ofi_rdma_mr_handle_t *dest_mr_handle = recv_data->recvs[sub_idx].dest_mr_handle; assert(rx_rail_id < dest_mr_handle->num_rails); void *desc = fi_mr_desc(dest_mr_handle->mr[rx_rail_id].get()); @@ -5351,8 +5650,10 @@ static int post_eager_copy(nccl_net_ofi_rdma_req *req) return -EIO; } - uintptr_t buff_offset = (uintptr_t)rx_buff - rx_mr_handle->base_addr; - ssize_t rc = fi_read(comm_rail->local_ep, recv_data->recvs[0].dst_buff, + /* Skip the eager header in the bounce buffer if present */ + uintptr_t buff_offset = (uintptr_t)rx_buff - rx_mr_handle->base_addr + + (r_comm->use_eager_header ? NCCL_OFI_EAGER_HEADER_SIZE : 0); + ssize_t rc = fi_read(comm_rail->local_ep, recv_data->recvs[sub_idx].dst_buff, rx_buff_data->recv_len, desc, comm_rail->local_addr, buff_offset, rx_key, rdma_req_get_ofi_context(req, rx_rail_id)); @@ -5515,6 +5816,96 @@ static inline int check_post_rx_buff_req(nccl_net_ofi_rdma_req *rx_buff_req) * @brief Send a message. This "interface function" is called, indirectly, from * the application */ + +/* + * @brief Drain the sender eager queue against available ctrl msgs. + * + * For each ctrl msg at next_msg_seq_num: + * - Single recv: pop front of queue (first entry satisfies it), advance seq. + * - Grouped recv: rotate queue once, matching tags.. + * + * Returns whether control message was not found. This is needed so it won't + * be checked again if we stopped draining. Otherwise there could be a race + * where we didn't drain but we are using a new control message + */ +static bool drain_sender_eager_queue(nccl_net_ofi_rdma_send_comm *s_comm) +{ + while (s_comm->eager_queue_count > 0) { + uint16_t seq = s_comm->next_msg_seq_num; + if (!has_ctrl_msg(s_comm, seq, 0, true)) + return false; + + /* Memory fence: we're about to read ctrl msg contents */ + std::atomic_thread_fence(std::memory_order_acquire); + + uint16_t slot = seq % NCCL_OFI_CTRL_MAILBOX_SIZE; + nccl_net_ofi_ctrl_msg_t *ctrl = &s_comm->ctrl_mailbox[slot]; + uint16_t num_recvs = ctrl->entries[0].num_recvs; + uint8_t head = s_comm->eager_queue_head; + nccl_ofi_eager_queue_entry_t *entry = &s_comm->eager_queue[head]; + rdma_req_send_data_t *send_data = get_send_data(entry->req); + /* Consume element from head. If it won't be found, we'll push it to the end*/ + s_comm->eager_queue_head = (head + 1) % NCCL_OFI_MAX_EAGER_PENDING; + s_comm->eager_queue_count--; + + if (num_recvs <= 1) { + /* Single recv: the front entry satisfies it */ + send_data->eager_ctrl_msg_received = true; + s_comm->n_ctrl_received += 1; + s_comm->next_msg_seq_num = (seq + 1) & MSG_SEQ_NUM_MASK; + /* Reset eager offset counter and prev parameters when advancing msg_seq_num if sent eager since last time */ + if (s_comm->eager_offset_next > 0) { + s_comm->prev_eager_msg_seq_num = seq; + s_comm->prev_eager_batch_count = s_comm->eager_offset_next; + s_comm->eager_offset_next = 0; + } + continue; + } + + if (s_comm->group_sends_remaining == 0) { + /* Start a new group */ + s_comm->n_ctrl_received += 1; + s_comm->group_num_recvs = num_recvs; + s_comm->group_sends_remaining = num_recvs; + } + + /* Search for matching tag in ctrl entries */ + bool found = false; + for (uint16_t e = 0; e < num_recvs; e++) { + if ((ctrl->entries[e].entry_used == 0) && (ctrl->entries[e].tag == entry->tag)) { + send_data->eager_ctrl_msg_received = true; + ctrl->entries[e].entry_used = 1; + s_comm->group_sends_remaining--; + found = true; + break; + } + } + if (!found) { + /* No match: push back */ + uint8_t tail = s_comm->eager_queue_tail; + s_comm->eager_queue[tail] = *entry; + s_comm->eager_queue_tail = (tail + 1) % NCCL_OFI_MAX_EAGER_PENDING; + s_comm->eager_queue_count++; + break; + + } + + if (s_comm->group_sends_remaining == 0) { + /* All sub-recvs satisfied by eager */ + s_comm->next_msg_seq_num = (seq + 1) & MSG_SEQ_NUM_MASK; + s_comm->group_num_recvs = 0; + /* Reset eager offset counter and prev parameters when advancing msg_seq_num if sent eager since last time */ + if (s_comm->eager_offset_next > 0) { + s_comm->prev_eager_msg_seq_num = seq; + s_comm->prev_eager_batch_count = s_comm->eager_offset_next; + s_comm->eager_offset_next = 0; + } + } + } + + return true; +} + int nccl_net_ofi_rdma_send_comm::send(void *data, size_t size, int tag, nccl_net_ofi_mr_handle_t *mhandle, nccl_net_ofi_req **base_req) { @@ -5583,9 +5974,20 @@ int nccl_net_ofi_rdma_send_comm::send(void *data, size_t size, int tag, mr_handle = domain->flush_buff.mr_handle; } - in_group = (s_comm->group_sends_remaining > 0); + /* Drain eager queue against any available ctrl msgs */ + if (s_comm->eager_queue_count > 0) { + have_ctrl = drain_sender_eager_queue(s_comm); + if (have_ctrl) { + /* Control message can be used - Check on specific tag*/ + have_ctrl = has_ctrl_msg(s_comm, s_comm->next_msg_seq_num, tag, false); + } + } else { + have_ctrl = has_ctrl_msg(s_comm, s_comm->next_msg_seq_num, tag, false); + } - have_ctrl = has_ctrl_msg(s_comm, msg_seq_num, tag); + /* Determine state again after drain*/ + msg_seq_num = s_comm->next_msg_seq_num; + in_group = (s_comm->group_sends_remaining > 0) ? true : false; /* If ctrl msg indicates a grouped recv, start tracking the group */ if (have_ctrl && !in_group) { @@ -5599,12 +6001,16 @@ int nccl_net_ofi_rdma_send_comm::send(void *data, size_t size, int tag, } /* Determine if this should be sent eagerly. - * Eager is disabled when multi-recv requires the sender to wait for - * the ctrl msg to detect grouped receives. Currently multi-recv is - * only supported with eager_send_size == -1, so the size check - * is sufficient to disable eager for multi-recv. */ - if (!have_ctrl && - (ssize_t)size <= endpoint->eager_send_size && s_comm->num_inflight_writes == 0) { + * Eager is allowed if: no ctrl msg, size + header fits, queue not full, + * not mid-group, + * no inflight writes. + * and not in the middle of eager sequence (to keep msg_seq_num of eager with offset=0 to be correct) + */ + if (!have_ctrl && !in_group && + (ssize_t)(size + NCCL_OFI_EAGER_HEADER_SIZE) <= endpoint->eager_send_size && + s_comm->eager_queue_count < NCCL_OFI_MAX_EAGER_PENDING && + s_comm->num_inflight_writes == 0 && + ((s_comm->eager_queue_count == 0) || (s_comm->eager_offset_next > 0))) { eager = true; } @@ -5627,6 +6033,23 @@ int nccl_net_ofi_rdma_send_comm::send(void *data, size_t size, int tag, goto error; } + if (s_comm->use_eager_header && eager) { + /* Set eager offset and enqueue */ + rdma_req_send_data_t *eager_send_data = get_send_data(req); + eager_send_data->eager_offset = s_comm->eager_offset_next; + eager_send_data->eager_ctrl_msg_received = false; + eager_send_data->prev_batch_count = s_comm->prev_eager_batch_count; + eager_send_data->prev_msg_seq_num = s_comm->prev_eager_msg_seq_num; + + uint8_t idx = s_comm->eager_queue_tail; + s_comm->eager_queue[idx].req = req; + s_comm->eager_queue[idx].tag = tag; + s_comm->eager_queue[idx].eager_offset = s_comm->eager_offset_next; + s_comm->eager_queue_tail = (idx + 1) % NCCL_OFI_MAX_EAGER_PENDING; + s_comm->eager_queue_count++; + s_comm->eager_offset_next++; + } + if (have_ctrl) { /* * For already received RDMA control message, populate @@ -5672,15 +6095,24 @@ int nccl_net_ofi_rdma_send_comm::send(void *data, size_t size, int tag, /* Return request to NCCL */ *base_req = req; - /* Increment next_msg_seq_num: for grouped sends, only after the last sub-send */ - if (in_group) { - s_comm->group_sends_remaining--; - if (s_comm->group_sends_remaining == 0) { - s_comm->group_num_recvs = 0; + /* Increment next_msg_seq_num: skip for eager (resolved during drain) */ + if (!eager || !s_comm->use_eager_header) { + uint16_t orig_seq = s_comm->next_msg_seq_num; + if (in_group) { + s_comm->group_sends_remaining--; + if (s_comm->group_sends_remaining == 0) { + s_comm->group_num_recvs = 0; + s_comm->next_msg_seq_num = (s_comm->next_msg_seq_num + 1) & MSG_SEQ_NUM_MASK; + } + } else { s_comm->next_msg_seq_num = (s_comm->next_msg_seq_num + 1) & MSG_SEQ_NUM_MASK; } - } else { - s_comm->next_msg_seq_num = (s_comm->next_msg_seq_num + 1) & MSG_SEQ_NUM_MASK; + /* Reset eager offset counter and prev parameters when advancing msg_seq_num if sent eager since last time */ + if ((s_comm->eager_offset_next > 0) && (orig_seq != s_comm->next_msg_seq_num)) { + s_comm->prev_eager_msg_seq_num = orig_seq; + s_comm->prev_eager_batch_count = s_comm->eager_offset_next; + s_comm->eager_offset_next = 0; + } } goto exit; @@ -6083,10 +6515,19 @@ int nccl_net_ofi_rdma_ep_t::create_send_comm(nccl_net_ofi_rdma_send_comm **s_com ret_s_comm->type = NCCL_NET_OFI_SEND_COMM; ret_s_comm->dev_id = dev_id; ret_s_comm->comm_active = true; + ret_s_comm->use_eager_header = device->supports_eager_header; ret_s_comm->next_msg_seq_num = NCCL_OFI_RDMA_MSG_SEQ_NUM_START; ret_s_comm->group_sends_remaining = 0; ret_s_comm->group_num_recvs = 0; ret_s_comm->group_tag_used = 0; + ret_s_comm->eager_queue_head = 0; + ret_s_comm->eager_queue_tail = 0; + ret_s_comm->eager_queue_count = 0; + ret_s_comm->eager_offset_next = 0; + ret_s_comm->prev_eager_msg_seq_num = 0xFFFF; + ret_s_comm->prev_eager_batch_count = 0; + ret_s_comm->eager_hdr_fl = nullptr; + ret_s_comm->eager_hdr_mr_handle = nullptr; /* The connect() API function acquired the endpoint we are using via get_ep(). Store shared_ptr in the comm to keep ep alive. */ @@ -6125,6 +6566,20 @@ int nccl_net_ofi_rdma_ep_t::create_send_comm(nccl_net_ofi_rdma_send_comm **s_com goto error; } + /* Create freelist for eager header buffers */ + ret_s_comm->eager_hdr_fl = new nccl_ofi_freelist( + sizeof(nccl_ofi_eager_msg_header_t), + NCCL_OFI_CTRL_MAILBOX_SIZE, 16, 0, + NULL, NULL, + freelist_regmr_host_fn, freelist_deregmr_host_fn, + domain_ptr, NCCL_OFI_EAGER_HEADER_SIZE, + "Eager Header", true); + if (!ret_s_comm->eager_hdr_fl) { + NCCL_OFI_WARN("Could not create eager header freelist for dev %d", dev_id); + ret = -ENOMEM; + goto error; + } + #if HAVE_NVTX_TRACING if (ofi_nccl_nvtx_trace_dimension() == NVTX_TRACE_DIMENSION::PER_COMM) { for (int i = 0; i < NCCL_OFI_N_NVTX_DOMAIN_PER_COMM; ++i) @@ -6540,7 +6995,7 @@ nccl_net_ofi_rdma_ep_t::nccl_net_ofi_rdma_ep_t(std::shared_ptreager_rx_buff_size = (this->eager_send_size == 0) ? - EAGER_RX_BUFFER_ALIGNMENT : this->eager_send_size; + EAGER_RX_BUFFER_ALIGNMENT : this->eager_send_size + (device->supports_eager_header ? NCCL_OFI_EAGER_HEADER_SIZE : 0); ret = this->init_rail_ofi_resources(device, domain_arg.get()); if (ret != 0) { diff --git a/tests/functional/.gitignore b/tests/functional/.gitignore index ebaa0a2168..0043ab5496 100644 --- a/tests/functional/.gitignore +++ b/tests/functional/.gitignore @@ -4,4 +4,5 @@ nccl_message_transfer reuse_listen_comm ring gin -grouped_recv \ No newline at end of file +grouped_recv +eager_multirecv diff --git a/tests/functional/Makefile.am b/tests/functional/Makefile.am index 1c98d3bfc3..692c147aaf 100644 --- a/tests/functional/Makefile.am +++ b/tests/functional/Makefile.am @@ -33,7 +33,7 @@ CXXLINK = OMPI_CXX="$(CXX)" MPICH_CXX="$(CXX)" \ if ENABLE_FUNC_TESTS noinst_HEADERS = functional_test.h -bin_PROGRAMS = nccl_connection nccl_message_transfer ring inflight_close reuse_listen_comm gin grouped_recv +bin_PROGRAMS = nccl_connection nccl_message_transfer ring inflight_close reuse_listen_comm gin grouped_recv eager_multirecv base_sources = functional_test.cpp @@ -45,3 +45,4 @@ reuse_listen_comm_SOURCES = $(base_sources) reuse_listen_comm.cpp gin_SOURCES = $(base_sources) gin.cpp grouped_recv_SOURCES = $(base_sources) grouped_recv.cpp endif +eager_multirecv_SOURCES = $(base_sources) eager_multirecv.cpp diff --git a/tests/functional/eager_multirecv.cpp b/tests/functional/eager_multirecv.cpp new file mode 100644 index 0000000000..832944de49 --- /dev/null +++ b/tests/functional/eager_multirecv.cpp @@ -0,0 +1,1496 @@ +/* + * Copyright (c) 2026 Amazon.com, Inc. or its affiliates. All rights reserved. + * + * Functional tests for eager message support with multi-recv. + * Tests 1-16 covering: single eager, grouped eager, mixed eager+write, + * queue ordering, tag matching, and permutation coverage. + */ + +#include "config.h" +#include "functional_test.h" +#include +#include + +/* Small size guaranteed to go eager (well under 8KB default) */ +static constexpr size_t EAGER_SIZE = 1024; +/* Large size guaranteed to NOT go eager (over 8KB + header) */ +static constexpr size_t LARGE_SIZE = 16384; + +/* + * Helper: poll a single request to completion, return size. + */ +static void poll_one(test_nccl_net_t *ext_net, void *req, int *out_size) +{ + int done = 0; + int sz = 0; + while (!done) { + OFINCCLTHROW(ext_net->test(req, &done, &sz)); + } + if (out_size) *out_size = sz; +} + +/* + * Helper: poll N send requests to completion. + */ +static void poll_sends(test_nccl_net_t *ext_net, void **reqs, int n) +{ + bool all_done = false; + while (!all_done) { + all_done = true; + for (int i = 0; i < n; i++) { + if (reqs[i]) { + int done = 0; + OFINCCLTHROW(ext_net->test(reqs[i], &done, nullptr)); + if (done) reqs[i] = nullptr; + else all_done = false; + } + } + } +} + +/* + * Helper: poll a grouped recv request, get per-sub sizes. + */ +static void poll_recv(test_nccl_net_t *ext_net, void *req, int *sizes, int n) +{ + int done = 0; + memset(sizes, 0, sizeof(int) * n); + while (!done) { + OFINCCLTHROW(ext_net->test(req, &done, sizes)); + } +} + +/* ================================================================ + * Test 5: Single recv eager — recv posted AFTER send (forces eager) + * ================================================================ */ + +/* + * Helper: post sends using a rotating queue with per-tag ordering. + * Sends that return NULL are retried in subsequent rounds. + * Within each round, once a tag fails, later sends with the same tag are skipped. + * All sends must be pre-allocated and registered before calling. + */ +static void post_sends_interleaved(test_nccl_net_t *ext_net, void *sComm, + void **bufs, size_t *sizes, int *tags, void **mhandles, + void **reqs, int count) +{ + std::queue sendq; + for (int i = 0; i < count; i++) sendq.push(i); + while (!sendq.empty()) { + std::set blocked_tags; + int round_size = sendq.size(); + for (int r = 0; r < round_size; r++) { + int idx = sendq.front(); + sendq.pop(); + if (blocked_tags.count(tags[idx])) { + sendq.push(idx); + continue; + } + void *req = nullptr; + ext_net->isend(sComm, bufs[idx], sizes[idx], + tags[idx], mhandles[idx], nullptr, &req); + if (req) { + reqs[idx] = req; + } else { + blocked_tags.insert(tags[idx]); + sendq.push(idx); + } + } + } + poll_sends(ext_net, reqs, count); +} + +/* + * T5: Single recv eager (late recv) + * Rank 0 sends 1 small (1024B) message with tag=1. + * Rank 1 waits 10ms (so send goes eager), then posts single recv. + * Validates data integrity and correct size. + * Tests: basic eager path where data arrives before recv is posted. + */ +class Test5_SingleEagerLate : public TestScenario { +public: + Test5_SingleEagerLate() : TestScenario("T5: Single recv eager (late recv)") {} + void run(ThreadContext &ctx) override { + for (size_t d = 0; d < ctx.lcomms.size(); d++) { + void *sComm = ctx.scomms[d], *rComm = ctx.rcomms[d]; + int btype = NCCL_PTR_HOST; + void *sbuf = nullptr, *rbuf = nullptr; + void *smh = nullptr, *rmh = nullptr; + OFINCCLTHROW(allocate_buff(&sbuf, EAGER_SIZE, btype)); + OFINCCLTHROW(allocate_buff(&rbuf, EAGER_SIZE, btype)); + OFINCCLTHROW(initialize_buff(sbuf, EAGER_SIZE, btype, 'E')); + OFINCCLTHROW(ext_net->regMr(sComm, sbuf, EAGER_SIZE, btype, &smh)); + OFINCCLTHROW(ext_net->regMr(rComm, rbuf, EAGER_SIZE, btype, &rmh)); + + if (ctx.rank == 0) { + void *req = nullptr; + post_send(ext_net, sComm, sbuf, EAGER_SIZE, 1, smh, &req); + poll_one(ext_net, req, nullptr); + } else { + /* Small delay so send goes eager */ + usleep(10000); + void *req = nullptr; + size_t sz = EAGER_SIZE; int tag = 1; + post_recv(ext_net, rComm, 1, &rbuf, &sz, &tag, &rmh, &req); + int rsz = 0; + poll_one(ext_net, req, &rsz); + if (rsz != (int)EAGER_SIZE) + throw std::runtime_error("T5: wrong size"); + char *exp = nullptr; + OFINCCLTHROW(allocate_buff((void**)&exp, EAGER_SIZE, btype)); + OFINCCLTHROW(initialize_buff(exp, EAGER_SIZE, btype, 'E')); + OFINCCLTHROW(validate_data((char*)rbuf, exp, EAGER_SIZE, btype)); + deallocate_buffer(exp, btype); + } + ext_net->deregMr(sComm, smh); ext_net->deregMr(rComm, rmh); + deallocate_buffer(sbuf, btype); deallocate_buffer(rbuf, btype); + } + } +}; + +/* ================================================================ + * Test 6: Single recv eager — recv posted BEFORE send + * ================================================================ */ +/* + * T6: Single recv eager (early recv) + * Rank 1 posts single recv first, then signals rank 0 via MPI barrier. + * Rank 0 sends 1 small (1024B) message with tag=1. + * Tests: eager path where recv is posted before data arrives. + */ +class Test6_SingleEagerEarly : public TestScenario { +public: + Test6_SingleEagerEarly() : TestScenario("T6: Single recv eager (early recv)") {} + void run(ThreadContext &ctx) override { + for (size_t d = 0; d < ctx.lcomms.size(); d++) { + void *sComm = ctx.scomms[d], *rComm = ctx.rcomms[d]; + int btype = NCCL_PTR_HOST; + void *sbuf = nullptr, *rbuf = nullptr, *smh = nullptr, *rmh = nullptr; + OFINCCLTHROW(allocate_buff(&sbuf, EAGER_SIZE, btype)); + OFINCCLTHROW(allocate_buff(&rbuf, EAGER_SIZE, btype)); + OFINCCLTHROW(initialize_buff(sbuf, EAGER_SIZE, btype, 'F')); + OFINCCLTHROW(ext_net->regMr(sComm, sbuf, EAGER_SIZE, btype, &smh)); + OFINCCLTHROW(ext_net->regMr(rComm, rbuf, EAGER_SIZE, btype, &rmh)); + + if (ctx.rank == 1) { + void *req = nullptr; + size_t sz = EAGER_SIZE; int tag = 1; + post_recv(ext_net, rComm, 1, &rbuf, &sz, &tag, &rmh, &req); + /* Signal rank 0 to send */ + MPI_Barrier(ctx.thread_comm); + int rsz = 0; + poll_one(ext_net, req, &rsz); + if (rsz != (int)EAGER_SIZE) + throw std::runtime_error("T6: wrong size"); + } else { + MPI_Barrier(ctx.thread_comm); + void *req = nullptr; + post_send(ext_net, sComm, sbuf, EAGER_SIZE, 1, smh, &req); + poll_one(ext_net, req, nullptr); + } + ext_net->deregMr(sComm, smh); ext_net->deregMr(rComm, rmh); + deallocate_buffer(sbuf, btype); deallocate_buffer(rbuf, btype); + } + } +}; + +/* ================================================================ + * Test 7: Multiple sequential single-recv eager messages + * ================================================================ */ +/* + * T7: 4 sequential single-recv eager messages + * Rank 0 sends 4 small messages back-to-back (tags=1, patterns 'A'-'D'). + * Rank 1 waits 20ms, then posts 4 single recvs sequentially. + * Tests: eager_offset 0-3, sender queue drain across multiple ctrl msgs. + */ +class Test7_MultiSeqEager : public TestScenario { +public: + Test7_MultiSeqEager() : TestScenario("T7: 4 sequential single-recv eager") {} + void run(ThreadContext &ctx) override { + constexpr int N = 4; + for (size_t d = 0; d < ctx.lcomms.size(); d++) { + void *sComm = ctx.scomms[d], *rComm = ctx.rcomms[d]; + int btype = NCCL_PTR_HOST; + if (ctx.rank == 0) { + for (int i = 0; i < N; i++) { + void *buf = nullptr, *mh = nullptr, *req = nullptr; + OFINCCLTHROW(allocate_buff(&buf, EAGER_SIZE, btype)); + OFINCCLTHROW(initialize_buff(buf, EAGER_SIZE, btype, 'A' + i)); + OFINCCLTHROW(ext_net->regMr(sComm, buf, EAGER_SIZE, btype, &mh)); + post_send(ext_net, sComm, buf, EAGER_SIZE, 1, mh, &req); + poll_one(ext_net, req, nullptr); + ext_net->deregMr(sComm, mh); + deallocate_buffer(buf, btype); + } + } else { + usleep(20000); /* Let all sends go eager */ + for (int i = 0; i < N; i++) { + void *buf = nullptr, *mh = nullptr, *req = nullptr; + OFINCCLTHROW(allocate_buff(&buf, EAGER_SIZE, btype)); + OFINCCLTHROW(ext_net->regMr(rComm, buf, EAGER_SIZE, btype, &mh)); + size_t sz = EAGER_SIZE; int tag = 1; + post_recv(ext_net, rComm, 1, &buf, &sz, &tag, &mh, &req); + int rsz = 0; + poll_one(ext_net, req, &rsz); + if (rsz != (int)EAGER_SIZE) + throw std::runtime_error("T7: wrong size at " + std::to_string(i)); + char *exp = nullptr; + OFINCCLTHROW(allocate_buff((void**)&exp, EAGER_SIZE, btype)); + OFINCCLTHROW(initialize_buff(exp, EAGER_SIZE, btype, 'A' + i)); + OFINCCLTHROW(validate_data((char*)buf, exp, EAGER_SIZE, btype)); + deallocate_buffer(exp, btype); + ext_net->deregMr(rComm, mh); + deallocate_buffer(buf, btype); + } + } + } + } +}; + +/* ================================================================ + * Test 8: Grouped recv with all eager + * ================================================================ */ +/* + * T8: Grouped recv (n=2) all eager + * Rank 0 sends 2 small messages with tags 10, 11. + * Rank 1 waits 20ms, then posts one grouped irecv(n=2, tags=[10,11]). + * Tests: multi-recv eager routing by tag, per-sub size reporting. + */ +class Test8_GroupedAllEager : public TestScenario { +public: + Test8_GroupedAllEager() : TestScenario("T8: Grouped recv (n=2) all eager") {} + void run(ThreadContext &ctx) override { + test_nccl_properties_t props = {}; + OFINCCLTHROW(ext_net->getProperties(0, &props)); + if (props.maxRecvs < 2) return; + + for (size_t d = 0; d < ctx.lcomms.size(); d++) { + void *sComm = ctx.scomms[d], *rComm = ctx.rcomms[d]; + int btype = NCCL_PTR_HOST; + if (ctx.rank == 0) { + for (int i = 0; i < 2; i++) { + void *buf = nullptr, *mh = nullptr, *req = nullptr; + OFINCCLTHROW(allocate_buff(&buf, EAGER_SIZE, btype)); + OFINCCLTHROW(initialize_buff(buf, EAGER_SIZE, btype, 'P' + i)); + OFINCCLTHROW(ext_net->regMr(sComm, buf, EAGER_SIZE, btype, &mh)); + post_send(ext_net, sComm, buf, EAGER_SIZE, 10 + i, mh, &req); + poll_one(ext_net, req, nullptr); + ext_net->deregMr(sComm, mh); + deallocate_buffer(buf, btype); + } + } else { + usleep(20000); + void *rbufs[2] = {}, *rmh[2] = {}; + size_t sizes[2]; int tags[2]; + for (int i = 0; i < 2; i++) { + OFINCCLTHROW(allocate_buff(&rbufs[i], EAGER_SIZE, btype)); + OFINCCLTHROW(ext_net->regMr(rComm, rbufs[i], EAGER_SIZE, btype, &rmh[i])); + sizes[i] = EAGER_SIZE; + tags[i] = 10 + i; + } + void *req = nullptr; + post_recv(ext_net, rComm, 2, rbufs, sizes, tags, rmh, &req); + int rsizes[2] = {}; + poll_recv(ext_net, req, rsizes, 2); + for (int i = 0; i < 2; i++) { + if (rsizes[i] != (int)EAGER_SIZE) + throw std::runtime_error("T8: wrong size sub " + std::to_string(i)); + char *exp = nullptr; + OFINCCLTHROW(allocate_buff((void**)&exp, EAGER_SIZE, btype)); + OFINCCLTHROW(initialize_buff(exp, EAGER_SIZE, btype, 'P' + i)); + OFINCCLTHROW(validate_data((char*)rbufs[i], exp, EAGER_SIZE, btype)); + deallocate_buffer(exp, btype); + ext_net->deregMr(rComm, rmh[i]); + deallocate_buffer(rbufs[i], btype); + } + } + } + } +}; + +/* ================================================================ + * Test 9: Grouped recv mixed eager + RDMA write + * ================================================================ */ +/* + * T9: Grouped recv (n=3) mixed eager + RDMA write + * Rank 0 sends: 1 small (eager, tag=20) + 2 large (write, tags=21,22). + * Rank 1 posts grouped irecv(n=3, tags=[20,21,22]). + * Tests: mixed eager + write completion in same grouped recv. + */ +class Test9_GroupedMixed : public TestScenario { +public: + Test9_GroupedMixed() : TestScenario("T9: Grouped recv (n=3) mixed eager+write") {} + void run(ThreadContext &ctx) override { + test_nccl_properties_t props = {}; + OFINCCLTHROW(ext_net->getProperties(0, &props)); + if (props.maxRecvs < 3) return; + + for (size_t d = 0; d < ctx.lcomms.size(); d++) { + void *sComm = ctx.scomms[d], *rComm = ctx.rcomms[d]; + int btype = NCCL_PTR_HOST; + /* Sub 0: eager, Sub 1: large write, Sub 2: large write */ + size_t send_sizes[3] = {EAGER_SIZE, LARGE_SIZE, LARGE_SIZE}; + if (ctx.rank == 0) { + for (int i = 0; i < 3; i++) { + void *buf = nullptr, *mh = nullptr, *req = nullptr; + OFINCCLTHROW(allocate_buff(&buf, send_sizes[i], btype)); + OFINCCLTHROW(initialize_buff(buf, send_sizes[i], btype, 'M' + i)); + OFINCCLTHROW(ext_net->regMr(sComm, buf, send_sizes[i], btype, &mh)); + post_send(ext_net, sComm, buf, send_sizes[i], 20 + i, mh, &req); + poll_one(ext_net, req, nullptr); + ext_net->deregMr(sComm, mh); + deallocate_buffer(buf, btype); + } + } else { + void *rbufs[3] = {}, *rmh[3] = {}; + size_t sizes[3]; int tags[3]; + for (int i = 0; i < 3; i++) { + sizes[i] = send_sizes[i]; + tags[i] = 20 + i; + OFINCCLTHROW(allocate_buff(&rbufs[i], sizes[i], btype)); + OFINCCLTHROW(ext_net->regMr(rComm, rbufs[i], sizes[i], btype, &rmh[i])); + } + void *req = nullptr; + post_recv(ext_net, rComm, 3, rbufs, sizes, tags, rmh, &req); + int rsizes[3] = {}; + poll_recv(ext_net, req, rsizes, 3); + for (int i = 0; i < 3; i++) { + if (rsizes[i] != (int)send_sizes[i]) + throw std::runtime_error("T9: wrong size sub " + std::to_string(i)); + char *exp = nullptr; + OFINCCLTHROW(allocate_buff((void**)&exp, send_sizes[i], btype)); + OFINCCLTHROW(initialize_buff(exp, send_sizes[i], btype, 'M' + i)); + OFINCCLTHROW(validate_data((char*)rbufs[i], exp, send_sizes[i], btype)); + deallocate_buffer(exp, btype); + ext_net->deregMr(rComm, rmh[i]); + deallocate_buffer(rbufs[i], btype); + } + } + } + } +}; + +/* ================================================================ + * Test 10: Grouped recv, no eager (large messages, regression) + * ================================================================ */ +/* + * T10: Grouped recv (n=2) all large (no eager, regression) + * Rank 0 sends 2 large (16KB) messages with tags 30, 31. + * Rank 1 posts grouped irecv(n=2, tags=[30,31]). + * Tests: non-eager grouped recv still works with new code. + */ +class Test10_GroupedNoEager : public TestScenario { +public: + Test10_GroupedNoEager() : TestScenario("T10: Grouped recv (n=2) all large (no eager)") {} + void run(ThreadContext &ctx) override { + test_nccl_properties_t props = {}; + OFINCCLTHROW(ext_net->getProperties(0, &props)); + if (props.maxRecvs < 2) return; + + for (size_t d = 0; d < ctx.lcomms.size(); d++) { + void *sComm = ctx.scomms[d], *rComm = ctx.rcomms[d]; + int btype = NCCL_PTR_HOST; + if (ctx.rank == 0) { + for (int i = 0; i < 2; i++) { + void *buf = nullptr, *mh = nullptr, *req = nullptr; + OFINCCLTHROW(allocate_buff(&buf, LARGE_SIZE, btype)); + OFINCCLTHROW(initialize_buff(buf, LARGE_SIZE, btype, 'L' + i)); + OFINCCLTHROW(ext_net->regMr(sComm, buf, LARGE_SIZE, btype, &mh)); + post_send(ext_net, sComm, buf, LARGE_SIZE, 30 + i, mh, &req); + poll_one(ext_net, req, nullptr); + ext_net->deregMr(sComm, mh); + deallocate_buffer(buf, btype); + } + } else { + void *rbufs[2] = {}, *rmh[2] = {}; + size_t sizes[2]; int tags[2]; + for (int i = 0; i < 2; i++) { + sizes[i] = LARGE_SIZE; tags[i] = 30 + i; + OFINCCLTHROW(allocate_buff(&rbufs[i], LARGE_SIZE, btype)); + OFINCCLTHROW(ext_net->regMr(rComm, rbufs[i], LARGE_SIZE, btype, &rmh[i])); + } + void *req = nullptr; + post_recv(ext_net, rComm, 2, rbufs, sizes, tags, rmh, &req); + int rsizes[2] = {}; + poll_recv(ext_net, req, rsizes, 2); + for (int i = 0; i < 2; i++) { + if (rsizes[i] != (int)LARGE_SIZE) + throw std::runtime_error("T10: wrong size"); + ext_net->deregMr(rComm, rmh[i]); + deallocate_buffer(rbufs[i], btype); + } + } + } + } +}; + +/* ================================================================ + * Test 11: Eager queue ordering across single + grouped + * 3 eager sends: offset 0 → single recv, offsets 1,2 → grouped(n=2) + * ================================================================ */ +/* + * T11: Eager across single + grouped recv + * Rank 0 sends 3 small messages: tag=1, tag=40, tag=41. + * Rank 1 waits 30ms, then posts: single recv (tag=1), grouped recv (n=2, tags=[40,41]). + * Eager offsets 0->single, 1,2->grouped. Tests cross-recv-type eager resolution. + */ +class Test11_EagerAcrossSingleGrouped : public TestScenario { +public: + Test11_EagerAcrossSingleGrouped() + : TestScenario("T11: Eager across single + grouped recv") {} + void run(ThreadContext &ctx) override { + test_nccl_properties_t props = {}; + OFINCCLTHROW(ext_net->getProperties(0, &props)); + if (props.maxRecvs < 2) return; + + for (size_t d = 0; d < ctx.lcomms.size(); d++) { + void *sComm = ctx.scomms[d], *rComm = ctx.rcomms[d]; + int btype = NCCL_PTR_HOST; + if (ctx.rank == 0) { + /* Send 3 small messages: tag 1, tag 40, tag 41 */ + int stags[3] = {1, 40, 41}; + for (int i = 0; i < 3; i++) { + void *buf = nullptr, *mh = nullptr, *req = nullptr; + OFINCCLTHROW(allocate_buff(&buf, EAGER_SIZE, btype)); + OFINCCLTHROW(initialize_buff(buf, EAGER_SIZE, btype, 'S' + i)); + OFINCCLTHROW(ext_net->regMr(sComm, buf, EAGER_SIZE, btype, &mh)); + post_send(ext_net, sComm, buf, EAGER_SIZE, stags[i], mh, &req); + poll_one(ext_net, req, nullptr); + ext_net->deregMr(sComm, mh); + deallocate_buffer(buf, btype); + } + } else { + usleep(30000); + /* Post single recv (tag 1) */ + void *rbuf0 = nullptr, *rmh0 = nullptr, *req0 = nullptr; + OFINCCLTHROW(allocate_buff(&rbuf0, EAGER_SIZE, btype)); + OFINCCLTHROW(ext_net->regMr(rComm, rbuf0, EAGER_SIZE, btype, &rmh0)); + size_t sz = EAGER_SIZE; int tag = 1; + post_recv(ext_net, rComm, 1, &rbuf0, &sz, &tag, &rmh0, &req0); + + /* Post grouped recv (n=2, tags 40,41) before polling first recv + * to avoid deadlock from interleaved eager sends */ + void *rbufs[2] = {}, *rmh[2] = {}; + size_t sizes[2]; int tags[2] = {40, 41}; + for (int i = 0; i < 2; i++) { + sizes[i] = EAGER_SIZE; + OFINCCLTHROW(allocate_buff(&rbufs[i], EAGER_SIZE, btype)); + OFINCCLTHROW(ext_net->regMr(rComm, rbufs[i], EAGER_SIZE, btype, &rmh[i])); + } + void *req1 = nullptr; + post_recv(ext_net, rComm, 2, rbufs, sizes, tags, rmh, &req1); + + /* Now poll both */ + int rsz = 0; + poll_one(ext_net, req0, &rsz); + int rsizes[2] = {}; + poll_recv(ext_net, req1, rsizes, 2); + + /* Validate single recv */ + char *exp = nullptr; + OFINCCLTHROW(allocate_buff((void**)&exp, EAGER_SIZE, btype)); + OFINCCLTHROW(initialize_buff(exp, EAGER_SIZE, btype, 'S')); + OFINCCLTHROW(validate_data((char*)rbuf0, exp, EAGER_SIZE, btype)); + deallocate_buffer(exp, btype); + /* Validate grouped recv */ + for (int i = 0; i < 2; i++) { + OFINCCLTHROW(allocate_buff((void**)&exp, EAGER_SIZE, btype)); + OFINCCLTHROW(initialize_buff(exp, EAGER_SIZE, btype, 'S' + 1 + i)); + OFINCCLTHROW(validate_data((char*)rbufs[i], exp, EAGER_SIZE, btype)); + deallocate_buffer(exp, btype); + ext_net->deregMr(rComm, rmh[i]); + deallocate_buffer(rbufs[i], btype); + } + ext_net->deregMr(rComm, rmh0); + deallocate_buffer(rbuf0, btype); + } + } + } +}; + +/* ================================================================ + * Test 12: Eager with tag mismatch pushback across two groups + * Sends: tags [B, D, A, C]. Groups: [B,D] then [A,C]. + * ================================================================ */ +/* + * T12: Eager tag pushback across two groups + * Rank 0 sends 4 small messages with tags [51,53,50,52] (B,D,A,C order). + * Rank 1 waits 30ms, posts: grouped(n=2, tags=[51,53]), grouped(n=2, tags=[50,52]). + * First group matches B,D; A,C are pushed back for second group. + * Tests: tag mismatch handling and re-insertion in sorted queue. + */ +class Test12_TagPushback : public TestScenario { +public: + Test12_TagPushback() : TestScenario("T12: Eager tag pushback across groups") {} + void run(ThreadContext &ctx) override { + test_nccl_properties_t props = {}; + OFINCCLTHROW(ext_net->getProperties(0, &props)); + if (props.maxRecvs < 2) return; + + for (size_t d = 0; d < ctx.lcomms.size(); d++) { + void *sComm = ctx.scomms[d], *rComm = ctx.rcomms[d]; + int btype = NCCL_PTR_HOST; + int stags[4] = {51, 53, 50, 52}; /* B, D, A, C */ + if (ctx.rank == 0) { + for (int i = 0; i < 4; i++) { + void *buf = nullptr, *mh = nullptr, *req = nullptr; + OFINCCLTHROW(allocate_buff(&buf, EAGER_SIZE, btype)); + OFINCCLTHROW(initialize_buff(buf, EAGER_SIZE, btype, 'a' + i)); + OFINCCLTHROW(ext_net->regMr(sComm, buf, EAGER_SIZE, btype, &mh)); + post_send(ext_net, sComm, buf, EAGER_SIZE, stags[i], mh, &req); + poll_one(ext_net, req, nullptr); + ext_net->deregMr(sComm, mh); + deallocate_buffer(buf, btype); + } + } else { + usleep(30000); + /* Group 1: tags [B=51, D=53] */ + void *rbufs1[2] = {}, *rmh1[2] = {}; + size_t sizes1[2]; int tags1[2] = {51, 53}; + for (int i = 0; i < 2; i++) { + sizes1[i] = EAGER_SIZE; + OFINCCLTHROW(allocate_buff(&rbufs1[i], EAGER_SIZE, btype)); + OFINCCLTHROW(ext_net->regMr(rComm, rbufs1[i], EAGER_SIZE, btype, &rmh1[i])); + } + void *req1 = nullptr; + post_recv(ext_net, rComm, 2, rbufs1, sizes1, tags1, rmh1, &req1); + + /* Group 2: tags [A=50, C=52] */ + void *rbufs2[2] = {}, *rmh2[2] = {}; + size_t sizes2[2]; int tags2[2] = {50, 52}; + for (int i = 0; i < 2; i++) { + sizes2[i] = EAGER_SIZE; + OFINCCLTHROW(allocate_buff(&rbufs2[i], EAGER_SIZE, btype)); + OFINCCLTHROW(ext_net->regMr(rComm, rbufs2[i], EAGER_SIZE, btype, &rmh2[i])); + } + void *req2 = nullptr; + post_recv(ext_net, rComm, 2, rbufs2, sizes2, tags2, rmh2, &req2); + + int rs1[2] = {}, rs2[2] = {}; + poll_recv(ext_net, req1, rs1, 2); + poll_recv(ext_net, req2, rs2, 2); + + /* Group1[B]=send0('a'), Group1[D]=send1('b') */ + char *exp = nullptr; + OFINCCLTHROW(allocate_buff((void**)&exp, EAGER_SIZE, btype)); + OFINCCLTHROW(initialize_buff(exp, EAGER_SIZE, btype, 'a')); + OFINCCLTHROW(validate_data((char*)rbufs1[0], exp, EAGER_SIZE, btype)); + OFINCCLTHROW(initialize_buff(exp, EAGER_SIZE, btype, 'b')); + OFINCCLTHROW(validate_data((char*)rbufs1[1], exp, EAGER_SIZE, btype)); + /* Group2[A]=send2('c'), Group2[C]=send3('d') */ + OFINCCLTHROW(initialize_buff(exp, EAGER_SIZE, btype, 'c')); + OFINCCLTHROW(validate_data((char*)rbufs2[0], exp, EAGER_SIZE, btype)); + OFINCCLTHROW(initialize_buff(exp, EAGER_SIZE, btype, 'd')); + OFINCCLTHROW(validate_data((char*)rbufs2[1], exp, EAGER_SIZE, btype)); + deallocate_buffer(exp, btype); + + for (int i = 0; i < 2; i++) { + ext_net->deregMr(rComm, rmh1[i]); deallocate_buffer(rbufs1[i], btype); + ext_net->deregMr(rComm, rmh2[i]); deallocate_buffer(rbufs2[i], btype); + } + } + } + } +}; + +/* ================================================================ + * Test 13: Eager queue full (8 messages) + * ================================================================ */ +/* + * T13: 32 eager messages (queue full) + * Rank 0 sends 32 small messages with tag=1, all at once. + * Rank 1 waits 50ms, then posts 32 single recvs sequentially. + * Each message has distinct data pattern ('0'+i). + * Tests: sender eager queue at max capacity, drain across 32 ctrl msgs. + */ +class Test13_QueueFull : public TestScenario { +public: + Test13_QueueFull() : TestScenario("T13: 32 eager messages (queue full)") {} + void run(ThreadContext &ctx) override { + constexpr int N = 32; + for (size_t d = 0; d < ctx.lcomms.size(); d++) { + void *sComm = ctx.scomms[d], *rComm = ctx.rcomms[d]; + int btype = NCCL_PTR_HOST; + if (ctx.rank == 0) { + void *reqs[N] = {}; + void *bufs[N] = {}, *mhs[N] = {}; + for (int i = 0; i < N; i++) { + OFINCCLTHROW(allocate_buff(&bufs[i], EAGER_SIZE, btype)); + OFINCCLTHROW(initialize_buff(bufs[i], EAGER_SIZE, btype, '0' + i)); + OFINCCLTHROW(ext_net->regMr(sComm, bufs[i], EAGER_SIZE, btype, &mhs[i])); + post_send(ext_net, sComm, bufs[i], EAGER_SIZE, 1, mhs[i], &reqs[i]); + } + poll_sends(ext_net, reqs, N); + for (int i = 0; i < N; i++) { + ext_net->deregMr(sComm, mhs[i]); + deallocate_buffer(bufs[i], btype); + } + } else { + void *bufs[N] = {}, *mhs[N] = {}, *exps[N] = {}; + for (int i = 0; i < N; i++) { + OFINCCLTHROW(allocate_buff(&bufs[i], EAGER_SIZE, btype)); + OFINCCLTHROW(ext_net->regMr(rComm, bufs[i], EAGER_SIZE, btype, &mhs[i])); + OFINCCLTHROW(allocate_buff(&exps[i], EAGER_SIZE, btype)); + OFINCCLTHROW(initialize_buff(exps[i], EAGER_SIZE, btype, '0' + i)); + } + usleep(50000); + for (int i = 0; i < N; i++) { + void *req = nullptr; + size_t sz = EAGER_SIZE; int tag = 1; + post_recv(ext_net, rComm, 1, &bufs[i], &sz, &tag, &mhs[i], &req); + int rsz = 0; + poll_one(ext_net, req, &rsz); + if (rsz != (int)EAGER_SIZE) + throw std::runtime_error("T13: wrong size at " + std::to_string(i)); + OFINCCLTHROW(validate_data((char*)bufs[i], (char*)exps[i], EAGER_SIZE, btype)); + } + for (int i = 0; i < N; i++) { + ext_net->deregMr(rComm, mhs[i]); + deallocate_buffer(bufs[i], btype); + deallocate_buffer(exps[i], btype); + } + } + } + } +}; + +/* ================================================================ + * Test 14: Eager size boundary + * ================================================================ */ +/* + * T14: Eager size boundary + * Tests two message sizes: 8184B (should go eager: 8184+8=8192) and + * 8185B (should NOT go eager: 8185+8=8193 > 8192). + * Validates correctness at both sizes. Cannot directly verify eager vs write + * path, but ensures no corruption at the boundary. + */ +class Test14_SizeBoundary : public TestScenario { +public: + Test14_SizeBoundary() : TestScenario("T14: Eager size boundary") {} + void run(ThreadContext &ctx) override { + /* We can't directly check if eager was used, but we verify + * correctness at the boundary. The eager threshold is + * eager_send_size; with 8B header, max payload = eager_send_size - 8. + * Default eager_send_size = 8192, so max eager payload = 8184. */ + size_t fits = 8184; /* Should go eager */ + size_t no_fit = 8185; /* Should NOT go eager (8185 + 8 = 8193 > 8192) */ + + for (size_t d = 0; d < ctx.lcomms.size(); d++) { + void *sComm = ctx.scomms[d], *rComm = ctx.rcomms[d]; + int btype = NCCL_PTR_HOST; + size_t test_sizes[2] = {fits, no_fit}; + for (int t = 0; t < 2; t++) { + size_t sz = test_sizes[t]; + if (ctx.rank == 0) { + void *buf = nullptr, *mh = nullptr, *req = nullptr; + OFINCCLTHROW(allocate_buff(&buf, sz, btype)); + OFINCCLTHROW(initialize_buff(buf, sz, btype, 'B' + t)); + OFINCCLTHROW(ext_net->regMr(sComm, buf, sz, btype, &mh)); + post_send(ext_net, sComm, buf, sz, 1, mh, &req); + poll_one(ext_net, req, nullptr); + ext_net->deregMr(sComm, mh); + deallocate_buffer(buf, btype); + } else { + void *buf = nullptr, *mh = nullptr, *req = nullptr; + OFINCCLTHROW(allocate_buff(&buf, sz, btype)); + OFINCCLTHROW(ext_net->regMr(rComm, buf, sz, btype, &mh)); + int tag = 1; + post_recv(ext_net, rComm, 1, &buf, &sz, &tag, &mh, &req); + int rsz = 0; + poll_one(ext_net, req, &rsz); + if (rsz != (int)sz) + throw std::runtime_error("T14: wrong size for " + std::to_string(sz)); + ext_net->deregMr(rComm, mh); + deallocate_buffer(buf, btype); + } + MPI_Barrier(ctx.thread_comm); + } + } + } +}; + +/* ================================================================ + * Test 15: Two grouped recvs (n=4), all eager/write permutations + * T0: eager/eager, T1: eager/write, T2: write/eager, T3: write/write + * ================================================================ */ +/* + * T15: Two grouped recvs (n=4), all eager/write permutations per tag + * Tags [60,61,62,63]. Two groups A and B, each n=4 with same tags. + * Per-tag pattern across groups: + * Tag 60: eager/eager, Tag 61: eager/write, Tag 62: write/eager, Tag 63: write/write + * Rank 0 sends 8 messages (4 per group) with appropriate sizes. + * Tests: every combination of eager vs write for same tag across consecutive groups. + */ +class Test15_PermutationEagerWrite : public TestScenario { +public: + Test15_PermutationEagerWrite() + : TestScenario("T15: 2x grouped(n=4) eager/write permutations") {} + void run(ThreadContext &ctx) override { + test_nccl_properties_t props = {}; + OFINCCLTHROW(ext_net->getProperties(0, &props)); + if (props.maxRecvs < 4) return; + + constexpr int N = 4; + int base_tag = 60; + /* Per-group, per-tag: is it eager (small) or write (large)? */ + /* Group A: T0=eager, T1=eager, T2=write, T3=write */ + /* Group B: T0=eager, T1=write, T2=eager, T3=write */ + bool is_eager[2][N] = { + {true, true, false, false}, /* Group A */ + {true, false, true, false}, /* Group B */ + }; + + for (size_t d = 0; d < ctx.lcomms.size(); d++) { + void *sComm = ctx.scomms[d], *rComm = ctx.rcomms[d]; + int btype = NCCL_PTR_HOST; + + if (ctx.rank == 0) { + /* Send group A's 4 messages, then group B's 4 */ + for (int g = 0; g < 2; g++) { + for (int i = 0; i < N; i++) { + size_t sz = is_eager[g][i] ? EAGER_SIZE : LARGE_SIZE; + char pattern = 'A' + g * N + i; + void *buf = nullptr, *mh = nullptr, *req = nullptr; + OFINCCLTHROW(allocate_buff(&buf, sz, btype)); + OFINCCLTHROW(initialize_buff(buf, sz, btype, pattern)); + OFINCCLTHROW(ext_net->regMr(sComm, buf, sz, btype, &mh)); + post_send(ext_net, sComm, buf, sz, base_tag + i, mh, &req); + poll_one(ext_net, req, nullptr); + ext_net->deregMr(sComm, mh); + deallocate_buffer(buf, btype); + } + } + } else { + /* Sleep to make sure all sends that can be eager are sent */ + usleep(50000); + /* Post two grouped recvs */ + for (int g = 0; g < 2; g++) { + void *rbufs[N] = {}, *rmh[N] = {}; + size_t sizes[N]; int tags[N]; + for (int i = 0; i < N; i++) { + sizes[i] = is_eager[g][i] ? EAGER_SIZE : LARGE_SIZE; + tags[i] = base_tag + i; + OFINCCLTHROW(allocate_buff(&rbufs[i], sizes[i], btype)); + OFINCCLTHROW(ext_net->regMr(rComm, rbufs[i], sizes[i], btype, &rmh[i])); + } + void *req = nullptr; + post_recv(ext_net, rComm, N, rbufs, sizes, tags, rmh, &req); + int rsizes[N] = {}; + poll_recv(ext_net, req, rsizes, N); + for (int i = 0; i < N; i++) { + size_t expected_sz = is_eager[g][i] ? EAGER_SIZE : LARGE_SIZE; + if (rsizes[i] != (int)expected_sz) + throw std::runtime_error( + "T15: wrong size g=" + std::to_string(g) + + " i=" + std::to_string(i)); + char pattern = 'A' + g * N + i; + char *exp = nullptr; + OFINCCLTHROW(allocate_buff((void**)&exp, expected_sz, btype)); + OFINCCLTHROW(initialize_buff(exp, expected_sz, btype, pattern)); + OFINCCLTHROW(validate_data((char*)rbufs[i], exp, expected_sz, btype)); + deallocate_buffer(exp, btype); + ext_net->deregMr(rComm, rmh[i]); + deallocate_buffer(rbufs[i], btype); + } + } + } + } + } +}; + +/* ================================================================ + * Test 16: 4x grouped(n=2) same tags, verify in-order per tag + * Tags [X=70, Y=71] repeated 4 times. Patterns must arrive in order. + * ================================================================ */ +/* + * T16: 4x grouped(n=2) same tags, verify in-order delivery per tag + * Tags [70,71] repeated across 4 grouped recvs. + * Rank 0 sends 32 small messages alternating tag 70, 71 with distinct patterns. + * Rank 1 posts 4 grouped recvs (n=2, tags=[70,71]). + * Validates patterns arrive in correct order within each tag across groups. + * Tests: ordering guarantee when same tags repeat across multiple grouped recvs. + */ +class Test16_OrderingPerTag : public TestScenario { +public: + Test16_OrderingPerTag() + : TestScenario("T16: 4x grouped(n=2) same tags, ordering per tag") {} + void run(ThreadContext &ctx) override { + test_nccl_properties_t props = {}; + OFINCCLTHROW(ext_net->getProperties(0, &props)); + if (props.maxRecvs < 2) return; + + constexpr int NGROUPS = 4; + int tag_x = 70, tag_y = 71; + + for (size_t d = 0; d < ctx.lcomms.size(); d++) { + void *sComm = ctx.scomms[d], *rComm = ctx.rcomms[d]; + int btype = NCCL_PTR_HOST; + + if (ctx.rank == 0) { + /* Send 8 messages: alternating tag X, tag Y */ + for (int g = 0; g < NGROUPS; g++) { + for (int sub = 0; sub < 2; sub++) { + int tag = (sub == 0) ? tag_x : tag_y; + char pattern = '0' + g * 2 + sub; + void *buf = nullptr, *mh = nullptr, *req = nullptr; + OFINCCLTHROW(allocate_buff(&buf, EAGER_SIZE, btype)); + OFINCCLTHROW(initialize_buff(buf, EAGER_SIZE, btype, pattern)); + OFINCCLTHROW(ext_net->regMr(sComm, buf, EAGER_SIZE, btype, &mh)); + post_send(ext_net, sComm, buf, EAGER_SIZE, tag, mh, &req); + poll_one(ext_net, req, nullptr); + ext_net->deregMr(sComm, mh); + deallocate_buffer(buf, btype); + } + } + } else { + /* Post 4 grouped recvs, each n=2 with tags [X, Y] */ + void *reqs[NGROUPS] = {}; + void *rbufs[NGROUPS][2] = {}, *rmh[NGROUPS][2] = {}; + for (int g = 0; g < NGROUPS; g++) { + size_t sizes[2] = {EAGER_SIZE, EAGER_SIZE}; + int tags[2] = {tag_x, tag_y}; + void *bufs[2], *mhs[2]; + for (int i = 0; i < 2; i++) { + OFINCCLTHROW(allocate_buff(&rbufs[g][i], EAGER_SIZE, btype)); + OFINCCLTHROW(ext_net->regMr(rComm, rbufs[g][i], EAGER_SIZE, btype, &rmh[g][i])); + bufs[i] = rbufs[g][i]; + mhs[i] = rmh[g][i]; + } + post_recv(ext_net, rComm, 2, bufs, sizes, tags, mhs, &reqs[g]); + } + + /* Poll all to completion */ + for (int g = 0; g < NGROUPS; g++) { + int rsizes[2] = {}; + poll_recv(ext_net, reqs[g], rsizes, 2); + } + + /* Validate ordering: group g should have patterns g*2, g*2+1 */ + for (int g = 0; g < NGROUPS; g++) { + for (int sub = 0; sub < 2; sub++) { + char expected_pattern = '0' + g * 2 + sub; + char *exp = nullptr; + OFINCCLTHROW(allocate_buff((void**)&exp, EAGER_SIZE, btype)); + OFINCCLTHROW(initialize_buff(exp, EAGER_SIZE, btype, expected_pattern)); + OFINCCLTHROW(validate_data((char*)rbufs[g][sub], exp, EAGER_SIZE, btype)); + deallocate_buffer(exp, btype); + ext_net->deregMr(rComm, rmh[g][sub]); + deallocate_buffer(rbufs[g][sub], btype); + } + } + } + } + } +}; + + +/* ================================================================ + * T17: Interleaved sends across two grouped recvs (all writes) + * Two grouped recvs (n=8, tags 0-7). Sender sends tags in order: + * 0,1,2,3,4,5,6, 1(msg2), 7(msg1), 0(msg2),2,3,4,5,6,7 + * The second message's tag=1 arrives before the first message's tag=7. + * All messages are large (write path, no eager). + * ================================================================ */ +class Test17_InterleavedAllWrite : public TestScenario { +public: + Test17_InterleavedAllWrite() + : TestScenario("T17: Interleaved sends, 2x grouped(n=8), all write") {} + void run(ThreadContext &ctx) override { + test_nccl_properties_t props = {}; + OFINCCLTHROW(ext_net->getProperties(0, &props)); + if (props.maxRecvs < 8) return; + + for (size_t d = 0; d < 1 && d < ctx.lcomms.size(); d++) { + void *sComm = ctx.scomms[d], *rComm = ctx.rcomms[d]; + int btype = NCCL_PTR_HOST; + constexpr int N = 8; + /* Send order: msg1 tags 0-6, msg2 tag 1, msg1 tag 7, msg2 tags 0,2-7 */ + struct { int msg; int tag; } send_order[] = { + {0,0},{0,1},{0,2},{0,3},{0,4},{0,5},{0,6}, + {1,1}, /* msg2 tag 1 before msg1 tag 7 */ + {0,7}, /* msg1 tag 7 */ + {1,0},{1,2},{1,3},{1,4},{1,5},{1,6},{1,7} + }; + constexpr int total_sends = 16; + + if (ctx.rank == 0) { + const int TOTAL = total_sends; + void *sbufs[16] = {}, *smhs[16] = {}, *sreqs[16] = {}; + for (int i = 0; i < TOTAL; i++) { + int tag = send_order[i].tag; + char pattern = 'A' + send_order[i].msg * N + tag; + OFINCCLTHROW(allocate_buff(&sbufs[i], LARGE_SIZE, btype)); + OFINCCLTHROW(initialize_buff(sbufs[i], LARGE_SIZE, btype, pattern)); + OFINCCLTHROW(ext_net->regMr(sComm, sbufs[i], LARGE_SIZE, btype, &smhs[i])); + } + int stags_arr[TOTAL]; + for (int i = 0; i < TOTAL; i++) stags_arr[i] = send_order[i].tag; + size_t ssizes_arr[TOTAL]; + for (int i = 0; i < TOTAL; i++) ssizes_arr[i] = LARGE_SIZE; + post_sends_interleaved(ext_net, sComm, + sbufs, ssizes_arr, stags_arr, smhs, sreqs, TOTAL); + for (int i = 0; i < TOTAL; i++) { + ext_net->deregMr(sComm, smhs[i]); + deallocate_buffer(sbufs[i], btype); + } + } else { + /* Post all recvs before polling to avoid deadlock from interleaved eager sends */ + void *all_rbufs[2][8] = {}; + void *all_rmh[2][8] = {}; + void *all_reqs[2] = {}; + for (int g = 0; g < 2; g++) { + size_t sizes[8]; int tags[8]; + for (int i = 0; i < 8; i++) { + sizes[i] = SEND_SIZE; tags[i] = i; + OFINCCLTHROW(allocate_buff(&all_rbufs[g][i], SEND_SIZE, btype)); + OFINCCLTHROW(ext_net->regMr(rComm, all_rbufs[g][i], SEND_SIZE, btype, &all_rmh[g][i])); + } + post_recv(ext_net, rComm, 8, all_rbufs[g], sizes, tags, all_rmh[g], &all_reqs[g]); + } + for (int g = 0; g < 2; g++) { + int rsizes[8] = {}; + poll_recv(ext_net, all_reqs[g], rsizes, 8); + for (int i = 0; i < 8; i++) { + char pattern = 'A' + g * 8 + i; + char *exp = nullptr; + OFINCCLTHROW(allocate_buff((void**)&exp, SEND_SIZE, btype)); + OFINCCLTHROW(initialize_buff(exp, SEND_SIZE, btype, pattern)); + OFINCCLTHROW(validate_data((char*)all_rbufs[g][i], exp, SEND_SIZE, btype)); + deallocate_buffer(exp, btype); + ext_net->deregMr(rComm, all_rmh[g][i]); + deallocate_buffer(all_rbufs[g][i], btype); + } + } + } + } + } +}; + +/* ================================================================ + * T18: Interleaved sends across two grouped recvs (all eager) + * Same interleaving as T17 but with small messages (eager path). + * ================================================================ */ +class Test18_InterleavedAllEager : public TestScenario { +public: + Test18_InterleavedAllEager() + : TestScenario("T18: Interleaved sends, 2x grouped(n=8), all eager") {} + void run(ThreadContext &ctx) override { + test_nccl_properties_t props = {}; + OFINCCLTHROW(ext_net->getProperties(0, &props)); + if (props.maxRecvs < 8) return; + + for (size_t d = 0; d < 1 && d < ctx.lcomms.size(); d++) { + void *sComm = ctx.scomms[d], *rComm = ctx.rcomms[d]; + int btype = NCCL_PTR_HOST; + constexpr int N = 8; + struct { int msg; int tag; } send_order[] = { + {0,0},{0,1},{0,2},{0,3},{0,4},{0,5},{0,6}, + {1,1},{0,7},{1,0},{1,2},{1,3},{1,4},{1,5},{1,6},{1,7} + }; + constexpr int total_sends = 16; + + if (ctx.rank == 0) { + const int TOTAL = total_sends; + void *sbufs[16] = {}, *smhs[16] = {}, *sreqs[16] = {}; + for (int i = 0; i < TOTAL; i++) { + int tag = send_order[i].tag; + char pattern = 'a' + send_order[i].msg * N + tag; + OFINCCLTHROW(allocate_buff(&sbufs[i], EAGER_SIZE, btype)); + OFINCCLTHROW(initialize_buff(sbufs[i], EAGER_SIZE, btype, pattern)); + OFINCCLTHROW(ext_net->regMr(sComm, sbufs[i], EAGER_SIZE, btype, &smhs[i])); + } + int stags_arr[TOTAL]; + for (int i = 0; i < TOTAL; i++) stags_arr[i] = send_order[i].tag; + size_t ssizes_arr[TOTAL]; + for (int i = 0; i < TOTAL; i++) ssizes_arr[i] = EAGER_SIZE; + post_sends_interleaved(ext_net, sComm, + sbufs, ssizes_arr, stags_arr, smhs, sreqs, TOTAL); + for (int i = 0; i < TOTAL; i++) { + ext_net->deregMr(sComm, smhs[i]); + deallocate_buffer(sbufs[i], btype); + } + } else { + usleep(30000); + /* Post all recvs before polling to avoid deadlock from interleaved eager sends */ + void *all_rbufs[2][8] = {}; + void *all_rmh[2][8] = {}; + void *all_reqs[2] = {}; + for (int g = 0; g < 2; g++) { + size_t sizes[8]; int tags[8]; + for (int i = 0; i < 8; i++) { + sizes[i] = EAGER_SIZE; tags[i] = i; + OFINCCLTHROW(allocate_buff(&all_rbufs[g][i], EAGER_SIZE, btype)); + OFINCCLTHROW(ext_net->regMr(rComm, all_rbufs[g][i], EAGER_SIZE, btype, &all_rmh[g][i])); + } + post_recv(ext_net, rComm, 8, all_rbufs[g], sizes, tags, all_rmh[g], &all_reqs[g]); + } + for (int g = 0; g < 2; g++) { + int rsizes[8] = {}; + poll_recv(ext_net, all_reqs[g], rsizes, 8); + for (int i = 0; i < 8; i++) { + char pattern = 'a' + g * 8 + i; + char *exp = nullptr; + OFINCCLTHROW(allocate_buff((void**)&exp, EAGER_SIZE, btype)); + OFINCCLTHROW(initialize_buff(exp, EAGER_SIZE, btype, pattern)); + OFINCCLTHROW(validate_data((char*)all_rbufs[g][i], exp, EAGER_SIZE, btype)); + deallocate_buffer(exp, btype); + ext_net->deregMr(rComm, all_rmh[g][i]); + deallocate_buffer(all_rbufs[g][i], btype); + } + } + } + } + } +}; + +/* ================================================================ + * T19: Interleaved sends, mixed eager+write within each group + * Two grouped recvs (n=4, tags 0-3). Tags 0,1 are eager, tags 2,3 are write. + * Send order: msg1(0e,1e,2w,3w) interleaved with msg2 starting at tag 1: + * 0e(m1), 1e(m1), 2w(m1), 1e(m2), 3w(m1), 0e(m2), 2w(m2), 3w(m2) + * ================================================================ */ +class Test19_InterleavedMixed : public TestScenario { +public: + Test19_InterleavedMixed() + : TestScenario("T19: Interleaved sends, 2x grouped(n=4), mixed eager+write") {} + void run(ThreadContext &ctx) override { + test_nccl_properties_t props = {}; + OFINCCLTHROW(ext_net->getProperties(0, &props)); + if (props.maxRecvs < 4) return; + + for (size_t d = 0; d < 1 && d < ctx.lcomms.size(); d++) { + void *sComm = ctx.scomms[d], *rComm = ctx.rcomms[d]; + int btype = NCCL_PTR_HOST; + constexpr int N = 4; + bool is_eager[] = {true, true, false, false}; + struct { int msg; int tag; } send_order[] = { + {0,0},{0,1},{0,2}, + {1,1}, /* msg2 tag 1 (eager) before msg1 tag 3 */ + {0,3}, + {1,0},{1,2},{1,3} + }; + constexpr int total_sends = 8; + + if (ctx.rank == 0) { + const int TOTAL = total_sends; + void *sbufs[8] = {}, *smhs[8] = {}, *sreqs[8] = {}; + size_t ssizes[8]; + for (int i = 0; i < TOTAL; i++) { + int tag = send_order[i].tag; + size_t sz = is_eager[tag] ? EAGER_SIZE : LARGE_SIZE; + char pattern = 'A' + send_order[i].msg * N + tag; + ssizes[i] = sz; + OFINCCLTHROW(allocate_buff(&sbufs[i], sz, btype)); + OFINCCLTHROW(initialize_buff(sbufs[i], sz, btype, pattern)); + OFINCCLTHROW(ext_net->regMr(sComm, sbufs[i], sz, btype, &smhs[i])); + } + int stags_arr[TOTAL]; + for (int i = 0; i < TOTAL; i++) stags_arr[i] = send_order[i].tag; + post_sends_interleaved(ext_net, sComm, + sbufs, ssizes, stags_arr, smhs, sreqs, TOTAL); + for (int i = 0; i < TOTAL; i++) { + ext_net->deregMr(sComm, smhs[i]); + deallocate_buffer(sbufs[i], btype); + } + } else { + /* Post all recvs before polling to avoid deadlock from interleaved eager sends */ + void *all_rbufs[2][8] = {}; + void *all_rmh[2][8] = {}; + void *all_reqs[2] = {}; + for (int g = 0; g < 2; g++) { + size_t sizes[4]; int tags[4]; + for (int i = 0; i < 4; i++) { + sizes[i] = EAGER_SIZE; tags[i] = i; + OFINCCLTHROW(allocate_buff(&all_rbufs[g][i], EAGER_SIZE, btype)); + OFINCCLTHROW(ext_net->regMr(rComm, all_rbufs[g][i], EAGER_SIZE, btype, &all_rmh[g][i])); + } + post_recv(ext_net, rComm, 4, all_rbufs[g], sizes, tags, all_rmh[g], &all_reqs[g]); + } + for (int g = 0; g < 2; g++) { + int rsizes[4] = {}; + poll_recv(ext_net, all_reqs[g], rsizes, 4); + for (int i = 0; i < 4; i++) { + char pattern = 'A' + g * 4 + i; + char *exp = nullptr; + OFINCCLTHROW(allocate_buff((void**)&exp, EAGER_SIZE, btype)); + OFINCCLTHROW(initialize_buff(exp, EAGER_SIZE, btype, pattern)); + OFINCCLTHROW(validate_data((char*)all_rbufs[g][i], exp, EAGER_SIZE, btype)); + deallocate_buffer(exp, btype); + ext_net->deregMr(rComm, all_rmh[g][i]); + deallocate_buffer(all_rbufs[g][i], btype); + } + } + } + } + } +}; + +/* ================================================================ + * T21: Three grouped recvs interleaved, same tags, different sizes + * 3x grouped(n=2, tags [0,1]). Msg sizes: group0=EAGER, group1=LARGE, group2=EAGER. + * Sender interleaves: g0t0, g0t1, g1t0, g2t0, g1t1, g2t1 + * Tests interleaving across 3 groups with mixed eager/write. + * ================================================================ */ +class Test21_ThreeGroupsInterleaved : public TestScenario { +public: + Test21_ThreeGroupsInterleaved() + : TestScenario("T21: 3x grouped(n=2) interleaved, mixed sizes") {} + void run(ThreadContext &ctx) override { + test_nccl_properties_t props = {}; + OFINCCLTHROW(ext_net->getProperties(0, &props)); + if (props.maxRecvs < 2) return; + + for (size_t d = 0; d < 1 && d < ctx.lcomms.size(); d++) { + void *sComm = ctx.scomms[d], *rComm = ctx.rcomms[d]; + int btype = NCCL_PTR_HOST; + size_t group_sizes[] = {EAGER_SIZE, LARGE_SIZE, EAGER_SIZE}; + struct { int grp; int tag; } send_order[] = { + {0,0},{0,1},{1,0},{2,0},{1,1},{2,1} + }; + constexpr int total_sends = 6; + + if (ctx.rank == 0) { + const int TOTAL = total_sends; + void *sbufs[6] = {}, *smhs[6] = {}, *sreqs[6] = {}; + size_t ssizes[6]; + for (int i = 0; i < TOTAL; i++) { + int g = send_order[i].grp; + int tag = send_order[i].tag; + size_t sz = group_sizes[g]; + char pattern = 'A' + g * 2 + tag; + ssizes[i] = sz; + OFINCCLTHROW(allocate_buff(&sbufs[i], sz, btype)); + OFINCCLTHROW(initialize_buff(sbufs[i], sz, btype, pattern)); + OFINCCLTHROW(ext_net->regMr(sComm, sbufs[i], sz, btype, &smhs[i])); + } + int stags_arr[TOTAL]; + for (int i = 0; i < TOTAL; i++) stags_arr[i] = send_order[i].tag; + post_sends_interleaved(ext_net, sComm, + sbufs, ssizes, stags_arr, smhs, sreqs, TOTAL); + for (int i = 0; i < TOTAL; i++) { + ext_net->deregMr(sComm, smhs[i]); + deallocate_buffer(sbufs[i], btype); + } + } else { + /* Post all recvs before polling to avoid deadlock from interleaved eager sends */ + void *all_rbufs[3][8] = {}; + void *all_rmh[3][8] = {}; + void *all_reqs[3] = {}; + for (int g = 0; g < 3; g++) { + size_t sizes[2]; int tags[2]; + for (int i = 0; i < 2; i++) { + sizes[i] = EAGER_SIZE; tags[i] = i; + OFINCCLTHROW(allocate_buff(&all_rbufs[g][i], EAGER_SIZE, btype)); + OFINCCLTHROW(ext_net->regMr(rComm, all_rbufs[g][i], EAGER_SIZE, btype, &all_rmh[g][i])); + } + post_recv(ext_net, rComm, 2, all_rbufs[g], sizes, tags, all_rmh[g], &all_reqs[g]); + } + for (int g = 0; g < 3; g++) { + int rsizes[2] = {}; + poll_recv(ext_net, all_reqs[g], rsizes, 2); + for (int i = 0; i < 2; i++) { + char pattern = 'A' + g * 2 + i; + char *exp = nullptr; + OFINCCLTHROW(allocate_buff((void**)&exp, EAGER_SIZE, btype)); + OFINCCLTHROW(initialize_buff(exp, EAGER_SIZE, btype, pattern)); + OFINCCLTHROW(validate_data((char*)all_rbufs[g][i], exp, EAGER_SIZE, btype)); + deallocate_buffer(exp, btype); + ext_net->deregMr(rComm, all_rmh[g][i]); + deallocate_buffer(all_rbufs[g][i], btype); + } + } + } + } + } +}; + +/* ================================================================ + * T22: Heavy interleaving: 2x grouped(n=8), every other send from msg2 + * Send order: m1t0, m2t0, m1t1, m2t1, m1t2, m2t2, ... m1t7, m2t7 + * Maximum interleaving — every send alternates between the two groups. + * All writes (large). + * ================================================================ */ +class Test22_MaxInterleaveWrite : public TestScenario { +public: + Test22_MaxInterleaveWrite() + : TestScenario("T22: Max interleave 2x grouped(n=8), all write") {} + void run(ThreadContext &ctx) override { + test_nccl_properties_t props = {}; + OFINCCLTHROW(ext_net->getProperties(0, &props)); + if (props.maxRecvs < 8) return; + + for (size_t d = 0; d < 1 && d < ctx.lcomms.size(); d++) { + void *sComm = ctx.scomms[d], *rComm = ctx.rcomms[d]; + int btype = NCCL_PTR_HOST; + constexpr int N = 8; + + if (ctx.rank == 0) { + constexpr int TOTAL = 16; + void *sbufs[TOTAL] = {}, *smhs[TOTAL] = {}, *sreqs[TOTAL] = {}; + int stags[TOTAL]; + int si = 0; + for (int tag = 0; tag < N; tag++) { + for (int g = 0; g < 2; g++) { + char pattern = 'A' + g * N + tag; + stags[si] = tag; + OFINCCLTHROW(allocate_buff(&sbufs[si], LARGE_SIZE, btype)); + OFINCCLTHROW(initialize_buff(sbufs[si], LARGE_SIZE, btype, pattern)); + OFINCCLTHROW(ext_net->regMr(sComm, sbufs[si], LARGE_SIZE, btype, &smhs[si])); + si++; + } + } + size_t ssizes_arr[TOTAL]; + for (int i = 0; i < TOTAL; i++) ssizes_arr[i] = LARGE_SIZE; + post_sends_interleaved(ext_net, sComm, + sbufs, ssizes_arr, stags, smhs, sreqs, TOTAL); + for (int i = 0; i < TOTAL; i++) { + ext_net->deregMr(sComm, smhs[i]); + deallocate_buffer(sbufs[i], btype); + } + } else { + /* Post all recvs before polling to avoid deadlock from interleaved eager sends */ + void *all_rbufs[2][8] = {}; + void *all_rmh[2][8] = {}; + void *all_reqs[2] = {}; + for (int g = 0; g < 2; g++) { + size_t sizes[8]; int tags[8]; + for (int i = 0; i < 8; i++) { + sizes[i] = SEND_SIZE; tags[i] = i; + OFINCCLTHROW(allocate_buff(&all_rbufs[g][i], SEND_SIZE, btype)); + OFINCCLTHROW(ext_net->regMr(rComm, all_rbufs[g][i], SEND_SIZE, btype, &all_rmh[g][i])); + } + post_recv(ext_net, rComm, 8, all_rbufs[g], sizes, tags, all_rmh[g], &all_reqs[g]); + } + for (int g = 0; g < 2; g++) { + int rsizes[8] = {}; + poll_recv(ext_net, all_reqs[g], rsizes, 8); + for (int i = 0; i < 8; i++) { + char pattern = 'A' + g * 8 + i; + char *exp = nullptr; + OFINCCLTHROW(allocate_buff((void**)&exp, SEND_SIZE, btype)); + OFINCCLTHROW(initialize_buff(exp, SEND_SIZE, btype, pattern)); + OFINCCLTHROW(validate_data((char*)all_rbufs[g][i], exp, SEND_SIZE, btype)); + deallocate_buffer(exp, btype); + ext_net->deregMr(rComm, all_rmh[g][i]); + deallocate_buffer(all_rbufs[g][i], btype); + } + } + } + } + } +}; + +/* ================================================================ + * T23: Eager spanning multiple grouped recvs + * Rank 0 sends 12 small messages: tags [1,2,3,4, 10,11,12,13, 20,21,22,23]. + * Rank 1 waits 50ms, then posts 3 grouped recvs (n=4 each): + * group 1: tags [1,2,3,4] + * group 2: tags [10,11,12,13] + * group 3: tags [20,21,22,23] + * Tests: eager batch spanning across 3 grouped receives with 4 sub-recvs each. + * ================================================================ */ +class Test23_EagerSpanMultiGroups : public TestScenario { +public: + Test23_EagerSpanMultiGroups() + : TestScenario("T23: Eager spanning 3 grouped recvs (n=4)") {} + void run(ThreadContext &ctx) override { + test_nccl_properties_t props = {}; + OFINCCLTHROW(ext_net->getProperties(0, &props)); + if (props.maxRecvs < 4) return; + + for (size_t d = 0; d < ctx.lcomms.size(); d++) { + void *sComm = ctx.scomms[d], *rComm = ctx.rcomms[d]; + int btype = NCCL_PTR_HOST; + constexpr int NGROUPS = 3, NSUB = 4; + constexpr int TOTAL = NGROUPS * NSUB; + int stags[TOTAL] = {1,2,3,4, 10,11,12,13, 20,21,22,23}; + + if (ctx.rank == 0) { + void *reqs[TOTAL] = {}; + void *bufs[TOTAL] = {}, *mhs[TOTAL] = {}; + for (int i = 0; i < TOTAL; i++) { + OFINCCLTHROW(allocate_buff(&bufs[i], EAGER_SIZE, btype)); + OFINCCLTHROW(initialize_buff(bufs[i], EAGER_SIZE, btype, 'A' + i)); + OFINCCLTHROW(ext_net->regMr(sComm, bufs[i], EAGER_SIZE, btype, &mhs[i])); + post_send(ext_net, sComm, bufs[i], EAGER_SIZE, stags[i], mhs[i], &reqs[i]); + } + poll_sends(ext_net, reqs, TOTAL); + for (int i = 0; i < TOTAL; i++) { + ext_net->deregMr(sComm, mhs[i]); + deallocate_buffer(bufs[i], btype); + } + } else { + usleep(50000); + /* Post all grouped recvs before polling to avoid deadlock + * from interleaved eager sends */ + void *all_rbufs[NGROUPS][NSUB] = {}; + void *all_rmh[NGROUPS][NSUB] = {}; + void *reqs[NGROUPS] = {}; + for (int g = 0; g < NGROUPS; g++) { + size_t sizes[NSUB]; int tags[NSUB]; + for (int i = 0; i < NSUB; i++) { + sizes[i] = EAGER_SIZE; + tags[i] = stags[g * NSUB + i]; + OFINCCLTHROW(allocate_buff(&all_rbufs[g][i], EAGER_SIZE, btype)); + OFINCCLTHROW(ext_net->regMr(rComm, all_rbufs[g][i], EAGER_SIZE, btype, &all_rmh[g][i])); + } + post_recv(ext_net, rComm, NSUB, all_rbufs[g], sizes, tags, all_rmh[g], &reqs[g]); + } + /* Now poll and validate all */ + for (int g = 0; g < NGROUPS; g++) { + int rsizes[NSUB] = {}; + poll_recv(ext_net, reqs[g], rsizes, NSUB); + for (int i = 0; i < NSUB; i++) { + if (rsizes[i] != (int)EAGER_SIZE) + throw std::runtime_error("T23: wrong size group " + std::to_string(g) + " sub " + std::to_string(i)); + char *exp = nullptr; + OFINCCLTHROW(allocate_buff((void**)&exp, EAGER_SIZE, btype)); + OFINCCLTHROW(initialize_buff(exp, EAGER_SIZE, btype, 'A' + g * NSUB + i)); + OFINCCLTHROW(validate_data((char*)all_rbufs[g][i], exp, EAGER_SIZE, btype)); + deallocate_buffer(exp, btype); + ext_net->deregMr(rComm, all_rmh[g][i]); + deallocate_buffer(all_rbufs[g][i], btype); + } + } + } + } + } +}; + +/* ================================================================ + * T24: Eager spanning single + grouped + single recvs + * Rank 0 sends 6 small messages: tags [1, 50,51, 1, 60,61]. + * Rank 1 waits 50ms, then posts: + * single recv (tag=1) + * grouped recv (n=2, tags=[50,51]) + * single recv (tag=1) + * grouped recv (n=2, tags=[60,61]) + * Tests: eager batch spanning alternating single and grouped recvs. + * ================================================================ */ +class Test24_EagerAlternatingSingleGrouped : public TestScenario { +public: + Test24_EagerAlternatingSingleGrouped() + : TestScenario("T24: Eager alternating single+grouped recvs") {} + void run(ThreadContext &ctx) override { + test_nccl_properties_t props = {}; + OFINCCLTHROW(ext_net->getProperties(0, &props)); + if (props.maxRecvs < 2) return; + + for (size_t d = 0; d < ctx.lcomms.size(); d++) { + void *sComm = ctx.scomms[d], *rComm = ctx.rcomms[d]; + int btype = NCCL_PTR_HOST; + constexpr int TOTAL = 6; + int stags[TOTAL] = {1, 50, 51, 1, 60, 61}; + + if (ctx.rank == 0) { + void *reqs[TOTAL] = {}; + void *bufs[TOTAL] = {}, *mhs[TOTAL] = {}; + for (int i = 0; i < TOTAL; i++) { + OFINCCLTHROW(allocate_buff(&bufs[i], EAGER_SIZE, btype)); + OFINCCLTHROW(initialize_buff(bufs[i], EAGER_SIZE, btype, 'P' + i)); + OFINCCLTHROW(ext_net->regMr(sComm, bufs[i], EAGER_SIZE, btype, &mhs[i])); + post_send(ext_net, sComm, bufs[i], EAGER_SIZE, stags[i], mhs[i], &reqs[i]); + } + poll_sends(ext_net, reqs, TOTAL); + for (int i = 0; i < TOTAL; i++) { + ext_net->deregMr(sComm, mhs[i]); + deallocate_buffer(bufs[i], btype); + } + } else { + usleep(50000); + /* Pattern: single, grouped(2), single, grouped(2) */ + struct { int n; int tags[2]; } recvs[4] = { + {1, {1, 0}}, + {2, {50, 51}}, + {1, {1, 0}}, + {2, {60, 61}}, + }; + /* Post all recvs before polling to avoid deadlock + * from interleaved eager sends */ + void *all_rbufs[4][2] = {}; + void *all_rmh[4][2] = {}; + void *all_reqs[4] = {}; + for (int r = 0; r < 4; r++) { + int n = recvs[r].n; + size_t sizes[2]; int tags[2]; + for (int i = 0; i < n; i++) { + sizes[i] = EAGER_SIZE; + tags[i] = recvs[r].tags[i]; + OFINCCLTHROW(allocate_buff(&all_rbufs[r][i], EAGER_SIZE, btype)); + OFINCCLTHROW(ext_net->regMr(rComm, all_rbufs[r][i], EAGER_SIZE, btype, &all_rmh[r][i])); + } + post_recv(ext_net, rComm, n, all_rbufs[r], sizes, tags, all_rmh[r], &all_reqs[r]); + } + /* Now poll and validate all */ + int send_idx = 0; + for (int r = 0; r < 4; r++) { + int n = recvs[r].n; + int rsizes[2] = {}; + poll_recv(ext_net, all_reqs[r], rsizes, n); + for (int i = 0; i < n; i++) { + if (rsizes[i] != (int)EAGER_SIZE) + throw std::runtime_error("T24: wrong size recv " + std::to_string(r) + " sub " + std::to_string(i)); + char *exp = nullptr; + OFINCCLTHROW(allocate_buff((void**)&exp, EAGER_SIZE, btype)); + OFINCCLTHROW(initialize_buff(exp, EAGER_SIZE, btype, 'P' + send_idx + i)); + OFINCCLTHROW(validate_data((char*)all_rbufs[r][i], exp, EAGER_SIZE, btype)); + deallocate_buffer(exp, btype); + ext_net->deregMr(rComm, all_rmh[r][i]); + deallocate_buffer(all_rbufs[r][i], btype); + } + send_idx += n; + } + } + } + } +}; + + +int main(int argc, char *argv[]) +{ + TestSuite suite; + + Test5_SingleEagerLate t5; + Test6_SingleEagerEarly t6; + Test7_MultiSeqEager t7; + Test8_GroupedAllEager t8; + Test9_GroupedMixed t9; + Test10_GroupedNoEager t10; + Test11_EagerAcrossSingleGrouped t11; + Test12_TagPushback t12; + Test13_QueueFull t13; + Test14_SizeBoundary t14; + Test15_PermutationEagerWrite t15; + Test16_OrderingPerTag t16; + + suite.add(&t5); + suite.add(&t6); + suite.add(&t7); + suite.add(&t8); + suite.add(&t9); + suite.add(&t10); + suite.add(&t11); + suite.add(&t12); + suite.add(&t13); + suite.add(&t14); + suite.add(&t15); + suite.add(&t16); + + Test17_InterleavedAllWrite t17; + Test18_InterleavedAllEager t18; + Test19_InterleavedMixed t19; + Test21_ThreeGroupsInterleaved t21; + Test22_MaxInterleaveWrite t22; + + suite.add(&t17); + suite.add(&t18); + suite.add(&t19); + suite.add(&t21); + suite.add(&t22); + + Test23_EagerSpanMultiGroups t23; + Test24_EagerAlternatingSingleGrouped t24; + suite.add(&t23); + suite.add(&t24); + + return suite.run_all(); +} diff --git a/tests/functional/functional_test.cpp b/tests/functional/functional_test.cpp index 5cdb114026..f0dae2f44c 100644 --- a/tests/functional/functional_test.cpp +++ b/tests/functional/functional_test.cpp @@ -202,6 +202,14 @@ ncclResult_t validate_data(char *recv_buf, char *expected_buf, size_t size, int case NCCL_PTR_HOST: ret = memcmp(recv_buf, expected_buf, size); if (ret != 0) { + // Find first mismatch + for (size_t i = 0; i < size; i++) { + if (recv_buf[i] != expected_buf[i]) { + NCCL_OFI_WARN("Data validation failed at byte %zu: recv=0x%02x expected=0x%02x (size=%zu)", + i, (unsigned char)recv_buf[i], (unsigned char)expected_buf[i], size); + break; + } + } NCCL_OFI_WARN("Data validation check failed. RC: %d, Buffer Type: %d", ret, buffer_type); return ncclSystemError; @@ -952,6 +960,15 @@ ncclResult_t TestSuite::run_all() test_nccl_net_config_t config = { .trafficClass = -1 }; OFINCCLCHECK(ext_net->init(&net_ctx, 0, &config, functional_test_logger, nullptr)); + /* Call getProperties on all devices to initialize device-level flags + * (e.g., supports_eager_header) before creating communicators. */ + int ndev; + OFINCCLCHECK(ext_net->devices(&ndev)); + for (int dev = 0; dev < ndev; dev++) { + test_nccl_properties_t props = {}; + OFINCCLCHECK(ext_net->getProperties(dev, &props)); + } + int passed = 0; for (const auto &test : tests) { test->set_ext_net(ext_net, net_ctx); diff --git a/tests/unit/ctrl_msg.cpp b/tests/unit/ctrl_msg.cpp index 1ce45007f4..dd2e6cb30d 100644 --- a/tests/unit/ctrl_msg.cpp +++ b/tests/unit/ctrl_msg.cpp @@ -21,6 +21,7 @@ #include "unit_test.h" #include "nccl_ofi.h" #include "nccl_ofi_rdma.h" +#include "nccl_ofi_dlist.h" /* Replicate the tag-matching search from get_ctrl_msg_buff_len / update_send_data_from_remote */ static nccl_net_ofi_ctrl_msg_entry_t *find_entry_by_tag(nccl_net_ofi_ctrl_msg_t *ctrl, int tag) @@ -211,6 +212,163 @@ static int test_max_recvs_entries() return 0; } + +/* Replicate the sort helpers from nccl_ofi_rdma.cpp for unit testing */ +static inline bool test_seq_before(uint16_t a, uint16_t b) +{ + uint16_t diff = (a - b) & 0x3FF; return diff > 0x1FF; +} + +static inline bool test_eager_entry_less(const nccl_ofi_recv_eager_entry_t *a, + const nccl_ofi_recv_eager_entry_t *b) +{ + if (a->msg_seq_num != b->msg_seq_num) + return test_seq_before(a->msg_seq_num, b->msg_seq_num); + return a->eager_offset < b->eager_offset; +} + +static inline void test_sorted_insert(nccl_ofi_dlist *list, + nccl_ofi_recv_eager_entry_t *entry) +{ + nccl_ofi_dlist_node *pos = list->head.prev; + while (pos != &list->head) { + nccl_ofi_recv_eager_entry_t *existing = + nccl_ofi_dlist_entry(pos, &nccl_ofi_recv_eager_entry_t::link); + if (!test_eager_entry_less(entry, existing)) + break; + pos = pos->prev; + } + entry->link.prev = pos; + entry->link.next = pos->next; + pos->next->prev = &entry->link; + pos->next = &entry->link; +} + +static int collect_list(nccl_ofi_dlist *list, nccl_ofi_recv_eager_entry_t **out, int max) +{ + int n = 0; + nccl_ofi_dlist_node *pos; + nccl_ofi_dlist_for_each_safe(list, pos) { + if (n >= max) break; + out[n++] = nccl_ofi_dlist_entry(pos, &nccl_ofi_recv_eager_entry_t::link); + } + return n; +} + +static int test_eager_sorted_insert() +{ + nccl_ofi_dlist list; + nccl_ofi_recv_eager_entry_t entries[8] = {}; + + /* Test 1: Insert in order */ + entries[0].msg_seq_num = 1; entries[0].eager_offset = 0; + entries[1].msg_seq_num = 1; entries[1].eager_offset = 1; + entries[2].msg_seq_num = 2; entries[2].eager_offset = 0; + test_sorted_insert(&list, &entries[0]); + test_sorted_insert(&list, &entries[1]); + test_sorted_insert(&list, &entries[2]); + + nccl_ofi_recv_eager_entry_t *out[8]; + int n = collect_list(&list, out, 8); + if (n != 3 || out[0] != &entries[0] || out[1] != &entries[1] || out[2] != &entries[2]) { + NCCL_OFI_WARN("in-order insert failed"); + return 1; + } + while (!list.empty()) list.pop_front(); + + /* Test 2: Insert in reverse order */ + test_sorted_insert(&list, &entries[2]); + test_sorted_insert(&list, &entries[1]); + test_sorted_insert(&list, &entries[0]); + + n = collect_list(&list, out, 8); + if (n != 3 || out[0] != &entries[0] || out[1] != &entries[1] || out[2] != &entries[2]) { + NCCL_OFI_WARN("reverse insert failed"); + return 1; + } + while (!list.empty()) list.pop_front(); + + /* Test 3: Same seq, different offsets interleaved */ + entries[3].msg_seq_num = 1; entries[3].eager_offset = 3; + entries[4].msg_seq_num = 1; entries[4].eager_offset = 1; + entries[5].msg_seq_num = 1; entries[5].eager_offset = 2; + entries[6].msg_seq_num = 1; entries[6].eager_offset = 0; + test_sorted_insert(&list, &entries[3]); + test_sorted_insert(&list, &entries[4]); + test_sorted_insert(&list, &entries[5]); + test_sorted_insert(&list, &entries[6]); + + n = collect_list(&list, out, 8); + if (n != 4 || out[0]->eager_offset != 0 || out[1]->eager_offset != 1 || + out[2]->eager_offset != 2 || out[3]->eager_offset != 3) { + NCCL_OFI_WARN("interleaved offset insert failed"); + return 1; + } + while (!list.empty()) list.pop_front(); + + /* Test 4: Seq wraparound (1023 before 0 in 10-bit space) */ + entries[0].msg_seq_num = 1023; entries[0].eager_offset = 0; + entries[1].msg_seq_num = 0; entries[1].eager_offset = 0; + entries[2].msg_seq_num = 1; entries[2].eager_offset = 0; + test_sorted_insert(&list, &entries[2]); + test_sorted_insert(&list, &entries[0]); + test_sorted_insert(&list, &entries[1]); + + n = collect_list(&list, out, 8); + if (n != 3 || out[0]->msg_seq_num != 1023 || out[1]->msg_seq_num != 0 || out[2]->msg_seq_num != 1) { + NCCL_OFI_WARN("wraparound insert failed: got seq %d, %d, %d", + out[0]->msg_seq_num, out[1]->msg_seq_num, out[2]->msg_seq_num); + return 1; + } + while (!list.empty()) list.pop_front(); + + /* Test 5: Single element */ + entries[0].msg_seq_num = 5; entries[0].eager_offset = 2; + test_sorted_insert(&list, &entries[0]); + n = collect_list(&list, out, 8); + if (n != 1 || out[0] != &entries[0]) { + NCCL_OFI_WARN("single element insert failed"); + return 1; + } + while (!list.empty()) list.pop_front(); + + /* Test 6: Duplicate key — both should be in list */ + entries[0].msg_seq_num = 3; entries[0].eager_offset = 1; entries[0].tag = 10; + entries[1].msg_seq_num = 3; entries[1].eager_offset = 1; entries[1].tag = 20; + test_sorted_insert(&list, &entries[0]); + test_sorted_insert(&list, &entries[1]); + n = collect_list(&list, out, 8); + if (n != 2) { + NCCL_OFI_WARN("duplicate key insert failed: got %d entries", n); + return 1; + } + while (!list.empty()) list.pop_front(); + + /* Test 7: Multiple seq batches interleaved */ + entries[0].msg_seq_num = 2; entries[0].eager_offset = 1; + entries[1].msg_seq_num = 1; entries[1].eager_offset = 0; + entries[2].msg_seq_num = 2; entries[2].eager_offset = 0; + entries[3].msg_seq_num = 1; entries[3].eager_offset = 1; + test_sorted_insert(&list, &entries[0]); + test_sorted_insert(&list, &entries[1]); + test_sorted_insert(&list, &entries[2]); + test_sorted_insert(&list, &entries[3]); + + n = collect_list(&list, out, 8); + if (n != 4 || + !(out[0]->msg_seq_num == 1 && out[0]->eager_offset == 0) || + !(out[1]->msg_seq_num == 1 && out[1]->eager_offset == 1) || + !(out[2]->msg_seq_num == 2 && out[2]->eager_offset == 0) || + !(out[3]->msg_seq_num == 2 && out[3]->eager_offset == 1)) { + NCCL_OFI_WARN("multi-batch interleaved insert failed"); + return 1; + } + while (!list.empty()) list.pop_front(); + + printf("PASS: eager sorted insert\n"); + return 0; +} + int main(int argc, char *argv[]) { unit_test_init(); @@ -220,6 +378,7 @@ int main(int argc, char *argv[]) rc |= test_tag_matching(); rc |= test_ready_bit(); rc |= test_max_recvs_entries(); + rc |= test_eager_sorted_insert(); if (rc == 0) printf("All ctrl_msg tests passed\n");