From 276ae2590154401ca5b1dc9d982a75465b68d525 Mon Sep 17 00:00:00 2001 From: Varun Nuthalapati Date: Tue, 14 Apr 2026 08:17:27 -0700 Subject: [PATCH 1/2] fix(model): convert bool mask_cache to float additive mask for softcapping MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When KV cache is active, build_mask_cache() returns a torch.bool tensor (True=keep). In scaled_dot_product_attention the bool mask was added directly to scores, contributing 0 or 1 instead of 0 or -inf, which breaks causal masking for models that use attention_logit_softcapping (e.g. Gemma 2). Add an elif branch that converts the boolean mask to an additive float mask (True→0.0, False→-inf) before the scores addition. The fix is applied to both CausalSelfAttention and MultiheadLatentAttention. Fixes #1672 --- litgpt/model.py | 6 ++++++ tests/test_model.py | 38 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 44 insertions(+) diff --git a/litgpt/model.py b/litgpt/model.py index 541860ab5b..acc5a13c96 100644 --- a/litgpt/model.py +++ b/litgpt/model.py @@ -586,6 +586,9 @@ def scaled_dot_product_attention( if mask is None: mask = torch.ones(q.size(2), q.size(2), dtype=q.dtype, device=q.device).triu(diagonal=1) mask.masked_fill_(mask.bool(), torch.finfo(q.dtype).min) + elif mask.dtype == torch.bool: + # build_mask_cache returns a boolean mask (True=keep); convert to additive float mask + mask = torch.zeros_like(mask, dtype=q.dtype).masked_fill_(~mask, torch.finfo(q.dtype).min) scores = scores + mask scores = F.softmax(scores, dim=-1, dtype=torch.float).to(dtype=q.dtype) y = scores @ v @@ -773,6 +776,9 @@ def scaled_dot_product_attention( if mask is None: mask = torch.ones(q.size(2), q.size(2), dtype=q.dtype, device=q.device).triu(diagonal=1) mask.masked_fill_(mask.bool(), torch.finfo(q.dtype).min) + elif mask.dtype == torch.bool: + # build_mask_cache returns a boolean mask (True=keep); convert to additive float mask + mask = torch.zeros_like(mask, dtype=q.dtype).masked_fill_(~mask, torch.finfo(q.dtype).min) scores = scores + mask scores = F.softmax(scores, dim=-1, dtype=torch.float).to(dtype=q.dtype) y = scores @ v diff --git a/tests/test_model.py b/tests/test_model.py index 8d0cf21d5e..b727cf13a0 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -1809,3 +1809,41 @@ def test_sliding_window_kv_cache_prefill_exceeds_window(): k_out, v_out = cache(input_pos_safe, k_safe, v_safe) assert k_out.shape == (batch_size, n_query_groups, safe_len, head_size) assert v_out.shape == (batch_size, n_query_groups, safe_len, head_size) + + +@torch.inference_mode() +def test_attention_mask_bool_to_float_with_softcapping(): + """Boolean mask_cache must be converted to an additive float mask before being added to + softcapped attention scores (issue #1672). Without the fix, True/False (1/0) are added + instead of 0/-inf, breaking causal masking during KV-cache generation.""" + from litgpt.model import build_mask_cache + + dtype = torch.float32 + batch_size, n_head, seq_len, head_size = 1, 2, 4, 8 + config = Config( + n_layer=1, + n_head=n_head, + n_embd=n_head * head_size, + n_query_groups=n_head, + block_size=seq_len, + attention_logit_softcapping=50.0, + vocab_size=16, + ) + model = GPT(config).to(dtype) + model.set_kv_cache(batch_size=batch_size, max_seq_length=seq_len) + + # Confirm mask_cache is boolean (pre-condition of the bug) + assert model.mask_cache is not None + assert model.mask_cache.dtype == torch.bool + + # Run prefill — all input positions at once + input_ids = torch.randint(0, config.vocab_size, (batch_size, seq_len)) + input_pos = torch.arange(seq_len) + out_with_cache = model(input_ids, input_pos) + + # Run without KV cache for reference + model.clear_kv_cache() + out_no_cache = model(input_ids) + + # Outputs must be numerically close; if mask was bool-added (+0/+1) they would diverge + torch.testing.assert_close(out_with_cache, out_no_cache, rtol=1e-4, atol=1e-4) From 68f47a2f2f5b1b1a28388990803044fd5a5c3a1b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 14 Apr 2026 15:17:57 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/test_model.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_model.py b/tests/test_model.py index b727cf13a0..8b81588630 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -1816,7 +1816,6 @@ def test_attention_mask_bool_to_float_with_softcapping(): """Boolean mask_cache must be converted to an additive float mask before being added to softcapped attention scores (issue #1672). Without the fix, True/False (1/0) are added instead of 0/-inf, breaking causal masking during KV-cache generation.""" - from litgpt.model import build_mask_cache dtype = torch.float32 batch_size, n_head, seq_len, head_size = 1, 2, 4, 8