diff --git a/candle-core/src/conv.rs b/candle-core/src/conv.rs index 115035ef1c..e4f1996229 100644 --- a/candle-core/src/conv.rs +++ b/candle-core/src/conv.rs @@ -192,6 +192,19 @@ impl Tensor { }; if groups == 1 { self.conv1d_single_group(kernel, ¶ms) + } else if c_in_k == 1 && c_out == c_in && groups == c_in && stride == 1 && dilation == 1 { + // Depthwise conv1d (groups == in_channels, unit stride/dilation, no channel + // multiplier). The per-group decomposition below launches O(groups) tiny kernels + // (plus a `cat`), dominated by launch overhead on CUDA — see issue #3389, where + // depthwise conv layers were 54% of total inference time on an RTX 5090 because of + // ~6000 kernel launches for `groups=2048, k=3`. Depthwise conv is just a per-channel + // weighted sum over the kernel window, expressible as `k_size` slices each scaled by a + // per-channel scalar and accumulated: a fixed number of elementwise kernels (~2*k_size) + // regardless of the channel count, numerically equivalent to the existing kernels (up + // to floating-point summation order, just like the CPU/CUDA backends already differ). + // The stride/dilation==1 guard is only because candle has no strided-narrow primitive + // yet; other strides keep the old per-group path. + self.conv1d_depthwise(kernel, ¶ms) } else { let blocks = self.chunk(groups, 1)?; let kernel = kernel.chunk(groups, 0)?; @@ -204,6 +217,37 @@ impl Tensor { } } + // Depthwise 1D convolution (stride == dilation == 1) with elementwise ops only, no per-group + // loop. `self`: (b_size, c, l_in); `kernel`: (c, 1, k_size); out: (b_size, c, l_out). + fn conv1d_depthwise(&self, kernel: &Self, params: &ParamsConv1D) -> Result { + let (b_size, c, _l_in) = self.dims3()?; + let l_out = params.l_out(); + // (b_size, c, l_in + 2*padding) + let padded = if params.padding == 0 { + self.clone() + } else { + self.pad_with_zeros(2, params.padding, params.padding)? + }; + // (c, 1, k_size) -> (1, c, k_size) so it broadcasts over the batch dimension. + let kernel = kernel.reshape((c, params.k_size))?.unsqueeze(0)?; + let mut out: Option = None; + for k in 0..params.k_size { + // out[b, c, t] += padded[b, c, t + k] * kernel[c, k] (stride == dilation == 1) + let slice = padded.narrow(2, k, l_out)?; // (b_size, c, l_out) + let w_k = kernel.narrow(2, k, 1)?; // (1, c, 1) + let term = slice.broadcast_mul(&w_k)?; + out = Some(match out { + None => term, + Some(acc) => (acc + term)?, + }); + } + match out { + Some(out) => Ok(out), + // k_size == 0 is rejected upstream by the kernel-shape checks; be defensive anyway. + None => Tensor::zeros((b_size, c, l_out), self.dtype(), self.device()), + } + } + fn conv_transpose1d_single_group( &self, kernel: &Self, @@ -328,6 +372,11 @@ impl Tensor { }; if groups == 1 { self.conv2d_single_group(kernel, ¶ms) + } else if c_in_k == 1 && c_out == c_in && groups == c_in && stride == 1 && dilation == 1 { + // Depthwise conv2d — see `conv1d_with_algo` for the rationale (issue #3389). + // `k_h * k_w` slices, each scaled by a per-channel scalar and accumulated: a fixed + // number of elementwise kernels independent of the channel count. + self.conv2d_depthwise(kernel, ¶ms) } else { let blocks = self.chunk(groups, 1)?; let kernel = kernel.chunk(groups, 0)?; @@ -340,6 +389,38 @@ impl Tensor { } } + // Depthwise 2D convolution (stride == dilation == 1) implemented with elementwise ops only. + // `self`: (b, c, h, w); `kernel`: (c, 1, k_h, k_w); out: (b, c, out_h, out_w). + fn conv2d_depthwise(&self, kernel: &Self, params: &ParamsConv2D) -> Result { + let (b_size, c, _i_h, _i_w) = self.dims4()?; + let (out_h, out_w) = (params.out_h(), params.out_w()); + let padded = if params.padding == 0 { + self.clone() + } else { + self.pad_with_zeros(2, params.padding, params.padding)? + .pad_with_zeros(3, params.padding, params.padding)? + }; + // (c, 1, k_h, k_w) -> (1, c, k_h, k_w) to broadcast over the batch dimension. + let kernel = kernel.reshape((c, params.k_h, params.k_w))?.unsqueeze(0)?; + let mut out: Option = None; + for kh in 0..params.k_h { + for kw in 0..params.k_w { + // out[b, c, i, j] += padded[b, c, i + kh, j + kw] * kernel[c, kh, kw] + let slice = padded.narrow(2, kh, out_h)?.narrow(3, kw, out_w)?; // (b, c, out_h, out_w) + let w = kernel.narrow(2, kh, 1)?.narrow(3, kw, 1)?; // (1, c, 1, 1) + let term = slice.broadcast_mul(&w)?; + out = Some(match out { + None => term, + Some(acc) => (acc + term)?, + }); + } + } + match out { + Some(out) => Ok(out), + None => Tensor::zeros((b_size, c, out_h, out_w), self.dtype(), self.device()), + } + } + /// Applies a 2D transposed convolution over the input tensor. pub fn conv_transpose2d( &self, diff --git a/candle-core/tests/conv_tests.rs b/candle-core/tests/conv_tests.rs index abae89aa30..996b251e47 100644 --- a/candle-core/tests/conv_tests.rs +++ b/candle-core/tests/conv_tests.rs @@ -970,3 +970,76 @@ test_device!( conv2d_c_eq_h_eq_w_gpu, conv2d_c_eq_h_eq_w_metal ); + +// Depthwise convolutions (groups == in_channels). The fast depthwise path +// (groups == c_in, c_in_k == 1, stride == dilation == 1) is exercised here and +// cross-checked against an *independent* code path: building the equivalent +// block-diagonal dense weight (c_out, c_in, k...) and running a plain +// `groups == 1` convolution (the im2col / cuDNN path). They must agree up to +// floating-point summation order. +fn conv1d_depthwise(dev: &Device) -> Result<()> { + let (b, c, l, k) = (2usize, 6usize, 13usize, 3usize); + let t = Tensor::randn(0f32, 1f32, (b, c, l), dev)?; + // Depthwise weight: (c_out=c, c_in_k=1, k). + let w = Tensor::randn(0f32, 1f32, (c, 1, k), dev)?; + // Equivalent block-diagonal dense weight (c_out=c, c_in=c, k), zero off the diagonal: + // w (c,1,k) -> (c,k,1); eye (c,c) -> (c,1,c); product -> (c,k,c) -> transpose -> (c,c,k). + let dense = { + let eye = Tensor::eye(c, t.dtype(), dev)?.unsqueeze(1)?; // (c, 1, c) + let wk = w.reshape((c, k))?.unsqueeze(2)?; // (c, k, 1) + wk.broadcast_mul(&eye)?.transpose(1, 2)?.contiguous()? // (c, c, k) + }; + for &padding in &[0usize, 1usize] { + let l_out = l + 2 * padding - (k - 1); + let fast = t.conv1d(&w, padding, 1, 1, c)?; + assert_eq!(fast.dims(), &[b, c, l_out]); + let reference = t.conv1d(&dense, padding, 1, 1, 1)?; + let diff: f32 = (fast - reference)? + .abs()? + .flatten_all()? + .max(0)? + .to_scalar()?; + assert!(diff < 1e-4, "depthwise conv1d mismatch: {diff}"); + } + Ok(()) +} + +fn conv2d_depthwise(dev: &Device) -> Result<()> { + let (b, c, h, w_, k) = (2usize, 5usize, 5usize, 7usize, 3usize); + let t = Tensor::randn(0f32, 1f32, (b, c, h, w_), dev)?; + // Depthwise weight: (c_out=c, c_in_k=1, k, k). + let weight = Tensor::randn(0f32, 1f32, (c, 1, k, k), dev)?; + // Equivalent block-diagonal dense weight (c_out=c, c_in=c, k, k), zero off the diagonal. + let dense = { + let eye = Tensor::eye(c, t.dtype(), dev)?.reshape((c, c, 1, 1))?; + let wk = weight.reshape((c, 1, k, k))?; + wk.broadcast_mul(&eye)?.contiguous()? // (c, c, k, k) + }; + for &padding in &[0usize, 1usize] { + let oh = h + 2 * padding - (k - 1); + let ow = w_ + 2 * padding - (k - 1); + let fast = t.conv2d(&weight, padding, 1, 1, c)?; + assert_eq!(fast.dims(), &[b, c, oh, ow]); + let reference = t.conv2d(&dense, padding, 1, 1, 1)?; + let diff: f32 = (fast - reference)? + .abs()? + .flatten_all()? + .max(0)? + .to_scalar()?; + assert!(diff < 1e-4, "depthwise conv2d mismatch: {diff}"); + } + Ok(()) +} + +test_device!( + conv1d_depthwise, + conv1d_depthwise_cpu, + conv1d_depthwise_gpu, + conv1d_depthwise_metal +); +test_device!( + conv2d_depthwise, + conv2d_depthwise_cpu, + conv2d_depthwise_gpu, + conv2d_depthwise_metal +);