Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 81 additions & 0 deletions candle-core/src/conv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,19 @@ impl Tensor {
};
if groups == 1 {
self.conv1d_single_group(kernel, &params)
} else if c_in_k == 1 && c_out == c_in && groups == c_in && stride == 1 && dilation == 1 {
// Depthwise conv1d (groups == in_channels, unit stride/dilation, no channel
// multiplier). The per-group decomposition below launches O(groups) tiny kernels
// (plus a `cat`), dominated by launch overhead on CUDA — see issue #3389, where
// depthwise conv layers were 54% of total inference time on an RTX 5090 because of
// ~6000 kernel launches for `groups=2048, k=3`. Depthwise conv is just a per-channel
// weighted sum over the kernel window, expressible as `k_size` slices each scaled by a
// per-channel scalar and accumulated: a fixed number of elementwise kernels (~2*k_size)
// regardless of the channel count, numerically equivalent to the existing kernels (up
// to floating-point summation order, just like the CPU/CUDA backends already differ).
// The stride/dilation==1 guard is only because candle has no strided-narrow primitive
// yet; other strides keep the old per-group path.
self.conv1d_depthwise(kernel, &params)
} else {
let blocks = self.chunk(groups, 1)?;
let kernel = kernel.chunk(groups, 0)?;
Expand All @@ -204,6 +217,37 @@ impl Tensor {
}
}

// Depthwise 1D convolution (stride == dilation == 1) with elementwise ops only, no per-group
// loop. `self`: (b_size, c, l_in); `kernel`: (c, 1, k_size); out: (b_size, c, l_out).
fn conv1d_depthwise(&self, kernel: &Self, params: &ParamsConv1D) -> Result<Self> {
let (b_size, c, _l_in) = self.dims3()?;
let l_out = params.l_out();
// (b_size, c, l_in + 2*padding)
let padded = if params.padding == 0 {
self.clone()
} else {
self.pad_with_zeros(2, params.padding, params.padding)?
};
// (c, 1, k_size) -> (1, c, k_size) so it broadcasts over the batch dimension.
let kernel = kernel.reshape((c, params.k_size))?.unsqueeze(0)?;
let mut out: Option<Tensor> = None;
for k in 0..params.k_size {
// out[b, c, t] += padded[b, c, t + k] * kernel[c, k] (stride == dilation == 1)
let slice = padded.narrow(2, k, l_out)?; // (b_size, c, l_out)
let w_k = kernel.narrow(2, k, 1)?; // (1, c, 1)
let term = slice.broadcast_mul(&w_k)?;
out = Some(match out {
None => term,
Some(acc) => (acc + term)?,
});
}
match out {
Some(out) => Ok(out),
// k_size == 0 is rejected upstream by the kernel-shape checks; be defensive anyway.
None => Tensor::zeros((b_size, c, l_out), self.dtype(), self.device()),
}
}

