Skip to content

Commit 71ed346

Browse files
committed
feat(npu): update chunk_gated_delta_rule and recurrent_gated_delta_rule operations on NPU
The main changes include: 1. Add the implementation file npu_recurrent_gated_delta_rule.cpp 2. Add function declarations in npu_ops_api.h 3. Add a generic interface in ops_api.h and ops_api.cpp 4. Update CMakeLists.txt to include the new source files 5. Integrate new operations in qwen3_gated_delta_net_base.cpp 6. Update submodule version
1 parent 681f9c7 commit 71ed346

9 files changed

Lines changed: 326 additions & 20 deletions

File tree

third_party/torch_npu_ops

Submodule torch_npu_ops updated from 599eb03 to eaaddb9

third_party/xllm_ops

Submodule xllm_ops updated from 96a5909 to d3a0acf

xllm/core/kernels/npu/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ cc_library(
2020
npu_moe_init_routing_v2.cpp
2121
npu_moe_token_unpermute.cpp
2222
rope.cpp
23+
npu_recurrent_gated_delta_rule.cpp
2324
DEPS
2425
:torch_npu_kernels
2526
:tilelang_kernels

xllm/core/kernels/npu/npu_ops_api.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,4 +141,16 @@ std::pair<torch::Tensor, torch::Tensor> apply_npu_partial_rotary_embedding(
141141
const torch::Tensor& cos_sin_cache,
142142
bool is_neox_style);
143143

144+
torch::Tensor npu_recurrent_gated_delta_rule(
145+
const torch::Tensor& query,
146+
const torch::Tensor& key,
147+
const torch::Tensor& value,
148+
torch::Tensor& state,
149+
const std::optional<torch::Tensor>& beta,
150+
const std::optional<double> scale,
151+
const std::optional<torch::Tensor>& actual_seq_lengths,
152+
const std::optional<torch::Tensor>& ssm_state_indices,
153+
const std::optional<torch::Tensor>& num_accepted_tokens,
154+
const std::optional<torch::Tensor>& g,
155+
const std::optional<torch::Tensor>& gk);
144156
} // namespace xllm::kernel::npu
Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
/* Copyright 2026 The xLLM Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
https://github.com/jd-opensource/xllm/blob/main/LICENSE
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include <c10/core/Device.h>
17+
#include <glog/logging.h>
18+
#include <torch/torch.h>
19+
#include <torch_npu/csrc/libs/init_npu.h>
20+
#include <torch_npu/torch_npu.h>
21+
22+
#include <nlohmann/json.hpp>
23+
#ifdef TORCH_HIGHER_THAN_PTA6
24+
#include <torch_npu/csrc/framework/OpCommand.h>
25+
#else
26+
#include <torch_npu/csrc/aten/NPUNativeFunctions.h>
27+
#include <torch_npu/csrc/framework/utils/OpPreparation.h>
28+
#endif
29+
30+
#include "acl/acl.h"
31+
#include "aclnn_recurrent_gated_delta_rule.h"
32+
#include "core/common/macros.h"
33+
#include "core/kernels/npu/utils.h"
34+
#include "npu_ops_api.h"
35+
36+
namespace xllm::kernel::npu {
37+
38+
torch::Tensor npu_recurrent_gated_delta_rule(
39+
const torch::Tensor& query,
40+
const torch::Tensor& key,
41+
const torch::Tensor& value,
42+
torch::Tensor& state,
43+
const std::optional<torch::Tensor>& beta,
44+
const std::optional<double> scale,
45+
const std::optional<torch::Tensor>& actual_seq_lengths,
46+
const std::optional<torch::Tensor>& ssm_state_indices,
47+
const std::optional<torch::Tensor>& num_accepted_tokens,
48+
const std::optional<torch::Tensor>& g,
49+
const std::optional<torch::Tensor>& gk) {
50+
check_tensor(query, "query", "recurrent_gated_delta_rule");
51+
check_tensor(key, "key", "recurrent_gated_delta_rule");
52+
check_tensor(value, "value", "recurrent_gated_delta_rule");
53+
check_tensor(state, "state", "recurrent_gated_delta_rule");
54+
55+
aclTensor* query_ids = nullptr;
56+
aclTensor* key_ids = nullptr;
57+
aclTensor* value_ids = nullptr;
58+
aclTensor* state_ids = nullptr;
59+
aclTensor* beta_ids = nullptr;
60+
aclTensor* actual_seq_lengths_ids = nullptr;
61+
aclTensor* ssm_state_indices_ids = nullptr;
62+
aclTensor* num_accepted_tokens_ids = nullptr;
63+
aclTensor* g_ids = nullptr;
64+
aclTensor* gk_ids = nullptr;
65+
aclTensor* out_ids = nullptr;
66+
67+
int32_t device_id = query.device().index();
68+
aclrtStream stream = c10_npu::getCurrentNPUStream(device_id).stream();
69+
70+
create_acltensor(&query_ids, query);
71+
create_acltensor(&key_ids, key);
72+
create_acltensor(&value_ids, value);
73+
create_acltensor(&state_ids, state);
74+
75+
if (beta.has_value() && beta.value().defined()) {
76+
create_acltensor(&beta_ids, beta.value());
77+
}
78+
if (actual_seq_lengths.has_value() && actual_seq_lengths.value().defined()) {
79+
create_acltensor(&actual_seq_lengths_ids, actual_seq_lengths.value());
80+
}
81+
if (ssm_state_indices.has_value() && ssm_state_indices.value().defined()) {
82+
create_acltensor(&ssm_state_indices_ids, ssm_state_indices.value());
83+
}
84+
if (num_accepted_tokens.has_value() &&
85+
num_accepted_tokens.value().defined()) {
86+
create_acltensor(&num_accepted_tokens_ids, num_accepted_tokens.value());
87+
}
88+
if (g.has_value() && g.value().defined()) {
89+
create_acltensor(&g_ids, g.value());
90+
}
91+
if (gk.has_value() && gk.value().defined()) {
92+
create_acltensor(&gk_ids, gk.value());
93+
}
94+
95+
at::Tensor out_result = at::empty_like(value);
96+
create_acltensor(&out_ids, out_result);
97+
98+
float scale_value = static_cast<float>(scale.value());
99+
100+
uint64_t workspace_size = 0;
101+
aclOpExecutor* executor = nullptr;
102+
103+
CHECK_ACL_SUCCESS(
104+
aclnnRecurrentGatedDeltaRuleGetWorkspaceSize(query_ids,
105+
key_ids,
106+
value_ids,
107+
beta_ids,
108+
state_ids,
109+
actual_seq_lengths_ids,
110+
ssm_state_indices_ids,
111+
g_ids,
112+
gk_ids,
113+
num_accepted_tokens_ids,
114+
scale_value,
115+
out_ids,
116+
&workspace_size,
117+
&executor),
118+
"recurrent_gated_delta_rule: failed to get workspace size");
119+
120+
void* workspace_addr = nullptr;
121+
if (workspace_size > 0) {
122+
CHECK_ACL_SUCCESS(
123+
aclrtMalloc(&workspace_addr, workspace_size, ACL_MEM_MALLOC_HUGE_FIRST),
124+
"recurrent_gated_delta_rule: failed to allocate workspace");
125+
}
126+
127+
CHECK_ACL_SUCCESS(aclnnRecurrentGatedDeltaRule(
128+
workspace_addr, workspace_size, executor, stream),
129+
"recurrent_gated_delta_rule: failed to perform recurrent "
130+
"gated delta rule");
131+
132+
aclDestroyTensor(query_ids);
133+
aclDestroyTensor(key_ids);
134+
aclDestroyTensor(value_ids);
135+
aclDestroyTensor(state_ids);
136+
aclDestroyTensor(out_ids);
137+
138+
if (beta_ids != nullptr) {
139+
aclDestroyTensor(beta_ids);
140+
}
141+
if (actual_seq_lengths_ids != nullptr) {
142+
aclDestroyTensor(actual_seq_lengths_ids);
143+
}
144+
if (ssm_state_indices_ids != nullptr) {
145+
aclDestroyTensor(ssm_state_indices_ids);
146+
}
147+
if (num_accepted_tokens_ids != nullptr) {
148+
aclDestroyTensor(num_accepted_tokens_ids);
149+
}
150+
if (g_ids != nullptr) {
151+
aclDestroyTensor(g_ids);
152+
}
153+
if (gk_ids != nullptr) {
154+
aclDestroyTensor(gk_ids);
155+
}
156+
157+
if (workspace_size > 0) {
158+
CHECK_ACL_SUCCESS(aclrtFree(workspace_addr),
159+
"recurrent_gated_delta_rule: failed to free workspace");
160+
}
161+
162+
return out_result;
163+
}
164+
165+
} // namespace xllm::kernel::npu

xllm/core/kernels/ops_api.cpp

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -756,6 +756,14 @@ void fused_indexer_k(FusedIndexerKParams& params) {
756756
#endif
757757
}
758758

759+
torch::Tensor l2_norm(torch::Tensor& x, double eps) {
760+
#if defined(USE_NPU)
761+
return npu::npu_l2norm_last_dim(x, eps);
762+
#else
763+
NOT_IMPLEMENTED();
764+
#endif
765+
}
766+
759767
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
760768
moe_init_routing_v2(MoeInitRoutingV2Params& params) {
761769
#if defined(USE_NPU)
@@ -991,4 +999,52 @@ void gemma_rms_norm(GemmaRMSNormParams& params) {
991999
NOT_IMPLEMENTED();
9921000
#endif
9931001
}
1002+
1003+
std::pair<torch::Tensor, torch::Tensor> chunk_gated_delta_rule(
1004+
ChunkGatedDeltaRuleParams& params) {
1005+
#if defined(USE_NPU)
1006+
return npu::npu_chunk_gated_delta_rule(params.q,
1007+
params.k,
1008+
params.v,
1009+
params.g,
1010+
params.beta,
1011+
params.scale,
1012+
params.initial_state,
1013+
params.output_final_state,
1014+
params.cu_seqlens,
1015+
params.head_first,
1016+
params.use_qk_l2norm_in_kernel);
1017+
#else
1018+
NOT_IMPLEMENTED();
1019+
#endif
1020+
}
1021+
1022+
torch::Tensor recurrent_gated_delta_rule(
1023+
const torch::Tensor& query,
1024+
const torch::Tensor& key,
1025+
const torch::Tensor& value,
1026+
torch::Tensor& state,
1027+
const std::optional<torch::Tensor>& beta,
1028+
const std::optional<double> scale,
1029+
const std::optional<torch::Tensor>& actual_seq_lengths,
1030+
const std::optional<torch::Tensor>& ssm_state_indices,
1031+
const std::optional<torch::Tensor>& num_accepted_tokens,
1032+
const std::optional<torch::Tensor>& g,
1033+
const std::optional<torch::Tensor>& gk) {
1034+
#if defined(USE_NPU)
1035+
return npu::npu_recurrent_gated_delta_rule(query,
1036+
key,
1037+
value,
1038+
state,
1039+
beta,
1040+
scale,
1041+
actual_seq_lengths,
1042+
ssm_state_indices,
1043+
num_accepted_tokens,
1044+
g,
1045+
gk);
1046+
#else
1047+
NOT_IMPLEMENTED();
1048+
#endif
1049+
}
9941050
} // namespace xllm::kernel

xllm/core/kernels/ops_api.h

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,9 @@ void fused_indexer_q(FusedIndexerQParams& params);
9494

9595
void fused_indexer_k(FusedIndexerKParams& params);
9696

97+
// L2 normalization along the last dimension
98+
torch::Tensor l2_norm(torch::Tensor& x, double eps = 1e-6);
99+
97100
// TODO: NPU moe_init_routing_v2 is equivalent to moe_gen_idx + moe_expand_input
98101
// (and token_count/cusum outputs) on other backends.
99102
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
@@ -142,5 +145,19 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
142145
fused_qkvzba_split_reshape_cat(FusedQkvzbaSplitReshapeParams& params);
143146

144147
void gemma_rms_norm(GemmaRMSNormParams& params);
145-
148+
std::pair<torch::Tensor, torch::Tensor> chunk_gated_delta_rule(
149+
ChunkGatedDeltaRuleParams& params);
150+
151+
torch::Tensor recurrent_gated_delta_rule(
152+
const torch::Tensor& query,
153+
const torch::Tensor& key,
154+
const torch::Tensor& value,
155+
torch::Tensor& state,
156+
const std::optional<torch::Tensor>& beta,
157+
const std::optional<double> scale,
158+
const std::optional<torch::Tensor>& actual_seq_lengths,
159+
const std::optional<torch::Tensor>& ssm_state_indices,
160+
const std::optional<torch::Tensor>& num_accepted_tokens,
161+
const std::optional<torch::Tensor>& g,
162+
const std::optional<torch::Tensor>& gk);
146163
} // namespace xllm::kernel

xllm/core/kernels/param.h

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1398,4 +1398,32 @@ struct GemmaRMSNormParams {
13981398
torch::Tensor rstd_out;
13991399
torch::Tensor norm_out;
14001400
};
1401+
1402+
struct ChunkGatedDeltaRuleParams {
1403+
// Query tensor. Shape: [B, T, Hqk, K]. Dtype: bfloat16.
1404+
torch::Tensor q;
1405+
// Key tensor. Shape: [B, T, Hqk, K]. Dtype: bfloat16.
1406+
torch::Tensor k;
1407+
// Value tensor. Shape: [B, T, H, V]. Dtype: bfloat16.
1408+
torch::Tensor v;
1409+
// Gating tensor. Shape: [B, T, H]. Dtype: float32 or bfloat16.
1410+
torch::Tensor g;
1411+
// Beta tensor. Shape: [B, T, H]. Dtype: float32 or bfloat16.
1412+
torch::Tensor beta;
1413+
// Optional scale factor for attention. Default: K^(-0.5).
1414+
std::optional<float> scale = std::nullopt;
1415+
// Optional initial state tensor. Shape: [N, H, K, V]. Dtype: bfloat16.
1416+
std::optional<torch::Tensor> initial_state = std::nullopt;
1417+
// Whether to output the final state.
1418+
bool output_final_state = false;
1419+
// Chunk size for processing. Default: 64.
1420+
int64_t chunk_size = 64;
1421+
// Optional cumulative sequence lengths. Shape: [num_sequences + 1]. Dtype:
1422+
// int32.
1423+
std::optional<torch::Tensor> cu_seqlens = std::nullopt;
1424+
// Whether input is head-first format. Default: false (batch-first).
1425+
bool head_first = false;
1426+
// Whether to apply L2 norm to q and k inside the kernel. Default: false.
1427+
bool use_qk_l2norm_in_kernel = false;
1428+
};
14011429
} // namespace xllm::kernel

0 commit comments

Comments
 (0)