diff --git a/candle-flash-attn/build.rs b/candle-flash-attn/build.rs index 8152477cb9..26981b3f11 100644 --- a/candle-flash-attn/build.rs +++ b/candle-flash-attn/build.rs @@ -5,44 +5,82 @@ use cudaforge::{KernelBuilder, Result}; use std::path::PathBuf; const CUTLASS_COMMIT: &str = "7d49e6c7e2f8896c47f586706e67e1fb215529dc"; -const KERNEL_FILES: [&str; 37] = [ +// PR-FA-1 (FA v2.8.3 vendor) — kernel inventory: +// - 1 dispatcher (flash_api.cu) +// - 24 forward sm80 kernels: 6 head dims × {fp16, bf16} × {dense, causal} +// for hdim 32 / 64 / 96 / 128 / 192 / 256 (the v2.8.3 set) +// - 24 split-KV forward sm80 kernels: 6 head dims × {fp16, bf16} × +// {dense, causal} for hdim 32 / 64 / 96 / 128 / 192 / 256, NEW in +// v2.8.3. Compiled-but-not-yet-dispatched in PR-FA-1 (`num_splits=1` +// forced in flash_api.cu); splitkv dispatch lands in PR-FA-2. +// +// **Dropped legacy head dims (160 / 224 / 512).** candle's prior vendored +// state carried forward kernels for these head dims, but Tri Dao removed +// them from upstream FA at some point — both the kernel files AND the +// matching `run_mha_fwd_hdim160/224/512` launch-template helpers are +// gone in v2.8.3. The candle-vendored .cu files for those dims rely on +// the missing helpers and won't compile against v2.8.3's launch +// template. Restoring legacy hdim support would require re-vendoring +// v2.0.1-era helpers and namespace-wrapping them — out of scope for +// PR-FA-1. If a downstream consumer needs hdim 160/224/512, file a +// follow-up issue and we'll address it separately. +const KERNEL_FILES: [&str; 49] = [ "kernels/flash_api.cu", - "kernels/flash_fwd_hdim128_fp16_sm80.cu", - "kernels/flash_fwd_hdim160_fp16_sm80.cu", - "kernels/flash_fwd_hdim192_fp16_sm80.cu", - "kernels/flash_fwd_hdim224_fp16_sm80.cu", - "kernels/flash_fwd_hdim256_fp16_sm80.cu", - "kernels/flash_fwd_hdim512_fp16_sm80.cu", + // Forward sm80 — v2.8.3-supported head dims (fp16 dense) "kernels/flash_fwd_hdim32_fp16_sm80.cu", "kernels/flash_fwd_hdim64_fp16_sm80.cu", "kernels/flash_fwd_hdim96_fp16_sm80.cu", - "kernels/flash_fwd_hdim128_bf16_sm80.cu", - "kernels/flash_fwd_hdim160_bf16_sm80.cu", - "kernels/flash_fwd_hdim192_bf16_sm80.cu", - "kernels/flash_fwd_hdim224_bf16_sm80.cu", - "kernels/flash_fwd_hdim256_bf16_sm80.cu", - "kernels/flash_fwd_hdim512_bf16_sm80.cu", - "kernels/flash_fwd_hdim32_bf16_sm80.cu", - "kernels/flash_fwd_hdim64_bf16_sm80.cu", - "kernels/flash_fwd_hdim96_bf16_sm80.cu", - "kernels/flash_fwd_hdim128_fp16_causal_sm80.cu", - "kernels/flash_fwd_hdim160_fp16_causal_sm80.cu", - "kernels/flash_fwd_hdim192_fp16_causal_sm80.cu", - "kernels/flash_fwd_hdim224_fp16_causal_sm80.cu", - "kernels/flash_fwd_hdim256_fp16_causal_sm80.cu", - "kernels/flash_fwd_hdim512_fp16_causal_sm80.cu", + "kernels/flash_fwd_hdim128_fp16_sm80.cu", + "kernels/flash_fwd_hdim192_fp16_sm80.cu", + "kernels/flash_fwd_hdim256_fp16_sm80.cu", + // Forward sm80 — v2.8.3-supported head dims (fp16 causal) "kernels/flash_fwd_hdim32_fp16_causal_sm80.cu", "kernels/flash_fwd_hdim64_fp16_causal_sm80.cu", "kernels/flash_fwd_hdim96_fp16_causal_sm80.cu", - "kernels/flash_fwd_hdim128_bf16_causal_sm80.cu", - "kernels/flash_fwd_hdim160_bf16_causal_sm80.cu", - "kernels/flash_fwd_hdim192_bf16_causal_sm80.cu", - "kernels/flash_fwd_hdim224_bf16_causal_sm80.cu", - "kernels/flash_fwd_hdim256_bf16_causal_sm80.cu", - "kernels/flash_fwd_hdim512_bf16_causal_sm80.cu", + "kernels/flash_fwd_hdim128_fp16_causal_sm80.cu", + "kernels/flash_fwd_hdim192_fp16_causal_sm80.cu", + "kernels/flash_fwd_hdim256_fp16_causal_sm80.cu", + // Forward sm80 — v2.8.3-supported head dims (bf16 dense) + "kernels/flash_fwd_hdim32_bf16_sm80.cu", + "kernels/flash_fwd_hdim64_bf16_sm80.cu", + "kernels/flash_fwd_hdim96_bf16_sm80.cu", + "kernels/flash_fwd_hdim128_bf16_sm80.cu", + "kernels/flash_fwd_hdim192_bf16_sm80.cu", + "kernels/flash_fwd_hdim256_bf16_sm80.cu", + // Forward sm80 — v2.8.3-supported head dims (bf16 causal) "kernels/flash_fwd_hdim32_bf16_causal_sm80.cu", "kernels/flash_fwd_hdim64_bf16_causal_sm80.cu", "kernels/flash_fwd_hdim96_bf16_causal_sm80.cu", + "kernels/flash_fwd_hdim128_bf16_causal_sm80.cu", + "kernels/flash_fwd_hdim192_bf16_causal_sm80.cu", + "kernels/flash_fwd_hdim256_bf16_causal_sm80.cu", + // Split-KV forward sm80 — NEW in v2.8.3 (paged-KV / multi-split + // dispatch). Compiled-but-not-yet-invoked in PR-FA-1; PR-FA-2 wires + // the splitkv branch into flash_api.cu's `run_mha_fwd`. + "kernels/flash_fwd_split_hdim32_fp16_sm80.cu", + "kernels/flash_fwd_split_hdim64_fp16_sm80.cu", + "kernels/flash_fwd_split_hdim96_fp16_sm80.cu", + "kernels/flash_fwd_split_hdim128_fp16_sm80.cu", + "kernels/flash_fwd_split_hdim192_fp16_sm80.cu", + "kernels/flash_fwd_split_hdim256_fp16_sm80.cu", + "kernels/flash_fwd_split_hdim32_fp16_causal_sm80.cu", + "kernels/flash_fwd_split_hdim64_fp16_causal_sm80.cu", + "kernels/flash_fwd_split_hdim96_fp16_causal_sm80.cu", + "kernels/flash_fwd_split_hdim128_fp16_causal_sm80.cu", + "kernels/flash_fwd_split_hdim192_fp16_causal_sm80.cu", + "kernels/flash_fwd_split_hdim256_fp16_causal_sm80.cu", + "kernels/flash_fwd_split_hdim32_bf16_sm80.cu", + "kernels/flash_fwd_split_hdim64_bf16_sm80.cu", + "kernels/flash_fwd_split_hdim96_bf16_sm80.cu", + "kernels/flash_fwd_split_hdim128_bf16_sm80.cu", + "kernels/flash_fwd_split_hdim192_bf16_sm80.cu", + "kernels/flash_fwd_split_hdim256_bf16_sm80.cu", + "kernels/flash_fwd_split_hdim32_bf16_causal_sm80.cu", + "kernels/flash_fwd_split_hdim64_bf16_causal_sm80.cu", + "kernels/flash_fwd_split_hdim96_bf16_causal_sm80.cu", + "kernels/flash_fwd_split_hdim128_bf16_causal_sm80.cu", + "kernels/flash_fwd_split_hdim192_bf16_causal_sm80.cu", + "kernels/flash_fwd_split_hdim256_bf16_causal_sm80.cu", ]; fn main() -> Result<()> { @@ -60,6 +98,8 @@ fn main() -> Result<()> { println!("cargo::rerun-if-changed=kernels/block_info.h"); println!("cargo::rerun-if-changed=kernels/static_switch.h"); println!("cargo::rerun-if-changed=kernels/hardware_info.h"); + println!("cargo::rerun-if-changed=kernels/namespace_config.h"); + println!("cargo::rerun-if-changed=kernels/philox_unpack.cuh"); let out_dir = PathBuf::from(std::env::var("OUT_DIR").expect("OUT_DIR not set")); let build_dir = match std::env::var("CANDLE_FLASH_ATTN_BUILD_DIR") { Err(_) => diff --git a/candle-flash-attn/kernels/alibi.h b/candle-flash-attn/kernels/alibi.h index e714233e7e..a65a5b3790 100644 --- a/candle-flash-attn/kernels/alibi.h +++ b/candle-flash-attn/kernels/alibi.h @@ -1,5 +1,6 @@ #include +#include "namespace_config.h" #include #include @@ -7,7 +8,7 @@ #include "utils.h" -namespace flash { +namespace FLASH_NAMESPACE { using namespace cute; @@ -71,4 +72,4 @@ struct Alibi { }; -} // namespace flash +} // namespace FLASH_NAMESPACE diff --git a/candle-flash-attn/kernels/block_info.h b/candle-flash-attn/kernels/block_info.h index cf60d653c3..9c8baff754 100644 --- a/candle-flash-attn/kernels/block_info.h +++ b/candle-flash-attn/kernels/block_info.h @@ -4,7 +4,8 @@ #pragma once -namespace flash { +#include "namespace_config.h" +namespace FLASH_NAMESPACE { //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -45,4 +46,4 @@ struct BlockInfo { //////////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace flash +} // namespace FLASH_NAMESPACE diff --git a/candle-flash-attn/kernels/dropout.h b/candle-flash-attn/kernels/dropout.h index 4882f97d93..9077b79913 100644 --- a/candle-flash-attn/kernels/dropout.h +++ b/candle-flash-attn/kernels/dropout.h @@ -4,10 +4,11 @@ #pragma once +#include "namespace_config.h" #include "philox.cuh" #include "utils.h" -namespace flash { +namespace FLASH_NAMESPACE { struct Dropout { @@ -26,7 +27,7 @@ struct Dropout { __forceinline__ __device__ void apply_dropout(Tensor &tensor_, int block_row_start, int block_col_start, int block_row_stride) { // convert shape from (4, MMA_M, MMA_N) to (8, MMA_M, MMA_N / 2) - Tensor tensor = make_tensor(tensor_.data(), flash::convert_layout_acc_dropout(tensor_.layout())); + Tensor tensor = make_tensor(tensor_.data(), FLASH_NAMESPACE::convert_layout_acc_dropout(tensor_.layout())); using T = typename Engine::value_type; auto encode_dropout = [](bool keep, T val) { return keep ? val : (encode_dropout_in_sign_bit ? -val : T(0)); @@ -41,7 +42,7 @@ struct Dropout { #pragma unroll for (int n = 0; n < size<2>(tensor) / 2; ++n, ++rowcol.y) { // if (cute::thread(32, 0)) { printf("m = %d, n = %d, row = %d, col = %d\n", m, n, int(rowcol.x), int(rowcol.y));} - uint4 random_uint4 = flash::philox(seed, reinterpret_cast(rowcol), offset); + uint4 random_uint4 = FLASH_NAMESPACE::philox(seed, reinterpret_cast(rowcol), offset); // if (cute::thread0()) { printf("philox = %u, %d, %d, %d\n", random_uint4.x, random_uint4.y, random_uint4.z, random_uint4.w);} uint8_t (&rnd_8)[16] = reinterpret_cast(random_uint4); // Special implementation for 16-bit types: we duplicate the threshold to the @@ -91,4 +92,4 @@ struct Dropout { }; -} // namespace flash +} // namespace FLASH_NAMESPACE diff --git a/candle-flash-attn/kernels/flash.h b/candle-flash-attn/kernels/flash.h index f21e4d6205..1405f17776 100644 --- a/candle-flash-attn/kernels/flash.h +++ b/candle-flash-attn/kernels/flash.h @@ -4,11 +4,14 @@ #pragma once +#include "namespace_config.h" + #include #include -// #include // For at::Generator and at::PhiloxCudaState +#include "philox_unpack.cuh" // candle: stub providing at::PhiloxCudaState (replaces the PyTorch at::cuda dep) +namespace FLASH_NAMESPACE { constexpr int TOTAL_DIM = 0; constexpr int H_DIM = 1; constexpr int D_DIM = 2; @@ -116,7 +119,7 @@ struct Flash_fwd_params : public Qkv_params { float softcap; // Random state. - // at::PhiloxCudaState philox_args; + at::PhiloxCudaState philox_args; // candle: stubbed type from philox_unpack.cuh — kernel computes-but-ignores when Is_dropout=false // Pointer to the RNG seed (idx 0) and offset (idx 1). uint64_t * rng_state; @@ -184,6 +187,8 @@ struct Flash_bwd_params : public Flash_fwd_params { //////////////////////////////////////////////////////////////////////////////////////////////////// template void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream); -// template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +template void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream); -// template void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream); +} // namespace FLASH_NAMESPACE diff --git a/candle-flash-attn/kernels/flash_api.cu b/candle-flash-attn/kernels/flash_api.cu index d172bef842..fc82c7aa74 100644 --- a/candle-flash-attn/kernels/flash_api.cu +++ b/candle-flash-attn/kernels/flash_api.cu @@ -1,17 +1,99 @@ +// candle-flash-attn host dispatch: thin extern "C" wrapper around +// Tri Dao's flash_fwd kernel templates. +// +// PR-FA-2 update: extend the dispatcher to mirror v2.8.3's +// `run_mha_fwd(params, stream, force_split_kernel)` shape — branch on +// `num_splits <= 1 && !force_split_kernel` to choose between the dense +// `run_mha_fwd_<>` and `run_mha_fwd_splitkv_dispatch<>` kernel +// templates. The FFI `extern "C" run_mha` exposes the new params +// (`num_splits`, `softmax_lseaccum_ptr`, `oaccum_ptr`, +// `force_split_kernel`) so Rust callers can drive splitkv. PR-FA-2 +// keeps Rust-side defaults at `num_splits=1` and null accumulator +// pointers, so the dense path is taken and existing behavior is +// unchanged. PR-FA-3 wires the Rust-side `set_params_splitkv` +// equivalent (heuristic + accumulator buffer allocation). +// +// PR-FA-1 (already merged): vendored kernels were bumped from the +// post-Dec-2024 state to upstream v2.8.3 (commit 060c918, 2025-08-14). +// v2.8.3 wraps the kernel templates in `namespace flash`, so +// `run_mha_fwd_<>` and `Flash_fwd_params` live under `FLASH_NAMESPACE` +// (= `flash`). + +#include +#include + #include "kernels.h" #include "kernel_helpers.h" +#include "namespace_config.h" #include "flash_fwd_launch_template.h" -void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { +namespace FLASH_NAMESPACE { + +// Suppress implicit instantiation of `run_mha_fwd_splitkv_dispatch<>` in this +// TU. Without these declarations cicc would expand all 24 tuples +// (2 dtypes × 6 hdims × 2 causal) of the splitkv dispatcher inline here, and +// each tuple instantiates ~142 kernel specialisations through the seven nested +// SWITCH macros in `run_flash_splitkv_fwd<>` — a single-TU compile that ran +// >30 minutes at ~18 GB RSS before being killed during PR-FA-2 development. +// +// The corresponding explicit instantiation definitions live in the per-hdim +// `flash_fwd_split_hdim*_*_sm80.cu` files (e.g. +// `template void run_mha_fwd_splitkv_dispatch(...)`), +// which compile in parallel as 24 independent TUs (~30s each on this machine). +// The linker resolves the calls in this TU to those out-of-line definitions. +// +// Note: `run_mha_fwd_<>` (the dense-path counterpart) is forward-declared in +// `flash.h` without a primary template definition; cicc therefore never tries +// to implicitly instantiate it here, and no extern declarations are required. +extern template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +extern template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +extern template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +extern template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +extern template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +extern template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +extern template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +extern template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +extern template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +extern template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +extern template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +extern template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +extern template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +extern template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +extern template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +extern template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +extern template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +extern template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +extern template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +extern template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +extern template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +extern template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +extern template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +extern template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +// Templated dispatch wrapper. Mirrors v2.8.3's +// `flash_api.cpp::run_mha_fwd` — chooses between the dense +// `run_mha_fwd_` specialisation and +// the `run_mha_fwd_splitkv_dispatch` +// specialisation based on `params.num_splits` and the explicit +// `force_split_kernel` override (used upstream for paged-KV / cached-K +// paths; not yet plumbed in candle but exposed for FFI symmetry). +void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream, + bool force_split_kernel = false) { FP16_SWITCH(!params.is_bf16, [&] { HEADDIM_SWITCH(params.d, [&] { BOOL_SWITCH(params.is_causal, Is_causal, [&] { - run_mha_fwd_(params, stream); + if (params.num_splits <= 1 && !force_split_kernel) { + run_mha_fwd_(params, stream); + } else { + run_mha_fwd_splitkv_dispatch(params, stream); + } }); }); }); } +} // namespace FLASH_NAMESPACE + extern "C" void run_mha( void *q_ptr, void *k_ptr, @@ -58,9 +140,22 @@ extern "C" void run_mha( int window_size_left, int window_size_right, - float softcap + float softcap, + + // PR-FA-2: split-KV dispatch surface. `num_splits<=1 && !force_split_kernel` + // takes the dense path (existing behavior); `num_splits>1` or + // `force_split_kernel != 0` enters `run_mha_fwd_splitkv_dispatch<>`. The + // `_accum_ptr` buffers must be fp32 of shape + // `(num_splits, b, h, seqlen_q[, d_rounded])`; the caller (Rust side) is + // responsible for allocation. Defaults of `num_splits=1`, + // `softmax_lseaccum_ptr=nullptr`, `oaccum_ptr=nullptr`, `force_split_kernel=0` + // reproduce PR-FA-1 behavior exactly. + int num_splits, + void *softmax_lseaccum_ptr, + void *oaccum_ptr, + int force_split_kernel ) { - Flash_fwd_params params; + FLASH_NAMESPACE::Flash_fwd_params params; // Reset the parameters memset(¶ms, 0, sizeof(params)); @@ -128,9 +223,27 @@ extern "C" void run_mha( params.window_size_right = window_size_right; params.is_seqlens_k_cumulative = true; - params.num_splits = 1; + params.num_splits = num_splits; + params.softmax_lseaccum_ptr = softmax_lseaccum_ptr; + params.oaccum_ptr = oaccum_ptr; params.unpadded_lse = unpadded_lse; + // Tripwire: candle-flash-attn does not support dropout. `philox_unpack.cuh` + // is a stubbed replacement (returns a fake seed/offset pair) so the dropout + // codepath inside `flash_fwd_kernel.h` compiles, but executing it would + // silently produce garbage. Dropout is currently impossible to reach because + // `params.p_dropout` is hard-set to 1.0 above and is not an FFI input — + // this check catches anyone re-introducing a dropout path without also + // wiring a real philox state. Unconditional (not `assert`) so the guard + // remains active in release builds. + if (params.p_dropout != 1.f) { + std::fprintf(stderr, + "candle-flash-attn: dropout is not supported " + "(philox_unpack.cuh is stubbed); got p_dropout=%f\n", + params.p_dropout); + std::abort(); + } + cudaStream_t stream = 0; // Use the default stream. - run_mha_fwd(params, stream); + FLASH_NAMESPACE::run_mha_fwd(params, stream, force_split_kernel != 0); } diff --git a/candle-flash-attn/kernels/flash_fwd_hdim128_bf16_causal_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim128_bf16_causal_sm80.cu index 9383c10249..baca4777bf 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim128_bf16_causal_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim128_bf16_causal_sm80.cu @@ -1,10 +1,14 @@ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" - +#include "namespace_config.h" #include "flash_fwd_launch_template.h" +namespace FLASH_NAMESPACE { + template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim128(params, stream); } + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/candle-flash-attn/kernels/flash_fwd_hdim128_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim128_bf16_sm80.cu index f03abda486..230059eab3 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim128_bf16_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim128_bf16_sm80.cu @@ -1,10 +1,14 @@ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" - +#include "namespace_config.h" #include "flash_fwd_launch_template.h" +namespace FLASH_NAMESPACE { + template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim128(params, stream); } + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/candle-flash-attn/kernels/flash_fwd_hdim128_fp16_causal_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim128_fp16_causal_sm80.cu index c616628c87..ab3ce5500d 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim128_fp16_causal_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim128_fp16_causal_sm80.cu @@ -1,10 +1,14 @@ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" - +#include "namespace_config.h" #include "flash_fwd_launch_template.h" +namespace FLASH_NAMESPACE { + template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim128(params, stream); } + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/candle-flash-attn/kernels/flash_fwd_hdim128_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim128_fp16_sm80.cu index 4ff6b9fbfb..0303285449 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim128_fp16_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim128_fp16_sm80.cu @@ -1,10 +1,14 @@ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" - +#include "namespace_config.h" #include "flash_fwd_launch_template.h" +namespace FLASH_NAMESPACE { + template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim128(params, stream); } + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/candle-flash-attn/kernels/flash_fwd_hdim160_bf16_causal_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim160_bf16_causal_sm80.cu deleted file mode 100644 index d6d4371bfb..0000000000 --- a/candle-flash-attn/kernels/flash_fwd_hdim160_bf16_causal_sm80.cu +++ /dev/null @@ -1,10 +0,0 @@ -// Copyright (c) 2024, Tri Dao. -// Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim160(params, stream); -} diff --git a/candle-flash-attn/kernels/flash_fwd_hdim160_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim160_bf16_sm80.cu deleted file mode 100644 index 5af68ac38f..0000000000 --- a/candle-flash-attn/kernels/flash_fwd_hdim160_bf16_sm80.cu +++ /dev/null @@ -1,10 +0,0 @@ -// Copyright (c) 2024, Tri Dao. -// Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim160(params, stream); -} diff --git a/candle-flash-attn/kernels/flash_fwd_hdim192_bf16_causal_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim192_bf16_causal_sm80.cu index 077d25d091..c04b7b9e08 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim192_bf16_causal_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim192_bf16_causal_sm80.cu @@ -1,10 +1,14 @@ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" - +#include "namespace_config.h" #include "flash_fwd_launch_template.h" +namespace FLASH_NAMESPACE { + template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim192(params, stream); } + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/candle-flash-attn/kernels/flash_fwd_hdim192_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim192_bf16_sm80.cu index ea5f265fe3..72468c3829 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim192_bf16_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim192_bf16_sm80.cu @@ -1,10 +1,14 @@ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" - +#include "namespace_config.h" #include "flash_fwd_launch_template.h" +namespace FLASH_NAMESPACE { + template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim192(params, stream); } + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/candle-flash-attn/kernels/flash_fwd_hdim192_fp16_causal_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim192_fp16_causal_sm80.cu index a4a7bc2422..89ffeb8e1d 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim192_fp16_causal_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim192_fp16_causal_sm80.cu @@ -1,10 +1,14 @@ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" - +#include "namespace_config.h" #include "flash_fwd_launch_template.h" +namespace FLASH_NAMESPACE { + template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim192(params, stream); } + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/candle-flash-attn/kernels/flash_fwd_hdim192_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim192_fp16_sm80.cu index c30c4a14fe..729660e8e9 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim192_fp16_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim192_fp16_sm80.cu @@ -1,10 +1,14 @@ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" - +#include "namespace_config.h" #include "flash_fwd_launch_template.h" +namespace FLASH_NAMESPACE { + template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim192(params, stream); } + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/candle-flash-attn/kernels/flash_fwd_hdim224_bf16_causal_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim224_bf16_causal_sm80.cu deleted file mode 100644 index db69f21cdf..0000000000 --- a/candle-flash-attn/kernels/flash_fwd_hdim224_bf16_causal_sm80.cu +++ /dev/null @@ -1,10 +0,0 @@ -// Copyright (c) 2024, Tri Dao. -// Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim224(params, stream); -} diff --git a/candle-flash-attn/kernels/flash_fwd_hdim224_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim224_bf16_sm80.cu deleted file mode 100644 index 9a11724b2b..0000000000 --- a/candle-flash-attn/kernels/flash_fwd_hdim224_bf16_sm80.cu +++ /dev/null @@ -1,10 +0,0 @@ -// Copyright (c) 2024, Tri Dao. -// Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim224(params, stream); -} diff --git a/candle-flash-attn/kernels/flash_fwd_hdim224_fp16_causal_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim224_fp16_causal_sm80.cu deleted file mode 100644 index d02edae078..0000000000 --- a/candle-flash-attn/kernels/flash_fwd_hdim224_fp16_causal_sm80.cu +++ /dev/null @@ -1,10 +0,0 @@ -// Copyright (c) 2024, Tri Dao. -// Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim224(params, stream); -} diff --git a/candle-flash-attn/kernels/flash_fwd_hdim256_bf16_causal_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim256_bf16_causal_sm80.cu index f84e978c91..c200847041 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim256_bf16_causal_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim256_bf16_causal_sm80.cu @@ -1,10 +1,14 @@ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" - +#include "namespace_config.h" #include "flash_fwd_launch_template.h" +namespace FLASH_NAMESPACE { + template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim256(params, stream); } + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/candle-flash-attn/kernels/flash_fwd_hdim256_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim256_bf16_sm80.cu index c52f0417b9..2f8ee2496b 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim256_bf16_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim256_bf16_sm80.cu @@ -1,10 +1,14 @@ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" - +#include "namespace_config.h" #include "flash_fwd_launch_template.h" +namespace FLASH_NAMESPACE { + template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim256(params, stream); } + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/candle-flash-attn/kernels/flash_fwd_hdim256_fp16_causal_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim256_fp16_causal_sm80.cu index f96f7edc67..fe8e550835 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim256_fp16_causal_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim256_fp16_causal_sm80.cu @@ -1,10 +1,14 @@ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" - +#include "namespace_config.h" #include "flash_fwd_launch_template.h" +namespace FLASH_NAMESPACE { + template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim256(params, stream); } + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/candle-flash-attn/kernels/flash_fwd_hdim256_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim256_fp16_sm80.cu index 9c7c6b93d8..6a99c93a49 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim256_fp16_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim256_fp16_sm80.cu @@ -1,10 +1,14 @@ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" - +#include "namespace_config.h" #include "flash_fwd_launch_template.h" +namespace FLASH_NAMESPACE { + template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim256(params, stream); } + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/candle-flash-attn/kernels/flash_fwd_hdim32_bf16_causal_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim32_bf16_causal_sm80.cu index e21d0408ca..8cb6578f54 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim32_bf16_causal_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim32_bf16_causal_sm80.cu @@ -1,10 +1,14 @@ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" - +#include "namespace_config.h" #include "flash_fwd_launch_template.h" +namespace FLASH_NAMESPACE { + template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim32(params, stream); } + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/candle-flash-attn/kernels/flash_fwd_hdim32_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim32_bf16_sm80.cu index f377a5b8fa..ba7adb8f4f 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim32_bf16_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim32_bf16_sm80.cu @@ -1,10 +1,14 @@ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" - +#include "namespace_config.h" #include "flash_fwd_launch_template.h" +namespace FLASH_NAMESPACE { + template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim32(params, stream); } + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/candle-flash-attn/kernels/flash_fwd_hdim32_fp16_causal_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim32_fp16_causal_sm80.cu index 74e4d66ae9..b0cb6844d4 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim32_fp16_causal_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim32_fp16_causal_sm80.cu @@ -1,10 +1,14 @@ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" - +#include "namespace_config.h" #include "flash_fwd_launch_template.h" +namespace FLASH_NAMESPACE { + template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim32(params, stream); } + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/candle-flash-attn/kernels/flash_fwd_hdim32_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim32_fp16_sm80.cu index e85db18e39..cfbcf6c526 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim32_fp16_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim32_fp16_sm80.cu @@ -1,10 +1,14 @@ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" - +#include "namespace_config.h" #include "flash_fwd_launch_template.h" +namespace FLASH_NAMESPACE { + template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim32(params, stream); } + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/candle-flash-attn/kernels/flash_fwd_hdim512_bf16_causal_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim512_bf16_causal_sm80.cu deleted file mode 100644 index 0888d7eab9..0000000000 --- a/candle-flash-attn/kernels/flash_fwd_hdim512_bf16_causal_sm80.cu +++ /dev/null @@ -1,10 +0,0 @@ -// Copyright (c) 2024, Tri Dao. -// Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim512(params, stream); -} diff --git a/candle-flash-attn/kernels/flash_fwd_hdim512_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim512_bf16_sm80.cu deleted file mode 100644 index 9626295e86..0000000000 --- a/candle-flash-attn/kernels/flash_fwd_hdim512_bf16_sm80.cu +++ /dev/null @@ -1,10 +0,0 @@ -// Copyright (c) 2024, Tri Dao. -// Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim512(params, stream); -} diff --git a/candle-flash-attn/kernels/flash_fwd_hdim512_fp16_causal_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim512_fp16_causal_sm80.cu deleted file mode 100644 index 1ece8f2a92..0000000000 --- a/candle-flash-attn/kernels/flash_fwd_hdim512_fp16_causal_sm80.cu +++ /dev/null @@ -1,10 +0,0 @@ -// Copyright (c) 2024, Tri Dao. -// Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim512(params, stream); -} diff --git a/candle-flash-attn/kernels/flash_fwd_hdim64_bf16_causal_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim64_bf16_causal_sm80.cu index 9297e8bb68..7d703bb454 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim64_bf16_causal_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim64_bf16_causal_sm80.cu @@ -1,10 +1,14 @@ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" - +#include "namespace_config.h" #include "flash_fwd_launch_template.h" +namespace FLASH_NAMESPACE { + template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim64(params, stream); } + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/candle-flash-attn/kernels/flash_fwd_hdim64_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim64_bf16_sm80.cu index 8364b1e7ee..e0ab7cb59b 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim64_bf16_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim64_bf16_sm80.cu @@ -1,10 +1,14 @@ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" - +#include "namespace_config.h" #include "flash_fwd_launch_template.h" +namespace FLASH_NAMESPACE { + template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim64(params, stream); } + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/candle-flash-attn/kernels/flash_fwd_hdim64_fp16_causal_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim64_fp16_causal_sm80.cu index 1c6ed7ef02..d4c869f7e7 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim64_fp16_causal_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim64_fp16_causal_sm80.cu @@ -1,10 +1,14 @@ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" - +#include "namespace_config.h" #include "flash_fwd_launch_template.h" +namespace FLASH_NAMESPACE { + template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim64(params, stream); } + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/candle-flash-attn/kernels/flash_fwd_hdim64_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim64_fp16_sm80.cu index 3c87573ba2..fa2863579a 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim64_fp16_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim64_fp16_sm80.cu @@ -1,10 +1,14 @@ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" - +#include "namespace_config.h" #include "flash_fwd_launch_template.h" +namespace FLASH_NAMESPACE { + template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim64(params, stream); } + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/candle-flash-attn/kernels/flash_fwd_hdim96_bf16_causal_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim96_bf16_causal_sm80.cu index 49fae856a5..d7c626cd97 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim96_bf16_causal_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim96_bf16_causal_sm80.cu @@ -1,10 +1,14 @@ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" - +#include "namespace_config.h" #include "flash_fwd_launch_template.h" +namespace FLASH_NAMESPACE { + template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim96(params, stream); } + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/candle-flash-attn/kernels/flash_fwd_hdim96_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim96_bf16_sm80.cu index c5af1cf634..9b2d5f7a09 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim96_bf16_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim96_bf16_sm80.cu @@ -1,10 +1,14 @@ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" - +#include "namespace_config.h" #include "flash_fwd_launch_template.h" +namespace FLASH_NAMESPACE { + template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim96(params, stream); } + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/candle-flash-attn/kernels/flash_fwd_hdim96_fp16_causal_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim96_fp16_causal_sm80.cu index b0d6c9928e..9eb4fdd6c7 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim96_fp16_causal_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim96_fp16_causal_sm80.cu @@ -1,10 +1,14 @@ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" - +#include "namespace_config.h" #include "flash_fwd_launch_template.h" +namespace FLASH_NAMESPACE { + template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim96(params, stream); } + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/candle-flash-attn/kernels/flash_fwd_hdim96_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim96_fp16_sm80.cu index c97aa33f8b..0a1ea1a938 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim96_fp16_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim96_fp16_sm80.cu @@ -1,10 +1,14 @@ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" - +#include "namespace_config.h" #include "flash_fwd_launch_template.h" +namespace FLASH_NAMESPACE { + template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim96(params, stream); } + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/candle-flash-attn/kernels/flash_fwd_kernel.h b/candle-flash-attn/kernels/flash_fwd_kernel.h index b6b26d5207..d492c87b5c 100644 --- a/candle-flash-attn/kernels/flash_fwd_kernel.h +++ b/candle-flash-attn/kernels/flash_fwd_kernel.h @@ -4,7 +4,8 @@ #pragma once -// #include "philox_unpack.cuh" // For at::cuda::philox::unpack +#include "namespace_config.h" +#include "philox_unpack.cuh" // For at::cuda::philox::unpack #include @@ -20,7 +21,7 @@ #include "dropout.h" #include "rotary.h" -namespace flash { +namespace FLASH_NAMESPACE { using namespace cute; @@ -65,9 +66,8 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi constexpr int kHeadDim = Kernel_traits::kHeadDim; constexpr int kNWarps = Kernel_traits::kNWarps; - auto seed_offset = std::make_tuple(0ull, 0ull); - // auto seed_offset = at::cuda::philox::unpack(params.philox_args); - flash::Dropout dropout(std::get<0>(seed_offset), std::get<1>(seed_offset), params.p_dropout_in_uint8_t, + auto seed_offset = at::cuda::philox::unpack(params.philox_args); + FLASH_NAMESPACE::Dropout dropout(std::get<0>(seed_offset), std::get<1>(seed_offset), params.p_dropout_in_uint8_t, bidb, bidh, tidx, params.h); // Save seed and offset for backward, before any early exiting. Otherwise the 0-th thread block might @@ -116,7 +116,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; } } // Clear_OOB_K must be false since we don't want to write zeros to gmem - flash::copy( + FLASH_NAMESPACE::copy( gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM ); #pragma unroll @@ -247,7 +247,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // Prologue // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs - flash::copy(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, + FLASH_NAMESPACE::copy(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, binfo.actual_seqlen_q - m_block * kBlockM); if (Kernel_traits::Is_Q_in_regs) { cute::cp_async_fence(); } @@ -256,7 +256,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // // if (cute::thread0()) { print(sQNoSwizzle); } if (Kernel_traits::Share_Q_K_smem) { - flash::cp_async_wait<0>(); + FLASH_NAMESPACE::cp_async_wait<0>(); __syncthreads(); Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); // M @@ -266,14 +266,14 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi int n_block = n_block_max - 1; // We don't need to clear the sK smem tiles since we'll mask out the scores anyway. - flash::copy(gmem_tiled_copy_QKV, tKgK(_, _, _, n_block), tKsK, tKVcKV, tKVpKV, + FLASH_NAMESPACE::copy(gmem_tiled_copy_QKV, tKgK(_, _, _, n_block), tKsK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN); cute::cp_async_fence(); // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z < 2) { print(tKgK); } // __syncthreads(); if (Kernel_traits::Is_Q_in_regs && !Kernel_traits::Share_Q_K_smem) { - flash::cp_async_wait<1>(); + FLASH_NAMESPACE::cp_async_wait<1>(); __syncthreads(); Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); // M @@ -282,10 +282,10 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi clear(acc_o); - flash::Softmax<2 * size<1>(acc_o)> softmax; + FLASH_NAMESPACE::Softmax<2 * size<1>(acc_o)> softmax; const float alibi_slope = !Has_alibi || params.alibi_slopes_ptr == nullptr ? 0.0f : reinterpret_cast(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax; - flash::Mask mask(binfo.actual_seqlen_k, binfo.actual_seqlen_q, params.window_size_left, params.window_size_right, alibi_slope); + FLASH_NAMESPACE::Mask mask(binfo.actual_seqlen_k, binfo.actual_seqlen_q, params.window_size_left, params.window_size_right, alibi_slope); // For performance reason, we separate out two kinds of iterations: // those that need masking on S, and those that don't. @@ -302,37 +302,37 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) { Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) clear(acc_s); - flash::cp_async_wait<0>(); + FLASH_NAMESPACE::cp_async_wait<0>(); __syncthreads(); // Advance gV if (masking_step > 0) { - flash::copy(gmem_tiled_copy_QKV, tVgV(_, _, _, n_block), tVsV, tKVcKV, tKVpKV); + FLASH_NAMESPACE::copy(gmem_tiled_copy_QKV, tVgV(_, _, _, n_block), tVsV, tKVcKV, tKVpKV); } else { // Clear the smem tiles to account for predicated off loads - flash::copy( + FLASH_NAMESPACE::copy( gmem_tiled_copy_QKV, tVgV(_, _, _, n_block), tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN ); } cute::cp_async_fence(); - flash::gemm( + FLASH_NAMESPACE::gemm( acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, smem_thr_copy_Q, smem_thr_copy_K ); // if (cute::thread0()) { print(acc_s); } if constexpr (Is_softcap){ - flash::apply_softcap(acc_s, params.softcap); + FLASH_NAMESPACE::apply_softcap(acc_s, params.softcap); } mask.template apply_mask( acc_s, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 ); - flash::cp_async_wait<0>(); + FLASH_NAMESPACE::cp_async_wait<0>(); __syncthreads(); if (n_block > n_block_min) { - flash::copy(gmem_tiled_copy_QKV, tKgK(_, _, _, n_block - 1), tKsK, tKVcKV, tKVpKV); + FLASH_NAMESPACE::copy(gmem_tiled_copy_QKV, tKgK(_, _, _, n_block - 1), tKsK, tKVcKV, tKVpKV); // This cp_async_fence needs to be in the if block, otherwise the synchronization // isn't right and we get race conditions. cute::cp_async_fence(); @@ -344,7 +344,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi : softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2); // Convert acc_s from fp32 to fp16/bf16 - Tensor rP = flash::convert_type(acc_s); + Tensor rP = FLASH_NAMESPACE::convert_type(acc_s); int block_row_idx = m_block * (kBlockM / 16) + tidx / 32; int block_col_idx = n_block * (kBlockN / 32); if (Return_softmax) { @@ -362,9 +362,9 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8. - Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs(rP.layout())); + Tensor tOrP = make_tensor(rP.data(), FLASH_NAMESPACE::convert_layout_acc_Aregs(rP.layout())); // if (cute::thread0()) { print(tOrP); } - flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); + FLASH_NAMESPACE::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); // if (cute::thread0()) { print(scores); } // This check is at the end of the loop since we always have at least 1 iteration @@ -378,23 +378,23 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi for (; n_block >= n_block_min; --n_block) { Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) clear(acc_s); - flash::cp_async_wait<0>(); + FLASH_NAMESPACE::cp_async_wait<0>(); __syncthreads(); - flash::copy(gmem_tiled_copy_QKV, tVgV(_, _, _, n_block), tVsV, tKVcKV, tKVpKV); + FLASH_NAMESPACE::copy(gmem_tiled_copy_QKV, tVgV(_, _, _, n_block), tVsV, tKVcKV, tKVpKV); cute::cp_async_fence(); - flash::gemm( + FLASH_NAMESPACE::gemm( acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, smem_thr_copy_Q, smem_thr_copy_K ); if constexpr (Is_softcap){ - flash::apply_softcap(acc_s, params.softcap); + FLASH_NAMESPACE::apply_softcap(acc_s, params.softcap); } - flash::cp_async_wait<0>(); + FLASH_NAMESPACE::cp_async_wait<0>(); __syncthreads(); if (n_block > n_block_min) { - flash::copy(gmem_tiled_copy_QKV, tKgK(_, _, _, n_block - 1), tKsK, tKVcKV, tKVpKV); + FLASH_NAMESPACE::copy(gmem_tiled_copy_QKV, tKgK(_, _, _, n_block - 1), tKsK, tKVcKV, tKVpKV); // This cp_async_fence needs to be in the if block, otherwise the synchronization // isn't right and we get race conditions. cute::cp_async_fence(); @@ -406,7 +406,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2); - Tensor rP = flash::convert_type(acc_s); + Tensor rP = FLASH_NAMESPACE::convert_type(acc_s); int block_row_idx = m_block * (kBlockM / 16) + tidx / 32; int block_col_idx = n_block * (kBlockN / 32); if (Return_softmax) { @@ -424,8 +424,8 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8. - Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs(rP.layout())); - flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); + Tensor tOrP = make_tensor(rP.data(), FLASH_NAMESPACE::convert_layout_acc_Aregs(rP.layout())); + FLASH_NAMESPACE::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); } // Epilogue @@ -433,7 +433,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi Tensor lse = softmax.template normalize_softmax_lse(acc_o, params.scale_softmax, params.rp_dropout); // Convert acc_o from fp32 to fp16/bf16 - Tensor rO = flash::convert_type(acc_o); + Tensor rO = FLASH_NAMESPACE::convert_type(acc_o); Tensor sO = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N) // Partition sO to match the accumulator partitioning auto smem_tiled_copy_O = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma); @@ -488,7 +488,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; } } // Clear_OOB_K must be false since we don't want to write zeros to gmem - flash::copy( + FLASH_NAMESPACE::copy( gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM ); } @@ -564,7 +564,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; } } // Clear_OOB_K must be false since we don't want to write zeros to gmem - flash::copy( + FLASH_NAMESPACE::copy( gmem_tiled_copy_Oaccum, tOrOaccum, tOgOaccum, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM ); #pragma unroll @@ -731,18 +731,18 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons auto tKgK_data = tKgK.data(); auto tVgV_data = tVgV.data(); for (int n_block = n_block_max - 1; n_block >= n_block_copy_min; n_block--) { - flash::copy_w_min_idx( + FLASH_NAMESPACE::copy_w_min_idx( tVgVnew, tVgV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN ); tVgVnew.data() = tVgVnew.data() + (-int(kBlockN * params.vnew_row_stride)); if (params.rotary_dim == 0) { - flash::copy_w_min_idx( + FLASH_NAMESPACE::copy_w_min_idx( tKgKnew, tKgK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN ); } else { if (params.is_rotary_interleaved) { // Don't clear OOB_K because we're writing to global memory - flash::copy_rotary_interleaved( + FLASH_NAMESPACE::copy_rotary_interleaved( tKgKnew, tKgK, tRgCos, tRgSin, tKVcKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN, params.d, params.rotary_dim ); @@ -750,7 +750,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons tRgSin.data() = tRgSin.data() + (-int(kBlockN * params.rotary_dim / 2)); } else { // Don't clear OOB_K because we're writing to global memory - flash::copy_rotary_contiguous( + FLASH_NAMESPACE::copy_rotary_contiguous( tKgKnew, tKgK, tRgCosCont, tRgSinCont, tKVcKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN, params.d, params.rotary_dim ); @@ -785,7 +785,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons // Read Q from gmem to smem, optionally apply rotary embedding. if (!Append_KV || params.rotary_dim == 0) { // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs - flash::copy(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, + FLASH_NAMESPACE::copy(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, binfo.actual_seqlen_q - m_block * kBlockM); } else { const index_t row_offset_cossin = (binfo.seqlen_k_cache + (params.leftpad_k == nullptr ? 0 : params.leftpad_k[bidb]) + (Is_causal || Is_local ? m_block * kBlockM : 0)) * (params.rotary_dim / 2); @@ -808,12 +808,12 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons Tensor tRgCosCont = gmem_thr_copy_rotary_cont.partition_S(gCosCont); Tensor tRgSinCont = gmem_thr_copy_rotary_cont.partition_S(gSinCont); if (params.is_rotary_interleaved) { - flash::copy_rotary_interleaved( + FLASH_NAMESPACE::copy_rotary_interleaved( tQgQ, tQsQ, tRgCos, tRgSin, tQcQ, binfo.actual_seqlen_q - m_block * kBlockM, 0, params.d, params.rotary_dim ); } else { - flash::copy_rotary_contiguous( + FLASH_NAMESPACE::copy_rotary_contiguous( tQgQ, tQsQ, tRgCosCont, tRgSinCont, tQcQ, binfo.actual_seqlen_q - m_block * kBlockM, 0, params.d, params.rotary_dim ); @@ -822,21 +822,21 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons int n_block = n_block_max - 1; // We don't need to clear the sK smem tiles since we'll mask out the scores anyway. - flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV, + FLASH_NAMESPACE::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN); cute::cp_async_fence(); - // flash::cp_async_wait<0>(); + // FLASH_NAMESPACE::cp_async_wait<0>(); // __syncthreads(); // if (tidx == 0 && blockIdx.y == 0 && blockIdx.z == 0) { print(tKsK); } // __syncthreads(); clear(acc_o); - flash::Softmax<2 * size<1>(acc_o)> softmax; + FLASH_NAMESPACE::Softmax<2 * size<1>(acc_o)> softmax; const float alibi_slope = !Has_alibi ? 0.0f : reinterpret_cast(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax; - flash::Mask mask(binfo.actual_seqlen_k, binfo.actual_seqlen_q, params.window_size_left, params.window_size_right, alibi_slope); + FLASH_NAMESPACE::Mask mask(binfo.actual_seqlen_k, binfo.actual_seqlen_q, params.window_size_left, params.window_size_right, alibi_slope); // For performance reason, we separate out two kinds of iterations: // those that need masking on S, and those that don't. @@ -853,7 +853,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) { Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) clear(acc_s); - flash::cp_async_wait<0>(); + FLASH_NAMESPACE::cp_async_wait<0>(); __syncthreads(); // Advance gV @@ -867,22 +867,22 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons const int block_table_offset_next = n_block * kBlockN - block_table_idx_next * params.page_block_size; tVgV.data() = tVgV.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.v_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.v_row_stride; } - flash::copy(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); + FLASH_NAMESPACE::copy(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); } else { // Clear the smem tiles to account for predicated off loads - flash::copy( + FLASH_NAMESPACE::copy( gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN ); } cute::cp_async_fence(); - flash::gemm( + FLASH_NAMESPACE::gemm( acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, smem_thr_copy_Q, smem_thr_copy_K ); // if (cute::thread0()) { print(acc_s); } if constexpr (Is_softcap){ - flash::apply_softcap(acc_s, params.softcap); + FLASH_NAMESPACE::apply_softcap(acc_s, params.softcap); } @@ -890,7 +890,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons acc_s, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 ); - flash::cp_async_wait<0>(); + FLASH_NAMESPACE::cp_async_wait<0>(); __syncthreads(); // if (tidx == 0 && blockIdx.y == 0 && blockIdx.z == 0) { print(tVsV); } // __syncthreads(); @@ -906,7 +906,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons const int block_table_offset_next =(n_block - 1) * kBlockN - block_table_idx_next * params.page_block_size; tKgK.data() = tKgK.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.k_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.k_row_stride; } - flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); + FLASH_NAMESPACE::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); // This cp_async_fence needs to be in the if block, otherwise the synchronization // isn't right and we get race conditions. cute::cp_async_fence(); @@ -919,12 +919,12 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons // if (cute::thread0()) { print(scores_max); print(scores_sum); print(scores); } // Convert acc_s from fp32 to fp16/bf16 - Tensor rP = flash::convert_type(acc_s); + Tensor rP = FLASH_NAMESPACE::convert_type(acc_s); // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8. - Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs(rP.layout())); + Tensor tOrP = make_tensor(rP.data(), FLASH_NAMESPACE::convert_layout_acc_Aregs(rP.layout())); - flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); + FLASH_NAMESPACE::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); // This check is at the end of the loop since we always have at least 1 iteration if (n_masking_steps > 1 && n_block <= n_block_min) { @@ -937,7 +937,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons for (; n_block >= n_block_min; --n_block) { Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) clear(acc_s); - flash::cp_async_wait<0>(); + FLASH_NAMESPACE::cp_async_wait<0>(); __syncthreads(); // Advance gV if (block_table == nullptr) { @@ -949,18 +949,18 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons const int block_table_offset_next = n_block * kBlockN - block_table_idx_next * params.page_block_size; tVgV.data() = tVgV.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.v_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.v_row_stride; } - flash::copy(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); + FLASH_NAMESPACE::copy(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); cute::cp_async_fence(); - flash::gemm( + FLASH_NAMESPACE::gemm( acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, smem_thr_copy_Q, smem_thr_copy_K ); if constexpr (Is_softcap){ - flash::apply_softcap(acc_s, params.softcap); + FLASH_NAMESPACE::apply_softcap(acc_s, params.softcap); } - flash::cp_async_wait<0>(); + FLASH_NAMESPACE::cp_async_wait<0>(); __syncthreads(); if (n_block > n_block_min) { // Advance gK @@ -973,7 +973,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons const int block_table_offset_next = (n_block - 1) * kBlockN - block_table_idx_next * params.page_block_size; tKgK.data() = tKgK.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.k_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.k_row_stride; } - flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); + FLASH_NAMESPACE::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); // This cp_async_fence needs to be in the if block, otherwise the synchronization // isn't right and we get race conditions. cute::cp_async_fence(); @@ -984,12 +984,12 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons ); softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2); - Tensor rP = flash::convert_type(acc_s); + Tensor rP = FLASH_NAMESPACE::convert_type(acc_s); // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8. - Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs(rP.layout())); + Tensor tOrP = make_tensor(rP.data(), FLASH_NAMESPACE::convert_layout_acc_Aregs(rP.layout())); - flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); + FLASH_NAMESPACE::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); } // Epilogue @@ -1006,7 +1006,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons >; auto smem_tiled_copy_Oaccum = make_tiled_copy_C(SmemTiledCopyO{}, tiled_mma); auto smem_thr_copy_Oaccum = smem_tiled_copy_Oaccum.get_thread_slice(tidx); - Tensor rO = flash::convert_type(acc_o); + Tensor rO = FLASH_NAMESPACE::convert_type(acc_o); Tensor taccOrOaccum = smem_thr_copy_Oaccum.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N) Tensor taccOsOaccum = smem_thr_copy_Oaccum.partition_D(sOaccum); // ((Atom,AtomNum),PIPE_M,PIPE_N) @@ -1065,7 +1065,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; } } // Clear_OOB_K must be false since we don't want to write zeros to gmem - flash::copy( + FLASH_NAMESPACE::copy( gmem_tiled_copy_Oaccum, tOrOaccum, tOgOaccum, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM ); } @@ -1088,7 +1088,7 @@ inline __device__ void compute_attn(const Params ¶ms) { // the attention matrix. This way, as long as we have the batch, head, and the location of // the 16 x 32 block within the attention matrix, we can generate the exact same dropout pattern. - flash::compute_attn_1rowblock(params, bidb, bidh, m_block); + FLASH_NAMESPACE::compute_attn_1rowblock(params, bidb, bidh, m_block); } //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -1102,7 +1102,7 @@ inline __device__ void compute_attn_splitkv(const Params ¶ms) { const int bidh = Split ? blockIdx.z - bidb * params.h : blockIdx.z; const int n_split_idx = Split ? blockIdx.y : 0; const int num_n_splits = Split ? gridDim.y : 1; - flash::compute_attn_1rowblock_splitkv(params, bidb, bidh, m_block, n_split_idx, num_n_splits); + FLASH_NAMESPACE::compute_attn_1rowblock_splitkv(params, bidb, bidh, m_block, n_split_idx, num_n_splits); } //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -1243,7 +1243,7 @@ inline __device__ void combine_attn_seqk_parallel(const Params ¶ms) { } // Load Oaccum in then scale and accumulate to O for (int split = 0; split < params.num_splits; ++split) { - flash::copy( + FLASH_NAMESPACE::copy( gmem_tiled_copy_Oaccum, tOgOaccum, tOrOaccum, tOcOaccum, tOpOaccum, params.b * params.h * params.seqlen_q - bidx * kBlockM ); #pragma unroll @@ -1263,7 +1263,7 @@ inline __device__ void combine_attn_seqk_parallel(const Params ¶ms) { } // if (cute::thread0()) { print_tensor(tOrO); } - Tensor rO = flash::convert_type(tOrO); + Tensor rO = FLASH_NAMESPACE::convert_type(tOrO); // Write to gO #pragma unroll for (int m = 0; m < size<1>(rO); ++m) { @@ -1291,4 +1291,4 @@ inline __device__ void combine_attn_seqk_parallel(const Params ¶ms) { } } -} // namespace flash +} // namespace FLASH_NAMESPACE diff --git a/candle-flash-attn/kernels/flash_fwd_launch_template.h b/candle-flash-attn/kernels/flash_fwd_launch_template.h index f8ac4fc4e8..b275c4ef0b 100644 --- a/candle-flash-attn/kernels/flash_fwd_launch_template.h +++ b/candle-flash-attn/kernels/flash_fwd_launch_template.h @@ -3,14 +3,16 @@ ******************************************************************************/ #pragma once -// #include // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK +#include "namespace_config.h" +#include "error.h" // candle: shim providing C10_CUDA_CHECK / C10_CUDA_KERNEL_LAUNCH_CHECK without c10 -#include "error.h" #include "static_switch.h" #include "hardware_info.h" #include "flash.h" #include "flash_fwd_kernel.h" +namespace FLASH_NAMESPACE { + // Determine if the architecture supports FLASH and define a macro to handle parameter modifiers #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 #define ARCH_SUPPORTS_FLASH @@ -30,7 +32,7 @@ __global__ void kernelName(KERNEL_PARAM_MODIFIER const Flash_fwd_params params) DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_kernel, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Return_softmax) { #if defined(ARCH_SUPPORTS_FLASH) static_assert(!(Is_causal && Is_local)); // Enforce constraints - flash::compute_attn(params); + FLASH_NAMESPACE::compute_attn(params); #else FLASH_UNSUPPORTED_ARCH #endif @@ -38,7 +40,7 @@ DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_kernel, bool Is_dropout, bool Is_causal, b DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_kernel, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Split, bool Append_KV) { #if defined(ARCH_SUPPORTS_FLASH) - flash::compute_attn_splitkv(params); + FLASH_NAMESPACE::compute_attn_splitkv(params); #else FLASH_UNSUPPORTED_ARCH #endif @@ -46,7 +48,7 @@ DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_kernel, bool Is_causal, bool Is_lo DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_combine_kernel, int kBlockM, int Log_max_splits, bool Is_even_K) { static_assert(Log_max_splits >= 1); - flash::combine_attn_seqk_parallel(params); + FLASH_NAMESPACE::combine_attn_seqk_parallel(params); } template @@ -74,7 +76,7 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { // If return_softmax, set IsEvenMNConst to false to reduce number of templates // If head dim > 128, set IsEvenMNConst to false to reduce number of templates // If Is_local, set Is_causal to false - auto kernel = &flash_fwd_kernel; + auto kernel = &flash_fwd_kernel; // auto kernel = &flash_fwd_kernel; // printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout)); // auto kernel = &flash_fwd_kernel; @@ -115,7 +117,7 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { // If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr. // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. // If Is_local, set Is_causal to false - auto kernel = &flash_fwd_splitkv_kernel; + auto kernel = &flash_fwd_splitkv_kernel; // auto kernel = &flash_fwd_splitkv_kernel; // auto kernel = &flash_fwd_splitkv_kernel; if (smem_size >= 48 * 1024) { @@ -163,8 +165,7 @@ void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream) constexpr static int kBlockM = 64; // Fixed for all head dimensions // TD [2023-08-28]: nvcc segfaults for headdim 96 with block size 64 x 256, // and for headdim 192 with block size 64 x 128. - // Also for headdim 160 with block size 64 x 128 after the rotary addition. - constexpr static int kBlockN = Headdim <= 64 ? 256 : (Headdim <= 128 ? 128 : (Headdim <= 256 ? 64 : 32)); + constexpr static int kBlockN = Headdim <= 64 ? 256 : (Headdim <= 128 ? 128 : 64); run_flash_splitkv_fwd, Is_causal>(params, stream); } @@ -196,12 +197,6 @@ void run_mha_fwd_hdim64(Flash_fwd_params ¶ms, cudaStream_t stream) { }); } -inline bool cuda_is_sm8x() { - // dprops = at::cuda::getCurrentDeviceProperties(); - // return dprops->major == 8 && dprops->minor > 0; - return false; -} - template void run_mha_fwd_hdim96(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 96; @@ -261,34 +256,6 @@ void run_mha_fwd_hdim128(Flash_fwd_params ¶ms, cudaStream_t stream) { }); } -template -void run_mha_fwd_hdim160(Flash_fwd_params ¶ms, cudaStream_t stream) { - constexpr static int Headdim = 160; - auto [cc_major, cc_minor] = get_compute_capability(get_current_device()); - bool is_sm8x = cc_major == 8 && cc_minor > 0; - DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { - // For A100, H100, 128 x 32 is the fastest. - // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), - // and 128 x 64 with 8 warps is the fastest for non-causal. - if (is_sm8x) { - if constexpr(!Is_causal) { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - } else { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - } - } else { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - } - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd>(params, stream); - // run_flash_fwd>(params, stream); - // run_flash_fwd>(params, stream); - // run_flash_fwd>(params, stream); - // run_flash_fwd>(params, stream); - }); -} - template void run_mha_fwd_hdim192(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 192; @@ -306,33 +273,6 @@ void run_mha_fwd_hdim192(Flash_fwd_params ¶ms, cudaStream_t stream) { }); } -template -void run_mha_fwd_hdim224(Flash_fwd_params ¶ms, cudaStream_t stream) { - constexpr static int Headdim = 224; - int device; - cudaGetDevice(&device); - int max_smem_per_block; - cudaError status_ = cudaDeviceGetAttribute( - &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device); - if (status_ != cudaSuccess) { - C10_CUDA_CHECK(status_); - } - // printf("max_smem_per_block = %d\n", max_smem_per_block); - DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { - if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64)) { // 112 KB - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - } else { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - } - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // We can't do 128 x 32 with 8 warps because with headdim 224, kBlockKSmem = 32. - // If we have N = 32, there are only 1024 elements to load at once, where each load - // is 8 elements. This means we can only use 128 threads and not 256 threads. - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - }); -} - template void run_mha_fwd_hdim256(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 256; @@ -361,25 +301,4 @@ void run_mha_fwd_hdim256(Flash_fwd_params ¶ms, cudaStream_t stream) { // run_flash_fwd, Is_dropout, Is_causal>(params, stream); }); } - -template -void run_mha_fwd_hdim512(Flash_fwd_params ¶ms, cudaStream_t stream) { - constexpr static int Headdim = 512; - int device; - cudaGetDevice(&device); - int max_smem_per_block; - cudaError status_ = cudaDeviceGetAttribute( - &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device); - if (status_ != cudaSuccess) { - C10_CUDA_CHECK(status_); - } - DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { - // For A100 (164KB max smem), use 64 x 32 with 4 warps (128KB smem). - // For sm86/sm89 (100KB max smem), use 32 x 32 with 4 warps (96KB smem). - if (max_smem_per_block >= 2 * Headdim * (64 + 2 * 32)) { // 128 KB - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - } else { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - } - }); -} +} // namespace FLASH_NAMESPACE diff --git a/candle-flash-attn/kernels/flash_fwd_split_hdim128_bf16_causal_sm80.cu b/candle-flash-attn/kernels/flash_fwd_split_hdim128_bf16_causal_sm80.cu new file mode 100644 index 0000000000..40559c640b --- /dev/null +++ b/candle-flash-attn/kernels/flash_fwd_split_hdim128_bf16_causal_sm80.cu @@ -0,0 +1,11 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" +#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/candle-flash-attn/kernels/flash_fwd_split_hdim128_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_split_hdim128_bf16_sm80.cu new file mode 100644 index 0000000000..48500b8f13 --- /dev/null +++ b/candle-flash-attn/kernels/flash_fwd_split_hdim128_bf16_sm80.cu @@ -0,0 +1,11 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" +#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/candle-flash-attn/kernels/flash_fwd_hdim160_fp16_causal_sm80.cu b/candle-flash-attn/kernels/flash_fwd_split_hdim128_fp16_causal_sm80.cu similarity index 51% rename from candle-flash-attn/kernels/flash_fwd_hdim160_fp16_causal_sm80.cu rename to candle-flash-attn/kernels/flash_fwd_split_hdim128_fp16_causal_sm80.cu index 1ef511a6b7..355902924d 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim160_fp16_causal_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_split_hdim128_fp16_causal_sm80.cu @@ -1,10 +1,11 @@ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" - +#include "namespace_config.h" #include "flash_fwd_launch_template.h" -template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim160(params, stream); -} +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/candle-flash-attn/kernels/flash_fwd_split_hdim128_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_split_hdim128_fp16_sm80.cu new file mode 100644 index 0000000000..6aa638de82 --- /dev/null +++ b/candle-flash-attn/kernels/flash_fwd_split_hdim128_fp16_sm80.cu @@ -0,0 +1,11 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" +#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/candle-flash-attn/kernels/flash_fwd_split_hdim192_bf16_causal_sm80.cu b/candle-flash-attn/kernels/flash_fwd_split_hdim192_bf16_causal_sm80.cu new file mode 100644 index 0000000000..979deee411 --- /dev/null +++ b/candle-flash-attn/kernels/flash_fwd_split_hdim192_bf16_causal_sm80.cu @@ -0,0 +1,11 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" +#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/candle-flash-attn/kernels/flash_fwd_split_hdim192_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_split_hdim192_bf16_sm80.cu new file mode 100644 index 0000000000..236365e4ff --- /dev/null +++ b/candle-flash-attn/kernels/flash_fwd_split_hdim192_bf16_sm80.cu @@ -0,0 +1,11 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" +#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/candle-flash-attn/kernels/flash_fwd_split_hdim192_fp16_causal_sm80.cu b/candle-flash-attn/kernels/flash_fwd_split_hdim192_fp16_causal_sm80.cu new file mode 100644 index 0000000000..9c4420fa81 --- /dev/null +++ b/candle-flash-attn/kernels/flash_fwd_split_hdim192_fp16_causal_sm80.cu @@ -0,0 +1,11 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" +#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/candle-flash-attn/kernels/flash_fwd_split_hdim192_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_split_hdim192_fp16_sm80.cu new file mode 100644 index 0000000000..872f5ced87 --- /dev/null +++ b/candle-flash-attn/kernels/flash_fwd_split_hdim192_fp16_sm80.cu @@ -0,0 +1,11 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" +#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/candle-flash-attn/kernels/flash_fwd_split_hdim256_bf16_causal_sm80.cu b/candle-flash-attn/kernels/flash_fwd_split_hdim256_bf16_causal_sm80.cu new file mode 100644 index 0000000000..8fee9f57bd --- /dev/null +++ b/candle-flash-attn/kernels/flash_fwd_split_hdim256_bf16_causal_sm80.cu @@ -0,0 +1,11 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" +#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/candle-flash-attn/kernels/flash_fwd_split_hdim256_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_split_hdim256_bf16_sm80.cu new file mode 100644 index 0000000000..6adcb1bf2f --- /dev/null +++ b/candle-flash-attn/kernels/flash_fwd_split_hdim256_bf16_sm80.cu @@ -0,0 +1,11 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" +#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/candle-flash-attn/kernels/flash_fwd_split_hdim256_fp16_causal_sm80.cu b/candle-flash-attn/kernels/flash_fwd_split_hdim256_fp16_causal_sm80.cu new file mode 100644 index 0000000000..df05869f7a --- /dev/null +++ b/candle-flash-attn/kernels/flash_fwd_split_hdim256_fp16_causal_sm80.cu @@ -0,0 +1,11 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" +#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/candle-flash-attn/kernels/flash_fwd_split_hdim256_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_split_hdim256_fp16_sm80.cu new file mode 100644 index 0000000000..51bd8e4d7a --- /dev/null +++ b/candle-flash-attn/kernels/flash_fwd_split_hdim256_fp16_sm80.cu @@ -0,0 +1,11 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" +#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/candle-flash-attn/kernels/flash_fwd_split_hdim32_bf16_causal_sm80.cu b/candle-flash-attn/kernels/flash_fwd_split_hdim32_bf16_causal_sm80.cu new file mode 100644 index 0000000000..fa340d6f06 --- /dev/null +++ b/candle-flash-attn/kernels/flash_fwd_split_hdim32_bf16_causal_sm80.cu @@ -0,0 +1,11 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" +#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/candle-flash-attn/kernels/flash_fwd_split_hdim32_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_split_hdim32_bf16_sm80.cu new file mode 100644 index 0000000000..0f2adec7a2 --- /dev/null +++ b/candle-flash-attn/kernels/flash_fwd_split_hdim32_bf16_sm80.cu @@ -0,0 +1,11 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" +#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/candle-flash-attn/kernels/flash_fwd_hdim224_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_split_hdim32_fp16_causal_sm80.cu similarity index 51% rename from candle-flash-attn/kernels/flash_fwd_hdim224_fp16_sm80.cu rename to candle-flash-attn/kernels/flash_fwd_split_hdim32_fp16_causal_sm80.cu index 28150ed0ad..345551033c 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim224_fp16_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_split_hdim32_fp16_causal_sm80.cu @@ -1,10 +1,11 @@ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" - +#include "namespace_config.h" #include "flash_fwd_launch_template.h" -template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim224(params, stream); -} +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/candle-flash-attn/kernels/flash_fwd_split_hdim32_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_split_hdim32_fp16_sm80.cu new file mode 100644 index 0000000000..ec9523de08 --- /dev/null +++ b/candle-flash-attn/kernels/flash_fwd_split_hdim32_fp16_sm80.cu @@ -0,0 +1,11 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" +#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/candle-flash-attn/kernels/flash_fwd_split_hdim64_bf16_causal_sm80.cu b/candle-flash-attn/kernels/flash_fwd_split_hdim64_bf16_causal_sm80.cu new file mode 100644 index 0000000000..750c69fcce --- /dev/null +++ b/candle-flash-attn/kernels/flash_fwd_split_hdim64_bf16_causal_sm80.cu @@ -0,0 +1,11 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" +#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/candle-flash-attn/kernels/flash_fwd_split_hdim64_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_split_hdim64_bf16_sm80.cu new file mode 100644 index 0000000000..a1b26d84f4 --- /dev/null +++ b/candle-flash-attn/kernels/flash_fwd_split_hdim64_bf16_sm80.cu @@ -0,0 +1,11 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" +#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/candle-flash-attn/kernels/flash_fwd_hdim512_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_split_hdim64_fp16_causal_sm80.cu similarity index 51% rename from candle-flash-attn/kernels/flash_fwd_hdim512_fp16_sm80.cu rename to candle-flash-attn/kernels/flash_fwd_split_hdim64_fp16_causal_sm80.cu index f328be93e2..3061167100 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim512_fp16_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_split_hdim64_fp16_causal_sm80.cu @@ -1,10 +1,11 @@ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" - +#include "namespace_config.h" #include "flash_fwd_launch_template.h" -template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim512(params, stream); -} +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/candle-flash-attn/kernels/flash_fwd_split_hdim64_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_split_hdim64_fp16_sm80.cu new file mode 100644 index 0000000000..aeda6bfdd2 --- /dev/null +++ b/candle-flash-attn/kernels/flash_fwd_split_hdim64_fp16_sm80.cu @@ -0,0 +1,11 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" +#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/candle-flash-attn/kernels/flash_fwd_split_hdim96_bf16_causal_sm80.cu b/candle-flash-attn/kernels/flash_fwd_split_hdim96_bf16_causal_sm80.cu new file mode 100644 index 0000000000..d55eb40391 --- /dev/null +++ b/candle-flash-attn/kernels/flash_fwd_split_hdim96_bf16_causal_sm80.cu @@ -0,0 +1,11 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" +#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/candle-flash-attn/kernels/flash_fwd_split_hdim96_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_split_hdim96_bf16_sm80.cu new file mode 100644 index 0000000000..a139c0743a --- /dev/null +++ b/candle-flash-attn/kernels/flash_fwd_split_hdim96_bf16_sm80.cu @@ -0,0 +1,11 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" +#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/candle-flash-attn/kernels/flash_fwd_hdim160_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_split_hdim96_fp16_causal_sm80.cu similarity index 51% rename from candle-flash-attn/kernels/flash_fwd_hdim160_fp16_sm80.cu rename to candle-flash-attn/kernels/flash_fwd_split_hdim96_fp16_causal_sm80.cu index 96abfbd8a1..8e66343237 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim160_fp16_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_split_hdim96_fp16_causal_sm80.cu @@ -1,10 +1,11 @@ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" - +#include "namespace_config.h" #include "flash_fwd_launch_template.h" -template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim160(params, stream); -} +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/candle-flash-attn/kernels/flash_fwd_split_hdim96_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_split_hdim96_fp16_sm80.cu new file mode 100644 index 0000000000..2a874bf607 --- /dev/null +++ b/candle-flash-attn/kernels/flash_fwd_split_hdim96_fp16_sm80.cu @@ -0,0 +1,11 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" +#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/candle-flash-attn/kernels/hardware_info.h b/candle-flash-attn/kernels/hardware_info.h index d5c48d3517..a4643a5d8a 100644 --- a/candle-flash-attn/kernels/hardware_info.h +++ b/candle-flash-attn/kernels/hardware_info.h @@ -4,8 +4,9 @@ #pragma once -#include #include +#include +#include #if !defined(__CUDACC_RTC__) #include "cuda_runtime.h" diff --git a/candle-flash-attn/kernels/kernel_traits.h b/candle-flash-attn/kernels/kernel_traits.h index 8db1dfcd04..8c0897488d 100644 --- a/candle-flash-attn/kernels/kernel_traits.h +++ b/candle-flash-attn/kernels/kernel_traits.h @@ -158,7 +158,7 @@ struct Flash_fwd_kernel_traits : public Base { Layout>{})); // Val layout, 8 vals per load }; -// Is_V_in_regs is an option to reduce smem usage, but will increase register pressure. +// Is_V_in_regs is an option to reduce smem usage, but will increase register pressue. // No_double_buffer is another option to reduce smem usage, but will slow things down. template -namespace flash { +namespace FLASH_NAMESPACE { using namespace cute; @@ -137,8 +138,8 @@ struct Mask { // if (cute::thread0()) { printf("Has_alibi = %d, Causal_mask=%d, Is_local=%d, Is_even_MN = %d, Need_masking = %d\n", Has_alibi, Causal_mask, Is_local, Is_even_MN, Need_masking); } if constexpr (Need_masking) { // Reshape tensor_ from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) - Tensor tensor = make_tensor(tensor_.data(), flash::convert_layout_acc_rowcol(tensor_.layout())); - // Do we need both row and column indices, or just column indices? + Tensor tensor = make_tensor(tensor_.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(tensor_.layout())); + // Do we need both row and column indices, or just column incides? static constexpr bool Col_idx_only = !(Has_alibi && !Is_causal) && !Is_local && !Causal_mask; const int lane_id = threadIdx.x % 32; const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; @@ -210,4 +211,4 @@ struct Mask { }; -} // namespace flash +} // namespace FLASH_NAMESPACE diff --git a/candle-flash-attn/kernels/namespace_config.h b/candle-flash-attn/kernels/namespace_config.h new file mode 100644 index 0000000000..a6fad57b15 --- /dev/null +++ b/candle-flash-attn/kernels/namespace_config.h @@ -0,0 +1,67 @@ +/** + * @file flash_namespace_config.h + * @brief Configuration file for Flash namespace management and isolation + * + * This header provides configuration macros for managing the Flash namespace + * across a codebase. It allows for flexible namespace naming and provides + * utilities for namespace declaration and scoping. + * + * Usage Examples: + * + * 1. Basic namespace wrapping: + * @code + * BEGIN_FLASH_NAMESPACE + * class FlashDevice { + * // Implementation + * }; + * END_FLASH_NAMESPACE + * @endcode + * + * 2. Accessing types within the namespace: + * @code + * FLASH_NAMESPACE_ALIAS(FlashDevice) device; + * @endcode + * + * 3. Defining content within namespace scope: + * @code + * FLASH_NAMESPACE_SCOPE( + * struct Configuration { + * uint32_t size; + * bool enabled; + * }; + * ) + * @endcode + * + * 4. Custom namespace name: + * @code + * #define FLASH_NAMESPACE custom_flash + * #include "flash_namespace_config.h" + * @endcode + * + * Configuration: + * - The default namespace is 'flash' if FLASH_NAMESPACE is not defined + * - Define FLASH_NAMESPACE before including this header to customize the + * namespace name + * + * Best Practices: + * - Include this header in all files that need access to the Flash namespace + * + */ +#pragma once + +#ifndef FLASH_NAMESPACE_CONFIG_H +#define FLASH_NAMESPACE_CONFIG_H + +// Set default namespace to flash +#ifndef FLASH_NAMESPACE +#define FLASH_NAMESPACE flash +#endif + +#define FLASH_NAMESPACE_ALIAS(name) FLASH_NAMESPACE::name + +#define FLASH_NAMESPACE_SCOPE(content) \ + namespace FLASH_NAMESPACE { \ + content \ + } + +#endif // FLASH_NAMESPACE_CONFIG_H diff --git a/candle-flash-attn/kernels/philox.cuh b/candle-flash-attn/kernels/philox.cuh index cd7e4d2fae..5205f4542e 100644 --- a/candle-flash-attn/kernels/philox.cuh +++ b/candle-flash-attn/kernels/philox.cuh @@ -2,7 +2,9 @@ #pragma once // Philox CUDA. -namespace flash { +#include "namespace_config.h" + +namespace FLASH_NAMESPACE { struct ull2 { unsigned long long x; @@ -48,4 +50,4 @@ __forceinline__ __device__ uint4 philox(unsigned long long seed, return output; } -} // namespace flash +} // namespace FLASH_NAMESPACE diff --git a/candle-flash-attn/kernels/philox_unpack.cuh b/candle-flash-attn/kernels/philox_unpack.cuh new file mode 100644 index 0000000000..204438e5d2 --- /dev/null +++ b/candle-flash-attn/kernels/philox_unpack.cuh @@ -0,0 +1,32 @@ +// candle: stub replacement for the original PyTorch-dependent header. +// +// Upstream Tri Dao's `philox_unpack.cuh` includes +// to provide +// `at::cuda::philox::unpack(at::PhiloxCudaState)` for the dropout +// path in `flash_fwd_kernel.h`. candle-flash-attn doesn't link +// against PyTorch, and inference-only callers don't exercise the +// dropout codepath, so we stub `unpack()` to return a dummy +// (seed, offset) pair. The result is computed-but-unused by the +// kernel when `Is_dropout=false` (the caller path through +// `LOCAL_SWITCH(... Is_dropout && !Is_softcap, ...)`); the +// ostensible Dropout object built from these dummy values is +// dead code under that compile-time branch. +// +// `at::PhiloxCudaState` is also stubbed as an empty struct so +// `flash.h`'s `Flash_fwd_params::philox_args` field can stay in +// the layout — matching the field count Tri Dao expects without +// dragging PyTorch in. + +#pragma once + +#include +#include + +namespace at { + struct PhiloxCudaState {}; + namespace cuda { namespace philox { + inline __host__ __device__ std::tuple unpack(PhiloxCudaState const&) { + return {0ull, 0ull}; + } + }} +} diff --git a/candle-flash-attn/kernels/rotary.h b/candle-flash-attn/kernels/rotary.h index 7f1614ad24..dbae24c626 100644 --- a/candle-flash-attn/kernels/rotary.h +++ b/candle-flash-attn/kernels/rotary.h @@ -6,11 +6,12 @@ #include +#include "namespace_config.h" #include "utils.h" //////////////////////////////////////////////////////////////////////////////////////////////////// -namespace flash { +namespace FLASH_NAMESPACE { using namespace cute; @@ -149,4 +150,4 @@ __forceinline__ __device__ void copy_rotary_contiguous(Tensor //////////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace flash +} // namespace FLASH_NAMESPACE diff --git a/candle-flash-attn/kernels/softmax.h b/candle-flash-attn/kernels/softmax.h index ebf1b09798..01589adedb 100644 --- a/candle-flash-attn/kernels/softmax.h +++ b/candle-flash-attn/kernels/softmax.h @@ -10,10 +10,11 @@ #include +#include "namespace_config.h" #include "philox.cuh" #include "utils.h" -namespace flash { +namespace FLASH_NAMESPACE { using namespace cute; @@ -135,18 +136,18 @@ struct Softmax { template __forceinline__ __device__ void softmax_rescale_o(Tensor0 &acc_s, Tensor1 &acc_o, float softmax_scale_log2) { // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) - Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); + Tensor scores = make_tensor(acc_s.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(acc_s.layout())); static_assert(decltype(size<0>(scores))::value == kNRows); if (Is_first) { - flash::template reduce_max(scores, row_max); - flash::scale_apply_exp2(scores, row_max, softmax_scale_log2); - flash::reduce_sum(scores, row_sum); + FLASH_NAMESPACE::template reduce_max(scores, row_max); + FLASH_NAMESPACE::scale_apply_exp2(scores, row_max, softmax_scale_log2); + FLASH_NAMESPACE::reduce_sum(scores, row_sum); } else { Tensor scores_max_prev = make_fragment_like(row_max); cute::copy(row_max, scores_max_prev); - flash::template reduce_max(scores, row_max); + FLASH_NAMESPACE::template reduce_max(scores, row_max); // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K)) - Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); + Tensor acc_o_rowcol = make_tensor(acc_o.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(acc_o.layout())); static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows); #pragma unroll for (int mi = 0; mi < size(row_max); ++mi) { @@ -158,10 +159,10 @@ struct Softmax { #pragma unroll for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scores_scale; } } - flash::scale_apply_exp2(scores, row_max, softmax_scale_log2); + FLASH_NAMESPACE::scale_apply_exp2(scores, row_max, softmax_scale_log2); // We don't do the reduce across threads here since we don't need to use the row_sum. // We do that reduce at the end when we need to normalize the softmax. - flash::reduce_sum(scores, row_sum); + FLASH_NAMESPACE::reduce_sum(scores, row_sum); } }; @@ -170,7 +171,7 @@ struct Softmax { SumOp sum_op; quad_allreduce_(row_sum, row_sum, sum_op); TensorT lse = make_fragment_like(row_sum); - Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); + Tensor acc_o_rowcol = make_tensor(acc_o.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(acc_o.layout())); static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows); #pragma unroll for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) { @@ -185,4 +186,4 @@ struct Softmax { }; }; -} // namespace flash +} // namespace FLASH_NAMESPACE diff --git a/candle-flash-attn/kernels/static_switch.h b/candle-flash-attn/kernels/static_switch.h index 4c88f0aea0..70d14daf69 100644 --- a/candle-flash-attn/kernels/static_switch.h +++ b/candle-flash-attn/kernels/static_switch.h @@ -101,20 +101,11 @@ } else if (HEADDIM <= 128) { \ constexpr static int kHeadDim = 128; \ return __VA_ARGS__(); \ - } else if (HEADDIM <= 160) { \ - constexpr static int kHeadDim = 160; \ - return __VA_ARGS__(); \ } else if (HEADDIM <= 192) { \ constexpr static int kHeadDim = 192; \ return __VA_ARGS__(); \ - } else if (HEADDIM <= 224) { \ - constexpr static int kHeadDim = 224; \ - return __VA_ARGS__(); \ } else if (HEADDIM <= 256) { \ constexpr static int kHeadDim = 256; \ return __VA_ARGS__(); \ - } else if (HEADDIM <= 512) { \ - constexpr static int kHeadDim = 512; \ - return __VA_ARGS__(); \ } \ }() diff --git a/candle-flash-attn/kernels/utils.h b/candle-flash-attn/kernels/utils.h index b7408ec444..a7729aede5 100644 --- a/candle-flash-attn/kernels/utils.h +++ b/candle-flash-attn/kernels/utils.h @@ -21,9 +21,11 @@ #include #include +#include "namespace_config.h" + //////////////////////////////////////////////////////////////////////////////////////////////////// -namespace flash { +namespace FLASH_NAMESPACE { //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -268,8 +270,8 @@ __forceinline__ __device__ auto convert_type_relu(Tensor const & } Tensor out = make_tensor(make_rmem_ptr(out_uint32.data()), tensor.layout()); #else - Tensor out = flash::convert_type(tensor); - flash::relu_(out); + Tensor out = FLASH_NAMESPACE::convert_type(tensor); + FLASH_NAMESPACE::relu_(out); #endif return out; } @@ -408,4 +410,4 @@ __forceinline__ __device__ void calculate_dtanh(Tensor &src_te //////////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace flash +} // namespace FLASH_NAMESPACE diff --git a/candle-flash-attn/src/ffi.rs b/candle-flash-attn/src/ffi.rs index 78d3a98677..379ba21a36 100644 --- a/candle-flash-attn/src/ffi.rs +++ b/candle-flash-attn/src/ffi.rs @@ -48,6 +48,17 @@ extern "C" { window_size_right: c_int, softcap: f32, + + // PR-FA-2: split-KV dispatch surface. With `num_splits <= 1` and + // `force_split_kernel == 0` the dispatcher takes the dense path + // (existing behavior). The accumulator buffers (`softmax_lseaccum_ptr`, + // `oaccum_ptr`) must be fp32 and live on the same device — only + // dereferenced when `num_splits > 1` or `force_split_kernel != 0`. + // Allocation is the caller's responsibility (lib.rs / PR-FA-3). + num_splits: c_int, + softmax_lseaccum_ptr: *const c_void, + oaccum_ptr: *const c_void, + force_split_kernel: c_int, ); } diff --git a/candle-flash-attn/src/lib.rs b/candle-flash-attn/src/lib.rs index ce419460e1..c794564776 100644 --- a/candle-flash-attn/src/lib.rs +++ b/candle-flash-attn/src/lib.rs @@ -5,6 +5,150 @@ use candle::cuda_backend::cudarc::driver::DevicePtr; use candle::{CpuStorage, DType, Layout, Result, Shape, Tensor}; use half::{bf16, f16}; +/// Splitkv tile size along the K dimension. +/// +/// This MUST match the `kBlockN` baked into `run_mha_fwd_splitkv_dispatch<>` +/// in `kernels/flash_fwd_launch_template.h:168`. If upstream changes the +/// per-hdim block size, this table needs the same change or `num_splits` +/// will be miscomputed. +/// +/// `#[doc(hidden)] pub` so integration tests can predict which split count +/// the dispatcher will pick for a given shape; not part of the supported API. +#[doc(hidden)] +pub fn splitkv_block_n(head_size: usize) -> usize { + if head_size <= 64 { + 256 + } else if head_size <= 128 { + 128 + } else { + 64 + } +} + +/// Port of upstream `flash_api.cpp::num_splits_heuristic` (Tri Dao FA v2.8.3, +/// commit 060c918). Picks the smallest split count whose efficiency reaches +/// ≥85% of the peak, capped at `max_splits`. +/// +/// Inputs `batch_nheads_mblocks` and `num_sms` should already be doubled if +/// the splitkv kernel uses 128 threads/block (as upstream does). See the +/// `num_sm * 2` factor in `set_params_splitkv`. +/// +/// `#[doc(hidden)] pub` so integration tests can verify the dispatcher +/// actually entered the splitkv path for a given shape (i.e. catch silent +/// fallback to dense if the heuristic regresses); not part of the supported +/// API and may change without notice. +#[doc(hidden)] +pub fn num_splits_heuristic( + batch_nheads_mblocks: usize, + num_sms: usize, + num_n_blocks: usize, + max_splits: usize, +) -> usize { + // If we already nearly fill the SMs, splitkv adds overhead without help. + // Equivalent to upstream's `>= 0.8 * num_SMs` but in integer arithmetic. + if batch_nheads_mblocks.saturating_mul(5) >= num_sms.saturating_mul(4) { + return 1; + } + let max_splits = max_splits.min(num_sms).min(num_n_blocks); + if max_splits == 0 { + return 1; + } + let ceildiv = |a: usize, b: usize| (a + b - 1) / b; + // Eligibility check: skip split counts that produce the same per-split + // block layout as a smaller count (e.g. 12 splits across 64 blocks + // collapses to 11). Upstream lambda from flash_api.cpp:275-277. + let is_eligible = |n: usize| -> bool { + n == 1 || ceildiv(num_n_blocks, n) != ceildiv(num_n_blocks, n - 1) + }; + let mut max_eff = 0.0f32; + let mut effs = Vec::with_capacity(max_splits); + for n in 1..=max_splits { + if !is_eligible(n) { + effs.push(0.0); + continue; + } + let n_waves = (batch_nheads_mblocks * n) as f32 / num_sms as f32; + let eff = n_waves / n_waves.ceil(); + if eff > max_eff { + max_eff = eff; + } + effs.push(eff); + } + for n in 1..=max_splits { + if !is_eligible(n) { + continue; + } + if effs[n - 1] >= 0.85 * max_eff { + return n; + } + } + 1 +} + +/// Rust port of upstream `flash_api.cpp::set_params_splitkv` (Tri Dao FA +/// v2.8.3). Decides whether to enter the splitkv dispatch path and, if so, +/// allocates the two fp32 accumulator buffers the kernel writes into. +/// +/// Returns `(num_splits, accumulators)`. When `num_splits == 1` the dense +/// path is taken and `accumulators` is `None`. The returned `CudaSlice`s +/// must outlive the `ffi::run_mha` call (their `device_ptr` guards are +/// taken at the FFI call site). +/// +/// Accumulator layouts match `flash_fwd_splitkv_combine_kernel`: +/// `softmax_lse_accum`: (num_splits, batch_size, num_heads, max_seqlen_q) +/// `out_accum`: (num_splits, batch_size, num_heads, max_seqlen_q, head_size_rounded) +fn set_params_splitkv( + dev: &candle::CudaDevice, + batch_size: usize, + num_heads: usize, + head_size: usize, + head_size_rounded: usize, + max_seqlen_q: usize, + max_seqlen_k: usize, +) -> Result<( + i32, + Option<( + candle::cuda_backend::cudarc::driver::CudaSlice, + candle::cuda_backend::cudarc::driver::CudaSlice, + )>, +)> { + let block_n = splitkv_block_n(head_size); + let num_n_blocks = (max_seqlen_k + block_n - 1) / block_n; + // Splitkv dispatcher fixes kBlockM = 64 (flash_fwd_launch_template.h:165). + let num_m_blocks = (max_seqlen_q + 64 - 1) / 64; + + // Upstream multiplies num_SMs by 2 because the splitkv kernel uses 128 + // threads per block (so each SM can in principle host two CTAs). + let num_sm = dev + .cuda_stream() + .context() + .attribute( + candle::cuda_backend::cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT, + ) + .map_err(|e| { + candle::Error::Msg(format!("cuDeviceGetAttribute(MULTIPROCESSOR_COUNT): {e}")) + })? as usize; + + let num_splits = num_splits_heuristic( + batch_size * num_heads * num_m_blocks, + num_sm * 2, + num_n_blocks, + 128, + ); + if num_splits > 128 { + candle::bail!("flash-attn splitkv: num_splits > 128 not supported (got {num_splits})"); + } + if num_splits <= 1 { + return Ok((num_splits as i32, None)); + } + + let lse_n = num_splits * batch_size * num_heads * max_seqlen_q; + let out_n = lse_n * head_size_rounded; + let lse_accum = unsafe { dev.alloc::(lse_n)? }; + let out_accum = unsafe { dev.alloc::(out_n)? }; + Ok((num_splits as i32, Some((lse_accum, out_accum)))) +} + pub struct FlashAttn { pub softmax_scale: f32, pub alibi_slopes: Option, @@ -30,7 +174,7 @@ impl FlashAttn { v_l: &Layout, is_bf16: bool, ) -> Result<(candle::CudaStorage, Shape)> { - // https://github.com/Dao-AILab/flash-attention/blob/b252072409e69c25f2b9d473cc534e49b24decd2/csrc/flash_attn/flash_api.cpp#L187 + // https://github.com/Dao-AILab/flash-attention/blob/060c9188beec3a8b62b33a3bfa6d5d2d44975fab/csrc/flash_attn/flash_api.cpp (v2.8.3, vendored 2026-05-06 per HF#3515) let dev = q.device(); let out_shape = q_l.shape().clone(); let out_l = Layout::contiguous(&out_shape); @@ -144,6 +288,21 @@ impl FlashAttn { let dst = unsafe { dev.alloc::(elem_count)? }; let softmax_lse = dev.alloc_zeros::(b_sz * 128 * num_heads * seqlen_q)?; + // PR-FA-3: decide between dense (`num_splits == 1`) and splitkv + // (`num_splits > 1`) dispatch, allocating accumulator buffers when + // splitkv is selected. For shapes that nearly fill the SMs (large + // batch × heads × m_blocks vs num_SMs) the heuristic returns 1 and + // behavior matches PR-FA-2. + let (num_splits, splitkv_buffers) = set_params_splitkv( + dev, + b_sz, + num_heads, + head_size, + head_size_rounded, + seqlen_q, + seqlen_k, + )?; + let is_bf16 = if is_bf16 { 1 } else { 0 }; // Causal is the special case where window_size_right == 0 and window_size_left < 0. @@ -166,6 +325,28 @@ impl FlashAttn { let (v_ptr, _guard) = v.device_ptr(&stream); let (dst_ptr, _guard) = dst.device_ptr(&stream); let (softmax_lse_ptr, _guard) = softmax_lse.device_ptr(&stream); + // Splitkv accumulator pointers + guards. Bound here so the guards + // live for the FFI call; null when num_splits == 1. + let lseaccum_ptr; + let oaccum_ptr; + let _splitkv_lse_guard; + let _splitkv_out_guard; + match &splitkv_buffers { + Some((lse_acc, out_acc)) => { + let (l, lg) = lse_acc.device_ptr(&stream); + let (o, og) = out_acc.device_ptr(&stream); + lseaccum_ptr = l as *const core::ffi::c_void; + oaccum_ptr = o as *const core::ffi::c_void; + _splitkv_lse_guard = Some(lg); + _splitkv_out_guard = Some(og); + } + None => { + lseaccum_ptr = std::ptr::null(); + oaccum_ptr = std::ptr::null(); + _splitkv_lse_guard = None; + _splitkv_out_guard = None; + } + } ffi::run_mha( q_ptr as *const core::ffi::c_void, k_ptr as *const core::ffi::c_void, @@ -204,6 +385,15 @@ impl FlashAttn { /* window_size_left */ window_size_left, /* window_size_right */ window_size_right, /* softcap */ self.softcap.unwrap_or(0f32), + // PR-FA-3 splitkv params. The accumulator pointers and their + // device_ptr guards are bound at the top of this unsafe block + // (see _splitkv_*_guard). When splitkv_buffers is None, + // num_splits == 1 and the kernel never reads through these + // null pointers (dense path). + /* num_splits */ num_splits, + /* softmax_lseaccum_ptr */ lseaccum_ptr, + /* oaccum_ptr */ oaccum_ptr, + /* force_split_kernel */ 0, ) } @@ -462,7 +652,7 @@ impl FlashAttnVarLen { v_l: &Layout, is_bf16: bool, ) -> Result<(candle::CudaStorage, Shape)> { - // https://github.com/Dao-AILab/flash-attention/blob/184b992dcb2a0890adaa19eb9b541c3e4f9d2a08/csrc/flash_attn/flash_api.cpp#L327 + // https://github.com/Dao-AILab/flash-attention/blob/060c9188beec3a8b62b33a3bfa6d5d2d44975fab/csrc/flash_attn/flash_api.cpp (v2.8.3, vendored 2026-05-06 per HF#3515) let dev = q.device(); let out_shape = q_l.shape().clone(); let out_l = Layout::contiguous(&out_shape); @@ -607,6 +797,20 @@ impl FlashAttnVarLen { let dst = unsafe { dev.alloc::(elem_count)? }; let softmax_lse = dev.alloc_zeros::(num_heads * total_q)?; + // PR-FA-3: splitkv decision + accumulator allocation. Mirrors the + // dense path; uses `max_seqlen_q` / `max_seqlen_k` for the heuristic + // input since varlen kernels still tile based on the per-batch + // maximum. + let (num_splits, splitkv_buffers) = set_params_splitkv( + dev, + batch_size, + num_heads, + head_size, + head_size_rounded, + self.max_seqlen_q, + self.max_seqlen_k, + )?; + let is_bf16 = if is_bf16 { 1 } else { 0 }; // Causal is the special case where window_size_right == 0 and window_size_left < 0. @@ -631,6 +835,26 @@ impl FlashAttnVarLen { let (softmax_lse_ptr, _guard) = softmax_lse.device_ptr(&stream); let (seqlens_q_ptr, _guard) = seqlens_q.device_ptr(&stream); let (seqlens_k_ptr, _guard) = seqlens_k.device_ptr(&stream); + let lseaccum_ptr; + let oaccum_ptr; + let _splitkv_lse_guard; + let _splitkv_out_guard; + match &splitkv_buffers { + Some((lse_acc, out_acc)) => { + let (l, lg) = lse_acc.device_ptr(&stream); + let (o, og) = out_acc.device_ptr(&stream); + lseaccum_ptr = l as *const core::ffi::c_void; + oaccum_ptr = o as *const core::ffi::c_void; + _splitkv_lse_guard = Some(lg); + _splitkv_out_guard = Some(og); + } + None => { + lseaccum_ptr = std::ptr::null(); + oaccum_ptr = std::ptr::null(); + _splitkv_lse_guard = None; + _splitkv_out_guard = None; + } + } ffi::run_mha( q_ptr as *const core::ffi::c_void, k_ptr as *const core::ffi::c_void, @@ -669,6 +893,11 @@ impl FlashAttnVarLen { /* window_size_left */ window_size_left, /* window_size_right */ window_size_right, /* softcap */ self.softcap.unwrap_or(0.0), + // PR-FA-3 splitkv params (see splitkv_buffers above). + /* num_splits */ num_splits, + /* softmax_lseaccum_ptr */ lseaccum_ptr, + /* oaccum_ptr */ oaccum_ptr, + /* force_split_kernel */ 0, ) } diff --git a/candle-flash-attn/tests/flash_attn_tests.rs b/candle-flash-attn/tests/flash_attn_tests.rs index e305861146..8ea7cbeae2 100644 --- a/candle-flash-attn/tests/flash_attn_tests.rs +++ b/candle-flash-attn/tests/flash_attn_tests.rs @@ -141,6 +141,90 @@ fn flash_attn_acausal_softcap() -> Result<()> { Ok(()) } +#[test] +fn flash_attn_acausal_splitkv() -> Result<()> { + // Shape designed to enter the splitkv dispatch path on any modern CUDA + // GPU (sm80+) per the heuristic ported from upstream FA v2.8.3 in + // PR-FA-3: batch * heads * m_blocks = 1 * 2 * 1 = 2 leaves headroom + // under the 0.8 * num_SMs short-circuit on A6000 / 4090 / A100, and + // seqlen_k = 512 with head_dim = 64 produces num_n_blocks = 2 (split + // tile is 256 in K for hdim <= 64) — so the heuristic returns >= 2 + // and the kernel takes the splitkv path. If the splitkv combine + // kernel is wrong, this test diverges from the fp32 attention + // reference. + let device = Device::new_cuda(0)?; + let (b, h, sq, sk, d) = (1usize, 2, 8, 512, 64); + let scale = 1.0f32 / (d as f32).sqrt(); + + // Provenance check: assert the dispatcher actually picks splitkv for this + // shape on this device, so the test fails (instead of silently passing + // via the dense path) if the heuristic ever regresses. Mirrors the + // computation done inside `set_params_splitkv` in src/lib.rs. + { + let cuda_dev = device.as_cuda_device()?; + let num_sm = cuda_dev + .cuda_stream() + .context() + .attribute( + candle::cuda_backend::cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT, + ) + .map_err(|e| anyhow::anyhow!("cuDeviceGetAttribute(MULTIPROCESSOR_COUNT): {e}"))? + as usize; + let block_n = candle_flash_attn::splitkv_block_n(d); + let num_n_blocks = (sk + block_n - 1) / block_n; + let num_m_blocks = (sq + 64 - 1) / 64; + let num_splits = candle_flash_attn::num_splits_heuristic( + b * h * num_m_blocks, + num_sm * 2, + num_n_blocks, + 128, + ); + assert!( + num_splits > 1, + "expected splitkv path for shape (b={b}, h={h}, sq={sq}, sk={sk}, d={d}) on a {num_sm}-SM device, but heuristic chose num_splits={num_splits}", + ); + } + + // Flash-attn input layout is (batch, seq, heads, head_dim). + let q = (Tensor::arange(0u32, (b * sq * h * d) as u32, &device)? + .to_dtype(DType::F16)? + .reshape((b, sq, h, d))? + / 1024.)?; + let k = (Tensor::arange(0u32, (b * sk * h * d) as u32, &device)? + .to_dtype(DType::F16)? + .reshape((b, sk, h, d))? + / 4096.)?; + let v = (Tensor::arange(0u32, (b * sk * h * d) as u32, &device)? + .to_dtype(DType::F16)? + .reshape((b, sk, h, d))? + / 8192.)?; + + // Reference attention: collapse (batch, heads) into a single batch axis + // so `fa_acausal`'s rank-3 matmul matches per-head, then unflatten. + let ys_ref = { + let qref = q.transpose(1, 2)?.contiguous()?.reshape((b * h, sq, d))?; + let kref = k.transpose(1, 2)?.contiguous()?.reshape((b * h, sk, d))?; + let vref = v.transpose(1, 2)?.contiguous()?.reshape((b * h, sk, d))?; + fa_acausal(&qref, &kref, &vref, scale)? + .reshape((b, h, sq, d))? + .transpose(1, 2)? + .contiguous()? + }; + + let ys = candle_flash_attn::flash_attn(&q, &k, &v, scale, false)?; + + let ys = ys.to_dtype(DType::F32)?; + let ys_ref = ys_ref.to_dtype(DType::F32)?; + let diff = ys.sub(&ys_ref)?.abs()?.flatten_all()?.max(0)?; + assert_eq!(ys.dims(), &[b, sq, h, d]); + let diff_v = diff.to_vec0::()?; + assert!( + diff_v < 5e-3, + "splitkv vs fa_acausal max abs diff = {diff_v} (expected < 5e-3)" + ); + Ok(()) +} + #[test] fn flash_attn_varlen() -> Result<()> { let device = Device::new_cuda(0)?;