Skip to content
5 changes: 4 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ FunctionWrappers = "069b7b12-0de2-55c6-9aab-29f3d0a68a2e"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
NCCL = "3fe64909-d7a1-4096-9b7d-7a0f12cf0f6b"
Reactant = "3c362404-f566-11ee-1572-e11a4b42c853"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
SimpleChains = "de6bee2f-e2f4-4ec7-b6ed-219cc6f6e9e5"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Expand All @@ -60,14 +61,15 @@ LuxMPIExt = "MPI"
LuxMPINCCLExt = ["CUDA", "MPI", "NCCL"]
LuxReverseDiffExt = ["FunctionWrappers", "ReverseDiff"]
LuxSimpleChainsExt = "SimpleChains"
LuxReactantExt = ["Enzyme", "Reactant"]
LuxTrackerExt = "Tracker"
LuxZygoteExt = "Zygote"

[compat]
ADTypes = "1.5"
Adapt = "4"
ArgCheck = "2.3"
ArrayInterface = "7.9"
ArrayInterface = "7.10"
CUDA = "5.3.2"
ChainRulesCore = "1.24"
Compat = "4.15"
Expand Down Expand Up @@ -96,6 +98,7 @@ NNlib = "0.9.21"
Optimisers = "0.3.3"
Preferences = "1.4.3"
Random = "1.10"
Reactant = "0.2"
Reexport = "1.2.2"
ReverseDiff = "1.15"
SIMDTypes = "0.1"
Expand Down
14 changes: 14 additions & 0 deletions ext/LuxReactantExt/LuxReactantExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
module LuxReactantExt

using Enzyme: Enzyme, Active, Const, Duplicated
using Reactant: Reactant
using Static: Static, False
using Setfield: @set!

using Lux: Lux, ReactantBackend
using Lux.Training: TrainingBackendCache, TrainState
using LuxCore: LuxCore

include("training.jl")

end
77 changes: 77 additions & 0 deletions ext/LuxReactantExt/training.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
function Lux.Training.single_train_step!(
backend::ReactantBackend, obj_fn::F, data, ts::TrainState) where {F}
data = Reactant.to_rarray(data)
ps = Reactant.to_rarray(ts.parameters)
st = Reactant.to_rarray(ts.states)
st_opt = Reactant.to_rarray(ts.optimizer_state)

compiled_inference = if backend.input_prototype !== nothing
Reactant.compile(LuxCore.apply,
(ts.model, Reactant.to_rarray(backend.input_prototype),
ps, LuxCore.testmode(st)))
else
nothing
end

compiled_grad_and_step! = Reactant.compile(
internal_grad_and_step!, (obj_fn, ts.model, ps, st, st_opt, data, ts.optimizer))

loss, st_updated, stats = compiled_grad_and_step!(
obj_fn, ts.model, ps, st, st_opt, data, ts.optimizer)

cache = TrainingBackendCache(backend, False(), nothing, (; compiled_grad_and_step!,
compiled_inference))
Comment thread
avik-pal marked this conversation as resolved.
@set! ts.cache = cache
@set! ts.objective_function = obj_fn
@set! ts.parameters = ps
@set! ts.states = st_updated
@set! ts.optimizer_state = st_opt
@set! ts.step = ts.step + 1

return nothing, loss, stats, ts # TODO: Return the gradients
end

function Lux.Training.single_train_step!(::ReactantBackend, obj_fn::F, data,
ts::TrainState{<:TrainingBackendCache{<:ReactantBackend}, F}) where {F}
data = Reactant.to_rarray(data)

loss, st_updated, stats = ts.cache.extras.compiled_grad_and_step!(
obj_fn, ts.model, ts.parameters, ts.states, ts.optimizer_state, data, ts.optimizer)

@set! ts.objective_function = obj_fn
@set! ts.states = st_updated
@set! ts.step = ts.step + 1

return nothing, loss, stats, ts # TODO: Return the gradients
end

function internal_grad_and_step!(
obj_fn::F, model, ps, st, st_opt, data, optimizer) where {F}
dps = Lux.recursive_make_zero(ps)

_, (loss, st_updated, stats) = Enzyme.autodiff(
Enzyme.ReverseWithPrimal, obj_fn, Active, Const(model),
Duplicated(ps, dps), Const(st), Const(data))

