Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions lib/EnzymeCore/src/EnzymeCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ export Forward, ForwardWithPrimal, Reverse, ReverseWithPrimal, ReverseSplitNoPri
export ReverseSplitModified, ReverseSplitWidth, ReverseHolomorphic, ReverseHolomorphicWithPrimal
export Const, Active, Duplicated, DuplicatedNoNeed, BatchDuplicated, BatchDuplicatedNoNeed, Annotation
export MixedDuplicated, BatchMixedDuplicated
export Seed, BatchSeed
export DefaultABI, FFIABI, InlineABI, NonGenABI
export BatchDuplicatedFunc
export within_autodiff
Expand Down Expand Up @@ -206,6 +207,24 @@ end
@inline batch_size(::BatchMixedDuplicated{T,N}) where {T,N} = N
@inline batch_size(::Type{BatchMixedDuplicated{T,N}}) where {T,N} = N

"""
Seed(dy)

Wrapper for a single adjoint to the return value in reverse mode.
"""
Comment thread
vchuravy marked this conversation as resolved.
Outdated
struct Seed{T}
dval::T
end

"""
BatchSeed(dys::NTuple)

Wrapper for a tuple of adjoints to the return value in reverse mode.
"""
struct BatchSeed{T,N}
dvals::NTuple{T,N}
end

"""
abstract type ABI

Expand Down
4 changes: 4 additions & 0 deletions src/Enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ import EnzymeCore:
DuplicatedNoNeed,
BatchDuplicated,
BatchDuplicatedNoNeed,
Seed,
BatchSeed,
ABI,
DefaultABI,
FFIABI,
Expand All @@ -60,6 +62,8 @@ export Annotation,
DuplicatedNoNeed,
BatchDuplicated,
BatchDuplicatedNoNeed,
Seed,
BatchSeed,
DefaultABI,
FFIABI,
InlineABI,
Expand Down
73 changes: 73 additions & 0 deletions src/sugar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1160,3 +1160,76 @@ grad
return nothing
end


"""
autodiff(
rmode::ReverseModeSplit,
f::Annotation,
ReturnActivity::Type{<:Annotation},
dresult::Seed,
Comment thread
gdalle marked this conversation as resolved.
Outdated
annotated_args...
)

Call [`autodiff_thunk`](@ref), execute the forward pass, increment output adjoint with `dresult`, then execute the reverse pass.

Useful for computing pullbacks / VJPs for functions whose output is not a scalar.
Comment thread
gdalle marked this conversation as resolved.
Outdated
"""
function autodiff(
rmode::ReverseModeSplit{ReturnPrimal},
f::FA,
::Type{RA},
dresult::Seed,
args::Vararg{Annotation,N},
) where {ReturnPrimal,FA<:Annotation,RA<:Annotation,N}
forward, reverse = autodiff_thunk(rmode, FA, RA, typeof.(args)...)
tape, result, shadow_result = forward(f, args...)
if RA <: Active
Comment thread
gdalle marked this conversation as resolved.
dinputs = only(reverse(f, args..., dresult.dval, tape))
else
shadow_result .+= dresult.dval # TODO: generalize beyond arrays
Comment thread
gdalle marked this conversation as resolved.
Outdated
dinputs = only(reverse(f, args..., tape))
end
if ReturnPrimal
return (dinputs, result)
else
return (dinputs,)
end
end

"""
autodiff(
rmode::ReverseModeSplit,
f::Annotation,
ReturnActivity::Type{<:Annotation},
dresults::BatchSeed,
annotated_args...
)

Call [`autodiff_thunk`](@ref), execute the forward pass, increment each output adjoint with the corresponding element from `dresults`, then execute the reverse pass.

