Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
NCCL = "3fe64909-d7a1-4096-9b7d-7a0f12cf0f6b"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Reactant = "3c362404-f566-11ee-1572-e11a4b42c853"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
SimpleChains = "de6bee2f-e2f4-4ec7-b6ed-219cc6f6e9e5"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Expand All @@ -57,6 +58,7 @@ LuxMLUtilsExt = "MLUtils"
LuxMPIExt = "MPI"
LuxMPINCCLExt = ["CUDA", "MPI", "NCCL"]
LuxOptimisersExt = "Optimisers"
LuxReactantExt = ["Enzyme", "Reactant"]
LuxReverseDiffExt = "ReverseDiff"
LuxSimpleChainsExt = "SimpleChains"
LuxTrackerExt = "Tracker"
Expand Down Expand Up @@ -102,6 +104,7 @@ Pkg = "1.10"
PrecompileTools = "1.2"
Preferences = "1.4.3"
Random = "1.10"
Reactant = "0.1.1"
ReTestItems = "1.23.1"
Reexport = "1.2.2"
ReverseDiff = "1.15"
Expand Down Expand Up @@ -134,6 +137,7 @@ MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
NCCL = "3fe64909-d7a1-4096-9b7d-7a0f12cf0f6b"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Reactant = "3c362404-f566-11ee-1572-e11a4b42c853"
ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
SimpleChains = "de6bee2f-e2f4-4ec7-b6ed-219cc6f6e9e5"
Expand All @@ -144,4 +148,4 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Aqua", "ComponentArrays", "DynamicExpressions", "Enzyme", "ExplicitImports", "FiniteDifferences", "Flux", "ForwardDiff", "Logging", "LuxAMDGPU", "LuxCUDA", "LuxTestUtils", "MLUtils", "MPI", "NCCL", "Optimisers", "Pkg", "ReTestItems", "ReverseDiff", "SimpleChains", "StableRNGs", "Statistics", "Test", "Tracker", "Zygote"]
test = ["Aqua", "ComponentArrays", "DynamicExpressions", "Enzyme", "ExplicitImports", "FiniteDifferences", "Flux", "ForwardDiff", "Logging", "LuxAMDGPU", "LuxCUDA", "LuxTestUtils", "MLUtils", "MPI", "NCCL", "Optimisers", "Pkg", "Reactant", "ReTestItems", "ReverseDiff", "SimpleChains", "StableRNGs", "Statistics", "Test", "Tracker", "Zygote"]
23 changes: 23 additions & 0 deletions ext/LuxReactantExt/LuxReactantExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
module LuxReactantExt

using Adapt: adapt
using ArgCheck: @argcheck
using ConcreteStructs: @concrete
using Enzyme: Enzyme, Active, Const, Duplicated
using Functors: fmapstructure, fmap
using Random: AbstractRNG, Xoshiro
using Reactant: Reactant
using Lux: Lux, LuxEltypeAdaptor, AutoReactant
using LuxCore: LuxCore, AbstractExplicitLayer

include("utils.jl")

# compile just the model. This allows us to run part of the model in vanilla LLVM. Needed
# for cases where we can't currently compile via Reactant or where XLA is not great
# for the model.
include("layer.jl")

# compile the entire training loop
include("train.jl")

end
258 changes: 258 additions & 0 deletions ext/LuxReactantExt/layer.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,258 @@
# Reactant doesn't handle mixed eltypes that well, so we will first try to compile it as
# a usual julia function. However, if that fails, we will type cast and try to recompile.
# Note that this is only a one time operation so it doesn't matter if this step is too slow.
function Lux.__to_reactant_adaptor(
to::Lux.ToReactantAdaptor{FST}, model::AbstractExplicitLayer) where {FST}
input_prototype = to.input_prototype
input_eltype = Lux.__recursive_eltype(input_prototype)
ps, st = Lux.setup(LuxCore.replicate(to.rng), model)
ps = to.ps_transform(ps)
ps_eltype = Lux.__recursive_eltype(ps)
st_eltype = Lux.__recursive_eltype(st)

newT = promote_type(input_eltype, ps_eltype, st_eltype)
eltype_adaptor = nothing

