Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
3e1f5a6
feat: compile training loop automatically using reactant
avik-pal Oct 4, 2024
cecc9f6
refactor: add a level of indirection for the train_step
avik-pal Oct 4, 2024
99e39e4
feat: directly compile step + grad function
avik-pal Oct 4, 2024
81d5e43
fix: make note of current issue with inplace update
avik-pal Oct 4, 2024
4a2bd5d
chore: bump minimum reactant version
avik-pal Oct 4, 2024
1029ed8
test: setup specific reactant test group
avik-pal Oct 4, 2024
7d400f9
ci: temporarily disable other tests (drop me)
avik-pal Oct 4, 2024
396f9eb
test: fix installation of Reactant
avik-pal Oct 4, 2024
d20064b
test: start adding loss function tests
avik-pal Oct 5, 2024
0701270
fix: xlogx and xlogy now work with Reactant scalars
avik-pal Oct 5, 2024
9b0a3b0
feat: support regression losses + tests
avik-pal Oct 5, 2024
fa21661
test: classification losses
avik-pal Oct 5, 2024
e5743cc
fix: more specialization
avik-pal Oct 5, 2024
8ef032f
fix: support all loss functions
avik-pal Oct 5, 2024
c261d5b
chore: comments
avik-pal Oct 5, 2024
09cc32a
fix: bump reactant version
avik-pal Oct 6, 2024
128ff2f
test: don't run reactant tests on windows
avik-pal Oct 8, 2024
fcbf03a
test: temporarily disable more tests
avik-pal Oct 8, 2024
7c842fd
fix: reactant GPU support
avik-pal Oct 8, 2024
d958283
fix: remove old LossFunctions.jl dispatches
avik-pal Oct 9, 2024
38ed312
test: try using MSELoss directly
avik-pal Oct 9, 2024
5481cdc
ci: reactivate all tests
avik-pal Oct 9, 2024
77f1048
ci(windows): don't test Reactant on windows
avik-pal Oct 9, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ jobs:
- "recurrent_layers"
- "eltype_match"
- "fluxcompat"
- "reactant"
include:
- version: "1.10"
os: macos-latest
Expand Down
21 changes: 11 additions & 10 deletions .github/workflows/CIPreRelease.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,17 @@ jobs:
os:
- ubuntu-latest
test_group:
- "core_layers"
- "contrib"
- "helpers"
- "distributed"
- "normalize_layers"
- "others"
- "autodiff"
- "recurrent_layers"
- "eltype_match"
- "fluxcompat"
# - "core_layers"
# - "contrib"
# - "helpers"
# - "distributed"
# - "normalize_layers"
# - "others"
# - "autodiff"
# - "recurrent_layers"
# - "eltype_match"
# - "fluxcompat"
- "reactant"
steps:
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v2
Expand Down
7 changes: 5 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ LossFunctions = "30fc2ffe-d236-52d8-8643-a9d8f7c094a7"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
NCCL = "3fe64909-d7a1-4096-9b7d-7a0f12cf0f6b"
Reactant = "3c362404-f566-11ee-1572-e11a4b42c853"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
SimpleChains = "de6bee2f-e2f4-4ec7-b6ed-219cc6f6e9e5"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Expand All @@ -59,6 +60,7 @@ LuxLossFunctionsExt = "LossFunctions"
LuxMLUtilsExt = "MLUtils"
LuxMPIExt = "MPI"
LuxMPINCCLExt = ["CUDA", "MPI", "NCCL"]
LuxReactantExt = ["Enzyme", "Reactant"]
LuxReverseDiffExt = ["FunctionWrappers", "ReverseDiff"]
LuxSimpleChainsExt = "SimpleChains"
LuxTrackerExt = "Tracker"
Expand All @@ -68,7 +70,7 @@ LuxZygoteExt = "Zygote"
ADTypes = "1.8.1"
Adapt = "4"
ArgCheck = "2.3"
ArrayInterface = "7.9"
ArrayInterface = "7.10"
CUDA = "5.3.2"
ChainRulesCore = "1.24"
Compat = "4.15"
Expand All @@ -87,7 +89,7 @@ LinearAlgebra = "1.10"
LossFunctions = "0.11.1"
LuxCore = "1"
LuxLib = "1.3"
MLDataDevices = "1.1"
MLDataDevices = "1.2"
MLUtils = "0.4.4"
MPI = "0.20.19"
MacroTools = "0.5.13"
Expand All @@ -97,6 +99,7 @@ NNlib = "0.9.24"
Optimisers = "0.3.3"
Preferences = "1.4.3"
Random = "1.10"
Reactant = "0.2.3"
Reexport = "1.2.2"
ReverseDiff = "1.15"
SIMDTypes = "0.1"
Expand Down
9 changes: 4 additions & 5 deletions ext/LuxEnzymeExt/training.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
function Lux.Training.compute_gradients(
function Lux.Training.compute_gradients_impl(
ad::AutoEnzyme, obj_fn::F, data, ts::TrainState) where {F}
dps = Lux.recursive_make_zero(ts.parameters)

Expand All @@ -20,9 +20,8 @@ end
const AUTODIFF_CACHE_TYPE = TrainingBackendCache{
<:AutoEnzyme, False, PS, <:NamedTuple{(:obj_fn, :st_wrap, :stats_wrap)}} where {PS}

function Lux.Training.compute_gradients(
function Lux.Training.compute_gradients_impl(
::AutoEnzyme, obj_fn::F, data, ts::TrainState{<:AUTODIFF_CACHE_TYPE, F}) where {F}
# dps = Lux.recursive_make_zero!!(ts.cache.dparameters)
Enzyme.make_zero!(ts.cache.dparameters)
dps = ts.cache.dparameters

Expand All @@ -36,7 +35,7 @@ function Lux.Training.compute_gradients(
return dps, loss, ts.cache.extras.stats_wrap[], ts
end

function Lux.Training.compute_gradients(ad::AutoEnzyme, obj_fn::F, data,
function Lux.Training.compute_gradients_impl(ad::AutoEnzyme, obj_fn::F, data,
ts::TrainState{<:TrainingBackendCache{<:AutoEnzyme, False}}) where {F}
@warn "Detected calls to `compute_gradients(::AutoEnzyme, ...)` with objective \
function that is changing across function calls. This can lead to the \
Expand All @@ -56,7 +55,7 @@ end
const AUTODIFF_THUNK_CACHE_TYPE = TrainingBackendCache{
<:AutoEnzyme, False, PS, <:NamedTuple{(:forward, :reverse)}} where {PS}

function Lux.Training.compute_gradients(::AutoEnzyme, obj_fn::F, data,
function Lux.Training.compute_gradients_impl(::AutoEnzyme, obj_fn::F, data,
ts::TrainState{<:AUTODIFF_THUNK_CACHE_TYPE, F}) where {F}
dps = Lux.recursive_make_zero!!(ts.cache.dparameters)
params = Duplicated(ts.parameters, dps)
Expand Down
14 changes: 14 additions & 0 deletions ext/LuxReactantExt/LuxReactantExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
module LuxReactantExt

using Enzyme: Enzyme, Const, Duplicated, Active
using Optimisers: Optimisers
using Reactant: Reactant, @compile, TracedRArray
using Setfield: @set!
using Static: False

using Lux: Lux, LuxOps, Training
using Lux.Training: TrainingBackendCache, ReactantBackend

include("training.jl")

end
92 changes: 92 additions & 0 deletions ext/LuxReactantExt/training.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
function Lux.Training.compute_gradients_impl(
backend::ReactantBackend, objective_function::F,
data, ts::Training.TrainState) where {F}
compiled_gradient_function = @compile compute_gradients_internal(
objective_function, ts.model, data, ts.parameters, ts.states)

grads, loss, stats, st = compiled_gradient_function(
objective_function, ts.model, data, ts.parameters, ts.states)

cache = TrainingBackendCache(backend, False(), nothing, (; compiled_gradient_function))
@set! ts.cache = cache
@set! ts.objective_function = objective_function
@set! ts.states = st
return grads, loss, stats, ts
end

function Lux.Training.compute_gradients_impl(::ReactantBackend, obj_fn::F, data,
ts::Training.TrainState{<:TrainingBackendCache{ReactantBackend}, F}) where {F}
grads, loss, stats, st = ts.cache.extras.compiled_gradient_function(
obj_fn, ts.model, data, ts.parameters, ts.states)
@set! ts.states = st
return grads, loss, stats, ts
end

function compute_gradients_internal(objective_function::F, model, data, ps, st) where {F}
dps = Enzyme.make_zero(ps)
_, (loss, stₙ, stats) = Enzyme.autodiff(
Enzyme.ReverseWithPrimal, Const(objective_function), Active, Const(model),
Duplicated(ps, dps), Const(st), Const(data))
return dps, loss, stats, stₙ
end

for inplace in ("!", "")
fname = Symbol(:single_train_step_impl, inplace)
internal_fn = Symbol(:compute_gradients_internal_and_step, inplace)

@eval function Lux.Training.$(fname)(backend::ReactantBackend, objective_function::F,
data, ts::Training.TrainState) where {F}
compiled_grad_and_step_function = @compile $(internal_fn)(
objective_function, ts.model, data, ts.parameters, ts.states,
ts.optimizer_state)

grads, ps, loss, stats, st, opt_state = compiled_grad_and_step_function(
objective_function, ts.model, data, ts.parameters, ts.states,
ts.optimizer_state)

cache = TrainingBackendCache(
backend, False(), nothing, (; compiled_grad_and_step_function))
@set! ts.cache = cache
@set! ts.objective_function = objective_function
@set! ts.states = st
@set! ts.parameters = ps
@set! ts.optimizer_state = opt_state
@set! ts.step = ts.step + 1

return grads, loss, stats, ts
end

@eval function Lux.Training.$(fname)(::ReactantBackend, obj_fn::F, data,
ts::Training.TrainState{<:TrainingBackendCache{ReactantBackend}, F}) where {F}
grads, ps, loss, stats, st, opt_state = ts.cache.extras.compiled_grad_and_step_function(
obj_fn, ts.model, data, ts.parameters, ts.states, ts.optimizer_state)

@set! ts.states = st
@set! ts.parameters = ps
@set! ts.optimizer_state = opt_state
@set! ts.step = ts.step + 1

return grads, loss, stats, ts
end
end

function compute_gradients_internal_and_step(objective_function::F, model, data, ps,
st, opt_state) where {F}
dps = Enzyme.make_zero(ps)
_, (loss, stₙ, stats) = Enzyme.autodiff(
Enzyme.ReverseWithPrimal, Const(objective_function), Active, Const(model),
Duplicated(ps, dps), Const(st), Const(data))
opt_state, ps = Optimisers.update(opt_state, ps, dps)
return dps, ps, loss, stats, stₙ, opt_state
end

function compute_gradients_internal_and_step!(objective_function::F, model, data, ps,
st, opt_state) where {F}
dps = Enzyme.make_zero(ps)
_, (loss, stₙ, stats) = Enzyme.autodiff(
Enzyme.ReverseWithPrimal, Const(objective_function), Active, Const(model),
Duplicated(ps, dps), Const(st), Const(data))
# XXX: Inplace updates not actually inplace
opt_state, ps = Optimisers.update!(opt_state, ps, dps)
return dps, ps, loss, stats, stₙ, opt_state
end
10 changes: 5 additions & 5 deletions ext/LuxReverseDiffExt/training.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
# Uncompiled ReverseDiff
function Lux.Training.compute_gradients(
function Lux.Training.compute_gradients_impl(
ad::AutoReverseDiff{false}, obj_fn::F, data, ts::TrainState) where {F}
@set! ts.cache = TrainingBackendCache(
ad, True(), Lux.recursive_make_zero(ts.parameters), nothing)
@set! ts.objective_function = obj_fn
return Lux.Training.compute_gradients(ad, obj_fn, data, ts)
end

function Lux.Training.compute_gradients(::AutoReverseDiff{false}, obj_fn::F, data,
function Lux.Training.compute_gradients_impl(::AutoReverseDiff{false}, obj_fn::F, data,
ts::TrainState{<:TrainingBackendCache{AutoReverseDiff{false}}}) where {F}
dparams = Training.dparameters(ts.cache)
tape = ReverseDiff.InstructionTape()
Expand All @@ -24,7 +24,7 @@ function Lux.Training.compute_gradients(::AutoReverseDiff{false}, obj_fn::F, dat
end

# Compiled ReverseDiff
function Lux.Training.compute_gradients(
function Lux.Training.compute_gradients_impl(
ad::AutoReverseDiff{true}, obj_fn::F, data, ts::TrainState) where {F}
@set! ts.cache = TrainingBackendCache(
ad, True(), Lux.recursive_make_zero(ts.parameters),
Expand All @@ -35,7 +35,7 @@ function Lux.Training.compute_gradients(
end

## Tape hasn't been compiled yet / Function mismatch so recompile
function Lux.Training.compute_gradients(ad::AutoReverseDiff{true}, obj_fn::F, data,
function Lux.Training.compute_gradients_impl(ad::AutoReverseDiff{true}, obj_fn::F, data,
ts::TrainState{<:TrainingBackendCache{AutoReverseDiff{true}}}) where {F}
if LuxCore.statelength(ts.states) != 0
throw(ArgumentError("AutoReverseDiff(; compile=true) is not supported for Lux \
Expand Down Expand Up @@ -82,7 +82,7 @@ function Lux.Training.compute_gradients(ad::AutoReverseDiff{true}, obj_fn::F, da
return dparams, ReverseDiff.value(loss), NamedTuple(), ts
end

function Lux.Training.compute_gradients(::AutoReverseDiff{true}, obj_fn::F, data,
function Lux.Training.compute_gradients_impl(::AutoReverseDiff{true}, obj_fn::F, data,
ts::TrainState{<:TrainingBackendCache{AutoReverseDiff{true}}, F}) where {F}
(; ps_cache, data_cache, output) = ts.cache.extras

Expand Down
4 changes: 2 additions & 2 deletions ext/LuxTrackerExt/training.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
function Lux.Training.compute_gradients(::AutoTracker, obj_fn::F, data,
function Lux.Training.compute_gradients_impl(::AutoTracker, obj_fn::F, data,
ts::TrainState{<:TrainingBackendCache{AutoTracker}}) where {F}
dps = Training.dparameters(ts.cache)
ps_tracked = construct_tracked_params(ts.parameters, dps)
Expand All @@ -13,7 +13,7 @@ function Lux.Training.compute_gradients(::AutoTracker, obj_fn::F, data,
return dps, loss.data, stats, ts
end

function Lux.Training.compute_gradients(
function Lux.Training.compute_gradients_impl(
ad::AutoTracker, obj_fn::F, data, ts::TrainState) where {F}
grads = Lux.recursive_make_zero(ts.parameters)
cache = TrainingBackendCache(ad, True(), grads, nothing)
Expand Down
2 changes: 1 addition & 1 deletion ext/LuxZygoteExt/training.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
function Lux.Training.compute_gradients(
function Lux.Training.compute_gradients_impl(
::AutoZygote, objective_function::F, data, ts::Lux.Training.TrainState) where {F}
(loss, st, stats), back = Zygote.pullback(
objective_function, ts.model, ts.parameters, ts.states, data)
Expand Down
5 changes: 3 additions & 2 deletions src/helpers/losses.jl
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,8 @@ function huber_loss(x::T1, y::T2, δ::T3) where {T1, T2, T3}
T = promote_type(T1, T2, T3)
diff = x - y
abs_diff = abs(diff)
return ifelse(abs_diff ≤ δ, T(0.5) * abs2(diff), δ * (abs_diff - T(0.5) * δ))
return ifelse(
abs_diff ≤ δ, convert(T, 0.5) * abs2(diff), δ * (abs_diff - convert(T, 0.5) * δ))
end
has_custom_derivative(::typeof(huber_loss)) = true
function derivative(::typeof(huber_loss), x::T, y::T2, δ::T3) where {T, T2, T3}
Expand Down Expand Up @@ -148,7 +149,7 @@ function derivative(::typeof(l2_hinge_loss), x::T1, y::T2) where {T1, T2}
end

function siamese_contrastive_loss(x::T1, y::T2, margin=true) where {T1, T2}
return (true - y) * x^2 + y * max(promote_type(T1, T2)(false), margin - x)^2
return (true - y) * x^2 + y * max(convert(promote_type(T1, T2), false), margin - x)^2
end

poisson_loss(x::T1, y::T2, ϵ) where {T1, T2} = x - xlogy(y, x + get_ϵ(T1, ϵ))
Expand Down
43 changes: 38 additions & 5 deletions src/helpers/training.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ using Static: StaticBool, Static, False, True

using ..Lux: Lux
using LuxCore: LuxCore, AbstractLuxLayer
using MLDataDevices: XLADevice, get_device_type, get_device, cpu_device

"""
TrainState
Expand Down Expand Up @@ -61,7 +62,13 @@ Constructor for [`TrainState`](@ref).
[`TrainState`](@ref) object.
"""
function TrainState(model::AbstractLuxLayer, ps, st, optimizer::Optimisers.AbstractRule)
st_opt = Optimisers.setup(optimizer, ps)
dev = get_device(ps)
st_opt = if dev isa XLADevice
ps_cpu = ps |> cpu_device()
Optimisers.setup(optimizer, ps_cpu) |> dev
else
Optimisers.setup(optimizer, ps)
end
return TrainState(nothing, nothing, model, ps, st, optimizer, st_opt, 0)
end

Expand Down Expand Up @@ -96,6 +103,8 @@ function Base.show(io::IO, ::MIME"text/plain", ts::TrainState)
print(io, "\n objective_function: ", nameof(typeof(ts.objective_function)))
end

struct ReactantBackend end

const APPLY_GRAD_DOCSTRING = """
## Arguments

Expand Down Expand Up @@ -183,7 +192,20 @@ A 4-Tuple containing:
returned in step `i + 1` might be aliased by the old gradients. If you want to prevent
this, simply use `copy(grads)` or `deepcopy(grads)` to make a copy of the gradients.
"""
function compute_gradients(ad::AbstractADType, ::F, _, ::TrainState) where {F}
function compute_gradients(ad, obj_fn::F, data, ts::TrainState) where {F}
dev_type = get_device_type((ts.parameters, ts.states))
return compute_gradients_impl(maybe_wrap_adtype(ad, dev_type), obj_fn, data, ts)
end

maybe_wrap_adtype(backend::ReactantBackend, _) = backend
maybe_wrap_adtype(ad::AbstractADType, _) = ad
function maybe_wrap_adtype(ad::AbstractADType, ::Type{XLADevice})
ad isa AutoEnzyme && return ReactantBackend()
throw(ArgumentError("Computing gradients for models on XLA is supported only with \
Enzyme.jl (`AutoEnzyme`)."))
end

function compute_gradients_impl(ad, ::F, _, ts::TrainState) where {F}
return check_if_compute_gradients_implemented(ad)
end

Expand All @@ -192,6 +214,10 @@ function check_if_compute_gradients_implemented(::T) where {T <: AbstractADType}
yet!"))
end

function check_if_compute_gradients_implemented(::ReactantBackend)
throw(ArgumentError("Load `Reactant` with `using Reactant` before using this function!"))
end

for package in (:Zygote, :Tracker, :ReverseDiff, :Enzyme)
adtype = Symbol(:Auto, package)
msg = "Load `$(package)` with `using $(package)`/`import $(package)` before using this \
Expand Down Expand Up @@ -244,7 +270,10 @@ only the parameters in `ts` are updated inplace. Users should be using the retur
object for further training steps, else there is no caching and performance will be
suboptimal (and absolutely terrible for backends like `AutoReactant`).
"""
function single_train_step! end
function single_train_step!(backend, obj_fn::F, data, ts::TrainState) where {F}
backend = maybe_wrap_adtype(backend, get_device_type((ts.parameters, ts.states)))
return single_train_step_impl!(backend, obj_fn, data, ts)
end

"""
single_train_step(backend, obj_fn::F, data, ts::TrainState)
Expand All @@ -259,10 +288,14 @@ In most cases you should use [`single_train_step!`](@ref) instead of this functi

Returned values are the same as [`compute_gradients`](@ref).
"""
function single_train_step end
function single_train_step(backend, obj_fn::F, data, ts::TrainState) where {F}
backend = maybe_wrap_adtype(backend, get_device_type((ts.parameters, ts.states)))
return single_train_step_impl(backend, obj_fn, data, ts)
end

for inplace in ("!", "")
step, apply_fn = Symbol(:single_train_step, inplace), Symbol(:apply_gradients, inplace)
step = Symbol(:single_train_step_impl, inplace)
apply_fn = Symbol(:apply_gradients, inplace)
@eval function $(step)(backend, obj_fn::F, data, ts::TrainState) where {F}
grads, loss, stats, ts = compute_gradients(backend, obj_fn, data, ts)
ts = $(apply_fn)(ts, grads)
Expand Down
Loading