forked from JuliaDiff/DifferentiationInterface.jl
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathforward_onearg.jl
More file actions
92 lines (86 loc) · 2.59 KB
/
forward_onearg.jl
File metadata and controls
92 lines (86 loc) · 2.59 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
## Pushforward
struct MooncakeOneArgPushforwardPrep{SIG, Tcache, FT, CT} <: DI.PushforwardPrep{SIG}
_sig::Val{SIG}
cache::Tcache
df::FT
context_tangents::CT
end
function DI.prepare_pushforward_nokwarg(
strict::Val,
f::F,
backend::AutoMooncakeForward,
x,
tx::NTuple,
contexts::Vararg{DI.Context, C}
) where {F, C}
_sig = DI.signature(f, backend, x, tx, contexts...; strict)
config = get_config(backend)
cache = prepare_derivative_cache(f, x, map(DI.unwrap, contexts)...; config)
df = zero_tangent_or_primal(f, backend)
context_tangents = map(zero_tangent_unwrap, contexts)
prep = MooncakeOneArgPushforwardPrep(_sig, cache, df, context_tangents)
return prep
end
function DI.value_and_pushforward(
f::F,
prep::MooncakeOneArgPushforwardPrep,
backend::AutoMooncakeForward,
x::X,
tx::NTuple,
contexts::Vararg{DI.Context, C}
) where {F, C, X}
DI.check_prep(f, prep, backend, x, tx, contexts...)
ys_and_ty = map(tx) do dx
y_and_dy = value_and_derivative!!(
prep.cache,
(f, prep.df),
(x, dx),
map(first_unwrap, contexts, prep.context_tangents)...,
)
y = first(y_and_dy)
dy_raw = last(y_and_dy)
dy = _to_primal_alloc(y, dy_raw)
return y, dy
end
y = _copy_output(first(ys_and_ty[1]))
ty = map(last, ys_and_ty)
return y, ty
end
function DI.pushforward(
f::F,
prep::MooncakeOneArgPushforwardPrep,
backend::AutoMooncakeForward,
x,
tx::NTuple,
contexts::Vararg{DI.Context, C}
) where {F, C}
DI.check_prep(f, prep, backend, x, tx, contexts...)
return DI.value_and_pushforward(f, prep, backend, x, tx, contexts...)[2]
end
function DI.value_and_pushforward!(
f::F,
ty::NTuple,
prep::MooncakeOneArgPushforwardPrep,
backend::AutoMooncakeForward,
x,
tx::NTuple,
contexts::Vararg{DI.Context, C}
) where {F, C}
DI.check_prep(f, prep, backend, x, tx, contexts...)
y, new_ty = DI.value_and_pushforward(f, prep, backend, x, tx, contexts...)
foreach(_to_primal!, ty, new_ty)
return y, ty
end
function DI.pushforward!(
f::F,
ty::NTuple,
prep::MooncakeOneArgPushforwardPrep,
backend::AutoMooncakeForward,
x,
tx::NTuple,
contexts::Vararg{DI.Context, C}
) where {F, C}
DI.check_prep(f, prep, backend, x, tx, contexts...)
DI.value_and_pushforward!(f, ty, prep, backend, x, tx, contexts...)
return ty
end