cuda : prevent integer truncation and overflow errors when using KQ mask strides in flash_attn_mask_to_KV_max kernel#24945
Conversation
…ask strides in flash_attn_mask_to_KV_max kernel
|
Can we get a test-case for this? |
@ORippler If you mean Anyway, without this PR: With this PR: So the test still fails because of ERR being too high, but there are no crashes. |
Overview
This PR prevents integer truncation and overflow errors in
flash_attn_mask_to_KV_maxkernel by changing type ofs31ands33frominttosize_t.Fixes #24912
Additional information
When large KQ masks are used (for example in models with long context lengths like 1M tokens with large ubatch size used to speed up prompt processing)
mask->nb[3] / sizeof(half2)can exceedinttype value range resulting ins33being interpreted as negative or smaller than expected due to value truncation. There's also another problem withjt*ncols1*s31multiplication resulting in integer overflow when bothjt*ncols1ands31are large, sos31type was also changed frominttosize_t.Requirements