Skip to content

Commit b0a0745

Browse files
committed
Fix some tests
- Fix doctests - Refactor _zeroed_backing() - Use `mergewith(f)` instead of the deprecated `merge(f)`
1 parent 004e43e commit b0a0745

5 files changed

Lines changed: 15 additions & 19 deletions

File tree

src/rules.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,10 @@ unary input, unary output scalar function:
1818
```jldoctest frule
1919
julia> dself = NoTangent();
2020
21-
julia> x = rand()
22-
0.8236475079774124
21+
julia> x = 1.23456;
2322
2423
julia> sinx, Δsinx = frule((dself, 1), sin, x)
25-
(0.7336293678134624, 0.6795498147167869)
24+
(0.9440031218347901, 0.3299365180851773)
2625
2726
julia> sinx == sin(x)
2827
true
@@ -51,7 +50,7 @@ that return a single output that is iterable, like a `Tuple`.
5150
So this is actually a [`Tangent`](@ref):
5251
```jldoctest frule
5352
julia> Δsincosx
54-
Tangent{Tuple{Float64, Float64}}(0.6795498147167869, -0.7336293678134624)
53+
Tangent{Tuple{Float64, Float64}}(0.3299365180851773, -0.9440031218347901)
5554
```
5655
5756
The optional [`RuleConfig`](@ref) option allows specifying frules only for AD systems that

src/tangent_arithmetic.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ function Base.:+(a::P, d::StructuralTangent{P}) where {P}
142142
return construct(P, net_backing)
143143
end
144144
end
145-
Base.:+(a::Dict, d::Tangent{P}) where {P} = merge(+, a, backing(d))
145+
Base.:+(a::Dict, d::Tangent{P}) where {P} = mergewith(+, a, backing(d))
146146
Base.:+(a::StructuralTangent{P}, b::P) where {P} = b + a
147147

148148
Base.:-(tangent::StructuralTangent{P}) where {P} = map(-, tangent)

src/tangent_types/structural_tangent.jl

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -208,16 +208,9 @@ function backing(x::T)::NamedTuple where {T}
208208
end
209209
end
210210

211-
"""
212-
_zeroed_backing(P)
213-
214-
Returns a NamedTuple with same fields as `P`, and all values `ZeroTangent()`.
215-
"""
216-
@generated function _zeroed_backing(::Type{P}) where {P}
217-
nil_base = ntuple(fieldcount(P)) do i
218-
(fieldname(P, i), ZeroTangent())
219-
end
220-
return (; nil_base...)
211+
# Returns a NamedTuple with same fields as `P`, and all values `ZeroTangent()`
212+
function _zeroed_backing(::Type{T}) where {T}
213+
return NamedTuple{fieldnames(T)}(ntuple(_ -> ZeroTangent(), fieldcount(T)))
221214
end
222215

223216
"""
@@ -299,7 +292,7 @@ function elementwise_add(a::NamedTuple{an}, b::NamedTuple{bn}) where {an,bn}
299292
end
300293
end
301294

302-
elementwise_add(a::Dict, b::Dict) = merge(+, a, b)
295+
elementwise_add(a::Dict, b::Dict) = mergewith(+, a, b)
303296

304297
struct PrimalAdditionFailedException{P} <: Exception
305298
primal::P

src/tangent_types/thunks.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -179,8 +179,7 @@ To evaluate the wrapped closure, call [`unthunk`](@ref) which is a no-op when th
179179
argument is not a `Thunk`.
180180
181181
```jldoctest
182-
julia> t = @thunk(3)
183-
Thunk(var"#4#5"())
182+
julia> t = @thunk(3);
184183
185184
julia> unthunk(t)
186185
3

test/projection.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -457,7 +457,12 @@ struct NoSuperType end
457457
@test 0 == @ballocated $pdiag(dx) setup = (dx = Diagonal(rand(10^3)))
458458

459459
psymm = ProjectTo(Symmetric(rand(10^3, 10^3)))
460-
@test_broken 0 == @ballocated $psymm(dx) setup = (dx = Symmetric(rand(10^3, 10^3))) # 64
460+
allocs = @ballocated $psymm(dx) setup = (dx = Symmetric(rand(10^3, 10^3))) # 64
461+
if VERSION > v"1.13"
462+
@test 0 == allocs
463+
else
464+
@test_broken 0 == allocs
465+
end
461466
end
462467

463468
@testset "#685" begin

0 commit comments

Comments
 (0)