Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 71 additions & 0 deletions candle-core/src/cuda_backend/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,25 @@ use std::sync::{Arc, Mutex, RwLock};

use super::{CudaError, CudaStorage, CudaStorageSlice, WrapErr};

/// Enable peer access from the *currently bound* CUDA context to the
/// supplied peer context. Caller is responsible for ensuring the right
/// context is current (`bind_to_thread`) before calling.
///
/// Treats `CUDA_ERROR_PEER_ACCESS_ALREADY_ENABLED` (704) as success so
/// the public `enable_peer_access` is safely idempotent across repeat
/// calls between the same context pair. Other errors surface via
/// `WrapErr`.
fn enable_peer_access_one_way(peer_ctx: cudarc::driver::sys::CUcontext) -> Result<()> {
let res = unsafe { cudarc::driver::sys::cuCtxEnablePeerAccess(peer_ctx, 0) };
if res == cudarc::driver::sys::cudaError_enum::CUDA_SUCCESS
|| res == cudarc::driver::sys::cudaError_enum::CUDA_ERROR_PEER_ACCESS_ALREADY_ENABLED
{
Ok(())
} else {
Err(cudarc::driver::DriverError(res)).w()
}
}

/// Unique identifier for cuda devices.
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub struct DeviceId(usize);
Expand Down Expand Up @@ -173,6 +192,58 @@ impl CudaDevice {
self.context.is_event_tracking()
}

/// Enable peer access between this device's context and another's, in
/// both directions, so that GPU-direct cross-card tensor operations
/// (`memcpy_peer_async` / `Tensor::to_device(&other_cuda)`) can route
/// over NVLink (or PCIe P2P) instead of erroring with
/// `CUDA_ERROR_INVALID_CONTEXT`.
///
/// Idempotent in two senses:
/// - Calling on the same `CudaDevice` (i.e. same ordinal) is a no-op,
/// returning `Ok(())`. Same-context "peer" access is meaningless.
/// - Calling repeatedly between the same two contexts is safe.
/// `cuCtxEnablePeerAccess` returns `CUDA_ERROR_PEER_ACCESS_ALREADY_ENABLED`
/// on the second call, which this method maps to `Ok(())`.
///
/// Operators that hold multiple `CudaDevice` instances on different
/// ordinals must call this between every pair they intend to do
/// cross-card transfers across. Without it, cudarc's safe
/// `memcpy_dtod` correctly dispatches to `memcpy_peer_async` when
/// source and destination contexts differ — but `memcpy_peer_async`
/// requires peer access to have been enabled first, otherwise the
/// driver rejects with `CUDA_ERROR_INVALID_CONTEXT`.
///
/// # Errors
///
/// Returns the underlying CUDA error if the device pair doesn't
/// support peer access (e.g. some heterogeneous or
/// IOMMU-isolated configurations) or if either context is in a
/// terminal state. Check `cuDeviceCanAccessPeer` separately if you
/// need to probe support before attempting to enable it.
pub fn enable_peer_access(&self, other: &Self) -> Result<()> {
let self_ord = self.context.ordinal();
let other_ord = other.context.ordinal();
if self_ord == other_ord {
// Same physical device. No peer to enable; not an error.
return Ok(());
}
let self_ctx = self.context.cu_ctx();
let other_ctx = other.context.cu_ctx();
// self ←→ other: enable each direction. The driver rejects
// `cuCtxEnablePeerAccess` with PEER_ACCESS_ALREADY_ENABLED if
// the call is a no-op (already enabled); fold that into Ok(())
// so the helper is safely idempotent across repeat calls.
self.context.bind_to_thread().w()?;
enable_peer_access_one_way(other_ctx)?;
other.context.bind_to_thread().w()?;
enable_peer_access_one_way(self_ctx)?;
// Restore self as current — bind_to_thread above pushed `other`
// onto this OS thread; callers that proceed with `self` work
// immediately after this method shouldn't have to re-bind.
self.context.bind_to_thread().w()?;
Ok(())
}

#[cfg(all(feature = "ug", not(target_arch = "wasm32")))]
pub fn compile(
&self,
Expand Down
25 changes: 25 additions & 0 deletions candle-core/src/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,31 @@ impl Device {
Ok(Self::Cuda(crate::CudaDevice::new_with_stream(ordinal)?))
}

/// Enable bidirectional peer access between two CUDA devices on
/// different ordinals so GPU-direct cross-card tensor operations
/// (`Tensor::to_device(&other_cuda)` →
/// `cudarc::CudaStream::clone_dtod` → `memcpy_peer_async`) succeed.
///
/// Idempotent: same-ordinal pairs and already-enabled pairs both
/// return `Ok(())`. See
/// [`crate::CudaDevice::enable_peer_access`] for the underlying
/// semantics.
///
/// Returns an error if either device is not CUDA (e.g. `Cpu` or
/// `Metal`), or if the underlying driver call rejects the request
/// (peer access unsupported on this hardware pair, etc.).
#[cfg(feature = "cuda")]
pub fn enable_peer_access(&self, other: &Self) -> Result<()> {
match (self, other) {
(Self::Cuda(a), Self::Cuda(b)) => a.enable_peer_access(b),
_ => crate::bail!(
"enable_peer_access requires two CUDA devices, got {:?} and {:?}",
self.location(),
other.location()
),
}
}

pub fn new_metal(ordinal: usize) -> Result<Self> {
Ok(Self::Metal(crate::MetalDevice::new(ordinal)?))
}
Expand Down