Skip to content
Open
68 changes: 68 additions & 0 deletions lib/EnzymeCore/src/EnzymeCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ export DefaultABI, FFIABI, InlineABI, NonGenABI
export BatchDuplicatedFunc
export within_autodiff, ignore_derivatives
export needs_primal
export ChunkStrategy, LargestChunk, FixedChunk, AutoChunk, pick_chunksize

function batch_size end

Expand Down Expand Up @@ -797,4 +798,71 @@ end

Combined(mode::ReverseMode) = mode

"""
ChunkStrategy

Abstract type gathering strategies for chunk size selection.

# Subtypes

- [`LargestChunk`](@ref)
- [`FixedChunk`](@ref)
- [`AutoChunk`](@ref)
"""
abstract type ChunkStrategy end

"""
LargestChunk()

Select chunk size equal to the number of elements, so that the corresponding array is processed in a single chunk.

!!! tip
In the current Enzyme interface, this strategy is equivalent to setting `chunk = nothing`.
"""
Comment thread
gdalle marked this conversation as resolved.
struct LargestChunk <: ChunkStrategy end

"""
FixedChunk{C}()
FixedChunk(C) # type-unstable

Select chunk size equal to a fixed integer `C`.

!!! tip
In the current Enzyme interface, this chunk strategy is equivalent to setting `chunk = Val(C)`.
"""
struct FixedChunk{C} <: ChunkStrategy end
Comment thread
gdalle marked this conversation as resolved.

FixedChunk(C::Int) = FixedChunk{C}()

"""
AutoChunk()

Select chunk size automatically based on internal Enzyme-specific heuristics, which are subject to change.
"""
struct AutoChunk <: ChunkStrategy end
Comment thread
gdalle marked this conversation as resolved.

const DEFAULT_CHUNK_SIZE = 16
Comment thread
gdalle marked this conversation as resolved.

"""
pick_chunksize(s::ChunkStrategy, n::Integer)
pick_chunksize(s::ChunkStrategy, a::AbstractArray)

Compute the chunk size chosen by strategy `s` based on the integer `n` or the array `a` (`n` corresponds to the array's length)
Return a `Val{C}` object.

- In forward-mode Jacobians, `a` would be the input array.
- In reverse-mode Jacobians, `a` would be the output array.

!!! warning
For `LargestChunk` and `AutoChunk` strategies, this function is type-unstable.
"""
function pick_chunksize end

pick_chunksize(::LargestChunk, n::Integer) = Val(n)
pick_chunksize(::LargestChunk, a::AbstractArray) = Val(length(a)) # allows inference on static arrays

pick_chunksize(::AutoChunk, n::Integer) = Val(min(DEFAULT_CHUNK_SIZE, n)) # TODO: improve
pick_chunksize(s::AutoChunk, a::AbstractArray) = pick_chunksize(s, length(a))
pick_chunksize(::FixedChunk{C}, ::Union{Integer,AbstractArray}) where {C} = Val{C}()

end # module EnzymeCore
28 changes: 28 additions & 0 deletions lib/EnzymeCore/test/chunk.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
using Test
using EnzymeCore

@testset "LargestChunk" begin
@test pick_chunksize(LargestChunk(), 10) == Val(10)
@test pick_chunksize(LargestChunk(), ones(10)) == Val(10)
@test pick_chunksize(LargestChunk(), 100) == Val(100)
@test pick_chunksize(LargestChunk(), ones(100)) == Val(100)
end

@testset "FixedChunk" begin
@test FixedChunk(3) == FixedChunk{3}()
@test_throws ErrorException pick_chunksize(FixedChunk{3}(), 2)
@test_throws ErrorException pick_chunksize(FixedChunk{3}(), ones(2))
@test pick_chunksize(FixedChunk{3}(), 10) == Val(3)
@test pick_chunksize(FixedChunk{3}(), ones(10)) == Val(3)
@test pick_chunksize(FixedChunk{3}(), 100) == Val(3)
@test pick_chunksize(FixedChunk{3}(), ones(100)) == Val(3)
@test pick_chunksize(FixedChunk{4}(), 100) == Val(4)
@test pick_chunksize(FixedChunk{4}(), ones(100)) == Val(4)
end

@testset "AutoChunk" begin
@test pick_chunksize(AutoChunk(), 10) == Val(10)
@test pick_chunksize(AutoChunk(), ones(10)) == Val(10)
@test pick_chunksize(AutoChunk(), 100) == Val(16)
@test pick_chunksize(AutoChunk(), ones(100)) == Val(16)
end
17 changes: 9 additions & 8 deletions lib/EnzymeCore/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,13 @@ using EnzymeCore
@testset "Mode modification" begin
include("mode_modification.jl")
end
end

@testset "within_autodiff" begin
@test !EnzymeCore.within_autodiff()
end

