Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
235 changes: 34 additions & 201 deletions include/nccl_ofi.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#define NCCL_OFI_H_

#include <unordered_map>
#include <memory>
#include <rdma/fabric.h>
#include <rdma/fi_errno.h>
#include <rdma/fi_domain.h>
Expand Down Expand Up @@ -285,8 +286,6 @@ class nccl_net_ofi_device_t {
int device_index,
struct fi_info *info);

virtual int release_device() = 0;

virtual int get_properties(nccl_ofi_properties_t *props) = 0;

/**
Expand All @@ -313,7 +312,7 @@ class nccl_net_ofi_device_t {
* of performance tradeoffs (be sure to read the domain
* description below).
*/
nccl_net_ofi_domain_t *get_domain(unsigned int domain_key = 0);
std::shared_ptr<nccl_net_ofi_domain_t> get_domain(unsigned int domain_key = 0);

/* Retrieve an endpoint associated with this device under the requested
* domain scope.
Expand All @@ -322,16 +321,7 @@ class nccl_net_ofi_device_t {
* @param endpoint_key Key for endpoint caching within the domain.
* Caller must provide an explicit key (e.g., TID or comm_id).
*/
nccl_net_ofi_ep_t *get_ep(unsigned int domain_key, long endpoint_key);

/**
* implementation of retreiving a domain from a device. This code
* assumes the device lock is already held, because in the case of
* get_domain() we only need to worry about the device lock, but in
* the device->get_ep call, hold the lock while we're also creating
* the ep.
*/
nccl_net_ofi_domain_t *nccl_net_ofi_device_get_domain_impl(unsigned int domain_key = 0);
std::shared_ptr<nccl_net_ofi_ep_t> get_ep(unsigned int domain_key, long endpoint_key);

/**
* @brief Erase all domain_table elements matching the provided domain
Expand Down Expand Up @@ -366,24 +356,14 @@ class nccl_net_ofi_device_t {
* multiple entities. */
std::mutex device_lock;

protected:
/**
* @brief Base device destructor
*
* Releases resources associated with base device.
*/
virtual ~nccl_net_ofi_device_t();

/**
* @brief Cleanup device resources.
*
* Virtual function to clean up and release each transport type's device resources.
* Set called_cleanup_resources to true at the start of the function to make sure
* it is only called once per device instance.
*
* @return 0 if successfully, negative error code on failure.
*/
virtual int cleanup_resources() = 0;
protected:

/*
* create a new domain. This funcion is a private pure
Expand All @@ -401,17 +381,15 @@ class nccl_net_ofi_device_t {
int release_all_domain_and_ep();

/**
* hash table indexed by thread id of active domains.
* Non-owning cache of domains, indexed by domain key.
* Uses weak_ptr so the table does not keep domains alive.
* Comms own endpoints (shared_ptr), endpoints own domains
* (shared_ptr). When the last comm releases its endpoint,
* the domain is destroyed and the weak_ptr expires. Stale
* entries are purged lazily on the next get_domain() miss.
*/
std::unordered_map<unsigned int, nccl_net_ofi_domain_t *> domain_table;
std::unordered_map<unsigned int, std::weak_ptr<nccl_net_ofi_domain_t>> domain_table;

/**
* Track whether the cleanup_resources function was already called to avoid calling
* multiple time on the same device instance. It being set to true does not
* indicate that the device resources were successfully released since this is set
* to true regardless of whether cleanup_resources finished successfully or not.
*/
bool called_cleanup_resources = false;
};


