Skip to content
Draft
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions DifferentiationInterface/Project.toml
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
name = "DifferentiationInterface"
uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
authors = ["Guillaume Dalle", "Adrian Hill"]
version = "0.7.16"
authors = ["Guillaume Dalle", "Adrian Hill"]

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

[weakdeps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
Diffractor = "9f5e2b26-1114-432f-b630-d3fe2085c51c"
Expand Down Expand Up @@ -38,7 +39,7 @@ DifferentiationInterfaceFastDifferentiationExt = "FastDifferentiation"
DifferentiationInterfaceFiniteDiffExt = "FiniteDiff"
DifferentiationInterfaceFiniteDifferencesExt = "FiniteDifferences"
DifferentiationInterfaceForwardDiffExt = ["ForwardDiff", "DiffResults"]
DifferentiationInterfaceGPUArraysCoreExt = "GPUArraysCore"
DifferentiationInterfaceGPUArraysCoreExt = ["GPUArraysCore", "Adapt"]
DifferentiationInterfaceGTPSAExt = "GTPSA"
DifferentiationInterfaceMooncakeExt = "Mooncake"
DifferentiationInterfacePolyesterForwardDiffExt = [
Expand All @@ -56,6 +57,7 @@ DifferentiationInterfaceTrackerExt = "Tracker"
DifferentiationInterfaceZygoteExt = ["Zygote", "ForwardDiff"]

[compat]
Adapt = "4.5.0"
ADTypes = "1.18.0"
ChainRulesCore = "1.23.0"
DiffResults = "1.1.0"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ using ChainRulesCore:
RuleConfig,
frule_via_ad,
rrule_via_ad,
unthunk
unthunk,
@not_implemented
import DifferentiationInterface as DI

ruleconfig(backend::AutoChainRules) = backend.ruleconfig
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,22 @@
function ChainRulesCore.rrule(dw::DI.DifferentiateWith, x)
(; f, backend) = dw
y = f(x)
prep_same = DI.prepare_pullback_same_point_nokwarg(Val(false), f, backend, x, (y,))
function pullbackfunc(dy)
tx = DI.pullback(f, prep_same, backend, x, (dy,))
return (NoTangent(), only(tx))
function ChainRulesCore.rrule(
dw::DI.DifferentiateWith{C}, x, contexts::Vararg{Any, C}
) where {C}
(; f, backend, context_wrappers) = dw
y = f(x, contexts...)
wrapped_contexts = map(DI.call, context_wrappers, contexts)
prep_same = DI.prepare_pullback_same_point_nokwarg(
Val(false), f, backend, x, (y,), wrapped_contexts...
)
function diffwith_pullbackfunc(dy)
dx = DI.pullback(f, prep_same, backend, x, (dy,), wrapped_contexts...) |> only
dc = map(contexts) do c
@not_implemented(
"""
Derivatives with respect to context arguments are not implemented.
"""
)
end
return (NoTangent(), dx, dc...)
end
return y, pullbackfunc
return y, diffwith_pullbackfunc
end
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
function (dw::DI.DifferentiateWith)(x::Dual{T, V, N}) where {T, V, N}
function (dw::DI.DifferentiateWith{0})(x::Dual{T, V, N}) where {T, V, N}
(; f, backend) = dw
xval = myvalue(T, x)
tx = mypartials(T, Val(N), x)
y, ty = DI.value_and_pushforward(f, backend, xval, tx)
return make_dual(T, y, ty)
end

function (dw::DI.DifferentiateWith)(x::AbstractArray{Dual{T, V, N}}) where {T, V, N}
function (dw::DI.DifferentiateWith{0})(x::AbstractArray{Dual{T, V, N}}) where {T, V, N}
(; f, backend) = dw
xval = myvalue(T, x)
tx = mypartials(T, Val(N), x)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
module DifferentiationInterfaceGPUArraysCoreExt

using Adapt: adapt
import DifferentiationInterface as DI
using GPUArraysCore: @allowscalar, AbstractGPUArray

Expand All @@ -17,4 +18,10 @@ function DI.multibasis(a::AbstractGPUArray{T}, inds) where {T}
return b
end

function DI.arroftup_to_tupofarr(
tx::AbstractArray{<:NTuple{B, <:Number}}, x::AbstractGPUArray{<:Number}
) where {B}
return ntuple(b -> adapt(typeof(x), getindex.(tx, b)), Val(B))
end

end
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ using Mooncake:
value_and_pullback!!,
zero_dual,
zero_tangent,
zero_rdata,
rdata_type,
fdata,
rdata,
Expand All @@ -26,11 +27,13 @@ using Mooncake:
@is_primitive,
zero_fcodual,
MinimalCtx,
NoFData,
NoRData,
primal,
_copy_output,
_copy_to_output!!,
tangent_to_primal!!
tangent_to_primal!!,
increment!!

const AnyAutoMooncake{C} = Union{AutoMooncake{C}, AutoMooncakeForward{C}}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
@is_primitive MinimalCtx Tuple{DI.DifferentiateWith, <:Any}
const NumberOrArray = Union{Number, AbstractArray{<:Number}}
@is_primitive MinimalCtx Tuple{DI.DifferentiateWith{0}, Any}
@is_primitive MinimalCtx Tuple{DI.DifferentiateWith{1}, Any, Any}
@is_primitive MinimalCtx Tuple{DI.DifferentiateWith{2}, Any, Any, Any}
@is_primitive MinimalCtx Tuple{DI.DifferentiateWith{3}, Any, Any, Any, Any}
# TODO: generate more cases programmatically

Comment thread
gdalle marked this conversation as resolved.
Outdated
struct MooncakeDifferentiateWithError <: Exception
F::Type
Expand All @@ -12,72 +17,87 @@ end
function Base.showerror(io::IO, e::MooncakeDifferentiateWithError)
return print(
io,
"MooncakeDifferentiateWithError: For the function type $(e.F) and input type $(e.X), the output type $(e.Y) is currently not supported.",
"MooncakeDifferentiateWithError: For the function type `$(e.F)` and input types `$(e.X)`, the output type `$(e.Y)` is currently not supported.",
)
end

function Mooncake.rrule!!(dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{<:Number})
function Mooncake.rrule!!(
dw::CoDual{<:DI.DifferentiateWith{C}},
x::CoDual{<:Number},
contexts::Vararg{CoDual{<:NumberOrArray}, C}
) where {C}
@assert tangent_type(typeof(dw)) == NoTangent
primal_func = primal(dw)
primal_x = primal(x)
(; f, backend) = primal_func
y = zero_fcodual(f(primal_x))
primal_contexts = map(primal, contexts)
(; f, backend, context_wrappers) = primal_func
y = zero_fcodual(f(primal_x, primal_contexts...))
wrapped_primal_contexts = map(DI.call, context_wrappers, primal_contexts)

# output is a vector, so we need to use the vector pullback
function pullback_array!!(dy::NoRData)
tx = DI.pullback(f, backend, primal_x, (y.dx,))
@assert rdata(only(tx)) isa rdata_type(tangent_type(typeof(primal_x)))
return NoRData(), rdata(only(tx))
dx = DI.pullback(f, backend, primal_x, (y.dx,), wrapped_primal_contexts...) |> only
@assert rdata(only(dx)) isa rdata_type(tangent_type(typeof(primal_x)))
Comment thread
gdalle marked this conversation as resolved.
Outdated
rc = nanify_fdata_and_rdata!!(contexts...)
return (NoRData(), rdata(dx), rc...)
end

# output is a scalar, so we can use the scalar pullback
function pullback_scalar!!(dy::Number)
tx = DI.pullback(f, backend, primal_x, (dy,))
@assert rdata(only(tx)) isa rdata_type(tangent_type(typeof(primal_x)))
return NoRData(), rdata(only(tx))
dx = DI.pullback(f, backend, primal_x, (dy,), wrapped_primal_contexts...) |> only
@assert rdata(dx) isa rdata_type(tangent_type(typeof(primal_x)))
rc = nanify_fdata_and_rdata!!(contexts...)
return (NoRData(), rdata(dx), rc...)
end

pullback = if primal(y) isa Number
pullback_scalar!!
elseif primal(y) isa AbstractArray
pullback_array!!
else
throw(MooncakeDifferentiateWithError(primal_func, primal_x, primal(y)))
throw(MooncakeDifferentiateWithError(primal_func, (primal_x, primal_contexts...), primal(y)))
end

return y, pullback
end

function Mooncake.rrule!!(
dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{<:AbstractArray{<:Number}}
)
dw::CoDual{<:DI.DifferentiateWith{C}},
x::CoDual{<:AbstractArray{<:Number}},
contexts::Vararg{CoDual{<:NumberOrArray}, C}
) where {C}
@assert tangent_type(typeof(dw)) == NoTangent
primal_func = primal(dw)
primal_x = primal(x)
fdata_arg = x.dx
(; f, backend) = primal_func
y = zero_fcodual(f(primal_x))
primal_contexts = map(primal, contexts)
(; f, backend, context_wrappers) = primal_func
y = zero_fcodual(f(primal_x, primal_contexts...))
wrapped_primal_contexts = map(DI.call, context_wrappers, primal_contexts)

# output is a vector, so we need to use the vector pullback
function pullback_array!!(dy::NoRData)
tx = DI.pullback(f, backend, primal_x, (y.dx,))
@assert rdata(first(only(tx))) isa rdata_type(tangent_type(typeof(first(primal_x))))
fdata_arg .+= only(tx)
return NoRData(), dy
dx = DI.pullback(f, backend, primal_x, (y.dx,), wrapped_primal_contexts...) |> only
@assert rdata(dx) isa rdata_type(tangent_type(typeof(primal_x)))
x.dx .+= dx
rc = nanify_fdata_and_rdata!!(contexts...)
return (NoRData(), dy, rc...)
end

# output is a scalar, so we can use the scalar pullback
function pullback_scalar!!(dy::Number)
tx = DI.pullback(f, backend, primal_x, (dy,))
@assert rdata(first(only(tx))) isa rdata_type(tangent_type(typeof(first(primal_x))))
fdata_arg .+= only(tx)
return NoRData(), NoRData()
dx = DI.pullback(f, backend, primal_x, (dy,), wrapped_primal_contexts...) |> only
@assert rdata(dx) isa rdata_type(tangent_type(typeof(primal_x)))
x.dx .+= dx
rc = nanify_fdata_and_rdata!!(contexts...)
return (NoRData(), NoRData(), rc...)
end

pullback = if primal(y) isa Number
pullback_scalar!!
elseif primal(y) isa AbstractArray
pullback_array!!
else
throw(MooncakeDifferentiateWithError(primal_func, primal_x, primal(y)))
throw(MooncakeDifferentiateWithError(primal_func, (primal_x, primal_contexts...), primal(y)))
end

return y, pullback
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,19 @@ function zero_tangent_or_primal(x, backend::AnyAutoMooncake)
return zero_tangent(x)
end
end

nanify(x::AbstractFloat) = convert(typeof(x), NaN)
nanify(x::AbstractArray) = map(nan_tangent, x)
nanify(x::Union{Tuple, NamedTuple}) = map(nan_tangent, x)
Comment thread
gdalle marked this conversation as resolved.
Outdated
nanify(::NoFData) = NoFData()
nanify(::NoRData) = NoRData()

function nanify_fdata_and_rdata!!(contexts::Vararg{CoDual, C}) where {C}
primal_contexts = map(primal, contexts)
fdata_contexts = map(tangent, contexts)
zero_rdata_contexts = map(zero_rdata, primal_contexts)
foreach(fdata_contexts) do fc
increment!!(fc, nanify(fc))
end
return map(nanify, zero_rdata_contexts)
end
16 changes: 8 additions & 8 deletions DifferentiationInterface/src/first_order/pullback.jl
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ function _value_and_pullback_via_pushforward(
tx = map(ty) do dy
dot(a, dy)
end
return y, arroftup_to_tupofarr(tx)
return y, arroftup_to_tupofarr(tx, x)
end

function _value_and_pullback_via_pushforward(
Expand All @@ -348,7 +348,7 @@ function _value_and_pullback_via_pushforward(
tx = map(ty) do dy
real(dot(a, dy)) + im * real(dot(b, dy))
end
return y, arroftup_to_tupofarr(tx)
return y, arroftup_to_tupofarr(tx, x)
end

function _value_and_pullback_via_pushforward(
Expand All @@ -366,7 +366,7 @@ function _value_and_pullback_via_pushforward(
dot(a, dy)
end
end
return y, arroftup_to_tupofarr(tx)
return y, arroftup_to_tupofarr(tx, x)
end

function _value_and_pullback_via_pushforward(
Expand All @@ -387,7 +387,7 @@ function _value_and_pullback_via_pushforward(
real(dot(a, dy)) + im * real(dot(b, dy))
end
end
return y, arroftup_to_tupofarr(tx)
return y, arroftup_to_tupofarr(tx, x)
end

function value_and_pullback(
Expand Down Expand Up @@ -458,7 +458,7 @@ function _value_and_pullback_via_pushforward(
tx = map(ty) do dy
dot(a, dy)
end
return y, arroftup_to_tupofarr(tx)
return y, arroftup_to_tupofarr(tx, x)
end

function _value_and_pullback_via_pushforward(
Expand All @@ -477,7 +477,7 @@ function _value_and_pullback_via_pushforward(
tx = map(ty) do dy
real(dot(a, dy)) + im * real(dot(b, dy))
end
return y, arroftup_to_tupofarr(tx)
return y, arroftup_to_tupofarr(tx, x)
end

function _value_and_pullback_via_pushforward(
Expand All @@ -495,7 +495,7 @@ function _value_and_pullback_via_pushforward(
dot(a, dy)
end
end
return y, arroftup_to_tupofarr(tx)
return y, arroftup_to_tupofarr(tx, x)
end

function _value_and_pullback_via_pushforward(
Expand All @@ -518,7 +518,7 @@ function _value_and_pullback_via_pushforward(
real(dot(a, dy)) + im * real(dot(b, dy))
end
end
return y, arroftup_to_tupofarr(tx)
return y, arroftup_to_tupofarr(tx, x)
end

function value_and_pullback(
Expand Down
12 changes: 6 additions & 6 deletions DifferentiationInterface/src/first_order/pushforward.jl
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ function _value_and_pushforward_via_pullback(
ty = map(tx) do dx
dot(a, dx)
end
return y, arroftup_to_tupofarr(ty)
return y, arroftup_to_tupofarr(ty, y)
end

function _value_and_pushforward_via_pullback(
Expand All @@ -348,7 +348,7 @@ function _value_and_pushforward_via_pullback(
ty = map(tx) do dx
real(dot(a, dx)) + im * real(dot(b, dx))
end
return y, arroftup_to_tupofarr(ty)
return y, arroftup_to_tupofarr(ty, y)
end

function _value_and_pushforward_via_pullback(
Expand All @@ -367,7 +367,7 @@ function _value_and_pushforward_via_pullback(
dot(a, dx)
end
end
return y, arroftup_to_tupofarr(ty)
return y, arroftup_to_tupofarr(ty, y)
end

function _value_and_pushforward_via_pullback(
Expand All @@ -387,7 +387,7 @@ function _value_and_pushforward_via_pullback(
real(dot(a, dx)) + im * real(dot(b, dx))
end
end
return y, arroftup_to_tupofarr(ty)
return y, arroftup_to_tupofarr(ty, y)
end

function value_and_pushforward(
Expand Down Expand Up @@ -460,7 +460,7 @@ function _value_and_pushforward_via_pullback(
dot(a, dx)
end
end
return y, arroftup_to_tupofarr(ty)
return y, arroftup_to_tupofarr(ty, y)
end

function _value_and_pushforward_via_pullback(
Expand All @@ -481,7 +481,7 @@ function _value_and_pushforward_via_pullback(
real(dot(a, dx)) + im * real(dot(b, dx))
end
end
return y, arroftup_to_tupofarr(ty)
return y, arroftup_to_tupofarr(ty, y)
end

function value_and_pushforward(
Expand Down
Loading
Loading