@@ -187,88 +187,39 @@ end
187187 @test eltype (f32 (f64 (m))[1 ]. W) == Float32
188188end
189189
190- @testset " Zeros " begin
190+ @testset " Without bias " begin
191191 m = Dense (3 ,2 ; bias= false )
192- @test f64 (m). b === m. b === Zeros ()
193- @test f32 (m). b === m. b === Zeros ()
192+ @test f64 (m). b === m. b === false === Zeros () # Zeros() is deprecated
193+ @test f32 (m). b === m. b === false
194194
195195 @testset " Gradients for broadcasted $op with sizes $s " for op in (+ ,- ,* ), s in ((1 ,), (2 ,3 ))
196196 o = ones (s)
197197 z = zeros (s)
198- Z = Zeros ()
199198
200199 @testset " Explicit" begin
201200 gfun (args... ) = gradient ((x, y) -> sum (op .(x,y)), args... )
202201 g = gfun (o, z)
203- @test gfun (o, Z ) == (g[1 ], nothing )
202+ @test gfun (o, false ) == (g[1 ], nothing )
204203
205204 g = gfun (z, o)
206- @test gfun (Z , o) == (nothing , g[2 ])
205+ @test gfun (false , o) == (nothing , g[2 ])
207206 end
208207
209208 @testset " Implicit" begin
210209 gfun (args... ) = gradient (() -> sum (op .(args... )), params (collect (args)))
211210 g = gfun (o, z)
212211
213- gres = gfun (o, Z )
212+ gres = gfun (o, false )
214213 @test gres[o] == g[o]
215- @test Z ∉ gres. params
214+ @test false ∉ gres. params
215+ @test length (gres. params) == 1
216216
217217 g = gfun (z, o)
218- gres = gfun (Z, o)
219- @test gres[o] == g[o]
220- @test Z ∉ gres. params
221- end
222- end
223-
224- @testset " Gradients for broadcasted / with sizes $s " for s in ((1 ,), (2 ,3 ))
225- o = ones (s)
226- z = zeros (s)
227- Z = Zeros () # Only defined for 0-dim
228-
229- @testset " Explicit" begin
230- gfun (args... ) = gradient ((x, y) -> sum (x ./ y), args... )
231- g = gfun (z, o)
232- @test gfun (Z, o) == (nothing , g[2 ])
233- end
234-
235- @testset " Implicit" begin
236- gfun (x,y) = gradient (() -> sum (x ./ y), params ([x,y]))
237-
238- g = gfun (z, o)
239- gres = gfun (Z, o)
240- @test gres[o] == g[o]
241- @test Z ∉ gres. params
242- end
243- end
244-
245- @testset " Gradients for $op with sizes $s " for op in (+ ,- ), s in (tuple (), (1 ,), (2 ,3 ))
246- o = ones (s)
247- z = zeros (s)
248- Z = Zeros ()
249-
250-
251- @testset " Explicit" begin
252- gfun (args... ) = gradient ((x, y) -> sum (op (x,y)), args... )
253-
254- g = gfun (o, z)
255- @test gfun (o, Z) == (g[1 ], nothing )
256-
257- g = gfun (z, o)
258- @test gfun (Z, o) == (nothing , g[2 ])
259- end
260218
261- @testset " Implicit" begin
262- gfun (args... ) = gradient (() -> sum (op (args... )), params (collect (args)))
263- g = gfun (o, z)
264- gres = gfun (o, Z)
219+ gres = gfun (false , o)
265220 @test gres[o] == g[o]
266- @test Z ∉ gres. params
267-
268- g = gfun (z, o)
269- gres = gfun (Z, o)
270- @test gres[o] == g[o]
271- @test Z ∉ gres. params
221+ @test false ∉ gres. params
222+ @test length (gres. params) == 1
272223 end
273224 end
274225end
@@ -281,52 +232,53 @@ end
281232 @test stack (unstack (stacked_array, 1 ), 1 ) == stacked_array
282233end
283234
235+
284236@testset " Param remapping" begin
285- ls (dims... ) = reshape (collect (Float32, 1 : prod (dims)), dims... ) # accepts dims in reverse order to Dense
286- dl (nin, nout, bias ) = Dense (ls (nout, nin), bias (nout))
287- dm (bias ) = Chain (
288- dl (3 , 5 , bias ),
289- dl (5 , 4 , bias ),
290- dl (4 , 3 , bias )
237+ count32 (dims... ) = reshape (collect (Float32, 1 : prod (dims)), dims... ) # accepts dims in reverse order to Dense
238+ dl (nin, nout, bt ) = Dense (count32 (nout, nin), bt (nout)) # this accepts dims in same order as Dense
239+ densechain (bt ) = Chain (
240+ dl (3 , 5 , bt ),
241+ dl (5 , 4 , bt ),
242+ dl (4 , 3 , bt )
291243 )
244+ nobias (n) = false
292245
293- nobias (n) = Zeros ()
294- testdense (m, bt) = @testset " Check layer $i " for (i, (l1, l2)) in enumerate (zip (m, dm (bt)))
295- @test l1. W == l2. W
296- @test l1. b == l2. b
297- @test typeof (l1. b) === typeof (l2. b)
246+ testdense (m, bt) = @testset " Check layer $i " for (i, (l1, l2)) in enumerate (zip (m, densechain (bt)))
247+ @test l1. weight == l2. weight
248+ @test l1. bias == l2. bias
249+ @test typeof (l1. bias) === typeof (l2. bias)
298250 end
299251
300252 @testset " loadparams!" begin
301- import Flux: loadparams!
302253 pars (w, b) = [w, b]
303254 import Flux: loadparams!, Zeros
304255 pars (w, b:: Zeros ) = [w, Flux. zeros (size (w,1 ))]
305256 pars (l) = pars (l. W, l. b)
306257 pararray (m) = mapreduce (pars, vcat, m)
307258 weights (m) = mapreduce (l -> [l. W], vcat, m)
308- @testset " Bias type $bt " for bt in (Flux . zeros, nobias)
309- m = dm (bt)
259+ @testset " Bias type $bt " for bt in (zeros, nobias)
260+ m = densechain (bt)
310261 loadparams! (m, params (m))
311262 testdense (m, bt)
312263 end
313-
264+ #=
314265 @testset "$b1 to $b2" for (b1, b2, be) in (
315266 (Flux.zeros, ones, ones), # Load ones as bias to a model with zeros as bias -> model gets ones as bias
316267 (ones, nobias, Flux.zeros), # Load Zeros as bias to a model with ones as bias-> model gets zeros as bias
317268 (nobias, ones, nobias), # Load ones as bias to a model with Zeros as bias-> model bias does not change
318269 )
319- m1 = dm (b1)
320- m2 = dm (b2)
270+ m1 = densechain (b1)
271+ m2 = densechain (b2)
321272 loadparams!(m1, b1 == nobias ? weights(m2) : pararray(m2))
322273 testdense(m1, be)
323274 end
275+ =#
324276 end
325277
326278 @testset " destructure" begin
327279 import Flux: destructure
328280 @testset " Bias type $bt " for bt in (zeros, nobias)
329- m = dm (bt)
281+ m = densechain (bt)
330282 p, re = destructure (m)
331283 testdense (re (p), bt)
332284 end
0 commit comments