Skip to content

Commit 4253bff

Browse files
authored
Fix some tests and bump version (#706)
* Fix some tests - Fix doctests - Refactor _zeroed_backing() - Use `mergewith(f)` instead of the deprecated `merge(f)` * Fix CI * Bump version
1 parent dad0d24 commit 4253bff

9 files changed

Lines changed: 41 additions & 51 deletions

File tree

.github/workflows/CI.yml

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -29,28 +29,21 @@ jobs:
2929
- "pre" # Pre-release/nightly
3030
os:
3131
- ubuntu-latest
32-
- macOS-13 # Intel
32+
- macOS-latest
3333
- windows-latest
3434
arch:
35-
- x64
36-
- x86
37-
exclude:
38-
# Test 32-bit only on Linux
39-
- os: macOS-13
40-
arch: x86
41-
- os: windows-latest
42-
arch: x86
35+
- default
4336
include:
44-
- os: macOS-latest # Apple Silicon
37+
- os: ubuntu-latest
4538
version: "1"
46-
arch: aarch64
39+
arch: x86
4740
steps:
4841
- uses: actions/checkout@v6
4942
- uses: julia-actions/setup-julia@v2
5043
with:
5144
version: ${{ matrix.version }}
5245
arch: ${{ matrix.arch }}
53-
- uses: julia-actions/cache@v2
46+
- uses: julia-actions/cache@v3
5447
- uses: julia-actions/julia-buildpkg@v1
5548
- uses: julia-actions/julia-runtest@v1
5649
- uses: julia-actions/julia-processcoverage@v1

.github/workflows/IntegrationTest.yml

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -13,53 +13,54 @@ concurrency:
1313

1414
jobs:
1515
test:
16-
name: ${{ matrix.package.repo }}
16+
name: ${{ matrix.package.name }}
1717
runs-on: ${{ matrix.os }}
1818
strategy:
1919
fail-fast: false
2020
matrix:
2121
julia-version: [1]
2222
os: [ubuntu-latest]
2323
package:
24-
- {user: JuliaDiff, repo: ChainRules.jl}
25-
- {user: JuliaDiff, repo: ChainRulesTestUtils.jl}
26-
- {user: JuliaDiff, repo: ChainRulesOverloadGeneration.jl}
27-
- {user: JuliaMath, repo: SpecialFunctions.jl}
28-
- {user: invenia, repo: BlockDiagonals.jl}
29-
- {user: invenia, repo: PDMatsExtras.jl}
30-
- {user: chrisbrahms, repo: Hankel.jl}
31-
- {user: SciML, repo: DiffEqBase.jl}
32-
- {user: SciML, repo: DataInterpolations.jl}
33-
- {user: dfdx, repo: Yota.jl}
34-
- {user: JuliaStats, repo: StatsFuns.jl}
35-
- {user: JuliaStats, repo: LogExpFunctions.jl}
24+
- {user: JuliaDiff, name: ChainRules}
25+
- {user: JuliaDiff, name: ChainRulesTestUtils}
26+
- {user: JuliaDiff, name: ChainRulesOverloadGeneration}
27+
- {user: JuliaMath, name: SpecialFunctions}
28+
- {user: invenia, name: BlockDiagonals}
29+
- {user: invenia, name: PDMatsExtras}
30+
- {user: chrisbrahms, name: Hankel}
31+
- {user: SciML, name: DiffEqBase}
32+
- {user: SciML, name: DataInterpolations}
33+
- {user: dfdx, name: Yota}
34+
- {user: JuliaStats, name: StatsFuns}
35+
- {user: JuliaStats, name: LogExpFunctions}
3636
# Diffractor needs to run on Julia nightly
3737
include:
3838
- julia-version: nightly
3939
os: ubuntu-latest
40-
package: {user: JuliaDiff, repo: Diffractor.jl}
40+
package: {user: JuliaDiff, name: Diffractor}
4141

4242
steps:
4343
- uses: actions/checkout@v6
4444
- uses: julia-actions/setup-julia@v2
4545
with:
4646
version: ${{ matrix.julia-version }}
4747
arch: x64
48+
- uses: julia-actions/cache@v3
4849
- uses: julia-actions/julia-buildpkg@v1
4950
- name: Clone Downstream
5051
uses: actions/checkout@v6
5152
with:
52-
repository: ${{ matrix.package.user }}/${{ matrix.package.repo }}
53+
repository: ${{ matrix.package.user }}/${{ matrix.package.name }}.jl
5354
path: downstream
5455
- name: Load this and run the downstream tests
55-
shell: julia --project=downstream {0}
56+
shell: julia --project=@temp {0}
5657
run: |
5758
using Pkg
5859
try
5960
# force it to use this PR's version of the package
60-
Pkg.develop(PackageSpec(path=".")) # resolver may fail with main deps
61+
Pkg.develop([PackageSpec(path="."), PackageSpec(path="downstream")]) # resolver may fail with main deps
6162
Pkg.update()
62-
Pkg.test() # resolver may fail with test time deps
63+
Pkg.test("${{ matrix.package.name }}") # resolver may fail with test time deps
6364
catch err
6465
err isa Pkg.Resolve.ResolverError || rethrow()
6566
# If we can't resolve that means this is incompatible by SemVer and this is fine

.github/workflows/JuliaNightly.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ jobs:
2727
- uses: julia-actions/setup-julia@v2
2828
with:
2929
version: nightly
30-
- uses: julia-actions/cache@v2
30+
- uses: julia-actions/cache@v3
3131
- uses: julia-actions/julia-buildpkg@v1
3232
- uses: julia-actions/julia-runtest@v1
3333
- uses: julia-actions/julia-processcoverage@v1

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRulesCore"
22
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
3-
version = "1.26.0"
3+
version = "1.26.1"
44

55
[deps]
66
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"

src/rules.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,10 @@ unary input, unary output scalar function:
1818
```jldoctest frule
1919
julia> dself = NoTangent();
2020
21-
julia> x = rand()
22-
0.8236475079774124
21+
julia> x = 1.23456;
2322
2423
julia> sinx, Δsinx = frule((dself, 1), sin, x)
25-
(0.7336293678134624, 0.6795498147167869)
24+
(0.9440031218347901, 0.3299365180851773)
2625
2726
julia> sinx == sin(x)
2827
true
@@ -51,7 +50,7 @@ that return a single output that is iterable, like a `Tuple`.
5150
So this is actually a [`Tangent`](@ref):
5251
```jldoctest frule
5352
julia> Δsincosx
54-
Tangent{Tuple{Float64, Float64}}(0.6795498147167869, -0.7336293678134624)
53+
Tangent{Tuple{Float64, Float64}}(0.3299365180851773, -0.9440031218347901)
5554
```
5655
5756
The optional [`RuleConfig`](@ref) option allows specifying frules only for AD systems that

src/tangent_arithmetic.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ function Base.:+(a::P, d::StructuralTangent{P}) where {P}
142142
return construct(P, net_backing)
143143
end
144144
end
145-
Base.:+(a::Dict, d::Tangent{P}) where {P} = merge(+, a, backing(d))
145+
Base.:+(a::Dict, d::Tangent{P}) where {P} = mergewith(+, a, backing(d))
146146
Base.:+(a::StructuralTangent{P}, b::P) where {P} = b + a
147147

148148
Base.:-(tangent::StructuralTangent{P}) where {P} = map(-, tangent)

src/tangent_types/structural_tangent.jl

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -208,16 +208,9 @@ function backing(x::T)::NamedTuple where {T}
208208
end
209209
end
210210

211-
"""
212-
_zeroed_backing(P)
213-
214-
Returns a NamedTuple with same fields as `P`, and all values `ZeroTangent()`.
215-
"""
216-
@generated function _zeroed_backing(::Type{P}) where {P}
217-
nil_base = ntuple(fieldcount(P)) do i
218-
(fieldname(P, i), ZeroTangent())
219-
end
220-
return (; nil_base...)
211+
# Returns a NamedTuple with same fields as `P`, and all values `ZeroTangent()`
212+
function _zeroed_backing(::Type{T}) where {T}
213+
return NamedTuple{fieldnames(T)}(ntuple(_ -> ZeroTangent(), fieldcount(T)))
221214
end
222215

223216
"""
@@ -299,7 +292,7 @@ function elementwise_add(a::NamedTuple{an}, b::NamedTuple{bn}) where {an,bn}
299292
end
300293
end
301294

302-
elementwise_add(a::Dict, b::Dict) = merge(+, a, b)
295+
elementwise_add(a::Dict, b::Dict) = mergewith(+, a, b)
303296

304297
struct PrimalAdditionFailedException{P} <: Exception
305298
primal::P

src/tangent_types/thunks.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -179,8 +179,7 @@ To evaluate the wrapped closure, call [`unthunk`](@ref) which is a no-op when th
179179
argument is not a `Thunk`.
180180
181181
```jldoctest
182-
julia> t = @thunk(3)
183-
Thunk(var"#4#5"())
182+
julia> t = @thunk(3);
184183
185184
julia> unthunk(t)
186185
3

test/projection.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -457,7 +457,12 @@ struct NoSuperType end
457457
@test 0 == @ballocated $pdiag(dx) setup = (dx = Diagonal(rand(10^3)))
458458

459459
psymm = ProjectTo(Symmetric(rand(10^3, 10^3)))
460-
@test_broken 0 == @ballocated $psymm(dx) setup = (dx = Symmetric(rand(10^3, 10^3))) # 64
460+
allocs = @ballocated $psymm(dx) setup = (dx = Symmetric(rand(10^3, 10^3))) # 64
461+
if VERSION > v"1.13"
462+
@test 0 == allocs
463+
else
464+
@test_broken 0 == allocs
465+
end
461466
end
462467

463468
@testset "#685" begin

0 commit comments

Comments
 (0)