ggml : process data in smaller chunks in CUDA ggml_top_k() implementation to reduce temporary buffers memory usage#24776
Conversation
…tion to reduce temporary buffers memory usage
| for (int64_t i = 0; i < nrows; i+= nrows_per_chunk) { | ||
| int64_t chunk_nrows = std::min(nrows_per_chunk, nrows - i); | ||
|
|
||
| ggml_cuda_pool_alloc<int> temp_dst_alloc(pool, ncols * chunk_nrows); |
There was a problem hiding this comment.
I'm not sure how the cuda pool works exactly - just wondering if we actually need to have this allocation inside the loop and not one time before it?
There was a problem hiding this comment.
I'm not sure how the cuda pool works exactly - just wondering if we actually need to have this allocation inside the loop and not one time before it?
@ggerganov Good point, will try it out.
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
|
@ORippler Do you envision any problems with organizing top-k processing the way I did it in this PR (added loop processing smaller input chunks, temporary buffer allocated outside the loop)? |
No general problems with doing it this way. Some remarks:
|
@ORippler OK, will put the number of chunk rows calculation into a common function and apply this to argsort as well. |
Overview
This PR reduces temporary buffers memory usage in CUDA backend
ggml_top_k()CUB implementation by processing input data in smaller chunks. Without this PR temporary buffers memory usage is 3 * input buffer size, allocated here:llama.cpp/ggml/src/ggml-cuda/top-k.cu
Line 79 in d5376cf
and here:
llama.cpp/ggml/src/ggml-cuda/argsort.cu
Lines 38 to 39 in d5376cf
With this PR memory usage for temporary buffers is only 3*min(input buffer size, 64MiB).
It also partially mitigates the problem of integer overflow in
ncols * nrowsproduct by lowering the amount of rows processed at once.Fixes #24718
Additional information
For example when running this test (not present originally, I added it):
without this PR memory usage in
nvidia-smigoes up to 12968MiB, while with this PR it goes up only to 3048MiB.Let's also compare the performance. Without this PR:
with this PR:
I ran
test-backend-opsTOP-K tests and they all passed. Test failing in #24718 also passed:Requirements