Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 110 additions & 37 deletions include/nccl_ofi_rdma.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,9 @@ static_assert(NCCL_OFI_RDMA_MSG_MAX <= (0x10),
* size to be fixed.
*/
typedef uint16_t nccl_ofi_rdma_msg_type_t;
/* Forward declaration for freelist MR registration callback class */
class freelist_regmr_ep_ctx_t;

/*
* @brief Rdma memory registration handle
*
Expand All @@ -132,25 +135,51 @@ class nccl_net_ofi_rdma_mr_handle_t : public nccl_net_ofi_mr_handle_t {
public:
/**
* @brief Default constructor
*
* Every MR handle covers both the data ep rails and the control ep
* rails. With FI_MR_ENDPOINT a single fid_mr may only be bound to
* one fid_ep, so we keep separate arrays for ownership and access:
*
* mr_data[0..num_rails-1] — owned fid_mr objects, bound to data ep rails
* mr_ctrl_owned[0..num_ctrl_rails-1] — owned fid_mr objects, bound to ctrl ep
* rails (FI_MR_ENDPOINT mode only)
* mr_ctrl[0..num_ctrl_rails-1] — non-owning view; always populated after
* registration:
* · FI_MR_ENDPOINT: points into mr_ctrl_owned[]
* · non-endpoint-MR: aliases mr_data[] entries
*
* Using a raw-pointer view for mr_ctrl[] eliminates IO-path conditionals:
* callers always use mr_ctrl[rail_id] without checking whether it is populated.
*/
nccl_net_ofi_rdma_mr_handle_t(size_t num_rails_arg)
nccl_net_ofi_rdma_mr_handle_t(size_t num_rails_arg, size_t num_ctrl_rails_arg)
: nccl_net_ofi_mr_handle_t(0),
num_rails(num_rails_arg),
num_ctrl_rails(num_ctrl_rails_arg),
base_addr(0)
{
mr_ctrl.fill(nullptr);
}

/**
* @brief Get MR key for RDMA handle
*
* Return MR key associated with first mr array element
*
* Return MR key associated with first data mr array element
*/
int get_mr_key(uint64_t *mr_key_ptr) override;

uint16_t num_rails;
uint16_t num_ctrl_rails;

/* Array of size `num_rails', indexed by rail_id */
std::array<ofi_mr_ptr, MAX_NUM_RAILS> mr;
/* Owned fid_mr objects, one per data ep rail */
std::array<ofi_mr_ptr, MAX_NUM_RAILS> mr_data;

/* Owned fid_mr objects, one per ctrl ep rail (FI_MR_ENDPOINT mode only).
* In non-endpoint-MR mode entries beyond num_ctrl_rails remain null. */
std::array<ofi_mr_ptr, MAX_NUM_RAILS> mr_ctrl_owned;

/* Non-owning view into either mr_ctrl_owned[] (FI_MR_ENDPOINT) or mr_data[]
* (non-endpoint-MR). Populated up to num_ctrl_rails after reg_mr_on_device(). */
std::array<struct fid_mr *, MAX_NUM_RAILS> mr_ctrl;

/* Base address of the registered memory region for offset calculation */
uintptr_t base_addr;
Expand Down Expand Up @@ -750,6 +779,10 @@ class nccl_net_ofi_rdma_recv_comm : public nccl_net_ofi_recv_comm {
/* Free list to track host flush buffers, for sending flush messages */
nccl_ofi_freelist *flush_buff_fl;

/* Context shared by ctrl_buff_fl and flush_buff_fl registration callbacks.
* Freed in recv_comm_destroy(). */
freelist_regmr_ep_ctx_t *comm_buff_regmr_ctx = nullptr;

#if HAVE_NVTX_TRACING
nvtxDomainHandle_t nvtx_domain[NCCL_OFI_N_NVTX_DOMAIN_PER_COMM];
#endif
Expand Down Expand Up @@ -875,6 +908,7 @@ class nccl_net_ofi_rdma_domain_t : public nccl_net_ofi_domain_t {
*/
int reg_mr(nccl_ofi_mr_ckey_ref ckey,
int type,
nccl_net_ofi_rdma_ep_t *ep,
nccl_net_ofi_rdma_mr_handle_t **mhandle);

/**
Expand Down Expand Up @@ -930,11 +964,13 @@ class nccl_net_ofi_rdma_domain_t : public nccl_net_ofi_domain_t {
*/
int reg_internal_mr(void *data,
size_t size, int type,
nccl_net_ofi_rdma_ep_t *ep,
nccl_net_ofi_rdma_mr_handle_t **mhandle);

#if HAVE_DECL_FI_MR_DMABUF
int reg_internal_mr_dma_buf(void *data,
int fd, uint64_t offset, size_t size, int type,
nccl_net_ofi_rdma_ep_t *ep,
nccl_net_ofi_rdma_mr_handle_t **mhandle);
#endif
/**
Expand All @@ -951,46 +987,46 @@ class nccl_net_ofi_rdma_domain_t : public nccl_net_ofi_domain_t {
uint16_t num_rails;
std::array<nccl_net_ofi_rdma_domain_rail_t, MAX_NUM_RAILS> domain_rails;

/* The flush buffer */
nccl_net_ofi_rdma_flush_buffer_t flush_buff;

/* List of endpoints and set of addresses they have connections to */
nccl_ofi_ep_addr_list_t ep_addr_list;

protected:
/**
* @brief RDMA domain destructor.
*
* Cleans up RDMA domain resources (flush buffers, ep_table).
* Cleans up RDMA domain resources (ep_table).
*/
~nccl_net_ofi_rdma_domain_t() override;

/**
* @brief Allocated and registers buffer to flush RDMA operations. On
* Success, receive communicator holds reference to flush buffer
* and associated memory handle.
*
* @param dev_id
* Device ID
*
* @return 0, on success
* error, on others
*/
int alloc_and_reg_flush_buff(int dev_id);

/**
* @brief Deregister flush buffer if flush buffer was registered. Deallocate flush buffer.
*
* @return 0, on success
* error, on others
*/
int dealloc_and_dereg_flush_buff();

private:
int reg_mr_on_device(nccl_ofi_mr_ckey_ref ckey,
int type,
nccl_net_ofi_rdma_ep_t *ep,
nccl_net_ofi_rdma_mr_handle_t **mhandle);

/**
* @brief Bind and enable a memory region on an endpoint
*
* Calls fi_mr_bind followed by fi_mr_enable on the given MR and endpoint.
*
* @param mr
* The memory region to bind and enable
* @param ofi_ep
* The endpoint to bind the MR to
* @param rail_id
* Rail index, used only for diagnostic messages
* @param rail_type
* Human-readable rail type string (e.g. "data" or "ctrl"),
* used only for diagnostic messages
*
* @return 0 on success
* non-zero on error
*/
int mr_bind_and_enable(struct fid_mr *mr,
struct fid_ep *ofi_ep,
uint16_t rail_id,
const char *rail_type);

/**
* @brief Deregister memory region without acquiring memory region cache lock
*
Expand Down Expand Up @@ -1076,6 +1112,9 @@ class nccl_net_ofi_rdma_ep_rail_t {
/* Mutex for rx buffer operations */
pthread_mutex_t rx_buff_mutex;

/* True if this rail posts control (ctrl) rx buffers; false for data (eager) */
bool is_ctrl = false;

/* Allocate a receive buffer request for this rail (eager or ctrl) */
nccl_net_ofi_rdma_req* (*rx_buff_req_alloc)(nccl_net_ofi_rdma_ep_t *ep,
nccl_net_ofi_rdma_ep_rail_t *rail);
Expand Down Expand Up @@ -1159,7 +1198,6 @@ class nccl_net_ofi_rdma_ep_t : public nccl_net_ofi_ep_t {
*/
inline nccl_net_ofi_rdma_ep_rail_t *rdma_endpoint_get_rail(uint16_t rail_id)
{
assert(!rails.empty());
assert(rail_id < num_rails);
return &rails[rail_id];
}
Expand All @@ -1169,7 +1207,6 @@ class nccl_net_ofi_rdma_ep_t : public nccl_net_ofi_ep_t {
*/
inline nccl_net_ofi_rdma_ep_rail_t *rdma_endpoint_get_control_rail(uint16_t rail_id)
{
assert(!control_rails.empty());
assert(rail_id < num_control_rails);
return &control_rails[rail_id];
}
Expand All @@ -1179,7 +1216,6 @@ class nccl_net_ofi_rdma_ep_t : public nccl_net_ofi_ep_t {
*/
inline nccl_net_ofi_rdma_cq_rail_t *rdma_endpoint_get_cq_rail(uint16_t rail_id)
{
assert(!cq_rails.empty());
assert(rail_id < num_rails);
return &cq_rails[rail_id];
}
Expand All @@ -1190,7 +1226,7 @@ class nccl_net_ofi_rdma_ep_t : public nccl_net_ofi_ep_t {
*/
inline ofi_cq_ptr &get_ofi_cq_for_cm() override
{
assert(!cq_rails.empty());
assert(num_rails > 0);
return cq_rails[0].cq;
}

Expand Down Expand Up @@ -1316,13 +1352,13 @@ class nccl_net_ofi_rdma_ep_t : public nccl_net_ofi_ep_t {
uint16_t num_control_rails;

/* Array of `num_rails` endpoint rails */
std::vector<nccl_net_ofi_rdma_ep_rail_t> rails;
std::array<nccl_net_ofi_rdma_ep_rail_t, MAX_NUM_RAILS> rails = {};

/* Array of `num_control_rails` endpoint rails */
std::vector<nccl_net_ofi_rdma_ep_rail_t> control_rails;
std::array<nccl_net_ofi_rdma_ep_rail_t, MAX_NUM_RAILS> control_rails = {};

/* Array of `num_rails` cq rails */
std::vector<nccl_net_ofi_rdma_cq_rail_t> cq_rails;
std::array<nccl_net_ofi_rdma_cq_rail_t, MAX_NUM_RAILS> cq_rails = {};

/* Pending requests queue */
std::deque<nccl_net_ofi_rdma_req *> pending_reqs_queue;
Expand All @@ -1335,6 +1371,23 @@ class nccl_net_ofi_rdma_ep_t : public nccl_net_ofi_ep_t {
nccl_ofi_freelist *eager_rx_buff_fl = nullptr;
/* Free list of rx buffer requests */
nccl_ofi_freelist *rx_buff_reqs_fl = nullptr;
/* Context passed to the freelist MR registration callback for rx buffer freelists.
* Heap-allocated freelist_regmr_ep_ctx_t; freed in fini_rx_buffers(). */
freelist_regmr_ep_ctx_t *rx_buff_regmr_ctx = nullptr;
/* MR handle for the flush buffer used by this endpoint.
* In FI_MR_ENDPOINT mode this is a per-endpoint registration bound to
* the endpoint's own flush_buff allocation; otherwise it aliases the
* domain's shared flush buffer MR.
* Set in init_rx_buffers(); the per-endpoint registration (if any)
* is deregistered in fini_rx_buffers(). */
nccl_net_ofi_rdma_mr_handle_t *flush_buff_mr_handle = nullptr;
/* Flush buffer for this endpoint.
* Set in init_rx_buffers().
* In FI_MR_ENDPOINT mode: owns a per-endpoint GPU allocation
* (buffer_base / buffer); freed in fini_rx_buffers().
Comment thread
ryanhankins marked this conversation as resolved.
* In non-endpoint-MR mode: buffer aliases the domain's
* flush_buff.buffer; buffer_base is nullptr (not owned). */
nccl_net_ofi_rdma_flush_buffer_t flush_buff = {};
/* Size of ctrl rx buffers */
size_t ctrl_rx_buff_size;
/* Size of eager rx buffers. Will be -1 if eager is entirely
Expand Down Expand Up @@ -1376,6 +1429,26 @@ class nccl_net_ofi_rdma_ep_t : public nccl_net_ofi_ep_t {
*/
int init_rx_buffers();

/**
* @brief Allocate and register a buffer used to flush RDMA operations.
*
* Allocates a GPU (or host, for Neuron) flush buffer into @p fb and
* registers it as an MR via the domain. In FI_MR_ENDPOINT mode the
* MR is bound and enabled on this endpoint; otherwise it is a
* domain-scoped registration. The resulting MR handle is written to
* @p mr_out.
*
* @param dev_id Device ID (used in error messages)
* @param fb Flush-buffer struct to populate
* @param mr_out Receives the registered MR handle
*
* @return 0, on success
* error, on others
*/
int alloc_and_reg_flush_buff(int dev_id,
nccl_net_ofi_rdma_flush_buffer_t &fb,
nccl_net_ofi_rdma_mr_handle_t *&mr_out);

/**
* @brief Initialize libfabric resources of endpoint rails
*/
Expand Down
Loading
Loading