fn conv_transpose1d_single_group(
&self,
kernel: &Self,
Expand Down Expand Up @@ -328,6 +372,11 @@ impl Tensor {
};
if groups == 1 {
self.conv2d_single_group(kernel, &params)
} else if c_in_k == 1 && c_out == c_in && groups == c_in && stride == 1 && dilation == 1 {
// Depthwise conv2d — see `conv1d_with_algo` for the rationale (issue #3389).
// `k_h * k_w` slices, each scaled by a per-channel scalar and accumulated: a fixed
// number of elementwise kernels independent of the channel count.
self.conv2d_depthwise(kernel, &params)
} else {
let blocks = self.chunk(groups, 1)?;
let kernel = kernel.chunk(groups, 0)?;
Expand All @@ -340,6 +389,38 @@ impl Tensor {
}
}

// Depthwise 2D convolution (stride == dilation == 1) implemented with elementwise ops only.
// `self`: (b, c, h, w); `kernel`: (c, 1, k_h, k_w); out: (b, c, out_h, out_w).
fn conv2d_depthwise(&self, kernel: &Self, params: &ParamsConv2D) -> Result<Self> {
let (b_size, c, _i_h, _i_w) = self.dims4()?;
let (out_h, out_w) = (params.out_h(), params.out_w());
let padded = if params.padding == 0 {
self.clone()
} else {
self.pad_with_zeros(2, params.padding, params.padding)?
.pad_with_zeros(3, params.padding, params.padding)?
};
// (c, 1, k_h, k_w) -> (1, c, k_h, k_w) to broadcast over the batch dimension.
let kernel = kernel.reshape((c, params.k_h, params.k_w))?.unsqueeze(0)?;
let mut out: Option<Tensor> = None;
for kh in 0..params.k_h {
for kw in 0..params.k_w {
// out[b, c, i, j] += padded[b, c, i + kh, j + kw] * kernel[c, kh, kw]
let slice = padded.narrow(2, kh, out_h)?.narrow(3, kw, out_w)?; // (b, c, out_h, out_w)
let w = kernel.narrow(2, kh, 1)?.narrow(3, kw, 1)?; // (1, c, 1, 1)
let term = slice.broadcast_mul(&w)?;
out = Some(match out {
None => term,
Some(acc) => (acc + term)?,
});
}
}
match out {
Some(out) => Ok(out),
None => Tensor::zeros((b_size, c, out_h, out_w), self.dtype(), self.device()),
}
}

/// Applies a 2D transposed convolution over the input tensor.
pub fn conv_transpose2d(
&self,
Expand Down
73 changes: 73 additions & 0 deletions candle-core/tests/conv_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -970,3 +970,76 @@ test_device!(
conv2d_c_eq_h_eq_w_gpu,
conv2d_c_eq_h_eq_w_metal
);

// Depthwise convolutions (groups == in_channels). The fast depthwise path
// (groups == c_in, c_in_k == 1, stride == dilation == 1) is exercised here and
// cross-checked against an *independent* code path: building the equivalent
// block-diagonal dense weight (c_out, c_in, k...) and running a plain
// `groups == 1` convolution (the im2col / cuDNN path). They must agree up to
// floating-point summation order.
fn conv1d_depthwise(dev: &Device) -> Result<()> {
let (b, c, l, k) = (2usize, 6usize, 13usize, 3usize);
let t = Tensor::randn(0f32, 1f32, (b, c, l), dev)?;
// Depthwise weight: (c_out=c, c_in_k=1, k).
let w = Tensor::randn(0f32, 1f32, (c, 1, k), dev)?;
// Equivalent block-diagonal dense weight (c_out=c, c_in=c, k), zero off the diagonal:
// w (c,1,k) -> (c,k,1); eye (c,c) -> (c,1,c); product -> (c,k,c) -> transpose -> (c,c,k).
let dense = {
let eye = Tensor::eye(c, t.dtype(), dev)?.unsqueeze(1)?; // (c, 1, c)
let wk = w.reshape((c, k))?.unsqueeze(2)?; // (c, k, 1)
wk.broadcast_mul(&eye)?.transpose(1, 2)?.contiguous()? // (c, c, k)
};
for &padding in &[0usize, 1usize] {
let l_out = l + 2 * padding - (k - 1);
let fast = t.conv1d(&w, padding, 1, 1, c)?;
assert_eq!(fast.dims(), &[b, c, l_out]);
let reference = t.conv1d(&dense, padding, 1, 1, 1)?;
let diff: f32 = (fast - reference)?
.abs()?
.flatten_all()?
.max(0)?
.to_scalar()?;
assert!(diff < 1e-4, "depthwise conv1d mismatch: {diff}");
}
Ok(())
}

fn conv2d_depthwise(dev: &Device) -> Result<()> {
let (b, c, h, w_, k) = (2usize, 5usize, 5usize, 7usize, 3usize);
let t = Tensor::randn(0f32, 1f32, (b, c, h, w_), dev)?;
// Depthwise weight: (c_out=c, c_in_k=1, k, k).
let weight = Tensor::randn(0f32, 1f32, (c, 1, k, k), dev)?;
// Equivalent block-diagonal dense weight (c_out=c, c_in=c, k, k), zero off the diagonal.
let dense = {
let eye = Tensor::eye(c, t.dtype(), dev)?.reshape((c, c, 1, 1))?;
let wk = weight.reshape((c, 1, k, k))?;
wk.broadcast_mul(&eye)?.contiguous()? // (c, c, k, k)
};
for &padding in &[0usize, 1usize] {
let oh = h + 2 * padding - (k - 1);
let ow = w_ + 2 * padding - (k - 1);
let fast = t.conv2d(&weight, padding, 1, 1, c)?;
assert_eq!(fast.dims(), &[b, c, oh, ow]);
let reference = t.conv2d(&dense, padding, 1, 1, 1)?;
let diff: f32 = (fast - reference)?
.abs()?
.flatten_all()?
.max(0)?
.to_scalar()?;
assert!(diff < 1e-4, "depthwise conv2d mismatch: {diff}");
}
Ok(())
}

test_device!(
conv1d_depthwise,
conv1d_depthwise_cpu,
conv1d_depthwise_gpu,
conv1d_depthwise_metal
);
test_device!(
conv2d_depthwise,
conv2d_depthwise_cpu,
conv2d_depthwise_gpu,
conv2d_depthwise_metal
);