Skip to content
Open
89 changes: 89 additions & 0 deletions lib/EnzymeCore/src/EnzymeCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ export DefaultABI, FFIABI, InlineABI, NonGenABI
export BatchDuplicatedFunc
export within_autodiff, ignore_derivatives
export needs_primal
export ChunkStrategy, SmallestChunk, LargestChunk, FixedChunk, AutoChunk, pick_chunksize

function batch_size end

Expand Down Expand Up @@ -797,4 +798,92 @@ end

Combined(mode::ReverseMode) = mode

"""
ChunkStrategy

Abstract type gathering strategies for chunk size selection.

# Subtypes

- [`SmallestChunk`](@ref)
- [`LargestChunk`](@ref)
- [`FixedChunk`](@ref)
- [`AutoChunk`](@ref)
"""
abstract type ChunkStrategy end

"""
SmallestChunk()

Select chunk size equal to 1, so that the corresponding array is processed in as many chunks as it has elements.

!!! tip
In the current Enzyme interface, this strategy is equivalent to setting `chunk = nothing`.
"""
Comment thread
gdalle marked this conversation as resolved.
struct SmallestChunk <: ChunkStrategy end

"""
LargestChunk()

Select chunk size equal to the number of elements, so that the corresponding array is processed in a single chunk.
"""
struct LargestChunk <: ChunkStrategy end

"""
FixedChunk{C}()

Select chunk size equal to a fixed integer `C`.

!!! tip
In the current Enzyme interface, this chunk strategy is equivalent to setting `chunk = Val(C)`.

!!! warning
This chunk strategy will error if the corresponding array has length `< C`.
"""
struct FixedChunk{C} <: ChunkStrategy end
Comment thread
gdalle marked this conversation as resolved.

"""
AutoChunk()

Select chunk size automatically based on internal Enzyme-specific heuristics, which are subject to change.
"""
struct AutoChunk <: ChunkStrategy end
Comment thread
gdalle marked this conversation as resolved.

const DEFAULT_CHUNK_SIZE = 16
Comment thread
gdalle marked this conversation as resolved.

"""
pick_chunksize(s::ChunkStrategy, n::Integer)
pick_chunksize(s::ChunkStrategy, a::AbstractArray)

Compute tLargestChunkze chosen by strategy `s` based on the integer `n` or the array `a` (`n` corresponds to the array's length)
Return a `Val{C}` object.
LargestChunk
- In forward-modeLargestChunkand Jacobians, `a` would be the input array.
- In reverse-mode Jacobians, `a` would be the output array.

!!! warning
For `LargestChunk` and `AutoChunk` strategies, this function is type-unstable.
"""
function pick_chunksize end

pick_chunksize(::SmallestChunk, a_or_n::Union{Integer,AbstractArray}) = Val(1)

pick_chunksize(::LargestChunk, n::Integer) = Val(n)
pick_chunksize(::LargestChunk, a::AbstractArray) = Val(length(a)) # allows inference on static arrays

pick_chunksize(::AutoChunk, n::Integer) = Val(min(DEFAULT_CHUNK_SIZE, n)) # TODO: improve
pick_chunksize(s::AutoChunk, a::AbstractArray) = pick_chunksize(s, length(a))

function pick_chunksize(s::FixedChunk{C}, a_or_n::Union{Integer,AbstractArray}) where {C}
check_length(s, a_or_n)
return Val{C}()
end

function check_length(::FixedChunk{C}, n::Integer) where {C}
if n < C
error("Chunk size $C is larger than length $n")
end
end
check_length(s::FixedChunk{C}, a::AbstractArray) where {C} = check_length(s, length(a))

end # module EnzymeCore
34 changes: 34 additions & 0 deletions lib/EnzymeCore/test/chunk.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
using Test
using EnzymeCore

@testset "SmallestChunk" begin
@test pick_chunksize(SmallestChunk(), 10) == Val(1)
@test pick_chunksize(SmallestChunk(), ones(10)) == Val(1)
@test pick_chunksize(SmallestChunk(), 100) == Val(1)
@test pick_chunksize(SmallestChunk(), ones(100)) == Val(1)
end

