diff --git a/Project.toml b/Project.toml index 52263c94f1..e848022771 100644 --- a/Project.toml +++ b/Project.toml @@ -46,6 +46,7 @@ FunctionWrappers = "069b7b12-0de2-55c6-9aab-29f3d0a68a2e" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195" NCCL = "3fe64909-d7a1-4096-9b7d-7a0f12cf0f6b" +Reactant = "3c362404-f566-11ee-1572-e11a4b42c853" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" SimpleChains = "de6bee2f-e2f4-4ec7-b6ed-219cc6f6e9e5" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" @@ -60,6 +61,7 @@ LuxMPIExt = "MPI" LuxMPINCCLExt = ["CUDA", "MPI", "NCCL"] LuxReverseDiffExt = ["FunctionWrappers", "ReverseDiff"] LuxSimpleChainsExt = "SimpleChains" +LuxReactantExt = ["Enzyme", "Reactant"] LuxTrackerExt = "Tracker" LuxZygoteExt = "Zygote" @@ -67,7 +69,7 @@ LuxZygoteExt = "Zygote" ADTypes = "1.5" Adapt = "4" ArgCheck = "2.3" -ArrayInterface = "7.9" +ArrayInterface = "7.10" CUDA = "5.3.2" ChainRulesCore = "1.24" Compat = "4.15" @@ -96,6 +98,7 @@ NNlib = "0.9.21" Optimisers = "0.3.3" Preferences = "1.4.3" Random = "1.10" +Reactant = "0.2" Reexport = "1.2.2" ReverseDiff = "1.15" SIMDTypes = "0.1" diff --git a/ext/LuxReactantExt/LuxReactantExt.jl b/ext/LuxReactantExt/LuxReactantExt.jl new file mode 100644 index 0000000000..0a72884ef2 --- /dev/null +++ b/ext/LuxReactantExt/LuxReactantExt.jl @@ -0,0 +1,14 @@ +module LuxReactantExt + +using Enzyme: Enzyme, Active, Const, Duplicated +using Reactant: Reactant +using Static: Static, False +using Setfield: @set! + +using Lux: Lux, ReactantBackend +using Lux.Training: TrainingBackendCache, TrainState +using LuxCore: LuxCore + +include("training.jl") + +end diff --git a/ext/LuxReactantExt/training.jl b/ext/LuxReactantExt/training.jl new file mode 100644 index 0000000000..1f4553fe53 --- /dev/null +++ b/ext/LuxReactantExt/training.jl @@ -0,0 +1,77 @@ +function Lux.Training.single_train_step!( + backend::ReactantBackend, obj_fn::F, data, ts::TrainState) where {F} + data = Reactant.to_rarray(data) + ps = Reactant.to_rarray(ts.parameters) + st = Reactant.to_rarray(ts.states) + st_opt = Reactant.to_rarray(ts.optimizer_state) + + compiled_inference = if backend.input_prototype !== nothing + Reactant.compile(LuxCore.apply, + (ts.model, Reactant.to_rarray(backend.input_prototype), + ps, LuxCore.testmode(st))) + else + nothing + end + + compiled_grad_and_step! = Reactant.compile( + internal_grad_and_step!, (obj_fn, ts.model, ps, st, st_opt, data, ts.optimizer)) + + loss, st_updated, stats = compiled_grad_and_step!( + obj_fn, ts.model, ps, st, st_opt, data, ts.optimizer) + + cache = TrainingBackendCache(backend, False(), nothing, (; compiled_grad_and_step!, + compiled_inference)) + @set! ts.cache = cache + @set! ts.objective_function = obj_fn + @set! ts.parameters = ps + @set! ts.states = st_updated + @set! ts.optimizer_state = st_opt + @set! ts.step = ts.step + 1 + + return nothing, loss, stats, ts # TODO: Return the gradients +end + +function Lux.Training.single_train_step!(::ReactantBackend, obj_fn::F, data, + ts::TrainState{<:TrainingBackendCache{<:ReactantBackend}, F}) where {F} + data = Reactant.to_rarray(data) + + loss, st_updated, stats = ts.cache.extras.compiled_grad_and_step!( + obj_fn, ts.model, ts.parameters, ts.states, ts.optimizer_state, data, ts.optimizer) + + @set! ts.objective_function = obj_fn + @set! ts.states = st_updated + @set! ts.step = ts.step + 1 + + return nothing, loss, stats, ts # TODO: Return the gradients +end + +function internal_grad_and_step!( + obj_fn::F, model, ps, st, st_opt, data, optimizer) where {F} + dps = Lux.recursive_make_zero(ps) + + _, (loss, st_updated, stats) = Enzyme.autodiff( + Enzyme.ReverseWithPrimal, obj_fn, Active, Const(model), + Duplicated(ps, dps), Const(st), Const(data)) + + Lux.simple_optimizers_apply!(optimizer, st_opt, ps, dps) # ps & st_opt are updated in-place + + return loss, st_updated, stats +end + +function (tstate::TrainState{<:TrainingBackendCache{<:ReactantBackend}})(data) + data_reactant = Reactant.to_rarray(data) + compiled_inference = if tstate.cache.extras.compiled_inference !== nothing + tstate.cache.extras.compiled_inference + else + @warn "Inference function not compiled before. This will trigger compilation on \ + every inference call to `(::TrainState)(data)`. Please use \ + `ReactantBackend(; input_prototype = data)` to compile the inference \ + function on the first call to `single_train_step!` or \ + `single_train_step`." maxlog=1 + Reactant.compile(LuxCore.apply, + (tstate.model, data_reactant, tstate.parameters, + LuxCore.testmode(tstate.states))) + end + return compiled_inference( + tstate.model, data_reactant, tstate.parameters, LuxCore.testmode(tstate.states)) +end diff --git a/src/Lux.jl b/src/Lux.jl index 972b8aa419..2a845cca41 100644 --- a/src/Lux.jl +++ b/src/Lux.jl @@ -14,7 +14,12 @@ using GPUArraysCore: @allowscalar using LossFunctions: LossFunctions using Markdown: @doc_str using NNlib: NNlib +<<<<<<< HEAD using Optimisers: Optimisers +======= +using Optimisers: Optimisers, Leaf, Descent +using Preferences: load_preference, has_preference +>>>>>>> f68ad624 (refactor(reactant): move optimisers into main pkg) using Random: Random, AbstractRNG using Static: StaticBool, StaticInt, StaticSymbol, True, False, static, known, dynamic using Reexport: Reexport, @reexport @@ -46,12 +51,20 @@ include("extended_ops.jl") # Training Helpers include("helpers/training.jl") +# Compilers +include("compilers.jl") + # Experimental include("contrib/contrib.jl") # Pretty Printing include("layers/display.jl") +# Transform to and from other frameworks +include("transform/types.jl") +include("transform/flux.jl") +include("transform/simplechains.jl") + # Layer Implementations include("layers/basic.jl") include("layers/containers.jl") @@ -70,16 +83,12 @@ include("helpers/losses.jl") include("helpers/recursive_ops.jl") include("helpers/match_eltype.jl") include("helpers/size_propagator.jl") +include("helpers/simple_optimizers.jl") # AutoDiff include("autodiff/api.jl") include("autodiff/autodiff.jl") -# Transform to and from other frameworks -include("transform/types.jl") -include("transform/flux.jl") -include("transform/simplechains.jl") - # Distributed Training include("distributed/backend.jl") include("distributed/public_api.jl") @@ -106,6 +115,8 @@ export jacobian_vector_product, vector_jacobian_product export batched_jacobian export AutoEnzyme, AutoForwardDiff, AutoReverseDiff, AutoTracker, AutoZygote +export ReactantBackend + export BinaryCrossEntropyLoss, BinaryFocalLoss, CrossEntropyLoss, DiceCoeffLoss, FocalLoss, HingeLoss, HuberLoss, KLDivergenceLoss, L1Loss, L2Loss, MAELoss, MSELoss, MSLELoss, PoissonLoss, SiameseContrastiveLoss, SquaredHingeLoss diff --git a/src/compilers.jl b/src/compilers.jl new file mode 100644 index 0000000000..57725f03c9 --- /dev/null +++ b/src/compilers.jl @@ -0,0 +1,29 @@ +abstract type AbstractCompilerBackend end + +""" + ReactantBackend(; input_prototype = nothing) + +Compile Lux model and gradient computation to MLIR/XLA via `Reactant.jl`. + +!!! tip "Newly Added Feature!" + + This has been added to Lux very recently and is under-going rapid development. + Currently, only a limited subset of Lux models can be compiled via `Reactant.jl`. If you + encounter any issues, please report them on the `Lux.jl` or `Reactant.jl` GitHub + repository. + +## Keyword Arguments + + - `input_prototype`: Input data representative of the data that will be used for + inference. If this is provided, we will compile the inference function with + `Reactant.jl` on the first call to [`Lux.Experimental.single_train_step!`](@ref) or + [`Lux.Experimental.single_train_step`](@ref). If this is not provided, we will have to + recompile the inference function on every call to `(::TrainState)(data)` and this will + be prohibitively expensive. + +See [`Lux.Experimental.single_train_step!`](@ref) or +[`Lux.Experimental.single_train_step`](@ref) for information on how to use this backend. +""" +@kwdef @concrete struct ReactantBackend <: AbstractCompilerBackend + input_prototype = nothing +end diff --git a/src/helpers/simple_optimizers.jl b/src/helpers/simple_optimizers.jl new file mode 100644 index 0000000000..cd9a32b082 --- /dev/null +++ b/src/helpers/simple_optimizers.jl @@ -0,0 +1,14 @@ +# These are meant to be used internally for compiling certain lux optiomization +function simple_optimizers_apply!(ps, gs, leaf::Leaf{<:Descent}) + @. ps -= leaf.rule.eta * gs +end + +for opt in (Descent,) + @eval function simple_optimizers_apply!(::$(opt), st_opt, ps, gs) + recursive_map(simple_optimizers_apply!, ps, gs, st_opt) + end +end + +function simple_optimizers_apply!(opt, st_opt, ps, gs) + throw(ArgumentError("Optimizer $(typeof(opt)) not yet supported.")) +end diff --git a/src/helpers/training.jl b/src/helpers/training.jl index 51fdb1a48a..6066ad2583 100644 --- a/src/helpers/training.jl +++ b/src/helpers/training.jl @@ -14,7 +14,7 @@ using LuxCore: LuxCore, AbstractLuxLayer """ TrainState -Training State containing: +## Training State containing: - `model`: `Lux` model. - `parameters`: Trainable Variables of the `model`. @@ -23,7 +23,7 @@ Training State containing: - `optimizer_state`: Optimizer State. - `step`: Number of updates of the parameters made. -Internal fields: +## Internal fields: - `cache`: Cached values. Implementations are free to use this for whatever they want. - `objective_function`: Objective function might be cached. @@ -32,6 +32,12 @@ Internal fields: Constructing this object directly shouldn't be considered a stable API. Use the version with the Optimisers API. + +## Special Features + +To run inference using the current parameters and states simply call the TrainState with +the input data as `tstate(data)`. This will automatically set `Lux.testmode`. However, note +that `tstate.states` will not be updated with the new state. """ @concrete struct TrainState cache @@ -65,6 +71,8 @@ function TrainState(model::AbstractLuxLayer, ps, st, optimizer::Optimisers.Abstr return TrainState(nothing, nothing, model, ps, st, optimizer, st_opt, 0) end +(ts::TrainState)(data) = ts.model(data, ts.parameters, Lux.testmode(ts.states)) + @concrete struct TrainingBackendCache backend first_try <: StaticBool @@ -237,12 +245,16 @@ Perform a single training step. Computes the gradients using [`compute_gradients updates the parameters using [`apply_gradients!`](@ref). All backends supported via [`compute_gradients`](@ref) are supported here. +## Additional Backends + + - [`ReactantBackend`](@ref): Compiles the training loop to MLIR/XLA via `Reactant.jl`. + ## Return Returned values are the same as [`compute_gradients`](@ref). Note that despite the `!`, only the parameters in `ts` are updated inplace. Users should be using the returned `ts` object for further training steps, else there is no caching and performance will be -suboptimal (and absolutely terrible for backends like `AutoReactant`). +suboptimal (and absolutely terrible for backends like `ReactantBackend`). """ function single_train_step! end @@ -253,6 +265,10 @@ Perform a single training step. Computes the gradients using [`compute_gradients updates the parameters using [`apply_gradients`](@ref). All backends supported via [`compute_gradients`](@ref) are supported here. +## Additional Backends + + - [`ReactantBackend`](@ref): Compiles the training loop to MLIR/XLA via `Reactant.jl`. + In most cases you should use [`single_train_step!`](@ref) instead of this function. ## Return diff --git a/test/qa_tests.jl b/test/qa_tests.jl index 074f464b09..c1e64e45e7 100644 --- a/test/qa_tests.jl +++ b/test/qa_tests.jl @@ -12,7 +12,8 @@ end @testitem "Explicit Imports: Quality Assurance" setup=[SharedTestSetup] tags=[:others] begin # Load all trigger packages - import Lux, ComponentArrays, ReverseDiff, SimpleChains, Tracker, Zygote, Enzyme + import Lux, ComponentArrays, ReverseDiff, Flux, SimpleChains, Tracker, Zygote, Enzyme, + Reactant using ExplicitImports # Skip our own packages