diff --git a/ggml/src/ggml-cuda/argsort.cu b/ggml/src/ggml-cuda/argsort.cu index c4f08091e79a..33a38c23e87e 100644 --- a/ggml/src/ggml-cuda/argsort.cu +++ b/ggml/src/ggml-cuda/argsort.cu @@ -28,6 +28,20 @@ static __global__ void init_offsets(int * offsets, const int ncols, const int nr #endif // STRIDED_ITERATOR_AVAILABLE #ifdef GGML_CUDA_USE_CUB + +// returns the suggested maximum number of rows to process during one argsort_f32_i32_cuda_cub() call +int argsort_f32_i32_cuda_cub_chunk_nrows(const size_t nb01, const int64_t nrows) { + // perform argsort in chunks up to approximately this size (currently 64MB) + // to avoid excessive temporary buffers memory usage + const int chunk_bytes = 1 << 26; + + // calculate how many rows will fit in one chunk (must be at least one) + const int chunk_nrows = chunk_bytes > nb01 ? chunk_bytes / nb01 : 1; + + // limit the resulting amount to total nrows + return nrows < chunk_nrows ? nrows : chunk_nrows; +} + void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool, const float * x, int * dst, @@ -254,11 +268,22 @@ void ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const size_t shared_mem = ncols_pad * sizeof(int); const size_t max_shared_mem = ggml_cuda_info().devices[ggml_cuda_get_device()].smpb; - if (shared_mem > max_shared_mem || ncols > 1024) { - ggml_cuda_pool & pool = ctx.pool(); - argsort_f32_i32_cuda_cub(pool, src0_d, (int *) dst_d, ncols, nrows, order, stream); - } else { - argsort_f32_i32_cuda_bitonic(src0_d, (int *) dst_d, ncols, nrows, order, stream); + // early return if we can use bitonic argsort + if (shared_mem <= max_shared_mem && ncols <= 1024) { + return argsort_f32_i32_cuda_bitonic(src0_d, (int *) dst_d, ncols, nrows, order, stream); + } + + const int chunk_nrows = argsort_f32_i32_cuda_cub_chunk_nrows(src0->nb[1], nrows); + + ggml_cuda_pool & pool = ctx.pool(); + + for (int64_t i = 0; i < nrows; i += chunk_nrows) { + int iter_nrows = chunk_nrows < nrows - i ? chunk_nrows : nrows - i; + + argsort_f32_i32_cuda_cub(pool, src0_d, (int *) dst_d, ncols, iter_nrows, order, stream); + + src0_d += ncols * iter_nrows; + dst_d += ncols * iter_nrows; } #else argsort_f32_i32_cuda_bitonic(src0_d, (int *) dst_d, ncols, nrows, order, stream); diff --git a/ggml/src/ggml-cuda/argsort.cuh b/ggml/src/ggml-cuda/argsort.cuh index 22b7306f2020..3abb6448a057 100644 --- a/ggml/src/ggml-cuda/argsort.cuh +++ b/ggml/src/ggml-cuda/argsort.cuh @@ -3,6 +3,7 @@ void ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst); #ifdef GGML_CUDA_USE_CUB +int argsort_f32_i32_cuda_cub_chunk_nrows(const size_t nb01, const int64_t nrows); void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool, const float * x, int * dst, diff --git a/ggml/src/ggml-cuda/top-k.cu b/ggml/src/ggml-cuda/top-k.cu index db1d39e2dc71..5e708e6c5ed4 100644 --- a/ggml/src/ggml-cuda/top-k.cu +++ b/ggml/src/ggml-cuda/top-k.cu @@ -75,17 +75,26 @@ void ggml_cuda_op_top_k(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const int ncols_pad = next_power_of_2(ncols); const size_t shared_mem = ncols_pad * sizeof(int); const size_t max_shared_mem = ggml_cuda_info().devices[ggml_cuda_get_device()].smpb; + const bool use_bitonic = shared_mem <= max_shared_mem && ncols <= 1024; + const int chunk_nrows = argsort_f32_i32_cuda_cub_chunk_nrows(src0->nb[1], nrows); - ggml_cuda_pool_alloc temp_dst_alloc(pool, ncols * nrows); + ggml_cuda_pool_alloc temp_dst_alloc(pool, ncols * chunk_nrows); int * tmp_dst = temp_dst_alloc.get(); - if (shared_mem > max_shared_mem || ncols > 1024) { - argsort_f32_i32_cuda_cub(pool, src0_d, tmp_dst, ncols, nrows, GGML_SORT_ORDER_DESC, stream); - } else { - argsort_f32_i32_cuda_bitonic(src0_d, tmp_dst, ncols, nrows, GGML_SORT_ORDER_DESC, stream); + for (int64_t i = 0; i < nrows; i += chunk_nrows) { + int iter_nrows = chunk_nrows < nrows - i ? chunk_nrows : nrows - i; + + if (use_bitonic) { + argsort_f32_i32_cuda_bitonic(src0_d, tmp_dst, ncols, iter_nrows, GGML_SORT_ORDER_DESC, stream); + } else { + argsort_f32_i32_cuda_cub(pool, src0_d, tmp_dst, ncols, iter_nrows, GGML_SORT_ORDER_DESC, stream); + } + CUDA_CHECK(cudaMemcpy2DAsync(dst_d, k * sizeof(int), tmp_dst, ncols * sizeof(int), k * sizeof(int), iter_nrows, + cudaMemcpyDeviceToDevice, stream)); + + src0_d += ncols * iter_nrows; + dst_d += k * iter_nrows; } - CUDA_CHECK(cudaMemcpy2DAsync(dst_d, k * sizeof(int), tmp_dst, ncols * sizeof(int), k * sizeof(int), nrows, - cudaMemcpyDeviceToDevice, stream)); #else // GGML_CUDA_USE_CUB ggml_cuda_pool_alloc temp_dst_alloc(pool, ncols * nrows); int * tmp_dst = temp_dst_alloc.get();