@@ -405,37 +405,39 @@ end
405405
406406function rrule (:: typeof (kron), x:: AbstractVector{<:Number} , y:: AbstractVector{<:Number} )
407407 function kron_pullback (z̄)
408- dz = reshape (z̄, length (y), length (x))
409- return NoTangent (), conj .(dz' * y), dz * conj .(x)
408+ dz = reshape (unthunk (z̄), length (y), length (x))
409+ x̄ = @thunk conj .(dz' * y)
410+ ȳ = @thunk dz * conj .(x)
411+ return NoTangent (), x̄, ȳ
410412 end
411413 return kron (x, y), kron_pullback
412414end
413415
414416function rrule (:: typeof (kron), x:: AbstractMatrix{<:Number} , y:: AbstractVector{<:Number} )
415417 function kron_pullback (z̄)
416- dz = reshape (z̄ , length (y), size (x)... )
417- x̄ = Ref (y' ) .* eachslice (dz; dims = (2 , 3 ))
418- ȳ = conj .(dot .(eachslice (dz; dims = 1 ), Ref (x)))
418+ dz = reshape (unthunk (z̄) , length (y), size (x)... )
419+ x̄ = @thunk Ref (y' ) .* eachslice (dz; dims = (2 , 3 ))
420+ ȳ = @thunk conj .(dot .(eachslice (dz; dims = 1 ), Ref (x)))
419421 return NoTangent (), x̄, ȳ
420422 end
421423 return kron (x, y), kron_pullback
422424end
423425
424426function rrule (:: typeof (kron), x:: AbstractVector{<:Number} , y:: AbstractMatrix{<:Number} )
425427 function kron_pullback (z̄)
426- dz = reshape (z̄ , size (y, 1 ), length (x), size (y, 2 ))
427- x̄ = conj .(dot .(eachslice (dz; dims = 2 ), Ref (y)))
428- ȳ = Ref (x' ) .* eachslice (dz; dims = (1 , 3 ))
428+ dz = reshape (unthunk (z̄) , size (y, 1 ), length (x), size (y, 2 ))
429+ x̄ = @thunk conj .(dot .(eachslice (dz; dims = 2 ), Ref (y)))
430+ ȳ = @thunk Ref (x' ) .* eachslice (dz; dims = (1 , 3 ))
429431 return NoTangent (), x̄, ȳ
430432 end
431433 return kron (x, y), kron_pullback
432434end
433435
434436function rrule (:: typeof (kron), x:: AbstractMatrix{<:Number} , y:: AbstractMatrix{<:Number} )
435437 function kron_pullback (z̄)
436- dz = reshape (z̄ , size (y, 1 ), size (x, 1 ), size (y, 2 ), size (x, 2 ))
437- x̄ = conj .(dot .(eachslice (dz, dims = (2 , 4 )), Ref (y)))
438- ȳ = dot .(eachslice (conj .(dz); dims = (1 , 3 )), Ref (conj .(x)))
438+ dz = reshape (unthunk (z̄) , size (y, 1 ), size (x, 1 ), size (y, 2 ), size (x, 2 ))
439+ x̄ = @thunk conj .(dot .(eachslice (dz, dims = (2 , 4 )), Ref (y)))
440+ ȳ = @thunk dot .(eachslice (conj .(dz); dims = (1 , 3 )), Ref (conj .(x)))
439441 return NoTangent (), x̄, ȳ
440442 end
441443 return kron (x, y), kron_pullback
0 commit comments