Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions lib/EnzymeCore/src/EnzymeCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ export Forward, ForwardWithPrimal, Reverse, ReverseWithPrimal, ReverseSplitNoPri
export ReverseSplitModified, ReverseSplitWidth, ReverseHolomorphic, ReverseHolomorphicWithPrimal
export Const, Active, Duplicated, DuplicatedNoNeed, BatchDuplicated, BatchDuplicatedNoNeed, Annotation
export MixedDuplicated, BatchMixedDuplicated
export Seed, BatchSeed
export DefaultABI, FFIABI, InlineABI, NonGenABI
export BatchDuplicatedFunc
export within_autodiff
Expand Down Expand Up @@ -206,6 +207,24 @@ end
@inline batch_size(::BatchMixedDuplicated{T,N}) where {T,N} = N
@inline batch_size(::Type{BatchMixedDuplicated{T,N}}) where {T,N} = N

"""
Seed(dy)

Wrapper for a single adjoint to the return value in reverse mode.
"""
Comment thread
vchuravy marked this conversation as resolved.
Outdated
struct Seed{T}
dval::T
end

"""
BatchSeed(dys::NTuple)

Wrapper for a tuple of adjoints to the return value in reverse mode.
"""
struct BatchSeed{T, N}
dvals::NTuple{T, N}
end

"""
abstract type ABI

Expand Down
7 changes: 6 additions & 1 deletion src/Enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ import EnzymeCore:
DuplicatedNoNeed,
BatchDuplicated,
BatchDuplicatedNoNeed,
Seed,
BatchSeed,
ABI,
DefaultABI,
FFIABI,
Expand All @@ -52,14 +54,17 @@ import EnzymeCore:
clear_runtime_activity,
within_autodiff,
WithPrimal,
NoPrimal
NoPrimal,
Split
export Annotation,
Const,
Active,
Duplicated,
DuplicatedNoNeed,
BatchDuplicated,
BatchDuplicatedNoNeed,
Seed,
BatchSeed,
DefaultABI,
FFIABI,
InlineABI,
Expand Down
79 changes: 79 additions & 0 deletions src/sugar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1159,3 +1159,82 @@ grad
)
return nothing
end

"""
autodiff(
rmode::Union{ReverseMode,ReverseModeSplit},
f::Annotation,
ReturnActivity::Type{<:Annotation},
dresult::Seed,
Comment thread
gdalle marked this conversation as resolved.
Outdated
annotated_args...
)

Call [`autodiff_thunk`](@ref) in split mode, execute the forward pass, increment output adjoint with `dresult`, then execute the reverse pass.

Useful for computing pullbacks / VJPs for functions whose output is not a scalar.
Comment thread
gdalle marked this conversation as resolved.
Outdated
"""
function autodiff(
rmode::Union{ReverseMode{ReturnPrimal}, ReverseModeSplit{ReturnPrimal}},
f::FA,
Comment thread
gdalle marked this conversation as resolved.
::Type{RA},
dresult::Seed,
args::Vararg{Annotation, N},
) where {ReturnPrimal, FA <: Annotation, RA <: Annotation, N}
if RA === Const
throw(ArgumentError("Return activity cannot be `Const`."))
end
forward, reverse = autodiff_thunk(Split(rmode), FA, RA, typeof.(args)...)
tape, result, shadow_result = forward(f, args...)
if RA <: Active
Comment thread
gdalle marked this conversation as resolved.
dinputs = only(reverse(f, args..., dresult.dval, tape))
else
Compiler.recursive_accumulate(shadow_result, dresult.dval)
dinputs = only(reverse(f, args..., tape))
end
if ReturnPrimal
return (dinputs, result)
else
return (dinputs,)
end
end

"""
autodiff(
rmode::Union{ReverseMode,ReverseModeSplit},
f::Annotation,
ReturnActivity::Type{<:Annotation},
dresults::BatchSeed,
annotated_args...
)

Call [`autodiff_thunk`](@ref) in split mode, execute the forward pass, increment each output adjoint with the corresponding element from `dresults`, then execute the reverse pass.

