diff --git a/third_party/torch_npu_ops b/third_party/torch_npu_ops index 599eb0334..eaaddb96f 160000 --- a/third_party/torch_npu_ops +++ b/third_party/torch_npu_ops @@ -1 +1 @@ -Subproject commit 599eb033413ec249e0d614796f0bcfedc5191253 +Subproject commit eaaddb96f036d90125ac01850ba7580e1e908a9e diff --git a/third_party/xllm_ops b/third_party/xllm_ops index 96a590903..d3a0acf6d 160000 --- a/third_party/xllm_ops +++ b/third_party/xllm_ops @@ -1 +1 @@ -Subproject commit 96a590903bd8e4a73131a59b25240d11634cb0ed +Subproject commit d3a0acf6dc56efce92a1ecc98723a9259d9a68a6 diff --git a/xllm/core/kernels/npu/CMakeLists.txt b/xllm/core/kernels/npu/CMakeLists.txt index c37e0ca1d..00c6d4201 100644 --- a/xllm/core/kernels/npu/CMakeLists.txt +++ b/xllm/core/kernels/npu/CMakeLists.txt @@ -20,6 +20,7 @@ cc_library( npu_moe_init_routing_v2.cpp npu_moe_token_unpermute.cpp rope.cpp + npu_recurrent_gated_delta_rule.cpp DEPS :torch_npu_kernels :tilelang_kernels diff --git a/xllm/core/kernels/npu/npu_ops_api.h b/xllm/core/kernels/npu/npu_ops_api.h index 6d85ca0df..c9156c0bf 100644 --- a/xllm/core/kernels/npu/npu_ops_api.h +++ b/xllm/core/kernels/npu/npu_ops_api.h @@ -141,4 +141,16 @@ std::pair apply_npu_partial_rotary_embedding( const torch::Tensor& cos_sin_cache, bool is_neox_style); +torch::Tensor npu_recurrent_gated_delta_rule( + const torch::Tensor& query, + const torch::Tensor& key, + const torch::Tensor& value, + torch::Tensor& state, + const std::optional& beta, + const std::optional scale, + const std::optional& actual_seq_lengths, + const std::optional& ssm_state_indices, + const std::optional& num_accepted_tokens, + const std::optional& g, + const std::optional& gk); } // namespace xllm::kernel::npu diff --git a/xllm/core/kernels/npu/npu_recurrent_gated_delta_rule.cpp b/xllm/core/kernels/npu/npu_recurrent_gated_delta_rule.cpp new file mode 100644 index 000000000..66562c635 --- /dev/null +++ b/xllm/core/kernels/npu/npu_recurrent_gated_delta_rule.cpp @@ -0,0 +1,165 @@ +/* Copyright 2026 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include +#include + +#include +#ifdef TORCH_HIGHER_THAN_PTA6 +#include +#else +#include +#include +#endif + +#include "acl/acl.h" +#include "aclnn_recurrent_gated_delta_rule.h" +#include "core/common/macros.h" +#include "core/kernels/npu/utils.h" +#include "npu_ops_api.h" + +namespace xllm::kernel::npu { + +torch::Tensor npu_recurrent_gated_delta_rule( + const torch::Tensor& query, + const torch::Tensor& key, + const torch::Tensor& value, + torch::Tensor& state, + const std::optional& beta, + const std::optional scale, + const std::optional& actual_seq_lengths, + const std::optional& ssm_state_indices, + const std::optional& num_accepted_tokens, + const std::optional& g, + const std::optional& gk) { + check_tensor(query, "query", "recurrent_gated_delta_rule"); + check_tensor(key, "key", "recurrent_gated_delta_rule"); + check_tensor(value, "value", "recurrent_gated_delta_rule"); + check_tensor(state, "state", "recurrent_gated_delta_rule"); + + aclTensor* query_ids = nullptr; + aclTensor* key_ids = nullptr; + aclTensor* value_ids = nullptr; + aclTensor* state_ids = nullptr; + aclTensor* beta_ids = nullptr; + aclTensor* actual_seq_lengths_ids = nullptr; + aclTensor* ssm_state_indices_ids = nullptr; + aclTensor* num_accepted_tokens_ids = nullptr; + aclTensor* g_ids = nullptr; + aclTensor* gk_ids = nullptr; + aclTensor* out_ids = nullptr; + + int32_t device_id = query.device().index(); + aclrtStream stream = c10_npu::getCurrentNPUStream(device_id).stream(); + + create_acltensor(&query_ids, query); + create_acltensor(&key_ids, key); + create_acltensor(&value_ids, value); + create_acltensor(&state_ids, state); + + if (beta.has_value() && beta.value().defined()) { + create_acltensor(&beta_ids, beta.value()); + } + if (actual_seq_lengths.has_value() && actual_seq_lengths.value().defined()) { + create_acltensor(&actual_seq_lengths_ids, actual_seq_lengths.value()); + } + if (ssm_state_indices.has_value() && ssm_state_indices.value().defined()) { + create_acltensor(&ssm_state_indices_ids, ssm_state_indices.value()); + } + if (num_accepted_tokens.has_value() && + num_accepted_tokens.value().defined()) { + create_acltensor(&num_accepted_tokens_ids, num_accepted_tokens.value()); + } + if (g.has_value() && g.value().defined()) { + create_acltensor(&g_ids, g.value()); + } + if (gk.has_value() && gk.value().defined()) { + create_acltensor(&gk_ids, gk.value()); + } + + at::Tensor out_result = at::empty_like(value); + create_acltensor(&out_ids, out_result); + + float scale_value = static_cast(scale.value()); + + uint64_t workspace_size = 0; + aclOpExecutor* executor = nullptr; + + CHECK_ACL_SUCCESS( + aclnnRecurrentGatedDeltaRuleGetWorkspaceSize(query_ids, + key_ids, + value_ids, + beta_ids, + state_ids, + actual_seq_lengths_ids, + ssm_state_indices_ids, + g_ids, + gk_ids, + num_accepted_tokens_ids, + scale_value, + out_ids, + &workspace_size, + &executor), + "recurrent_gated_delta_rule: failed to get workspace size"); + + void* workspace_addr = nullptr; + if (workspace_size > 0) { + CHECK_ACL_SUCCESS( + aclrtMalloc(&workspace_addr, workspace_size, ACL_MEM_MALLOC_HUGE_FIRST), + "recurrent_gated_delta_rule: failed to allocate workspace"); + } + + CHECK_ACL_SUCCESS(aclnnRecurrentGatedDeltaRule( + workspace_addr, workspace_size, executor, stream), + "recurrent_gated_delta_rule: failed to perform recurrent " + "gated delta rule"); + + aclDestroyTensor(query_ids); + aclDestroyTensor(key_ids); + aclDestroyTensor(value_ids); + aclDestroyTensor(state_ids); + aclDestroyTensor(out_ids); + + if (beta_ids != nullptr) { + aclDestroyTensor(beta_ids); + } + if (actual_seq_lengths_ids != nullptr) { + aclDestroyTensor(actual_seq_lengths_ids); + } + if (ssm_state_indices_ids != nullptr) { + aclDestroyTensor(ssm_state_indices_ids); + } + if (num_accepted_tokens_ids != nullptr) { + aclDestroyTensor(num_accepted_tokens_ids); + } + if (g_ids != nullptr) { + aclDestroyTensor(g_ids); + } + if (gk_ids != nullptr) { + aclDestroyTensor(gk_ids); + } + + if (workspace_size > 0) { + CHECK_ACL_SUCCESS(aclrtFree(workspace_addr), + "recurrent_gated_delta_rule: failed to free workspace"); + } + + return out_result; +} + +} // namespace xllm::kernel::npu \ No newline at end of file diff --git a/xllm/core/kernels/ops_api.cpp b/xllm/core/kernels/ops_api.cpp index 008082cbf..788a488cc 100644 --- a/xllm/core/kernels/ops_api.cpp +++ b/xllm/core/kernels/ops_api.cpp @@ -756,6 +756,14 @@ void fused_indexer_k(FusedIndexerKParams& params) { #endif } +torch::Tensor l2_norm(torch::Tensor& x, double eps) { +#if defined(USE_NPU) + return npu::npu_l2norm_last_dim(x, eps); +#else + NOT_IMPLEMENTED(); +#endif +} + std::tuple moe_init_routing_v2(MoeInitRoutingV2Params& params) { #if defined(USE_NPU) @@ -1022,4 +1030,51 @@ torch::Tensor build_split_qkv_rmsnorm_mrope_gather_pattern( #endif } +std::pair chunk_gated_delta_rule( + ChunkGatedDeltaRuleParams& params) { +#if defined(USE_NPU) + return npu::npu_chunk_gated_delta_rule(params.q, + params.k, + params.v, + params.g, + params.beta, + params.scale, + params.initial_state, + params.output_final_state, + params.cu_seqlens, + params.head_first, + params.use_qk_l2norm_in_kernel); +#else + NOT_IMPLEMENTED(); +#endif +} + +torch::Tensor recurrent_gated_delta_rule( + const torch::Tensor& query, + const torch::Tensor& key, + const torch::Tensor& value, + torch::Tensor& state, + const std::optional& beta, + const std::optional scale, + const std::optional& actual_seq_lengths, + const std::optional& ssm_state_indices, + const std::optional& num_accepted_tokens, + const std::optional& g, + const std::optional& gk) { +#if defined(USE_NPU) + return npu::npu_recurrent_gated_delta_rule(query, + key, + value, + state, + beta, + scale, + actual_seq_lengths, + ssm_state_indices, + num_accepted_tokens, + g, + gk); +#else + NOT_IMPLEMENTED(); +#endif +} } // namespace xllm::kernel diff --git a/xllm/core/kernels/ops_api.h b/xllm/core/kernels/ops_api.h index dcb28f27a..58c8fa161 100644 --- a/xllm/core/kernels/ops_api.h +++ b/xllm/core/kernels/ops_api.h @@ -94,6 +94,9 @@ void fused_indexer_q(FusedIndexerQParams& params); void fused_indexer_k(FusedIndexerKParams& params); +// L2 normalization along the last dimension +torch::Tensor l2_norm(torch::Tensor& x, double eps = 1e-6); + // TODO: NPU moe_init_routing_v2 is equivalent to moe_gen_idx + moe_expand_input // (and token_count/cusum outputs) on other backends. std::tuple @@ -152,4 +155,19 @@ torch::Tensor build_split_qkv_rmsnorm_mrope_gather_pattern( bool is_interleaved, const torch::Device& device); +std::pair chunk_gated_delta_rule( + ChunkGatedDeltaRuleParams& params); + +torch::Tensor recurrent_gated_delta_rule( + const torch::Tensor& query, + const torch::Tensor& key, + const torch::Tensor& value, + torch::Tensor& state, + const std::optional& beta, + const std::optional scale, + const std::optional& actual_seq_lengths, + const std::optional& ssm_state_indices, + const std::optional& num_accepted_tokens, + const std::optional& g, + const std::optional& gk); } // namespace xllm::kernel diff --git a/xllm/core/kernels/param.h b/xllm/core/kernels/param.h index 486d4f574..9c96c8373 100644 --- a/xllm/core/kernels/param.h +++ b/xllm/core/kernels/param.h @@ -1411,4 +1411,31 @@ struct SplitQkvRmsnormMropeParams { int64_t head_size; }; +struct ChunkGatedDeltaRuleParams { + // Query tensor. Shape: [B, T, Hqk, K]. Dtype: bfloat16. + torch::Tensor q; + // Key tensor. Shape: [B, T, Hqk, K]. Dtype: bfloat16. + torch::Tensor k; + // Value tensor. Shape: [B, T, H, V]. Dtype: bfloat16. + torch::Tensor v; + // Gating tensor. Shape: [B, T, H]. Dtype: float32 or bfloat16. + torch::Tensor g; + // Beta tensor. Shape: [B, T, H]. Dtype: float32 or bfloat16. + torch::Tensor beta; + // Optional scale factor for attention. Default: K^(-0.5). + std::optional scale = std::nullopt; + // Optional initial state tensor. Shape: [N, H, K, V]. Dtype: bfloat16. + std::optional initial_state = std::nullopt; + // Whether to output the final state. + bool output_final_state = false; + // Chunk size for processing. Default: 64. + int64_t chunk_size = 64; + // Optional cumulative sequence lengths. Shape: [num_sequences + 1]. Dtype: + // int32. + std::optional cu_seqlens = std::nullopt; + // Whether input is head-first format. Default: false (batch-first). + bool head_first = false; + // Whether to apply L2 norm to q and k inside the kernel. Default: false. + bool use_qk_l2norm_in_kernel = false; +}; } // namespace xllm::kernel diff --git a/xllm/core/layers/npu_torch/qwen3_gated_delta_net_base.cpp b/xllm/core/layers/npu_torch/qwen3_gated_delta_net_base.cpp index c0c4fd699..cd2873778 100644 --- a/xllm/core/layers/npu_torch/qwen3_gated_delta_net_base.cpp +++ b/xllm/core/layers/npu_torch/qwen3_gated_delta_net_base.cpp @@ -422,29 +422,56 @@ torch::Tensor Qwen3GatedDeltaNetBaseImpl::forward( gdn_params.beta = 1.0f; gdn_params.threshold = 20.0f; std::tie(g, beta) = xllm::kernel::fused_gdn_gating(gdn_params); - g = g.permute({1, 0, 2}).contiguous(); - beta = beta.permute({1, 0, 2}).contiguous(); } auto [processed_q, processed_k, processed_v] = process_mixed_qkv(mixed_qkv); - int64_t repeat_times = num_v_heads_ / num_k_heads_; - if (repeat_times > 1) { - processed_q = processed_q.repeat_interleave(repeat_times, 2); - processed_k = processed_k.repeat_interleave(repeat_times, 2); - } // Apply chunked or recurrent gated-delta attention and update caches. if (attn_metadata.is_prefill) { + xllm::kernel::ChunkGatedDeltaRuleParams chunk_gated_delta_params; + chunk_gated_delta_params.q = processed_q; + chunk_gated_delta_params.k = processed_k; + chunk_gated_delta_params.v = processed_v; + chunk_gated_delta_params.g = g; + chunk_gated_delta_params.beta = beta; + // Get initial state from ssm_cache for sequences with previous state + // Shape: [batch_size, num_heads, head_k_dim, head_v_dim] + torch::Tensor initial_state_tensor = + torch::index_select(ssm_cache, 0, linear_state_indices); + // Todo: chunked-prefill/prefix-cache use initial_state + initial_state_tensor.fill_(0.0); + chunk_gated_delta_params.initial_state = initial_state_tensor; + chunk_gated_delta_params.output_final_state = true; + chunk_gated_delta_params.cu_seqlens = attn_metadata.q_cu_seq_lens; + chunk_gated_delta_params.head_first = false; + chunk_gated_delta_params.use_qk_l2norm_in_kernel = true; std::tie(core_attn_out, last_recurrent_state) = - torch_chunk_gated_delta_rule( - processed_q, processed_k, processed_v, g, beta); - ssm_cache.index_put_({linear_state_indices}, - last_recurrent_state.to(ssm_cache.dtype())); + xllm::kernel::chunk_gated_delta_rule(chunk_gated_delta_params); + ssm_cache.index_put_( + {linear_state_indices}, + last_recurrent_state.transpose(-1, -2).to(ssm_cache.dtype())); } else { - auto ssm_state = torch::index_select(ssm_cache, 0, linear_state_indices); - std::tie(core_attn_out, last_recurrent_state) = - torch_recurrent_gated_delta_rule( - processed_q, processed_k, processed_v, g, beta, ssm_state); - ssm_cache.index_put_({linear_state_indices}, - last_recurrent_state.to(ssm_cache.dtype())); + processed_q = xllm::kernel::l2_norm(processed_q, 1e-6); + processed_k = xllm::kernel::l2_norm(processed_k, 1e-6); + auto zero = torch::zeros({1}, attn_metadata.q_seq_lens.options()); + torch::Tensor actual_seq_lengths = + torch::cat({zero, attn_metadata.q_seq_lens}, 0); + double scale = 1.0 / std::sqrt(static_cast(processed_q.size(-1))); + core_attn_out = xllm::kernel::recurrent_gated_delta_rule( + processed_q.reshape( + {-1, processed_q.size(-2), processed_q.size(-1)}), + processed_k.reshape( + {-1, processed_k.size(-2), processed_k.size(-1)}), + processed_v.reshape( + {-1, processed_v.size(-2), processed_v.size(-1)}), + ssm_cache, + beta.squeeze(0).contiguous(), + scale, + actual_seq_lengths, + linear_state_indices, + c10::nullopt, + g.squeeze(0).contiguous(), + c10::nullopt) + .unsqueeze(0) + .contiguous(); } auto z_reshaped = z.view({-1, z.size(-1)});