if !to.force_allow_mixed_eltypes &&
any(x -> x != newT && x != Union{}, (input_eltype, ps_eltype, st_eltype))
try # Try compiling, but this might fail
return Lux.__to_reactant_adaptor(to, model, input_prototype, ps, st, nothing)
catch err
@warn """
Mixed Eltypes detected. Failure is NOT unexpected. Trying to recompile with a \
common eltype.

HINT: To force compiling the mixed eltypes, set \
`force_allow_mixed_eltypes=true` in the constructor of `ToReactantAdaptor`.

If compilation succeeds, all inputs to the compiled model will be \
automatically type casted to the common eltype.\n
""" exception=err input_eltype ps_eltype st_eltype common_eltype=newT
end

eltype_adaptor = LuxEltypeAdaptor{newT}()
input_prototype = adapt(eltype_adaptor, to.input_prototype)
ps = adapt(eltype_adaptor, ps)
st = adapt(eltype_adaptor, st)
end

return Lux.__to_reactant_adaptor(to, model, input_prototype, ps, st, eltype_adaptor)
end

function Lux.__to_reactant_adaptor(
to::Lux.ToReactantAdaptor{FST}, model::AbstractExplicitLayer,
input_prototype, ps, st, eltype_adaptor) where {FST}
output = first(model(input_prototype, ps, st))
concrete_output = Lux.__make_reactant_array(output)

concrete_input = Lux.__make_reactant_array(input_prototype)
cps = Lux.__make_reactant_array(ps)
cst = Lux.__make_reactant_array(st)

smodel = Lux.StatefulLuxLayer{FST}(model, cps, cst)
fwd_fn = to.skip_compile_forward ? nothing :
Reactant.compile((m, x) -> m(x), (smodel, concrete_input))

cst_test = Lux.__make_reactant_array(Lux.testmode(st))
smodel_test = Lux.StatefulLuxLayer{FST}(model, cps, cst_test)
inference_fn = to.skip_compile_inference ? nothing :
Reactant.compile((m, x) -> m(x), (smodel_test, concrete_input))

vjp_fn = if to.skip_compile_vjp
nothing
else
function enzyme_vjp_fn(m, x, y, dy)
dx = Enzyme.make_zero(x)
dps = Enzyme.make_zero(m.ps)
st_m = ifelse(FST, m.st, m.st_any)

function wrapper_fn!(y, model, x, ps, st)
copyto!(y, first(LuxCore.apply(model, x, ps, st)))
return nothing
end

Enzyme.autodiff(
Enzyme.Reverse, wrapper_fn!, Enzyme.Const, Enzyme.Duplicated(y, dy),
Enzyme.Const(m.model), Enzyme.Duplicated(x, dx),
Enzyme.Duplicated(m.ps, dps), Enzyme.Const(st_m))
return dx, dps
end

try
concrete_output2 = Lux.__make_reactant_array(deepcopy(output))
Reactant.compile(
enzyme_vjp_fn, (smodel, concrete_input, concrete_output, concrete_output2))
catch err
to.force_compile_backward && rethrow(err)
@error """
Enzyme failed to compile the backward pass. Differentiation will be disabled \
for this model.

HINT: To force compilation of the backward pass, set \
`force_compile_backward=true` in the constructor of `ToReactantAdaptor`.\n
""" exception=err
nothing
end
end

jvp_fn = if to.skip_compile_jvp
nothing
else # TODO: Implement JVP with Enzyme.Forward
throw(ArgumentError("JVPs are not implemented yet."))
end

return Lux.ReactantLayer{
FST, Lux.__recursive_eltype(input_prototype), typeof(input_prototype),
typeof(concrete_input), typeof(cst), typeof(cst_test)}(
to, cps, model, fwd_fn, inference_fn, vjp_fn, jvp_fn,
eltype_adaptor, fmapstructure(Lux.__size, input_prototype))
end

# TODO: Currently we are maintaining 2 copies of the parameters, this is not ideal.
# We can return the parameters and states from the layer itself, since we don't care
# about the values, but just the type.
function LuxCore.initialparameters(rng::AbstractRNG, layer::Lux.ReactantLayer)
ps = layer.adaptor(LuxCore.initialparameters(rng, layer.layer))
layer.eltype_adaptor !== nothing && (ps = adapt(layer.eltype_adaptor, ps))
return Lux.__make_reactant_array(ps)
end

function LuxCore.initialstates(rng::AbstractRNG, layer::Lux.ReactantLayer)
st = LuxCore.initialstates(rng, layer.layer)
layer.eltype_adaptor !== nothing && (st = adapt(layer.eltype_adaptor, st))
return (; states=Lux.__make_reactant_array(st), training=Val(true))
end