Useful for computing pullbacks / VJPs for functions whose output is not a scalar.
"""
function autodiff(
rmode::ReverseModeSplit{ReturnPrimal},
f::FA,
::Type{RA},
dresults::BatchSeed{B},
args::Vararg{Annotation,N},
) where {ReturnPrimal,B,FA<:Annotation,RA<:Annotation,N}
rmode_rightwidth = ReverseSplitWidth(rmode, Val(B))
forward, reverse = autodiff_thunk(rmode_rightwidth, FA, RA, typeof.(args)...)
tape, result, shadow_results = forward(f, args...)
if RA <: Active
dinputs = only(reverse(f, args..., dresults.dvals, tape))
else
foreach(shadow_results, dresults.dvals) do d0, d
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vchuravy knows better than me here, but I don't think we should use foreach to ensure type stability, instead explicitly using either a generated function, or ntuple to go through?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have an example in mind where foreach would be unstable?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

foreach for tuples is fine. It boils down to ntuple

d0 .+= d # TODO: generalize beyond arrays
end
dinputs = only(reverse(f, args..., tape))
end
if ReturnPrimal
return (dinputs, result)
else
return (dinputs,)
end
end
70 changes: 70 additions & 0 deletions test/sugar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -650,3 +650,73 @@ end
# @show J_r_3(u, A, x)
# @show J_f_3(u, A, x)
end

@testset "Seeded reverse autodiff" begin

f(x::Vector{Float64}, y::Float64) = sum(abs2, x) * y
g(x::Vector{Float64}, y::Float64) = [f(x, y)]

x = [1.0, 2.0, 3.0]
y = 4.0
dx = similar(x)
dresult = 5.0
dxs = (similar(x), similar(x))
dresults = (5.0, 7.0)

@testset "simple" begin
for mode in (ReverseSplitNoPrimal, ReverseSplitWithPrimal)
make_zero!(dx)
dinputs_and_maybe_result = autodiff(mode, Const(f), Active, Seed(dresult), Duplicated(x, dx), Active(y))
dinputs = first(dinputs_and_maybe_result)
@test isnothing(dinputs[1])
@test dinputs[2] == dresult * sum(abs2, x)
@test dx == dresult * 2x * y
if mode == ReverseSplitWithPrimal
@test last(dinputs_and_maybe_result) == f(x, y)
end
end

for mode in (ReverseSplitNoPrimal, ReverseSplitWithPrimal)
make_zero!(dx)
dinputs_and_maybe_result = autodiff(mode, Const(g), Duplicated, Seed([dresult]), Duplicated(x, dx), Active(y))
dinputs = first(dinputs_and_maybe_result)
@test isnothing(dinputs[1])
@test dinputs[2] == dresult * sum(abs2, x)
@test dx == dresult * 2x * y
if mode == ReverseSplitWithPrimal
@test last(dinputs_and_maybe_result) == g(x, y)
end
end
end

@testset "batch" begin
for mode in (ReverseSplitNoPrimal, ReverseSplitWithPrimal)
make_zero!(dxs)
dinputs_and_maybe_result = autodiff(mode, Const(f), Active, BatchSeed(dresults), BatchDuplicated(x, dxs), Active(y))
dinputs = first(dinputs_and_maybe_result)
@test isnothing(dinputs[1])
@test dinputs[2][1] == dresults[1] * sum(abs2, x)
@test dinputs[2][2] == dresults[2] * sum(abs2, x)
@test dxs[1] == dresults[1] * 2x * y
@test dxs[2] == dresults[2] * 2x * y
if mode == ReverseSplitWithPrimal
@test last(dinputs_and_maybe_result) == f(x, y)
end
end

for mode in (ReverseSplitNoPrimal, ReverseSplitWithPrimal)
make_zero!(dxs)
dinputs_and_maybe_result = autodiff(mode, Const(g), BatchDuplicated, BatchSeed(([dresults[1]], [dresults[2]])), BatchDuplicated(x, dxs), Active(y))
dinputs = first(dinputs_and_maybe_result)
@test isnothing(dinputs[1])
@test dinputs[2][1] == dresults[1] * sum(abs2, x)
@test dinputs[2][2] == dresults[2] * sum(abs2, x)
@test dxs[1] == dresults[1] * 2x * y
@test dxs[2] == dresults[2] * 2x * y
if mode == ReverseSplitWithPrimal
@test last(dinputs_and_maybe_result) == g(x, y)
end
end
end

end