diff --git a/src/make_zero.jl b/src/make_zero.jl index 4f627581ea..b0566210cd 100644 --- a/src/make_zero.jl +++ b/src/make_zero.jl @@ -198,7 +198,7 @@ function make_zero_immutable!(prev::T, seen::S)::T where {T<:Tuple,S} end function make_zero_immutable!(prev::NamedTuple{a,b}, seen::S)::NamedTuple{a,b} where {a,b,S} - NamedTuple{a,b}(ntuple(Val(length(T.parameters))) do i + NamedTuple{a,b}(ntuple(Val(length(a))) do i Base.@_inline_meta make_zero_immutable!(prev[a[i]], seen) end) @@ -384,6 +384,13 @@ end push!(seen, prev) + # For make_zero!(NamedTuple) we want to recurse and zero out + # the storage + if !Base.ismutabletype(T) + make_zero_immutable!(prev, seen) + return nothing + end + for i = 1:nf if isdefined(prev, i) xi = getfield(prev, i) diff --git a/test/abi.jl b/test/abi.jl index 7a7917553f..ee81186874 100644 --- a/test/abi.jl +++ b/test/abi.jl @@ -1,5 +1,6 @@ using Enzyme using Test +using Random @testset "ABI & Calling convention" begin @@ -493,6 +494,21 @@ end @test dv.y ≈ 0.0 end +@testset "Make Zero!" begin + params = (; q = rand(64), p = rand(64)) + dparams = make_zero(params) + @test all(==(0), dparams.q) + @test all(==(0), dparams.p) + + rand!(dparams.q) + rand!(dparams.p) + + make_zero!(dparams) + @test all(==(0), dparams.q) + @test all(==(0), dparams.p) +end + + @testset "Type inference" begin x = ones(10) @inferred autodiff(Enzyme.Reverse, abssum, Duplicated(x,x))