Skip to content

fix(qwen3_vl): build causal mask for prefill, not single-token decode (#3505)#3518

Open
0xSoftBoi wants to merge 1 commit into
huggingface:mainfrom
0xSoftBoi:fix/3505-qwen3-vl-attn-mask
Open

fix(qwen3_vl): build causal mask for prefill, not single-token decode (#3505)#3518
0xSoftBoi wants to merge 1 commit into
huggingface:mainfrom
0xSoftBoi:fix/3505-qwen3-vl-attn-mask

Conversation

@0xSoftBoi
Copy link
Copy Markdown

Summary

In Qwen3VLModel::forward, the gate that decides whether to build a causal attention mask is inverted:

let attention_mask = if seqlen <= 1 {
    Some(self.prepare_decoder_attention_mask(...))   // mask built when *not* needed
} else {
    None                                             // mask skipped when needed
};

The decoder causal mask is only useful when more than one token is being processed at once (prefill / multi-token prompt). For single-token autoregressive decode steps the KV cache already constrains attention and the mask can be skipped. The current code does the opposite.

Every other candle text model (mistral, starcoder2, helium, olmo2, phi3, deepseek2, csm, gemma4, …) follows the canonical pattern, e.g. mistral.rs:

let attention_mask = if seq_len <= 1 {
    None
} else {
    let mask = self.prepare_decoder_attention_mask(seq_len, seqlen_offset)?;
    Some(mask)
};

This PR aligns Qwen3 VL with that shape and adds a comment explaining the invariant so the inversion isn't reintroduced.

Test

  • cargo build -p candle-transformers — clean
  • cargo clippy -p candle-transformers --lib -- -D warnings — clean
  • cargo fmt --all -- --check — clean

Fixes #3505

…huggingface#3505)

The condition gating attention-mask construction in Qwen3VLModel::forward
was inverted: `if seqlen <= 1 { Some(mask) } else { None }`. This built
the causal mask only on single-token decode steps (where the KV cache
already constrains attention) and skipped it on multi-token prefill (where
the model needs the mask to attend correctly). The result was wrong
forward-pass behavior whenever a prompt of length > 1 was processed.

All other candle text models (mistral, starcoder2, helium, olmo2, phi3,
deepseek2, etc.) follow the canonical `if seq_len <= 1 { None }` else
`Some(mask)` shape. This patch aligns Qwen3 VL with that pattern.

Fixes huggingface#3505
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Qwen3 VL forward pass doesn't construct attention mask properly

1 participant