-
Notifications
You must be signed in to change notification settings - Fork 97
Expand file tree
/
Copy pathsort.jl
More file actions
103 lines (86 loc) · 3.09 KB
/
sort.jl
File metadata and controls
103 lines (86 loc) · 3.09 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
#####
##### `sort`
#####
function frule((_, ẋs, _), ::typeof(partialsort), xs::AbstractVector, k; kw...)
inds = partialsortperm(xs, k; kw...)
return xs[inds], ẋs[inds]
end
function rrule(::typeof(partialsort), xs::AbstractVector, k::Union{Integer,OrdinalRange}; kwargs...)
inds = partialsortperm(xs, k; kwargs...)
ys = xs[inds]
function partialsort_pullback(Δys)
function partialsort_add!(Δxs)
Δxs[inds] += Δys
return Δxs
end
Δxs = InplaceableThunk(partialsort_add!, @thunk(partialsort_add!(zero(xs))))
return NoTangent(), Δxs, NoTangent()
end
return ys, partialsort_pullback
end
function frule((_, ẋs), ::typeof(sort), xs::AbstractArray; kw...)
inds = sortperm(xs; kw...)
return xs[inds], ẋs[inds]
end
function rrule(::typeof(sort), xs::AbstractArray; kwargs...)
inds = sortperm(xs; kwargs...)
ys = xs[inds]
function sort_pullback(ȳ)
Δys = unthunk(ȳ)
function sort_add!(Δxs)
Δxs[inds] += Δys
return Δxs
end
Δxs = InplaceableThunk(sort_add!, @thunk(sort_add!(zero(Δys))))
return NoTangent(), Δxs
end
return ys, sort_pullback
end
#####
##### `sortslices`
#####
function frule((_, ẋ), ::typeof(sortslices), x::AbstractArray; dims::Integer, kw...)
p = sortperm(collect(eachslice(x; dims=dims)); kw...)
inds = ntuple(d -> d == dims ? p : (:), ndims(x))
return x[inds...], ẋ[inds...]
end
function rrule(::typeof(sortslices), x::AbstractArray; dims::Integer, kw...)
p = sortperm(collect(eachslice(x; dims=dims)); kw...)
inds = ntuple(d -> d == dims ? p : (:), ndims(x))
function sortslices_pullback(dy)
return (NoTangent(), ∇getindex(x, unthunk(dy), inds...))
end
return x[inds...], sortslices_pullback
end
#####
##### `unique`
#####
function rrule(::typeof(unique), x::AbstractArray{<:Number}; dims=:)
axes_x = axes(x)
y = unique(x; dims=dims) # accepts only dims=: or dims::Integer
function unique_pullback(dy_raw)
dy = unthunk(dy_raw)
if length(x) == length(y)
# Short-circuit for the case of all unique, since `mask` is fairly expensive:
dx = reshape(dy, axes_x)
return (NoTangent(), ProjectTo(x)(dx))
end
if dims isa Colon
xs, ys = vec(x), y
else
xs, ys = collect(eachslice(x; dims=dims)), collect(eachslice(y; dims=dims))
end
mask = isequal.(permutedims(ys), xs) # unique([0.0, -0.0, NaN, NaN])
mask .= (mask .== cumsum(mask, dims=1) .== true) # this implements findfirst(mask; dims=1)
keep = map(I -> I[1], findall(mask))
if dims isa Colon
# The function `∇getindex` allows second derivatives.
dx = reshape(∇getindex(vec(x), vec(dy), keep), axes_x) ## TODO understand again why vec!
else
inds = ntuple(d -> d==dims ? keep : (:), length(axes_x))
dx = ∇getindex(x, dy, inds...)
end
return (NoTangent(), ProjectTo(x)(dx))
end
return y, unique_pullback
end