-
Notifications
You must be signed in to change notification settings - Fork 32
Expand file tree
/
Copy pathtest.jl
More file actions
82 lines (70 loc) · 2.02 KB
/
test.jl
File metadata and controls
82 lines (70 loc) · 2.02 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
include("../../testutils.jl")
using DifferentiationInterface, DifferentiationInterfaceTest
using Mooncake: Mooncake
using Test
using ExplicitImports
check_no_implicit_imports(DifferentiationInterface)
nomatrix(scens) = filter(s -> !(s.x isa AbstractMatrix) && !(s.y isa AbstractMatrix), scens)
backends = [
AutoMooncake(),
AutoMooncakeForward(),
AutoMooncake(; config = Mooncake.Config(; friendly_tangents = true)),
AutoMooncakeForward(; config = Mooncake.Config(; friendly_tangents = true)),
]
for backend in backends
@test check_available(backend)
@test check_inplace(backend)
end
test_differentiation(
backends[3:4],
default_scenarios();
excluded = SECOND_ORDER,
logging = LOGGING,
);
test_differentiation(
backends[3:4],
nomatrix(
default_scenarios(;
include_normal = false,
include_constantified = true,
include_cachified = true,
use_tuples = true
)
);
excluded = SECOND_ORDER,
logging = LOGGING,
);
test_differentiation(
backends[1:2],
nomatrix(default_scenarios());
excluded = SECOND_ORDER,
logging = LOGGING,
);
EXCLUDED = @static if VERSION ≥ v"1.11-" && VERSION ≤ v"1.12-"
# testing only :hessian on 1.11 due to an opaque closure bug.
# this is potentially the same issue as discussed in
# https://github.com/chalk-lab/MistyClosures.jl/pull/12#issue-3278662295
[FIRST_ORDER..., :hvp, :second_derivative]
else
[FIRST_ORDER...]
end
# Test second-order differentiation (forward-over-reverse)
test_differentiation(
[SecondOrder(AutoMooncakeForward(), AutoMooncake())],
nomatrix(default_scenarios());
excluded = EXCLUDED,
logging = LOGGING,
)
@testset "NamedTuples" begin
ps = (; A = rand(5), B = rand(5))
myfun(ps) = sum(ps.A .* ps.B)
grad = gradient(myfun, backends[1], ps)
@test grad.A == ps.B
@test grad.B == ps.A
end
test_differentiation(
backends[3:4],
nomatrix(static_scenarios());
logging = LOGGING,
excluded = SECOND_ORDER
)