@@ -6,6 +6,10 @@ _paddims(x::Tuple, y::Tuple) = (x..., y[(end - (length(y) - length(x) - 1)):end]
66expand (N, i:: Tuple ) = i
77expand (N, i:: Integer ) = ntuple (_ -> i, N)
88
9+ conv_reshape_bias (c) = c. bias isa AbstractVector ?
10+ reshape (c. bias, map (_-> 1 , c. stride)... , :, 1 ) :
11+ c. bias
12+
913"""
1014 SamePad()
1115
96100
97101"""
98102 Conv(weight::AbstractArray, bias, [activation; stride, pad, dilation])
99-
103+
100104Constructs a convolutional layer with the given weight and bias.
101105Accepts the same keywords (and has the same defaults) as the `Conv((4,4), 3=>7, relu)`
102106method.
@@ -117,7 +121,7 @@ julia> params(c1) |> length
1171212
118122```
119123"""
120- function Conv (w:: AbstractArray{T,N} , b:: Union{Bool, Zeros, AbstractVector{T}} , σ = identity;
124+ function Conv (w:: AbstractArray{T,N} , b:: Union{Bool,AbstractVector{T}} , σ = identity;
121125 stride = 1 , pad = 0 , dilation = 1 ) where {T,N}
122126 stride = expand (Val (N- 2 ), stride)
123127 dilation = expand (Val (N- 2 ), dilation)
@@ -152,9 +156,8 @@ convfilter(filter::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer};
152156function (c:: Conv )(x:: AbstractArray )
153157 # TODO : breaks gpu broadcast :(
154158 # ndims(x) == ndims(c.weight)-1 && return squeezebatch(c(reshape(x, size(x)..., 1)))
155- σ, b = c. σ, reshape (c. bias, ntuple (_-> 1 , length (c. stride))... , :, 1 )
156159 cdims = DenseConvDims (x, c. weight; stride= c. stride, padding= c. pad, dilation= c. dilation)
157- σ .(conv (x, c. weight, cdims) .+ b )
160+ (c . σ) . (conv (x, c. weight, cdims) .+ conv_reshape_bias (c) )
158161end
159162
160163function Base. show (io:: IO , l:: Conv )
@@ -207,16 +210,16 @@ end
207210
208211"""
209212 ConvTranspose(weight::AbstractArray, bias, [activation; stride, pad, dilation])
210-
213+
211214Constructs a layer with the given weight and bias arrays.
212215Accepts the same keywords as the `ConvTranspose((4,4), 3=>7, relu)` method.
213216"""
214- function ConvTranspose (w:: AbstractArray{T,N} , b:: Union{Bool, Zeros, AbstractVector{T}} , σ = identity;
217+ function ConvTranspose (w:: AbstractArray{T,N} , b:: Union{Bool, AbstractVector{T}} , σ = identity;
215218 stride = 1 , pad = 0 , dilation = 1 ) where {T,N}
216219 stride = expand (Val (N- 2 ), stride)
217220 dilation = expand (Val (N- 2 ), dilation)
218221 pad = calc_padding (ConvTranspose, pad, size (w)[1 : N- 2 ], dilation, stride)
219- bias = create_bias (b, zeros, size (w, N- 1 ))
222+ bias = create_bias (b, zeros, size (w, N- 1 ))
220223 return ConvTranspose (σ, w, bias, stride, pad, dilation)
221224end
222225
248251
249252function (c:: ConvTranspose )(x:: AbstractArray )
250253 # ndims(x) == ndims(c.weight)-1 && return squeezebatch(c(reshape(x, size(x)..., 1)))
251- σ, b = c. σ, reshape (c. bias, map (_-> 1 , c. stride)... , :, 1 )
252254 cdims = conv_transpose_dims (c, x)
253- σ .(∇conv_data (x, c. weight, cdims) .+ b )
255+ (c . σ) . (∇conv_data (x, c. weight, cdims) .+ conv_reshape_bias (c) )
254256end
255257
256258function Base. show (io:: IO , l:: ConvTranspose )
@@ -304,11 +306,11 @@ end
304306
305307"""
306308 DepthwiseConv(weight::AbstractArray, bias, [activation; stride, pad, dilation])
307-
309+
308310Constructs a layer with the given weight and bias arrays.
309311Accepts the same keywords as the `DepthwiseConv((4,4), 3=>6, relu)` method.
310312"""
311- function DepthwiseConv (w:: AbstractArray{T,N} , b:: Union{Bool, Zeros, AbstractVector{T}} , σ = identity;
313+ function DepthwiseConv (w:: AbstractArray{T,N} , b:: Union{Bool,AbstractVector{T}} , σ = identity;
312314 stride = 1 , pad = 0 , dilation = 1 ) where {T,N}
313315 stride = expand (Val (N- 2 ), stride)
314316 dilation = expand (Val (N- 2 ), dilation)
@@ -341,9 +343,8 @@ depthwiseconvfilter(filter::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer};
341343 init = glorot_uniform) where N = init (filter... , div (ch[2 ], ch[1 ]), ch[1 ])
342344
343345function (c:: DepthwiseConv )(x)
344- σ, b = c. σ, reshape (c. bias, map (_-> 1 , c. stride)... , :, 1 )
345346 cdims = DepthwiseConvDims (x, c. weight; stride= c. stride, padding= c. pad, dilation= c. dilation)
346- σ .(depthwiseconv (x, c. weight, cdims) .+ b )
347+ (c . σ) . (depthwiseconv (x, c. weight, cdims) .+ conv_reshape_bias (c) )
347348end
348349
349350function Base. show (io:: IO , l:: DepthwiseConv )
@@ -392,11 +393,11 @@ end
392393
393394"""
394395 CrossCor(weight::AbstractArray, bias, [activation; stride, pad, dilation])
395-
396+
396397Constructs a layer with the given weight and bias arrays.
397398Accepts the same keywords as the `CrossCor((4,4), 3=>7, relu)` method.
398399"""
399- function CrossCor (w:: AbstractArray{T,N} , b:: Union{Bool, Zeros, AbstractVector{T}} , σ = identity;
400+ function CrossCor (w:: AbstractArray{T,N} , b:: Union{Bool,AbstractVector{T}} = true , σ = identity;
400401 stride = 1 , pad = 0 , dilation = 1 ) where {T,N}
401402 stride = expand (Val (N- 2 ), stride)
402403 dilation = expand (Val (N- 2 ), dilation)
422423function (c:: CrossCor )(x:: AbstractArray )
423424 # TODO : breaks gpu broadcast :(
424425 # ndims(x) == ndims(c.weight)-1 && return squeezebatch(c(reshape(x, size(x)..., 1)))
425- σ, b = c. σ, reshape (c. bias, map (_-> 1 , c. stride)... , :, 1 )
426426 cdims = DenseConvDims (x, c. weight; stride= c. stride, padding= c. pad, dilation= c. dilation)
427- σ .(crosscor (x, c. weight, cdims) .+ b )
427+ (c . σ) . (crosscor (x, c. weight, cdims) .+ conv_reshape_bias (c) )
428428end
429429
430430function Base. show (io:: IO , l:: CrossCor )
0 commit comments