416416function rrule (:: typeof (kron), x:: AbstractMatrix{<:Number} , y:: AbstractVector{<:Number} )
417417 function kron_pullback (z̄)
418418 dz = reshape (unthunk (z̄), length (y), size (x)... )
419- x̄ = @thunk Ref (y ' ) .* eachslice (dz; dims = (2 , 3 ))
419+ x̄ = @thunk conj .( dot .( eachslice (dz; dims = (2 , 3 )), Ref (y) ))
420420 ȳ = @thunk conj .(dot .(eachslice (dz; dims = 1 ), Ref (x)))
421421 return NoTangent (), x̄, ȳ
422422 end
@@ -427,7 +427,7 @@ function rrule(::typeof(kron), x::AbstractVector{<:Number}, y::AbstractMatrix{<:
427427 function kron_pullback (z̄)
428428 dz = reshape (unthunk (z̄), size (y, 1 ), length (x), size (y, 2 ))
429429 x̄ = @thunk conj .(dot .(eachslice (dz; dims = 2 ), Ref (y)))
430- ȳ = @thunk Ref (x ' ) .* eachslice (dz; dims = (1 , 3 ))
430+ ȳ = @thunk conj .( dot .( eachslice (dz; dims = (1 , 3 )), Ref (x) ))
431431 return NoTangent (), x̄, ȳ
432432 end
433433 return kron (x, y), kron_pullback
@@ -437,7 +437,7 @@ function rrule(::typeof(kron), x::AbstractMatrix{<:Number}, y::AbstractMatrix{<:
437437 function kron_pullback (z̄)
438438 dz = reshape (unthunk (z̄), size (y, 1 ), size (x, 1 ), size (y, 2 ), size (x, 2 ))
439439 x̄ = @thunk conj .(dot .(eachslice (dz, dims = (2 , 4 )), Ref (y)))
440- ȳ = @thunk dot .(eachslice (conj .(dz) ; dims = (1 , 3 )), Ref ( conj . (x)))
440+ ȳ = @thunk conj .( dot .(eachslice (dz ; dims = (1 , 3 )), Ref (x)))
441441 return NoTangent (), x̄, ȳ
442442 end
443443 return kron (x, y), kron_pullback
0 commit comments