From 405aa42f770b5fb76bfe5e8d3f488bb0577df4fe Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 26 May 2024 20:03:46 -0700 Subject: [PATCH 01/15] Auto compile Lux models to reactant --- Project.toml | 3 +++ ext/LuxReactantExt.jl | 48 +++++++++++++++++++++++++++++++++++++++ src/Lux.jl | 2 ++ src/layers/extension.jl | 9 ++++++++ src/transform/reactant.jl | 13 +++++++++++ 5 files changed, 75 insertions(+) create mode 100644 ext/LuxReactantExt.jl create mode 100644 src/transform/reactant.jl diff --git a/Project.toml b/Project.toml index a588fa1ba3..8508bba2f7 100644 --- a/Project.toml +++ b/Project.toml @@ -40,6 +40,7 @@ MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195" NCCL = "3fe64909-d7a1-4096-9b7d-7a0f12cf0f6b" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" +Reactant = "3c362404-f566-11ee-1572-e11a4b42c853" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" SimpleChains = "de6bee2f-e2f4-4ec7-b6ed-219cc6f6e9e5" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" @@ -57,6 +58,7 @@ LuxMLUtilsExt = "MLUtils" LuxMPIExt = "MPI" LuxMPINCCLExt = ["CUDA", "MPI", "NCCL"] LuxOptimisersExt = "Optimisers" +LuxReactantExt = "Reactant" LuxReverseDiffExt = "ReverseDiff" LuxSimpleChainsExt = "SimpleChains" LuxTrackerExt = "Tracker" @@ -102,6 +104,7 @@ Pkg = "1.10" PrecompileTools = "1.2" Preferences = "1.4.3" Random = "1.10" +Reactant = "0.1.1" ReTestItems = "1.23.1" Reexport = "1.2.2" ReverseDiff = "1.15" diff --git a/ext/LuxReactantExt.jl b/ext/LuxReactantExt.jl new file mode 100644 index 0000000000..1c39ce1d12 --- /dev/null +++ b/ext/LuxReactantExt.jl @@ -0,0 +1,48 @@ +module LuxReactantExt + +using ArgCheck: @argcheck +using Random: AbstractRNG, Xoshiro +using Reactant: Reactant +using Lux: Lux +using LuxCore: LuxCore, AbstractExplicitLayer + +@inline __make_concrete_array(x::Reactant.ConcreteRArray) = x +@inline __make_concrete_array(x::AbstractArray) = Reactant.ConcreteRArray(x) +@inline function __make_concrete_array(x) + return Reactant.make_tracer(IdDict(), x, (), Reactant.ArrayToConcrete, nothing) +end + +# FIXME: currently only `stateless_apply` is supported: https://github.com/EnzymeAD/Reactant.jl/issues/8 +function Lux.__to_reactant_adaptor(model::AbstractExplicitLayer, input_prototype) + concrete_input = __make_concrete_array(input_prototype) + cmodel = __make_concrete_array(model) + # We generate fake parameters and states to compile the model + ps = LuxCore.initialparameters(Xoshiro(123), model) + cps = __make_concrete_array(ps) + + st = LuxCore.initialstates(Xoshiro(123), model) + @argcheck st==LuxCore._getemptystate(model) "Currently only stateless models are supported." + + fwd = Reactant.compile( + (m, x, ps) -> LuxCore.stateless_apply(m, x, ps), (cmodel, concrete_input, cps)) + + # TODO: conditionally compile the backward pass + + return Lux.ReactantLayer(model, cmodel, fwd, nothing) +end + +function LuxCore.initialparameters(rng::AbstractRNG, layer::Lux.ReactantLayer) + return __make_concrete_array(LuxCore.initialparameters(rng, layer.layer)) +end + +# FIXME: Change once https://github.com/EnzymeAD/Reactant.jl/pull/8 is fixed +function LuxCore.initialstates(::AbstractRNG, layer::Lux.ReactantLayer) + return NamedTuple() # __make_concrete_array(LuxCore.initialstates(rng, layer.layer)) +end + +# TODO: Add a type assert here to make it type stable +function (l::Lux.ReactantLayer)(x, ps, ::NamedTuple{()}) + return LuxCore.stateless_apply(l.clayer, __make_concrete_array(x), ps), NamedTuple() +end + +end diff --git a/src/Lux.jl b/src/Lux.jl index 92695f60e2..89cb5c21c8 100644 --- a/src/Lux.jl +++ b/src/Lux.jl @@ -76,6 +76,7 @@ include("helpers/nested_ad.jl") include("transform/types.jl") include("transform/flux.jl") include("transform/simplechains.jl") +include("transform/reactant.jl") # Distributed Training include("distributed/backend.jl") @@ -110,6 +111,7 @@ export f16, f32, f64 export transform export FromFluxAdaptor, FluxLayer export ToSimpleChainsAdaptor, SimpleChainsLayer +export ToReactantAdaptor export DynamicExpressionsLayer export MPIBackend, NCCLBackend, DistributedUtils diff --git a/src/layers/extension.jl b/src/layers/extension.jl index 398a44f55b..c2d82d764e 100644 --- a/src/layers/extension.jl +++ b/src/layers/extension.jl @@ -248,3 +248,12 @@ function CRC.rrule(::typeof(__apply_simple_chain), layer, x, ps, ::LuxCPUDevice) end return res, __∇apply_simple_chain end + +# TODO: Add a ChainRules rrule that calls the `bwd` function, i.e. uses Enzyme for the +# gradient computation +@concrete struct ReactantLayer{F, B, L <: AbstractExplicitLayer} <: AbstractExplicitLayer + layer::L + clayer + fwd::F + bwd::B +end diff --git a/src/transform/reactant.jl b/src/transform/reactant.jl new file mode 100644 index 0000000000..fb21336fde --- /dev/null +++ b/src/transform/reactant.jl @@ -0,0 +1,13 @@ +# TODO: Add options to compile the gradients directly using Enzyme.jl +@concrete struct ToReactantAdaptor <: AbstractFromLuxAdaptor + input_prototype +end + +function Adapt.adapt(to::ToReactantAdaptor, model::AbstractExplicitLayer) + if Base.get_extension(@__MODULE__, :LuxReactantExt) === nothing + error("`ToReactantAdaptor` requires `LuxReactantExt.jl` to be loaded.") + end + return __to_reactant_adaptor(model, to.input_prototype) +end + +function __to_reactant_adaptor end From e5446908bc32eaf306140a1c4a99d6ebf9a499ed Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 26 May 2024 20:30:10 -0700 Subject: [PATCH 02/15] Work around the returning states issue --- Project.toml | 2 +- ext/LuxReactantExt.jl | 22 +++++++++++----------- src/transform/reactant.jl | 1 - 3 files changed, 12 insertions(+), 13 deletions(-) diff --git a/Project.toml b/Project.toml index 8508bba2f7..0b74e662a3 100644 --- a/Project.toml +++ b/Project.toml @@ -58,7 +58,7 @@ LuxMLUtilsExt = "MLUtils" LuxMPIExt = "MPI" LuxMPINCCLExt = ["CUDA", "MPI", "NCCL"] LuxOptimisersExt = "Optimisers" -LuxReactantExt = "Reactant" +LuxReactantExt = ["Enzyme", "Reactant"] LuxReverseDiffExt = "ReverseDiff" LuxSimpleChainsExt = "SimpleChains" LuxTrackerExt = "Tracker" diff --git a/ext/LuxReactantExt.jl b/ext/LuxReactantExt.jl index 1c39ce1d12..8c5c37a379 100644 --- a/ext/LuxReactantExt.jl +++ b/ext/LuxReactantExt.jl @@ -1,6 +1,7 @@ module LuxReactantExt using ArgCheck: @argcheck +using Enzyme: Enzyme using Random: AbstractRNG, Xoshiro using Reactant: Reactant using Lux: Lux @@ -12,21 +13,20 @@ using LuxCore: LuxCore, AbstractExplicitLayer return Reactant.make_tracer(IdDict(), x, (), Reactant.ArrayToConcrete, nothing) end -# FIXME: currently only `stateless_apply` is supported: https://github.com/EnzymeAD/Reactant.jl/issues/8 function Lux.__to_reactant_adaptor(model::AbstractExplicitLayer, input_prototype) concrete_input = __make_concrete_array(input_prototype) cmodel = __make_concrete_array(model) + # We generate fake parameters and states to compile the model ps = LuxCore.initialparameters(Xoshiro(123), model) cps = __make_concrete_array(ps) st = LuxCore.initialstates(Xoshiro(123), model) - @argcheck st==LuxCore._getemptystate(model) "Currently only stateless models are supported." + cst = __make_concrete_array(st) - fwd = Reactant.compile( - (m, x, ps) -> LuxCore.stateless_apply(m, x, ps), (cmodel, concrete_input, cps)) + csmodel = Lux.StatefulLuxLayer{false}(cmodel, cps, cst) - # TODO: conditionally compile the backward pass + fwd = Reactant.compile((m, x) -> m(x), (csmodel, concrete_input)) return Lux.ReactantLayer(model, cmodel, fwd, nothing) end @@ -35,14 +35,14 @@ function LuxCore.initialparameters(rng::AbstractRNG, layer::Lux.ReactantLayer) return __make_concrete_array(LuxCore.initialparameters(rng, layer.layer)) end -# FIXME: Change once https://github.com/EnzymeAD/Reactant.jl/pull/8 is fixed -function LuxCore.initialstates(::AbstractRNG, layer::Lux.ReactantLayer) - return NamedTuple() # __make_concrete_array(LuxCore.initialstates(rng, layer.layer)) +function LuxCore.initialstates(rng::AbstractRNG, layer::Lux.ReactantLayer) + return __make_concrete_array(LuxCore.initialstates(rng, layer.layer)) end -# TODO: Add a type assert here to make it type stable -function (l::Lux.ReactantLayer)(x, ps, ::NamedTuple{()}) - return LuxCore.stateless_apply(l.clayer, __make_concrete_array(x), ps), NamedTuple() +function (l::Lux.ReactantLayer)(x, ps, st::NamedTuple) + csmodel = Lux.StatefulLuxLayer{false}(l.clayer, ps, st) + y = l.fwd(csmodel, __make_concrete_array(x)) + return y, csmodel.st_any end end diff --git a/src/transform/reactant.jl b/src/transform/reactant.jl index fb21336fde..4b1ffdfb03 100644 --- a/src/transform/reactant.jl +++ b/src/transform/reactant.jl @@ -1,4 +1,3 @@ -# TODO: Add options to compile the gradients directly using Enzyme.jl @concrete struct ToReactantAdaptor <: AbstractFromLuxAdaptor input_prototype end From 018d7a0a8c59da78a9a84c541adf58e17a059632 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 26 May 2024 20:37:23 -0700 Subject: [PATCH 03/15] Make the states type stable --- ext/LuxReactantExt.jl | 15 ++++++++------- src/layers/extension.jl | 3 ++- src/transform/reactant.jl | 4 ++-- 3 files changed, 12 insertions(+), 10 deletions(-) diff --git a/ext/LuxReactantExt.jl b/ext/LuxReactantExt.jl index 8c5c37a379..a062e54622 100644 --- a/ext/LuxReactantExt.jl +++ b/ext/LuxReactantExt.jl @@ -13,8 +13,9 @@ using LuxCore: LuxCore, AbstractExplicitLayer return Reactant.make_tracer(IdDict(), x, (), Reactant.ArrayToConcrete, nothing) end -function Lux.__to_reactant_adaptor(model::AbstractExplicitLayer, input_prototype) - concrete_input = __make_concrete_array(input_prototype) +function Lux.__to_reactant_adaptor( + to::Lux.ToReactantAdaptor{FST}, model::AbstractExplicitLayer) where {FST} + concrete_input = __make_concrete_array(to.input_prototype) cmodel = __make_concrete_array(model) # We generate fake parameters and states to compile the model @@ -24,11 +25,11 @@ function Lux.__to_reactant_adaptor(model::AbstractExplicitLayer, input_prototype st = LuxCore.initialstates(Xoshiro(123), model) cst = __make_concrete_array(st) - csmodel = Lux.StatefulLuxLayer{false}(cmodel, cps, cst) + csmodel = Lux.StatefulLuxLayer{FST}(cmodel, cps, cst) fwd = Reactant.compile((m, x) -> m(x), (csmodel, concrete_input)) - return Lux.ReactantLayer(model, cmodel, fwd, nothing) + return Lux.ReactantLayer{FST}(model, cmodel, fwd, nothing) end function LuxCore.initialparameters(rng::AbstractRNG, layer::Lux.ReactantLayer) @@ -39,10 +40,10 @@ function LuxCore.initialstates(rng::AbstractRNG, layer::Lux.ReactantLayer) return __make_concrete_array(LuxCore.initialstates(rng, layer.layer)) end -function (l::Lux.ReactantLayer)(x, ps, st::NamedTuple) - csmodel = Lux.StatefulLuxLayer{false}(l.clayer, ps, st) +function (l::Lux.ReactantLayer{FST})(x, ps, st::NamedTuple) where {FST} + csmodel = Lux.StatefulLuxLayer{FST}(l.clayer, ps, st) y = l.fwd(csmodel, __make_concrete_array(x)) - return y, csmodel.st_any + return y, ifelse(FST, csmodel.st, csmodel.st_any) end end diff --git a/src/layers/extension.jl b/src/layers/extension.jl index c2d82d764e..967ba719f5 100644 --- a/src/layers/extension.jl +++ b/src/layers/extension.jl @@ -251,7 +251,8 @@ end # TODO: Add a ChainRules rrule that calls the `bwd` function, i.e. uses Enzyme for the # gradient computation -@concrete struct ReactantLayer{F, B, L <: AbstractExplicitLayer} <: AbstractExplicitLayer +@concrete struct ReactantLayer{FST, F, B, L <: AbstractExplicitLayer} <: + AbstractExplicitLayer layer::L clayer fwd::F diff --git a/src/transform/reactant.jl b/src/transform/reactant.jl index 4b1ffdfb03..b952b56278 100644 --- a/src/transform/reactant.jl +++ b/src/transform/reactant.jl @@ -1,4 +1,4 @@ -@concrete struct ToReactantAdaptor <: AbstractFromLuxAdaptor +@concrete struct ToReactantAdaptor{FST} <: AbstractFromLuxAdaptor input_prototype end @@ -6,7 +6,7 @@ function Adapt.adapt(to::ToReactantAdaptor, model::AbstractExplicitLayer) if Base.get_extension(@__MODULE__, :LuxReactantExt) === nothing error("`ToReactantAdaptor` requires `LuxReactantExt.jl` to be loaded.") end - return __to_reactant_adaptor(model, to.input_prototype) + return __to_reactant_adaptor(to, model) end function __to_reactant_adaptor end From 7ec54e71c921a082abab7071cdd898442016583a Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 26 May 2024 20:51:39 -0700 Subject: [PATCH 04/15] Try using Enzyme for the backward pass --- ext/LuxReactantExt.jl | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/ext/LuxReactantExt.jl b/ext/LuxReactantExt.jl index a062e54622..2d2df75ef6 100644 --- a/ext/LuxReactantExt.jl +++ b/ext/LuxReactantExt.jl @@ -29,7 +29,26 @@ function Lux.__to_reactant_adaptor( fwd = Reactant.compile((m, x) -> m(x), (csmodel, concrete_input)) - return Lux.ReactantLayer{FST}(model, cmodel, fwd, nothing) + bwd = try + enzyme_grad_fn = (m, x) -> begin + dx = Enzyme.make_zero(x) + dps = Enzyme.make_zero(m.ps) + st = ifelse(FST, m.st, m.st_any) + Enzyme.autodiff( + Enzyme.Reverse, (m, x, ps, st) -> first(LuxCore.apply(m, x, ps, st)), + Enzyme.Duplicated, Enzyme.Const(m), Enzyme.Duplicated(x, dx), + Enzyme.Duplicated(ps, dps), Enzyme.Const(st)) + return (; ps=dps), dx + end + + Reactant.compile(enzyme_grad_fn, (csmodel, concrete_input)) + catch err + @error "Enzyme failed to compile the backward pass. Differentiation will be \ + disabled for this model." exception=err + nothing + end + + return Lux.ReactantLayer{FST}(model, cmodel, fwd, bwd) end function LuxCore.initialparameters(rng::AbstractRNG, layer::Lux.ReactantLayer) From 2414d99d9e1cdefee95bd095a8e692cd4965ae53 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 27 May 2024 12:14:56 -0700 Subject: [PATCH 05/15] Add a force backward compile mode --- ext/LuxReactantExt.jl | 1 + src/transform/reactant.jl | 9 +++++++++ 2 files changed, 10 insertions(+) diff --git a/ext/LuxReactantExt.jl b/ext/LuxReactantExt.jl index 2d2df75ef6..6bc8a47728 100644 --- a/ext/LuxReactantExt.jl +++ b/ext/LuxReactantExt.jl @@ -43,6 +43,7 @@ function Lux.__to_reactant_adaptor( Reactant.compile(enzyme_grad_fn, (csmodel, concrete_input)) catch err + to.force_compile_backward && rethrow(err) @error "Enzyme failed to compile the backward pass. Differentiation will be \ disabled for this model." exception=err nothing diff --git a/src/transform/reactant.jl b/src/transform/reactant.jl index b952b56278..e1d7633da4 100644 --- a/src/transform/reactant.jl +++ b/src/transform/reactant.jl @@ -1,5 +1,14 @@ @concrete struct ToReactantAdaptor{FST} <: AbstractFromLuxAdaptor input_prototype + force_compile_backward::Bool +end + +function ToReactantAdaptor{FST}( + input_prototype; force_compile_backward::Bool=false) where {FST} + return ToReactantAdaptor{FST}(input_prototype, force_compile_backward) +end +function ToReactantAdaptor(args...; fixed_state_type::Val=Val(true), kwargs...) + return ToReactantAdaptor{__unwrap_val(fixed_state_type)}(args...; kwargs...) end function Adapt.adapt(to::ToReactantAdaptor, model::AbstractExplicitLayer) From a7bf7a66ef8d21c018663243cf5c9f8141dc7600 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 27 May 2024 12:47:51 -0700 Subject: [PATCH 06/15] Special handling for mixed eltypes in Reactant --- ext/LuxReactantExt.jl | 71 ++++++++++++++++++++++++++++++++------- src/layers/extension.jl | 1 + src/transform/reactant.jl | 8 +++-- src/utils.jl | 2 +- 4 files changed, 66 insertions(+), 16 deletions(-) diff --git a/ext/LuxReactantExt.jl b/ext/LuxReactantExt.jl index 6bc8a47728..f19ef4354b 100644 --- a/ext/LuxReactantExt.jl +++ b/ext/LuxReactantExt.jl @@ -1,10 +1,11 @@ module LuxReactantExt +using Adapt: adapt using ArgCheck: @argcheck using Enzyme: Enzyme using Random: AbstractRNG, Xoshiro using Reactant: Reactant -using Lux: Lux +using Lux: Lux, LuxEltypeAdaptor using LuxCore: LuxCore, AbstractExplicitLayer @inline __make_concrete_array(x::Reactant.ConcreteRArray) = x @@ -13,16 +14,52 @@ using LuxCore: LuxCore, AbstractExplicitLayer return Reactant.make_tracer(IdDict(), x, (), Reactant.ArrayToConcrete, nothing) end +# Reactant doesn't handle mixed eltypes that well, so we will first try to compile it as +# a usual julia function. However, if that fails, we will type cast and try to recompile. +# Note that this is only a one time operation so it doesn't matter if this step is too slow. function Lux.__to_reactant_adaptor( to::Lux.ToReactantAdaptor{FST}, model::AbstractExplicitLayer) where {FST} - concrete_input = __make_concrete_array(to.input_prototype) - cmodel = __make_concrete_array(model) + input_prototype = to.input_prototype + input_eltype = Lux.__recursive_eltype(input_prototype) + ps, st = Lux.setup(Xoshiro(123), model) # We generate fake parameters and states to compile the model + ps_eltype = Lux.__recursive_eltype(ps) + st_eltype = Lux.__recursive_eltype(st) - # We generate fake parameters and states to compile the model - ps = LuxCore.initialparameters(Xoshiro(123), model) - cps = __make_concrete_array(ps) + newT = promote_type(input_eltype, ps_eltype, st_eltype) + eltype_adaptor = LuxEltypeAdaptor{newT}() + + if !to.force_allow_mixed_eltypes && + any(x -> x != newT && x != Union{}, (input_eltype, ps_eltype, st_eltype)) + # Try compiling, but this might fail + try + return Lux.__to_reactant_adaptor(to, model, input_prototype, ps, st, nothing) + catch err + @warn """ + Mixed Eltypes detected. Failure is NOT unexpected. Trying to recompile with a \ + common eltype. + + HINT: To force compiling the mixed eltypes, set \ + `force_allow_mixed_eltypes=true` in the constructor of `ToReactantAdaptor`. + + If compilation succeeds, all inputs to the compiled model will be \ + automatically type casted to the common eltype.\n + """ exception=err input_eltype ps_eltype st_eltype common_eltype=newT + end + + input_prototype = adapt(eltype_adaptor, to.input_prototype) + ps = adapt(eltype_adaptor, ps) + st = adapt(eltype_adaptor, st) + end + + return Lux.__to_reactant_adaptor(to, model, input_prototype, ps, st, eltype_adaptor) +end - st = LuxCore.initialstates(Xoshiro(123), model) +function Lux.__to_reactant_adaptor( + to::Lux.ToReactantAdaptor{FST}, model::AbstractExplicitLayer, + input_prototype, ps, st, eltype_adaptor) where {FST} + concrete_input = __make_concrete_array(input_prototype) + cmodel = __make_concrete_array(model) + cps = __make_concrete_array(ps) cst = __make_concrete_array(st) csmodel = Lux.StatefulLuxLayer{FST}(cmodel, cps, cst) @@ -44,24 +81,34 @@ function Lux.__to_reactant_adaptor( Reactant.compile(enzyme_grad_fn, (csmodel, concrete_input)) catch err to.force_compile_backward && rethrow(err) - @error "Enzyme failed to compile the backward pass. Differentiation will be \ - disabled for this model." exception=err + @error """ + Enzyme failed to compile the backward pass. Differentiation will be disabled for \ + this model. + + HINT: To force compilation of the backward pass, set `force_compile_backward=true` \ + in the constructor of `ToReactantAdaptor`.\n + """ exception=err nothing end - return Lux.ReactantLayer{FST}(model, cmodel, fwd, bwd) + return Lux.ReactantLayer{FST}(model, cmodel, fwd, bwd, eltype_adaptor) end function LuxCore.initialparameters(rng::AbstractRNG, layer::Lux.ReactantLayer) - return __make_concrete_array(LuxCore.initialparameters(rng, layer.layer)) + ps = LuxCore.initialparameters(rng, layer.layer) + layer.eltype_adaptor !== nothing && (ps = adapt(layer.eltype_adaptor, ps)) + return __make_concrete_array(ps) end function LuxCore.initialstates(rng::AbstractRNG, layer::Lux.ReactantLayer) - return __make_concrete_array(LuxCore.initialstates(rng, layer.layer)) + st = LuxCore.initialstates(rng, layer.layer) + layer.eltype_adaptor !== nothing && (st = adapt(layer.eltype_adaptor, st)) + return __make_concrete_array(st) end function (l::Lux.ReactantLayer{FST})(x, ps, st::NamedTuple) where {FST} csmodel = Lux.StatefulLuxLayer{FST}(l.clayer, ps, st) + l.eltype_adaptor !== nothing && (x = adapt(l.eltype_adaptor, x)) y = l.fwd(csmodel, __make_concrete_array(x)) return y, ifelse(FST, csmodel.st, csmodel.st_any) end diff --git a/src/layers/extension.jl b/src/layers/extension.jl index 967ba719f5..4c66d6ca46 100644 --- a/src/layers/extension.jl +++ b/src/layers/extension.jl @@ -257,4 +257,5 @@ end clayer fwd::F bwd::B + eltype_adaptor end diff --git a/src/transform/reactant.jl b/src/transform/reactant.jl index e1d7633da4..ec755850e5 100644 --- a/src/transform/reactant.jl +++ b/src/transform/reactant.jl @@ -1,11 +1,13 @@ @concrete struct ToReactantAdaptor{FST} <: AbstractFromLuxAdaptor input_prototype force_compile_backward::Bool + force_allow_mixed_eltypes::Bool end -function ToReactantAdaptor{FST}( - input_prototype; force_compile_backward::Bool=false) where {FST} - return ToReactantAdaptor{FST}(input_prototype, force_compile_backward) +function ToReactantAdaptor{FST}(input_prototype; force_compile_backward::Bool=false, + force_allow_mixed_eltypes::Bool=false) where {FST} + return ToReactantAdaptor{FST}( + input_prototype, force_compile_backward, force_allow_mixed_eltypes) end function ToReactantAdaptor(args...; fixed_state_type::Val=Val(true), kwargs...) return ToReactantAdaptor{__unwrap_val(fixed_state_type)}(args...; kwargs...) diff --git a/src/utils.jl b/src/utils.jl index 9772d724c5..f2ba54a743 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -267,7 +267,7 @@ end # recussive_eltype @inline __recursive_eltype(x::AbstractArray) = eltype(x) -@inline __recursive_eltype(x::Tuple) = promote_type(__recursice_eltype.(x)...) +@inline __recursive_eltype(x::Tuple) = promote_type(__recursive_eltype.(x)...) @inline __recursive_eltype(x::NamedTuple) = promote_type(__recursive_eltype.(values(x))...) @inline __recursive_eltype(::Nothing) = Bool @inline __recursive_eltype(x::Number) = eltype(x) From 7ea236f41eb875c2abdca67d37a9182a74c05796 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 27 May 2024 13:35:56 -0700 Subject: [PATCH 07/15] Add checks to validate matching structure --- ext/LuxReactantExt.jl | 34 ++++++++++++++++++++++++++++++---- src/layers/extension.jl | 3 ++- src/transform/reactant.jl | 1 + 3 files changed, 33 insertions(+), 5 deletions(-) diff --git a/ext/LuxReactantExt.jl b/ext/LuxReactantExt.jl index f19ef4354b..6914c136bb 100644 --- a/ext/LuxReactantExt.jl +++ b/ext/LuxReactantExt.jl @@ -3,6 +3,7 @@ module LuxReactantExt using Adapt: adapt using ArgCheck: @argcheck using Enzyme: Enzyme +using Functors: fmapstructure using Random: AbstractRNG, Xoshiro using Reactant: Reactant using Lux: Lux, LuxEltypeAdaptor @@ -26,7 +27,7 @@ function Lux.__to_reactant_adaptor( st_eltype = Lux.__recursive_eltype(st) newT = promote_type(input_eltype, ps_eltype, st_eltype) - eltype_adaptor = LuxEltypeAdaptor{newT}() + eltype_adaptor = nothing if !to.force_allow_mixed_eltypes && any(x -> x != newT && x != Union{}, (input_eltype, ps_eltype, st_eltype)) @@ -46,6 +47,7 @@ function Lux.__to_reactant_adaptor( """ exception=err input_eltype ps_eltype st_eltype common_eltype=newT end + eltype_adaptor = LuxEltypeAdaptor{newT}() input_prototype = adapt(eltype_adaptor, to.input_prototype) ps = adapt(eltype_adaptor, ps) st = adapt(eltype_adaptor, st) @@ -91,7 +93,8 @@ function Lux.__to_reactant_adaptor( nothing end - return Lux.ReactantLayer{FST}(model, cmodel, fwd, bwd, eltype_adaptor) + return Lux.ReactantLayer{FST, Lux.__recursive_eltype(input_prototype)}( + model, cmodel, fwd, bwd, eltype_adaptor, fmapstructure(Lux.__size, input_prototype)) end function LuxCore.initialparameters(rng::AbstractRNG, layer::Lux.ReactantLayer) @@ -106,11 +109,34 @@ function LuxCore.initialstates(rng::AbstractRNG, layer::Lux.ReactantLayer) return __make_concrete_array(st) end -function (l::Lux.ReactantLayer{FST})(x, ps, st::NamedTuple) where {FST} +function (l::Lux.ReactantLayer{FST, T})(x, ps, st::NamedTuple) where {FST, T} csmodel = Lux.StatefulLuxLayer{FST}(l.clayer, ps, st) l.eltype_adaptor !== nothing && (x = adapt(l.eltype_adaptor, x)) - y = l.fwd(csmodel, __make_concrete_array(x)) + + # XLARuntimeError is not great, so check and terminate early if needed + input_structure = fmapstructure(Lux.__size, x) + if l.input_structure != input_structure + throw(DimensionMismatch(lazy"Input structure mismatch. Expected $(l.input_structure), got $(input_structure).")) + end + + # TODO: For non array inputs this we make the eltype uniform which might not be + # desirable. We should handle those cases with `fmap` + if T != Lux.__recursive_eltype(x) + @warn """ + `Reactant.compile` was called with input eltype $(T) but the current input eltype \ + is $(Lux.__recursive_eltype(x)). This might lead to unexpected behavior. + + We will convert the input to $(T) and continue. If you want to avoid this, please \ + recompile the model with the correct input eltype. + """ maxlog=1 + x = adapt(LuxEltypeAdaptor{T}(), x) + end + + y = Lux.__apply_reactant(l, csmodel, x) return y, ifelse(FST, csmodel.st, csmodel.st_any) end +@inline Lux.__apply_reactant(l::Lux.ReactantLayer, csmodel, x) = l.fwd( + csmodel, __make_concrete_array(x)) + end diff --git a/src/layers/extension.jl b/src/layers/extension.jl index 4c66d6ca46..218c54abd9 100644 --- a/src/layers/extension.jl +++ b/src/layers/extension.jl @@ -251,11 +251,12 @@ end # TODO: Add a ChainRules rrule that calls the `bwd` function, i.e. uses Enzyme for the # gradient computation -@concrete struct ReactantLayer{FST, F, B, L <: AbstractExplicitLayer} <: +@concrete struct ReactantLayer{FST, T, F, B, L <: AbstractExplicitLayer} <: AbstractExplicitLayer layer::L clayer fwd::F bwd::B eltype_adaptor + input_structure end diff --git a/src/transform/reactant.jl b/src/transform/reactant.jl index ec755850e5..34e5bb3ddf 100644 --- a/src/transform/reactant.jl +++ b/src/transform/reactant.jl @@ -21,3 +21,4 @@ function Adapt.adapt(to::ToReactantAdaptor, model::AbstractExplicitLayer) end function __to_reactant_adaptor end +function __apply_reactant end From d186bf6eaa1c1533fe7863a0556384b28539bb90 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 27 May 2024 15:54:26 -0700 Subject: [PATCH 08/15] Try converting common parameter types to the compiled type --- ext/LuxReactantExt.jl | 116 +++++++++++++++++++++++++++++++++----- src/Lux.jl | 16 +++--- src/layers/extension.jl | 23 +++++++- src/transform/reactant.jl | 11 ++-- 4 files changed, 138 insertions(+), 28 deletions(-) diff --git a/ext/LuxReactantExt.jl b/ext/LuxReactantExt.jl index 6914c136bb..4cb4839a40 100644 --- a/ext/LuxReactantExt.jl +++ b/ext/LuxReactantExt.jl @@ -3,18 +3,26 @@ module LuxReactantExt using Adapt: adapt using ArgCheck: @argcheck using Enzyme: Enzyme -using Functors: fmapstructure +using Functors: fmapstructure, fmap +using Markdown: @md_str using Random: AbstractRNG, Xoshiro using Reactant: Reactant using Lux: Lux, LuxEltypeAdaptor using LuxCore: LuxCore, AbstractExplicitLayer -@inline __make_concrete_array(x::Reactant.ConcreteRArray) = x -@inline __make_concrete_array(x::AbstractArray) = Reactant.ConcreteRArray(x) @inline function __make_concrete_array(x) return Reactant.make_tracer(IdDict(), x, (), Reactant.ArrayToConcrete, nothing) end +@inline function __try_similar_structure(x::AbstractArray, y::NamedTuple{()}) + length(x) == 0 && return y + throw(DimensionMismatch(lazy"Expected empty array, got $(size(x)).")) +end +@inline function __try_similar_structure(x::AbstractArray, y::AbstractArray) + return parent(x) !== x ? copy(x) : x # unview arrays and such +end +@inline __try_similar_structure(x, y) = fmap(__try_similar_structure, x, y) + # Reactant doesn't handle mixed eltypes that well, so we will first try to compile it as # a usual julia function. However, if that fails, we will type cast and try to recompile. # Note that this is only a one time operation so it doesn't matter if this step is too slow. @@ -22,7 +30,8 @@ function Lux.__to_reactant_adaptor( to::Lux.ToReactantAdaptor{FST}, model::AbstractExplicitLayer) where {FST} input_prototype = to.input_prototype input_eltype = Lux.__recursive_eltype(input_prototype) - ps, st = Lux.setup(Xoshiro(123), model) # We generate fake parameters and states to compile the model + ps, st = Lux.setup(LuxCore.replicate(to.rng), model) + ps = to.ps_transform(ps) ps_eltype = Lux.__recursive_eltype(ps) st_eltype = Lux.__recursive_eltype(st) @@ -31,8 +40,7 @@ function Lux.__to_reactant_adaptor( if !to.force_allow_mixed_eltypes && any(x -> x != newT && x != Union{}, (input_eltype, ps_eltype, st_eltype)) - # Try compiling, but this might fail - try + try # Try compiling, but this might fail return Lux.__to_reactant_adaptor(to, model, input_prototype, ps, st, nothing) catch err @warn """ @@ -93,12 +101,18 @@ function Lux.__to_reactant_adaptor( nothing end + # TODO: Add compiled types to the layer type information. That way we can check + # if the model is being executed with the correct types. return Lux.ReactantLayer{FST, Lux.__recursive_eltype(input_prototype)}( - model, cmodel, fwd, bwd, eltype_adaptor, fmapstructure(Lux.__size, input_prototype)) + to, input_prototype, concrete_input, cps, cst, model, cmodel, fwd, + bwd, eltype_adaptor, fmapstructure(Lux.__size, input_prototype)) end +# TODO: Currently we are maintaining 2 copies of the parameters, this is not ideal. +# We can return the parameters and states from the layer itself, since we don't care +# about the values, but just the type. function LuxCore.initialparameters(rng::AbstractRNG, layer::Lux.ReactantLayer) - ps = LuxCore.initialparameters(rng, layer.layer) + ps = layer.adaptor(LuxCore.initialparameters(rng, layer.layer)) layer.eltype_adaptor !== nothing && (ps = adapt(layer.eltype_adaptor, ps)) return __make_concrete_array(ps) end @@ -110,7 +124,6 @@ function LuxCore.initialstates(rng::AbstractRNG, layer::Lux.ReactantLayer) end function (l::Lux.ReactantLayer{FST, T})(x, ps, st::NamedTuple) where {FST, T} - csmodel = Lux.StatefulLuxLayer{FST}(l.clayer, ps, st) l.eltype_adaptor !== nothing && (x = adapt(l.eltype_adaptor, x)) # XLARuntimeError is not great, so check and terminate early if needed @@ -120,7 +133,7 @@ function (l::Lux.ReactantLayer{FST, T})(x, ps, st::NamedTuple) where {FST, T} end # TODO: For non array inputs this we make the eltype uniform which might not be - # desirable. We should handle those cases with `fmap` + # desirable. We should handle those cases with `fmap` if T != Lux.__recursive_eltype(x) @warn """ `Reactant.compile` was called with input eltype $(T) but the current input eltype \ @@ -132,11 +145,86 @@ function (l::Lux.ReactantLayer{FST, T})(x, ps, st::NamedTuple) where {FST, T} x = adapt(LuxEltypeAdaptor{T}(), x) end - y = Lux.__apply_reactant(l, csmodel, x) - return y, ifelse(FST, csmodel.st, csmodel.st_any) + return Lux.__apply_reactant(l, x, ps, st) +end + +# This is the ideal case where all the types match correctly. +# Input Type mispatches should not happen here, they should be handled before this function +# is called. +# If `st` mismatch happens then user really messed something up. can't do anything about it. +@inline function Lux.__apply_reactant( + l::Lux.ReactantLayer{FST, T, inType}, x::inType, ps, st) where {FST, T, inType} + return Lux.__apply_reactant(l, __make_concrete_array(x), ps, st) +end + +@inline function Lux.__apply_reactant( + l::Lux.ReactantLayer{FST, T, inType, inCType, psType, stType}, x::inCType, + ps::psType, st::stType) where {FST, T, inType, inCType, psType, stType} + csmodel = Lux.StatefulLuxLayer{FST}(l.clayer, ps, st) + return Lux.__apply_reactant(l, csmodel, x), ifelse(FST, csmodel.st, csmodel.st_any) +end + +# Parameter type mismatch. This might be too common so try to handle it gracefully. +@inline function Lux.__apply_reactant( + l::Lux.ReactantLayer{FST, T, inType, inCType, psType, stType}, x::inCType, + ps::psType2, st::stType) where {FST, T, inType, inCType, psType, psType2, stType} + ps = __try_similar_structure(Lux.__named_tuple(ps), l.concrete_ps) + ps = l.adaptor(ps) + l.eltype_adaptor !== nothing && (ps = adapt(l.eltype_adaptor, ps)) + ps = __make_concrete_array(ps) + + if typeof(ps) != psType + @warn "Automatic type conversion failed for `ps`." original_ps_type=psType2 + __graceful_type_mismatch_error(l, x, ps, st) + end + + return Lux.__apply_reactant(l, __make_concrete_array(x), ps, st) end -@inline Lux.__apply_reactant(l::Lux.ReactantLayer, csmodel, x) = l.fwd( - csmodel, __make_concrete_array(x)) +Lux.__apply_reactant(l, x, ps, st) = __graceful_type_mismatch_error(l, x, ps, st) + +@inline Lux.__apply_reactant(l::Lux.ReactantLayer, csmodel, x) = l.fwd(csmodel, x) + +# Don't inline, else types don't get displayed in the stack trace +function __graceful_type_mismatch_error( + ::Lux.ReactantLayer{FST, T, inType, inCType, psType, stType}, + x, ps, st) where {FST, T, inType, inCType, psType, stType} + #! format: off + input_type_mismatch_str = typeof(x) == inType || typeof(x) == inCType ? """ + 1. Input Types Matched. + """ : """ + 1. Input Type: $(typeof(x)). + Compiled Input Type: $(inType). + Compiled Concrete Input Type: $(inCType). + """ + #! format: on + + ps_type_mismatch_str = typeof(ps) == psType ? """ + 2. Parameter Types Matched. + """ : """ + 2. Parameter Type: $(typeof(ps)). + Compiled Parameter Type: $(psType). + """ + + st_type_mismatch_str = typeof(st) == stType ? """ + 3. State Types Matched. + """ : """ + 3. State Type: $(typeof(st)). + Compiled State Type: $(stType). + """ + + throw(ArgumentError(""" + Model compiled types and input types don't match. We tried our best to convert the \ + types to the right ones, but we failed. Ideally the argument types should not be \ + modified after compilation. + + 1. Recompile the model with the correct input types. + 2. Open an issue on the Lux.jl repository, to check if we can ease out the automatic \ + type conversion. + + List of Type Mismatches: + + $(input_type_mismatch_str) $(ps_type_mismatch_str) $(st_type_mismatch_str)""")) +end end diff --git a/src/Lux.jl b/src/Lux.jl index 89cb5c21c8..22616d898b 100644 --- a/src/Lux.jl +++ b/src/Lux.jl @@ -16,7 +16,7 @@ using PrecompileTools: @recompile_invalidations using Markdown: @doc_str using OhMyThreads: tmapreduce using Preferences: @load_preference - using Random: Random, AbstractRNG + using Random: Random, AbstractRNG, Xoshiro using Reexport: @reexport using LuxCore, LuxLib, LuxDeviceUtils, WeightInitializers @@ -48,6 +48,12 @@ const DISABLE_AUTOMATIC_NESTED_AD_SWITCH = @load_preference("DisableAutomaticNes # Utilities include("utils.jl") +# Transform to and from other frameworks +include("transform/types.jl") +include("transform/flux.jl") +include("transform/simplechains.jl") +include("transform/reactant.jl") + # Layer Implementations include("layers/basic.jl") include("layers/containers.jl") @@ -72,12 +78,6 @@ include("helpers/compact.jl") include("helpers/autodiff.jl") include("helpers/nested_ad.jl") -# Transform to and from other frameworks -include("transform/types.jl") -include("transform/flux.jl") -include("transform/simplechains.jl") -include("transform/reactant.jl") - # Distributed Training include("distributed/backend.jl") include("distributed/public_api.jl") @@ -111,7 +111,7 @@ export f16, f32, f64 export transform export FromFluxAdaptor, FluxLayer export ToSimpleChainsAdaptor, SimpleChainsLayer -export ToReactantAdaptor +export ToReactantAdaptor, ReactantLayer export DynamicExpressionsLayer export MPIBackend, NCCLBackend, DistributedUtils diff --git a/src/layers/extension.jl b/src/layers/extension.jl index 218c54abd9..1735bdd006 100644 --- a/src/layers/extension.jl +++ b/src/layers/extension.jl @@ -251,8 +251,13 @@ end # TODO: Add a ChainRules rrule that calls the `bwd` function, i.e. uses Enzyme for the # gradient computation -@concrete struct ReactantLayer{FST, T, F, B, L <: AbstractExplicitLayer} <: - AbstractExplicitLayer +@concrete struct ReactantLayer{FST, T, inType, inCType, psType, stType, F, B, + L <: AbstractExplicitLayer, AD <: ToReactantAdaptor} <: AbstractExplicitLayer + adaptor::AD + input_prototype::inType + concrete_input_prototype::inCType + concrete_ps::psType + concrete_st::stType layer::L clayer fwd::F @@ -260,3 +265,17 @@ end eltype_adaptor input_structure end + +function Base.show(io::IO, s::ReactantLayer{ST}) where {ST} + if get(io, :typeinfo, nothing) === nothing # e.g. top level in REPL + print(io, "ReactantLayer{$ST}(\n") + _big_show(io, s.layer, 4) + elseif !get(io, :compact, false) # e.g. printed inside a Vector, but not a Matrix + print(io, "ReactantLayer{$ST}(") + _layer_show(io, s.layer) + else + print(io, "ReactantLayer{$ST}(") + show(io, s.layer) + end + print(io, ")") +end diff --git a/src/transform/reactant.jl b/src/transform/reactant.jl index 34e5bb3ddf..fce9ec63e4 100644 --- a/src/transform/reactant.jl +++ b/src/transform/reactant.jl @@ -1,13 +1,16 @@ -@concrete struct ToReactantAdaptor{FST} <: AbstractFromLuxAdaptor +@concrete struct ToReactantAdaptor{FST, R <: AbstractRNG} <: AbstractFromLuxAdaptor input_prototype + ps_transform + rng::R force_compile_backward::Bool force_allow_mixed_eltypes::Bool end -function ToReactantAdaptor{FST}(input_prototype; force_compile_backward::Bool=false, +function ToReactantAdaptor{FST}(input_prototype; rng=Xoshiro(123), ps_transform=identity, + force_compile_backward::Bool=false, force_allow_mixed_eltypes::Bool=false) where {FST} - return ToReactantAdaptor{FST}( - input_prototype, force_compile_backward, force_allow_mixed_eltypes) + return ToReactantAdaptor{FST}(input_prototype, ps_transform, rng, + force_compile_backward, force_allow_mixed_eltypes) end function ToReactantAdaptor(args...; fixed_state_type::Val=Val(true), kwargs...) return ToReactantAdaptor{__unwrap_val(fixed_state_type)}(args...; kwargs...) From c63241892d48c17de73c1dd3ff4f80a4997836f1 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 27 May 2024 16:02:47 -0700 Subject: [PATCH 09/15] Accidentally used Boxed value --- ext/LuxReactantExt.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ext/LuxReactantExt.jl b/ext/LuxReactantExt.jl index 4cb4839a40..8d24879b80 100644 --- a/ext/LuxReactantExt.jl +++ b/ext/LuxReactantExt.jl @@ -83,8 +83,8 @@ function Lux.__to_reactant_adaptor( st = ifelse(FST, m.st, m.st_any) Enzyme.autodiff( Enzyme.Reverse, (m, x, ps, st) -> first(LuxCore.apply(m, x, ps, st)), - Enzyme.Duplicated, Enzyme.Const(m), Enzyme.Duplicated(x, dx), - Enzyme.Duplicated(ps, dps), Enzyme.Const(st)) + Enzyme.Duplicated, Enzyme.Const(m.model), Enzyme.Duplicated(x, dx), + Enzyme.Duplicated(m.ps, dps), Enzyme.Const(st)) return (; ps=dps), dx end From f5a3e58052b8bec3efb2f117f64a22cdea4b2d94 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 27 May 2024 16:31:53 -0700 Subject: [PATCH 10/15] Implement a working VJP function --- ext/LuxReactantExt.jl | 44 ++++++++++++++++++++++++----------------- src/layers/extension.jl | 7 ++++--- 2 files changed, 30 insertions(+), 21 deletions(-) diff --git a/ext/LuxReactantExt.jl b/ext/LuxReactantExt.jl index 8d24879b80..26a3d38cc1 100644 --- a/ext/LuxReactantExt.jl +++ b/ext/LuxReactantExt.jl @@ -67,6 +67,9 @@ end function Lux.__to_reactant_adaptor( to::Lux.ToReactantAdaptor{FST}, model::AbstractExplicitLayer, input_prototype, ps, st, eltype_adaptor) where {FST} + output = first(model(input_prototype, ps, st)) + concrete_output = __make_concrete_array(output) + concrete_input = __make_concrete_array(input_prototype) cmodel = __make_concrete_array(model) cps = __make_concrete_array(ps) @@ -74,21 +77,28 @@ function Lux.__to_reactant_adaptor( csmodel = Lux.StatefulLuxLayer{FST}(cmodel, cps, cst) - fwd = Reactant.compile((m, x) -> m(x), (csmodel, concrete_input)) - - bwd = try - enzyme_grad_fn = (m, x) -> begin - dx = Enzyme.make_zero(x) - dps = Enzyme.make_zero(m.ps) - st = ifelse(FST, m.st, m.st_any) - Enzyme.autodiff( - Enzyme.Reverse, (m, x, ps, st) -> first(LuxCore.apply(m, x, ps, st)), - Enzyme.Duplicated, Enzyme.Const(m.model), Enzyme.Duplicated(x, dx), - Enzyme.Duplicated(m.ps, dps), Enzyme.Const(st)) - return (; ps=dps), dx + fwd_fn = Reactant.compile((m, x) -> m(x), (csmodel, concrete_input)) + + function enzyme_vjp_fn(m, x, y, dy) + dx = Enzyme.make_zero(x) + dps = Enzyme.make_zero(m.ps) + st_m = ifelse(FST, m.st, m.st_any) + + function wrapper_fn!(y, model, x, ps, st) + copyto!(y, first(LuxCore.apply(model, x, ps, st))) + return nothing end - Reactant.compile(enzyme_grad_fn, (csmodel, concrete_input)) + Enzyme.autodiff(Enzyme.Reverse, wrapper_fn!, Enzyme.Const, Enzyme.Duplicated(y, dy), + Enzyme.Const(m.model), Enzyme.Duplicated(x, dx), + Enzyme.Duplicated(m.ps, dps), Enzyme.Const(st_m)) + return dx, dps + end + + vjp_fn = try + concrete_output2 = __make_concrete_array(deepcopy(output)) + Reactant.compile( + enzyme_vjp_fn, (csmodel, concrete_input, concrete_output, concrete_output2)) catch err to.force_compile_backward && rethrow(err) @error """ @@ -101,11 +111,9 @@ function Lux.__to_reactant_adaptor( nothing end - # TODO: Add compiled types to the layer type information. That way we can check - # if the model is being executed with the correct types. return Lux.ReactantLayer{FST, Lux.__recursive_eltype(input_prototype)}( - to, input_prototype, concrete_input, cps, cst, model, cmodel, fwd, - bwd, eltype_adaptor, fmapstructure(Lux.__size, input_prototype)) + to, input_prototype, concrete_input, cps, cst, model, cmodel, fwd_fn, + vjp_fn, eltype_adaptor, fmapstructure(Lux.__size, input_prototype)) end # TODO: Currently we are maintaining 2 copies of the parameters, this is not ideal. @@ -183,7 +191,7 @@ end Lux.__apply_reactant(l, x, ps, st) = __graceful_type_mismatch_error(l, x, ps, st) -@inline Lux.__apply_reactant(l::Lux.ReactantLayer, csmodel, x) = l.fwd(csmodel, x) +@inline Lux.__apply_reactant(l::Lux.ReactantLayer, csmodel, x) = l.fwd_fn(csmodel, x) # Don't inline, else types don't get displayed in the stack trace function __graceful_type_mismatch_error( diff --git a/src/layers/extension.jl b/src/layers/extension.jl index 1735bdd006..791671b30d 100644 --- a/src/layers/extension.jl +++ b/src/layers/extension.jl @@ -251,7 +251,8 @@ end # TODO: Add a ChainRules rrule that calls the `bwd` function, i.e. uses Enzyme for the # gradient computation -@concrete struct ReactantLayer{FST, T, inType, inCType, psType, stType, F, B, +# TODO: Inference won't work OOTB, we will have to compile that separately +@concrete struct ReactantLayer{FST, T, inType, inCType, psType, stType, L <: AbstractExplicitLayer, AD <: ToReactantAdaptor} <: AbstractExplicitLayer adaptor::AD input_prototype::inType @@ -260,8 +261,8 @@ end concrete_st::stType layer::L clayer - fwd::F - bwd::B + fwd_fn + vjp_fn eltype_adaptor input_structure end From 55d47679418781916f8fab25f9393b2be03d6ef1 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 29 May 2024 17:51:35 -0700 Subject: [PATCH 11/15] Move things around --- ext/LuxReactantExt/LuxReactantExt.jl | 20 +++++++++++++++++++ .../layer.jl} | 13 ------------ ext/LuxReactantExt/train.jl | 0 src/transform/reactant.jl | 11 ++++++++++ 4 files changed, 31 insertions(+), 13 deletions(-) create mode 100644 ext/LuxReactantExt/LuxReactantExt.jl rename ext/{LuxReactantExt.jl => LuxReactantExt/layer.jl} (96%) create mode 100644 ext/LuxReactantExt/train.jl diff --git a/ext/LuxReactantExt/LuxReactantExt.jl b/ext/LuxReactantExt/LuxReactantExt.jl new file mode 100644 index 0000000000..3aa93eb4bd --- /dev/null +++ b/ext/LuxReactantExt/LuxReactantExt.jl @@ -0,0 +1,20 @@ +module LuxReactantExt + +using Adapt: adapt +using ArgCheck: @argcheck +using Enzyme: Enzyme +using Functors: fmapstructure, fmap +using Markdown: @md_str +using Random: AbstractRNG, Xoshiro +using Reactant: Reactant +using Lux: Lux, LuxEltypeAdaptor +using LuxCore: LuxCore, AbstractExplicitLayer + +# compile just the model. This allows us to run part of the model in vanilla LLVM. Needed +# for cases where we can't currently compile via Reactant or where XLA is not great +# for the model. +include("layer.jl") + +include("train.jl") + +end diff --git a/ext/LuxReactantExt.jl b/ext/LuxReactantExt/layer.jl similarity index 96% rename from ext/LuxReactantExt.jl rename to ext/LuxReactantExt/layer.jl index 26a3d38cc1..6941a33824 100644 --- a/ext/LuxReactantExt.jl +++ b/ext/LuxReactantExt/layer.jl @@ -1,14 +1,3 @@ -module LuxReactantExt - -using Adapt: adapt -using ArgCheck: @argcheck -using Enzyme: Enzyme -using Functors: fmapstructure, fmap -using Markdown: @md_str -using Random: AbstractRNG, Xoshiro -using Reactant: Reactant -using Lux: Lux, LuxEltypeAdaptor -using LuxCore: LuxCore, AbstractExplicitLayer @inline function __make_concrete_array(x) return Reactant.make_tracer(IdDict(), x, (), Reactant.ArrayToConcrete, nothing) @@ -234,5 +223,3 @@ function __graceful_type_mismatch_error( $(input_type_mismatch_str) $(ps_type_mismatch_str) $(st_type_mismatch_str)""")) end - -end diff --git a/ext/LuxReactantExt/train.jl b/ext/LuxReactantExt/train.jl new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/transform/reactant.jl b/src/transform/reactant.jl index fce9ec63e4..88a500db6a 100644 --- a/src/transform/reactant.jl +++ b/src/transform/reactant.jl @@ -25,3 +25,14 @@ end function __to_reactant_adaptor end function __apply_reactant end + +""" + AutoReactant() + +Compile the training loop to MLIR/XLA via `Reactant.jl`. + +This has been added to Lux very recently and is under-going rapid development. Currently, +only a limited subset of Lux models can be compiled via `Reactant.jl`. If you encounter any +issues, please report them on the `Lux.jl` or `Reactant.jl` GitHub repository. +""" +struct AutoReactant end From db3949e444b3b4b45a1ac25f3a504f9e77d65ccb Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 29 May 2024 18:03:52 -0700 Subject: [PATCH 12/15] Training step function --- Project.toml | 3 +- ext/LuxReactantExt/LuxReactantExt.jl | 1 + src/Lux.jl | 1 + src/contrib/training.jl | 61 +++++++++++++++++++++------- 4 files changed, 51 insertions(+), 15 deletions(-) diff --git a/Project.toml b/Project.toml index 0b74e662a3..5608085f4f 100644 --- a/Project.toml +++ b/Project.toml @@ -137,6 +137,7 @@ MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195" NCCL = "3fe64909-d7a1-4096-9b7d-7a0f12cf0f6b" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" +Reactant = "3c362404-f566-11ee-1572-e11a4b42c853" ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" SimpleChains = "de6bee2f-e2f4-4ec7-b6ed-219cc6f6e9e5" @@ -147,4 +148,4 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Aqua", "ComponentArrays", "DynamicExpressions", "Enzyme", "ExplicitImports", "FiniteDifferences", "Flux", "ForwardDiff", "Logging", "LuxAMDGPU", "LuxCUDA", "LuxTestUtils", "MLUtils", "MPI", "NCCL", "Optimisers", "Pkg", "ReTestItems", "ReverseDiff", "SimpleChains", "StableRNGs", "Statistics", "Test", "Tracker", "Zygote"] +test = ["Aqua", "ComponentArrays", "DynamicExpressions", "Enzyme", "ExplicitImports", "FiniteDifferences", "Flux", "ForwardDiff", "Logging", "LuxAMDGPU", "LuxCUDA", "LuxTestUtils", "MLUtils", "MPI", "NCCL", "Optimisers", "Pkg", "Reactant", "ReTestItems", "ReverseDiff", "SimpleChains", "StableRNGs", "Statistics", "Test", "Tracker", "Zygote"] diff --git a/ext/LuxReactantExt/LuxReactantExt.jl b/ext/LuxReactantExt/LuxReactantExt.jl index 3aa93eb4bd..321da3a57e 100644 --- a/ext/LuxReactantExt/LuxReactantExt.jl +++ b/ext/LuxReactantExt/LuxReactantExt.jl @@ -15,6 +15,7 @@ using LuxCore: LuxCore, AbstractExplicitLayer # for the model. include("layer.jl") +# compile the entire training loop include("train.jl") end diff --git a/src/Lux.jl b/src/Lux.jl index 22616d898b..7e939b342d 100644 --- a/src/Lux.jl +++ b/src/Lux.jl @@ -105,6 +105,7 @@ export @compact, CompactLuxLayer export jacobian_vector_product, vector_jacobian_product export batched_jacobian export AutoForwardDiff, AutoReverseDiff, AutoTracker, AutoZygote +export AutoReactant export f16, f32, f64 diff --git a/src/contrib/training.jl b/src/contrib/training.jl index c0f285992c..ad5497698c 100644 --- a/src/contrib/training.jl +++ b/src/contrib/training.jl @@ -36,21 +36,24 @@ function Base.show(io::IO, ts::TrainState) print(io, "\n objective_function: ", nameof(typeof(ts.objective_function))) end -""" - apply_gradients(ts::TrainState, grads) - -Update the parameters stored in `ts` using the gradients `grads`. - +const APPLY_GRAD_DOCSTRING = """ ## Arguments - `ts`: [`TrainState`](@ref) object. - `grads`: Gradients of the loss function wrt `ts.params`. - - `update_inplace`: Whether to update the parameters inplace or not. ## Returns Updated [`TrainState`](@ref) object. """ + +""" + apply_gradients(ts::TrainState, grads) + +Update the parameters stored in `ts` using the gradients `grads`. + +$(APPLY_GRAD_DOCSTRING) +""" function apply_gradients end """ @@ -59,14 +62,7 @@ function apply_gradients end Update the parameters stored in `ts` using the gradients `grads`. This is an inplace version of [`apply_gradients`](@ref). -## Arguments - - - `ts`: [`TrainState`](@ref) object. - - `grads`: Gradients of the loss function wrt `ts.params`. - -## Returns - -Updated [`TrainState`](@ref) object. +$(APPLY_GRAD_DOCSTRING) """ function apply_gradients! end @@ -146,3 +142,40 @@ end return wrapped_objective_function, st_updated, stats end + +""" + single_train_step!(backend, obj_fn::F, data, ts::TrainState) + +Perform a single training step. Computes the gradients using [`compute_gradients`](@ref) and +updates the parameters using [`apply_gradients!`](@ref). All backends supported via +[`compute_gradients`](@ref) are supported here. + +## Additional Backends + + - [`AutoReactant`](@ref): Compiles the training loop to MLIR/XLA via `Reactant.jl`. +""" +function single_train_step! end + +""" + single_train_step(backend, obj_fn::F, data, ts::TrainState) + +Perform a single training step. Computes the gradients using [`compute_gradients`](@ref) and +updates the parameters using [`apply_gradients`](@ref). All backends supported via +[`compute_gradients`](@ref) are supported here. + +## Additional Backends + + - [`AutoReactant`](@ref): Compiles the training loop to MLIR/XLA via `Reactant.jl`. + +In most cases you should use [`single_train_step!`](@ref) instead of this function. +""" +function single_train_step end + +for inplace in ("!", "") + step, apply_fn = Symbol(:single_train_step, inplace), Symbol(:apply_gradients, inplace) + @eval function $step(backend, obj_fn::F, data, ts::TrainState) where {F} + grads, loss, stats, ts = compute_gradients(backend, obj_fn, data, ts) + $apply_fn(ts, grads) + return grads, loss, stats, ts + end +end From 7ec09d7fe9f3053b1418238ab83fdc64bb09556a Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 29 May 2024 19:59:26 -0700 Subject: [PATCH 13/15] Start impl for Reactant training loop --- ext/LuxReactantExt/LuxReactantExt.jl | 8 ++-- ext/LuxReactantExt/layer.jl | 14 ------- ext/LuxReactantExt/train.jl | 55 ++++++++++++++++++++++++++++ ext/LuxReactantExt/utils.jl | 12 ++++++ src/Lux.jl | 5 ++- src/contrib/training.jl | 28 +++++++++++--- 6 files changed, 97 insertions(+), 25 deletions(-) create mode 100644 ext/LuxReactantExt/utils.jl diff --git a/ext/LuxReactantExt/LuxReactantExt.jl b/ext/LuxReactantExt/LuxReactantExt.jl index 321da3a57e..dce6e8f849 100644 --- a/ext/LuxReactantExt/LuxReactantExt.jl +++ b/ext/LuxReactantExt/LuxReactantExt.jl @@ -2,14 +2,16 @@ module LuxReactantExt using Adapt: adapt using ArgCheck: @argcheck -using Enzyme: Enzyme +using ConcreteStructs: @concrete +using Enzyme: Enzyme, Active, Const, Duplicated using Functors: fmapstructure, fmap -using Markdown: @md_str using Random: AbstractRNG, Xoshiro using Reactant: Reactant -using Lux: Lux, LuxEltypeAdaptor +using Lux: Lux, LuxEltypeAdaptor, AutoReactant using LuxCore: LuxCore, AbstractExplicitLayer +include("utils.jl") + # compile just the model. This allows us to run part of the model in vanilla LLVM. Needed # for cases where we can't currently compile via Reactant or where XLA is not great # for the model. diff --git a/ext/LuxReactantExt/layer.jl b/ext/LuxReactantExt/layer.jl index 6941a33824..b506be96f4 100644 --- a/ext/LuxReactantExt/layer.jl +++ b/ext/LuxReactantExt/layer.jl @@ -1,17 +1,3 @@ - -@inline function __make_concrete_array(x) - return Reactant.make_tracer(IdDict(), x, (), Reactant.ArrayToConcrete, nothing) -end - -@inline function __try_similar_structure(x::AbstractArray, y::NamedTuple{()}) - length(x) == 0 && return y - throw(DimensionMismatch(lazy"Expected empty array, got $(size(x)).")) -end -@inline function __try_similar_structure(x::AbstractArray, y::AbstractArray) - return parent(x) !== x ? copy(x) : x # unview arrays and such -end -@inline __try_similar_structure(x, y) = fmap(__try_similar_structure, x, y) - # Reactant doesn't handle mixed eltypes that well, so we will first try to compile it as # a usual julia function. However, if that fails, we will type cast and try to recompile. # Note that this is only a one time operation so it doesn't matter if this step is too slow. diff --git a/ext/LuxReactantExt/train.jl b/ext/LuxReactantExt/train.jl index e69de29bb2..b7c19e270b 100644 --- a/ext/LuxReactantExt/train.jl +++ b/ext/LuxReactantExt/train.jl @@ -0,0 +1,55 @@ +# TODO: For the iip versions as well. Metaprogram that part + +# Case III: Nothing is cached. First call to `single_train_step` +function Lux.Experimental.single_train_step( + ad::AutoReactant, obj_fn::F, data, ts::Lux.Experimental.TrainState) where {F} + # ps = ts.parameters + dps = Lux.__recursive_make_zero(ts.parameters) + # st = ts.states + # model = ts.model + + data = __make_concrete_array(data) + model = __make_concrete_array(ts.model) + dps = __make_concrete_array(dps) + ps = __make_concrete_array(ts.parameters) + st = __make_concrete_array(ts.states) + + # @show + + # function reverse_fn_wrapper(obj_fn, model, ps, dps, st, data) + obj_fn_wrapper, st_updated, stats = Lux.Experimental.__wrap_objective_function( + obj_fn, st) + # st_, stats = nothing, (;) + # @show __update_fn_wrapper(obj_fn_wrapper, model, ps, dps, st, data) + + # function obj_fn_wrapper(obj_fn, model, ps, st, data) # Intentionally boxing + # y, st_, stats = obj_fn(model, ps, st, data) + # return y + # end + + # @show obj_fn_wrapper # (obj_fn, model, ps, st, data) + + # _, loss = Enzyme.autodiff(Enzyme.ReverseWithPrimal, obj_fn_wrapper, Active, + # Const(model), Duplicated(ps, dps), Const(st), Const(data)) + # loss = obj_fn_wrapper(obj_fn, model, ps, st, data) + + # return loss, st_new, stats # FIXME: Return the correct things + # return loss + # end + + # @show reverse_fn_wrapper # (obj_fn, model, ps, dps, st, data) + + # @show reverse_fn_wrapper(obj_fn, model, ts.parameters, + # Lux.__recursive_make_zero(ts.parameters), ts.states, data) + + compiled_fn = Reactant.compile(__update_fn_wrapper, (obj_fn, model, ps, dps, st, data)) + + # return compiled_fn, (obj_fn, model, ps, dps, st, data) +end + +function __update_fn_wrapper(obj_fn, model, ps, dps, st, data) + _, loss = Enzyme.autodiff(Enzyme.ReverseWithPrimal, obj_fn, Active, Const(model), + Duplicated(ps, dps), Const(st), Const(data)) + # Lux.Experimental.apply_gradients() + return loss +end diff --git a/ext/LuxReactantExt/utils.jl b/ext/LuxReactantExt/utils.jl new file mode 100644 index 0000000000..d595a42c34 --- /dev/null +++ b/ext/LuxReactantExt/utils.jl @@ -0,0 +1,12 @@ +@inline function __make_concrete_array(x) + return Reactant.make_tracer(IdDict(), x, (), Reactant.ArrayToConcrete, nothing) +end + +@inline function __try_similar_structure(x::AbstractArray, y::NamedTuple{()}) + length(x) == 0 && return y + throw(DimensionMismatch(lazy"Expected empty array, got $(size(x)).")) +end +@inline function __try_similar_structure(x::AbstractArray, y::AbstractArray) + return parent(x) !== x ? copy(x) : x # unview arrays and such +end +@inline __try_similar_structure(x, y) = fmap(__try_similar_structure, x, y) diff --git a/src/Lux.jl b/src/Lux.jl index 7e939b342d..e2cb135573 100644 --- a/src/Lux.jl +++ b/src/Lux.jl @@ -3,7 +3,8 @@ module Lux using PrecompileTools: @recompile_invalidations @recompile_invalidations begin - using ADTypes: AbstractADType, AutoForwardDiff, AutoReverseDiff, AutoTracker, AutoZygote + using ADTypes: AbstractADType, AutoEnzyme, AutoForwardDiff, AutoReverseDiff, + AutoTracker, AutoZygote using Adapt: Adapt, adapt using ArgCheck: @argcheck using ArrayInterface: ArrayInterface @@ -104,7 +105,7 @@ export @compact, CompactLuxLayer export jacobian_vector_product, vector_jacobian_product export batched_jacobian -export AutoForwardDiff, AutoReverseDiff, AutoTracker, AutoZygote +export AutoEnzyme, AutoForwardDiff, AutoReverseDiff, AutoTracker, AutoZygote export AutoReactant export f16, f32, f64 diff --git a/src/contrib/training.jl b/src/contrib/training.jl index ad5497698c..95caa68900 100644 --- a/src/contrib/training.jl +++ b/src/contrib/training.jl @@ -13,6 +13,11 @@ Internal fields: - `cache`: Cached values. Implementations are free to use this for whatever they want. - `objective_function`: Objective function might be cached. + +!!! warning + + Constructing this object directly shouldn't be considered a stable API. Use the + version with the Optimisers API. """ @concrete struct TrainState{C, F} cache::C @@ -27,8 +32,8 @@ end function Base.show(io::IO, ts::TrainState) println(io, "TrainState") println(io, " model: ", ts.model) - println(io, " parameters: ", Lux.parameterlength(ts.parameters)) - println(io, " states: ", Lux.statelength(ts.states)) + println(io, " # of parameters: ", Lux.parameterlength(ts.parameters)) + println(io, " # of states: ", Lux.statelength(ts.states)) println(io, " optimizer_state: ", ts.optimizer_state) print(io, " step: ", ts.step) ts.cache !== nothing && print(io, "\n cache: ", nameof(typeof(ts.cache))) @@ -152,7 +157,14 @@ updates the parameters using [`apply_gradients!`](@ref). All backends supported ## Additional Backends - - [`AutoReactant`](@ref): Compiles the training loop to MLIR/XLA via `Reactant.jl`. + - [`AutoReactant`](@ref): Compiles the training loop to MLIR/XLA via `Reactant.jl`. + +## Return + +Returned values are the same as [`compute_gradients`](@ref). Note that despite the `!`, +only the parameters in `ts` are updated inplace. Users should be using the returned `ts` +object for further training steps, else there is no caching and performance will be +suboptimal (and absolutely terrible for backends like `AutoReactant`). """ function single_train_step! end @@ -165,17 +177,21 @@ updates the parameters using [`apply_gradients`](@ref). All backends supported v ## Additional Backends - - [`AutoReactant`](@ref): Compiles the training loop to MLIR/XLA via `Reactant.jl`. + - [`AutoReactant`](@ref): Compiles the training loop to MLIR/XLA via `Reactant.jl`. In most cases you should use [`single_train_step!`](@ref) instead of this function. + +## Return + +Returned values are the same as [`compute_gradients`](@ref). """ function single_train_step end for inplace in ("!", "") step, apply_fn = Symbol(:single_train_step, inplace), Symbol(:apply_gradients, inplace) - @eval function $step(backend, obj_fn::F, data, ts::TrainState) where {F} + @eval function $(step)(backend, obj_fn::F, data, ts::TrainState) where {F} grads, loss, stats, ts = compute_gradients(backend, obj_fn, data, ts) - $apply_fn(ts, grads) + ts = $apply_fn(ts, grads) return grads, loss, stats, ts end end From 916e685a713ef58e88358785054480c5b58335ac Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 31 May 2024 22:53:29 -0700 Subject: [PATCH 14/15] More comprehensive compilation options --- ext/LuxReactantExt/layer.jl | 171 ++++++++++++++++++++++-------------- ext/LuxReactantExt/train.jl | 10 +-- ext/LuxReactantExt/utils.jl | 10 ++- src/layers/extension.jl | 16 ++-- src/transform/reactant.jl | 31 ++++++- 5 files changed, 156 insertions(+), 82 deletions(-) diff --git a/ext/LuxReactantExt/layer.jl b/ext/LuxReactantExt/layer.jl index b506be96f4..b45bf36acb 100644 --- a/ext/LuxReactantExt/layer.jl +++ b/ext/LuxReactantExt/layer.jl @@ -43,52 +43,67 @@ function Lux.__to_reactant_adaptor( to::Lux.ToReactantAdaptor{FST}, model::AbstractExplicitLayer, input_prototype, ps, st, eltype_adaptor) where {FST} output = first(model(input_prototype, ps, st)) - concrete_output = __make_concrete_array(output) + concrete_output = Lux.__make_reactant_array(output) - concrete_input = __make_concrete_array(input_prototype) - cmodel = __make_concrete_array(model) - cps = __make_concrete_array(ps) - cst = __make_concrete_array(st) + concrete_input = Lux.__make_reactant_array(input_prototype) + cps = Lux.__make_reactant_array(ps) + cst = Lux.__make_reactant_array(st) - csmodel = Lux.StatefulLuxLayer{FST}(cmodel, cps, cst) + smodel = Lux.StatefulLuxLayer{FST}(model, cps, cst) + fwd_fn = Reactant.compile((m, x) -> m(x), (smodel, concrete_input)) - fwd_fn = Reactant.compile((m, x) -> m(x), (csmodel, concrete_input)) + cst_test = Lux.__make_reactant_array(Lux.testmode(st)) + smodel_test = Lux.StatefulLuxLayer{FST}(model, cps, cst_test) + inference_fn = Reactant.compile((m, x) -> m(x), (smodel_test, concrete_input)) - function enzyme_vjp_fn(m, x, y, dy) - dx = Enzyme.make_zero(x) - dps = Enzyme.make_zero(m.ps) - st_m = ifelse(FST, m.st, m.st_any) - - function wrapper_fn!(y, model, x, ps, st) - copyto!(y, first(LuxCore.apply(model, x, ps, st))) - return nothing + vjp_fn = if to.skip_compile_vjp + nothing + else + function enzyme_vjp_fn(m, x, y, dy) + dx = Enzyme.make_zero(x) + dps = Enzyme.make_zero(m.ps) + st_m = ifelse(FST, m.st, m.st_any) + + function wrapper_fn!(y, model, x, ps, st) + copyto!(y, first(LuxCore.apply(model, x, ps, st))) + return nothing + end + + Enzyme.autodiff( + Enzyme.Reverse, wrapper_fn!, Enzyme.Const, Enzyme.Duplicated(y, dy), + Enzyme.Const(m.model), Enzyme.Duplicated(x, dx), + Enzyme.Duplicated(m.ps, dps), Enzyme.Const(st_m)) + return dx, dps end - Enzyme.autodiff(Enzyme.Reverse, wrapper_fn!, Enzyme.Const, Enzyme.Duplicated(y, dy), - Enzyme.Const(m.model), Enzyme.Duplicated(x, dx), - Enzyme.Duplicated(m.ps, dps), Enzyme.Const(st_m)) - return dx, dps + try + concrete_output2 = Lux.__make_reactant_array(deepcopy(output)) + Reactant.compile( + enzyme_vjp_fn, (smodel, concrete_input, concrete_output, concrete_output2)) + catch err + to.force_compile_backward && rethrow(err) + @error """ + Enzyme failed to compile the backward pass. Differentiation will be disabled \ + for this model. + + HINT: To force compilation of the backward pass, set \ + `force_compile_backward=true` in the constructor of `ToReactantAdaptor`.\n + """ exception=err + nothing + end end - vjp_fn = try - concrete_output2 = __make_concrete_array(deepcopy(output)) - Reactant.compile( - enzyme_vjp_fn, (csmodel, concrete_input, concrete_output, concrete_output2)) - catch err - to.force_compile_backward && rethrow(err) - @error """ - Enzyme failed to compile the backward pass. Differentiation will be disabled for \ - this model. - - HINT: To force compilation of the backward pass, set `force_compile_backward=true` \ - in the constructor of `ToReactantAdaptor`.\n - """ exception=err + jvp_fn = if to.skip_compile_jvp nothing + else # TODO: Implement JVP with Enzyme.Forward + throw(ArgumentError("JVPs are not implemented yet.")) end - return Lux.ReactantLayer{FST, Lux.__recursive_eltype(input_prototype)}( - to, input_prototype, concrete_input, cps, cst, model, cmodel, fwd_fn, - vjp_fn, eltype_adaptor, fmapstructure(Lux.__size, input_prototype)) + return Lux.ReactantLayer{ + FST, Lux.__recursive_eltype(input_prototype), typeof(input_prototype), + typeof(concrete_input), typeof(cst), typeof(cst_test)}( + to, cps, model, fwd_fn, inference_fn, vjp_fn, jvp_fn, + eltype_adaptor, fmapstructure(Lux.__size, input_prototype)) end # TODO: Currently we are maintaining 2 copies of the parameters, this is not ideal. @@ -97,23 +112,20 @@ end function LuxCore.initialparameters(rng::AbstractRNG, layer::Lux.ReactantLayer) ps = layer.adaptor(LuxCore.initialparameters(rng, layer.layer)) layer.eltype_adaptor !== nothing && (ps = adapt(layer.eltype_adaptor, ps)) - return __make_concrete_array(ps) + return Lux.__make_reactant_array(ps) end function LuxCore.initialstates(rng::AbstractRNG, layer::Lux.ReactantLayer) st = LuxCore.initialstates(rng, layer.layer) layer.eltype_adaptor !== nothing && (st = adapt(layer.eltype_adaptor, st)) - return __make_concrete_array(st) + return (; states=Lux.__make_reactant_array(st), training=Val(true)) end function (l::Lux.ReactantLayer{FST, T})(x, ps, st::NamedTuple) where {FST, T} l.eltype_adaptor !== nothing && (x = adapt(l.eltype_adaptor, x)) # XLARuntimeError is not great, so check and terminate early if needed - input_structure = fmapstructure(Lux.__size, x) - if l.input_structure != input_structure - throw(DimensionMismatch(lazy"Input structure mismatch. Expected $(l.input_structure), got $(input_structure).")) - end + @argcheck fmapstructure(Lux.__size, x) == l.input_structure # TODO: For non array inputs this we make the eltype uniform which might not be # desirable. We should handle those cases with `fmap` @@ -131,47 +143,69 @@ function (l::Lux.ReactantLayer{FST, T})(x, ps, st::NamedTuple) where {FST, T} return Lux.__apply_reactant(l, x, ps, st) end +@inline function Lux.__apply_reactant(l::Lux.ReactantLayer, x, ps, st) + y, st_ = Lux.__apply_reactant(l, x, ps, st.states, st.training) + return y, (; states=st_, training=st.training) +end + # This is the ideal case where all the types match correctly. # Input Type mispatches should not happen here, they should be handled before this function # is called. -# If `st` mismatch happens then user really messed something up. can't do anything about it. -@inline function Lux.__apply_reactant( - l::Lux.ReactantLayer{FST, T, inType}, x::inType, ps, st) where {FST, T, inType} - return Lux.__apply_reactant(l, __make_concrete_array(x), ps, st) +@inline function Lux.__apply_reactant(l::Lux.ReactantLayer{FST, T, inType}, x::inType, + ps, st, training) where {FST, T, inType} + return Lux.__apply_reactant(l, Lux.__make_reactant_array(x), ps, st, training) end @inline function Lux.__apply_reactant( - l::Lux.ReactantLayer{FST, T, inType, inCType, psType, stType}, x::inCType, - ps::psType, st::stType) where {FST, T, inType, inCType, psType, stType} - csmodel = Lux.StatefulLuxLayer{FST}(l.clayer, ps, st) - return Lux.__apply_reactant(l, csmodel, x), ifelse(FST, csmodel.st, csmodel.st_any) + l::Lux.ReactantLayer{FST, T, inType, inCType, stType, stTestType, psType}, + x::inCType, ps::psType, st::stType, + training) where {FST, T, inType, inCType, psType, stType, stTestType} + smodel = Lux.StatefulLuxLayer{FST}(l.layer, ps, st) + return ( + Lux.__apply_reactant(l, smodel, x, training), ifelse(FST, smodel.st, smodel.st_any)) end # Parameter type mismatch. This might be too common so try to handle it gracefully. @inline function Lux.__apply_reactant( - l::Lux.ReactantLayer{FST, T, inType, inCType, psType, stType}, x::inCType, - ps::psType2, st::stType) where {FST, T, inType, inCType, psType, psType2, stType} + l::Lux.ReactantLayer{FST, T, inType, inCType, stType, stTestType, psType}, + x::inCType, ps::psType2, st, + training) where {FST, T, inType, inCType, stType, stTestType, psType, psType2} + @warn "Parameter Type Mismatch with compiled Reactant function. This will lead to \ + performance regressions" maxlog=1 + ps = __try_similar_structure(Lux.__named_tuple(ps), l.concrete_ps) ps = l.adaptor(ps) l.eltype_adaptor !== nothing && (ps = adapt(l.eltype_adaptor, ps)) - ps = __make_concrete_array(ps) + ps = Lux.__make_reactant_array(ps) if typeof(ps) != psType @warn "Automatic type conversion failed for `ps`." original_ps_type=psType2 - __graceful_type_mismatch_error(l, x, ps, st) + __graceful_type_mismatch_error(l, x, ps, st, training) end - return Lux.__apply_reactant(l, __make_concrete_array(x), ps, st) + return Lux.__apply_reactant(l, Lux.__make_reactant_array(x), ps, st, training) +end + +function Lux.__apply_reactant(l, x, ps, st, training) + return __graceful_type_mismatch_error(l, x, ps, st, training) end -Lux.__apply_reactant(l, x, ps, st) = __graceful_type_mismatch_error(l, x, ps, st) +@inline function Lux.__apply_reactant(l::Lux.ReactantLayer, smodel, x, ::Val{true}) + return l.fwd_fn(smodel, x) +end -@inline Lux.__apply_reactant(l::Lux.ReactantLayer, csmodel, x) = l.fwd_fn(csmodel, x) +@inline function Lux.__apply_reactant(l::Lux.ReactantLayer, smodel, x, ::Val{false}) + return l.inference_fn(smodel, x) +end # Don't inline, else types don't get displayed in the stack trace function __graceful_type_mismatch_error( - ::Lux.ReactantLayer{FST, T, inType, inCType, psType, stType}, - x, ps, st) where {FST, T, inType, inCType, psType, stType} + ::Lux.ReactantLayer{FST, T, inType, inCType, stType, stTestType, psType}, + x, + ps, + st, + ::Val{training}) where { + FST, T, inType, inCType, psType, stType, stTestType, training} #! format: off input_type_mismatch_str = typeof(x) == inType || typeof(x) == inCType ? """ 1. Input Types Matched. @@ -189,12 +223,21 @@ function __graceful_type_mismatch_error( Compiled Parameter Type: $(psType). """ - st_type_mismatch_str = typeof(st) == stType ? """ - 3. State Types Matched. - """ : """ - 3. State Type: $(typeof(st)). - Compiled State Type: $(stType). - """ + st_type_mismatch_str = if training + typeof(st) == stType ? """ + 3. State Types Matched. + """ : """ + 3. State Type: $(typeof(st)). + Compiled State Type: $(stType). + """ + else + typeof(st) == stTestType ? """ + 3. State Types Matched. + """ : """ + 3. State Type: $(typeof(st)). + Compiled State Type: $(stTestType). + """ + end throw(ArgumentError(""" Model compiled types and input types don't match. We tried our best to convert the \ diff --git a/ext/LuxReactantExt/train.jl b/ext/LuxReactantExt/train.jl index b7c19e270b..a6a2d758cf 100644 --- a/ext/LuxReactantExt/train.jl +++ b/ext/LuxReactantExt/train.jl @@ -8,11 +8,11 @@ function Lux.Experimental.single_train_step( # st = ts.states # model = ts.model - data = __make_concrete_array(data) - model = __make_concrete_array(ts.model) - dps = __make_concrete_array(dps) - ps = __make_concrete_array(ts.parameters) - st = __make_concrete_array(ts.states) + data = Lux.__make_reactant_array(data) + model = Lux.__make_reactant_array(ts.model) + dps = Lux.__make_reactant_array(dps) + ps = Lux.__make_reactant_array(ts.parameters) + st = Lux.__make_reactant_array(ts.states) # @show diff --git a/ext/LuxReactantExt/utils.jl b/ext/LuxReactantExt/utils.jl index d595a42c34..61f66ffec4 100644 --- a/ext/LuxReactantExt/utils.jl +++ b/ext/LuxReactantExt/utils.jl @@ -1,4 +1,12 @@ -@inline function __make_concrete_array(x) +@inline Lux.__make_reactant_array(x::Reactant.RArray) = x +@inline function Lux.__make_reactant_array(x::AbstractArray) + hasmethod(Reactant.ArrayToConcrete, Tuple{typeof(x)}) && + return Reactant.ConcreteRArray(x) + return __make_tracer(x) +end +@inline Lux.__make_reactant_array(x) = __make_tracer(x) + +@inline function __make_tracer(x) return Reactant.make_tracer(IdDict(), x, (), Reactant.ArrayToConcrete, nothing) end diff --git a/src/layers/extension.jl b/src/layers/extension.jl index 791671b30d..1f9d126030 100644 --- a/src/layers/extension.jl +++ b/src/layers/extension.jl @@ -240,9 +240,8 @@ end # Workaround for SimpleChains not being able to handle some input types function CRC.rrule(::typeof(__apply_simple_chain), layer, x, ps, ::LuxCPUDevice) res, pb = CRC.rrule(layer, x, ps) + # Safety measure to prevent errors from weird Array types that SimpleChains doesn't support __∇apply_simple_chain = @closure Δ -> begin - # Safety measure to prevent errors from weird Array types that SimpleChains doesn't - # support ∂layer, ∂x, ∂ps = pb(convert(Array, Δ)) return CRC.NoTangent(), ∂layer, ∂x, ∂ps, CRC.NoTangent() end @@ -251,18 +250,19 @@ end # TODO: Add a ChainRules rrule that calls the `bwd` function, i.e. uses Enzyme for the # gradient computation -# TODO: Inference won't work OOTB, we will have to compile that separately -@concrete struct ReactantLayer{FST, T, inType, inCType, psType, stType, +# TODO: Docstring +@concrete struct ReactantLayer{FST, T, inType, inCType, stType, stTestType, psType, L <: AbstractExplicitLayer, AD <: ToReactantAdaptor} <: AbstractExplicitLayer adaptor::AD - input_prototype::inType - concrete_input_prototype::inCType concrete_ps::psType - concrete_st::stType layer::L - clayer + + # Compiled Functions fwd_fn + inference_fn vjp_fn + jvp_fn + eltype_adaptor input_structure end diff --git a/src/transform/reactant.jl b/src/transform/reactant.jl index 88a500db6a..845ef75c06 100644 --- a/src/transform/reactant.jl +++ b/src/transform/reactant.jl @@ -1,16 +1,26 @@ @concrete struct ToReactantAdaptor{FST, R <: AbstractRNG} <: AbstractFromLuxAdaptor input_prototype + ps_transform rng::R - force_compile_backward::Bool + force_allow_mixed_eltypes::Bool + skip_compile_vjp::Bool + force_compile_vjp::Bool + skip_compile_jvp::Bool + force_compile_jvp::Bool end function ToReactantAdaptor{FST}(input_prototype; rng=Xoshiro(123), ps_transform=identity, - force_compile_backward::Bool=false, - force_allow_mixed_eltypes::Bool=false) where {FST} + force_allow_mixed_eltypes::Bool=false, force_compile_vjp::Bool=false, + skip_compile_vjp::Bool=false, force_compile_jvp::Bool=false, + skip_compile_jvp::Bool=true) where {FST} + skip_compile_vjp && @argcheck !force_compile_vjp + skip_compile_jvp && @argcheck !force_compile_jvp + return ToReactantAdaptor{FST}(input_prototype, ps_transform, rng, - force_compile_backward, force_allow_mixed_eltypes) + force_allow_mixed_eltypes, skip_compile_vjp, force_compile_vjp, + skip_compile_jvp, force_compile_jvp) end function ToReactantAdaptor(args...; fixed_state_type::Val=Val(true), kwargs...) return ToReactantAdaptor{__unwrap_val(fixed_state_type)}(args...; kwargs...) @@ -36,3 +46,16 @@ only a limited subset of Lux models can be compiled via `Reactant.jl`. If you en issues, please report them on the `Lux.jl` or `Reactant.jl` GitHub repository. """ struct AutoReactant end + +""" + __make_reactant_array(x) + +Converts `x` to a `Reactant.ConcreteRArray` if it is not already one. +""" +function __make_reactant_array end + +@inline function __make_reactant_array(nt::NamedTuple{names}) where {names} + return NamedTuple{names}(map(__make_reactant_array, values(nt))) +end +@inline __make_reactant_array(t::Tuple) = map(__make_reactant_array, t) +@inline __make_reactant_array(x::AbstractExplicitLayer) = x From 8ce97075e559c67488bfb11fb76c657fa489f386 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 31 May 2024 23:05:33 -0700 Subject: [PATCH 15/15] Allow skipping all compilation --- ext/LuxReactantExt/layer.jl | 8 ++++++-- src/transform/reactant.jl | 17 ++++++++++++----- src/utils.jl | 1 + 3 files changed, 19 insertions(+), 7 deletions(-) diff --git a/ext/LuxReactantExt/layer.jl b/ext/LuxReactantExt/layer.jl index b45bf36acb..c2eb3f07ab 100644 --- a/ext/LuxReactantExt/layer.jl +++ b/ext/LuxReactantExt/layer.jl @@ -50,11 +50,13 @@ function Lux.__to_reactant_adaptor( cst = Lux.__make_reactant_array(st) smodel = Lux.StatefulLuxLayer{FST}(model, cps, cst) - fwd_fn = Reactant.compile((m, x) -> m(x), (smodel, concrete_input)) + fwd_fn = to.skip_compile_forward ? nothing : + Reactant.compile((m, x) -> m(x), (smodel, concrete_input)) cst_test = Lux.__make_reactant_array(Lux.testmode(st)) smodel_test = Lux.StatefulLuxLayer{FST}(model, cps, cst_test) - inference_fn = Reactant.compile((m, x) -> m(x), (smodel_test, concrete_input)) + inference_fn = to.skip_compile_inference ? nothing : + Reactant.compile((m, x) -> m(x), (smodel_test, concrete_input)) vjp_fn = if to.skip_compile_vjp nothing @@ -191,10 +193,12 @@ function Lux.__apply_reactant(l, x, ps, st, training) end @inline function Lux.__apply_reactant(l::Lux.ReactantLayer, smodel, x, ::Val{true}) + @argcheck l.fwd_fn !== nothing return l.fwd_fn(smodel, x) end @inline function Lux.__apply_reactant(l::Lux.ReactantLayer, smodel, x, ::Val{false}) + @argcheck l.inference_fn !== nothing return l.inference_fn(smodel, x) end diff --git a/src/transform/reactant.jl b/src/transform/reactant.jl index 845ef75c06..47f3e009ca 100644 --- a/src/transform/reactant.jl +++ b/src/transform/reactant.jl @@ -5,6 +5,8 @@ rng::R force_allow_mixed_eltypes::Bool + skip_compile_forward::Bool + skip_compile_inference::Bool skip_compile_vjp::Bool force_compile_vjp::Bool skip_compile_jvp::Bool @@ -12,15 +14,20 @@ end function ToReactantAdaptor{FST}(input_prototype; rng=Xoshiro(123), ps_transform=identity, - force_allow_mixed_eltypes::Bool=false, force_compile_vjp::Bool=false, + force_allow_mixed_eltypes::Bool=false, skip_compile_forward::Bool=false, + skip_compile_inference::Bool=false, force_compile_vjp::Bool=false, skip_compile_vjp::Bool=false, force_compile_jvp::Bool=false, - skip_compile_jvp::Bool=true) where {FST} + skip_compile_jvp::Bool=true) where {FST} # TODO: change skip_compile_jvp to false skip_compile_vjp && @argcheck !force_compile_vjp skip_compile_jvp && @argcheck !force_compile_jvp - return ToReactantAdaptor{FST}(input_prototype, ps_transform, rng, - force_allow_mixed_eltypes, skip_compile_vjp, force_compile_vjp, - skip_compile_jvp, force_compile_jvp) + @argcheck any(!, + (skip_compile_forward, skip_compile_inference, skip_compile_vjp, skip_compile_jvp)) + + return ToReactantAdaptor{FST}( + input_prototype, ps_transform, rng, force_allow_mixed_eltypes, + skip_compile_forward, skip_compile_inference, skip_compile_vjp, + force_compile_vjp, skip_compile_jvp, force_compile_jvp) end function ToReactantAdaptor(args...; fixed_state_type::Val=Val(true), kwargs...) return ToReactantAdaptor{__unwrap_val(fixed_state_type)}(args...; kwargs...) diff --git a/src/utils.jl b/src/utils.jl index f2ba54a743..d52081aa3b 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -271,6 +271,7 @@ end @inline __recursive_eltype(x::NamedTuple) = promote_type(__recursive_eltype.(values(x))...) @inline __recursive_eltype(::Nothing) = Bool @inline __recursive_eltype(x::Number) = eltype(x) +@inline __recursive_eltype(::Val) = Bool @inline function __recursive_eltype(x) _eltype = Ref(Bool) function __internal_recursive_eltype(x)