We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent fff05c2 commit 649e797Copy full SHA for 649e797
1 file changed
src/rulesets/LinearAlgebra/dense.jl
@@ -405,10 +405,12 @@ end
405
end
406
407
function rrule(::typeof(kron), x::AbstractVector{<:Number}, y::AbstractVector{<:Number})
408
+ project_x = ProjectTo(x)
409
+ project_y = ProjectTo(y)
410
function kron_pullback(z̄)
411
dz = reshape(unthunk(z̄), length(y), length(x))
- x̄ = @thunk conj.(dz' * y)
- ȳ = @thunk dz * conj.(x)
412
+ x̄ = @thunk(project_x(conj.(dz' * y)))
413
+ ȳ = @thunk(project_y(dz * conj.(x)))
414
return NoTangent(), x̄, ȳ
415
416
return kron(x, y), kron_pullback
0 commit comments