Skip to content

Commit d38d905

Browse files
authored
fix: make DI compatible with latest Mooncake friendly tangents (#1001)
* fix: make DI compatible with latest Mooncake friendly tangents * Disable tests * Bump * Test barrier * Changelog
1 parent 1b5d91c commit d38d905

9 files changed

Lines changed: 51 additions & 12 deletions

File tree

.github/workflows/Test.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ jobs:
101101
- ForwardDiff
102102
- GTPSA
103103
- Mooncake
104+
- Mooncake-old
104105
- PolyesterForwardDiff
105106
- ReverseDiff
106107
- SparsityDetector

DifferentiationInterface/CHANGELOG.md

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,15 @@ All notable changes to this project will be documented in this file.
55
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
66
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
77

8-
## [Unreleased](https://github.com/JuliaDiff/DifferentiationInterface.jl/compare/DifferentiationInterface-v0.7.16...main)
8+
## [Unreleased](https://github.com/JuliaDiff/DifferentiationInterface.jl/compare/DifferentiationInterface-v0.7.17...main)
9+
10+
## [0.7.17](https://github.com/JuliaDiff/DifferentiationInterface.jl/compare/DifferentiationInterface-v0.7.16...DifferentiationInterface-v0.7.17)
11+
12+
### Fixed
13+
14+
- Make DI compatible with latest Mooncake friendly tangents ([#1001](https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/1001))
15+
- Add docstrings to the result anlysis methods for sparse matrix preparations ([#984](https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/984))
16+
- Make wrong-mode pushforward/pullback return the correct array type ([#974](https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/974))
917

1018
## [0.7.16](https://github.com/JuliaDiff/DifferentiationInterface.jl/compare/DifferentiationInterface-v0.7.15...DifferentiationInterface-v0.7.16)
1119

DifferentiationInterface/Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DifferentiationInterface"
22
uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
3-
version = "0.7.16"
3+
version = "0.7.17"
44
authors = ["Guillaume Dalle", "Adrian Hill"]
55

66
[deps]
@@ -71,7 +71,7 @@ ForwardDiff = "0.10.36,1"
7171
GPUArraysCore = "0.2"
7272
GTPSA = "1.4.0"
7373
LinearAlgebra = "1"
74-
Mooncake = "0.5.1 - 0.5.24"
74+
Mooncake = "0.5.1"
7575
PolyesterForwardDiff = "0.1.2"
7676
ReverseDiff = "1.15.1"
7777
SparseArrays = "1"

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,15 @@ using Mooncake:
2929
NoRData,
3030
primal,
3131
_copy_output,
32-
_copy_to_output!!,
33-
tangent_to_primal!!
32+
_copy_to_output!!
3433

3534
const AnyAutoMooncake{C} = Union{AutoMooncake{C}, AutoMooncakeForward{C}}
3635

3736
DI.check_available(::AnyAutoMooncake{C}) where {C} = true
3837
DI.inner_preparation_behavior(::AutoMooncakeForward) = DI.PrepareInnerSimple()
3938

39+
@inline new_friendly_tangents() = isdefined(Mooncake, :FriendlyTangentCache)
40+
4041
include("utils.jl")
4142
include("onearg.jl")
4243
include("twoarg.jl")

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/utils.jl

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,19 @@ function call_and_return(f!::F, y, x, contexts...) where {F}
99
return y
1010
end
1111

12+
function adaptive_tangent_to_primal!!(primal, tangent)
13+
@static if new_friendly_tangents()
14+
# TODO: optimize performance by allocating cache during prep
15+
return Mooncake.tangent_to_friendly!!(primal, tangent)
16+
else
17+
return Mooncake.tangent_to_primal!!(primal, tangent)
18+
end
19+
end
20+
1221
function zero_tangent_or_primal(x, backend::AnyAutoMooncake)
1322
if get_config(backend).friendly_tangents
1423
# zero(x) but safer
15-
return tangent_to_primal!!(_copy_output(x), zero_tangent(x))
24+
return adaptive_tangent_to_primal!!(_copy_output(x), zero_tangent(x))
1625
else
1726
return zero_tangent(x)
1827
end
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
[deps]
2+
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
3+
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
4+
DifferentiationInterfaceTest = "a82114a7-5aa3-49a8-9643-716bb13727a3"
5+
ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
6+
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
7+
SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5"
8+
SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35"
9+
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
10+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
11+
12+
[compat]
13+
Mooncake = "<0.5.25"
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
../Mooncake/test.jl

DifferentiationInterface/test/Back/Mooncake/Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,6 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1111

1212
[sources]
1313
DifferentiationInterface = { path = "../../.." }
14+
15+
[compat]
16+
Mooncake = ">=0.5.25"

DifferentiationInterface/test/Back/Mooncake/test.jl

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,12 @@ test_differentiation(
7474
@test grad.B == ps.A
7575
end
7676

77-
test_differentiation(
78-
backends[3:4],
79-
nomatrix(static_scenarios());
80-
logging = LOGGING,
81-
excluded = SECOND_ORDER
82-
)
77+
# see https://github.com/JuliaDiff/DifferentiationInterface.jl/issues/986
78+
if pkgversion(Mooncake) < v"0.5.25"
79+
test_differentiation(
80+
backends[3:4],
81+
nomatrix(static_scenarios());
82+
logging = LOGGING,
83+
excluded = SECOND_ORDER
84+
)
85+
end

0 commit comments

Comments
 (0)