@@ -403,74 +403,40 @@ function frule((_, Δx, Δy), ::typeof(kron), x::AbstractVecOrMat{<:Number}, y::
403403 return kron (x, y), kron (Δx, y) + kron (x, Δy)
404404end
405405
406- function rrule (:: typeof (kron), x:: AbstractVector , y:: AbstractVector )
406+ function rrule (:: typeof (kron), x:: AbstractVector{<:Number} , y:: AbstractVector{<:Number} )
407407 function kron_pullback (z̄)
408- x̄ = zero (x)
409- ȳ = zero (y)
410- m = firstindex (z̄)
411- @inbounds for i in eachindex (x)
412- xi = x[i]
413- for k in eachindex (y)
414- x̄[i] += y[k]' * z̄[m]
415- ȳ[k] += xi' * z̄[m]
416- m += 1
417- end
418- end
419- NoTangent (), x̄, ȳ
408+ dz = reshape (z̄, length (y), length (x))
409+ return NoTangent (), conj .(dz' * y), dz * conj .(x)
420410 end
421- kron (x, y), kron_pullback
411+ return kron (x, y), kron_pullback
422412end
423413
424- function rrule (:: typeof (kron), x:: AbstractMatrix , y:: AbstractVector )
414+ function rrule (:: typeof (kron), x:: AbstractMatrix{<:Number} , y:: AbstractVector{<:Number} )
425415 function kron_pullback (z̄)
426- x̄ = zero (x)
427- ȳ = zero (y)
428- m = firstindex (z̄)
429- @inbounds for j in axes (x,2 ), i in axes (x,1 )
430- xij = x[i,j]
431- for k in eachindex (y)
432- x̄[i, j] += y[k]' * z̄[m]
433- ȳ[k] += xij' * z̄[m]
434- m += 1
435- end
436- end
437- NoTangent (), x̄, ȳ
416+ dz = reshape (z̄, length (y), size (x)... )
417+ x̄ = Ref (y' ) .* eachslice (dz; dims = (2 , 3 ))
418+ ȳ = conj .(dot .(eachslice (dz; dims = 1 ), Ref (x)))
419+ return NoTangent (), x̄, ȳ
438420 end
439- kron (x, y), kron_pullback
421+ return kron (x, y), kron_pullback
440422end
441423
442- function rrule (:: typeof (kron), x:: AbstractVector , y:: AbstractMatrix )
424+ function rrule (:: typeof (kron), x:: AbstractVector{<:Number} , y:: AbstractMatrix{<:Number} )
443425 function kron_pullback (z̄)
444- x̄ = zero (x)
445- ȳ = zero (y)
446- m = firstindex (z̄)
447- @inbounds for l in axes (y,2 ), i in eachindex (x)
448- xi = x[i]
449- for k in axes (y,1 )
450- x̄[i] += y[k, l]' * z̄[m]
451- ȳ[k, l] += xi' * z̄[m]
452- m += 1
453- end
454- end
455- NoTangent (), x̄, ȳ
426+ dz = reshape (z̄, size (y, 1 ), length (x), size (y, 2 ))
427+ x̄ = conj .(dot .(eachslice (dz; dims = 2 ), Ref (y)))
428+ ȳ = Ref (x' ) .* eachslice (dz; dims = (1 , 3 ))
429+ return NoTangent (), x̄, ȳ
456430 end
457- kron (x, y), kron_pullback
431+ return kron (x, y), kron_pullback
458432end
459433
460- function rrule (:: typeof (kron), x:: AbstractMatrix , y:: AbstractMatrix )
434+ function rrule (:: typeof (kron), x:: AbstractMatrix{<:Number} , y:: AbstractMatrix{<:Number} )
461435 function kron_pullback (z̄)
462- x̄ = zero (x)
463- ȳ = zero (y)
464- m = firstindex (z̄)
465- @inbounds for l in axes (y,2 ), j in axes (x,2 ), i in axes (x,1 )
466- xij = x[i, j]
467- for k in axes (y,1 )
468- x̄[i, j] += y[k, l]' * z̄[m]
469- ȳ[k, l] += xij' * z̄[m]
470- m += 1
471- end
472- end
473- NoTangent (), x̄, ȳ
436+ dz = reshape (z̄, size (y, 1 ), size (x, 1 ), size (y, 2 ), size (x, 2 ))
437+ x̄ = conj .(dot .(eachslice (dz, dims = (2 , 4 )), Ref (y)))
438+ ȳ = dot .(eachslice (conj .(dz); dims = (1 , 3 )), Ref (conj .(x)))
439+ return NoTangent (), x̄, ȳ
474440 end
475- kron (x, y), kron_pullback
441+ return kron (x, y), kron_pullback
476442end
0 commit comments