diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index d211220eab..33565d6c20 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -42,6 +42,7 @@ jobs: - "recurrent_layers" - "eltype_match" - "fluxcompat" + - "reactant" include: - version: "1.10" os: macos-latest diff --git a/.github/workflows/CIPreRelease.yml b/.github/workflows/CIPreRelease.yml index 6b378dc39d..d0a958f543 100644 --- a/.github/workflows/CIPreRelease.yml +++ b/.github/workflows/CIPreRelease.yml @@ -32,16 +32,17 @@ jobs: os: - ubuntu-latest test_group: - - "core_layers" - - "contrib" - - "helpers" - - "distributed" - - "normalize_layers" - - "others" - - "autodiff" - - "recurrent_layers" - - "eltype_match" - - "fluxcompat" + # - "core_layers" + # - "contrib" + # - "helpers" + # - "distributed" + # - "normalize_layers" + # - "others" + # - "autodiff" + # - "recurrent_layers" + # - "eltype_match" + # - "fluxcompat" + - "reactant" steps: - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v2 diff --git a/Project.toml b/Project.toml index 00435a5164..720bcf9fb1 100644 --- a/Project.toml +++ b/Project.toml @@ -46,6 +46,7 @@ LossFunctions = "30fc2ffe-d236-52d8-8643-a9d8f7c094a7" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195" NCCL = "3fe64909-d7a1-4096-9b7d-7a0f12cf0f6b" +Reactant = "3c362404-f566-11ee-1572-e11a4b42c853" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" SimpleChains = "de6bee2f-e2f4-4ec7-b6ed-219cc6f6e9e5" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" @@ -59,6 +60,7 @@ LuxLossFunctionsExt = "LossFunctions" LuxMLUtilsExt = "MLUtils" LuxMPIExt = "MPI" LuxMPINCCLExt = ["CUDA", "MPI", "NCCL"] +LuxReactantExt = ["Enzyme", "Reactant"] LuxReverseDiffExt = ["FunctionWrappers", "ReverseDiff"] LuxSimpleChainsExt = "SimpleChains" LuxTrackerExt = "Tracker" @@ -68,7 +70,7 @@ LuxZygoteExt = "Zygote" ADTypes = "1.8.1" Adapt = "4" ArgCheck = "2.3" -ArrayInterface = "7.9" +ArrayInterface = "7.10" CUDA = "5.3.2" ChainRulesCore = "1.24" Compat = "4.15" @@ -87,7 +89,7 @@ LinearAlgebra = "1.10" LossFunctions = "0.11.1" LuxCore = "1" LuxLib = "1.3" -MLDataDevices = "1.1" +MLDataDevices = "1.2" MLUtils = "0.4.4" MPI = "0.20.19" MacroTools = "0.5.13" @@ -97,6 +99,7 @@ NNlib = "0.9.24" Optimisers = "0.3.3" Preferences = "1.4.3" Random = "1.10" +Reactant = "0.2.3" Reexport = "1.2.2" ReverseDiff = "1.15" SIMDTypes = "0.1" diff --git a/ext/LuxEnzymeExt/training.jl b/ext/LuxEnzymeExt/training.jl index 410b9f11ef..3718379bf1 100644 --- a/ext/LuxEnzymeExt/training.jl +++ b/ext/LuxEnzymeExt/training.jl @@ -1,4 +1,4 @@ -function Lux.Training.compute_gradients( +function Lux.Training.compute_gradients_impl( ad::AutoEnzyme, obj_fn::F, data, ts::TrainState) where {F} dps = Lux.recursive_make_zero(ts.parameters) @@ -20,9 +20,8 @@ end const AUTODIFF_CACHE_TYPE = TrainingBackendCache{ <:AutoEnzyme, False, PS, <:NamedTuple{(:obj_fn, :st_wrap, :stats_wrap)}} where {PS} -function Lux.Training.compute_gradients( +function Lux.Training.compute_gradients_impl( ::AutoEnzyme, obj_fn::F, data, ts::TrainState{<:AUTODIFF_CACHE_TYPE, F}) where {F} - # dps = Lux.recursive_make_zero!!(ts.cache.dparameters) Enzyme.make_zero!(ts.cache.dparameters) dps = ts.cache.dparameters @@ -36,7 +35,7 @@ function Lux.Training.compute_gradients( return dps, loss, ts.cache.extras.stats_wrap[], ts end -function Lux.Training.compute_gradients(ad::AutoEnzyme, obj_fn::F, data, +function Lux.Training.compute_gradients_impl(ad::AutoEnzyme, obj_fn::F, data, ts::TrainState{<:TrainingBackendCache{<:AutoEnzyme, False}}) where {F} @warn "Detected calls to `compute_gradients(::AutoEnzyme, ...)` with objective \ function that is changing across function calls. This can lead to the \ @@ -56,7 +55,7 @@ end const AUTODIFF_THUNK_CACHE_TYPE = TrainingBackendCache{ <:AutoEnzyme, False, PS, <:NamedTuple{(:forward, :reverse)}} where {PS} -function Lux.Training.compute_gradients(::AutoEnzyme, obj_fn::F, data, +function Lux.Training.compute_gradients_impl(::AutoEnzyme, obj_fn::F, data, ts::TrainState{<:AUTODIFF_THUNK_CACHE_TYPE, F}) where {F} dps = Lux.recursive_make_zero!!(ts.cache.dparameters) params = Duplicated(ts.parameters, dps) diff --git a/ext/LuxReactantExt/LuxReactantExt.jl b/ext/LuxReactantExt/LuxReactantExt.jl new file mode 100644 index 0000000000..ce0e0cd062 --- /dev/null +++ b/ext/LuxReactantExt/LuxReactantExt.jl @@ -0,0 +1,14 @@ +module LuxReactantExt + +using Enzyme: Enzyme, Const, Duplicated, Active +using Optimisers: Optimisers +using Reactant: Reactant, @compile, TracedRArray +using Setfield: @set! +using Static: False + +using Lux: Lux, LuxOps, Training +using Lux.Training: TrainingBackendCache, ReactantBackend + +include("training.jl") + +end diff --git a/ext/LuxReactantExt/training.jl b/ext/LuxReactantExt/training.jl new file mode 100644 index 0000000000..182ca9c86d --- /dev/null +++ b/ext/LuxReactantExt/training.jl @@ -0,0 +1,92 @@ +function Lux.Training.compute_gradients_impl( + backend::ReactantBackend, objective_function::F, + data, ts::Training.TrainState) where {F} + compiled_gradient_function = @compile compute_gradients_internal( + objective_function, ts.model, data, ts.parameters, ts.states) + + grads, loss, stats, st = compiled_gradient_function( + objective_function, ts.model, data, ts.parameters, ts.states) + + cache = TrainingBackendCache(backend, False(), nothing, (; compiled_gradient_function)) + @set! ts.cache = cache + @set! ts.objective_function = objective_function + @set! ts.states = st + return grads, loss, stats, ts +end + +function Lux.Training.compute_gradients_impl(::ReactantBackend, obj_fn::F, data, + ts::Training.TrainState{<:TrainingBackendCache{ReactantBackend}, F}) where {F} + grads, loss, stats, st = ts.cache.extras.compiled_gradient_function( + obj_fn, ts.model, data, ts.parameters, ts.states) + @set! ts.states = st + return grads, loss, stats, ts +end + +function compute_gradients_internal(objective_function::F, model, data, ps, st) where {F} + dps = Enzyme.make_zero(ps) + _, (loss, stₙ, stats) = Enzyme.autodiff( + Enzyme.ReverseWithPrimal, Const(objective_function), Active, Const(model), + Duplicated(ps, dps), Const(st), Const(data)) + return dps, loss, stats, stₙ +end + +for inplace in ("!", "") + fname = Symbol(:single_train_step_impl, inplace) + internal_fn = Symbol(:compute_gradients_internal_and_step, inplace) + + @eval function Lux.Training.$(fname)(backend::ReactantBackend, objective_function::F, + data, ts::Training.TrainState) where {F} + compiled_grad_and_step_function = @compile $(internal_fn)( + objective_function, ts.model, data, ts.parameters, ts.states, + ts.optimizer_state) + + grads, ps, loss, stats, st, opt_state = compiled_grad_and_step_function( + objective_function, ts.model, data, ts.parameters, ts.states, + ts.optimizer_state) + + cache = TrainingBackendCache( + backend, False(), nothing, (; compiled_grad_and_step_function)) + @set! ts.cache = cache + @set! ts.objective_function = objective_function + @set! ts.states = st + @set! ts.parameters = ps + @set! ts.optimizer_state = opt_state + @set! ts.step = ts.step + 1 + + return grads, loss, stats, ts + end + + @eval function Lux.Training.$(fname)(::ReactantBackend, obj_fn::F, data, + ts::Training.TrainState{<:TrainingBackendCache{ReactantBackend}, F}) where {F} + grads, ps, loss, stats, st, opt_state = ts.cache.extras.compiled_grad_and_step_function( + obj_fn, ts.model, data, ts.parameters, ts.states, ts.optimizer_state) + + @set! ts.states = st + @set! ts.parameters = ps + @set! ts.optimizer_state = opt_state + @set! ts.step = ts.step + 1 + + return grads, loss, stats, ts + end +end + +function compute_gradients_internal_and_step(objective_function::F, model, data, ps, + st, opt_state) where {F} + dps = Enzyme.make_zero(ps) + _, (loss, stₙ, stats) = Enzyme.autodiff( + Enzyme.ReverseWithPrimal, Const(objective_function), Active, Const(model), + Duplicated(ps, dps), Const(st), Const(data)) + opt_state, ps = Optimisers.update(opt_state, ps, dps) + return dps, ps, loss, stats, stₙ, opt_state +end + +function compute_gradients_internal_and_step!(objective_function::F, model, data, ps, + st, opt_state) where {F} + dps = Enzyme.make_zero(ps) + _, (loss, stₙ, stats) = Enzyme.autodiff( + Enzyme.ReverseWithPrimal, Const(objective_function), Active, Const(model), + Duplicated(ps, dps), Const(st), Const(data)) + # XXX: Inplace updates not actually inplace + opt_state, ps = Optimisers.update!(opt_state, ps, dps) + return dps, ps, loss, stats, stₙ, opt_state +end diff --git a/ext/LuxReverseDiffExt/training.jl b/ext/LuxReverseDiffExt/training.jl index 33bf01eb99..6deaf63788 100644 --- a/ext/LuxReverseDiffExt/training.jl +++ b/ext/LuxReverseDiffExt/training.jl @@ -1,5 +1,5 @@ # Uncompiled ReverseDiff -function Lux.Training.compute_gradients( +function Lux.Training.compute_gradients_impl( ad::AutoReverseDiff{false}, obj_fn::F, data, ts::TrainState) where {F} @set! ts.cache = TrainingBackendCache( ad, True(), Lux.recursive_make_zero(ts.parameters), nothing) @@ -7,7 +7,7 @@ function Lux.Training.compute_gradients( return Lux.Training.compute_gradients(ad, obj_fn, data, ts) end -function Lux.Training.compute_gradients(::AutoReverseDiff{false}, obj_fn::F, data, +function Lux.Training.compute_gradients_impl(::AutoReverseDiff{false}, obj_fn::F, data, ts::TrainState{<:TrainingBackendCache{AutoReverseDiff{false}}}) where {F} dparams = Training.dparameters(ts.cache) tape = ReverseDiff.InstructionTape() @@ -24,7 +24,7 @@ function Lux.Training.compute_gradients(::AutoReverseDiff{false}, obj_fn::F, dat end # Compiled ReverseDiff -function Lux.Training.compute_gradients( +function Lux.Training.compute_gradients_impl( ad::AutoReverseDiff{true}, obj_fn::F, data, ts::TrainState) where {F} @set! ts.cache = TrainingBackendCache( ad, True(), Lux.recursive_make_zero(ts.parameters), @@ -35,7 +35,7 @@ function Lux.Training.compute_gradients( end ## Tape hasn't been compiled yet / Function mismatch so recompile -function Lux.Training.compute_gradients(ad::AutoReverseDiff{true}, obj_fn::F, data, +function Lux.Training.compute_gradients_impl(ad::AutoReverseDiff{true}, obj_fn::F, data, ts::TrainState{<:TrainingBackendCache{AutoReverseDiff{true}}}) where {F} if LuxCore.statelength(ts.states) != 0 throw(ArgumentError("AutoReverseDiff(; compile=true) is not supported for Lux \ @@ -82,7 +82,7 @@ function Lux.Training.compute_gradients(ad::AutoReverseDiff{true}, obj_fn::F, da return dparams, ReverseDiff.value(loss), NamedTuple(), ts end -function Lux.Training.compute_gradients(::AutoReverseDiff{true}, obj_fn::F, data, +function Lux.Training.compute_gradients_impl(::AutoReverseDiff{true}, obj_fn::F, data, ts::TrainState{<:TrainingBackendCache{AutoReverseDiff{true}}, F}) where {F} (; ps_cache, data_cache, output) = ts.cache.extras diff --git a/ext/LuxTrackerExt/training.jl b/ext/LuxTrackerExt/training.jl index 0e0880b416..982d708f94 100644 --- a/ext/LuxTrackerExt/training.jl +++ b/ext/LuxTrackerExt/training.jl @@ -1,4 +1,4 @@ -function Lux.Training.compute_gradients(::AutoTracker, obj_fn::F, data, +function Lux.Training.compute_gradients_impl(::AutoTracker, obj_fn::F, data, ts::TrainState{<:TrainingBackendCache{AutoTracker}}) where {F} dps = Training.dparameters(ts.cache) ps_tracked = construct_tracked_params(ts.parameters, dps) @@ -13,7 +13,7 @@ function Lux.Training.compute_gradients(::AutoTracker, obj_fn::F, data, return dps, loss.data, stats, ts end -function Lux.Training.compute_gradients( +function Lux.Training.compute_gradients_impl( ad::AutoTracker, obj_fn::F, data, ts::TrainState) where {F} grads = Lux.recursive_make_zero(ts.parameters) cache = TrainingBackendCache(ad, True(), grads, nothing) diff --git a/ext/LuxZygoteExt/training.jl b/ext/LuxZygoteExt/training.jl index 3832800cd3..83d999afa0 100644 --- a/ext/LuxZygoteExt/training.jl +++ b/ext/LuxZygoteExt/training.jl @@ -1,4 +1,4 @@ -function Lux.Training.compute_gradients( +function Lux.Training.compute_gradients_impl( ::AutoZygote, objective_function::F, data, ts::Lux.Training.TrainState) where {F} (loss, st, stats), back = Zygote.pullback( objective_function, ts.model, ts.parameters, ts.states, data) diff --git a/src/helpers/losses.jl b/src/helpers/losses.jl index 1f38c36eab..1af021e4cd 100644 --- a/src/helpers/losses.jl +++ b/src/helpers/losses.jl @@ -120,7 +120,8 @@ function huber_loss(x::T1, y::T2, δ::T3) where {T1, T2, T3} T = promote_type(T1, T2, T3) diff = x - y abs_diff = abs(diff) - return ifelse(abs_diff ≤ δ, T(0.5) * abs2(diff), δ * (abs_diff - T(0.5) * δ)) + return ifelse( + abs_diff ≤ δ, convert(T, 0.5) * abs2(diff), δ * (abs_diff - convert(T, 0.5) * δ)) end has_custom_derivative(::typeof(huber_loss)) = true function derivative(::typeof(huber_loss), x::T, y::T2, δ::T3) where {T, T2, T3} @@ -148,7 +149,7 @@ function derivative(::typeof(l2_hinge_loss), x::T1, y::T2) where {T1, T2} end function siamese_contrastive_loss(x::T1, y::T2, margin=true) where {T1, T2} - return (true - y) * x^2 + y * max(promote_type(T1, T2)(false), margin - x)^2 + return (true - y) * x^2 + y * max(convert(promote_type(T1, T2), false), margin - x)^2 end poisson_loss(x::T1, y::T2, ϵ) where {T1, T2} = x - xlogy(y, x + get_ϵ(T1, ϵ)) diff --git a/src/helpers/training.jl b/src/helpers/training.jl index 51fdb1a48a..c0e6644ffe 100644 --- a/src/helpers/training.jl +++ b/src/helpers/training.jl @@ -10,6 +10,7 @@ using Static: StaticBool, Static, False, True using ..Lux: Lux using LuxCore: LuxCore, AbstractLuxLayer +using MLDataDevices: XLADevice, get_device_type, get_device, cpu_device """ TrainState @@ -61,7 +62,13 @@ Constructor for [`TrainState`](@ref). [`TrainState`](@ref) object. """ function TrainState(model::AbstractLuxLayer, ps, st, optimizer::Optimisers.AbstractRule) - st_opt = Optimisers.setup(optimizer, ps) + dev = get_device(ps) + st_opt = if dev isa XLADevice + ps_cpu = ps |> cpu_device() + Optimisers.setup(optimizer, ps_cpu) |> dev + else + Optimisers.setup(optimizer, ps) + end return TrainState(nothing, nothing, model, ps, st, optimizer, st_opt, 0) end @@ -96,6 +103,8 @@ function Base.show(io::IO, ::MIME"text/plain", ts::TrainState) print(io, "\n objective_function: ", nameof(typeof(ts.objective_function))) end +struct ReactantBackend end + const APPLY_GRAD_DOCSTRING = """ ## Arguments @@ -183,7 +192,20 @@ A 4-Tuple containing: returned in step `i + 1` might be aliased by the old gradients. If you want to prevent this, simply use `copy(grads)` or `deepcopy(grads)` to make a copy of the gradients. """ -function compute_gradients(ad::AbstractADType, ::F, _, ::TrainState) where {F} +function compute_gradients(ad, obj_fn::F, data, ts::TrainState) where {F} + dev_type = get_device_type((ts.parameters, ts.states)) + return compute_gradients_impl(maybe_wrap_adtype(ad, dev_type), obj_fn, data, ts) +end + +maybe_wrap_adtype(backend::ReactantBackend, _) = backend +maybe_wrap_adtype(ad::AbstractADType, _) = ad +function maybe_wrap_adtype(ad::AbstractADType, ::Type{XLADevice}) + ad isa AutoEnzyme && return ReactantBackend() + throw(ArgumentError("Computing gradients for models on XLA is supported only with \ + Enzyme.jl (`AutoEnzyme`).")) +end + +function compute_gradients_impl(ad, ::F, _, ts::TrainState) where {F} return check_if_compute_gradients_implemented(ad) end @@ -192,6 +214,10 @@ function check_if_compute_gradients_implemented(::T) where {T <: AbstractADType} yet!")) end +function check_if_compute_gradients_implemented(::ReactantBackend) + throw(ArgumentError("Load `Reactant` with `using Reactant` before using this function!")) +end + for package in (:Zygote, :Tracker, :ReverseDiff, :Enzyme) adtype = Symbol(:Auto, package) msg = "Load `$(package)` with `using $(package)`/`import $(package)` before using this \ @@ -244,7 +270,10 @@ only the parameters in `ts` are updated inplace. Users should be using the retur object for further training steps, else there is no caching and performance will be suboptimal (and absolutely terrible for backends like `AutoReactant`). """ -function single_train_step! end +function single_train_step!(backend, obj_fn::F, data, ts::TrainState) where {F} + backend = maybe_wrap_adtype(backend, get_device_type((ts.parameters, ts.states))) + return single_train_step_impl!(backend, obj_fn, data, ts) +end """ single_train_step(backend, obj_fn::F, data, ts::TrainState) @@ -259,10 +288,14 @@ In most cases you should use [`single_train_step!`](@ref) instead of this functi Returned values are the same as [`compute_gradients`](@ref). """ -function single_train_step end +function single_train_step(backend, obj_fn::F, data, ts::TrainState) where {F} + backend = maybe_wrap_adtype(backend, get_device_type((ts.parameters, ts.states))) + return single_train_step_impl(backend, obj_fn, data, ts) +end for inplace in ("!", "") - step, apply_fn = Symbol(:single_train_step, inplace), Symbol(:apply_gradients, inplace) + step = Symbol(:single_train_step_impl, inplace) + apply_fn = Symbol(:apply_gradients, inplace) @eval function $(step)(backend, obj_fn::F, data, ts::TrainState) where {F} grads, loss, stats, ts = compute_gradients(backend, obj_fn, data, ts) ts = $(apply_fn)(ts, grads) diff --git a/src/utils.jl b/src/utils.jl index 1e2929dc7e..8de8408ca0 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -86,7 +86,7 @@ function pairs(x) return Base.pairs(x) end -@concrete struct Fix3 +@concrete struct Fix3 <: Function f x end diff --git a/test/reactant/loss_tests.jl b/test/reactant/loss_tests.jl new file mode 100644 index 0000000000..a98bf1a716 --- /dev/null +++ b/test/reactant/loss_tests.jl @@ -0,0 +1,240 @@ +@testitem "Compiled Loss Functions" tags=[:reactant] setup=[SharedTestSetup] begin + using Reactant, Lux, OneHotArrays + + rng = StableRNG(123) + + @testset "$(mode)" for (mode, atype, dev, ongpu) in MODES + if mode == "amdgpu" + @warn "Skipping AMDGPU tests for Reactant" + continue + end + + if ongpu + Reactant.set_default_backend("gpu") + else + Reactant.set_default_backend("cpu") + end + + @testset "xlogx & xlogy" begin + x = rand(rng, 10) + y = rand(rng, 10) + x_ra = Reactant.to_rarray(x) + y_ra = Reactant.to_rarray(y) + + fn1(x) = LuxOps.xlogx.(x) + fn2(x, y) = LuxOps.xlogy.(x, y) + + fn1_compiled = @compile fn1(x_ra) + @test fn1(x) ≈ fn1_compiled(x_ra) + + fn2_compiled = @compile fn2(x_ra, y_ra) + @test fn2(x, y) ≈ fn2_compiled(x_ra, y_ra) + end + + @testset "Regression Loss" begin + y = [1.0, 1.0, 0.0, 0.0] + ŷ = [0.9, 0.1, 0.1, 0.9] + + y_ra = Reactant.to_rarray(y) + ŷ_ra = Reactant.to_rarray(ŷ) + + @testset for loss in ("MSE", "MAE", "Huber") + loss_mean = eval(Symbol(loss * "Loss"))() + loss_sum = eval(Symbol(loss * "Loss"))(; agg=sum) + loss_sum2 = eval(Symbol(loss * "Loss"))(; agg=(args...) -> sum(args...)) + + loss_mean_compiled = @compile loss_mean(ŷ_ra, y_ra) + @test loss_mean(ŷ, y) ≈ loss_mean_compiled(ŷ_ra, y_ra) + + loss_sum_compiled = @compile loss_sum(ŷ_ra, y_ra) + @test loss_sum(ŷ, y) ≈ loss_sum_compiled(ŷ_ra, y_ra) + + loss_sum2_compiled = @compile loss_sum2(ŷ_ra, y_ra) + @test loss_sum2(ŷ, y) ≈ loss_sum2_compiled(ŷ_ra, y_ra) + end + + @testset "MSLE" begin + y = [123.0, 456.0, 789.0] + ŷ = [345.0, 332.0, 789.0] + + y_ra = Reactant.to_rarray(y) + ŷ_ra = Reactant.to_rarray(ŷ) + + loss_msle = MSLELoss() + loss_msle_compiled = @compile loss_msle(ŷ_ra, y_ra) + @test loss_msle(ŷ, y) ≈ loss_msle_compiled(ŷ_ra, y_ra) + end + end + + @testset "Classification Loss" begin + y = onehotbatch([1, 1, 0, 0], 0:1) |> Array + ŷ = [0.1 0.9; 0.9 0.1; 0.9 0.1; 0.1 0.9]' |> Array + + y_ra = Reactant.to_rarray(y) + ŷ_ra = Reactant.to_rarray(ŷ) + + @testset "CrossEntropyLoss" begin + celoss = CrossEntropyLoss() + celoss_compiled = @compile celoss(ŷ_ra, y_ra) + @test celoss(ŷ, y) ≈ celoss_compiled(ŷ_ra, y_ra) + + celoss_ls = CrossEntropyLoss(; label_smoothing=0.1) + celoss_ls_compiled = @compile celoss_ls(ŷ_ra, y_ra) + @test celoss_ls(ŷ, y) ≈ celoss_ls_compiled(ŷ_ra, y_ra) + + celoss_lp = CrossEntropyLoss(; logits=Val(true)) + celoss_lp_compiled = @compile celoss_lp(log.(ŷ_ra), y_ra) + @test celoss_lp(log.(ŷ), y) ≈ celoss_lp_compiled(log.(ŷ_ra), y_ra) + + celoss_lp_ls = CrossEntropyLoss(; logits=Val(true), label_smoothing=0.1) + celoss_lp_ls_compiled = @compile celoss_lp_ls(log.(ŷ_ra), y_ra) + @test celoss_lp_ls(log.(ŷ), y) ≈ celoss_lp_ls_compiled(log.(ŷ_ra), y_ra) + end + + @testset "Binary CrossEntropyLoss" begin + bceloss = BinaryCrossEntropyLoss() + bceloss_compiled = @compile bceloss(ŷ_ra, y_ra) + @test bceloss(ŷ, y) ≈ bceloss_compiled(ŷ_ra, y_ra) + + bceloss_ls = BinaryCrossEntropyLoss(; label_smoothing=0.1) + bceloss_ls_compiled = @compile bceloss_ls(ŷ_ra, y_ra) + @test bceloss_ls(ŷ, y) ≈ bceloss_ls_compiled(ŷ_ra, y_ra) + + bceloss_lp = BinaryCrossEntropyLoss(; logits=Val(true)) + bceloss_lp_compiled = @compile bceloss_lp(log.(ŷ_ra), y_ra) + @test bceloss_lp(log.(ŷ), y) ≈ bceloss_lp_compiled(log.(ŷ_ra), y_ra) + + bceloss_lp_ls = BinaryCrossEntropyLoss(; + logits=Val(true), label_smoothing=0.1) + bceloss_lp_ls_compiled = @compile bceloss_lp_ls(log.(ŷ_ra), y_ra) + @test bceloss_lp_ls(log.(ŷ), y) ≈ bceloss_lp_ls_compiled(log.(ŷ_ra), y_ra) + end + + @testset "BinaryFocalLoss" begin + y = [0 1 0 + 1 0 1] + ŷ = [0.268941 0.5 0.268941 + 0.731059 0.5 0.731059] + + y_ra = Reactant.to_rarray(y) + ŷ_ra = Reactant.to_rarray(ŷ) + + bfl = BinaryFocalLoss() + bfl_compiled = @compile bfl(ŷ_ra, y_ra) + @test bfl(ŷ, y) ≈ bfl_compiled(ŷ_ra, y_ra) + end + + @testset "FocalLoss" begin + y = [1 0 0 0 1 + 0 1 0 1 0 + 0 0 1 0 0] + ŷ = softmax(reshape(-7:7, 3, 5) .* 1.0f0) |> Array + + y_ra = Reactant.to_rarray(y) + ŷ_ra = Reactant.to_rarray(ŷ) + + fl = FocalLoss() + fl_compiled = @compile fl(ŷ_ra, y_ra) + @test fl(ŷ, y) ≈ fl_compiled(ŷ_ra, y_ra) + end + end + + @testset "Other Losses" begin + @testset "KLDivergenceLoss" begin + y = [1.0 2.0 3.0] + ŷ = [4.0 5.0 6.0] + + y_ra = Reactant.to_rarray(y) + ŷ_ra = Reactant.to_rarray(ŷ) + + kldl = KLDivergenceLoss() + kldl_compiled = @compile kldl(ŷ_ra, y_ra) + @test kldl(ŷ, y) ≈ kldl_compiled(ŷ_ra, y_ra) + end + + @testset "HingeLoss" begin + y = [1.0, 2.0, 3.0, 4.0] + ŷ = [5.0, 6.0, 7.0, 8.0] + + y_ra = Reactant.to_rarray(y) + ŷ_ra = Reactant.to_rarray(ŷ) + + hl = HingeLoss() + hl_compiled = @compile hl(ŷ_ra, y_ra) + @test hl(ŷ, y) ≈ hl_compiled(ŷ_ra, y_ra) + + hl = HingeLoss(; agg=mean) + hl_compiled = @compile hl(ŷ_ra, y_ra) + @test hl(ŷ, y) ≈ hl_compiled(ŷ_ra, y_ra) + end + + @testset "SquaredHingeLoss" begin + y = [1.0, 2.0, 3.0, 4.0] + ŷ = [5.0, 6.0, 7.0, 8.0] + + y_ra = Reactant.to_rarray(y) + ŷ_ra = Reactant.to_rarray(ŷ) + + hl = SquaredHingeLoss() + hl_compiled = @compile hl(ŷ_ra, y_ra) + @test hl(ŷ, y) ≈ hl_compiled(ŷ_ra, y_ra) + + hl = SquaredHingeLoss(; agg=mean) + hl_compiled = @compile hl(ŷ_ra, y_ra) + @test hl(ŷ, y) ≈ hl_compiled(ŷ_ra, y_ra) + end + + @testset "PoissonLoss" begin + y = [0.1, 0.2, 0.3] + ŷ = [0.4, 0.5, 0.6] + + y_ra = Reactant.to_rarray(y) + ŷ_ra = Reactant.to_rarray(ŷ) + + pl = PoissonLoss() + pl_compiled = @compile pl(ŷ_ra, y_ra) + @test pl(ŷ, y) ≈ pl_compiled(ŷ_ra, y_ra) + + pl = PoissonLoss(; agg=mean) + pl_compiled = @compile pl(ŷ_ra, y_ra) + @test pl(ŷ, y) ≈ pl_compiled(ŷ_ra, y_ra) + end + + @testset "DiceCoeffLoss" begin + y = [1.0, 0.5, 0.3, 2.4] + ŷ = [0.0, 1.4, 0.5, 1.2] + + y_ra = Reactant.to_rarray(y) + ŷ_ra = Reactant.to_rarray(ŷ) + + dl = DiceCoeffLoss() + dl_compiled = @compile dl(ŷ_ra, y_ra) + @test dl(ŷ, y) ≈ dl_compiled(ŷ_ra, y_ra) + + dl = DiceCoeffLoss(; agg=mean) + dl_compiled = @compile dl(ŷ_ra, y_ra) + @test dl(ŷ, y) ≈ dl_compiled(ŷ_ra, y_ra) + end + + @testset "Siamese Contrastive Loss" begin + y = [1.0 0.0 + 0.0 0.0 + 0.0 1.0] + ŷ = [0.4 0.2 + 0.5 0.5 + 0.1 0.3] + + y_ra = Reactant.to_rarray(y) + ŷ_ra = Reactant.to_rarray(ŷ) + + sl = SiameseContrastiveLoss() + sl_compiled = @compile sl(ŷ_ra, y_ra) + @test sl(ŷ, y) ≈ sl_compiled(ŷ_ra, y_ra) + + sl = SiameseContrastiveLoss(; agg=mean) + sl_compiled = @compile sl(ŷ_ra, y_ra) + @test sl(ŷ, y) ≈ sl_compiled(ŷ_ra, y_ra) + end + end + end +end diff --git a/test/reactant/training_tests.jl b/test/reactant/training_tests.jl new file mode 100644 index 0000000000..b3b27969c3 --- /dev/null +++ b/test/reactant/training_tests.jl @@ -0,0 +1,62 @@ +@testitem "Reactant: Training API" tags=[:reactant] setup=[SharedTestSetup] begin + using Reactant, Optimisers + + @testset "$(mode)" for (mode, atype, dev, ongpu) in MODES + if mode == "amdgpu" + @warn "Skipping AMDGPU tests for Reactant" + continue + end + + if ongpu + Reactant.set_default_backend("gpu") + else + Reactant.set_default_backend("cpu") + end + + xdev = xla_device(; force=true) + + @testset "MLP Training: $(version)" for version in (:iip, :oop) + model = Chain( + Dense(2 => 32, gelu), + Dense(32 => 32, gelu), + Dense(32 => 2) + ) + ps, st = Lux.setup(StableRNG(1234), model) |> xdev + + x_ra = randn(Float32, 2, 32) |> xdev + + inference_fn = @compile model(x_ra, ps, Lux.testmode(st)) + + x = [rand(Float32, 2, 32) for _ in 1:32] + y = [xᵢ .^ 2 for xᵢ in x] + + dataloader = DeviceIterator(xdev, zip(x, y)) + + total_initial_loss = mapreduce(+, dataloader) do (xᵢ, yᵢ) + ŷᵢ, _ = inference_fn(xᵢ, ps, Lux.testmode(st)) + return MSELoss()(ŷᵢ, yᵢ) + end + + train_state = Training.TrainState(model, ps, st, Adam(0.01f0)) + + for epoch in 1:100, (xᵢ, yᵢ) in dataloader + grads, loss, stats, train_state = if version === :iip + Training.single_train_step!( + AutoEnzyme(), MSELoss(), (xᵢ, yᵢ), train_state) + elseif version === :oop + Training.single_train_step( + AutoEnzyme(), MSELoss(), (xᵢ, yᵢ), train_state) + else + error("Invalid version: $(version)") + end + end + + total_final_loss = mapreduce(+, dataloader) do (xᵢ, yᵢ) + ŷᵢ, _ = inference_fn(xᵢ, train_state.parameters, Lux.testmode(st)) + return MSELoss()(ŷᵢ, yᵢ) + end + + @test total_final_loss < 100 * total_initial_loss + end + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 679e0ae002..6d311c8aad 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -8,6 +8,8 @@ const ALL_LUX_TEST_GROUPS = [ "core_layers", "contrib", "helpers", "distributed", "normalize_layers", "others", "autodiff", "recurrent_layers", "fluxcompat"] +Sys.iswindows() || push!(ALL_LUX_TEST_GROUPS, "reactant") + INPUT_TEST_GROUP = lowercase(get(ENV, "LUX_TEST_GROUP", "all")) const LUX_TEST_GROUP = if startswith("!", INPUT_TEST_GROUP[1]) exclude_group = lowercase.(split(INPUT_TEST_GROUP[2:end], ",")) @@ -26,6 +28,12 @@ if ("all" in LUX_TEST_GROUP || "distributed" in LUX_TEST_GROUP) end ("all" in LUX_TEST_GROUP || "fluxcompat" in LUX_TEST_GROUP) && push!(EXTRA_PKGS, Pkg.PackageSpec("Flux")) + +if !Sys.iswindows() + ("all" in LUX_TEST_GROUP || "reactant" in LUX_TEST_GROUP) && + push!(EXTRA_PKGS, Pkg.PackageSpec("Reactant")) +end + (BACKEND_GROUP == "all" || BACKEND_GROUP == "cuda") && push!(EXTRA_PKGS, Pkg.PackageSpec("LuxCUDA")) (BACKEND_GROUP == "all" || BACKEND_GROUP == "amdgpu") &&