diff --git a/include/rdma/gin/nccl_ofi_gin_gdaki.h b/include/rdma/gin/nccl_ofi_gin_gdaki.h index c65762d236..eae1c80757 100644 --- a/include/rdma/gin/nccl_ofi_gin_gdaki.h +++ b/include/rdma/gin/nccl_ofi_gin_gdaki.h @@ -13,9 +13,31 @@ bool nccl_ofi_gin_gdaki_enabled(); /* - * The GDAKI plugin. Shared functions (init, devices, listen, connect) - * are nullptr and get copied from the proxy plugin at init time. + * The GDAKI plugin. Shared plugin APIs (declared below) are assigned into + * this plugin at compile time; GDAKI-specific APIs live in + * nccl_ofi_gin_gdaki.cpp. */ extern ncclGin_v13_t nccl_ofi_gin_gdaki_plugin; +/* + * Shared plugin APIs — defined in nccl_ofi_gin_api.cpp. GDAKI reuses these + * directly because they operate on shared types (nccl_ofi_rdma_gin_put_comm + * etc.) produced by connect() in both proxy and GDAKI modes. + */ +ncclResult_t nccl_ofi_gin_init(void **ctx, uint64_t commId, ncclDebugLogger_t logFunction); +ncclResult_t nccl_ofi_gin_devices(int *ndev); +ncclResult_t nccl_ofi_gin_listen(void *ctx, int dev, void *handle, void **listenComm); +ncclResult_t nccl_ofi_gin_connect(void *ctx, void *handles[], int nranks, int rank, + void *listenComm, void **collComm); +ncclResult_t nccl_ofi_gin_regMrSym(void *collComm, void *data, size_t size, int type, + uint64_t mrFlags, void **mhandle, void **ginHandle); +ncclResult_t nccl_ofi_gin_regMrSymDmaBuf(void *collComm, void *data, size_t size, int type, + uint64_t offset, int fd, uint64_t mrFlags, + void **mhandle, void **ginHandle); +ncclResult_t nccl_ofi_gin_deregMrSym(void *collComm, void *mhandle); +ncclResult_t nccl_ofi_gin_closeColl(void *collComm); +ncclResult_t nccl_ofi_gin_closeListen(void *listenComm); +ncclResult_t nccl_ofi_gin_ginProgress(void *collComm); +ncclResult_t nccl_ofi_gin_finalize(void *ctx); + #endif /* NCCL_OFI_GIN_GDAKI_H_ */ diff --git a/src/rdma/gin/nccl_ofi_gin_api.cpp b/src/rdma/gin/nccl_ofi_gin_api.cpp index c725773f74..602f9edc9d 100644 --- a/src/rdma/gin/nccl_ofi_gin_api.cpp +++ b/src/rdma/gin/nccl_ofi_gin_api.cpp @@ -30,7 +30,7 @@ struct nccl_ofi_gin_context { explicit nccl_ofi_gin_context(uint64_t id) : comm_id(id) {} }; -static ncclResult_t nccl_ofi_gin_init(void **ctx, uint64_t commId, ncclDebugLogger_t logFunction) +ncclResult_t nccl_ofi_gin_init(void **ctx, uint64_t commId, ncclDebugLogger_t logFunction) { if (ofi_log_function == nullptr) { ofi_log_function = logFunction; @@ -76,25 +76,19 @@ static ncclResult_t nccl_ofi_gin_init(void **ctx, uint64_t commId, ncclDebugLogg } /* - * Morph the exported plugin to GDAKI if requested. - * - * Copy shared functions (init, devices, listen, connect) from the - * proxy plugin into the GDAKI plugin, then overwrite the exported - * symbol with the GDAKI plugin. + * Morph the exported plugin to GDAKI if requested. Shared plugin APIs + * are wired into nccl_ofi_gin_gdaki_plugin at compile time, so we just + * overwrite the exported symbol. */ if (nccl_ofi_gin_gdaki_enabled()) { NCCL_OFI_INFO(NCCL_NET | NCCL_INIT, "gin: GDAKI mode enabled (OFI_NCCL_GIN_GDAKI=1)"); - nccl_ofi_gin_gdaki_plugin.init = ncclGinPlugin_v13.init; - nccl_ofi_gin_gdaki_plugin.devices = ncclGinPlugin_v13.devices; - nccl_ofi_gin_gdaki_plugin.listen = ncclGinPlugin_v13.listen; - nccl_ofi_gin_gdaki_plugin.connect = ncclGinPlugin_v13.connect; memcpy(&ncclGinPlugin_v13, &nccl_ofi_gin_gdaki_plugin, sizeof(ncclGinPlugin_v13)); } return ncclSuccess; } -static ncclResult_t nccl_ofi_gin_devices(int *ndev) +ncclResult_t nccl_ofi_gin_devices(int *ndev) { return nccl_net_ofi_devices(ndev); } @@ -135,7 +129,7 @@ static ncclResult_t nccl_ofi_gin_getProperties(int dev, ncclNetProperties_v11_t return ncclSuccess; } -static ncclResult_t nccl_ofi_gin_listen(void *ctx, int dev, void *handle, void **listenComm) +ncclResult_t nccl_ofi_gin_listen(void *ctx, int dev, void *handle, void **listenComm) { /* Extract communicator ID from GIN context */ nccl_ofi_gin_context *context = static_cast(ctx); @@ -195,7 +189,7 @@ static ncclResult_t nccl_ofi_gin_listen(void *ctx, int dev, void *handle, void * return ncclSuccess; } -static ncclResult_t nccl_ofi_gin_connect(void *ctx, void *handles[], int nranks, int rank, +ncclResult_t nccl_ofi_gin_connect(void *ctx, void *handles[], int nranks, int rank, void *listenComm, void **collComm) { auto *gin_handles = reinterpret_cast(handles); @@ -214,7 +208,7 @@ static ncclResult_t nccl_ofi_gin_connect(void *ctx, void *handles[], int nranks, return nccl_net_ofi_retval_translate(ret); } -static ncclResult_t nccl_ofi_gin_regMrSymDmaBuf(void *collComm, void *data, size_t size, int type, +ncclResult_t nccl_ofi_gin_regMrSymDmaBuf(void *collComm, void *data, size_t size, int type, uint64_t offset, int fd, uint64_t mrFlags, void **mhandle, void **ginHandle) { @@ -243,14 +237,14 @@ static ncclResult_t nccl_ofi_gin_regMrSymDmaBuf(void *collComm, void *data, size return ncclSuccess; } -static ncclResult_t nccl_ofi_gin_regMrSym(void *collComm, void *data, size_t size, int type, +ncclResult_t nccl_ofi_gin_regMrSym(void *collComm, void *data, size_t size, int type, uint64_t mrFlags, void **mhandle, void **ginHandle) { return nccl_ofi_gin_regMrSymDmaBuf(collComm, data, size, type, 0, -1, mrFlags, mhandle, ginHandle); } -static ncclResult_t nccl_ofi_gin_deregMrSym(void *collComm, void *mhandle) +ncclResult_t nccl_ofi_gin_deregMrSym(void *collComm, void *mhandle) { auto *comm = static_cast(collComm); auto *mr_handle = static_cast(mhandle); @@ -263,7 +257,7 @@ static ncclResult_t nccl_ofi_gin_deregMrSym(void *collComm, void *mhandle) return ncclSuccess; } -static ncclResult_t nccl_ofi_gin_ginProgress(void *collComm) +ncclResult_t nccl_ofi_gin_ginProgress(void *collComm) { auto *gin_comm = static_cast(collComm); int ret = gin_comm->get_resources().progress(); @@ -275,7 +269,7 @@ static ncclResult_t nccl_ofi_gin_ginProgress(void *collComm) return nccl_net_ofi_retval_translate(ret); } -static ncclResult_t nccl_ofi_gin_closeColl(void *collComm) +ncclResult_t nccl_ofi_gin_closeColl(void *collComm) { auto *gin_comm = static_cast(collComm); @@ -286,7 +280,7 @@ static ncclResult_t nccl_ofi_gin_closeColl(void *collComm) return nccl_net_ofi_retval_translate(ret); } -static ncclResult_t nccl_ofi_gin_closeListen(void *listenComm) +ncclResult_t nccl_ofi_gin_closeListen(void *listenComm) { delete static_cast(listenComm); return ncclSuccess; @@ -330,7 +324,7 @@ static ncclResult_t nccl_ofi_gin_iput(void *collComm, uint64_t srcOff, void *src 0, nullptr, 0, 0, request); } -static ncclResult_t nccl_ofi_gin_finalize(void *ctx) +ncclResult_t nccl_ofi_gin_finalize(void *ctx) { /* Clean up the GIN context structure. If ctx is NULL, init() was never called or failed, so there's diff --git a/src/rdma/gin/nccl_ofi_gin_gdaki.cpp b/src/rdma/gin/nccl_ofi_gin_gdaki.cpp index e731a40d9f..3cf60d388f 100644 --- a/src/rdma/gin/nccl_ofi_gin_gdaki.cpp +++ b/src/rdma/gin/nccl_ofi_gin_gdaki.cpp @@ -1,10 +1,11 @@ /* * Copyright (c) 2026 Amazon.com, Inc. or its affiliates. All rights reserved. * - * GDAKI stub implementations for the GIN plugin API. - * These provide the full ncclGin_v13_t plugin for GDAKI mode. - * - * Task: GDAKI stub implementation + * GDAKI plugin for the GIN API. Shared APIs (init, devices, listen, connect, + * regMrSym[DmaBuf], deregMrSym, closeColl, closeListen, ginProgress, finalize) + * are reused from the proxy-side implementations in nccl_ofi_gin_api.cpp. + * Only the GDAKI-specific stubs (createContext/destroyContext/get_properties/ + * regMr*-return-error/queryLastError) live here until they are implemented. */ #include "config.h" @@ -73,86 +74,37 @@ static ncclResult_t nccl_ofi_gin_gdaki_destroyContext(void *ginCtx) return ncclSuccess; } -static ncclResult_t nccl_ofi_gin_gdaki_regMrSym(void *collComm, void *data, size_t size, int type, - uint64_t mrFlags, void **mhandle, void **ginHandle) -{ - NCCL_OFI_WARN("gin GDAKI: regMrSym not yet implemented"); - return ncclInternalError; -} - -static ncclResult_t nccl_ofi_gin_gdaki_regMrSymDmaBuf(void *collComm, void *data, size_t size, - int type, uint64_t offset, int fd, - uint64_t mrFlags, void **mhandle, - void **ginHandle) -{ - NCCL_OFI_WARN("gin GDAKI: regMrSymDmaBuf not yet implemented"); - return ncclInternalError; -} - -static ncclResult_t nccl_ofi_gin_gdaki_deregMrSym(void *collComm, void *mhandle) -{ - NCCL_OFI_WARN("gin GDAKI: deregMrSym not yet implemented"); - return ncclSuccess; -} - -static ncclResult_t nccl_ofi_gin_gdaki_closeColl(void *collComm) -{ - NCCL_OFI_WARN("gin GDAKI: closeColl not yet implemented"); - return ncclSuccess; -} - -static ncclResult_t nccl_ofi_gin_gdaki_closeListen(void *listenComm) -{ - NCCL_OFI_WARN("gin GDAKI: closeListen not yet implemented"); - return ncclSuccess; -} - -static ncclResult_t nccl_ofi_gin_gdaki_ginProgress(void *ginCtx) -{ - NCCL_OFI_WARN("gin GDAKI: ginProgress not yet implemented"); - return ncclSuccess; -} - static ncclResult_t nccl_ofi_gin_gdaki_queryLastError(void *ginCtx, bool *hasError) { - NCCL_OFI_WARN("gin GDAKI: queryLastError not yet implemented"); *hasError = false; return ncclSuccess; } -static ncclResult_t nccl_ofi_gin_gdaki_finalize(void *ctx) -{ - NCCL_OFI_WARN("gin GDAKI: finalize not yet implemented"); - return ncclSuccess; -} - /* - * GDAKI plugin. Function pointers for the ncclGin_v13_t interface: - * - iput, iputSignal, iget, iflush, test are nullptr (no CPU involvement - * in GDAKI mode) - * - init, devices, listen, connect are copied from the proxy plugin at - * init time. + * GDAKI plugin. Shared APIs are wired directly from nccl_ofi_gin_api.cpp; + * GDAKI-specific ones above. iput/iputSignal/iget/iflush/test are nullptr — + * no CPU involvement in GDAKI mode. */ ncclGin_v13_t nccl_ofi_gin_gdaki_plugin = { .name = "Libfabric_GDAKI", - .init = nullptr, /* Copied from proxy plugin at init time */ - .devices = nullptr, /* Copied from proxy plugin at init time */ + .init = nccl_ofi_gin_init, + .devices = nccl_ofi_gin_devices, .getProperties = nccl_ofi_gin_gdaki_get_properties, - .listen = nullptr, /* Copied from proxy plugin at init time */ - .connect = nullptr, /* Copied from proxy plugin at init time */ + .listen = nccl_ofi_gin_listen, + .connect = nccl_ofi_gin_connect, .createContext = nccl_ofi_gin_gdaki_createContext, - .regMrSym = nccl_ofi_gin_gdaki_regMrSym, - .regMrSymDmaBuf = nccl_ofi_gin_gdaki_regMrSymDmaBuf, - .deregMrSym = nccl_ofi_gin_gdaki_deregMrSym, + .regMrSym = nccl_ofi_gin_regMrSym, + .regMrSymDmaBuf = nccl_ofi_gin_regMrSymDmaBuf, + .deregMrSym = nccl_ofi_gin_deregMrSym, .destroyContext = nccl_ofi_gin_gdaki_destroyContext, - .closeColl = nccl_ofi_gin_gdaki_closeColl, - .closeListen = nccl_ofi_gin_gdaki_closeListen, - .iput = nullptr, /* No CPU involvement in GDAKI mode */ - .iputSignal = nullptr, /* No CPU involvement in GDAKI mode */ - .iget = nullptr, /* No CPU involvement in GDAKI mode */ - .iflush = nullptr, /* No CPU involvement in GDAKI mode */ - .test = nullptr, /* No CPU involvement in GDAKI mode */ - .ginProgress = nccl_ofi_gin_gdaki_ginProgress, + .closeColl = nccl_ofi_gin_closeColl, + .closeListen = nccl_ofi_gin_closeListen, + .iput = nullptr, + .iputSignal = nullptr, + .iget = nullptr, + .iflush = nullptr, + .test = nullptr, + .ginProgress = nccl_ofi_gin_ginProgress, .queryLastError = nccl_ofi_gin_gdaki_queryLastError, - .finalize = nccl_ofi_gin_gdaki_finalize + .finalize = nccl_ofi_gin_finalize };