Lux.simple_optimizers_apply!(optimizer, st_opt, ps, dps) # ps & st_opt are updated in-place

return loss, st_updated, stats
end

function (tstate::TrainState{<:TrainingBackendCache{<:ReactantBackend}})(data)
data_reactant = Reactant.to_rarray(data)
compiled_inference = if tstate.cache.extras.compiled_inference !== nothing
tstate.cache.extras.compiled_inference
else
@warn "Inference function not compiled before. This will trigger compilation on \
every inference call to `(::TrainState)(data)`. Please use \
`ReactantBackend(; input_prototype = data)` to compile the inference \
function on the first call to `single_train_step!` or \
`single_train_step`." maxlog=1
Reactant.compile(LuxCore.apply,
(tstate.model, data_reactant, tstate.parameters,
LuxCore.testmode(tstate.states)))
end
return compiled_inference(
tstate.model, data_reactant, tstate.parameters, LuxCore.testmode(tstate.states))
end
21 changes: 16 additions & 5 deletions src/Lux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,12 @@ using GPUArraysCore: @allowscalar
using LossFunctions: LossFunctions
using Markdown: @doc_str
using NNlib: NNlib
<<<<<<< HEAD
using Optimisers: Optimisers
=======
using Optimisers: Optimisers, Leaf, Descent
using Preferences: load_preference, has_preference
>>>>>>> f68ad624 (refactor(reactant): move optimisers into main pkg)
using Random: Random, AbstractRNG
using Static: StaticBool, StaticInt, StaticSymbol, True, False, static, known, dynamic
using Reexport: Reexport, @reexport
Expand Down Expand Up @@ -46,12 +51,20 @@ include("extended_ops.jl")
# Training Helpers
include("helpers/training.jl")

# Compilers
include("compilers.jl")

# Experimental
include("contrib/contrib.jl")

# Pretty Printing
include("layers/display.jl")

# Transform to and from other frameworks
include("transform/types.jl")
include("transform/flux.jl")
include("transform/simplechains.jl")

# Layer Implementations
include("layers/basic.jl")
include("layers/containers.jl")
Expand All @@ -70,16 +83,12 @@ include("helpers/losses.jl")
include("helpers/recursive_ops.jl")
include("helpers/match_eltype.jl")
include("helpers/size_propagator.jl")
include("helpers/simple_optimizers.jl")

# AutoDiff
include("autodiff/api.jl")
include("autodiff/autodiff.jl")

# Transform to and from other frameworks
include("transform/types.jl")
include("transform/flux.jl")
include("transform/simplechains.jl")

# Distributed Training
include("distributed/backend.jl")
include("distributed/public_api.jl")
Expand All @@ -106,6 +115,8 @@ export jacobian_vector_product, vector_jacobian_product
export batched_jacobian
export AutoEnzyme, AutoForwardDiff, AutoReverseDiff, AutoTracker, AutoZygote

export ReactantBackend

export BinaryCrossEntropyLoss, BinaryFocalLoss, CrossEntropyLoss, DiceCoeffLoss, FocalLoss,
HingeLoss, HuberLoss, KLDivergenceLoss, L1Loss, L2Loss, MAELoss, MSELoss, MSLELoss,
PoissonLoss, SiameseContrastiveLoss, SquaredHingeLoss
Expand Down
29 changes: 29 additions & 0 deletions src/compilers.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
abstract type AbstractCompilerBackend end

"""
ReactantBackend(; input_prototype = nothing)

Compile Lux model and gradient computation to MLIR/XLA via `Reactant.jl`.

!!! tip "Newly Added Feature!"

This has been added to Lux very recently and is under-going rapid development.
Currently, only a limited subset of Lux models can be compiled via `Reactant.jl`. If you
encounter any issues, please report them on the `Lux.jl` or `Reactant.jl` GitHub
repository.

## Keyword Arguments

- `input_prototype`: Input data representative of the data that will be used for
inference. If this is provided, we will compile the inference function with
`Reactant.jl` on the first call to [`Lux.Experimental.single_train_step!`](@ref) or
[`Lux.Experimental.single_train_step`](@ref). If this is not provided, we will have to
recompile the inference function on every call to `(::TrainState)(data)` and this will
be prohibitively expensive.

See [`Lux.Experimental.single_train_step!`](@ref) or
[`Lux.Experimental.single_train_step`](@ref) for information on how to use this backend.
"""
@kwdef @concrete struct ReactantBackend <: AbstractCompilerBackend
input_prototype = nothing
end
14 changes: 14 additions & 0 deletions src/helpers/simple_optimizers.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# These are meant to be used internally for compiling certain lux optiomization
function simple_optimizers_apply!(ps, gs, leaf::Leaf{<:Descent})
@. ps -= leaf.rule.eta * gs
end

