Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 24 additions & 2 deletions include/rdma/gin/nccl_ofi_gin_gdaki.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,31 @@
bool nccl_ofi_gin_gdaki_enabled();

/*
* The GDAKI plugin. Shared functions (init, devices, listen, connect)
* are nullptr and get copied from the proxy plugin at init time.
* The GDAKI plugin. Shared plugin APIs (declared below) are assigned into
* this plugin at compile time; GDAKI-specific APIs live in
* nccl_ofi_gin_gdaki.cpp.
*/
extern ncclGin_v13_t nccl_ofi_gin_gdaki_plugin;

/*
* Shared plugin APIs — defined in nccl_ofi_gin_api.cpp. GDAKI reuses these
* directly because they operate on shared types (nccl_ofi_rdma_gin_put_comm
* etc.) produced by connect() in both proxy and GDAKI modes.
*/
ncclResult_t nccl_ofi_gin_init(void **ctx, uint64_t commId, ncclDebugLogger_t logFunction);
ncclResult_t nccl_ofi_gin_devices(int *ndev);
ncclResult_t nccl_ofi_gin_listen(void *ctx, int dev, void *handle, void **listenComm);
ncclResult_t nccl_ofi_gin_connect(void *ctx, void *handles[], int nranks, int rank,
void *listenComm, void **collComm);
ncclResult_t nccl_ofi_gin_regMrSym(void *collComm, void *data, size_t size, int type,
uint64_t mrFlags, void **mhandle, void **ginHandle);
ncclResult_t nccl_ofi_gin_regMrSymDmaBuf(void *collComm, void *data, size_t size, int type,
uint64_t offset, int fd, uint64_t mrFlags,
void **mhandle, void **ginHandle);
ncclResult_t nccl_ofi_gin_deregMrSym(void *collComm, void *mhandle);
ncclResult_t nccl_ofi_gin_closeColl(void *collComm);
ncclResult_t nccl_ofi_gin_closeListen(void *listenComm);
ncclResult_t nccl_ofi_gin_ginProgress(void *collComm);
ncclResult_t nccl_ofi_gin_finalize(void *ctx);

#endif /* NCCL_OFI_GIN_GDAKI_H_ */
34 changes: 14 additions & 20 deletions src/rdma/gin/nccl_ofi_gin_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ struct nccl_ofi_gin_context {
explicit nccl_ofi_gin_context(uint64_t id) : comm_id(id) {}
};

