@@ -399,46 +399,48 @@ end
399399# #### `kron`
400400# ####
401401
402- function frule ((_, Δx, Δy), :: typeof (kron), x:: AbstractVecOrMat{<:Number} , y:: AbstractVecOrMat{<:Number} )
403- return kron (x, y), kron (Δx, y) + kron (x, Δy)
404- end
402+ @static if VERSION ≥ v " 1.9.0"
403+ function frule ((_, Δx, Δy), :: typeof (kron), x:: AbstractVecOrMat{<:Number} , y:: AbstractVecOrMat{<:Number} )
404+ return kron (x, y), kron (Δx, y) + kron (x, Δy)
405+ end
405406
406- function rrule (:: typeof (kron), x:: AbstractVector{<:Number} , y:: AbstractVector{<:Number} )
407- function kron_pullback (z̄)
408- dz = reshape (unthunk (z̄), length (y), length (x))
409- x̄ = @thunk conj .(dz' * y)
410- ȳ = @thunk dz * conj .(x)
411- return NoTangent (), x̄, ȳ
407+ function rrule (:: typeof (kron), x:: AbstractVector{<:Number} , y:: AbstractVector{<:Number} )
408+ function kron_pullback (z̄)
409+ dz = reshape (unthunk (z̄), length (y), length (x))
410+ x̄ = @thunk conj .(dz' * y)
411+ ȳ = @thunk dz * conj .(x)
412+ return NoTangent (), x̄, ȳ
413+ end
414+ return kron (x, y), kron_pullback
412415 end
413- return kron (x, y), kron_pullback
414- end
415416
416- function rrule (:: typeof (kron), x:: AbstractMatrix{<:Number} , y:: AbstractVector{<:Number} )
417- function kron_pullback (z̄)
418- dz = reshape (unthunk (z̄), length (y), size (x)... )
419- x̄ = @thunk conj .(dot .(eachslice (dz; dims = (2 , 3 )), Ref (y)))
420- ȳ = @thunk conj .(dot .(eachslice (dz; dims = 1 ), Ref (x)))
421- return NoTangent (), x̄, ȳ
417+ function rrule (:: typeof (kron), x:: AbstractMatrix{<:Number} , y:: AbstractVector{<:Number} )
418+ function kron_pullback (z̄)
419+ dz = reshape (unthunk (z̄), length (y), size (x)... )
420+ x̄ = @thunk conj .(dot .(eachslice (dz; dims = (2 , 3 )), Ref (y)))
421+ ȳ = @thunk conj .(dot .(eachslice (dz; dims = 1 ), Ref (x)))
422+ return NoTangent (), x̄, ȳ
423+ end
424+ return kron (x, y), kron_pullback
422425 end
423- return kron (x, y), kron_pullback
424- end
425426
426- function rrule (:: typeof (kron), x:: AbstractVector{<:Number} , y:: AbstractMatrix{<:Number} )
427- function kron_pullback (z̄)
428- dz = reshape (unthunk (z̄), size (y, 1 ), length (x), size (y, 2 ))
429- x̄ = @thunk conj .(dot .(eachslice (dz; dims = 2 ), Ref (y)))
430- ȳ = @thunk conj .(dot .(eachslice (dz; dims = (1 , 3 )), Ref (x)))
431- return NoTangent (), x̄, ȳ
427+ function rrule (:: typeof (kron), x:: AbstractVector{<:Number} , y:: AbstractMatrix{<:Number} )
428+ function kron_pullback (z̄)
429+ dz = reshape (unthunk (z̄), size (y, 1 ), length (x), size (y, 2 ))
430+ x̄ = @thunk conj .(dot .(eachslice (dz; dims = 2 ), Ref (y)))
431+ ȳ = @thunk conj .(dot .(eachslice (dz; dims = (1 , 3 )), Ref (x)))
432+ return NoTangent (), x̄, ȳ
433+ end
434+ return kron (x, y), kron_pullback
432435 end
433- return kron (x, y), kron_pullback
434- end
435436
436- function rrule (:: typeof (kron), x:: AbstractMatrix{<:Number} , y:: AbstractMatrix{<:Number} )
437- function kron_pullback (z̄)
438- dz = reshape (unthunk (z̄), size (y, 1 ), size (x, 1 ), size (y, 2 ), size (x, 2 ))
439- x̄ = @thunk conj .(dot .(eachslice (dz, dims = (2 , 4 )), Ref (y)))
440- ȳ = @thunk conj .(dot .(eachslice (dz; dims = (1 , 3 )), Ref (x)))
441- return NoTangent (), x̄, ȳ
437+ function rrule (:: typeof (kron), x:: AbstractMatrix{<:Number} , y:: AbstractMatrix{<:Number} )
438+ function kron_pullback (z̄)
439+ dz = reshape (unthunk (z̄), size (y, 1 ), size (x, 1 ), size (y, 2 ), size (x, 2 ))
440+ x̄ = @thunk conj .(dot .(eachslice (dz, dims = (2 , 4 )), Ref (y)))
441+ ȳ = @thunk conj .(dot .(eachslice (dz; dims = (1 , 3 )), Ref (x)))
442+ return NoTangent (), x̄, ȳ
443+ end
444+ return kron (x, y), kron_pullback
442445 end
443- return kron (x, y), kron_pullback
444446end
0 commit comments