@@ -403,9 +403,25 @@ function frule((_, Δx, Δy), ::typeof(kron), x, y)
403403 return kron (x, y), kron (Δx, y) + kron (x, Δy)
404404end
405405
406- function rrule (:: typeof (kron), x:: AbstractMatrix , y:: AbstractVector )
407- z = kron (x, y)
406+ function rrule (:: typeof (kron), x:: AbstractVector , y:: AbstractVector )
407+ function kron_pullback (z̄)
408+ x̄ = zero (x)
409+ ȳ = zero (y)
410+ m = firstindex (z̄)
411+ @inbounds for i in eachindex (x)
412+ xi = x[i]
413+ for k in eachindex (y)
414+ x̄[i] += y[k]' * z̄[m]
415+ ȳ[k] += xi' * z̄[m]
416+ m += 1
417+ end
418+ end
419+ NoTangent (), x̄, ȳ
420+ end
421+ kron (x, y), kron_pullback
422+ end
408423
424+ function rrule (:: typeof (kron), x:: AbstractMatrix , y:: AbstractVector )
409425 function kron_pullback (z̄)
410426 x̄ = zero (x)
411427 ȳ = zero (y)
@@ -414,18 +430,16 @@ function rrule(::typeof(kron), x::AbstractMatrix, y::AbstractVector)
414430 xij = x[i,j]
415431 for k in eachindex (y)
416432 x̄[i, j] += y[k]' * z̄[m]
417- ȳ[k] += xij * z̄[m]
433+ ȳ[k] += xij' * z̄[m]
418434 m += 1
419435 end
420436 end
421437 NoTangent (), x̄, ȳ
422438 end
423- z , kron_pullback
439+ kron (x, y) , kron_pullback
424440end
425441
426442function rrule (:: typeof (kron), x:: AbstractVector , y:: AbstractMatrix )
427- z = kron (x, y)
428-
429443 function kron_pullback (z̄)
430444 x̄ = zero (x)
431445 ȳ = zero (y)
@@ -434,11 +448,29 @@ function rrule(::typeof(kron), x::AbstractVector, y::AbstractMatrix)
434448 xi = x[i]
435449 for k in axes (y,1 )
436450 x̄[i] += y[k, l]' * z̄[m]
437- ȳ[k, l] += xi * z̄[m]
451+ ȳ[k, l] += xi' * z̄[m]
452+ m += 1
453+ end
454+ end
455+ NoTangent (), x̄, ȳ
456+ end
457+ kron (x, y), kron_pullback
458+ end
459+
460+ function rrule (:: typeof (kron), x:: AbstractMatrix , y:: AbstractMatrix )
461+ function kron_pullback (z̄)
462+ x̄ = zero (x)
463+ ȳ = zero (y)
464+ m = firstindex (z̄)
465+ @inbounds for l in axes (y,2 ), j in axes (x,2 ), i in axes (x,1 )
466+ xij = x[i, j]
467+ for k in axes (y,1 )
468+ x̄[i, j] += y[k, l]' * z̄[m]
469+ ȳ[k, l] += xij' * z̄[m]
438470 m += 1
439471 end
440472 end
441473 NoTangent (), x̄, ȳ
442474 end
443- z , kron_pullback
475+ kron (x, y) , kron_pullback
444476end
0 commit comments