diff --git a/include/nccl_ofi_rdma.h b/include/nccl_ofi_rdma.h index bdcd18d5b4..a0766ee710 100644 --- a/include/nccl_ofi_rdma.h +++ b/include/nccl_ofi_rdma.h @@ -122,6 +122,9 @@ static_assert(NCCL_OFI_RDMA_MSG_MAX <= (0x10), * size to be fixed. */ typedef uint16_t nccl_ofi_rdma_msg_type_t; +/* Forward declaration for freelist MR registration callback class */ +class freelist_regmr_ep_ctx_t; + /* * @brief Rdma memory registration handle * @@ -132,25 +135,51 @@ class nccl_net_ofi_rdma_mr_handle_t : public nccl_net_ofi_mr_handle_t { public: /** * @brief Default constructor + * + * Every MR handle covers both the data ep rails and the control ep + * rails. With FI_MR_ENDPOINT a single fid_mr may only be bound to + * one fid_ep, so we keep separate arrays for ownership and access: + * + * mr_data[0..num_rails-1] — owned fid_mr objects, bound to data ep rails + * mr_ctrl_owned[0..num_ctrl_rails-1] — owned fid_mr objects, bound to ctrl ep + * rails (FI_MR_ENDPOINT mode only) + * mr_ctrl[0..num_ctrl_rails-1] — non-owning view; always populated after + * registration: + * · FI_MR_ENDPOINT: points into mr_ctrl_owned[] + * · non-endpoint-MR: aliases mr_data[] entries + * + * Using a raw-pointer view for mr_ctrl[] eliminates IO-path conditionals: + * callers always use mr_ctrl[rail_id] without checking whether it is populated. */ - nccl_net_ofi_rdma_mr_handle_t(size_t num_rails_arg) + nccl_net_ofi_rdma_mr_handle_t(size_t num_rails_arg, size_t num_ctrl_rails_arg) : nccl_net_ofi_mr_handle_t(0), num_rails(num_rails_arg), + num_ctrl_rails(num_ctrl_rails_arg), base_addr(0) { + mr_ctrl.fill(nullptr); } /** * @brief Get MR key for RDMA handle - * - * Return MR key associated with first mr array element + * + * Return MR key associated with first data mr array element */ int get_mr_key(uint64_t *mr_key_ptr) override; uint16_t num_rails; + uint16_t num_ctrl_rails; - /* Array of size `num_rails', indexed by rail_id */ - std::array mr; + /* Owned fid_mr objects, one per data ep rail */ + std::array mr_data; + + /* Owned fid_mr objects, one per ctrl ep rail (FI_MR_ENDPOINT mode only). + * In non-endpoint-MR mode entries beyond num_ctrl_rails remain null. */ + std::array mr_ctrl_owned; + + /* Non-owning view into either mr_ctrl_owned[] (FI_MR_ENDPOINT) or mr_data[] + * (non-endpoint-MR). Populated up to num_ctrl_rails after reg_mr_on_device(). */ + std::array mr_ctrl; /* Base address of the registered memory region for offset calculation */ uintptr_t base_addr; @@ -750,6 +779,10 @@ class nccl_net_ofi_rdma_recv_comm : public nccl_net_ofi_recv_comm { /* Free list to track host flush buffers, for sending flush messages */ nccl_ofi_freelist *flush_buff_fl; + /* Context shared by ctrl_buff_fl and flush_buff_fl registration callbacks. + * Freed in recv_comm_destroy(). */ + freelist_regmr_ep_ctx_t *comm_buff_regmr_ctx = nullptr; + #if HAVE_NVTX_TRACING nvtxDomainHandle_t nvtx_domain[NCCL_OFI_N_NVTX_DOMAIN_PER_COMM]; #endif @@ -875,6 +908,7 @@ class nccl_net_ofi_rdma_domain_t : public nccl_net_ofi_domain_t { */ int reg_mr(nccl_ofi_mr_ckey_ref ckey, int type, + nccl_net_ofi_rdma_ep_t *ep, nccl_net_ofi_rdma_mr_handle_t **mhandle); /** @@ -930,11 +964,13 @@ class nccl_net_ofi_rdma_domain_t : public nccl_net_ofi_domain_t { */ int reg_internal_mr(void *data, size_t size, int type, + nccl_net_ofi_rdma_ep_t *ep, nccl_net_ofi_rdma_mr_handle_t **mhandle); #if HAVE_DECL_FI_MR_DMABUF int reg_internal_mr_dma_buf(void *data, int fd, uint64_t offset, size_t size, int type, + nccl_net_ofi_rdma_ep_t *ep, nccl_net_ofi_rdma_mr_handle_t **mhandle); #endif /** @@ -951,9 +987,6 @@ class nccl_net_ofi_rdma_domain_t : public nccl_net_ofi_domain_t { uint16_t num_rails; std::array domain_rails; - /* The flush buffer */ - nccl_net_ofi_rdma_flush_buffer_t flush_buff; - /* List of endpoints and set of addresses they have connections to */ nccl_ofi_ep_addr_list_t ep_addr_list; @@ -961,36 +994,39 @@ class nccl_net_ofi_rdma_domain_t : public nccl_net_ofi_domain_t { /** * @brief RDMA domain destructor. * - * Cleans up RDMA domain resources (flush buffers, ep_table). + * Cleans up RDMA domain resources (ep_table). */ ~nccl_net_ofi_rdma_domain_t() override; - /** - * @brief Allocated and registers buffer to flush RDMA operations. On - * Success, receive communicator holds reference to flush buffer - * and associated memory handle. - * - * @param dev_id - * Device ID - * - * @return 0, on success - * error, on others - */ - int alloc_and_reg_flush_buff(int dev_id); - - /** - * @brief Deregister flush buffer if flush buffer was registered. Deallocate flush buffer. - * - * @return 0, on success - * error, on others - */ - int dealloc_and_dereg_flush_buff(); - private: int reg_mr_on_device(nccl_ofi_mr_ckey_ref ckey, int type, + nccl_net_ofi_rdma_ep_t *ep, nccl_net_ofi_rdma_mr_handle_t **mhandle); + /** + * @brief Bind and enable a memory region on an endpoint + * + * Calls fi_mr_bind followed by fi_mr_enable on the given MR and endpoint. + * + * @param mr + * The memory region to bind and enable + * @param ofi_ep + * The endpoint to bind the MR to + * @param rail_id + * Rail index, used only for diagnostic messages + * @param rail_type + * Human-readable rail type string (e.g. "data" or "ctrl"), + * used only for diagnostic messages + * + * @return 0 on success + * non-zero on error + */ + int mr_bind_and_enable(struct fid_mr *mr, + struct fid_ep *ofi_ep, + uint16_t rail_id, + const char *rail_type); + /** * @brief Deregister memory region without acquiring memory region cache lock * @@ -1076,6 +1112,9 @@ class nccl_net_ofi_rdma_ep_rail_t { /* Mutex for rx buffer operations */ pthread_mutex_t rx_buff_mutex; + /* True if this rail posts control (ctrl) rx buffers; false for data (eager) */ + bool is_ctrl = false; + /* Allocate a receive buffer request for this rail (eager or ctrl) */ nccl_net_ofi_rdma_req* (*rx_buff_req_alloc)(nccl_net_ofi_rdma_ep_t *ep, nccl_net_ofi_rdma_ep_rail_t *rail); @@ -1159,7 +1198,6 @@ class nccl_net_ofi_rdma_ep_t : public nccl_net_ofi_ep_t { */ inline nccl_net_ofi_rdma_ep_rail_t *rdma_endpoint_get_rail(uint16_t rail_id) { - assert(!rails.empty()); assert(rail_id < num_rails); return &rails[rail_id]; } @@ -1169,7 +1207,6 @@ class nccl_net_ofi_rdma_ep_t : public nccl_net_ofi_ep_t { */ inline nccl_net_ofi_rdma_ep_rail_t *rdma_endpoint_get_control_rail(uint16_t rail_id) { - assert(!control_rails.empty()); assert(rail_id < num_control_rails); return &control_rails[rail_id]; } @@ -1179,7 +1216,6 @@ class nccl_net_ofi_rdma_ep_t : public nccl_net_ofi_ep_t { */ inline nccl_net_ofi_rdma_cq_rail_t *rdma_endpoint_get_cq_rail(uint16_t rail_id) { - assert(!cq_rails.empty()); assert(rail_id < num_rails); return &cq_rails[rail_id]; } @@ -1190,7 +1226,7 @@ class nccl_net_ofi_rdma_ep_t : public nccl_net_ofi_ep_t { */ inline ofi_cq_ptr &get_ofi_cq_for_cm() override { - assert(!cq_rails.empty()); + assert(num_rails > 0); return cq_rails[0].cq; } @@ -1316,13 +1352,13 @@ class nccl_net_ofi_rdma_ep_t : public nccl_net_ofi_ep_t { uint16_t num_control_rails; /* Array of `num_rails` endpoint rails */ - std::vector rails; + std::array rails = {}; /* Array of `num_control_rails` endpoint rails */ - std::vector control_rails; + std::array control_rails = {}; /* Array of `num_rails` cq rails */ - std::vector cq_rails; + std::array cq_rails = {}; /* Pending requests queue */ std::deque pending_reqs_queue; @@ -1335,6 +1371,23 @@ class nccl_net_ofi_rdma_ep_t : public nccl_net_ofi_ep_t { nccl_ofi_freelist *eager_rx_buff_fl = nullptr; /* Free list of rx buffer requests */ nccl_ofi_freelist *rx_buff_reqs_fl = nullptr; + /* Context passed to the freelist MR registration callback for rx buffer freelists. + * Heap-allocated freelist_regmr_ep_ctx_t; freed in fini_rx_buffers(). */ + freelist_regmr_ep_ctx_t *rx_buff_regmr_ctx = nullptr; + /* MR handle for the flush buffer used by this endpoint. + * In FI_MR_ENDPOINT mode this is a per-endpoint registration bound to + * the endpoint's own flush_buff allocation; otherwise it aliases the + * domain's shared flush buffer MR. + * Set in init_rx_buffers(); the per-endpoint registration (if any) + * is deregistered in fini_rx_buffers(). */ + nccl_net_ofi_rdma_mr_handle_t *flush_buff_mr_handle = nullptr; + /* Flush buffer for this endpoint. + * Set in init_rx_buffers(). + * In FI_MR_ENDPOINT mode: owns a per-endpoint GPU allocation + * (buffer_base / buffer); freed in fini_rx_buffers(). + * In non-endpoint-MR mode: buffer aliases the domain's + * flush_buff.buffer; buffer_base is nullptr (not owned). */ + nccl_net_ofi_rdma_flush_buffer_t flush_buff = {}; /* Size of ctrl rx buffers */ size_t ctrl_rx_buff_size; /* Size of eager rx buffers. Will be -1 if eager is entirely @@ -1376,6 +1429,26 @@ class nccl_net_ofi_rdma_ep_t : public nccl_net_ofi_ep_t { */ int init_rx_buffers(); + /** + * @brief Allocate and register a buffer used to flush RDMA operations. + * + * Allocates a GPU (or host, for Neuron) flush buffer into @p fb and + * registers it as an MR via the domain. In FI_MR_ENDPOINT mode the + * MR is bound and enabled on this endpoint; otherwise it is a + * domain-scoped registration. The resulting MR handle is written to + * @p mr_out. + * + * @param dev_id Device ID (used in error messages) + * @param fb Flush-buffer struct to populate + * @param mr_out Receives the registered MR handle + * + * @return 0, on success + * error, on others + */ + int alloc_and_reg_flush_buff(int dev_id, + nccl_net_ofi_rdma_flush_buffer_t &fb, + nccl_net_ofi_rdma_mr_handle_t *&mr_out); + /** * @brief Initialize libfabric resources of endpoint rails */ diff --git a/src/nccl_ofi_rdma.cpp b/src/nccl_ofi_rdma.cpp index 7b27569863..5051067ff3 100644 --- a/src/nccl_ofi_rdma.cpp +++ b/src/nccl_ofi_rdma.cpp @@ -2652,8 +2652,32 @@ int nccl_net_ofi_rdma_domain_t::dereg_mr_no_lock(nccl_net_ofi_rdma_mr_handle_t * } +int nccl_net_ofi_rdma_domain_t::mr_bind_and_enable(struct fid_mr *mr, + struct fid_ep *ofi_ep, + uint16_t rail_id, + const char *rail_type) +{ + int ret = fi_mr_bind(mr, &ofi_ep->fid, 0); + if (ret != 0) { + NCCL_OFI_WARN("fi_mr_bind (%s rail %u) failed: rc=%d, error=%s", + rail_type, rail_id, ret, fi_strerror(-ret)); + return ret; + } + + ret = fi_mr_enable(mr); + if (ret != 0) { + NCCL_OFI_WARN("fi_mr_enable failed on %s rail %u with rc=%d, error=%s", + rail_type, rail_id, ret, fi_strerror(-ret)); + return ret; + } + + return 0; +} + + int nccl_net_ofi_rdma_domain_t::reg_mr_on_device(nccl_ofi_mr_ckey_ref ckey, int type, + nccl_net_ofi_rdma_ep_t *ep, nccl_net_ofi_rdma_mr_handle_t **mhandle) { int ret = 0; @@ -2663,8 +2687,19 @@ int nccl_net_ofi_rdma_domain_t::reg_mr_on_device(nccl_ofi_mr_ckey_ref ckey, *mhandle = NULL; - /* Allocate rdma memory registration handle */ - auto *ret_handle = new nccl_net_ofi_rdma_mr_handle_t(num_rails); + /* Allocate rdma memory registration handle. + * + * In FI_MR_ENDPOINT mode (ep != nullptr) a separate fid_mr is registered + * for each ctrl ep rail, so nctrl = ep->num_control_rails. + * + * In non-endpoint-MR mode (ep == nullptr) a single domain-level fid_mr is + * shared across all rails. We still size mr_ctrl[] to num_rails and alias + * each entry to the corresponding mr_data[] entry so that IO-path callers + * can always use mr_ctrl[rail_id] unconditionally, without checking + * whether the vector is populated. */ + uint16_t nctrl = ep ? ep->num_control_rails + : static_cast(num_rails); + auto *ret_handle = new nccl_net_ofi_rdma_mr_handle_t(num_rails, nctrl); if (key_pool->get_size() != 0) { auto key = key_pool->allocate_id(); @@ -2684,7 +2719,7 @@ int nccl_net_ofi_rdma_domain_t::reg_mr_on_device(nccl_ofi_mr_ckey_ref ckey, goto error; } - /* Register memory on each rail */ + /* Register and bind memory on each data rail */ for (uint16_t rail_id = 0; rail_id != num_rails; ++rail_id) { nccl_net_ofi_rdma_domain_rail_t *domain_rail = this->rdma_domain_get_rail(rail_id); @@ -2692,12 +2727,53 @@ int nccl_net_ofi_rdma_domain_t::reg_mr_on_device(nccl_ofi_mr_ckey_ref ckey, &mr_attr, regattr_flags); if (OFI_UNLIKELY(mr_result.is_failure())) { - NCCL_OFI_WARN("Could not register memory on rail %u with flag %lu", + NCCL_OFI_WARN("Could not register memory on data rail %u with flag %lu", rail_id, regattr_flags); ret = mr_result.error_code; goto error; } - ret_handle->mr[rail_id] = std::move(mr_result.resource); + ret_handle->mr_data[rail_id] = std::move(mr_result.resource); + + if (ep != nullptr) { + nccl_net_ofi_rdma_ep_rail_t *ep_rail = ep->rdma_endpoint_get_rail(rail_id); + struct fid_ep *ofi_ep = ep_rail->ofi_ep.get(); + ret = mr_bind_and_enable(ret_handle->mr_data[rail_id].get(), ofi_ep, + rail_id, "data"); + if (ret != 0) goto error; + } + } + + if (ep != nullptr) { + /* FI_MR_ENDPOINT: register a separate fid_mr per ctrl rail, bind it + * to the ctrl ep, and point mr_ctrl[] at the owned objects. */ + for (uint16_t rail_id = 0; rail_id != nctrl; ++rail_id) { + nccl_net_ofi_rdma_domain_rail_t *domain_rail = this->rdma_domain_get_rail(rail_id); + + auto mr_result = nccl_ofi_ofiutils_mr_regattr(domain_rail->domain, + &mr_attr, + regattr_flags); + if (OFI_UNLIKELY(mr_result.is_failure())) { + NCCL_OFI_WARN("Could not register memory on ctrl rail %u with flag %lu", + rail_id, regattr_flags); + ret = mr_result.error_code; + goto error; + } + ret_handle->mr_ctrl_owned[rail_id] = std::move(mr_result.resource); + + nccl_net_ofi_rdma_ep_rail_t *ctrl_rail = ep->rdma_endpoint_get_control_rail(rail_id); + struct fid_ep *ctrl_ofi_ep = ctrl_rail->ofi_ep.get(); + ret = mr_bind_and_enable(ret_handle->mr_ctrl_owned[rail_id].get(), ctrl_ofi_ep, + rail_id, "ctrl"); + if (ret != 0) goto error; + + ret_handle->mr_ctrl[rail_id] = ret_handle->mr_ctrl_owned[rail_id].get(); + } + } else { + /* Non-endpoint-MR: the domain-level fid_mr is valid for all rails. + * Alias mr_ctrl[i] to mr_data[i] so IO-path callers need no conditional. */ + for (uint16_t rail_id = 0; rail_id != nctrl; ++rail_id) { + ret_handle->mr_ctrl[rail_id] = ret_handle->mr_data[rail_id].get(); + } } /* Store base address of registered memory region for offset calculations. @@ -2716,6 +2792,7 @@ int nccl_net_ofi_rdma_domain_t::reg_mr_on_device(nccl_ofi_mr_ckey_ref ckey, int nccl_net_ofi_rdma_domain_t::reg_mr(nccl_ofi_mr_ckey_ref ckey, int type, + nccl_net_ofi_rdma_ep_t *ep, nccl_net_ofi_rdma_mr_handle_t **mhandle) { int ret = 0; @@ -2738,7 +2815,7 @@ int nccl_net_ofi_rdma_domain_t::reg_mr(nccl_ofi_mr_ckey_ref ckey, } /* Cache miss */ - ret = this->reg_mr_on_device(ckey, type, &ret_handle); + ret = this->reg_mr_on_device(ckey, type, ep, &ret_handle); if (OFI_UNLIKELY(ret != 0)) { return ret; } @@ -2753,7 +2830,7 @@ int nccl_net_ofi_rdma_domain_t::reg_mr(nccl_ofi_mr_ckey_ref ckey, return ret; } } else { - ret = this->reg_mr_on_device(ckey, type, &ret_handle); + ret = this->reg_mr_on_device(ckey, type, ep, &ret_handle); if (OFI_UNLIKELY(ret != 0)) { return ret; } @@ -2765,35 +2842,29 @@ int nccl_net_ofi_rdma_domain_t::reg_mr(nccl_ofi_mr_ckey_ref ckey, int nccl_net_ofi_rdma_domain_t::reg_internal_mr(void *data, - size_t size, int type, - nccl_net_ofi_rdma_mr_handle_t **mhandle) + size_t size, int type, + nccl_net_ofi_rdma_ep_t *ep, + nccl_net_ofi_rdma_mr_handle_t **mhandle) { assert(system_page_size > 0); assert(NCCL_OFI_IS_PTR_ALIGNED(data, system_page_size)); assert(NCCL_OFI_IS_ALIGNED(size, system_page_size)); - /* TODO: When the endpoint mr feature is supported for RDMA plugin - * pass the endpoint during the mr key create below. For now, we are - * passing nullptr - */ - const nccl_ofi_mr_ckey_t ckey = nccl_ofi_mr_ckey_mk_vec(data, size, nullptr); - return this->reg_mr(&ckey, type, mhandle); + const nccl_ofi_mr_ckey_t ckey = nccl_ofi_mr_ckey_mk_vec(data, size, ep); + return this->reg_mr(&ckey, type, ep, mhandle); } #if HAVE_DECL_FI_MR_DMABUF int nccl_net_ofi_rdma_domain_t::reg_internal_mr_dma_buf(void *data, - int fd, uint64_t offset, size_t size, int type, - nccl_net_ofi_rdma_mr_handle_t **mhandle) + int fd, uint64_t offset, size_t size, int type, + nccl_net_ofi_rdma_ep_t *ep, + nccl_net_ofi_rdma_mr_handle_t **mhandle) { assert(NCCL_OFI_IS_PTR_ALIGNED(data, system_page_size)); assert(NCCL_OFI_IS_ALIGNED(size, system_page_size)); - /* TODO: When the endpoint mr feature is supported for RDMA plugin - * pass the endpoint during the mr key create below. For now, we are - * passing nullptr - */ - const nccl_ofi_mr_ckey_t ckey = nccl_ofi_mr_ckey_mk_dmabuf(fd, offset, size, data, nullptr); - return this->reg_mr(&ckey, type, mhandle); + const nccl_ofi_mr_ckey_t ckey = nccl_ofi_mr_ckey_mk_dmabuf(fd, offset, size, data, ep); + return this->reg_mr(&ckey, type, ep, mhandle); } #endif @@ -2808,11 +2879,12 @@ int nccl_net_ofi_rdma_send_comm::regMr(nccl_ofi_mr_ckey_ref ckey, return domain->reg_mr(ckey, type_param, + endpoint_mr ? endpoint : nullptr, (nccl_net_ofi_rdma_mr_handle_t **)mhandle); } int nccl_net_ofi_rdma_recv_comm::regMr(nccl_ofi_mr_ckey_ref ckey, - int type_param, void **mhandle) + int type_param, void **mhandle) { nccl_net_ofi_rdma_ep_t *endpoint = (nccl_net_ofi_rdma_ep_t *)this->ep.get(); nccl_net_ofi_rdma_domain_t *domain = endpoint->rdma_endpoint_get_domain(); @@ -2822,6 +2894,7 @@ int nccl_net_ofi_rdma_recv_comm::regMr(nccl_ofi_mr_ckey_ref ckey, return domain->reg_mr(ckey, type_param, + endpoint_mr ? endpoint : nullptr, (nccl_net_ofi_rdma_mr_handle_t **)mhandle); } @@ -2831,58 +2904,80 @@ typedef struct { } freelist_regmr_fn_handle_t; /** - * Register host memory for use with the given communicator - * - * This interface is suitable for use with freelist_init_mr. - * - * @param data - * Pointer to memory region. Must be aligned to page size. - * @param size - * Size of memory region. Must be a multiple of page size. + * C++ class encapsulating the MR registration and deregistration callbacks + * for freelists. An instance is heap-allocated per communicator or endpoint + * and passed as the opaque context pointer to the freelist. The static + * shims regmr_fn / deregmr_fn cast back to this type, eliminating bare + * void * opaque usage at call sites. */ -static int freelist_regmr_host_fn(void *domain_void_ptr, void *data, size_t size, void **handle) -{ - nccl_net_ofi_rdma_domain_t *domain = (nccl_net_ofi_rdma_domain_t *)domain_void_ptr; +class freelist_regmr_ep_ctx_t { +public: + freelist_regmr_ep_ctx_t(nccl_net_ofi_rdma_domain_t *domain_arg, + nccl_net_ofi_rdma_ep_t *ep_arg) + : domain(domain_arg), ep(ep_arg) {} - nccl_net_ofi_rdma_mr_handle_t *mr_handle; + /** + * Register host memory via the domain MR path. + * + * @param data Pointer to page-aligned memory region. + * @param size Size of memory region (page-multiple). + * @param handle Receives a heap-allocated freelist_regmr_fn_handle_t *. + */ + int regmr(void *data, size_t size, void **handle) + { + nccl_net_ofi_rdma_mr_handle_t *mr_handle; + + freelist_regmr_fn_handle_t *freelist_handle = + (freelist_regmr_fn_handle_t *)malloc(sizeof(freelist_regmr_fn_handle_t)); + if (!freelist_handle) { + NCCL_OFI_WARN("Failed to allocate memory for freelist handle"); + return -ENOMEM; + } - freelist_regmr_fn_handle_t *freelist_handle = - (freelist_regmr_fn_handle_t *)malloc(sizeof(freelist_regmr_fn_handle_t)); - if (!freelist_handle) { - NCCL_OFI_WARN("Failed to allocate memory for freelist handle"); - return -ENOMEM; + int ret = domain->reg_internal_mr(data, size, NCCL_PTR_HOST, ep, &mr_handle); + if (ret != 0) { + NCCL_OFI_WARN("Failed call to reg_mr: %d", ret); + free(freelist_handle); + return -EIO; + } + + freelist_handle->mr_handle = mr_handle; + freelist_handle->domain = domain; + *handle = static_cast(freelist_handle); + return 0; } - int ret = domain->reg_internal_mr(data, size, NCCL_PTR_HOST, &mr_handle); - if (ret != 0) { - NCCL_OFI_WARN("Failed call to reg_mr: %d", ret); + /** Deregister memory that was registered via regmr(). */ + static int deregmr(void *handle) + { + freelist_regmr_fn_handle_t *freelist_handle = + static_cast(handle); + assert(freelist_handle); + int ret = freelist_handle->domain->dereg_mr(freelist_handle->mr_handle); + if (OFI_UNLIKELY(ret != 0)) { + NCCL_OFI_WARN("Failed call to dereg_mr"); + return -EIO; + } free(freelist_handle); - return -EIO; + return 0; } - freelist_handle->mr_handle = mr_handle; - freelist_handle->domain = domain; - *handle = (void *)freelist_handle; - return 0; -} + /** C-style shim matching nccl_ofi_freelist_regmr_fn; pass to freelist constructors. */ + static int regmr_fn(void *opaque, void *data, size_t size, void **handle) + { + return static_cast(opaque)->regmr(data, size, handle); + } -/** - * Deregister host memory registered with freelist_regmr_host_fn - * - * This interface is suitable for use with a freelist. - */ -static int freelist_deregmr_host_fn(void *handle) -{ - freelist_regmr_fn_handle_t *freelist_handle = (freelist_regmr_fn_handle_t *)handle; - assert(freelist_handle); - int ret = freelist_handle->domain->dereg_mr(freelist_handle->mr_handle); - if (OFI_UNLIKELY(ret != 0)) { - NCCL_OFI_WARN("Failed call to dereg_mr"); - return -EIO; + /** C-style shim matching nccl_ofi_freelist_deregmr_fn; pass to freelist constructors. */ + static int deregmr_fn(void *handle) + { + return deregmr(handle); } - free(freelist_handle); - return 0; -} + +private: + nccl_net_ofi_rdma_domain_t *domain; + nccl_net_ofi_rdma_ep_t *ep; /* nullptr when no ep binding needed */ +}; int nccl_net_ofi_rdma_recv_comm::deregMr(nccl_net_ofi_mr_handle_t *mhandle) { @@ -3024,7 +3119,7 @@ int nccl_net_ofi_rdma_recv_comm::allocate_recv_req( entry->recv_idx = i; entry->msg_seq_num = req->msg_seq_num & MSG_SEQ_NUM_MASK; for (uint16_t rail_id = 0; rail_id < this->num_rails; rail_id++) { - uint64_t rkey = fi_mr_key(mr_h->mr[rail_id].get()); + uint64_t rkey = fi_mr_key(mr_h->mr_data[rail_id].get()); if (rkey == FI_KEY_NOTAVAIL) { NCCL_OFI_WARN("RDMA write buffers should be pre-registered"); return -ENOENT; @@ -3111,7 +3206,6 @@ int nccl_net_ofi_rdma_recv_comm::recv(int n, void **buffers, nccl_net_ofi_rdma_req *req = NULL; rdma_req_recv_data_t *recv_data = NULL; nccl_net_ofi_rdma_ep_t *endpoint = NULL; - nccl_net_ofi_rdma_domain_t *domain = NULL; nccl_net_ofi_rdma_device_t *device = NULL; int device_id = 0; nccl_net_ofi_rdma_mr_handle_t **mr_handles = (nccl_net_ofi_rdma_mr_handle_t **)mhandles; @@ -3146,9 +3240,6 @@ int nccl_net_ofi_rdma_recv_comm::recv(int n, void **buffers, endpoint = (nccl_net_ofi_rdma_ep_t *)this->ep.get(); assert(endpoint != NULL); - domain = endpoint->rdma_endpoint_get_domain(); - assert(domain != NULL); - device = endpoint->rdma_endpoint_get_device(); assert(device != NULL); @@ -3224,8 +3315,8 @@ int nccl_net_ofi_rdma_recv_comm::recv(int n, void **buffers, */ for (i = 0 ; i < n ; i++) { if (sizes[i] == 0) { - buffers[i] = domain->flush_buff.buffer; - mr_handles[i] = domain->flush_buff.mr_handle; + buffers[i] = endpoint->flush_buff.buffer; + mr_handles[i] = endpoint->flush_buff_mr_handle; } } @@ -3309,163 +3400,143 @@ int nccl_net_ofi_rdma_recv_comm::recv(int n, void **buffers, return ret; } -int nccl_net_ofi_rdma_domain_t::dealloc_and_dereg_flush_buff() -{ - int ret = 0; - nccl_net_ofi_rdma_mr_handle_t *mr_handle = this->flush_buff.mr_handle; - if (mr_handle) { - ret = this->dereg_mr(mr_handle); - } - if (ret != 0) { - NCCL_OFI_WARN("Failed to deregister flush buffer"); - goto exit; - } - - /* - * Clean up the flush buffer only if it was mapped correctly - */ - if (this->flush_buff.buffer != MAP_FAILED) { -#if HAVE_GPU - ret = nccl_net_ofi_gpu_mem_free(this->flush_buff.buffer_base); -#endif -#if HAVE_NEURON - ret = nccl_net_ofi_dealloc_mr_buffer(this->flush_buff.buffer, - system_page_size); -#endif - if (OFI_UNLIKELY(ret != 0)) { - NCCL_OFI_WARN("Unable to deallocate flush buffer (%d)", ret); - goto exit; - } - this->flush_buff.buffer = MAP_FAILED; - } - - exit: - return ret; -} - /* - * @brief Allocated and registers GPU buffer to flush RDMA operations. On - * Success, receive domain holds reference to flush buffer - * and associated memory handle. - * - * @param dev_id - * Device ID + * @brief Allocate and register a buffer used to flush RDMA operations. * - * @return 0, on success - * error, on others + * See header for full documentation. */ -int nccl_net_ofi_rdma_domain_t::alloc_and_reg_flush_buff(int dev_id) +int nccl_net_ofi_rdma_ep_t::alloc_and_reg_flush_buff(int dev_id, + nccl_net_ofi_rdma_flush_buffer_t &fb, + nccl_net_ofi_rdma_mr_handle_t *&mr_out) { int ret = 0; nccl_net_ofi_rdma_mr_handle_t *mr_handle = NULL; + nccl_net_ofi_rdma_domain_t *rdma_domain = this->rdma_endpoint_get_domain(); + +#if HAVE_NEURON || HAVE_GPU + nccl_net_ofi_rdma_ep_t *bind_ep = endpoint_mr ? this : nullptr; +#endif #if HAVE_NEURON int rc; NCCL_OFI_TRACE(NCCL_NET, "Registering buffer for flush operations"); - this->flush_buff.size = NCCL_OFI_FLUSH_SIZE; + fb.size = NCCL_OFI_FLUSH_SIZE; assert(NCCL_OFI_FLUSH_SIZE <= system_page_size); - ret = nccl_net_ofi_alloc_mr_buffer(system_page_size, &(this->flush_buff.buffer)); + ret = nccl_net_ofi_alloc_mr_buffer(system_page_size, &(fb.buffer)); if (OFI_UNLIKELY(ret != 0)) { NCCL_OFI_WARN("Unable to allocate flush buffer (%d)", ret); return ret; } /* make sure flush destination address does not overflow beyond host buffer */ - assert(((NCCL_OFI_DEFAULT_CPU_CACHE_LINE_SIZE * this->num_rails) + this->flush_buff.size) <= system_page_size); + assert(((NCCL_OFI_DEFAULT_CPU_CACHE_LINE_SIZE * this->num_rails) + fb.size) <= system_page_size); - ret = this->reg_internal_mr(this->flush_buff.buffer, system_page_size, - NCCL_PTR_HOST, &mr_handle); + ret = rdma_domain->reg_internal_mr(fb.buffer, system_page_size, + NCCL_PTR_HOST, bind_ep, &mr_handle); if (OFI_UNLIKELY(ret != 0)) { NCCL_OFI_WARN("Could not register dummy buffer for flush, dev: %d", dev_id); - rc = nccl_net_ofi_dealloc_mr_buffer(this->flush_buff.buffer, + rc = nccl_net_ofi_dealloc_mr_buffer(fb.buffer, system_page_size); if (rc != 0) { NCCL_OFI_WARN("Unable to deallocate flush buffer (%d)", rc); } - this->flush_buff.buffer = MAP_FAILED; + fb.buffer = MAP_FAILED; } #endif #if HAVE_GPU - int rc; - NCCL_OFI_TRACE(NCCL_NET, "Registering buffer in GPU for flush operations"); - - /* - * We allocate twice the system page size since GPU memory allocation - * does not guarantee that the allocated memory will be system page aligned. - * Post allocation, we calculate the page aligned ptr and perform - * memory registrations on it. - */ - this->flush_buff.size = 2 * system_page_size; - ret = nccl_net_ofi_gpu_mem_alloc(&(this->flush_buff.buffer_base), this->flush_buff.size); - if (OFI_UNLIKELY(ret != 0)) { - NCCL_OFI_WARN("Unable to allocate flush buffer (%d)", ret); - return ret; - } + { + int rc; + NCCL_OFI_TRACE(NCCL_NET, "Registering buffer in GPU for flush operations"); - /* - * Calculate the ptr aligned to system page size - */ - this->flush_buff.buffer = (void *)NCCL_OFI_ROUND_UP((uintptr_t)this->flush_buff.buffer_base, system_page_size); - - /* Copy flush sentinel value into aligned ptr of gpu buffer */ - ret = nccl_net_ofi_gpu_mem_copy_host_to_device(this->flush_buff.buffer, flush_sentinel, - NCCL_OFI_DEFAULT_CPU_CACHE_LINE_SIZE); - if (OFI_UNLIKELY(ret != 0)) { - NCCL_OFI_WARN("Unable to copy sentinel value to gpu flush buffer (%d)", ret); - return ret; - } - -#if HAVE_DECL_FI_MR_DMABUF - /* - * If dma buf is viable and supported then register flush dummy buffer - * using dma buf for provider access - */ - nccl_net_ofi_rdma_device_t *dev = this->rdma_domain_get_device(); - struct fi_info *nic_prov = dev->get_ofi_info_for_cm(); + /* + * We allocate twice the system page size since GPU memory + * allocation does not guarantee page alignment. Post + * allocation we calculate the page-aligned ptr and perform + * memory registration on it. + */ + fb.size = 2 * system_page_size; + ret = nccl_net_ofi_gpu_mem_alloc(&(fb.buffer_base), fb.size); + if (OFI_UNLIKELY(ret != 0)) { + NCCL_OFI_WARN("Unable to allocate flush buffer (%d)", ret); + return ret; + } - if (nccl_ofi_dmabuf_viable_and_supported(nic_prov)) { - size_t offset = 0; - int fd; + /* Calculate the ptr aligned to system page size */ + fb.buffer = (void *)NCCL_OFI_ROUND_UP((uintptr_t)fb.buffer_base, + system_page_size); - /* - * Retrieve the fd and offset and the aligned ptr used for dma buf - */ - ret = nccl_net_ofi_gpu_get_dma_buf_fd(this->flush_buff.buffer, system_page_size, &fd, &offset); + /* Copy flush sentinel value into aligned ptr of gpu buffer */ + ret = nccl_net_ofi_gpu_mem_copy_host_to_device(fb.buffer, flush_sentinel, + NCCL_OFI_DEFAULT_CPU_CACHE_LINE_SIZE); if (OFI_UNLIKELY(ret != 0)) { - NCCL_OFI_WARN("Unable to retrieve flush buffer fd (%d)", ret); + NCCL_OFI_WARN("Unable to copy sentinel value to gpu flush buffer (%d)", ret); + rc = nccl_net_ofi_gpu_mem_free(fb.buffer_base); + if (rc != 0) { + NCCL_OFI_WARN("Unable to deallocate flush buffer (%d)", rc); + } + fb.buffer_base = nullptr; return ret; } - NCCL_OFI_TRACE(NCCL_NET, "Registering flush buffer using DMA BUF fd: %d offset: %ld", fd, offset); +#if HAVE_DECL_FI_MR_DMABUF + /* + * If dma buf is viable and supported then register the flush + * dummy buffer using dma buf for provider access. + */ + { + nccl_net_ofi_rdma_device_t *dev = rdma_domain->rdma_domain_get_device(); + struct fi_info *nic_prov = dev->get_ofi_info_for_cm(); + + if (nccl_ofi_dmabuf_viable_and_supported(nic_prov)) { + size_t offset = 0; + int fd; + + ret = nccl_net_ofi_gpu_get_dma_buf_fd(fb.buffer, + system_page_size, + &fd, &offset); + if (OFI_UNLIKELY(ret != 0)) { + NCCL_OFI_WARN("Unable to retrieve flush buffer fd (%d)", ret); + rc = nccl_net_ofi_gpu_mem_free(fb.buffer_base); + if (rc != 0) { + NCCL_OFI_WARN("Unable to deallocate flush buffer (%d)", rc); + } + fb.buffer_base = nullptr; + return ret; + } - ret = this->reg_internal_mr_dma_buf(this->flush_buff.buffer, fd, offset, system_page_size, - NCCL_PTR_CUDA, &mr_handle); - close(fd); - } else { - ret = this->reg_internal_mr(this->flush_buff.buffer, system_page_size, NCCL_PTR_CUDA, &mr_handle); - } + NCCL_OFI_TRACE(NCCL_NET, "Registering flush buffer using DMA BUF fd: %d offset: %ld", fd, offset); + ret = rdma_domain->reg_internal_mr_dma_buf(fb.buffer, fd, offset, + system_page_size, + NCCL_PTR_CUDA, + bind_ep, &mr_handle); + close(fd); + } else { + ret = rdma_domain->reg_internal_mr(fb.buffer, system_page_size, + NCCL_PTR_CUDA, bind_ep, &mr_handle); + } + } #else - ret = this->reg_internal_mr(this->flush_buff.buffer, system_page_size, NCCL_PTR_CUDA, &mr_handle); + ret = rdma_domain->reg_internal_mr(fb.buffer, system_page_size, + NCCL_PTR_CUDA, bind_ep, &mr_handle); #endif - if (OFI_UNLIKELY(ret != 0)) { - NCCL_OFI_WARN("Could not register dummy buffer for flush, dev: %d", - dev_id); - - rc = nccl_net_ofi_gpu_mem_free(&this->flush_buff.buffer_base); - if (rc != 0) { - NCCL_OFI_WARN("Unable to deallocate flush buffer (%d)", - rc); + if (OFI_UNLIKELY(ret != 0)) { + NCCL_OFI_WARN("Could not register dummy buffer for flush, dev: %d", dev_id); + rc = nccl_net_ofi_gpu_mem_free(fb.buffer_base); + if (rc != 0) { + NCCL_OFI_WARN("Unable to deallocate flush buffer (%d)", rc); + } + fb.buffer_base = nullptr; + fb.buffer = MAP_FAILED; } - this->flush_buff.buffer = MAP_FAILED; } #endif - this->flush_buff.mr_handle = mr_handle; + mr_out = mr_handle; return ret; } @@ -3512,6 +3583,8 @@ static int recv_comm_destroy(nccl_net_ofi_rdma_recv_comm *r_comm) delete r_comm->ctrl_buff_fl; delete r_comm->flush_buff_fl; + delete r_comm->comm_buff_regmr_ctx; + r_comm->comm_buff_regmr_ctx = nullptr; delete r_comm->nccl_ofi_reqs_fl; delete r_comm->msgbuff; @@ -4171,7 +4244,7 @@ static int alloc_rdma_read_req(nccl_net_ofi_rdma_recv_comm *r_comm, nccl_net_ofi_rdma_req **ret_req) { uint64_t flags = 0; - struct fid_mr *rail_mr_handle = buff_mr_handle->mr[0].get(); + struct fid_mr *rail_mr_handle = buff_mr_handle->mr_data[0].get(); void *desc = fi_mr_desc(rail_mr_handle); *ret_req = NULL; @@ -4336,6 +4409,7 @@ static nccl_net_ofi_rdma_recv_comm *prepare_recv_comm(nccl_net_ofi_rdma_domain_t size_t comm_id = 0; nccl_net_ofi_rdma_recv_comm *r_comm = NULL; nccl_net_ofi_rdma_ep_t *ep = NULL; + freelist_regmr_ep_ctx_t *comm_ctx = nullptr; nccl_net_ofi_rdma_device_t *device = domain->rdma_domain_get_device(); int dev_id = device->dev_id; int num_rails = l_comm_ep->num_rails; @@ -4391,7 +4465,7 @@ static nccl_net_ofi_rdma_recv_comm *prepare_recv_comm(nccl_net_ofi_rdma_domain_t ret = domain->reg_internal_mr(r_comm->ctrl_mailbox, NCCL_OFI_ROUND_UP(sizeof(nccl_net_ofi_ctrl_msg_t) * NCCL_OFI_CTRL_MAILBOX_SIZE, system_page_size), - NCCL_PTR_HOST, &mr_handle); + NCCL_PTR_HOST, endpoint_mr ? l_comm_ep : nullptr, &mr_handle); if (ret != 0) { NCCL_OFI_WARN("Failed to register memory for the control mailbox: %d", ret); goto error; @@ -4489,18 +4563,26 @@ static nccl_net_ofi_rdma_recv_comm *prepare_recv_comm(nccl_net_ofi_rdma_domain_t return NULL; } + /* Allocate a freelist registration context for ctrl_buff_fl and flush_buff_fl. + * In FI_MR_ENDPOINT mode the ep pointer is set so that each registered MR is + * bound and enabled against this endpoint. In non-endpoint-MR mode ep is nullptr + * so the bind/enable step is skipped. + * Must outlive both freelists; freed in recv_comm_destroy(). */ + comm_ctx = new freelist_regmr_ep_ctx_t{domain, endpoint_mr ? ep : nullptr}; + r_comm->comm_buff_regmr_ctx = comm_ctx; + r_comm->ctrl_buff_fl = new nccl_ofi_freelist(sizeof(nccl_net_ofi_rdma_close_msg_t), 8, 8, NCCL_OFI_MAX_REQUESTS, NULL, NULL, - freelist_regmr_host_fn, - freelist_deregmr_host_fn, domain, 1, + freelist_regmr_ep_ctx_t::regmr_fn, + freelist_regmr_ep_ctx_t::deregmr_fn, comm_ctx, 1, "Ctrl Buffer", true); /* Allocate flush buffer freelist */ r_comm->flush_buff_fl = new nccl_ofi_freelist(NCCL_OFI_DEFAULT_CPU_CACHE_LINE_SIZE * MAX_NUM_RAILS, 8, 8, NCCL_OFI_MAX_REQUESTS, NULL, NULL, - freelist_regmr_host_fn, - freelist_deregmr_host_fn, domain, + freelist_regmr_ep_ctx_t::regmr_fn, + freelist_regmr_ep_ctx_t::deregmr_fn, comm_ctx, NCCL_OFI_DEFAULT_CPU_CACHE_LINE_SIZE, "Flush Buffer", true); @@ -5033,7 +5115,7 @@ static int post_rdma_write(nccl_net_ofi_rdma_req *req, rdma_req_send_data_t *send_data = get_send_data(req); assert(xfer_info->rail_id < send_data->buff_mr_handle->num_rails); uint16_t rail_id = xfer_info->rail_id; - struct fid_mr *rail_mr_handle = send_data->buff_mr_handle->mr[rail_id].get(); + struct fid_mr *rail_mr_handle = send_data->buff_mr_handle->mr_data[rail_id].get(); void *desc = fi_mr_desc(rail_mr_handle); ssize_t rc; @@ -5070,7 +5152,7 @@ static int post_rdma_eager_send(nccl_net_ofi_rdma_req *req, rdma_req_send_data_t *send_data = get_send_data(req); assert(xfer_info->rail_id < send_data->buff_mr_handle->num_rails); uint16_t rail_id = xfer_info->rail_id; - struct fid_mr *rail_mr_handle = send_data->buff_mr_handle->mr[rail_id].get(); + struct fid_mr *rail_mr_handle = send_data->buff_mr_handle->mr_data[rail_id].get(); void *desc = fi_mr_desc(rail_mr_handle); ssize_t rc; @@ -5095,7 +5177,12 @@ static int post_rx_buffer(nccl_net_ofi_rdma_req *req, nccl_ofi_freelist::fl_entry *rx_buff_fl_elem = rx_buff_data->rx_buff_fl_elem; freelist_regmr_fn_handle_t *fl_mr_handle = (freelist_regmr_fn_handle_t *)rx_buff_fl_elem->mr_handle; - void *desc = fi_mr_desc(fl_mr_handle->mr_handle->mr[rx_buff_data->rail->rail_id].get()); + nccl_net_ofi_rdma_mr_handle_t *mr_h = fl_mr_handle->mr_handle; + uint16_t rid = rx_buff_data->rail->rail_id; + struct fid_mr *raw_mr = ep_rail->is_ctrl + ? mr_h->mr_ctrl[rid] + : mr_h->mr_data[rid].get(); + void *desc = fi_mr_desc(raw_mr); struct iovec iov; struct fi_msg msg; uint64_t flags = 0; @@ -5223,8 +5310,8 @@ static ssize_t send_ctrl_post(nccl_net_ofi_rdma_recv_comm *r_comm, nccl_net_ofi_rdma_recv_comm_rail_t *comm_rail = r_comm->get_control_rail(rail_id); - assert(rail_id < mr_handle->num_rails); - void *desc = fi_mr_desc(mr_handle->mr[rail_id].get()); + assert(rail_id < mr_handle->num_ctrl_rails); + void *desc = fi_mr_desc(mr_handle->mr_ctrl[rail_id]); ssize_t rc = fi_send(comm_rail->local_ep, ctrl_fl_elem->ptr, size, @@ -5271,7 +5358,7 @@ static int post_rdma_ctrl(nccl_net_ofi_rdma_req *req) rail_id = 0; } - void *desc = fi_mr_desc(r_comm->ctrl_mr_handle->mr[rail_id].get()); + void *desc = fi_mr_desc(r_comm->ctrl_mr_handle->mr_ctrl[rail_id]); nccl_net_ofi_rdma_recv_comm_rail_t *comm_rail = r_comm->get_control_rail(rail_id); ssize_t rc = fi_write(comm_rail->local_ep, &r_comm->ctrl_mailbox[slot], @@ -5342,10 +5429,10 @@ static int post_eager_copy(nccl_net_ofi_rdma_req *req) nccl_net_ofi_rdma_mr_handle_t *dest_mr_handle = recv_data->recvs[0].dest_mr_handle; assert(rx_rail_id < dest_mr_handle->num_rails); - void *desc = fi_mr_desc(dest_mr_handle->mr[rx_rail_id].get()); + void *desc = fi_mr_desc(dest_mr_handle->mr_data[rx_rail_id].get()); void *rx_buff = rx_buff_data->rx_buff_fl_elem->ptr; - uint64_t rx_key = fi_mr_key(rx_mr_handle->mr[rx_rail_id].get()); + uint64_t rx_key = fi_mr_key(rx_mr_handle->mr_data[rx_rail_id].get()); if (rx_key == FI_KEY_NOTAVAIL) { NCCL_OFI_WARN("Failed to get rx_key"); return -EIO; @@ -5370,7 +5457,6 @@ static int post_flush_req(nccl_net_ofi_rdma_req *req) { nccl_net_ofi_rdma_recv_comm *r_comm = (nccl_net_ofi_rdma_recv_comm *)req->comm; nccl_net_ofi_rdma_ep_t *ep = (nccl_net_ofi_rdma_ep_t *)r_comm->ep.get(); - nccl_net_ofi_rdma_domain_t *domain = ep->rdma_endpoint_get_domain(); rdma_req_flush_data_t *flush_data = get_flush_data(req); nccl_net_ofi_rdma_recv_comm_rail_t *comm_rail; ssize_t rc = 0; @@ -5380,8 +5466,8 @@ static int post_flush_req(nccl_net_ofi_rdma_req *req) comm_rail = r_comm->get_data_rail(rail_id); struct fid_mr *mr_handle = NULL; - void *desc = fi_mr_desc(domain->flush_buff.mr_handle->mr[rail_id].get()); - mr_handle = flush_data->mr_handle->mr[rail_id].get(); + void *desc = fi_mr_desc(ep->flush_buff_mr_handle->mr_data[rail_id].get()); + mr_handle = flush_data->mr_handle->mr_data[rail_id].get(); uint64_t cuda_key = 0ULL; @@ -5396,7 +5482,7 @@ static int post_flush_req(nccl_net_ofi_rdma_req *req) } } - nccl_net_ofi_rdma_flush_buffer_t *f_buff = &domain->flush_buff; + nccl_net_ofi_rdma_flush_buffer_t *f_buff = &ep->flush_buff; uintptr_t host_buff_addr = (uintptr_t)f_buff->buffer + (NCCL_OFI_DEFAULT_CPU_CACHE_LINE_SIZE * rail_id); uintptr_t buff_offset = (uintptr_t)flush_data->data - flush_data->mr_handle->base_addr; rc = fi_read(comm_rail->local_ep, @@ -5420,7 +5506,6 @@ static int post_flush_req(nccl_net_ofi_rdma_req *req) { nccl_net_ofi_rdma_recv_comm *r_comm = (nccl_net_ofi_rdma_recv_comm *)req->comm; nccl_net_ofi_rdma_ep_t *ep = (nccl_net_ofi_rdma_ep_t *)r_comm->ep.get(); - nccl_net_ofi_rdma_domain_t *domain = ep->rdma_endpoint_get_domain(); rdma_req_flush_data_t *flush_data = get_flush_data(req); nccl_net_ofi_rdma_recv_comm_rail_t *comm_rail; ssize_t rc = 0; @@ -5432,8 +5517,8 @@ static int post_flush_req(nccl_net_ofi_rdma_req *req) freelist_regmr_fn_handle_t *fl_handle = (freelist_regmr_fn_handle_t *)flush_data->flush_fl_elem->mr_handle; - void *desc = fi_mr_desc(fl_handle->mr_handle->mr[rail_id].get()); - mr_handle = domain->flush_buff.mr_handle->mr[rail_id].get(); + void *desc = fi_mr_desc(fl_handle->mr_handle->mr_data[rail_id].get()); + mr_handle = ep->flush_buff_mr_handle->mr_data[rail_id].get(); uint64_t cuda_key = 0ULL; if (mr_handle != NULL) { @@ -5447,7 +5532,7 @@ static int post_flush_req(nccl_net_ofi_rdma_req *req) } uint64_t *host_buff_addr = get_flush_buffer_for_rail(flush_data->flush_fl_elem->ptr, rail_id); - uintptr_t buff_offset = (uintptr_t)domain->flush_buff.buffer - domain->flush_buff.mr_handle->base_addr; + uintptr_t buff_offset = (uintptr_t)ep->flush_buff.buffer - ep->flush_buff_mr_handle->base_addr; rc = fi_read(comm_rail->local_ep, (void *)host_buff_addr, NCCL_OFI_DEFAULT_CPU_CACHE_LINE_SIZE, desc, comm_rail->local_addr, @@ -5522,7 +5607,6 @@ int nccl_net_ofi_rdma_send_comm::send(void *data, size_t size, int tag, nccl_net_ofi_rdma_send_comm *s_comm = this; nccl_net_ofi_rdma_mr_handle_t *mr_handle = (nccl_net_ofi_rdma_mr_handle_t *)mhandle; nccl_net_ofi_rdma_ep_t *endpoint = NULL; - nccl_net_ofi_rdma_domain_t *domain = NULL; nccl_net_ofi_rdma_req *req = NULL; uint16_t msg_seq_num = s_comm->next_msg_seq_num; bool have_ctrl = false; @@ -5548,9 +5632,6 @@ int nccl_net_ofi_rdma_send_comm::send(void *data, size_t size, int tag, endpoint = (nccl_net_ofi_rdma_ep_t *)s_comm->ep.get(); assert(endpoint != NULL); - domain = endpoint->rdma_endpoint_get_domain(); - assert(domain != NULL); - std::lock_guard eplock(endpoint->ep_lock); CHECK_ENDPOINT_ACTIVE(endpoint, "send"); @@ -5579,8 +5660,8 @@ int nccl_net_ofi_rdma_send_comm::send(void *data, size_t size, int tag, * pointer and MR */ if (size == 0) { - data = domain->flush_buff.buffer; - mr_handle = domain->flush_buff.mr_handle; + data = endpoint->flush_buff.buffer; + mr_handle = endpoint->flush_buff_mr_handle; } in_group = (s_comm->group_sends_remaining > 0); @@ -5748,10 +5829,14 @@ void nccl_net_ofi_rdma_ep_t::prepare_send_connect_message(uint32_t local_comm_id /* Send s_comm's control mailbox offset to receiver */ conn_msg->ctrl_addr = (uintptr_t)ctrl_msg - ctrl_msg_mr_handle->base_addr; - /* Send s_comm's control mailbox mr_key */ - for (uint16_t rail_id = 0; rail_id != num_rails; ++rail_id) { - uint64_t rkey = fi_mr_key(ctrl_msg_mr_handle->mr[rail_id].get()); - conn_msg->ctrl_mr_key[rail_id] = rkey; + /* Send the MR key for each ctrl ep rail so the receiver can fi_write + * into our ctrl mailbox. The receiver will issue fi_write operations + * through our ctrl ep rails, so in FI_MR_ENDPOINT mode the rkey must + * come from an fid_mr bound to a ctrl ep (mr_ctrl[]). In + * non-endpoint-MR mode mr_ctrl[i] is the same object as mr_data[i] + * because a single domain-level MR is valid for all endpoints. */ + for (uint16_t rail_id = 0; rail_id != num_control_rails; ++rail_id) { + conn_msg->ctrl_mr_key[rail_id] = fi_mr_key(ctrl_msg_mr_handle->mr_ctrl[rail_id]); } /* Set number of rails to be sent back to remote for verification */ @@ -5834,20 +5919,30 @@ int nccl_net_ofi_rdma_ep_t::init_rx_buffers() "Rx Buffer Requests", enable_freelist_leak_detection); + /* Allocate a context for freelist MR registration callbacks. + * In FI_MR_ENDPOINT mode ep is set so newly registered MRs are + * bound and enabled against this endpoint; otherwise ep is nullptr. + * Must outlive the freelists; freed in fini_rx_buffers(). */ + freelist_regmr_ep_ctx_t *rx_ctx = new freelist_regmr_ep_ctx_t{domain_ptr, endpoint_mr ? this : nullptr}; + this->rx_buff_regmr_ctx = rx_ctx; + + /* Ctrl rx buffers are posted via fi_recvmsg on the ctrl ep rail; the unified + * MR handle binds to both data and ctrl ep rails so rx_ctx works for both. */ this->ctrl_rx_buff_fl = new nccl_ofi_freelist(this->ctrl_rx_buff_size, ofi_nccl_rdma_min_posted_control_buffers(), 16, 0, NULL, NULL, - freelist_regmr_host_fn, freelist_deregmr_host_fn, - domain_ptr, 1, + freelist_regmr_ep_ctx_t::regmr_fn, freelist_regmr_ep_ctx_t::deregmr_fn, + rx_ctx, 1, "Ctrl Rx Buffer", enable_freelist_leak_detection); if (this->eager_rx_buff_size > 0) { + /* Eager rx buffers are posted on data ep rails; use rx_ctx. */ this->eager_rx_buff_fl = new nccl_ofi_freelist(this->eager_rx_buff_size, ofi_nccl_rdma_min_posted_eager_buffers(), 16, 0, NULL, NULL, - freelist_regmr_host_fn, freelist_deregmr_host_fn, - domain_ptr, EAGER_RX_BUFFER_ALIGNMENT, + freelist_regmr_ep_ctx_t::regmr_fn, freelist_regmr_ep_ctx_t::deregmr_fn, + rx_ctx, EAGER_RX_BUFFER_ALIGNMENT, "Eager Rx Buffer", enable_freelist_leak_detection); } else { @@ -5871,6 +5966,7 @@ int nccl_net_ofi_rdma_ep_t::init_rx_buffers() rail->num_rx_buff_posted = 0; nccl_net_ofi_mutex_init(&rail->rx_buff_mutex, NULL); rail->rx_buff_req_alloc = ctrl_rx_buff_req_alloc; + rail->is_ctrl = true; } for (uint16_t rail_id = 0; rail_id < this->num_rails; ++rail_id) { @@ -5889,8 +5985,21 @@ int nccl_net_ofi_rdma_ep_t::init_rx_buffers() rail->num_rx_buff_posted = 0; nccl_net_ofi_mutex_init(&rail->rx_buff_mutex, NULL); rail->rx_buff_req_alloc = eager_rx_buff_req_alloc; + rail->is_ctrl = false; } +#if HAVE_GPU || HAVE_NEURON + { + nccl_net_ofi_rdma_device_t *dev = domain_ptr->rdma_domain_get_device(); + ret = this->alloc_and_reg_flush_buff(dev->dev_id, + this->flush_buff, + this->flush_buff_mr_handle); + if (OFI_UNLIKELY(ret != 0)) { + return ret; + } + } +#endif /* HAVE_GPU / HAVE_NEURON */ + return ret; } @@ -5908,6 +6017,30 @@ int nccl_net_ofi_rdma_ep_t::fini_rx_buffers() delete this->rx_buff_reqs_fl; + /* Deregister the flush buffer MR (always owned by this endpoint). */ + if (this->flush_buff_mr_handle != nullptr) { + nccl_net_ofi_rdma_domain_t *domain_ptr = this->rdma_endpoint_get_domain(); + domain_ptr->dereg_mr(this->flush_buff_mr_handle); + this->flush_buff_mr_handle = nullptr; + } + +#if HAVE_GPU + /* Free the per-endpoint flush buffer GPU allocation. */ + if (this->flush_buff.buffer_base != nullptr) { + nccl_net_ofi_gpu_mem_free(this->flush_buff.buffer_base); + this->flush_buff.buffer_base = nullptr; + this->flush_buff.buffer = MAP_FAILED; + } +#endif /* HAVE_GPU */ + +#if HAVE_NEURON + /* Free the per-endpoint host flush buffer. */ + if (this->flush_buff.buffer != MAP_FAILED && this->flush_buff.buffer != nullptr) { + nccl_net_ofi_dealloc_mr_buffer(this->flush_buff.buffer, system_page_size); + this->flush_buff.buffer = MAP_FAILED; + } +#endif /* HAVE_NEURON */ + for (uint16_t rail_id = 0; rail_id < this->num_rails; ++rail_id) { rail = this->rdma_endpoint_get_rail(rail_id); nccl_net_ofi_mutex_destroy(&rail->rx_buff_mutex); @@ -5925,8 +6058,8 @@ int nccl_net_ofi_rdma_ep_t::fini_rx_buffers() int nccl_net_ofi_rdma_mr_handle_t::get_mr_key(uint64_t *mr_key_ptr) { int ret = 0; - assert(!this->mr.empty()); - uint64_t key = fi_mr_key(this->mr[0].get()); + assert(num_rails > 0); + uint64_t key = fi_mr_key(this->mr_data[0].get()); if (OFI_UNLIKELY(key == FI_KEY_NOTAVAIL)) { ret = -ENOENT; NCCL_OFI_WARN("Error retrieving MR key, leaking key"); @@ -6032,7 +6165,7 @@ int nccl_net_ofi_rdma_send_comm::write(void* src, size_t size, void* mhandle, uint64_t dest, uint64_t mr_key, nccl_net_ofi_req ** base_req) { nccl_net_ofi_rdma_mr_handle_t *mr_handle = (nccl_net_ofi_rdma_mr_handle_t *)mhandle; - struct fid_mr *rail_mr_handle = mr_handle->mr[0].get(); + struct fid_mr *rail_mr_handle = mr_handle->mr_data[0].get(); void *desc = fi_mr_desc(rail_mr_handle); uint64_t flags = 0; return rma_write_impl(this, src, size, desc, dest, mr_key, flags, base_req); @@ -6118,7 +6251,7 @@ int nccl_net_ofi_rdma_ep_t::create_send_comm(nccl_net_ofi_rdma_send_comm **s_com /* Allocate control mailbox */ ret = domain_ptr->reg_internal_mr(ret_s_comm->ctrl_mailbox, NCCL_OFI_ROUND_UP(sizeof(nccl_net_ofi_ctrl_msg_t) * NCCL_OFI_CTRL_MAILBOX_SIZE, system_page_size), - NCCL_PTR_HOST, &ret_s_comm->ctrl_mr_handle); + NCCL_PTR_HOST, endpoint_mr ? this : nullptr, &ret_s_comm->ctrl_mr_handle); if (ret != 0) { NCCL_OFI_WARN("Could not register memory for control mailbox for dev %d", dev_id); ret = -ENOMEM; @@ -6432,12 +6565,29 @@ nccl_net_ofi_rdma_ep_t::~nccl_net_ofi_rdma_ep_t() } /* Ideally we would "un-post" the rx buffers, but this - * should be accomplished by closing the endpoint. */ - this->release_rdma_ep_resources(device->dev_id); + * should be accomplished by closing the endpoint. + * + * In FI_MR_ENDPOINT mode the libfabric spec requires that MRs bound to + * an endpoint are closed before the endpoint itself. fini_rx_buffers() + * deregisters per-endpoint MRs, so it must run before + * release_rdma_ep_resources() in that mode. For non-endpoint-MR + * providers the original order is preserved. + */ + int err_code; + if (endpoint_mr) { + err_code = this->fini_rx_buffers(); + if (err_code != 0) { + NCCL_OFI_WARN("rdma endpoint destructor: tearing down freelists failed, rc %d", err_code); + } - int err_code = this->fini_rx_buffers(); - if (err_code != 0) { - NCCL_OFI_WARN("rdma endpoint destructor: tearing down freelists failed, rc %d", err_code); + this->release_rdma_ep_resources(device->dev_id); + } else { + this->release_rdma_ep_resources(device->dev_id); + + err_code = this->fini_rx_buffers(); + if (err_code != 0) { + NCCL_OFI_WARN("rdma endpoint destructor: tearing down freelists failed, rc %d", err_code); + } } err_code = nccl_net_ofi_mutex_destroy(&this->pending_reqs_lock); @@ -6520,12 +6670,8 @@ nccl_net_ofi_rdma_ep_t::nccl_net_ofi_rdma_ep_t(std::shared_ptrnum_control_rails = 1; } - /* Zero-initialize the rails and control_rails vector elements */ - this->rails.resize(this->num_rails); - - this->control_rails.resize(this->num_control_rails); - - this->cq_rails.resize(this->num_rails); + /* Zero-initialize the rails and control_rails array elements + * via the = {} initializers on the member declarations. */ ret = nccl_net_ofi_mutex_init(&this->pending_reqs_lock, NULL); if (ret != 0) { @@ -6564,11 +6710,6 @@ nccl_net_ofi_rdma_ep_t::nccl_net_ofi_rdma_ep_t(std::shared_ptrdealloc_and_dereg_flush_buff(); - if (err_code != 0) { - NCCL_OFI_WARN("Failed to deregister flush buffer pool"); - } - /* Check for leaked endpoints. With weak_ptr entries, expired * entries are harmless (stale cache). But a live weak_ptr * means a comm still holds a shared_ptr to an ep, which @@ -6588,7 +6729,6 @@ nccl_net_ofi_rdma_domain_t::nccl_net_ofi_rdma_domain_t(nccl_net_ofi_rdma_device_ unsigned int domain_key_arg) : nccl_net_ofi_domain_t(device_arg, domain_key_arg) { - int ret = 0; if (OFI_UNLIKELY(device_arg == nullptr)) { NCCL_OFI_WARN("Invalid device provided"); throw std::runtime_error("RDMA domain constructor: invalid device provided"); @@ -6610,14 +6750,6 @@ nccl_net_ofi_rdma_domain_t::nccl_net_ofi_rdma_domain_t(nccl_net_ofi_rdma_device_ } domain_rail->domain = std::move(domain_result.resource); } - - /* - * Setup flush resources. - */ - ret = this->alloc_and_reg_flush_buff(device_arg->dev_id); - if (OFI_UNLIKELY(ret != 0)) { - throw std::runtime_error("RDMA domain constructor: flush buffer alloc/reg failed"); - } } nccl_net_ofi_domain_t *nccl_net_ofi_rdma_device_t::create_domain(unsigned int domain_key) @@ -6856,7 +6988,7 @@ static void get_hints(struct fi_info *hints) hints->ep_attr->type = FI_EP_RDM; hints->domain_attr->mr_mode = FI_MR_LOCAL | FI_MR_HMEM | FI_MR_VIRT_ADDR | - FI_MR_ALLOCATED | FI_MR_PROV_KEY; + FI_MR_ALLOCATED | FI_MR_PROV_KEY | FI_MR_ENDPOINT; hints->domain_attr->mr_key_size = (size_t) ofi_nccl_mr_key_size(); hints->domain_attr->threading = FI_THREAD_COMPLETION; @@ -7043,11 +7175,6 @@ int nccl_net_ofi_rdma_init(const char *provider_filter, return ret; } - if (endpoint_mr) { - NCCL_OFI_WARN("RDMA protocol does not support endpoint memory registration."); - return -ENOTSUP; - } - if ((ssize_t)ofi_nccl_eager_max_size() > (ssize_t)ofi_nccl_min_stripe_size()) { NCCL_OFI_WARN("Invalid value for EAGER_MAX_SIZE"); return -ENOTSUP;