@testset "ignore_derivatives" begin
@test EnzymeCore.ignore_derivatives(3) == 3
@testset "Chunk strategy" begin
include("chunk.jl")
end
@testset "within_autodiff" begin
@test !EnzymeCore.within_autodiff()
end
@testset "ignore_derivatives" begin
@test EnzymeCore.ignore_derivatives(3) == 3
end
end
3 changes: 3 additions & 0 deletions src/Enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,11 @@ export autodiff,
make_zero!,
remake_zero!

import EnzymeCore: ChunkStrategy, LargestChunk, FixedChunk, AutoChunk, pick_chunksize

export jacobian, gradient, gradient!, hvp, hvp!, hvp_and_gradient!
export batch_size, onehot, chunkedonehot
export LargestChunk, FixedChunk, AutoChunk

using LinearAlgebra
import SparseArrays
Expand Down
84 changes: 56 additions & 28 deletions src/sugar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,27 @@ gradient!(ReverseWithPrimal, dx, f, [2.0, 3.0])
end
end

const ExtendedChunkStrategy = Union{ChunkStrategy, Nothing, Val}

# eats and returns a type because generated functions work on argument types
get_strategy(chunk::Type{CS}) where {CS<:ChunkStrategy} = chunk

function get_strategy(::Type{Nothing})
Base.depwarn(
"The `chunk=nothing` configuration will be deprecated in a future release. Please use `chunk=LargestChunk()` instead.",
:get_strategy,
)
return LargestChunk()
end

function get_strategy(::Type{Val{C}}) where {C}
Base.depwarn(
"The `chunk=Val(C)` configuration will be deprecated in a future release. Please use `chunk=FixedChunk{C}()` instead.",
:get_strategy,
)
return FixedChunk{C}
end

@inline function chunkedonehot(x, ::Val{chunk}) where {chunk}
sz = length(x)
num = ((sz + chunk - 1) ÷ chunk)
Expand All @@ -428,11 +449,16 @@ end
return ((one(x),),)
end

@inline function chunkedonehot(x, strategy::ChunkStrategy)
return chunkedonehot(x, pick_chunksize(strategy, x))
end

@inline tupleconcat(x) = x
@inline tupleconcat(x, y) = (x..., y...)
@inline tupleconcat(x, y, z...) = (x..., tupleconcat(y, z...)...)

@generated function create_shadows(chunk::ChunkTy, x::X, vargs::Vararg{Any,N}) where {ChunkTy, X, N}
@generated function create_shadows(chunk::ExtendedChunkStrategy, x::X, vargs::Vararg{Any,N}) where {X, N}
chunk_strategy = get_strategy(chunk)
args = Union{Symbol,Expr}[:x]
tys = Type[X]
for i in 1:N
Expand All @@ -446,7 +472,7 @@ end
push!(exprs, :(nothing))
elseif ty <: AbstractFloat
push!(exprs, :(nothing))
elseif ChunkTy == Nothing || ChunkTy == Val{1}
elseif chunk_strategy == LargestChunk || chunk_strategy == FixedChunk{1}
push!(exprs, :(onehot($arg)))
else
push!(exprs, :(chunkedonehot($arg, chunk)))
Expand Down Expand Up @@ -502,10 +528,11 @@ end
@inline specialize_output(output, input) = output

"""
gradient(::ForwardMode, f, x; shadows=onehot(x), chunk=nothing)
gradient(::ForwardMode, f, x, args...; chunk=LargestChunk(), shadows=create_shadows(chunk, x, args...))

Compute the gradient of an array-input function `f` using forward mode. The
optional keyword argument `shadow` is a vector of one-hot vectors of type `x`
Compute the gradient of an array-input function `f` using forward mode.
The optional keyword argument `chunk` denotes the chunk size to use: it can be any instance of [`EnzymeCore.ChunkStrategy`](@ref EnzymeCore.ChunkStrategy).
The optional keyword argument `shadow` is a vector of one-hot vectors of type `x`
which are used to forward-propagate into the return. For performance reasons,
this should be computed once, outside the call to `gradient`, rather than
within this call.
Expand All @@ -530,15 +557,15 @@ gradient(ForwardWithPrimal, f, [2.0, 3.0])
```

```jldoctest gradfwd
gradient(Forward, f, [2.0, 3.0]; chunk=Val(1))
gradient(Forward, f, [2.0, 3.0]; chunk=FixedChunk{1}())

# output

([3.0, 2.0],)
```

```jldoctest gradfwd
gradient(ForwardWithPrimal, f, [2.0, 3.0]; chunk=Val(1))
gradient(ForwardWithPrimal, f, [2.0, 3.0]; chunk=FixedChunk{1}())

# output
(derivs = ([3.0, 2.0],), val = 6.0)
Expand Down Expand Up @@ -587,9 +614,11 @@ gradient(Forward, mul, [2.0, 3.0], Const([2.7, 3.1]))
f::F,
x::ty_0,
args::Vararg{Any,N};
chunk::CS = nothing,
chunk::ExtendedChunkStrategy = LargestChunk(),
shadows::ST = create_shadows(chunk, x, args...),
) where {F, ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity,StrongZero,CS,ST, ty_0, N}
) where {F, ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity,StrongZero,ST, ty_0, N}

chunk_strategy = get_strategy(chunk)

syms = Union{Symbol,Expr}[:x]
shads = Union{Symbol,Expr}[:(shadows[1])]
Expand Down Expand Up @@ -617,7 +646,7 @@ gradient(Forward, mul, [2.0, 3.0], Const([2.7, 3.1]))
end
end

if CS == Val{0}
if chunk_strategy == FixedChunk{0}
return quote
Base.@_inline_meta
throw(ErrorException("Cannot differentiate with a batch size of 0"))
Expand Down Expand Up @@ -659,7 +688,7 @@ gradient(Forward, mul, [2.0, 3.0], Const([2.7, 3.1]))
:($resp[1])
elseif argnum == 0
vals[i]
elseif CS == Nothing
elseif chunk_strategy == LargestChunk
dargs = Union{Symbol,Expr}[]
for (j, arg2) in enumerate(syms)
if i == j
Expand Down Expand Up @@ -688,7 +717,7 @@ gradient(Forward, mul, [2.0, 3.0], Const([2.7, 3.1]))
end

:(values($resp[1]))
elseif CS == Val{1}
elseif chunk_strategy == FixedChunk{1}
subderivatives = Union{Symbol,Expr}[]
for an in 1:argnum
dargs = Union{Symbol,Expr}[]
Expand Down Expand Up @@ -788,7 +817,7 @@ end
"""
jacobian(::ForwardMode, args...; kwargs...)