@testset "LargestChunk" begin
@test pick_chunksize(LargestChunk(), 10) == Val(10)
@test pick_chunksize(LargestChunk(), ones(10)) == Val(10)
@test pick_chunksize(LargestChunk(), 100) == Val(100)
@test pick_chunksize(LargestChunk(), ones(100)) == Val(100)
end

@testset "FixedChunk" begin
@test_throws ErrorException pick_chunksize(FixedChunk{3}(), 2)
@test_throws ErrorException pick_chunksize(FixedChunk{3}(), ones(2))
@test pick_chunksize(FixedChunk{3}(), 10) == Val(3)
@test pick_chunksize(FixedChunk{3}(), ones(10)) == Val(3)
@test pick_chunksize(FixedChunk{3}(), 100) == Val(3)
@test pick_chunksize(FixedChunk{3}(), ones(100)) == Val(3)
@test pick_chunksize(FixedChunk{4}(), 100) == Val(4)
@test pick_chunksize(FixedChunk{4}(), ones(100)) == Val(4)
end

@testset "AutoChunk" begin
@test pick_chunksize(AutoChunk(), 10) == Val(10)
@test pick_chunksize(AutoChunk(), ones(10)) == Val(10)
@test pick_chunksize(AutoChunk(), 100) == Val(16)
@test pick_chunksize(AutoChunk(), ones(100)) == Val(16)
end
17 changes: 9 additions & 8 deletions lib/EnzymeCore/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,13 @@ using EnzymeCore
@testset "Mode modification" begin
include("mode_modification.jl")
end
end

@testset "within_autodiff" begin
@test !EnzymeCore.within_autodiff()
end

@testset "ignore_derivatives" begin
@test EnzymeCore.ignore_derivatives(3) == 3
@testset "Chunk strategy" begin
include("chunk.jl")
end
@testset "within_autodiff" begin
@test !EnzymeCore.within_autodiff()
end
@testset "ignore_derivatives" begin
@test EnzymeCore.ignore_derivatives(3) == 3
end
end
3 changes: 3 additions & 0 deletions src/Enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,11 @@ export autodiff,
make_zero!,
remake_zero!

import EnzymeCore: ChunkStrategy, SmallestChunk, LargestChunk, FixedChunk, AutoChunk, pick_chunksize

export jacobian, gradient, gradient!, hvp, hvp!, hvp_and_gradient!
export batch_size, onehot, chunkedonehot
export SmallestChunk, LargestChunk, FixedChunk, AutoChunk

using LinearAlgebra
import SparseArrays
Expand Down
84 changes: 56 additions & 28 deletions src/sugar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,27 @@ gradient!(ReverseWithPrimal, dx, f, [2.0, 3.0])
end
end

const ExtendedChunkStrategy = Union{ChunkStrategy, Nothing, Val}

# eats and returns a type because generated functions work on argument types
get_strategy(chunk::Type{CS}) where {CS<:ChunkStrategy} = chunk

function get_strategy(::Type{Nothing})
Base.depwarn(
"The `chunk=nothing` configuration will be deprecated in a future release. Please use `chunk=SmallestChunk()` instead.",
:get_strategy,
)
return SmallestChunk
end

function get_strategy(::Type{Val{C}}) where {C}
Base.depwarn(
"The `chunk=Val(C)` configuration will be deprecated in a future release. Please use `chunk=FixedChunk{C}()` instead.",
:get_strategy,
)
return FixedChunk{C}
end

@inline function chunkedonehot(x, ::Val{chunk}) where {chunk}
sz = length(x)
num = ((sz + chunk - 1) ÷ chunk)
Expand All @@ -428,11 +449,16 @@ end
return ((one(x),),)
end

@inline function chunkedonehot(x, strategy::ChunkStrategy)
return chunkedonehot(x, pick_chunksize(strategy, x))
end

@inline tupleconcat(x) = x
@inline tupleconcat(x, y) = (x..., y...)
@inline tupleconcat(x, y, z...) = (x..., tupleconcat(y, z...)...)