Useful for computing pullbacks / VJPs for functions whose output is not a scalar.
"""
function autodiff(
rmode::Union{ReverseMode{ReturnPrimal}, ReverseModeSplit{ReturnPrimal}},
f::FA,
::Type{RA},
dresults::BatchSeed{B},
args::Vararg{Annotation, N},
) where {ReturnPrimal, B, FA <: Annotation, RA <: Annotation, N}
if RA === Const
throw(ArgumentError("Return activity cannot be `Const`."))
end
rmode_rightwidth = ReverseSplitWidth(Split(rmode), Val(B))
forward, reverse = autodiff_thunk(rmode_rightwidth, FA, RA, typeof.(args)...)
tape, result, shadow_results = forward(f, args...)
if RA <: Active
dinputs = only(reverse(f, args..., dresults.dvals, tape))
else
foreach(shadow_results, dresults.dvals) do d0, d
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vchuravy knows better than me here, but I don't think we should use foreach to ensure type stability, instead explicitly using either a generated function, or ntuple to go through?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have an example in mind where foreach would be unstable?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

foreach for tuples is fine. It boils down to ntuple

Compiler.recursive_accumulate(d0, d)
end
dinputs = only(reverse(f, args..., tape))
end
if ReturnPrimal
return (dinputs, result)
else
return (dinputs,)
end
end
74 changes: 74 additions & 0 deletions test/sugar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -650,3 +650,77 @@ end
# @show J_r_3(u, A, x)
# @show J_f_3(u, A, x)
end

@testset "Seeded reverse autodiff" begin

f(x::Vector{Float64}, y::Float64) = sum(abs2, x) * y
g(x::Vector{Float64}, y::Float64) = [f(x, y)]

x = [1.0, 2.0, 3.0]
y = 4.0
dx = similar(x)
dresult = 5.0
dxs = (similar(x), similar(x))
dresults = (5.0, 7.0)

@testset "simple" begin
@test_throws ArgumentError autodiff(Reverse, Const(f), Const, Seed(dresult), Duplicated(x, dx), Active(y))

for mode in (Reverse, ReverseWithPrimal, ReverseSplitNoPrimal, ReverseSplitWithPrimal)
make_zero!(dx)
dinputs_and_maybe_result = autodiff(mode, Const(f), Active, Seed(dresult), Duplicated(x, dx), Active(y))
dinputs = first(dinputs_and_maybe_result)
@test isnothing(dinputs[1])
@test dinputs[2] == dresult * sum(abs2, x)
@test dx == dresult * 2x * y
if Enzyme.Split(mode) == ReverseSplitWithPrimal
@test last(dinputs_and_maybe_result) == f(x, y)
end
end

for mode in (Reverse, ReverseWithPrimal, ReverseSplitNoPrimal, ReverseSplitWithPrimal)
make_zero!(dx)
dinputs_and_maybe_result = autodiff(mode, Const(g), Duplicated, Seed([dresult]), Duplicated(x, dx), Active(y))
dinputs = first(dinputs_and_maybe_result)
@test isnothing(dinputs[1])
@test dinputs[2] == dresult * sum(abs2, x)
@test dx == dresult * 2x * y
if Enzyme.Split(mode) == ReverseSplitWithPrimal
@test last(dinputs_and_maybe_result) == g(x, y)
end
end
end

@testset "batch" begin
@test_throws ArgumentError autodiff(Reverse, Const(f), Const, BatchSeed(dresults), BatchDuplicated(x, dxs), Active(y))

for mode in (Reverse, ReverseWithPrimal, ReverseSplitNoPrimal, ReverseSplitWithPrimal)
make_zero!(dxs)
dinputs_and_maybe_result = autodiff(mode, Const(f), Active, BatchSeed(dresults), BatchDuplicated(x, dxs), Active(y))
dinputs = first(dinputs_and_maybe_result)
@test isnothing(dinputs[1])
@test dinputs[2][1] == dresults[1] * sum(abs2, x)
@test dinputs[2][2] == dresults[2] * sum(abs2, x)
@test dxs[1] == dresults[1] * 2x * y
@test dxs[2] == dresults[2] * 2x * y
if Enzyme.Split(mode) == ReverseSplitWithPrimal
@test last(dinputs_and_maybe_result) == f(x, y)
end
end

for mode in (Reverse, ReverseWithPrimal, ReverseSplitNoPrimal, ReverseSplitWithPrimal)
make_zero!(dxs)
dinputs_and_maybe_result = autodiff(mode, Const(g), BatchDuplicated, BatchSeed(([dresults[1]], [dresults[2]])), BatchDuplicated(x, dxs), Active(y))
dinputs = first(dinputs_and_maybe_result)
@test isnothing(dinputs[1])
@test dinputs[2][1] == dresults[1] * sum(abs2, x)
@test dinputs[2][2] == dresults[2] * sum(abs2, x)
@test dxs[1] == dresults[1] * 2x * y
@test dxs[2] == dresults[2] * 2x * y
if Enzyme.Split(mode) == ReverseSplitWithPrimal
@test last(dinputs_and_maybe_result) == g(x, y)
end
end
end

end