Expand All @@ -424,7 +402,7 @@ class nccl_net_ofi_device_t {
* generally it is expected that calls into resources that share the
* same domain will share the same lock.
*/
class nccl_net_ofi_domain_t {
class nccl_net_ofi_domain_t : public std::enable_shared_from_this<nccl_net_ofi_domain_t> {
public:
/**
* @brief Default constructor.
Expand Down Expand Up @@ -463,7 +441,7 @@ class nccl_net_ofi_domain_t {
*
* Pure virtual function to allocate a new endpoint structure
*/
virtual nccl_net_ofi_ep_t *create_endpoint() = 0;
virtual std::shared_ptr<nccl_net_ofi_ep_t> create_endpoint() = 0;

/**
* @brief Returns the base domain's device back-pointer.
Expand All @@ -473,24 +451,6 @@ class nccl_net_ofi_domain_t {
return device;
}

/**
* @brief Increments the base domain reference count.
* This needs to be protected with device_lock as
* domain life cycle is managed at device level
*/
inline void increment_ref_cnt() {
ref_cnt++;
}

/**
* @brief Decrements the base domain reference count.
* This needs to be protected with device_lock as
* domain life cycle is managed at device level
*/
inline void decrement_ref_cnt() {
ref_cnt--;
}

/*
* Retrieve an endpoint for this domain. If a suitable
* endpoint does not exist, call create_endpoint() to create
Expand All @@ -499,22 +459,7 @@ class nccl_net_ofi_domain_t {
* @param endpoint_key Key for endpoint caching.
* Caller must provide an explicit key (e.g., TID or comm_id).
*/
nccl_net_ofi_ep_t *get_ep(long endpoint_key);

/**
* @brief Release resources associated with the domain
*
* @param skip_device_lock
* false, taking device lock by default.
* ture, not taking device lock when caller takes it.
* @param force_cleanup
* false, not release when endpoint exists.
* true, release no matter endpoint exists nor not.
*
* TODO: Disable thread safety analysis until conditional locking is
* removed.
*/
int release_domain(bool skip_device_lock, bool force_cleanup) NO_THREAD_SAFETY_ANALYSIS;
std::shared_ptr<nccl_net_ofi_ep_t> get_ep(long endpoint_key);
Comment thread
bwbarrett marked this conversation as resolved.

/*
* Protocol-agnostic MR cache for this device.
Expand All @@ -531,83 +476,36 @@ class nccl_net_ofi_domain_t {
*/
void remove_ep_from_map(nccl_net_ofi_ep_t *ep);

/**
* @brief Increment base domain's unreleased_inactive_ep_counter
*/
inline void inc_unreleased_inactive_ep_counter()
{
++unreleased_inactive_ep_counter;
}

/**
* @brief Decrement base domain's unreleased_inactive_ep_counter
*/
inline void dec_unreleased_inactive_ep_counter()
{
--unreleased_inactive_ep_counter;
}

protected:
/**
* @brief Destructor.
*
* Cleans up base domain resources.
*/
virtual ~nccl_net_ofi_domain_t();

/**
* @brief Cleanup domain resources.
*
* Virtual function to clean up and release each transport type's domain resources.
* Set called_cleanup_resources to true at the start of the function to make sure
* it is only called once per domain instance.
*
* @return 0 if successfully, negative error code on failure.
*/
virtual int cleanup_resources() = 0;
protected:

/* Backpointer to the device associated with this domain. */
nccl_net_ofi_device_t *const device = nullptr;

/* The Domain index or a key in the device domain table */
const unsigned int domain_key;

/* Domain reference counter for resource management.
*
* In some modes (right now, endpoint_per_communicator), we create
* multiple endpoints per domain. This counter tracks the number
* of endpoints created on this domain. When it reaches 0, the
* domain can be destroyed. */
size_t ref_cnt;

/**
* release all endpoints. This function is a private
* function, which is called only during cleanup_resources() to free allocated
* endpoints.
* Release all endpoints by clearing the ep_table.
*/
int release_all_ep();

/**
* hash table indexed by thread id of active endpoints.
* Non-owning cache of endpoints, indexed by endpoint key
* (thread ID or comm_id for GIN). Uses weak_ptr so the
* table does not keep endpoints alive. Comms own endpoints
* via shared_ptr (ep). When the last comm releases
* its endpoint, the ep is destroyed and the weak_ptr
* expires. Stale entries are purged lazily on the next
* get_ep() miss.
*/
std::unordered_map<long, nccl_net_ofi_ep_t *> ep_table;

/**
* Number of endpoints that have been deactivated but not freed
*
* This counter is used for a diagnostic when the domain is closed,
* to track inactive ednpoint (which aren't in the ep table) which
* were never closed
*/
size_t unreleased_inactive_ep_counter = 0;

/**
* Track whether the cleanup_resources function was already called to avoid calling
* multiple time on the same domain instance. It being set to true does not
* indicate that the domain resources were successfully released since this is set
* to true regardless of whether cleanup_resources finished successfully or not.
*/
bool called_cleanup_resources = false;
std::unordered_map<long, std::weak_ptr<nccl_net_ofi_ep_t>> ep_table;
Comment thread
bwbarrett marked this conversation as resolved.
};


Expand All @@ -626,7 +524,7 @@ class nccl_net_ofi_domain_t {
* call to get_ep() or during initialization is left to the
* implementation.
*/
class nccl_net_ofi_ep_t {
class nccl_net_ofi_ep_t : public std::enable_shared_from_this<nccl_net_ofi_ep_t> {
public:
/**
* @brief Default constructor.
Expand All @@ -635,7 +533,7 @@ class nccl_net_ofi_ep_t {
* Expectation is that this will be called by a transport's endpoint
* constructor
*/
nccl_net_ofi_ep_t(nccl_net_ofi_domain_t *domain);
nccl_net_ofi_ep_t(std::shared_ptr<nccl_net_ofi_domain_t> domain);

/* Create a receiving object and provide a handle to it.
*
Expand Down Expand Up @@ -677,43 +575,6 @@ class nccl_net_ofi_ep_t {
*/
virtual ofi_cq_ptr &get_ofi_cq_for_cm() = 0;

/**
* @brief Release nccl_ofi_ep.
*
* Decrease reference counter. Release resources and free
* endpoint if reference counter becomes zero. Must be
* protected by lock stored in base_dev.
*
* @param skip_lock
* false, taking domain lock by default.
* ture, not taking domain lock when caller takes it.
* @param force_cleanup
* false, not release when endpoint has ref count.
* true, release no matter endpoint has ref count or not.
*
* TODO: Disable thread safety analysis until conditional locking is
* removed.
*/
virtual int release_ep(bool skip_lock, bool force_cleanup) NO_THREAD_SAFETY_ANALYSIS;

/**
* @brief Increments the base endpoint reference count.
* This needs to be protected with domain_lock as
* endpoint life cycle is managed at domain level
*/
inline void increment_ref_cnt() {
ref_cnt++;
}

/**
* @brief Decrements the base endpoint reference count.
* This needs to be protected with domain_lock as
* endpoint life cycle is managed at domain level
*/
inline void decrement_ref_cnt() {
ref_cnt--;
}

nccl_ofi_spinlock ep_lock;

/*
Expand Down Expand Up @@ -745,46 +606,16 @@ class nccl_net_ofi_ep_t {
return *domain;
}

protected:
/**
* @brief Virtual destructor.
* Virtual function called when resources associated with
* the ep should be destroyed. Device lock will be held when
* this function is called.
*/
virtual ~nccl_net_ofi_ep_t() = default;

/**
* @brief Cleanup endpoint resources.
*
* Virtual function to clean up and release each transport type's endpoint resources.
* Set called_cleanup_resources to true at the start of the function to make sure it
* is only called once per endpoint instance.
*
* @return 0 if successfully, negative error code on failure.
*/
virtual int cleanup_resources() = 0;

/* Backpointer to the domain associated with this ep. */
nccl_net_ofi_domain_t *domain = nullptr;

/**
* Track whether the cleanup_resources function was already called to avoid calling
* multiple time on the same endpoint instance. It being set to true does not
* indicate that the endpoint resources were successfully released since this is set
* to true regardless of whether cleanup_resources finished successfully or not.
*/
bool called_cleanup_resources = false;

/* Endpoint reference counter for resource management.
* sendrecv_get_ep()/sendrecv_release_ep() must be called in
* pair when an object is acquired to use and
* released. sendrecv_get_ep() allocates a new object when it
* is called for the first time. sendrecv_get_ep() creates the
* endpoint libfabric resources if the reference counter was
* zero. sendrecv_release_ep() releases the resources if the
* reference counter is decreased down to zero. */
int ref_cnt;
protected:
/* Backpointer to the domain associated with this ep.
* Holds a shared_ptr to keep the domain alive as long as
* this endpoint exists. */
std::shared_ptr<nccl_net_ofi_domain_t> domain;
};

enum nccl_net_ofi_comm_type_t {
Expand All @@ -807,7 +638,9 @@ class nccl_net_ofi_comm {
virtual ~nccl_net_ofi_comm() = default;

enum nccl_net_ofi_comm_type_t type;
nccl_net_ofi_ep_t *ep;
/* Shared ownership of the endpoint. Keeps ep alive as long
* as this comm exists. */
std::shared_ptr<nccl_net_ofi_ep_t> ep;
int dev_id;
};

Expand Down
Loading
Loading