function (l::Lux.ReactantLayer{FST, T})(x, ps, st::NamedTuple) where {FST, T}
l.eltype_adaptor !== nothing && (x = adapt(l.eltype_adaptor, x))

# XLARuntimeError is not great, so check and terminate early if needed
@argcheck fmapstructure(Lux.__size, x) == l.input_structure

# TODO: For non array inputs this we make the eltype uniform which might not be
# desirable. We should handle those cases with `fmap`
if T != Lux.__recursive_eltype(x)
@warn """
`Reactant.compile` was called with input eltype $(T) but the current input eltype \
is $(Lux.__recursive_eltype(x)). This might lead to unexpected behavior.

We will convert the input to $(T) and continue. If you want to avoid this, please \
recompile the model with the correct input eltype.
""" maxlog=1
x = adapt(LuxEltypeAdaptor{T}(), x)
end

return Lux.__apply_reactant(l, x, ps, st)
end

@inline function Lux.__apply_reactant(l::Lux.ReactantLayer, x, ps, st)
y, st_ = Lux.__apply_reactant(l, x, ps, st.states, st.training)
return y, (; states=st_, training=st.training)
end

# This is the ideal case where all the types match correctly.
# Input Type mispatches should not happen here, they should be handled before this function
# is called.
@inline function Lux.__apply_reactant(l::Lux.ReactantLayer{FST, T, inType}, x::inType,
ps, st, training) where {FST, T, inType}
return Lux.__apply_reactant(l, Lux.__make_reactant_array(x), ps, st, training)
end

@inline function Lux.__apply_reactant(
l::Lux.ReactantLayer{FST, T, inType, inCType, stType, stTestType, psType},
x::inCType, ps::psType, st::stType,
training) where {FST, T, inType, inCType, psType, stType, stTestType}
smodel = Lux.StatefulLuxLayer{FST}(l.layer, ps, st)
return (
Lux.__apply_reactant(l, smodel, x, training), ifelse(FST, smodel.st, smodel.st_any))
end

# Parameter type mismatch. This might be too common so try to handle it gracefully.
@inline function Lux.__apply_reactant(
l::Lux.ReactantLayer{FST, T, inType, inCType, stType, stTestType, psType},
x::inCType, ps::psType2, st,
training) where {FST, T, inType, inCType, stType, stTestType, psType, psType2}
@warn "Parameter Type Mismatch with compiled Reactant function. This will lead to \
performance regressions" maxlog=1

ps = __try_similar_structure(Lux.__named_tuple(ps), l.concrete_ps)
ps = l.adaptor(ps)
l.eltype_adaptor !== nothing && (ps = adapt(l.eltype_adaptor, ps))
ps = Lux.__make_reactant_array(ps)

if typeof(ps) != psType
@warn "Automatic type conversion failed for `ps`." original_ps_type=psType2
__graceful_type_mismatch_error(l, x, ps, st, training)
end

return Lux.__apply_reactant(l, Lux.__make_reactant_array(x), ps, st, training)
end

function Lux.__apply_reactant(l, x, ps, st, training)
return __graceful_type_mismatch_error(l, x, ps, st, training)
end

@inline function Lux.__apply_reactant(l::Lux.ReactantLayer, smodel, x, ::Val{true})
@argcheck l.fwd_fn !== nothing
return l.fwd_fn(smodel, x)
end

@inline function Lux.__apply_reactant(l::Lux.ReactantLayer, smodel, x, ::Val{false})
@argcheck l.inference_fn !== nothing
return l.inference_fn(smodel, x)
end

# Don't inline, else types don't get displayed in the stack trace
function __graceful_type_mismatch_error(
::Lux.ReactantLayer{FST, T, inType, inCType, stType, stTestType, psType},
x,
ps,
st,
::Val{training}) where {
FST, T, inType, inCType, psType, stType, stTestType, training}
#! format: off
input_type_mismatch_str = typeof(x) == inType || typeof(x) == inCType ? """
1. Input Types Matched.
""" : """
1. Input Type: $(typeof(x)).
Compiled Input Type: $(inType).
Compiled Concrete Input Type: $(inCType).
"""
#! format: on

ps_type_mismatch_str = typeof(ps) == psType ? """
2. Parameter Types Matched.
""" : """
2. Parameter Type: $(typeof(ps)).
Compiled Parameter Type: $(psType).
"""

