Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions litgpt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,6 +586,9 @@ def scaled_dot_product_attention(
if mask is None:
mask = torch.ones(q.size(2), q.size(2), dtype=q.dtype, device=q.device).triu(diagonal=1)
mask.masked_fill_(mask.bool(), torch.finfo(q.dtype).min)
elif mask.dtype == torch.bool:
# build_mask_cache returns a boolean mask (True=keep); convert to additive float mask
mask = torch.zeros_like(mask, dtype=q.dtype).masked_fill_(~mask, torch.finfo(q.dtype).min)
scores = scores + mask
scores = F.softmax(scores, dim=-1, dtype=torch.float).to(dtype=q.dtype)
y = scores @ v
Expand Down Expand Up @@ -773,6 +776,9 @@ def scaled_dot_product_attention(
if mask is None:
mask = torch.ones(q.size(2), q.size(2), dtype=q.dtype, device=q.device).triu(diagonal=1)
mask.masked_fill_(mask.bool(), torch.finfo(q.dtype).min)
elif mask.dtype == torch.bool:
# build_mask_cache returns a boolean mask (True=keep); convert to additive float mask
mask = torch.zeros_like(mask, dtype=q.dtype).masked_fill_(~mask, torch.finfo(q.dtype).min)
scores = scores + mask
scores = F.softmax(scores, dim=-1, dtype=torch.float).to(dtype=q.dtype)
y = scores @ v
Expand Down
37 changes: 37 additions & 0 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1809,3 +1809,40 @@ def test_sliding_window_kv_cache_prefill_exceeds_window():
k_out, v_out = cache(input_pos_safe, k_safe, v_safe)
assert k_out.shape == (batch_size, n_query_groups, safe_len, head_size)
assert v_out.shape == (batch_size, n_query_groups, safe_len, head_size)


@torch.inference_mode()
def test_attention_mask_bool_to_float_with_softcapping():
"""Boolean mask_cache must be converted to an additive float mask before being added to
softcapped attention scores (issue #1672). Without the fix, True/False (1/0) are added
instead of 0/-inf, breaking causal masking during KV-cache generation."""

dtype = torch.float32
batch_size, n_head, seq_len, head_size = 1, 2, 4, 8
config = Config(
n_layer=1,
n_head=n_head,
n_embd=n_head * head_size,
n_query_groups=n_head,
block_size=seq_len,
attention_logit_softcapping=50.0,
vocab_size=16,
)
model = GPT(config).to(dtype)
model.set_kv_cache(batch_size=batch_size, max_seq_length=seq_len)

# Confirm mask_cache is boolean (pre-condition of the bug)
assert model.mask_cache is not None
assert model.mask_cache.dtype == torch.bool

# Run prefill — all input positions at once
input_ids = torch.randint(0, config.vocab_size, (batch_size, seq_len))
input_pos = torch.arange(seq_len)
out_with_cache = model(input_ids, input_pos)

# Run without KV cache for reference
model.clear_kv_cache()
out_no_cache = model(input_ids)

# Outputs must be numerically close; if mask was bool-added (+0/+1) they would diverge
torch.testing.assert_close(out_with_cache, out_no_cache, rtol=1e-4, atol=1e-4)
Loading