@generated function create_shadows(chunk::ChunkTy, x::X, vargs::Vararg{Any,N}) where {ChunkTy, X, N}
@generated function create_shadows(chunk::ExtendedChunkStrategy, x::X, vargs::Vararg{Any,N}) where {X, N}
chunk_strategy = get_strategy(chunk)
args = Union{Symbol,Expr}[:x]
tys = Type[X]
for i in 1:N
Expand All @@ -446,7 +472,7 @@ end
push!(exprs, :(nothing))
elseif ty <: AbstractFloat
push!(exprs, :(nothing))
elseif ChunkTy == Nothing || ChunkTy == Val{1}
elseif chunk_strategy == SmallestChunk || chunk_strategy == FixedChunk{1}
push!(exprs, :(onehot($arg)))
else
push!(exprs, :(chunkedonehot($arg, chunk)))
Expand Down Expand Up @@ -502,10 +528,11 @@ end
@inline specialize_output(output, input) = output

"""
gradient(::ForwardMode, f, x; shadows=onehot(x), chunk=nothing)
gradient(::ForwardMode, f, x, args...; chunk=SmallestChunk(), shadows=create_shadows(chunk, x, args...))

Compute the gradient of an array-input function `f` using forward mode. The
optional keyword argument `shadow` is a vector of one-hot vectors of type `x`
Compute the gradient of an array-input function `f` using forward mode.
The optional keyword argument `chunk` denotes the chunk size to use: it can be any instance of [`EnzymeCore.ChunkStrategy`](@ref EnzymeCore.ChunkStrategy).
The optional keyword argument `shadow` is a vector of one-hot vectors of type `x`
which are used to forward-propagate into the return. For performance reasons,
this should be computed once, outside the call to `gradient`, rather than
within this call.
Expand All @@ -530,15 +557,15 @@ gradient(ForwardWithPrimal, f, [2.0, 3.0])
```

```jldoctest gradfwd
gradient(Forward, f, [2.0, 3.0]; chunk=Val(1))
gradient(Forward, f, [2.0, 3.0]; chunk=FixedChunk{1}())

# output

([3.0, 2.0],)
```

```jldoctest gradfwd
gradient(ForwardWithPrimal, f, [2.0, 3.0]; chunk=Val(1))
gradient(ForwardWithPrimal, f, [2.0, 3.0]; chunk=FixedChunk{1}())

# output
(derivs = ([3.0, 2.0],), val = 6.0)
Expand Down Expand Up @@ -587,9 +614,11 @@ gradient(Forward, mul, [2.0, 3.0], Const([2.7, 3.1]))
f::F,
x::ty_0,
args::Vararg{Any,N};
chunk::CS = nothing,
chunk::ExtendedChunkStrategy = SmallestChunk(),
shadows::ST = create_shadows(chunk, x, args...),
) where {F, ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity,StrongZero,CS,ST, ty_0, N}
) where {F, ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity,StrongZero,ST, ty_0, N}

chunk_strategy = get_strategy(chunk)

syms = Union{Symbol,Expr}[:x]
shads = Union{Symbol,Expr}[:(shadows[1])]
Expand Down Expand Up @@ -617,7 +646,7 @@ gradient(Forward, mul, [2.0, 3.0], Const([2.7, 3.1]))
end
end

if CS == Val{0}
if chunk_strategy == FixedChunk{0}
return quote
Base.@_inline_meta
throw(ErrorException("Cannot differentiate with a batch size of 0"))
Expand Down Expand Up @@ -659,7 +688,7 @@ gradient(Forward, mul, [2.0, 3.0], Const([2.7, 3.1]))
:($resp[1])
elseif argnum == 0
vals[i]
elseif CS == Nothing
elseif chunk_strategy == SmallestChunk
dargs = Union{Symbol,Expr}[]
for (j, arg2) in enumerate(syms)
if i == j
Expand Down Expand Up @@ -688,7 +717,7 @@ gradient(Forward, mul, [2.0, 3.0], Const([2.7, 3.1]))
end

:(values($resp[1]))
elseif CS == Val{1}
elseif chunk_strategy == FixedChunk{1}
subderivatives = Union{Symbol,Expr}[]
for an in 1:argnum
dargs = Union{Symbol,Expr}[]
Expand Down Expand Up @@ -788,7 +817,7 @@ end
"""
jacobian(::ForwardMode, args...; kwargs...)

