Skip to content

Commit fb7ea0a

Browse files
committed
Implement a working VJP function
1 parent 50f64a1 commit fb7ea0a

2 files changed

Lines changed: 30 additions & 21 deletions

File tree

ext/LuxReactantExt.jl

Lines changed: 26 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -67,28 +67,38 @@ end
6767
function Lux.__to_reactant_adaptor(
6868
to::Lux.ToReactantAdaptor{FST}, model::AbstractExplicitLayer,
6969
input_prototype, ps, st, eltype_adaptor) where {FST}
70+
output = first(model(input_prototype, ps, st))
71+
concrete_output = __make_concrete_array(output)
72+
7073
concrete_input = __make_concrete_array(input_prototype)
7174
cmodel = __make_concrete_array(model)
7275
cps = __make_concrete_array(ps)
7376
cst = __make_concrete_array(st)
7477

7578
csmodel = Lux.StatefulLuxLayer{FST}(cmodel, cps, cst)
7679

77-
fwd = Reactant.compile((m, x) -> m(x), (csmodel, concrete_input))
78-
79-
bwd = try
80-
enzyme_grad_fn = (m, x) -> begin
81-
dx = Enzyme.make_zero(x)
82-
dps = Enzyme.make_zero(m.ps)
83-
st = ifelse(FST, m.st, m.st_any)
84-
Enzyme.autodiff(
85-
Enzyme.Reverse, (m, x, ps, st) -> first(LuxCore.apply(m, x, ps, st)),
86-
Enzyme.Duplicated, Enzyme.Const(m.model), Enzyme.Duplicated(x, dx),
87-
Enzyme.Duplicated(m.ps, dps), Enzyme.Const(st))
88-
return (; ps=dps), dx
80+
fwd_fn = Reactant.compile((m, x) -> m(x), (csmodel, concrete_input))
81+
82+
function enzyme_vjp_fn(m, x, y, dy)
83+
dx = Enzyme.make_zero(x)
84+
dps = Enzyme.make_zero(m.ps)
85+
st_m = ifelse(FST, m.st, m.st_any)
86+
87+
function wrapper_fn!(y, model, x, ps, st)
88+
copyto!(y, first(LuxCore.apply(model, x, ps, st)))
89+
return nothing
8990
end
9091

91-
Reactant.compile(enzyme_grad_fn, (csmodel, concrete_input))
92+
Enzyme.autodiff(Enzyme.Reverse, wrapper_fn!, Enzyme.Const, Enzyme.Duplicated(y, dy),
93+
Enzyme.Const(m.model), Enzyme.Duplicated(x, dx),
94+
Enzyme.Duplicated(m.ps, dps), Enzyme.Const(st_m))
95+
return dx, dps
96+
end
97+
98+
vjp_fn = try
99+
concrete_output2 = __make_concrete_array(deepcopy(output))
100+
Reactant.compile(
101+
enzyme_vjp_fn, (csmodel, concrete_input, concrete_output, concrete_output2))
92102
catch err
93103
to.force_compile_backward && rethrow(err)
94104
@error """
@@ -101,11 +111,9 @@ function Lux.__to_reactant_adaptor(
101111
nothing
102112
end
103113

104-
# TODO: Add compiled types to the layer type information. That way we can check
105-
# if the model is being executed with the correct types.
106114
return Lux.ReactantLayer{FST, Lux.__recursive_eltype(input_prototype)}(
107-
to, input_prototype, concrete_input, cps, cst, model, cmodel, fwd,
108-
bwd, eltype_adaptor, fmapstructure(Lux.__size, input_prototype))
115+
to, input_prototype, concrete_input, cps, cst, model, cmodel, fwd_fn,
116+
vjp_fn, eltype_adaptor, fmapstructure(Lux.__size, input_prototype))
109117
end
110118

111119
# TODO: Currently we are maintaining 2 copies of the parameters, this is not ideal.
@@ -183,7 +191,7 @@ end
183191

184192
Lux.__apply_reactant(l, x, ps, st) = __graceful_type_mismatch_error(l, x, ps, st)
185193

186-
@inline Lux.__apply_reactant(l::Lux.ReactantLayer, csmodel, x) = l.fwd(csmodel, x)
194+
@inline Lux.__apply_reactant(l::Lux.ReactantLayer, csmodel, x) = l.fwd_fn(csmodel, x)
187195

188196
# Don't inline, else types don't get displayed in the stack trace
189197
function __graceful_type_mismatch_error(

src/layers/extension.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,8 @@ end
245245

246246
# TODO: Add a ChainRules rrule that calls the `bwd` function, i.e. uses Enzyme for the
247247
# gradient computation
248-
@concrete struct ReactantLayer{FST, T, inType, inCType, psType, stType, F, B,
248+
# TODO: Inference won't work OOTB, we will have to compile that separately
249+
@concrete struct ReactantLayer{FST, T, inType, inCType, psType, stType,
249250
L <: AbstractExplicitLayer, AD <: ToReactantAdaptor} <: AbstractExplicitLayer
250251
adaptor::AD
251252
input_prototype::inType
@@ -254,8 +255,8 @@ end
254255
concrete_st::stType
255256
layer::L
256257
clayer
257-
fwd::F
258-
bwd::B
258+
fwd_fn
259+
vjp_fn
259260
eltype_adaptor
260261
input_structure
261262
end

0 commit comments

Comments
 (0)