Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 73 additions & 0 deletions src/sugar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1160,3 +1160,76 @@ grad
return nothing
end


"""
seeded_autodiff_thunk(
rmode::ReverseModeSplit,
dresult,
f,
ReturnActivity,
annotated_args...
)

Call [`autodiff_thunk`](@ref), execute the forward pass, increment output tangent with `dresult`, then execute the reverse pass.

Useful for computing pullbacks / VJPs for functions whose output is not a scalar.
Comment thread
gdalle marked this conversation as resolved.
Outdated
"""
function seeded_autodiff_thunk(
rmode::ReverseModeSplit{ReturnPrimal},
dresult,
f::FA,
::Type{RA},
args::Vararg{Annotation,N},
) where {ReturnPrimal,FA<:Annotation,RA<:Annotation,N}
forward, reverse = autodiff_thunk(rmode, FA, RA, typeof.(args)...)
tape, result, shadow_result = forward(f, args...)
if RA <: Active
Comment thread
gdalle marked this conversation as resolved.
dinputs = only(reverse(f, args..., dresult, tape))
Comment thread
gdalle marked this conversation as resolved.
Outdated
else
shadow_result .+= dresult # TODO: generalize beyond arrays
dinputs = only(reverse(f, args..., tape))
end
if ReturnPrimal
return (dinputs, result)
else
return (dinputs,)
end
end

"""
batch_seeded_autodiff_thunk(
rmode::ReverseModeSplit,
dresults::NTuple,
f,
ReturnActivity,
annotated_args...
)

Call [`autodiff_thunk`](@ref), execute the forward pass, increment each output tangent with the corresponding element from `dresults`, then execute the reverse pass.

Useful for computing pullbacks / VJPs for functions whose output is not a scalar.
"""
function batch_seeded_autodiff_thunk(
rmode::ReverseModeSplit{ReturnPrimal},
dresults::NTuple{B},
f::FA,
::Type{RA},
args::Vararg{Annotation,N},
) where {ReturnPrimal,B,FA<:Annotation,RA<:Annotation,N}
rmode_rightwidth = ReverseSplitWidth(rmode, Val(B))
forward, reverse = autodiff_thunk(rmode_rightwidth, FA, RA, typeof.(args)...)
tape, result, shadow_results = forward(f, args...)
if RA <: Active
dinputs = only(reverse(f, args..., dresults, tape))
else
foreach(shadow_results, dresults) do d0, d
d0 .+= d # TODO: generalize beyond arrays
end
dinputs = only(reverse(f, args..., tape))
end
if ReturnPrimal
return (dinputs, result)
else
return (dinputs,)
end
end
72 changes: 72 additions & 0 deletions test/sugar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -650,3 +650,75 @@ end
# @show J_r_3(u, A, x)
# @show J_f_3(u, A, x)
end

using Enzyme: seeded_autodiff_thunk, batch_seeded_autodiff_thunk

@testset "seeded_autodiff_thunk" begin

f(x::Vector{Float64}, y::Float64) = sum(abs2, x) * y
g(x::Vector{Float64}, y::Float64) = [f(x, y)]

x = [1.0, 2.0, 3.0]
y = 4.0
dx = similar(x)
dresult = 5.0
dxs = (similar(x), similar(x))
dresults = (5.0, 7.0)

@testset "simple" begin
for mode in (ReverseSplitNoPrimal, ReverseSplitWithPrimal)
make_zero!(dx)
dinputs_and_maybe_result = seeded_autodiff_thunk(mode, dresult, Const(f), Active, Duplicated(x, dx), Active(y))
dinputs = first(dinputs_and_maybe_result)
@test isnothing(dinputs[1])
@test dinputs[2] == dresult * sum(abs2, x)
@test dx == dresult * 2x * y
if mode == ReverseSplitWithPrimal
@test last(dinputs_and_maybe_result) == f(x, y)
end
end

for mode in (ReverseSplitNoPrimal, ReverseSplitWithPrimal)
make_zero!(dx)
dinputs_and_maybe_result = seeded_autodiff_thunk(mode, [dresult], Const(g), Duplicated, Duplicated(x, dx), Active(y))
dinputs = first(dinputs_and_maybe_result)
@test isnothing(dinputs[1])
@test dinputs[2] == dresult * sum(abs2, x)
@test dx == dresult * 2x * y
if mode == ReverseSplitWithPrimal
@test last(dinputs_and_maybe_result) == g(x, y)
end
end
end

@testset "batch" begin
for mode in (ReverseSplitNoPrimal, ReverseSplitWithPrimal)
make_zero!(dxs)
dinputs_and_maybe_result = batch_seeded_autodiff_thunk(mode, dresults, Const(f), Active, BatchDuplicated(x, dxs), Active(y))
dinputs = first(dinputs_and_maybe_result)
@test isnothing(dinputs[1])
@test dinputs[2][1] == dresults[1] * sum(abs2, x)
@test dinputs[2][2] == dresults[2] * sum(abs2, x)
@test dxs[1] == dresults[1] * 2x * y
@test dxs[2] == dresults[2] * 2x * y
if mode == ReverseSplitWithPrimal
@test last(dinputs_and_maybe_result) == f(x, y)
end
end

for mode in (ReverseSplitNoPrimal, ReverseSplitWithPrimal)
make_zero!(dxs)
dinputs_and_maybe_result = batch_seeded_autodiff_thunk(mode, ([dresults[1]], [dresults[2]]), Const(g), BatchDuplicated, BatchDuplicated(x, dxs), Active(y))
dinputs = first(dinputs_and_maybe_result)
@test isnothing(dinputs[1])
@test dinputs[2][1] == dresults[1] * sum(abs2, x)
@test dinputs[2][2] == dresults[2] * sum(abs2, x)
@test dxs[1] == dresults[1] * 2x * y
@test dxs[2] == dresults[2] * 2x * y
if mode == ReverseSplitWithPrimal
@test last(dinputs_and_maybe_result) == g(x, y)
end
end
end

end