st_type_mismatch_str = if training
typeof(st) == stType ? """
3. State Types Matched.
""" : """
3. State Type: $(typeof(st)).
Compiled State Type: $(stType).
"""
else
typeof(st) == stTestType ? """
3. State Types Matched.
""" : """
3. State Type: $(typeof(st)).
Compiled State Type: $(stTestType).
"""
end

throw(ArgumentError("""
Model compiled types and input types don't match. We tried our best to convert the \
types to the right ones, but we failed. Ideally the argument types should not be \
modified after compilation.

1. Recompile the model with the correct input types.
2. Open an issue on the Lux.jl repository, to check if we can ease out the automatic \
type conversion.

List of Type Mismatches:

$(input_type_mismatch_str) $(ps_type_mismatch_str) $(st_type_mismatch_str)"""))
end
55 changes: 55 additions & 0 deletions ext/LuxReactantExt/train.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# TODO: For the iip versions as well. Metaprogram that part

# Case III: Nothing is cached. First call to `single_train_step`
function Lux.Experimental.single_train_step(
ad::AutoReactant, obj_fn::F, data, ts::Lux.Experimental.TrainState) where {F}
# ps = ts.parameters
dps = Lux.__recursive_make_zero(ts.parameters)
# st = ts.states
# model = ts.model

data = Lux.__make_reactant_array(data)
model = Lux.__make_reactant_array(ts.model)
dps = Lux.__make_reactant_array(dps)
ps = Lux.__make_reactant_array(ts.parameters)
st = Lux.__make_reactant_array(ts.states)

# @show

# function reverse_fn_wrapper(obj_fn, model, ps, dps, st, data)
obj_fn_wrapper, st_updated, stats = Lux.Experimental.__wrap_objective_function(
obj_fn, st)
# st_, stats = nothing, (;)
# @show __update_fn_wrapper(obj_fn_wrapper, model, ps, dps, st, data)

# function obj_fn_wrapper(obj_fn, model, ps, st, data) # Intentionally boxing
# y, st_, stats = obj_fn(model, ps, st, data)
# return y
# end

# @show obj_fn_wrapper # (obj_fn, model, ps, st, data)

# _, loss = Enzyme.autodiff(Enzyme.ReverseWithPrimal, obj_fn_wrapper, Active,
# Const(model), Duplicated(ps, dps), Const(st), Const(data))
# loss = obj_fn_wrapper(obj_fn, model, ps, st, data)

# return loss, st_new, stats # FIXME: Return the correct things
# return loss
# end

# @show reverse_fn_wrapper # (obj_fn, model, ps, dps, st, data)

# @show reverse_fn_wrapper(obj_fn, model, ts.parameters,
# Lux.__recursive_make_zero(ts.parameters), ts.states, data)

compiled_fn = Reactant.compile(__update_fn_wrapper, (obj_fn, model, ps, dps, st, data))

# return compiled_fn, (obj_fn, model, ps, dps, st, data)
end

function __update_fn_wrapper(obj_fn, model, ps, dps, st, data)
_, loss = Enzyme.autodiff(Enzyme.ReverseWithPrimal, obj_fn, Active, Const(model),
Duplicated(ps, dps), Const(st), Const(data))
# Lux.Experimental.apply_gradients()
return loss
end
20 changes: 20 additions & 0 deletions ext/LuxReactantExt/utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
@inline Lux.__make_reactant_array(x::Reactant.RArray) = x
@inline function Lux.__make_reactant_array(x::AbstractArray)
hasmethod(Reactant.ArrayToConcrete, Tuple{typeof(x)}) &&
return Reactant.ConcreteRArray(x)
return __make_tracer(x)
end
@inline Lux.__make_reactant_array(x) = __make_tracer(x)

@inline function __make_tracer(x)
return Reactant.make_tracer(IdDict(), x, (), Reactant.ArrayToConcrete, nothing)
end

@inline function __try_similar_structure(x::AbstractArray, y::NamedTuple{()})
length(x) == 0 && return y
throw(DimensionMismatch(lazy"Expected empty array, got $(size(x))."))
end
@inline function __try_similar_structure(x::AbstractArray, y::AbstractArray)
return parent(x) !== x ? copy(x) : x # unview arrays and such
end
@inline __try_similar_structure(x, y) = fmap(__try_similar_structure, x, y)
Loading