candle-flash-attn: update vendored Tri Dao FA kernels to v2.8.3#3521
Open
toddwbucy wants to merge 1 commit into
Open
candle-flash-attn: update vendored Tri Dao FA kernels to v2.8.3#3521toddwbucy wants to merge 1 commit into
toddwbucy wants to merge 1 commit into
Conversation
Bumps the vendored Flash Attention 2 kernel sources from the post-Dec-2024 state (~v2.6.x with huggingface#2688/huggingface#2689/huggingface#2690 cherry-picks) to upstream **v2.8.3** (Tri Dao commit 060c918, 2025-08-14). Closes the kernel-vendoring portion of huggingface#3515. ## What's in the bump * All `kernels/*.h` and forward sm80 `*.cu` files replaced with v2.8.3 sources (path-remapped from upstream `csrc/flash_attn/src/` to candle's flat `kernels/`). * 24 split-KV forward sm80 kernels added (NEW in v2.8.3) — one per {fp16, bf16} × {dense, causal} × {hdim 32, 64, 96, 128, 192, 256}. * `flash_api.cu` dispatcher mirrors v2.8.3's `run_mha_fwd(params, stream, force_split_kernel)` shape: branches between dense `run_mha_fwd_<>` and `run_mha_fwd_splitkv_dispatch<>` based on `params.num_splits` and `force_split_kernel`. * `extern "C" run_mha` FFI gains 4 params (`num_splits`, `softmax_lseaccum_ptr`, `oaccum_ptr`, `force_split_kernel`). `src/ffi.rs` mirrors. Existing call sites pass `num_splits=1` + null accumulator pointers + `force_split_kernel=0` until the dispatcher heuristic decides otherwise. * Rust-side `set_params_splitkv` equivalent: ports upstream `num_splits_heuristic` (efficiency-based search, ≥85% of peak occupancy, capped at 128) and the accumulator allocation. Exposed as `#[doc(hidden)] pub` so integration tests can verify the dispatcher chose splitkv for a given shape. * Dropped legacy hdim 160 / 224 / 512 forward kernels — Tri Dao removed their launch-template helpers in v2.x and the candle vendored .cu files for those dims no longer compile against v2.8.3. * CUTLASS pin retained at `7d49e6c7`; v2.8.3's pinned commit `dc481792` is not required for the kernels we compile. ## Stripping PyTorch The dropout codepath in v2.8.3's `flash_fwd_kernel.h` references `at::PhiloxCudaState` and `at::cuda::philox::unpack`. Provide a candle-side stub in `kernels/philox_unpack.cuh` so the path compiles without dragging the torch C++ runtime in. Compile-time tripwire at the dispatch site asserts `params.p_dropout == 1.f` to prevent any future caller from silently triggering the dropout path against the stub. ## Compile-time fix: extern template declarations Adding the splitkv branch to `flash_api.cu` produced a runaway nvcc compile (>30 minutes, killed at 17.9 GB cicc RSS during development). Root cause: `flash_fwd_launch_template.h` defines the primary template `run_mha_fwd_splitkv_dispatch<>`, whose body calls `run_flash_splitkv_fwd<>` with **seven nested binary `*_SWITCH` chains** — 128 kernel specialisations per (dtype, hdim, causal) tuple, plus a 14-way combine-kernel chain. Implicitly instantiating all 24 tuples in the dispatcher TU expanded ~3,400 kernel specialisations through CUTLASS GEMM templates in a single nvcc invocation. Fixed by adding 24 `extern template` declarations in `flash_api.cu` after the include. nvcc no longer instantiates `run_mha_fwd_splitkv_dispatch<>` specialisations in this TU; the linker resolves the calls to the explicit instantiation definitions in the per-hdim `flash_fwd_split_hdim*_*_sm80.cu` files (which compile in parallel as 24 independent ~30s TUs). The dense-path counterpart `run_mha_fwd_<>` is forward-declared in `flash.h` without a primary template definition, so it never had this problem. ## Testing Build + test gate (cuda:0, sm_86, RTX A6000): $ cargo build # clean: 49 kernels — 4m36s $ cargo test flash_attn_acausal ok flash_attn_acausal_softcap ok flash_attn_varlen ok flash_attn_acausal_splitkv ok # NEW: exercises splitkv path The new `flash_attn_acausal_splitkv` test uses a shape (batch=1, heads=2, seqlen_q=8, seqlen_k=512, head_dim=64) chosen so `num_splits_heuristic` returns ≥ 2 on any modern sm80+ GPU (A6000 / 4090 / A100 / H100), and asserts max-abs diff < 5e-3 against an fp32 attention reference. ## Public Rust API `flash_attn`, `flash_attn_windowed`, `flash_attn_alibi`, `flash_attn_varlen`, etc. — unchanged. The new FFI params are internal; defaults reproduce pre-bump behavior exactly. ## Note on motivation Originally filed under the framing "FA kernels are stale, bumping fixes long-context divergence." The bump itself is structurally fine and worth landing on its own merits (newer kernels, splitkv support, no regressions on the strata that already passed), but it is **correctness-neutral on long-context.** The actual long-context fix turned out to be in `qwen2.rs::RotaryEmbedding::new` (filed separately) — see the diagnostic update on huggingface#3515. This PR is the "newer kernels + splitkv" half of the original framing.
Author
This was referenced May 7, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Closes the kernel-vendoring portion of #3515.
Summary
Bumps the vendored Flash Attention 2 kernel sources from the
post-Dec-2024 state (~v2.6.x with #2688 / #2689 / #2690 cherry-picks)
to upstream v2.8.3 (Tri Dao commit 060c918, 2025-08-14).
Single squashed commit. ~700 net lines including 24 new split-KV
kernel files; the bulk is mechanical kernel updates from upstream
v2.8.3.
What's in the bump
kernels/*.hand forward sm80*.cufiles replaced withv2.8.3 sources, path-remapped from upstream
csrc/flash_attn/src/to candle's flatkernels/.flash_fwd_split_hdim*_*_sm80.cu)— {fp16, bf16} × {dense, causal} × {hdim 32, 64, 96, 128, 192, 256}.
flash_api.cudispatcher mirrors v2.8.3'srun_mha_fwd(params, stream, force_split_kernel)shape: branchesbetween
run_mha_fwd_<>(dense) andrun_mha_fwd_splitkv_dispatch<>based on
params.num_splitsandforce_split_kernel.extern "C" run_mhaFFI gains four params (num_splits,softmax_lseaccum_ptr,oaccum_ptr,force_split_kernel).src/ffi.rsmirrors. Existing call sites pass safe defaults sopre-bump behavior is preserved exactly.
num_splits_heuristic(efficiencysearch, ≥85% of peak occupancy, capped at 128) and the
set_params_splitkvaccumulator-allocation logic.removed their launch-template helpers in v2.x and the previously
vendored .cu files for those head dims no longer compile against
v2.8.3's launch template.
7d49e6c7(v2.8.3's pinneddc481792is not required for the kernels we compile).
PyTorch dependency stub
v2.8.3's dropout codepath references
at::PhiloxCudaStateandat::cuda::philox::unpack. Provide a candle-side stub inkernels/philox_unpack.cuhso the path compiles without draggingthe torch C++ runtime in. Tripwire at the dispatch site asserts
params.p_dropout == 1.fto prevent any future caller fromsilently triggering the (broken-by-stub) dropout codepath.
Compile-time fix: extern template declarations
The naive splitkv-aware
flash_api.cuproduced a runaway nvcccompile (~30+ minutes, ~17 GB cicc RSS) during development. Root
cause:
flash_fwd_launch_template.hdefines the primary templaterun_mha_fwd_splitkv_dispatch<>, whose body has 7 nested*_SWITCHchains (128 kernel specialisations per(dtype, hdim, causal) tuple, plus a 14-way combine-kernel chain).
Implicitly instantiating all 24 tuples in the dispatcher TU
expanded ~3,400 kernel specialisations through CUTLASS GEMM
templates in a single nvcc invocation.
Fixed by adding 24
extern templatedeclarations inflash_api.cuafter the include. nvcc no longer instantiates the splitkv
specialisations in this TU; linker resolves to the explicit
instantiations in the per-hdim
flash_fwd_split_hdim*_*_sm80.cufiles (which compile in parallel as 24 ~30s TUs). The dense-path
counterpart
run_mha_fwd_<>was already forward-declared inflash.hwithout a primary template definition, so it never hadthis problem.
Testing
cuda:0, sm_86, RTX A6000:
The new
flash_attn_acausal_splitkvtest uses a shape(batch=1, heads=2, seqlen_q=8, seqlen_k=512, head_dim=64) chosen
so
num_splits_heuristicreturns ≥ 2 on any modern sm80+ GPU,and asserts max-abs < 5e-3 against an fp32 attention reference.
Public Rust API
flash_attn,flash_attn_windowed,flash_attn_alibi,flash_attn_varlen, etc. — unchanged. New FFI params are internal.Note on framing
Originally filed under the hypothesis "FA kernels are stale,
bumping fixes long-context divergence." The bump is structurally
fine and worth landing on its own merits (newer kernels + splitkv
support + no regressions), but it is correctness-neutral on
long-context by itself. The actual long-context fix turned out
to be in
candle-transformers/src/models/qwen2.rs::RotaryEmbedding::new(filed as a separate PR, link to follow if maintainers want to
review them together). See the diagnostic update on #3515
for the full chain.
This PR is the "newer kernels + splitkv" half of the original ask.