6767function Lux. __to_reactant_adaptor (
6868 to:: Lux.ToReactantAdaptor{FST} , model:: AbstractExplicitLayer ,
6969 input_prototype, ps, st, eltype_adaptor) where {FST}
70+ output = first (model (input_prototype, ps, st))
71+ concrete_output = __make_concrete_array (output)
72+
7073 concrete_input = __make_concrete_array (input_prototype)
7174 cmodel = __make_concrete_array (model)
7275 cps = __make_concrete_array (ps)
7376 cst = __make_concrete_array (st)
7477
7578 csmodel = Lux. StatefulLuxLayer {FST} (cmodel, cps, cst)
7679
77- fwd = Reactant. compile ((m, x) -> m (x), (csmodel, concrete_input))
78-
79- bwd = try
80- enzyme_grad_fn = (m, x) -> begin
81- dx = Enzyme. make_zero (x)
82- dps = Enzyme. make_zero (m. ps)
83- st = ifelse (FST, m. st, m. st_any)
84- Enzyme. autodiff (
85- Enzyme. Reverse, (m, x, ps, st) -> first (LuxCore. apply (m, x, ps, st)),
86- Enzyme. Duplicated, Enzyme. Const (m. model), Enzyme. Duplicated (x, dx),
87- Enzyme. Duplicated (m. ps, dps), Enzyme. Const (st))
88- return (; ps= dps), dx
80+ fwd_fn = Reactant. compile ((m, x) -> m (x), (csmodel, concrete_input))
81+
82+ function enzyme_vjp_fn (m, x, y, dy)
83+ dx = Enzyme. make_zero (x)
84+ dps = Enzyme. make_zero (m. ps)
85+ st_m = ifelse (FST, m. st, m. st_any)
86+
87+ function wrapper_fn! (y, model, x, ps, st)
88+ copyto! (y, first (LuxCore. apply (model, x, ps, st)))
89+ return nothing
8990 end
9091
91- Reactant. compile (enzyme_grad_fn, (csmodel, concrete_input))
92+ Enzyme. autodiff (Enzyme. Reverse, wrapper_fn!, Enzyme. Const, Enzyme. Duplicated (y, dy),
93+ Enzyme. Const (m. model), Enzyme. Duplicated (x, dx),
94+ Enzyme. Duplicated (m. ps, dps), Enzyme. Const (st_m))
95+ return dx, dps
96+ end
97+
98+ vjp_fn = try
99+ concrete_output2 = __make_concrete_array (deepcopy (output))
100+ Reactant. compile (
101+ enzyme_vjp_fn, (csmodel, concrete_input, concrete_output, concrete_output2))
92102 catch err
93103 to. force_compile_backward && rethrow (err)
94104 @error """
@@ -101,11 +111,9 @@ function Lux.__to_reactant_adaptor(
101111 nothing
102112 end
103113
104- # TODO : Add compiled types to the layer type information. That way we can check
105- # if the model is being executed with the correct types.
106114 return Lux. ReactantLayer {FST, Lux.__recursive_eltype(input_prototype)} (
107- to, input_prototype, concrete_input, cps, cst, model, cmodel, fwd ,
108- bwd , eltype_adaptor, fmapstructure (Lux. __size, input_prototype))
115+ to, input_prototype, concrete_input, cps, cst, model, cmodel, fwd_fn ,
116+ vjp_fn , eltype_adaptor, fmapstructure (Lux. __size, input_prototype))
109117end
110118
111119# TODO : Currently we are maintaining 2 copies of the parameters, this is not ideal.
183191
184192Lux. __apply_reactant (l, x, ps, st) = __graceful_type_mismatch_error (l, x, ps, st)
185193
186- @inline Lux. __apply_reactant (l:: Lux.ReactantLayer , csmodel, x) = l. fwd (csmodel, x)
194+ @inline Lux. __apply_reactant (l:: Lux.ReactantLayer , csmodel, x) = l. fwd_fn (csmodel, x)
187195
188196# Don't inline, else types don't get displayed in the stack trace
189197function __graceful_type_mismatch_error (
0 commit comments