Equivalent to gradient(::ForwardMode, args...; kwargs...)
Equivalent to `gradient(::ForwardMode, args...; kwargs...)`.
"""
@inline function jacobian(fm::ForwardMode, args...; kwargs...)
gradient(fm, args...; kwargs...)
Expand All @@ -798,10 +827,11 @@ end
mode::ReverseMode{ReturnPrimal},
RT::RType,
n_outs::OutType,
chunk::CT,
chunk::ExtendedChunkStrategy,
f::F,
xs::Vararg{Any, Nargs}
) where {ReturnPrimal,RType, F,Nargs,OutType,CT}
) where {ReturnPrimal,RType, F,Nargs,OutType}
chunk_strategy = get_strategy(chunk)
fty = if f <: Enzyme.Annotation
f.parameters[1]
else
Expand Down Expand Up @@ -891,7 +921,7 @@ end
end
end

if chunk == Val{0}
if chunk_strategy == FixedChunk{0}
return quote
throw(ErrorException("Cannot differentiate with a batch size of 0"))
end
Expand All @@ -912,11 +942,9 @@ end
MDTys = Union{Expr,Symbol}[]
MDTysLast = Union{Expr,Symbol}[]

chunksize = if chunk <: Val
chunk.parameters[1]
else
1
end
chunksize_val = pick_chunksize(chunk_strategy(), n_out_val)
chunksize = typeof(chunksize_val).parameters[1]

num = ((n_out_val + chunksize - 1) ÷ chunksize)

last_size = if num * chunksize == n_out_val
Expand All @@ -940,7 +968,7 @@ end
else
push!(exprs, Expr(:(=), mdi, :(Compiler.active_reg_nothrow($xti) == Compiler.ActiveState || Compiler.active_reg_nothrow($xti) == Compiler.MixedState)))

if chunk == Val{1} || chunk == Nothing
if chunk_strategy == FixedChunk{1}
push!(MDTys, :($mdi ? MixedDuplicated{$xti} : Duplicated{$xti}))
else
push!(MDTys, :($mdi ? BatchMixedDuplicated{$xti, $chunksize} : BatchDuplicated{$xti, $chunksize}))
Expand Down Expand Up @@ -1169,12 +1197,12 @@ end
end

"""
jacobian(::ReverseMode, f, x; n_outs=nothing, chunk=nothing)
jacobian(::ReverseMode, f, x; n_outs=nothing, chunk=LargestChunk())
jacobian(::ReverseMode, f, x)

Compute the jacobian of a array-output function `f` using (potentially vector)
reverse mode. The `chunk` argument optionally denotes the chunk size to use and
`n_outs` optionally denotes the shape of the array returned by `f` (e.g `size(f(x))`).
Compute the jacobian of a array-output function `f` using (potentially vector) reverse mode.
The optional keyword argument `chunk` denotes the chunk size to use: it can be any instance of [`EnzymeCore.ChunkStrategy`](@ref EnzymeCore.ChunkStrategy).
The optional keyword argument `n_outs` denotes the shape of the array returned by `f` (e.g `size(f(x))`).

Example:

Expand Down Expand Up @@ -1227,8 +1255,8 @@ this function will retun an AbstractArray of shape `size(output)` of values of t
f::F,
xs::Vararg{Any, Nargs};
n_outs::OutType = nothing,
chunk::CT = nothing,
) where {F,Nargs, OutType,CT}
chunk::ExtendedChunkStrategy = LargestChunk(),
) where {F,Nargs, OutType}

fty = if f <: Enzyme.Annotation
f.parameters[1]
Expand Down
Loading