Equivalent to gradient(::ForwardMode, args...; kwargs...)
Equivalent to `gradient(::ForwardMode, args...; kwargs...)`.
"""
@inline function jacobian(fm::ForwardMode, args...; kwargs...)
gradient(fm, args...; kwargs...)
Expand All @@ -798,10 +827,11 @@ end
mode::ReverseMode{ReturnPrimal},
RT::RType,
n_outs::OutType,
chunk::CT,
chunk::ExtendedChunkStrategy,
f::F,
xs::Vararg{Any, Nargs}
) where {ReturnPrimal,RType, F,Nargs,OutType,CT}
) where {ReturnPrimal,RType, F,Nargs,OutType}
chunk_strategy = get_strategy(chunk)
fty = if f <: Enzyme.Annotation
f.parameters[1]
else
Expand Down Expand Up @@ -891,7 +921,7 @@ end
end
end

if chunk == Val{0}
if chunk_strategy == FixedChunk{0}
return quote
throw(ErrorException("Cannot differentiate with a batch size of 0"))
end
Expand All @@ -912,11 +942,9 @@ end
MDTys = Union{Expr,Symbol}[]
MDTysLast = Union{Expr,Symbol}[]

chunksize = if chunk <: Val
chunk.parameters[1]
else
1
end
chunksize_val = pick_chunksize(chunk_strategy(), n_out_val)
chunksize = typeof(chunksize_val).parameters[1]

num = ((n_out_val + chunksize - 1) ÷ chunksize)

last_size = if num * chunksize == n_out_val
Expand All @@ -940,7 +968,7 @@ end
else
push!(exprs, Expr(:(=), mdi, :(Compiler.active_reg_nothrow($xti) == Compiler.ActiveState || Compiler.active_reg_nothrow($xti) == Compiler.MixedState)))

if chunk == Val{1} || chunk == Nothing
if chunk_strategy == SmallestChunk || chunk_strategy == FixedChunk{1}
push!(MDTys, :($mdi ? MixedDuplicated{$xti} : Duplicated{$xti}))
else
push!(MDTys, :($mdi ? BatchMixedDuplicated{$xti, $chunksize} : BatchDuplicated{$xti, $chunksize}))
Expand Down Expand Up @@ -1169,12 +1197,12 @@ end
end

"""
jacobian(::ReverseMode, f, x; n_outs=nothing, chunk=nothing)
jacobian(::ReverseMode, f, x; n_outs=nothing, chunk=SmallestChunk())
jacobian(::ReverseMode, f, x)

Compute the jacobian of a array-output function `f` using (potentially vector)
reverse mode. The `chunk` argument optionally denotes the chunk size to use and
`n_outs` optionally denotes the shape of the array returned by `f` (e.g `size(f(x))`).
Compute the jacobian of a array-output function `f` using (potentially vector) reverse mode.
The optional keyword argument `chunk` denotes the chunk size to use: it can be any instance of [`EnzymeCore.ChunkStrategy`](@ref EnzymeCore.ChunkStrategy).
The optional keyword argument `n_outs` denotes the shape of the array returned by `f` (e.g `size(f(x))`).

Example:

Expand Down Expand Up @@ -1227,8 +1255,8 @@ this function will retun an AbstractArray of shape `size(output)` of values of t
f::F,
xs::Vararg{Any, Nargs};
n_outs::OutType = nothing,
chunk::CT = nothing,
) where {F,Nargs, OutType,CT}
chunk::ExtendedChunkStrategy = SmallestChunk(),
) where {F,Nargs, OutType}

fty = if f <: Enzyme.Annotation
f.parameters[1]
Expand Down
Loading