diff --git a/lib/EnzymeCore/src/EnzymeCore.jl b/lib/EnzymeCore/src/EnzymeCore.jl index 52d211dafc..37d96040f4 100644 --- a/lib/EnzymeCore/src/EnzymeCore.jl +++ b/lib/EnzymeCore/src/EnzymeCore.jl @@ -8,6 +8,7 @@ export DefaultABI, FFIABI, InlineABI, NonGenABI export BatchDuplicatedFunc export within_autodiff, ignore_derivatives export needs_primal +export ChunkStrategy, LargestChunk, FixedChunk, AutoChunk, pick_chunksize function batch_size end @@ -797,4 +798,71 @@ end Combined(mode::ReverseMode) = mode +""" + ChunkStrategy + +Abstract type gathering strategies for chunk size selection. + +# Subtypes + +- [`LargestChunk`](@ref) +- [`FixedChunk`](@ref) +- [`AutoChunk`](@ref) +""" +abstract type ChunkStrategy end + +""" + LargestChunk() + +Select chunk size equal to the number of elements, so that the corresponding array is processed in a single chunk. + +!!! tip + In the current Enzyme interface, this strategy is equivalent to setting `chunk = nothing`. +""" +struct LargestChunk <: ChunkStrategy end + +""" + FixedChunk{C}() + FixedChunk(C) # type-unstable + +Select chunk size equal to a fixed integer `C`. + +!!! tip + In the current Enzyme interface, this chunk strategy is equivalent to setting `chunk = Val(C)`. +""" +struct FixedChunk{C} <: ChunkStrategy end + +FixedChunk(C::Int) = FixedChunk{C}() + +""" + AutoChunk() + +Select chunk size automatically based on internal Enzyme-specific heuristics, which are subject to change. +""" +struct AutoChunk <: ChunkStrategy end + +const DEFAULT_CHUNK_SIZE = 16 + +""" + pick_chunksize(s::ChunkStrategy, n::Integer) + pick_chunksize(s::ChunkStrategy, a::AbstractArray) + +Compute the chunk size chosen by strategy `s` based on the integer `n` or the array `a` (`n` corresponds to the array's length) +Return a `Val{C}` object. + +- In forward-mode Jacobians, `a` would be the input array. +- In reverse-mode Jacobians, `a` would be the output array. + +!!! warning + For `LargestChunk` and `AutoChunk` strategies, this function is type-unstable. +""" +function pick_chunksize end + +pick_chunksize(::LargestChunk, n::Integer) = Val(n) +pick_chunksize(::LargestChunk, a::AbstractArray) = Val(length(a)) # allows inference on static arrays + +pick_chunksize(::AutoChunk, n::Integer) = Val(min(DEFAULT_CHUNK_SIZE, n)) # TODO: improve +pick_chunksize(s::AutoChunk, a::AbstractArray) = pick_chunksize(s, length(a)) +pick_chunksize(::FixedChunk{C}, ::Union{Integer,AbstractArray}) where {C} = Val{C}() + end # module EnzymeCore diff --git a/lib/EnzymeCore/test/chunk.jl b/lib/EnzymeCore/test/chunk.jl new file mode 100644 index 0000000000..f4e56612a6 --- /dev/null +++ b/lib/EnzymeCore/test/chunk.jl @@ -0,0 +1,28 @@ +using Test +using EnzymeCore + +@testset "LargestChunk" begin + @test pick_chunksize(LargestChunk(), 10) == Val(10) + @test pick_chunksize(LargestChunk(), ones(10)) == Val(10) + @test pick_chunksize(LargestChunk(), 100) == Val(100) + @test pick_chunksize(LargestChunk(), ones(100)) == Val(100) +end + +@testset "FixedChunk" begin + @test FixedChunk(3) == FixedChunk{3}() + @test_throws ErrorException pick_chunksize(FixedChunk{3}(), 2) + @test_throws ErrorException pick_chunksize(FixedChunk{3}(), ones(2)) + @test pick_chunksize(FixedChunk{3}(), 10) == Val(3) + @test pick_chunksize(FixedChunk{3}(), ones(10)) == Val(3) + @test pick_chunksize(FixedChunk{3}(), 100) == Val(3) + @test pick_chunksize(FixedChunk{3}(), ones(100)) == Val(3) + @test pick_chunksize(FixedChunk{4}(), 100) == Val(4) + @test pick_chunksize(FixedChunk{4}(), ones(100)) == Val(4) +end + +@testset "AutoChunk" begin + @test pick_chunksize(AutoChunk(), 10) == Val(10) + @test pick_chunksize(AutoChunk(), ones(10)) == Val(10) + @test pick_chunksize(AutoChunk(), 100) == Val(16) + @test pick_chunksize(AutoChunk(), ones(100)) == Val(16) +end diff --git a/lib/EnzymeCore/test/runtests.jl b/lib/EnzymeCore/test/runtests.jl index c32747e0f0..bf2bcd2143 100644 --- a/lib/EnzymeCore/test/runtests.jl +++ b/lib/EnzymeCore/test/runtests.jl @@ -36,12 +36,13 @@ using EnzymeCore @testset "Mode modification" begin include("mode_modification.jl") end -end - -@testset "within_autodiff" begin - @test !EnzymeCore.within_autodiff() -end - -@testset "ignore_derivatives" begin - @test EnzymeCore.ignore_derivatives(3) == 3 + @testset "Chunk strategy" begin + include("chunk.jl") + end + @testset "within_autodiff" begin + @test !EnzymeCore.within_autodiff() + end + @testset "ignore_derivatives" begin + @test EnzymeCore.ignore_derivatives(3) == 3 + end end diff --git a/src/Enzyme.jl b/src/Enzyme.jl index 5494830e61..b790fae3e6 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -111,8 +111,11 @@ export autodiff, make_zero!, remake_zero! +import EnzymeCore: ChunkStrategy, LargestChunk, FixedChunk, AutoChunk, pick_chunksize + export jacobian, gradient, gradient!, hvp, hvp!, hvp_and_gradient! export batch_size, onehot, chunkedonehot +export LargestChunk, FixedChunk, AutoChunk using LinearAlgebra import SparseArrays diff --git a/src/sugar.jl b/src/sugar.jl index 778da68e48..319f937e8e 100644 --- a/src/sugar.jl +++ b/src/sugar.jl @@ -415,6 +415,27 @@ gradient!(ReverseWithPrimal, dx, f, [2.0, 3.0]) end end +const ExtendedChunkStrategy = Union{ChunkStrategy, Nothing, Val} + +# eats and returns a type because generated functions work on argument types +get_strategy(chunk::Type{CS}) where {CS<:ChunkStrategy} = chunk + +function get_strategy(::Type{Nothing}) + Base.depwarn( + "The `chunk=nothing` configuration will be deprecated in a future release. Please use `chunk=LargestChunk()` instead.", + :get_strategy, + ) + return LargestChunk() +end + +function get_strategy(::Type{Val{C}}) where {C} + Base.depwarn( + "The `chunk=Val(C)` configuration will be deprecated in a future release. Please use `chunk=FixedChunk{C}()` instead.", + :get_strategy, + ) + return FixedChunk{C} +end + @inline function chunkedonehot(x, ::Val{chunk}) where {chunk} sz = length(x) num = ((sz + chunk - 1) ÷ chunk) @@ -428,11 +449,16 @@ end return ((one(x),),) end +@inline function chunkedonehot(x, strategy::ChunkStrategy) + return chunkedonehot(x, pick_chunksize(strategy, x)) +end + @inline tupleconcat(x) = x @inline tupleconcat(x, y) = (x..., y...) @inline tupleconcat(x, y, z...) = (x..., tupleconcat(y, z...)...) -@generated function create_shadows(chunk::ChunkTy, x::X, vargs::Vararg{Any,N}) where {ChunkTy, X, N} +@generated function create_shadows(chunk::ExtendedChunkStrategy, x::X, vargs::Vararg{Any,N}) where {X, N} + chunk_strategy = get_strategy(chunk) args = Union{Symbol,Expr}[:x] tys = Type[X] for i in 1:N @@ -446,7 +472,7 @@ end push!(exprs, :(nothing)) elseif ty <: AbstractFloat push!(exprs, :(nothing)) - elseif ChunkTy == Nothing || ChunkTy == Val{1} + elseif chunk_strategy == LargestChunk || chunk_strategy == FixedChunk{1} push!(exprs, :(onehot($arg))) else push!(exprs, :(chunkedonehot($arg, chunk))) @@ -502,10 +528,11 @@ end @inline specialize_output(output, input) = output """ - gradient(::ForwardMode, f, x; shadows=onehot(x), chunk=nothing) + gradient(::ForwardMode, f, x, args...; chunk=LargestChunk(), shadows=create_shadows(chunk, x, args...)) -Compute the gradient of an array-input function `f` using forward mode. The -optional keyword argument `shadow` is a vector of one-hot vectors of type `x` +Compute the gradient of an array-input function `f` using forward mode. +The optional keyword argument `chunk` denotes the chunk size to use: it can be any instance of [`EnzymeCore.ChunkStrategy`](@ref EnzymeCore.ChunkStrategy). +The optional keyword argument `shadow` is a vector of one-hot vectors of type `x` which are used to forward-propagate into the return. For performance reasons, this should be computed once, outside the call to `gradient`, rather than within this call. @@ -530,7 +557,7 @@ gradient(ForwardWithPrimal, f, [2.0, 3.0]) ``` ```jldoctest gradfwd -gradient(Forward, f, [2.0, 3.0]; chunk=Val(1)) +gradient(Forward, f, [2.0, 3.0]; chunk=FixedChunk{1}()) # output @@ -538,7 +565,7 @@ gradient(Forward, f, [2.0, 3.0]; chunk=Val(1)) ``` ```jldoctest gradfwd -gradient(ForwardWithPrimal, f, [2.0, 3.0]; chunk=Val(1)) +gradient(ForwardWithPrimal, f, [2.0, 3.0]; chunk=FixedChunk{1}()) # output (derivs = ([3.0, 2.0],), val = 6.0) @@ -587,9 +614,11 @@ gradient(Forward, mul, [2.0, 3.0], Const([2.7, 3.1])) f::F, x::ty_0, args::Vararg{Any,N}; - chunk::CS = nothing, + chunk::ExtendedChunkStrategy = LargestChunk(), shadows::ST = create_shadows(chunk, x, args...), -) where {F, ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity,StrongZero,CS,ST, ty_0, N} +) where {F, ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity,StrongZero,ST, ty_0, N} + + chunk_strategy = get_strategy(chunk) syms = Union{Symbol,Expr}[:x] shads = Union{Symbol,Expr}[:(shadows[1])] @@ -617,7 +646,7 @@ gradient(Forward, mul, [2.0, 3.0], Const([2.7, 3.1])) end end - if CS == Val{0} + if chunk_strategy == FixedChunk{0} return quote Base.@_inline_meta throw(ErrorException("Cannot differentiate with a batch size of 0")) @@ -659,7 +688,7 @@ gradient(Forward, mul, [2.0, 3.0], Const([2.7, 3.1])) :($resp[1]) elseif argnum == 0 vals[i] - elseif CS == Nothing + elseif chunk_strategy == LargestChunk dargs = Union{Symbol,Expr}[] for (j, arg2) in enumerate(syms) if i == j @@ -688,7 +717,7 @@ gradient(Forward, mul, [2.0, 3.0], Const([2.7, 3.1])) end :(values($resp[1])) - elseif CS == Val{1} + elseif chunk_strategy == FixedChunk{1} subderivatives = Union{Symbol,Expr}[] for an in 1:argnum dargs = Union{Symbol,Expr}[] @@ -788,7 +817,7 @@ end """ jacobian(::ForwardMode, args...; kwargs...) -Equivalent to gradient(::ForwardMode, args...; kwargs...) +Equivalent to `gradient(::ForwardMode, args...; kwargs...)`. """ @inline function jacobian(fm::ForwardMode, args...; kwargs...) gradient(fm, args...; kwargs...) @@ -798,10 +827,11 @@ end mode::ReverseMode{ReturnPrimal}, RT::RType, n_outs::OutType, - chunk::CT, + chunk::ExtendedChunkStrategy, f::F, xs::Vararg{Any, Nargs} -) where {ReturnPrimal,RType, F,Nargs,OutType,CT} +) where {ReturnPrimal,RType, F,Nargs,OutType} + chunk_strategy = get_strategy(chunk) fty = if f <: Enzyme.Annotation f.parameters[1] else @@ -891,7 +921,7 @@ end end end - if chunk == Val{0} + if chunk_strategy == FixedChunk{0} return quote throw(ErrorException("Cannot differentiate with a batch size of 0")) end @@ -912,11 +942,9 @@ end MDTys = Union{Expr,Symbol}[] MDTysLast = Union{Expr,Symbol}[] - chunksize = if chunk <: Val - chunk.parameters[1] - else - 1 - end + chunksize_val = pick_chunksize(chunk_strategy(), n_out_val) + chunksize = typeof(chunksize_val).parameters[1] + num = ((n_out_val + chunksize - 1) ÷ chunksize) last_size = if num * chunksize == n_out_val @@ -940,7 +968,7 @@ end else push!(exprs, Expr(:(=), mdi, :(Compiler.active_reg_nothrow($xti) == Compiler.ActiveState || Compiler.active_reg_nothrow($xti) == Compiler.MixedState))) - if chunk == Val{1} || chunk == Nothing + if chunk_strategy == FixedChunk{1} push!(MDTys, :($mdi ? MixedDuplicated{$xti} : Duplicated{$xti})) else push!(MDTys, :($mdi ? BatchMixedDuplicated{$xti, $chunksize} : BatchDuplicated{$xti, $chunksize})) @@ -1169,12 +1197,12 @@ end end """ - jacobian(::ReverseMode, f, x; n_outs=nothing, chunk=nothing) + jacobian(::ReverseMode, f, x; n_outs=nothing, chunk=LargestChunk()) jacobian(::ReverseMode, f, x) -Compute the jacobian of a array-output function `f` using (potentially vector) -reverse mode. The `chunk` argument optionally denotes the chunk size to use and -`n_outs` optionally denotes the shape of the array returned by `f` (e.g `size(f(x))`). +Compute the jacobian of a array-output function `f` using (potentially vector) reverse mode. +The optional keyword argument `chunk` denotes the chunk size to use: it can be any instance of [`EnzymeCore.ChunkStrategy`](@ref EnzymeCore.ChunkStrategy). +The optional keyword argument `n_outs` denotes the shape of the array returned by `f` (e.g `size(f(x))`). Example: @@ -1227,8 +1255,8 @@ this function will retun an AbstractArray of shape `size(output)` of values of t f::F, xs::Vararg{Any, Nargs}; n_outs::OutType = nothing, - chunk::CT = nothing, -) where {F,Nargs, OutType,CT} + chunk::ExtendedChunkStrategy = LargestChunk(), +) where {F,Nargs, OutType} fty = if f <: Enzyme.Annotation f.parameters[1] diff --git a/test/sugar.jl b/test/sugar.jl index 3e133248ff..00d13f18fb 100644 --- a/test/sugar.jl +++ b/test/sugar.jl @@ -666,3 +666,51 @@ end # @show J_r_3(u, A, x) # @show J_f_3(u, A, x) end + +fchunk1(x) = sum(sin, x) +fchunk2(x) = map(sin, x) + map(cos, reverse(x)) + +@testset "Chunking strategies" begin + @testset "ChunkedOneHot" begin + @test Enzyme.chunkedonehot(ones(30), Enzyme.LargestChunk()) isa Tuple{NTuple{30}} + @test Enzyme.chunkedonehot(ones(10), Enzyme.LargestChunk()) isa Tuple{NTuple{10}} + @test Enzyme.chunkedonehot(ones(30), Enzyme.LargestChunk()) isa Tuple{NTuple{30}} + @test Enzyme.chunkedonehot(ones(3), Enzyme.FixedChunk{1}()) isa Tuple{NTuple{1},NTuple{1},NTuple{1}} + @test Enzyme.chunkedonehot(ones(10), Enzyme.FixedChunk{4}()) isa Tuple{NTuple{4},NTuple{4},NTuple{2}} + @test Enzyme.chunkedonehot(ones(10), Enzyme.FixedChunk{5}()) isa Tuple{NTuple{5},NTuple{5}} + @test Enzyme.chunkedonehot(ones(10), Enzyme.AutoChunk()) isa Tuple{NTuple{10}} + @test Enzyme.chunkedonehot(ones(30), Enzyme.AutoChunk()) isa Tuple{NTuple{16}, NTuple{14}} + @test Enzyme.chunkedonehot(ones(30), Enzyme.AutoChunk()) isa Tuple{NTuple{16}, NTuple{14}} + @test Enzyme.chunkedonehot(ones(40), Enzyme.AutoChunk()) isa Tuple{NTuple{16}, NTuple{16}, NTuple{8}} + end + + strategies = [Enzyme.LargestChunk(), Enzyme.FixedChunk{1}(), Enzyme.FixedChunk{3}(), Enzyme.AutoChunk()] + + @testset "Forward gradient" begin + @testset for chunk in strategies + for n in (2, 10) + x = ones(n) + g = Enzyme.gradient(Enzyme.Forward, fchunk1, x) + @test g == Enzyme.gradient(Enzyme.Forward, fchunk1, x; chunk) + end + end + end + @testset "Forward Jacobian" begin + @testset for chunk in strategies + for n in (2, 10) + x = ones(n) + J = Enzyme.jacobian(Enzyme.Forward, fchunk2, x) + @test J == Enzyme.jacobian(Enzyme.Forward, fchunk2, x; chunk) + end + end + end + @testset "Reverse Jacobian" begin + @testset for chunk in strategies + for n in (2, 10) + x = ones(n) + J = Enzyme.jacobian(Enzyme.Forward, fchunk2, x) + @test J == Enzyme.jacobian(Enzyme.Reverse, fchunk2, x; chunk) + end + end + end +end