From 86a263b81975d13ce415af6b89bc10247fc3458b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stanis=C5=82aw=20Szymczyk?= Date: Tue, 23 Jun 2026 11:49:57 +0200 Subject: [PATCH] cuda : prevent integer truncation and overflow errors when using KQ mask strides in flash_attn_mask_to_KV_max kernel --- ggml/src/ggml-cuda/fattn-common.cuh | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index 8dfa51ad1e8f..1cd44afa0a4a 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -664,7 +664,7 @@ constexpr __device__ dequantize_V_t get_dequantize_V() { template __launch_bounds__(FATTN_KQ_STRIDE/2, 1) static __global__ void flash_attn_mask_to_KV_max( - const half2 * __restrict__ mask, int * __restrict__ KV_max, const int ne30, const int s31, const int s33) { + const half2 * __restrict__ mask, int * __restrict__ KV_max, const int ne30, const size_t s31, const size_t s33) { const int ne31 = gridDim.x; const int tid = threadIdx.x; const int sequence = blockIdx.y; @@ -1089,8 +1089,8 @@ void launch_fattn( // Only worth the overhead if there is at lease one FATTN_KQ_STRIDE x FATTN_KQ_STRIDE square to be skipped or // multiple sequences of possibly different lengths. if (mask && K->ne[1] % FATTN_KQ_STRIDE == 0 && (Q->ne[1] >= 1024 || Q->ne[3] > 1)) { - const int s31 = mask->nb[1] / sizeof(half2); - const int s33 = mask->nb[3] / sizeof(half2); + const size_t s31 = mask->nb[1] / sizeof(half2); + const size_t s33 = mask->nb[3] / sizeof(half2); const dim3 blocks_num_KV_max(ntiles_x, Q->ne[3], 1); const dim3 block_dim_KV_max(FATTN_KQ_STRIDE/2, 1, 1);