static ncclResult_t nccl_ofi_gin_init(void **ctx, uint64_t commId, ncclDebugLogger_t logFunction)
ncclResult_t nccl_ofi_gin_init(void **ctx, uint64_t commId, ncclDebugLogger_t logFunction)
{
if (ofi_log_function == nullptr) {
ofi_log_function = logFunction;
Expand Down Expand Up @@ -76,25 +76,19 @@ static ncclResult_t nccl_ofi_gin_init(void **ctx, uint64_t commId, ncclDebugLogg
}

/*
* Morph the exported plugin to GDAKI if requested.
*
* Copy shared functions (init, devices, listen, connect) from the
* proxy plugin into the GDAKI plugin, then overwrite the exported
* symbol with the GDAKI plugin.
* Morph the exported plugin to GDAKI if requested. Shared plugin APIs
* are wired into nccl_ofi_gin_gdaki_plugin at compile time, so we just
* overwrite the exported symbol.
*/
if (nccl_ofi_gin_gdaki_enabled()) {
NCCL_OFI_INFO(NCCL_NET | NCCL_INIT, "gin: GDAKI mode enabled (OFI_NCCL_GIN_GDAKI=1)");
nccl_ofi_gin_gdaki_plugin.init = ncclGinPlugin_v13.init;
nccl_ofi_gin_gdaki_plugin.devices = ncclGinPlugin_v13.devices;
nccl_ofi_gin_gdaki_plugin.listen = ncclGinPlugin_v13.listen;
nccl_ofi_gin_gdaki_plugin.connect = ncclGinPlugin_v13.connect;
memcpy(&ncclGinPlugin_v13, &nccl_ofi_gin_gdaki_plugin, sizeof(ncclGinPlugin_v13));
}

return ncclSuccess;
}

static ncclResult_t nccl_ofi_gin_devices(int *ndev)
ncclResult_t nccl_ofi_gin_devices(int *ndev)
{
return nccl_net_ofi_devices(ndev);
}
Expand Down Expand Up @@ -135,7 +129,7 @@ static ncclResult_t nccl_ofi_gin_getProperties(int dev, ncclNetProperties_v11_t
return ncclSuccess;
}

static ncclResult_t nccl_ofi_gin_listen(void *ctx, int dev, void *handle, void **listenComm)
ncclResult_t nccl_ofi_gin_listen(void *ctx, int dev, void *handle, void **listenComm)
{
/* Extract communicator ID from GIN context */
nccl_ofi_gin_context *context = static_cast<nccl_ofi_gin_context *>(ctx);
Expand Down Expand Up @@ -195,7 +189,7 @@ static ncclResult_t nccl_ofi_gin_listen(void *ctx, int dev, void *handle, void *
return ncclSuccess;
}

static ncclResult_t nccl_ofi_gin_connect(void *ctx, void *handles[], int nranks, int rank,
ncclResult_t nccl_ofi_gin_connect(void *ctx, void *handles[], int nranks, int rank,
void *listenComm, void **collComm)
{
auto *gin_handles = reinterpret_cast<nccl_net_ofi_conn_handle_t **>(handles);
Expand All @@ -214,7 +208,7 @@ static ncclResult_t nccl_ofi_gin_connect(void *ctx, void *handles[], int nranks,
return nccl_net_ofi_retval_translate(ret);
}

static ncclResult_t nccl_ofi_gin_regMrSymDmaBuf(void *collComm, void *data, size_t size, int type,
ncclResult_t nccl_ofi_gin_regMrSymDmaBuf(void *collComm, void *data, size_t size, int type,
uint64_t offset, int fd, uint64_t mrFlags,
void **mhandle, void **ginHandle)
{
Expand Down Expand Up @@ -243,14 +237,14 @@ static ncclResult_t nccl_ofi_gin_regMrSymDmaBuf(void *collComm, void *data, size
return ncclSuccess;
}

static ncclResult_t nccl_ofi_gin_regMrSym(void *collComm, void *data, size_t size, int type,
ncclResult_t nccl_ofi_gin_regMrSym(void *collComm, void *data, size_t size, int type,
uint64_t mrFlags, void **mhandle, void **ginHandle)
{
return nccl_ofi_gin_regMrSymDmaBuf(collComm, data, size, type, 0, -1, mrFlags, mhandle,
ginHandle);
}

static ncclResult_t nccl_ofi_gin_deregMrSym(void *collComm, void *mhandle)
ncclResult_t nccl_ofi_gin_deregMrSym(void *collComm, void *mhandle)
{
auto *comm = static_cast<nccl_ofi_rdma_gin_put_comm *>(collComm);
auto *mr_handle = static_cast<nccl_ofi_gin_symm_mr_handle_t *>(mhandle);
Expand All @@ -263,7 +257,7 @@ static ncclResult_t nccl_ofi_gin_deregMrSym(void *collComm, void *mhandle)
return ncclSuccess;
}

static ncclResult_t nccl_ofi_gin_ginProgress(void *collComm)
ncclResult_t nccl_ofi_gin_ginProgress(void *collComm)
{
auto *gin_comm = static_cast<nccl_ofi_rdma_gin_put_comm *>(collComm);
int ret = gin_comm->get_resources().progress();
Expand All @@ -275,7 +269,7 @@ static ncclResult_t nccl_ofi_gin_ginProgress(void *collComm)
return nccl_net_ofi_retval_translate(ret);
}

static ncclResult_t nccl_ofi_gin_closeColl(void *collComm)
ncclResult_t nccl_ofi_gin_closeColl(void *collComm)
{
auto *gin_comm = static_cast<nccl_ofi_rdma_gin_put_comm *>(collComm);

Expand All @@ -286,7 +280,7 @@ static ncclResult_t nccl_ofi_gin_closeColl(void *collComm)
return nccl_net_ofi_retval_translate(ret);
}

static ncclResult_t nccl_ofi_gin_closeListen(void *listenComm)
ncclResult_t nccl_ofi_gin_closeListen(void *listenComm)
{
delete static_cast<nccl_ofi_rdma_gin_listen_comm *>(listenComm);
return ncclSuccess;
Expand Down Expand Up @@ -330,7 +324,7 @@ static ncclResult_t nccl_ofi_gin_iput(void *collComm, uint64_t srcOff, void *src
0, nullptr, 0, 0, request);
}

static ncclResult_t nccl_ofi_gin_finalize(void *ctx)
ncclResult_t nccl_ofi_gin_finalize(void *ctx)
{
/* Clean up the GIN context structure.
If ctx is NULL, init() was never called or failed, so there's
Expand Down
96 changes: 24 additions & 72 deletions src/rdma/gin/nccl_ofi_gin_gdaki.cpp
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
/*
* Copyright (c) 2026 Amazon.com, Inc. or its affiliates. All rights reserved.
*
* GDAKI stub implementations for the GIN plugin API.
* These provide the full ncclGin_v13_t plugin for GDAKI mode.
*
* Task: GDAKI stub implementation
* GDAKI plugin for the GIN API. Shared APIs (init, devices, listen, connect,
* regMrSym[DmaBuf], deregMrSym, closeColl, closeListen, ginProgress, finalize)
* are reused from the proxy-side implementations in nccl_ofi_gin_api.cpp.
* Only the GDAKI-specific stubs (createContext/destroyContext/get_properties/
* regMr*-return-error/queryLastError) live here until they are implemented.
*/

#include "config.h"
Expand Down Expand Up @@ -73,86 +74,37 @@ static ncclResult_t nccl_ofi_gin_gdaki_destroyContext(void *ginCtx)
return ncclSuccess;
}

static ncclResult_t nccl_ofi_gin_gdaki_regMrSym(void *collComm, void *data, size_t size, int type,
uint64_t mrFlags, void **mhandle, void **ginHandle)
{
NCCL_OFI_WARN("gin GDAKI: regMrSym not yet implemented");
return ncclInternalError;
}

static ncclResult_t nccl_ofi_gin_gdaki_regMrSymDmaBuf(void *collComm, void *data, size_t size,
int type, uint64_t offset, int fd,
uint64_t mrFlags, void **mhandle,
void **ginHandle)
{
NCCL_OFI_WARN("gin GDAKI: regMrSymDmaBuf not yet implemented");
return ncclInternalError;
}

static ncclResult_t nccl_ofi_gin_gdaki_deregMrSym(void *collComm, void *mhandle)
{
NCCL_OFI_WARN("gin GDAKI: deregMrSym not yet implemented");
return ncclSuccess;
}

static ncclResult_t nccl_ofi_gin_gdaki_closeColl(void *collComm)
{
NCCL_OFI_WARN("gin GDAKI: closeColl not yet implemented");
return ncclSuccess;
}

static ncclResult_t nccl_ofi_gin_gdaki_closeListen(void *listenComm)
{
NCCL_OFI_WARN("gin GDAKI: closeListen not yet implemented");
return ncclSuccess;
}

static ncclResult_t nccl_ofi_gin_gdaki_ginProgress(void *ginCtx)
{
NCCL_OFI_WARN("gin GDAKI: ginProgress not yet implemented");
return ncclSuccess;
}

static ncclResult_t nccl_ofi_gin_gdaki_queryLastError(void *ginCtx, bool *hasError)
{
NCCL_OFI_WARN("gin GDAKI: queryLastError not yet implemented");
*hasError = false;
return ncclSuccess;
}

static ncclResult_t nccl_ofi_gin_gdaki_finalize(void *ctx)
{
NCCL_OFI_WARN("gin GDAKI: finalize not yet implemented");
return ncclSuccess;
}

/*
* GDAKI plugin. Function pointers for the ncclGin_v13_t interface:
* - iput, iputSignal, iget, iflush, test are nullptr (no CPU involvement
* in GDAKI mode)
* - init, devices, listen, connect are copied from the proxy plugin at
* init time.
* GDAKI plugin. Shared APIs are wired directly from nccl_ofi_gin_api.cpp;
* GDAKI-specific ones above. iput/iputSignal/iget/iflush/test are nullptr —
* no CPU involvement in GDAKI mode.
*/
ncclGin_v13_t nccl_ofi_gin_gdaki_plugin = {
.name = "Libfabric_GDAKI",
.init = nullptr, /* Copied from proxy plugin at init time */
.devices = nullptr, /* Copied from proxy plugin at init time */
.init = nccl_ofi_gin_init,
.devices = nccl_ofi_gin_devices,
.getProperties = nccl_ofi_gin_gdaki_get_properties,
.listen = nullptr, /* Copied from proxy plugin at init time */
.connect = nullptr, /* Copied from proxy plugin at init time */
.listen = nccl_ofi_gin_listen,
.connect = nccl_ofi_gin_connect,
.createContext = nccl_ofi_gin_gdaki_createContext,
.regMrSym = nccl_ofi_gin_gdaki_regMrSym,
.regMrSymDmaBuf = nccl_ofi_gin_gdaki_regMrSymDmaBuf,
.deregMrSym = nccl_ofi_gin_gdaki_deregMrSym,
.regMrSym = nccl_ofi_gin_regMrSym,
.regMrSymDmaBuf = nccl_ofi_gin_regMrSymDmaBuf,
.deregMrSym = nccl_ofi_gin_deregMrSym,
.destroyContext = nccl_ofi_gin_gdaki_destroyContext,
.closeColl = nccl_ofi_gin_gdaki_closeColl,
.closeListen = nccl_ofi_gin_gdaki_closeListen,
.iput = nullptr, /* No CPU involvement in GDAKI mode */
.iputSignal = nullptr, /* No CPU involvement in GDAKI mode */
.iget = nullptr, /* No CPU involvement in GDAKI mode */
.iflush = nullptr, /* No CPU involvement in GDAKI mode */
.test = nullptr, /* No CPU involvement in GDAKI mode */
.ginProgress = nccl_ofi_gin_gdaki_ginProgress,
.closeColl = nccl_ofi_gin_closeColl,
.closeListen = nccl_ofi_gin_closeListen,
.iput = nullptr,
.iputSignal = nullptr,
.iget = nullptr,
.iflush = nullptr,
.test = nullptr,
.ginProgress = nccl_ofi_gin_ginProgress,
.queryLastError = nccl_ofi_gin_gdaki_queryLastError,
.finalize = nccl_ofi_gin_gdaki_finalize
.finalize = nccl_ofi_gin_finalize
};
Loading