for opt in (Descent,)
@eval function simple_optimizers_apply!(::$(opt), st_opt, ps, gs)
recursive_map(simple_optimizers_apply!, ps, gs, st_opt)
end
end

function simple_optimizers_apply!(opt, st_opt, ps, gs)
throw(ArgumentError("Optimizer $(typeof(opt)) not yet supported."))
end
22 changes: 19 additions & 3 deletions src/helpers/training.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ using LuxCore: LuxCore, AbstractLuxLayer
"""
TrainState

Training State containing:
## Training State containing:

- `model`: `Lux` model.
- `parameters`: Trainable Variables of the `model`.
Expand All @@ -23,7 +23,7 @@ Training State containing:
- `optimizer_state`: Optimizer State.
- `step`: Number of updates of the parameters made.

Internal fields:
## Internal fields:

- `cache`: Cached values. Implementations are free to use this for whatever they want.
- `objective_function`: Objective function might be cached.
Expand All @@ -32,6 +32,12 @@ Internal fields:

Constructing this object directly shouldn't be considered a stable API. Use the
version with the Optimisers API.

## Special Features

To run inference using the current parameters and states simply call the TrainState with
the input data as `tstate(data)`. This will automatically set `Lux.testmode`. However, note
that `tstate.states` will not be updated with the new state.
"""
@concrete struct TrainState
cache
Expand Down Expand Up @@ -65,6 +71,8 @@ function TrainState(model::AbstractLuxLayer, ps, st, optimizer::Optimisers.Abstr
return TrainState(nothing, nothing, model, ps, st, optimizer, st_opt, 0)
end

(ts::TrainState)(data) = ts.model(data, ts.parameters, Lux.testmode(ts.states))

@concrete struct TrainingBackendCache
backend
first_try <: StaticBool
Expand Down Expand Up @@ -237,12 +245,16 @@ Perform a single training step. Computes the gradients using [`compute_gradients
updates the parameters using [`apply_gradients!`](@ref). All backends supported via
[`compute_gradients`](@ref) are supported here.

## Additional Backends

- [`ReactantBackend`](@ref): Compiles the training loop to MLIR/XLA via `Reactant.jl`.

## Return

Returned values are the same as [`compute_gradients`](@ref). Note that despite the `!`,
only the parameters in `ts` are updated inplace. Users should be using the returned `ts`
object for further training steps, else there is no caching and performance will be
suboptimal (and absolutely terrible for backends like `AutoReactant`).
suboptimal (and absolutely terrible for backends like `ReactantBackend`).
"""
function single_train_step! end

Expand All @@ -253,6 +265,10 @@ Perform a single training step. Computes the gradients using [`compute_gradients
updates the parameters using [`apply_gradients`](@ref). All backends supported via
[`compute_gradients`](@ref) are supported here.

## Additional Backends

- [`ReactantBackend`](@ref): Compiles the training loop to MLIR/XLA via `Reactant.jl`.

In most cases you should use [`single_train_step!`](@ref) instead of this function.

## Return
Expand Down
3 changes: 2 additions & 1 deletion test/qa_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ end

@testitem "Explicit Imports: Quality Assurance" setup=[SharedTestSetup] tags=[:others] begin
# Load all trigger packages
import Lux, ComponentArrays, ReverseDiff, SimpleChains, Tracker, Zygote, Enzyme
import Lux, ComponentArrays, ReverseDiff, Flux, SimpleChains, Tracker, Zygote, Enzyme,
Reactant
using ExplicitImports

# Skip our own packages
Expand Down