Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 68 additions & 28 deletions candle-flash-attn/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,44 +5,82 @@ use cudaforge::{KernelBuilder, Result};
use std::path::PathBuf;
const CUTLASS_COMMIT: &str = "7d49e6c7e2f8896c47f586706e67e1fb215529dc";

const KERNEL_FILES: [&str; 37] = [
// PR-FA-1 (FA v2.8.3 vendor) — kernel inventory:
// - 1 dispatcher (flash_api.cu)
// - 24 forward sm80 kernels: 6 head dims × {fp16, bf16} × {dense, causal}
// for hdim 32 / 64 / 96 / 128 / 192 / 256 (the v2.8.3 set)
// - 24 split-KV forward sm80 kernels: 6 head dims × {fp16, bf16} ×
// {dense, causal} for hdim 32 / 64 / 96 / 128 / 192 / 256, NEW in
// v2.8.3. Compiled-but-not-yet-dispatched in PR-FA-1 (`num_splits=1`
// forced in flash_api.cu); splitkv dispatch lands in PR-FA-2.
//
// **Dropped legacy head dims (160 / 224 / 512).** candle's prior vendored
// state carried forward kernels for these head dims, but Tri Dao removed
// them from upstream FA at some point — both the kernel files AND the
// matching `run_mha_fwd_hdim160/224/512` launch-template helpers are
// gone in v2.8.3. The candle-vendored .cu files for those dims rely on
// the missing helpers and won't compile against v2.8.3's launch
// template. Restoring legacy hdim support would require re-vendoring
// v2.0.1-era helpers and namespace-wrapping them — out of scope for
// PR-FA-1. If a downstream consumer needs hdim 160/224/512, file a
// follow-up issue and we'll address it separately.
const KERNEL_FILES: [&str; 49] = [
"kernels/flash_api.cu",
"kernels/flash_fwd_hdim128_fp16_sm80.cu",
"kernels/flash_fwd_hdim160_fp16_sm80.cu",
"kernels/flash_fwd_hdim192_fp16_sm80.cu",
"kernels/flash_fwd_hdim224_fp16_sm80.cu",
"kernels/flash_fwd_hdim256_fp16_sm80.cu",
"kernels/flash_fwd_hdim512_fp16_sm80.cu",
// Forward sm80 — v2.8.3-supported head dims (fp16 dense)
"kernels/flash_fwd_hdim32_fp16_sm80.cu",
"kernels/flash_fwd_hdim64_fp16_sm80.cu",
"kernels/flash_fwd_hdim96_fp16_sm80.cu",
"kernels/flash_fwd_hdim128_bf16_sm80.cu",
"kernels/flash_fwd_hdim160_bf16_sm80.cu",
"kernels/flash_fwd_hdim192_bf16_sm80.cu",
"kernels/flash_fwd_hdim224_bf16_sm80.cu",
"kernels/flash_fwd_hdim256_bf16_sm80.cu",
"kernels/flash_fwd_hdim512_bf16_sm80.cu",
"kernels/flash_fwd_hdim32_bf16_sm80.cu",
"kernels/flash_fwd_hdim64_bf16_sm80.cu",
"kernels/flash_fwd_hdim96_bf16_sm80.cu",
"kernels/flash_fwd_hdim128_fp16_causal_sm80.cu",
"kernels/flash_fwd_hdim160_fp16_causal_sm80.cu",
"kernels/flash_fwd_hdim192_fp16_causal_sm80.cu",
"kernels/flash_fwd_hdim224_fp16_causal_sm80.cu",
"kernels/flash_fwd_hdim256_fp16_causal_sm80.cu",
"kernels/flash_fwd_hdim512_fp16_causal_sm80.cu",
"kernels/flash_fwd_hdim128_fp16_sm80.cu",
"kernels/flash_fwd_hdim192_fp16_sm80.cu",
"kernels/flash_fwd_hdim256_fp16_sm80.cu",
// Forward sm80 — v2.8.3-supported head dims (fp16 causal)
"kernels/flash_fwd_hdim32_fp16_causal_sm80.cu",
"kernels/flash_fwd_hdim64_fp16_causal_sm80.cu",
"kernels/flash_fwd_hdim96_fp16_causal_sm80.cu",
"kernels/flash_fwd_hdim128_bf16_causal_sm80.cu",
"kernels/flash_fwd_hdim160_bf16_causal_sm80.cu",
"kernels/flash_fwd_hdim192_bf16_causal_sm80.cu",
"kernels/flash_fwd_hdim224_bf16_causal_sm80.cu",
"kernels/flash_fwd_hdim256_bf16_causal_sm80.cu",
"kernels/flash_fwd_hdim512_bf16_causal_sm80.cu",
"kernels/flash_fwd_hdim128_fp16_causal_sm80.cu",
"kernels/flash_fwd_hdim192_fp16_causal_sm80.cu",
"kernels/flash_fwd_hdim256_fp16_causal_sm80.cu",
// Forward sm80 — v2.8.3-supported head dims (bf16 dense)
"kernels/flash_fwd_hdim32_bf16_sm80.cu",
"kernels/flash_fwd_hdim64_bf16_sm80.cu",
"kernels/flash_fwd_hdim96_bf16_sm80.cu",
"kernels/flash_fwd_hdim128_bf16_sm80.cu",
"kernels/flash_fwd_hdim192_bf16_sm80.cu",
"kernels/flash_fwd_hdim256_bf16_sm80.cu",
// Forward sm80 — v2.8.3-supported head dims (bf16 causal)
"kernels/flash_fwd_hdim32_bf16_causal_sm80.cu",
"kernels/flash_fwd_hdim64_bf16_causal_sm80.cu",
"kernels/flash_fwd_hdim96_bf16_causal_sm80.cu",
"kernels/flash_fwd_hdim128_bf16_causal_sm80.cu",
"kernels/flash_fwd_hdim192_bf16_causal_sm80.cu",
"kernels/flash_fwd_hdim256_bf16_causal_sm80.cu",
// Split-KV forward sm80 — NEW in v2.8.3 (paged-KV / multi-split
// dispatch). Compiled-but-not-yet-invoked in PR-FA-1; PR-FA-2 wires
// the splitkv branch into flash_api.cu's `run_mha_fwd`.
"kernels/flash_fwd_split_hdim32_fp16_sm80.cu",
"kernels/flash_fwd_split_hdim64_fp16_sm80.cu",
"kernels/flash_fwd_split_hdim96_fp16_sm80.cu",
"kernels/flash_fwd_split_hdim128_fp16_sm80.cu",
"kernels/flash_fwd_split_hdim192_fp16_sm80.cu",
"kernels/flash_fwd_split_hdim256_fp16_sm80.cu",
"kernels/flash_fwd_split_hdim32_fp16_causal_sm80.cu",
"kernels/flash_fwd_split_hdim64_fp16_causal_sm80.cu",
"kernels/flash_fwd_split_hdim96_fp16_causal_sm80.cu",
"kernels/flash_fwd_split_hdim128_fp16_causal_sm80.cu",
"kernels/flash_fwd_split_hdim192_fp16_causal_sm80.cu",
"kernels/flash_fwd_split_hdim256_fp16_causal_sm80.cu",
"kernels/flash_fwd_split_hdim32_bf16_sm80.cu",
"kernels/flash_fwd_split_hdim64_bf16_sm80.cu",
"kernels/flash_fwd_split_hdim96_bf16_sm80.cu",
"kernels/flash_fwd_split_hdim128_bf16_sm80.cu",
"kernels/flash_fwd_split_hdim192_bf16_sm80.cu",
"kernels/flash_fwd_split_hdim256_bf16_sm80.cu",
"kernels/flash_fwd_split_hdim32_bf16_causal_sm80.cu",
"kernels/flash_fwd_split_hdim64_bf16_causal_sm80.cu",
"kernels/flash_fwd_split_hdim96_bf16_causal_sm80.cu",
"kernels/flash_fwd_split_hdim128_bf16_causal_sm80.cu",
"kernels/flash_fwd_split_hdim192_bf16_causal_sm80.cu",
"kernels/flash_fwd_split_hdim256_bf16_causal_sm80.cu",
];

fn main() -> Result<()> {
Expand All @@ -60,6 +98,8 @@ fn main() -> Result<()> {
println!("cargo::rerun-if-changed=kernels/block_info.h");
println!("cargo::rerun-if-changed=kernels/static_switch.h");
println!("cargo::rerun-if-changed=kernels/hardware_info.h");
println!("cargo::rerun-if-changed=kernels/namespace_config.h");
println!("cargo::rerun-if-changed=kernels/philox_unpack.cuh");
let out_dir = PathBuf::from(std::env::var("OUT_DIR").expect("OUT_DIR not set"));
let build_dir = match std::env::var("CANDLE_FLASH_ATTN_BUILD_DIR") {
Err(_) =>
Expand Down
5 changes: 3 additions & 2 deletions candle-flash-attn/kernels/alibi.h
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
#include <cmath>

#include "namespace_config.h"
#include <cute/tensor.hpp>

#include <cutlass/cutlass.h>
#include <cutlass/array.h>

#include "utils.h"

namespace flash {
namespace FLASH_NAMESPACE {

using namespace cute;

Expand Down Expand Up @@ -71,4 +72,4 @@ struct Alibi {

};

} // namespace flash
} // namespace FLASH_NAMESPACE
5 changes: 3 additions & 2 deletions candle-flash-attn/kernels/block_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

#pragma once

namespace flash {
#include "namespace_config.h"
namespace FLASH_NAMESPACE {

////////////////////////////////////////////////////////////////////////////////////////////////////

Expand Down Expand Up @@ -45,4 +46,4 @@ struct BlockInfo {

////////////////////////////////////////////////////////////////////////////////////////////////////

} // namespace flash
} // namespace FLASH_NAMESPACE
9 changes: 5 additions & 4 deletions candle-flash-attn/kernels/dropout.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@

#pragma once

#include "namespace_config.h"
#include "philox.cuh"
#include "utils.h"

namespace flash {
namespace FLASH_NAMESPACE {

struct Dropout {

Expand All @@ -26,7 +27,7 @@ struct Dropout {
__forceinline__ __device__ void apply_dropout(Tensor<Engine, Layout> &tensor_,
int block_row_start, int block_col_start, int block_row_stride) {
// convert shape from (4, MMA_M, MMA_N) to (8, MMA_M, MMA_N / 2)
Tensor tensor = make_tensor(tensor_.data(), flash::convert_layout_acc_dropout(tensor_.layout()));
Tensor tensor = make_tensor(tensor_.data(), FLASH_NAMESPACE::convert_layout_acc_dropout(tensor_.layout()));
using T = typename Engine::value_type;
auto encode_dropout = [](bool keep, T val) {
return keep ? val : (encode_dropout_in_sign_bit ? -val : T(0));
Expand All @@ -41,7 +42,7 @@ struct Dropout {
#pragma unroll
for (int n = 0; n < size<2>(tensor) / 2; ++n, ++rowcol.y) {
// if (cute::thread(32, 0)) { printf("m = %d, n = %d, row = %d, col = %d\n", m, n, int(rowcol.x), int(rowcol.y));}
uint4 random_uint4 = flash::philox(seed, reinterpret_cast<unsigned long long&>(rowcol), offset);
uint4 random_uint4 = FLASH_NAMESPACE::philox(seed, reinterpret_cast<unsigned long long&>(rowcol), offset);
// if (cute::thread0()) { printf("philox = %u, %d, %d, %d\n", random_uint4.x, random_uint4.y, random_uint4.z, random_uint4.w);}
uint8_t (&rnd_8)[16] = reinterpret_cast<uint8_t (&)[16]>(random_uint4);
// Special implementation for 16-bit types: we duplicate the threshold to the
Expand Down Expand Up @@ -91,4 +92,4 @@ struct Dropout {

};

} // namespace flash
} // namespace FLASH_NAMESPACE
13 changes: 9 additions & 4 deletions candle-flash-attn/kernels/flash.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,14 @@

#pragma once

#include "namespace_config.h"

#include <cuda.h>
#include <vector>

// #include <ATen/cuda/CUDAGeneratorImpl.h> // For at::Generator and at::PhiloxCudaState
#include "philox_unpack.cuh" // candle: stub providing at::PhiloxCudaState (replaces the PyTorch at::cuda dep)

namespace FLASH_NAMESPACE {
constexpr int TOTAL_DIM = 0;
constexpr int H_DIM = 1;
constexpr int D_DIM = 2;
Expand Down Expand Up @@ -116,7 +119,7 @@ struct Flash_fwd_params : public Qkv_params {
float softcap;

// Random state.
// at::PhiloxCudaState philox_args;
at::PhiloxCudaState philox_args; // candle: stubbed type from philox_unpack.cuh — kernel computes-but-ignores when Is_dropout=false

// Pointer to the RNG seed (idx 0) and offset (idx 1).
uint64_t * rng_state;
Expand Down Expand Up @@ -184,6 +187,8 @@ struct Flash_bwd_params : public Flash_fwd_params {
////////////////////////////////////////////////////////////////////////////////////////////////////

template<typename T, int Headdim, bool Is_causal> void run_mha_fwd_(Flash_fwd_params &params, cudaStream_t stream);
// template<typename T, int Headdim, bool Is_causal> void run_mha_fwd_splitkv_dispatch(Flash_fwd_params &params, cudaStream_t stream);
template<typename T, int Headdim, bool Is_causal> void run_mha_fwd_splitkv_dispatch(Flash_fwd_params &params, cudaStream_t stream);

template<typename T, int Headdim, bool Is_causal> void run_mha_bwd_(Flash_bwd_params &params, cudaStream_t stream);

// template<typename T, int Headdim, bool Is_causal> void run_mha_bwd_(Flash_bwd_params &params, cudaStream_t stream);
} // namespace FLASH_NAMESPACE
125 changes: 119 additions & 6 deletions candle-flash-attn/kernels/flash_api.cu
Original file line number Diff line number Diff line change
@@ -1,17 +1,99 @@
// candle-flash-attn host dispatch: thin extern "C" wrapper around
// Tri Dao's flash_fwd kernel templates.
//
// PR-FA-2 update: extend the dispatcher to mirror v2.8.3's
// `run_mha_fwd(params, stream, force_split_kernel)` shape — branch on
// `num_splits <= 1 && !force_split_kernel` to choose between the dense
// `run_mha_fwd_<>` and `run_mha_fwd_splitkv_dispatch<>` kernel
// templates. The FFI `extern "C" run_mha` exposes the new params
// (`num_splits`, `softmax_lseaccum_ptr`, `oaccum_ptr`,
// `force_split_kernel`) so Rust callers can drive splitkv. PR-FA-2
// keeps Rust-side defaults at `num_splits=1` and null accumulator
// pointers, so the dense path is taken and existing behavior is
// unchanged. PR-FA-3 wires the Rust-side `set_params_splitkv`
// equivalent (heuristic + accumulator buffer allocation).
//
// PR-FA-1 (already merged): vendored kernels were bumped from the
// post-Dec-2024 state to upstream v2.8.3 (commit 060c918, 2025-08-14).
// v2.8.3 wraps the kernel templates in `namespace flash`, so
// `run_mha_fwd_<>` and `Flash_fwd_params` live under `FLASH_NAMESPACE`
// (= `flash`).

#include <cstdio>
#include <cstdlib>

#include "kernels.h"
#include "kernel_helpers.h"
#include "namespace_config.h"
#include "flash_fwd_launch_template.h"

void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream) {
namespace FLASH_NAMESPACE {

// Suppress implicit instantiation of `run_mha_fwd_splitkv_dispatch<>` in this
// TU. Without these declarations cicc would expand all 24 tuples
// (2 dtypes × 6 hdims × 2 causal) of the splitkv dispatcher inline here, and
// each tuple instantiates ~142 kernel specialisations through the seven nested
// SWITCH macros in `run_flash_splitkv_fwd<>` — a single-TU compile that ran
// >30 minutes at ~18 GB RSS before being killed during PR-FA-2 development.
//
// The corresponding explicit instantiation definitions live in the per-hdim
// `flash_fwd_split_hdim*_*_sm80.cu` files (e.g.
// `template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 64, false>(...)`),
// which compile in parallel as 24 independent TUs (~30s each on this machine).
// The linker resolves the calls in this TU to those out-of-line definitions.
//
// Note: `run_mha_fwd_<>` (the dense-path counterpart) is forward-declared in
// `flash.h` without a primary template definition; cicc therefore never tries
// to implicitly instantiate it here, and no extern declarations are required.
extern template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 32, false>(Flash_fwd_params &params, cudaStream_t stream);
extern template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 32, true>(Flash_fwd_params &params, cudaStream_t stream);
extern template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 64, false>(Flash_fwd_params &params, cudaStream_t stream);
extern template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 64, true>(Flash_fwd_params &params, cudaStream_t stream);
extern template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 96, false>(Flash_fwd_params &params, cudaStream_t stream);
extern template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 96, true>(Flash_fwd_params &params, cudaStream_t stream);
extern template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 128, false>(Flash_fwd_params &params, cudaStream_t stream);
extern template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 128, true>(Flash_fwd_params &params, cudaStream_t stream);
extern template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 192, false>(Flash_fwd_params &params, cudaStream_t stream);
extern template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 192, true>(Flash_fwd_params &params, cudaStream_t stream);
extern template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 256, false>(Flash_fwd_params &params, cudaStream_t stream);
extern template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 256, true>(Flash_fwd_params &params, cudaStream_t stream);
extern template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 32, false>(Flash_fwd_params &params, cudaStream_t stream);
extern template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 32, true>(Flash_fwd_params &params, cudaStream_t stream);
extern template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 64, false>(Flash_fwd_params &params, cudaStream_t stream);
extern template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 64, true>(Flash_fwd_params &params, cudaStream_t stream);
extern template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 96, false>(Flash_fwd_params &params, cudaStream_t stream);
extern template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 96, true>(Flash_fwd_params &params, cudaStream_t stream);
extern template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 128, false>(Flash_fwd_params &params, cudaStream_t stream);
extern template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 128, true>(Flash_fwd_params &params, cudaStream_t stream);
extern template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 192, false>(Flash_fwd_params &params, cudaStream_t stream);
extern template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 192, true>(Flash_fwd_params &params, cudaStream_t stream);
extern template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 256, false>(Flash_fwd_params &params, cudaStream_t stream);
extern template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 256, true>(Flash_fwd_params &params, cudaStream_t stream);

// Templated dispatch wrapper. Mirrors v2.8.3's
// `flash_api.cpp::run_mha_fwd` — chooses between the dense
// `run_mha_fwd_<elem_type, kHeadDim, Is_causal>` specialisation and
// the `run_mha_fwd_splitkv_dispatch<elem_type, kHeadDim, Is_causal>`
// specialisation based on `params.num_splits` and the explicit
// `force_split_kernel` override (used upstream for paged-KV / cached-K
// paths; not yet plumbed in candle but exposed for FFI symmetry).
void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream,
bool force_split_kernel = false) {
FP16_SWITCH(!params.is_bf16, [&] {
HEADDIM_SWITCH(params.d, [&] {
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
run_mha_fwd_<elem_type, kHeadDim, Is_causal>(params, stream);
if (params.num_splits <= 1 && !force_split_kernel) {
run_mha_fwd_<elem_type, kHeadDim, Is_causal>(params, stream);
} else {
run_mha_fwd_splitkv_dispatch<elem_type, kHeadDim, Is_causal>(params, stream);
}
});
});
});
}

} // namespace FLASH_NAMESPACE

extern "C" void run_mha(
void *q_ptr,
void *k_ptr,
Expand Down Expand Up @@ -58,9 +140,22 @@ extern "C" void run_mha(
int window_size_left,
int window_size_right,

float softcap
float softcap,

// PR-FA-2: split-KV dispatch surface. `num_splits<=1 && !force_split_kernel`
// takes the dense path (existing behavior); `num_splits>1` or
// `force_split_kernel != 0` enters `run_mha_fwd_splitkv_dispatch<>`. The
// `_accum_ptr` buffers must be fp32 of shape
// `(num_splits, b, h, seqlen_q[, d_rounded])`; the caller (Rust side) is
// responsible for allocation. Defaults of `num_splits=1`,
// `softmax_lseaccum_ptr=nullptr`, `oaccum_ptr=nullptr`, `force_split_kernel=0`
// reproduce PR-FA-1 behavior exactly.
int num_splits,
void *softmax_lseaccum_ptr,
void *oaccum_ptr,
int force_split_kernel
) {
Flash_fwd_params params;
FLASH_NAMESPACE::Flash_fwd_params params;
// Reset the parameters
memset(&params, 0, sizeof(params));

Expand Down Expand Up @@ -128,9 +223,27 @@ extern "C" void run_mha(
params.window_size_right = window_size_right;

params.is_seqlens_k_cumulative = true;
params.num_splits = 1;
params.num_splits = num_splits;
params.softmax_lseaccum_ptr = softmax_lseaccum_ptr;
params.oaccum_ptr = oaccum_ptr;
params.unpadded_lse = unpadded_lse;

// Tripwire: candle-flash-attn does not support dropout. `philox_unpack.cuh`
// is a stubbed replacement (returns a fake seed/offset pair) so the dropout
// codepath inside `flash_fwd_kernel.h` compiles, but executing it would
// silently produce garbage. Dropout is currently impossible to reach because
// `params.p_dropout` is hard-set to 1.0 above and is not an FFI input —
// this check catches anyone re-introducing a dropout path without also
// wiring a real philox state. Unconditional (not `assert`) so the guard
// remains active in release builds.
if (params.p_dropout != 1.f) {
std::fprintf(stderr,
"candle-flash-attn: dropout is not supported "
"(philox_unpack.cuh is stubbed); got p_dropout=%f\n",
params.p_dropout);
std::abort();
}

cudaStream_t stream = 0; // Use the default stream.
run_mha_fwd(params, stream);
FLASH_NAMESPACE::run_mha_fwd(params, stream, force_split_kernel != 0);
}
Loading