@@ -295,17 +295,21 @@ end
295295# ###
296296# ### merge
297297# ###
298-
299- function rrule (:: typeof (merge), nt1:: NamedTuple{F1} , nt2:: NamedTuple{F2} ) where {F1,F2}
300- y = merge (nt1, nt2)
301- function merge_pullback (dy)
302- dnt1 = Tangent {typeof(nt1)} (;
303- (f1 => (f1 in F2 ? ZeroTangent () : getproperty (dy, f1)) for f1 in F1). ..
304- )
305- dnt2 = Tangent {typeof(nt2)} (; (f2 => getproperty (dy, f2) for f2 in F2). .. )
298+ # need to work around inability to return closures from generated functions
299+ struct MergePullback{T1, T2}
300+ end
301+ (this:: MergePullback )(dy:: AbstractThunk ) = this (unthunk (dy))
302+ (:: MergePullback )(x:: AbstractZero ) = (NoTangent (), x, x)
303+ @generated function (:: MergePullback{T1,T2} )(dy:: Tangent ) where {F1,T1<: NamedTuple{F1} ,F2,T2<: NamedTuple{F2} }
304+ _getproperty_kwexpr (key) = :($ key = getproperty (dy, $ (Meta. quot (key))))
305+ quote
306+ dnt1 = Tangent {T1} (; $ (map (_getproperty_kwexpr, setdiff (F1, F2))... ))
307+ dnt2 = Tangent {T2} (; $ (map (_getproperty_kwexpr, F2)... ))
306308 return (NoTangent (), dnt1, dnt2)
307309 end
308- merge_pullback (dy:: AbstractThunk ) = merge_pullback (unthunk (dy))
309- merge_pullback (x:: AbstractZero ) = (NoTangent (), x, x)
310- return y, merge_pullback
310+ end
311+
312+ function rrule (:: typeof (merge), nt1:: T1 , nt2:: T2 ) where {T1<: NamedTuple , T2<: NamedTuple }
313+ y = merge (nt1, nt2)
314+ return y, MergePullback {T1,T2} ()
311315end
0 commit comments