Skip to content

candle-core: add Device::enable_peer_access for cross-CUDA P2P transfers#3525

Open
toddwbucy wants to merge 1 commit into
huggingface:mainfrom
toddwbucy:feat/peer-access-upstream
Open

candle-core: add Device::enable_peer_access for cross-CUDA P2P transfers#3525
toddwbucy wants to merge 1 commit into
huggingface:mainfrom
toddwbucy:feat/peer-access-upstream

Conversation

@toddwbucy
Copy link
Copy Markdown

Fixes #3524.

Adds explicit `Device::enable_peer_access(&other)` so GPU-direct cross-card tensor operations (`Tensor::to_device(&other_cuda_device)`) succeed instead of erroring with `CUDA_ERROR_INVALID_CONTEXT` on first use.

What

  • `CudaDevice::enable_peer_access(&self, other: &Self) -> Result<()>` — calls `cuCtxEnablePeerAccess` in both directions (`self ←→ other`), bound to the appropriate context for each direction.
  • `Device::enable_peer_access(&self, other: &Self) -> Result<()>` — public surface, cuda-only (`#[cfg(feature = "cuda")]`); errors clearly if either side isn't a CUDA device.

Idempotent

  • Same-ordinal pairs no-op, returning `Ok(())` — same-context "peer" access is meaningless.
  • Repeat calls between the same context pair are safe — the driver's `CUDA_ERROR_PEER_ACCESS_ALREADY_ENABLED` (704) is folded into `Ok(())` inside the helper. So callers can put `enable_peer_access` on a hot path without worrying about state tracking.

Why explicit, not auto-enable in `BackendDevice::new`

Per #3524 I considered (A) explicit opt-in vs (B) opportunistic auto-enable in `BackendDevice::new`. Going with (A) here:

  • Peer access is a real driver-state mutation with costs (UVA coordination on platforms that have it, no-op rejection on heterogeneous configs). Surfacing the opt-in keeps those visible.
  • (B) requires global state — a registry of all already-constructed `CudaDevice`s — which is a much larger surgery and creates surprising side effects (every `Device::new_cuda(N)` would mutate driver state for every other extant CUDA device).
  • (A) is forward-compatible with (B) — if you later want auto-enable, `BackendDevice::new` can just call this same helper.

Happy to switch to (B) instead if maintainers prefer that direction; the underlying `enable_peer_access_one_way` helper is the same either way.

Errors

Returns the underlying `DriverError` if the device pair doesn't support peer access (some heterogeneous or IOMMU-isolated configurations) or if either context is in a terminal state. Callers wanting to probe support before attempting can use `cuDeviceCanAccessPeer` separately; not wrapped here.

Validation

  • `cargo build -p candle-core --features cuda` — clean
  • Tested in a downstream project's topology bench (2× A6000 + NVLink, also tested over PCIe with NVLink bridge physically removed) where `Tensor::to_device(&other_cuda_device)` previously errored with `INVALID_CONTEXT`. With this patch + a one-time `encoder_device.enable_peer_access(&decoder_device)?` call before the first transfer, the cross-card path succeeds and routes through NVLink (or PCIe P2P).

Adjacent

This is the third small fix from our Jina V4 + multi-GPU work. The other two are #3520 (qwen2 RoPE fp32 cos/sin tables) and #3521 (FA v2.8.3 vendored kernel bump). All three were independently discovered while bringing a Jina V4 embedder up under candle on a 2-A6000 rig; happy to provide more context on any of them.

`Tensor::to_device(&other_cuda_device)` dispatches to
`CudaStorage::transfer_to_device` → `cudarc::CudaStream::clone_dtod` →
`memcpy_peer_async` when the two devices have different contexts. But
`memcpy_peer_async` requires `cuCtxEnablePeerAccess` to have been
called between the two contexts first; without it the driver rejects
with `CUDA_ERROR_INVALID_CONTEXT`. Result: any GPU-direct cross-
`CudaDevice` transfer fails on first use.

Add explicit opt-in at the public `Device::enable_peer_access(&other)`
API:

  - `CudaDevice::enable_peer_access(&self, other: &Self)` calls
    `cuCtxEnablePeerAccess` in both directions (self ←→ other), bound
    to the appropriate context for each direction.
  - `Device::enable_peer_access(&self, other: &Self)` (cuda-only)
    delegates to `CudaDevice` when both are CUDA, errors otherwise.

Idempotent in two senses:

  - Same-ordinal pairs are a no-op (`Ok(())`).
  - Repeat calls between the same context pair are safe; the driver's
    `CUDA_ERROR_PEER_ACCESS_ALREADY_ENABLED` is folded into `Ok(())`
    inside the helper.

Operators must call this before doing cross-card transfers; we do
not auto-enable in `BackendDevice::new` because that requires global
state tracking of all already-constructed CudaDevices, which would
be a much larger surgery.

Tested: discovered during a topology bench (encoder GPU0 ↔ decoder
GPU1, NVLink-bridged) where `Tensor::to_device` errored with
INVALID_CONTEXT. With this patch + a one-time
`encoder_device.enable_peer_access(&decoder_device)?` call before
the first transfer, the cross-card path succeeds and routes through
the NVLink bridge.

Builds: `cargo build -p candle-core --features cuda` — clean.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Cross-CUDA Tensor::to_device fails with CUDA_ERROR_INVALID_CONTEXT (no peer-access enable)

2 participants