-
Notifications
You must be signed in to change notification settings - Fork 32
Expand file tree
/
Copy pathonearg.jl
More file actions
80 lines (75 loc) · 2.74 KB
/
onearg.jl
File metadata and controls
80 lines (75 loc) · 2.74 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
struct ReactantGradientPrep{SIG, XR, GR, CG, CG!, CVG, CVG!} <: DI.GradientPrep{SIG}
_sig::Val{SIG}
xr::XR
gr::GR
compiled_gradient::CG
compiled_gradient!::CG!
compiled_value_and_gradient::CVG
compiled_value_and_gradient!::CVG!
end
function DI.prepare_gradient_nokwarg(
strict::Val, f::F, rebackend::AutoReactant, x, contexts::Vararg{DI.Context, C}
) where {F, C}
_sig = DI.signature(f, rebackend, x; strict)
backend = rebackend.mode
xr = to_reac(x)
gr = to_reac(similar(x))
contextsr = map(to_reac, contexts)
compiled_gradient = @compile DI.gradient(f, backend, xr, contextsr...)
compiled_gradient! = @compile DI.gradient!(f, gr, backend, xr, contextsr...)
compiled_value_and_gradient = @compile DI.value_and_gradient(f, backend, xr, contextsr...)
compiled_value_and_gradient! = @compile DI.value_and_gradient!(f, gr, backend, xr, contextsr...)
return ReactantGradientPrep(
_sig,
xr,
gr,
compiled_gradient,
compiled_gradient!,
compiled_value_and_gradient,
compiled_value_and_gradient!,
)
end
function DI.gradient(
f::F, prep::ReactantGradientPrep, rebackend::AutoReactant, x, contexts::Vararg{DI.Context, C}
) where {F, C}
DI.check_prep(f, prep, rebackend, x)
backend = rebackend.mode
(; xr, compiled_gradient) = prep
copyto!(xr, x)
contextsr = map(to_reac, contexts)
gr = compiled_gradient(f, backend, xr, contextsr...)
return gr
end
function DI.value_and_gradient(
f::F, prep::ReactantGradientPrep, rebackend::AutoReactant, x, contexts::Vararg{DI.Context, C}
) where {F, C}
DI.check_prep(f, prep, rebackend, x)
backend = rebackend.mode
(; xr, compiled_value_and_gradient) = prep
copyto!(xr, x)
contextsr = map(to_reac, contexts)
yr, gr = compiled_value_and_gradient(f, backend, xr, contextsr...)
return yr, gr
end
function DI.gradient!(
f::F, grad, prep::ReactantGradientPrep, rebackend::AutoReactant, x, contexts::Vararg{DI.Context, C}
) where {F, C}
DI.check_prep(f, prep, rebackend, x)
backend = rebackend.mode
(; xr, gr, compiled_gradient!) = prep
copyto!(xr, x)
contextsr = map(to_reac, contexts)
compiled_gradient!(f, gr, backend, xr, contextsr...)
return copyto!(grad, gr)
end
function DI.value_and_gradient!(
f::F, grad, prep::ReactantGradientPrep, rebackend::AutoReactant, x, contexts::Vararg{DI.Context, C}
) where {F, C}
DI.check_prep(f, prep, rebackend, x)
backend = rebackend.mode
(; xr, gr, compiled_value_and_gradient!) = prep
copyto!(xr, x)
contextsr = map(to_reac, contexts)
yr, gr = compiled_value_and_gradient!(f, gr, backend, xr, contextsr...)
return yr, copyto!(grad, gr)
end