From 9d8ff053abbeb0706ad3e84e4adb646c60529f6f Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Sat, 17 Jan 2026 21:45:18 +0100 Subject: [PATCH 1/6] generic train! --- ext/FluxEnzymeExt.jl | 82 +++++++++++++++++--------------------------- src/Flux.jl | 1 + src/gradient.jl | 3 +- src/train.jl | 80 +++++++++++++++++------------------------- 4 files changed, 65 insertions(+), 101 deletions(-) diff --git a/ext/FluxEnzymeExt.jl b/ext/FluxEnzymeExt.jl index 0969c854d8..37ece51278 100644 --- a/ext/FluxEnzymeExt.jl +++ b/ext/FluxEnzymeExt.jl @@ -1,14 +1,12 @@ module FluxEnzymeExt using Flux -import Flux.Train: _enzyme_train! import Optimisers import Functors import Enzyme using Enzyme: EnzymeCore, EnzymeRules, Active, Const, Duplicated, autodiff, ReverseWithPrimal, DuplicatedNoNeed using Enzyme: autodiff_thunk, Reverse, ReverseSplitWithPrimal -using ProgressLogging: @withprogress, @logprogress EnzymeRules.inactive(::typeof(Flux.Losses._check_sizes), args...) = true @@ -28,13 +26,13 @@ _trymake_duplicated(x) = EnzymeCore.Duplicated(x, EnzymeCore.make_zero(x)) function _enzyme_gradient(f, args::Union{Const, Duplicated}...; zero::Bool=true) - for x in args - zero && x isa Duplicated && EnzymeCore.remake_zero!(x.dval) - _check_mutable(x) - end - ad = Enzyme.set_runtime_activity(Reverse) - Enzyme.autodiff(ad, Const(f), Active, args...) - return map(_grad_or_nothing, args) + for x in args + zero && x isa Duplicated && EnzymeCore.remake_zero!(x.dval) + _check_mutable(x) + end + ad = Enzyme.set_runtime_activity(Reverse) + Enzyme.autodiff(ad, Const(f), Active, args...) + return map(_grad_or_nothing, args) end _check_mutable(x::Const) = nothing @@ -48,30 +46,30 @@ _grad_or_nothing(::Const) = nothing _grad_or_nothing(x) = Optimisers.isnumeric(x) ? x : nothing function _enzyme_withgradient(f, args::Union{Const, Duplicated}...; zero::Bool=true) - for x in args - zero && x isa Duplicated && EnzymeCore.remake_zero!(x.dval) - _check_mutable(x) - end - - # In order to support auxillary outputs, we try different ways. - - ## Take I, doesn't allow for aux at all. - ad = Enzyme.set_runtime_activity(ReverseWithPrimal) - _, result = Enzyme.autodiff(ReverseWithPrimal, Const(f), Active, args...) - - ## Take II, using split mode. - ## This fails with RNNs https://github.com/EnzymeAD/Enzyme.jl/issues/2897 - # forward, reverse = autodiff_thunk(ReverseSplitWithPrimal, Const{typeof(f)}, Active, map(typeof, args)...) - # tape, result, shadow_result = forward(Const(f), args...) - # reverse(Const(f), args..., _sensitivity(result), tape) - - ## Take III, it may be more efficient to have the function write the loss into Ref(0.0)? - ## This doesn't work with Reactant - # dup_loss = DuplicatedNoNeed(Ref(0f0), Ref(1f0)) - # ad = Enzyme.set_runtime_activity(ReverseWithPrimal) - # _, result = autodiff(ad, Const(_ref_loss!), Const, dup_loss, Const(f), args...) - - return (; val = result, grad = map(_grad_or_nothing, args)) + for x in args + zero && x isa Duplicated && EnzymeCore.remake_zero!(x.dval) + _check_mutable(x) + end + + # In order to support auxillary outputs, we try different ways. + + ## Take I, doesn't allow for aux at all. + ad = Enzyme.set_runtime_activity(ReverseWithPrimal) + _, result = Enzyme.autodiff(ReverseWithPrimal, Const(f), Active, args...) + + ## Take II, using split mode. + ## This fails with RNNs https://github.com/EnzymeAD/Enzyme.jl/issues/2897 + # forward, reverse = autodiff_thunk(ReverseSplitWithPrimal, Const{typeof(f)}, Active, map(typeof, args)...) + # tape, result, shadow_result = forward(Const(f), args...) + # reverse(Const(f), args..., _sensitivity(result), tape) + + ## Take III, it may be more efficient to have the function write the loss into Ref(0.0)? + ## This doesn't work with Reactant + # dup_loss = DuplicatedNoNeed(Ref(0f0), Ref(1f0)) + # ad = Enzyme.set_runtime_activity(ReverseWithPrimal) + # _, result = autodiff(ad, Const(_ref_loss!), Const, dup_loss, Const(f), args...) + + return (; val = result, grad = map(_grad_or_nothing, args)) end ## for Take II above @@ -94,22 +92,4 @@ end # or else a Tuple or NamedTuple whose first element is a real number.""") -### Flux.Train, for train! - -function _enzyme_train!(loss, model::Duplicated, data, opt; cb = nothing) - isnothing(cb) || error("""train! does not support callback functions. - For more control use a loop with `gradient` and `update!`.""") - @withprogress for (i,d) in enumerate(data) - d_splat = d isa Tuple ? d : (d,) - l, gs = Flux.withgradient(loss, AutoEnzyme(), model, map(Const, d_splat)...) - if !isfinite(l) - throw(DomainError(lazy"Loss is $l on data item $i, stopping training")) - end - opt, model2 = Optimisers.update!(opt, model.val, model.dval) - model = Duplicated(model2, model.dval) - - @logprogress Base.haslength(data) ? i/length(data) : nothing - end -end - end # FluxEnzymeExt diff --git a/src/Flux.jl b/src/Flux.jl index 657d9b9a1c..3cb6fc8963 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -26,6 +26,7 @@ using Zygote.ForwardDiff: value using EnzymeCore: EnzymeCore @reexport using ADTypes # AutoZygote, AutoMooncake, etc... +using ADTypes: AbstractADType @reexport using MLDataDevices: MLDataDevices, supported_gpu_backends, reset_gpu_device!, default_device_rng, diff --git a/src/gradient.jl b/src/gradient.jl index 8ee4d1eebf..66c37b318c 100644 --- a/src/gradient.jl +++ b/src/gradient.jl @@ -48,13 +48,12 @@ julia> Flux.gradient(f, AutoMooncake(), [1.0, 2.0, 3.0]) ([2.0, 2.0, 2.0],) ``` """ -function gradient(f, adtype::ADTypes.AbstractADType, args...) +function gradient(f, adtype::AbstractADType, args...) error("AD backend has to be loaded to use `gradient(f, AutoXXX(), args...)`. Make sure to `using` the corresponding package, e.g. `using Mooncake` for `AutoMooncake()`. Supported backends are $SUPPORTED_AD_BACKENDS.") end - # Default gradient using Zygote function gradient(f, args...; zero::Bool=true) for a in args diff --git a/src/train.jl b/src/train.jl index dd2f2284e3..e00f688f43 100644 --- a/src/train.jl +++ b/src/train.jl @@ -6,14 +6,11 @@ using Functors: fmap, fmapstructure using ..Flux: Flux using ProgressLogging: @progress, @withprogress, @logprogress -using Zygote: Zygote +using EnzymeCore: Duplicated +using ADTypes: AbstractADType export setup, train! -using ProgressLogging: @progress, @withprogress, @logprogress -using Zygote: Zygote -using EnzymeCore: Duplicated - """ opt_state = setup(rule, model) @@ -49,7 +46,7 @@ function setup(rule::Optimisers.AbstractRule, model) state = Optimisers.setup(rule, model) # This check only needs foreach; using fmap caused https://github.com/FluxML/Flux.jl/issues/2144 fmapstructure(model, exclude = Optimisers.isnumeric) do x - Optimisers.maywrite(x) || error("""model must be fully mutable for `train!` to work, got `x::$(typeof(x))`. + Optimisers.maywrite(x) || error("""model must be fully mutable for `train!` to work, got `x::$(typeof(x))`. If `x .+= dx` is in fact ok, define `Optimisers.maywrite(::$(typeof(x))) = true`""") end return state @@ -63,15 +60,17 @@ Special method for use with Enzyme.jl, ignores the stored gradient. setup(rule::Optimisers.AbstractRule, model::Duplicated) = setup(rule, model.val) """ - train!(loss, model, data, opt_state) + train!(loss, [adtype,] model, data, opt_state) Uses a `loss` function and training `data` to improve the `model`'s parameters according to a particular optimisation rule encoded in `opt_state`. + Iterates through `data` once, evaluating for each `d in data` either `loss(model, d...)` if `d isa Tuple`, or else `loss(model, d)` for other `d`. -If `model` is an Enzyme.Duplicated and `Enzyme.jl` is loaded, gradients will be computed with Enzyme, -otherwise they will be computed with Zygote. +The optional argument `adtype`, selects an automatic differentiation engine among the ones supported by +[`gradient`](@ref). If no `adtype` is given, then Zygote is used by default, unless `model` is of type `Duplicated` from Enzyme.jl, +in which case Enzyme is used. For example, with these definitions... ``` @@ -108,62 +107,47 @@ It adds only a few features to the loop above: * Callback functions are not supported. (But any code can be included in the above `for` loop.) """ -function train!(loss, model, data, opt; cb = nothing) - isnothing(cb) || error("""train! does not support callback functions. - For more control use a loop with `gradient` and `update!`.""") - @withprogress for (i,d) in enumerate(data) - d_splat = d isa Tuple ? d : (d,) +function train!(loss, adtype::AbstractADType, model, data, opt; cb = nothing) + isnothing(cb) || error("""train! does not support callback functions. + For more control use a loop with `gradient` and `update!`.""") + @withprogress for (i,d) in enumerate(data) + d_splat = d isa Tuple ? d : (d,) - l, gs = Zygote.withgradient(m -> loss(m, d_splat...), model) + l, gs = Flux.withgradient(m -> loss(m, d_splat...), adtype, model) - if !isfinite(l) - throw(DomainError(lazy"Loss is $l on data item $i, stopping training")) - end + if !isfinite(l) + throw(DomainError(lazy"Loss is $l on data item $i, stopping training")) + end - opt, model = Optimisers.update!(opt, model, gs[1]) + opt, model = Optimisers.update!(opt, model, gs[1]) - @logprogress Base.haslength(data) ? i/length(data) : nothing - end + @logprogress Base.haslength(data) ? i/length(data) : nothing + end end # This method let you use Optimisers.Descent() without setup, when there is no state function train!(loss, model, data, rule::Optimisers.AbstractRule; cb = nothing) - train!(loss, model, data, _rule_to_state(model, rule); cb) + return train!(loss, model, data, _rule_to_state(model, rule); cb) end function _rule_to_state(model, rule::Optimisers.AbstractRule) - state = setup(rule, model) - @gensym warn_id - name = typeof(rule).name.name - fmap(state, exclude = x -> x isa Optimisers.Leaf) do leaf - leaf.state isa Nothing || @warn """Optimiser $name has state which will be discarded after `train!` finishes. - Please run `opt = Flux.setup($name(), model)` and pass this `opt` to `train!`.""" leaf maxlog=1 _id=warn_id - leaf - end - state + state = setup(rule, model) + @gensym warn_id + name = typeof(rule).name.name + fmap(state, exclude = x -> x isa Optimisers.Leaf) do leaf + leaf.state isa Nothing || @warn """Optimiser $name has state which will be discarded after `train!` finishes. + Please run `opt = Flux.setup($name(), model)` and pass this `opt` to `train!`.""" leaf maxlog=1 _id=warn_id + leaf + end + return state end -""" - train!(loss, Duplicated(model), data, opt_state) - -This method uses Enzyme.jl instead of Zygote.jl to compute the gradients, -but is otherwise the same as `train!(loss, model, data, opt_state)`. - -Only available when Enzyme is loaded. - -!!! compat "New" - This method was added in Flux 0.13.9. - -""" -train!(loss, model::Duplicated, data, opt; cb = nothing) = _enzyme_train!(loss, model, data, opt; cb = nothing) - -# FluxEnzymeExt defines more specific _enzyme_train!(loss, model::Duplicated, data, opt; cb) -_enzyme_train!(loss, model, data, opt; cb = nothing) = throw(ArgumentError("The method `train!(loss, Duplicated(model), data, opt_state)` is only available when Enzyme.jl is loaded")) +train!(loss, model::Duplicated, data, opt; cb = nothing) = train!(loss, AutoEnzyme(), model, data, opt; cb) # This method let you use Optimisers.Descent() without setup, when there is no state function train!(loss, model::Duplicated, data, rule::Optimisers.AbstractRule; cb=nothing) - train!(loss, model, data, _rule_to_state(model, rule); cb) + return train!(loss, model, data, _rule_to_state(model, rule); cb) end end # module Train From 56d608cc860ff4c9250c9d9d3141475957f29eb8 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Sat, 17 Jan 2026 22:51:54 +0100 Subject: [PATCH 2/6] fix --- src/train.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/train.jl b/src/train.jl index e00f688f43..6cc0b2f84f 100644 --- a/src/train.jl +++ b/src/train.jl @@ -125,6 +125,7 @@ function train!(loss, adtype::AbstractADType, model, data, opt; cb = nothing) end end +train!(loss, model, data, opt; cb = nothing) = train!(loss, AutoZygote(), model, data, opt; cb) # This method let you use Optimisers.Descent() without setup, when there is no state function train!(loss, model, data, rule::Optimisers.AbstractRule; cb = nothing) From 301f8c2ab2856e4c1c52688d53d62cacb93e47fd Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Sat, 17 Jan 2026 23:42:28 +0100 Subject: [PATCH 3/6] fix --- src/train.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/train.jl b/src/train.jl index 6cc0b2f84f..9bc09a664f 100644 --- a/src/train.jl +++ b/src/train.jl @@ -7,7 +7,7 @@ using ..Flux: Flux using ProgressLogging: @progress, @withprogress, @logprogress using EnzymeCore: Duplicated -using ADTypes: AbstractADType +using ADTypes: AbstractADType, AutoEnzyme, AutoZygote export setup, train! From cdfb0ec5f64062c4c6a2683aafabe2936db2f161 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Sun, 18 Jan 2026 09:01:46 +0100 Subject: [PATCH 4/6] fix --- src/train.jl | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/train.jl b/src/train.jl index 9bc09a664f..5ab9e6e11a 100644 --- a/src/train.jl +++ b/src/train.jl @@ -119,12 +119,20 @@ function train!(loss, adtype::AbstractADType, model, data, opt; cb = nothing) throw(DomainError(lazy"Loss is $l on data item $i, stopping training")) end - opt, model = Optimisers.update!(opt, model, gs[1]) + opt, model = _update!(opt, model, gs[1]) @logprogress Base.haslength(data) ? i/length(data) : nothing end end +_update!(opt_state, model, grads) = Optimisers.update!(opt_state, model, grads) + +function _update!(opt_state, model::Duplicated, grad) + opt_state, model2 = Optimisers.update!(opt_state, model.val, grad) + return opt_state, Duplicated(model2, model.dval) +end + + train!(loss, model, data, opt; cb = nothing) = train!(loss, AutoZygote(), model, data, opt; cb) # This method let you use Optimisers.Descent() without setup, when there is no state From 9a0c5346a48a4488016f91067480033436a21ef2 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Sun, 18 Jan 2026 09:34:33 +0100 Subject: [PATCH 5/6] fix docs --- docs/src/reference/training/reference.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/src/reference/training/reference.md b/docs/src/reference/training/reference.md index 6c825709e5..256e8ba0c8 100644 --- a/docs/src/reference/training/reference.md +++ b/docs/src/reference/training/reference.md @@ -18,7 +18,7 @@ The available optimization rules are listed the [optimisation rules](@ref man-op ```@docs Flux.Train.setup -Flux.Train.train!(loss, model, data, state) +Flux.Train.train! Optimisers.update Optimisers.update! Optimisers.setup From 3112bd3b6b649f8a2c17ee4bd9b61829dda225e8 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Sun, 18 Jan 2026 10:17:33 +0100 Subject: [PATCH 6/6] fix docs --- docs/src/reference/training/reference.md | 5 ----- 1 file changed, 5 deletions(-) diff --git a/docs/src/reference/training/reference.md b/docs/src/reference/training/reference.md index 256e8ba0c8..d8e2eaef4c 100644 --- a/docs/src/reference/training/reference.md +++ b/docs/src/reference/training/reference.md @@ -36,11 +36,6 @@ julia> opt_state = Flux.setup(Adam(0), model); julia> Flux.train!((m,x,y) -> sum(abs2, m(x) .- y), dup_model, [(x1, y1)], opt_state) ``` -```@docs -Flux.train!(loss, model::Flux.EnzymeCore.Duplicated, data, opt) -``` - - ## Optimisation Modifiers The state returned by `setup` can be modified to temporarily prevent training of