From 05a5e8645a6f64289d7f4da98695c153ad01aa6c Mon Sep 17 00:00:00 2001 From: yingxudeng Date: Fri, 17 Apr 2026 01:58:17 +0800 Subject: [PATCH] perf: use fused gdn gating for qwen3.5 prefill. --- .../ascend/kernels/fused_gdn_gating.py | 10 +++++++++- .../npu/tilelang/fused_gdn_gating_wrapper.cpp | 5 +++-- .../fused_gdn_gating_wrapper_test.cpp | 6 ++++++ .../npu_torch/qwen3_gated_delta_net_base.cpp | 19 ++++++++++--------- 4 files changed, 28 insertions(+), 12 deletions(-) diff --git a/xllm/compiler/tilelang/targets/ascend/kernels/fused_gdn_gating.py b/xllm/compiler/tilelang/targets/ascend/kernels/fused_gdn_gating.py index 59e0d45c7..1ca56c5ee 100644 --- a/xllm/compiler/tilelang/targets/ascend/kernels/fused_gdn_gating.py +++ b/xllm/compiler/tilelang/targets/ascend/kernels/fused_gdn_gating.py @@ -26,7 +26,15 @@ VECTOR_BYTES_PER_ITER = 256 SUPPORTED_NUM_HEADS = (4, 6, 8, 12, 16, 24, 32, 48, 64, 128) MAX_VEC_CORE_NUM = detect_vec_core_num() -BATCH_SIZE_SPECIALIZATIONS = tuple(range(2, 49, 2)) + + +def _build_batch_size_specializations() -> tuple[int, ...]: + small_dense = range(2, 49, 2) + large_sparse = range(64, DEFAULT_MAX_BATCH + 1, 64) + return tuple(sorted(set((*small_dense, *large_sparse)))) + + +BATCH_SIZE_SPECIALIZATIONS = _build_batch_size_specializations() def select_launch_block_num(*, num_batches: int, vec_core_num: int) -> int: diff --git a/xllm/core/kernels/npu/tilelang/fused_gdn_gating_wrapper.cpp b/xllm/core/kernels/npu/tilelang/fused_gdn_gating_wrapper.cpp index 07331ab97..3ad74bf68 100644 --- a/xllm/core/kernels/npu/tilelang/fused_gdn_gating_wrapper.cpp +++ b/xllm/core/kernels/npu/tilelang/fused_gdn_gating_wrapper.cpp @@ -193,8 +193,9 @@ void run_tilelang_fused_gdn_gating_chunk(const torch::Tensor& A_log, auto specialization = build_runtime_specialization(a); const auto* entry = find_fused_gdn_gating_kernel_entry(specialization); - // Expected fast path: compiled batch_size variants are dense [2, 4, ..., 48]. - // If a value is missing, fall back to the nearest smaller batch_size. + // Small batch-size variants are dense [2, 4, ..., 48]. Larger long-prefill + // variants are compiled sparsely and dispatch falls back to the nearest + // smaller batch_size when an exact value is not available. if (entry == nullptr) { int32_t fallback_batch_size = specialization.batch_size - kBatchSpecializationStep; diff --git a/xllm/core/kernels/npu/tilelang/fused_gdn_gating_wrapper_test.cpp b/xllm/core/kernels/npu/tilelang/fused_gdn_gating_wrapper_test.cpp index f91d06d7a..971323b8d 100644 --- a/xllm/core/kernels/npu/tilelang/fused_gdn_gating_wrapper_test.cpp +++ b/xllm/core/kernels/npu/tilelang/fused_gdn_gating_wrapper_test.cpp @@ -144,6 +144,12 @@ TEST_F(TileLangFusedGdnGatingWrapperTest, MatchesTorchReference) { .num_heads = 32, .seed = 106, }, + { + .name = "long_prefill_b3533_h32", + .num_batches = 3533, + .num_heads = 32, + .seed = 108, + }, { .name = "custom_beta2_threshold0p5_b33_h64", .num_batches = 33, diff --git a/xllm/core/layers/npu_torch/qwen3_gated_delta_net_base.cpp b/xllm/core/layers/npu_torch/qwen3_gated_delta_net_base.cpp index cb877a272..694af0b81 100644 --- a/xllm/core/layers/npu_torch/qwen3_gated_delta_net_base.cpp +++ b/xllm/core/layers/npu_torch/qwen3_gated_delta_net_base.cpp @@ -393,15 +393,16 @@ torch::Tensor Qwen3GatedDeltaNetBaseImpl::forward( // Compute gated delta net decay and beta terms. if (attn_metadata.is_prefill) { - beta = torch::sigmoid(b); - torch::Tensor A_log_exp = A_log_.exp(); - torch::Tensor a_float = a.to(torch::kFloat32); - torch::Tensor a_plus_dt = a_float + dt_bias_; - torch::Tensor softplus_out = torch::nn::functional::softplus( - a_plus_dt, - torch::nn::functional::SoftplusFuncOptions().beta(1.0).threshold(20.0)); - g = -A_log_exp * softplus_out; - g = g.to(a.dtype()).contiguous(); + xllm::kernel::FusedGdnGatingParams gdn_params; + gdn_params.A_log = A_log_; + gdn_params.a = a.contiguous().view({-1, a.size(-1)}); + gdn_params.b = b.contiguous().view({-1, b.size(-1)}); + gdn_params.dt_bias = dt_bias_; + gdn_params.beta = 1.0f; + gdn_params.threshold = 20.0f; + std::tie(g, beta) = xllm::kernel::fused_gdn_gating(gdn_params); + g = g.squeeze(0).contiguous().view({batch_size, seq_len, a.size(-1)}); + beta = beta.squeeze(0).contiguous().view({batch_size, seq_len, b.size(-1)}); } else { xllm::kernel::FusedGdnGatingParams gdn_params; gdn_params.A_log = A_log_;