diff --git a/src/kernels.jl b/src/kernels.jl index d5deb40..5c6b0ef 100644 --- a/src/kernels.jl +++ b/src/kernels.jl @@ -40,7 +40,7 @@ function transduce_impl(rf::F, init, arrays...) where {F} length(ys) == 1 && return @allowscalar ys[1] rf2 = AlwaysCombine(rf) while true - ys, = _transduce!(buf, rf2, CombineInit(), ys) + ys, = _transduce!(buf, rf2, init, ys) # @info "ys, = _transduce!(buf, rf2, ...)" Text(summary(ys)) # @info "ys, = _transduce!(buf, rf2, ...)" collect(ys) length(ys) == 1 && return @allowscalar ys[1] @@ -282,8 +282,6 @@ function transduce_kernel!( return end -struct CombineInit end - struct AlwaysCombine{I} <: AbstractReduction{I} inner::I end @@ -293,8 +291,7 @@ AlwaysCombine(rf::Transducers.R_{Map}) = AlwaysCombine(Transducers.inner(rf)) AlwaysCombine(rf::Transducers.BottomRF) = AlwaysCombine(Transducers.inner(rf)) =# -@inline Transducers.start(::AlwaysCombine, init::CombineInit) = init -@inline Transducers.next(::AlwaysCombine, ::CombineInit, input) = first(input) +@inline Transducers.start(rf::AlwaysCombine, init) = start(rf.inner, init) @inline Transducers.next(rf::F, acc, input) where {F<:AlwaysCombine} = _combine(rf.inner, acc, first(input)) @inline Transducers.complete(rf::F, result) where {F<:AlwaysCombine} = diff --git a/src/shfl.jl b/src/shfl.jl index 7ab8128..96111c2 100644 --- a/src/shfl.jl +++ b/src/shfl.jl @@ -44,10 +44,8 @@ function transduce_shfl_impl(rf::F, init, arrays...) where {F} # @info "ys, = transduce_shfl!(nothing, rf, ...)" collect(ys) length(ys) == 1 && return @allowscalar ys[1] rf2 = AlwaysCombine(rf) - combine_init = init # require type-stable init - @assert start(rf, init) === init while true - ys, = transduce_shfl!(buf, rf2, combine_init, ys) + ys, = transduce_shfl!(buf, rf2, init, ys) # @info "ys, = transduce_shfl!(buf, rf2, ...)" Text(summary(ys)) # @info "ys, = transduce_shfl!(buf, rf2, ...)" collect(ys) length(ys) == 1 && return @allowscalar ys[1]