Skip to content

Commit ee013f2

Browse files
committed
Fix some tests
- Fix doctests - Move _zeroed_backing() to a different file so it's defined before being used (otherwise causes failures on 1.13+) - Use `mergewith(f)` instead of the deprecated `merge(f)`
1 parent 004e43e commit ee013f2

6 files changed

Lines changed: 25 additions & 21 deletions

File tree

src/rules.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,11 @@ unary input, unary output scalar function:
1818
```jldoctest frule
1919
julia> dself = NoTangent();
2020
21-
julia> x = rand()
22-
0.8236475079774124
21+
julia> x = 1.23456
22+
1.23456
2323
2424
julia> sinx, Δsinx = frule((dself, 1), sin, x)
25-
(0.7336293678134624, 0.6795498147167869)
25+
(0.9440031218347901, 0.3299365180851773)
2626
2727
julia> sinx == sin(x)
2828
true
@@ -51,7 +51,7 @@ that return a single output that is iterable, like a `Tuple`.
5151
So this is actually a [`Tangent`](@ref):
5252
```jldoctest frule
5353
julia> Δsincosx
54-
Tangent{Tuple{Float64, Float64}}(0.6795498147167869, -0.7336293678134624)
54+
Tangent{Tuple{Float64, Float64}}(0.3299365180851773, -0.9440031218347901)
5555
```
5656
5757
The optional [`RuleConfig`](@ref) option allows specifying frules only for AD systems that

src/tangent_arithmetic.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ function Base.:+(a::P, d::StructuralTangent{P}) where {P}
142142
return construct(P, net_backing)
143143
end
144144
end
145-
Base.:+(a::Dict, d::Tangent{P}) where {P} = merge(+, a, backing(d))
145+
Base.:+(a::Dict, d::Tangent{P}) where {P} = mergewith(+, a, backing(d))
146146
Base.:+(a::StructuralTangent{P}, b::P) where {P} = b + a
147147

148148
Base.:-(tangent::StructuralTangent{P}) where {P} = map(-, tangent)

src/tangent_types/abstract_zero.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,3 +201,15 @@ zero_tangent(::Core.Compiler.AbstractInterpreter) = NoTangent()
201201
zero_tangent(::Core.Compiler.InstructionStream) = NoTangent()
202202
zero_tangent(::Core.CodeInfo) = NoTangent()
203203
zero_tangent(::Core.MethodInstance) = NoTangent()
204+
205+
"""
206+
_zeroed_backing(P)
207+
208+
Returns a NamedTuple with same fields as `P`, and all values `ZeroTangent()`.
209+
"""
210+
@generated function _zeroed_backing(::Type{P}) where {P}
211+
nil_base = ntuple(fieldcount(P)) do i
212+
(fieldname(P, i), ZeroTangent())
213+
end
214+
return (; nil_base...)
215+
end

src/tangent_types/structural_tangent.jl

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -208,18 +208,6 @@ function backing(x::T)::NamedTuple where {T}
208208
end
209209
end
210210

211-
"""
212-
_zeroed_backing(P)
213-
214-
Returns a NamedTuple with same fields as `P`, and all values `ZeroTangent()`.
215-
"""
216-
@generated function _zeroed_backing(::Type{P}) where {P}
217-
nil_base = ntuple(fieldcount(P)) do i
218-
(fieldname(P, i), ZeroTangent())
219-
end
220-
return (; nil_base...)
221-
end
222-
223211
"""
224212
construct(::Type{T}, fields::[NamedTuple|Tuple])
225213
@@ -299,7 +287,7 @@ function elementwise_add(a::NamedTuple{an}, b::NamedTuple{bn}) where {an,bn}
299287
end
300288
end
301289

302-
elementwise_add(a::Dict, b::Dict) = merge(+, a, b)
290+
elementwise_add(a::Dict, b::Dict) = mergewith(+, a, b)
303291

304292
struct PrimalAdditionFailedException{P} <: Exception
305293
primal::P

src/tangent_types/thunks.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -179,8 +179,7 @@ To evaluate the wrapped closure, call [`unthunk`](@ref) which is a no-op when th
179179
argument is not a `Thunk`.
180180
181181
```jldoctest
182-
julia> t = @thunk(3)
183-
Thunk(var"#4#5"())
182+
julia> t = @thunk(3);
184183
185184
julia> unthunk(t)
186185
3

test/projection.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -457,7 +457,12 @@ struct NoSuperType end
457457
@test 0 == @ballocated $pdiag(dx) setup = (dx = Diagonal(rand(10^3)))
458458

459459
psymm = ProjectTo(Symmetric(rand(10^3, 10^3)))
460-
@test_broken 0 == @ballocated $psymm(dx) setup = (dx = Symmetric(rand(10^3, 10^3))) # 64
460+
allocs = @ballocated $psymm(dx) setup = (dx = Symmetric(rand(10^3, 10^3))) # 64
461+
if VERSION > v"1.13"
462+
@test 0 == allocs
463+
else
464+
@test_broken 0 == allocs
465+
end
461466
end
462467

463468
@testset "#685" begin

0 commit comments

Comments
 (0)