Skip to content

candle-flash-attn: update vendored Tri Dao FA kernels to v2.8.3#3521

Open
toddwbucy wants to merge 1 commit into
huggingface:mainfrom
toddwbucy:feat/candle-flash-attn-v2.8.3
Open

candle-flash-attn: update vendored Tri Dao FA kernels to v2.8.3#3521
toddwbucy wants to merge 1 commit into
huggingface:mainfrom
toddwbucy:feat/candle-flash-attn-v2.8.3

Conversation

@toddwbucy
Copy link
Copy Markdown

Closes the kernel-vendoring portion of #3515.

Summary

Bumps the vendored Flash Attention 2 kernel sources from the
post-Dec-2024 state (~v2.6.x with #2688 / #2689 / #2690 cherry-picks)
to upstream v2.8.3 (Tri Dao commit 060c918, 2025-08-14).

Single squashed commit. ~700 net lines including 24 new split-KV
kernel files; the bulk is mechanical kernel updates from upstream
v2.8.3.

What's in the bump

  • All kernels/*.h and forward sm80 *.cu files replaced with
    v2.8.3 sources, path-remapped from upstream
    csrc/flash_attn/src/ to candle's flat kernels/.
  • 24 new split-KV forward sm80 kernels added (flash_fwd_split_hdim*_*_sm80.cu)
    — {fp16, bf16} × {dense, causal} × {hdim 32, 64, 96, 128, 192, 256}.
  • flash_api.cu dispatcher mirrors v2.8.3's
    run_mha_fwd(params, stream, force_split_kernel) shape: branches
    between run_mha_fwd_<> (dense) and run_mha_fwd_splitkv_dispatch<>
    based on params.num_splits and force_split_kernel.
  • extern "C" run_mha FFI gains four params (num_splits,
    softmax_lseaccum_ptr, oaccum_ptr, force_split_kernel).
    src/ffi.rs mirrors. Existing call sites pass safe defaults so
    pre-bump behavior is preserved exactly.
  • Rust-side ports of upstream's num_splits_heuristic (efficiency
    search, ≥85% of peak occupancy, capped at 128) and the
    set_params_splitkv accumulator-allocation logic.
  • Dropped legacy hdim 160 / 224 / 512 forward kernels — Tri Dao
    removed their launch-template helpers in v2.x and the previously
    vendored .cu files for those head dims no longer compile against
    v2.8.3's launch template.
  • CUTLASS pin retained at 7d49e6c7 (v2.8.3's pinned dc481792
    is not required for the kernels we compile).

PyTorch dependency stub

v2.8.3's dropout codepath references at::PhiloxCudaState and
at::cuda::philox::unpack. Provide a candle-side stub in
kernels/philox_unpack.cuh so the path compiles without dragging
the torch C++ runtime in. Tripwire at the dispatch site asserts
params.p_dropout == 1.f to prevent any future caller from
silently triggering the (broken-by-stub) dropout codepath.

Compile-time fix: extern template declarations

The naive splitkv-aware flash_api.cu produced a runaway nvcc
compile (~30+ minutes, ~17 GB cicc RSS) during development. Root
cause: flash_fwd_launch_template.h defines the primary template
run_mha_fwd_splitkv_dispatch<>, whose body has 7 nested
*_SWITCH chains (128 kernel specialisations per
(dtype, hdim, causal) tuple, plus a 14-way combine-kernel chain).
Implicitly instantiating all 24 tuples in the dispatcher TU
expanded ~3,400 kernel specialisations through CUTLASS GEMM
templates in a single nvcc invocation.

Fixed by adding 24 extern template declarations in flash_api.cu
after the include. nvcc no longer instantiates the splitkv
specialisations in this TU; linker resolves to the explicit
instantiations in the per-hdim flash_fwd_split_hdim*_*_sm80.cu
files (which compile in parallel as 24 ~30s TUs). The dense-path
counterpart run_mha_fwd_<> was already forward-declared in
flash.h without a primary template definition, so it never had
this problem.

Testing

cuda:0, sm_86, RTX A6000:

$ cargo build           # 4m36s clean (49 kernels)
$ cargo test
  flash_attn_acausal          ok
  flash_attn_acausal_softcap  ok
  flash_attn_varlen           ok
  flash_attn_acausal_splitkv  ok   # NEW

The new flash_attn_acausal_splitkv test uses a shape
(batch=1, heads=2, seqlen_q=8, seqlen_k=512, head_dim=64) chosen
so num_splits_heuristic returns ≥ 2 on any modern sm80+ GPU,
and asserts max-abs < 5e-3 against an fp32 attention reference.

Public Rust API

flash_attn, flash_attn_windowed, flash_attn_alibi,
flash_attn_varlen, etc. — unchanged. New FFI params are internal.

Note on framing

Originally filed under the hypothesis "FA kernels are stale,
bumping fixes long-context divergence." The bump is structurally
fine and worth landing on its own merits (newer kernels + splitkv
support + no regressions), but it is correctness-neutral on
long-context
by itself. The actual long-context fix turned out
to be in candle-transformers/src/models/qwen2.rs::RotaryEmbedding::new
(filed as a separate PR, link to follow if maintainers want to
review them together). See the diagnostic update on #3515
for the full chain.

This PR is the "newer kernels + splitkv" half of the original ask.

Bumps the vendored Flash Attention 2 kernel sources from the
post-Dec-2024 state (~v2.6.x with huggingface#2688/huggingface#2689/huggingface#2690 cherry-picks)
to upstream **v2.8.3** (Tri Dao commit 060c918, 2025-08-14).
Closes the kernel-vendoring portion of huggingface#3515.

## What's in the bump

* All `kernels/*.h` and forward sm80 `*.cu` files replaced with
  v2.8.3 sources (path-remapped from upstream
  `csrc/flash_attn/src/` to candle's flat `kernels/`).
* 24 split-KV forward sm80 kernels added (NEW in v2.8.3) — one per
  {fp16, bf16} × {dense, causal} × {hdim 32, 64, 96, 128, 192, 256}.
* `flash_api.cu` dispatcher mirrors v2.8.3's
  `run_mha_fwd(params, stream, force_split_kernel)` shape: branches
  between dense `run_mha_fwd_<>` and `run_mha_fwd_splitkv_dispatch<>`
  based on `params.num_splits` and `force_split_kernel`.
* `extern "C" run_mha` FFI gains 4 params (`num_splits`,
  `softmax_lseaccum_ptr`, `oaccum_ptr`, `force_split_kernel`).
  `src/ffi.rs` mirrors. Existing call sites pass `num_splits=1` +
  null accumulator pointers + `force_split_kernel=0` until the
  dispatcher heuristic decides otherwise.
* Rust-side `set_params_splitkv` equivalent: ports upstream
  `num_splits_heuristic` (efficiency-based search, ≥85% of peak
  occupancy, capped at 128) and the accumulator allocation. Exposed
  as `#[doc(hidden)] pub` so integration tests can verify the
  dispatcher chose splitkv for a given shape.
* Dropped legacy hdim 160 / 224 / 512 forward kernels — Tri Dao
  removed their launch-template helpers in v2.x and the candle
  vendored .cu files for those dims no longer compile against v2.8.3.
* CUTLASS pin retained at `7d49e6c7`; v2.8.3's pinned commit
  `dc481792` is not required for the kernels we compile.

## Stripping PyTorch

The dropout codepath in v2.8.3's `flash_fwd_kernel.h` references
`at::PhiloxCudaState` and `at::cuda::philox::unpack`. Provide a
candle-side stub in `kernels/philox_unpack.cuh` so the path
compiles without dragging the torch C++ runtime in. Compile-time
tripwire at the dispatch site asserts `params.p_dropout == 1.f`
to prevent any future caller from silently triggering the dropout
path against the stub.

## Compile-time fix: extern template declarations

Adding the splitkv branch to `flash_api.cu` produced a runaway nvcc
compile (>30 minutes, killed at 17.9 GB cicc RSS during
development). Root cause: `flash_fwd_launch_template.h` defines the
primary template `run_mha_fwd_splitkv_dispatch<>`, whose body calls
`run_flash_splitkv_fwd<>` with **seven nested binary `*_SWITCH`
chains** — 128 kernel specialisations per (dtype, hdim, causal)
tuple, plus a 14-way combine-kernel chain. Implicitly instantiating
all 24 tuples in the dispatcher TU expanded ~3,400 kernel
specialisations through CUTLASS GEMM templates in a single nvcc
invocation.

Fixed by adding 24 `extern template` declarations in `flash_api.cu`
after the include. nvcc no longer instantiates
`run_mha_fwd_splitkv_dispatch<>` specialisations in this TU; the
linker resolves the calls to the explicit instantiation definitions
in the per-hdim `flash_fwd_split_hdim*_*_sm80.cu` files (which
compile in parallel as 24 independent ~30s TUs). The dense-path
counterpart `run_mha_fwd_<>` is forward-declared in `flash.h`
without a primary template definition, so it never had this problem.

## Testing

Build + test gate (cuda:0, sm_86, RTX A6000):

  $ cargo build           # clean: 49 kernels — 4m36s
  $ cargo test
    flash_attn_acausal          ok
    flash_attn_acausal_softcap  ok
    flash_attn_varlen           ok
    flash_attn_acausal_splitkv  ok   # NEW: exercises splitkv path

The new `flash_attn_acausal_splitkv` test uses a shape
(batch=1, heads=2, seqlen_q=8, seqlen_k=512, head_dim=64) chosen
so `num_splits_heuristic` returns ≥ 2 on any modern sm80+ GPU
(A6000 / 4090 / A100 / H100), and asserts max-abs diff < 5e-3
against an fp32 attention reference.

## Public Rust API

`flash_attn`, `flash_attn_windowed`, `flash_attn_alibi`,
`flash_attn_varlen`, etc. — unchanged. The new FFI params are
internal; defaults reproduce pre-bump behavior exactly.

## Note on motivation

Originally filed under the framing "FA kernels are stale, bumping
fixes long-context divergence." The bump itself is structurally
fine and worth landing on its own merits (newer kernels, splitkv
support, no regressions on the strata that already passed), but
it is **correctness-neutral on long-context.** The actual
long-context fix turned out to be in `qwen2.rs::RotaryEmbedding::new`
(filed separately) — see the diagnostic update on huggingface#3515. This PR
is the "newer kernels + splitkv" half of the original framing.
@toddwbucy
Copy link
Copy Markdown
Author

Cross-link: #3520 is the actual long-context correctness fix (5-line dtype change in qwen2.rs RoPE). Originally framed in #3515 as one issue; the diagnostic chain split them into two independent contributions. Either review order works.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants