Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion third_party/torch_npu_ops
Submodule torch_npu_ops updated from 599eb0 to eaaddb
2 changes: 1 addition & 1 deletion third_party/xllm_ops
Submodule xllm_ops updated from 96a590 to d3a0ac
1 change: 1 addition & 0 deletions xllm/core/kernels/npu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ cc_library(
npu_moe_init_routing_v2.cpp
npu_moe_token_unpermute.cpp
rope.cpp
npu_recurrent_gated_delta_rule.cpp
DEPS
:torch_npu_kernels
:tilelang_kernels
Expand Down
12 changes: 12 additions & 0 deletions xllm/core/kernels/npu/npu_ops_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -141,4 +141,16 @@ std::pair<torch::Tensor, torch::Tensor> apply_npu_partial_rotary_embedding(
const torch::Tensor& cos_sin_cache,
bool is_neox_style);

torch::Tensor npu_recurrent_gated_delta_rule(
const torch::Tensor& query,
const torch::Tensor& key,
const torch::Tensor& value,
torch::Tensor& state,
const std::optional<torch::Tensor>& beta,
const std::optional<double> scale,
const std::optional<torch::Tensor>& actual_seq_lengths,
const std::optional<torch::Tensor>& ssm_state_indices,
const std::optional<torch::Tensor>& num_accepted_tokens,
const std::optional<torch::Tensor>& g,
const std::optional<torch::Tensor>& gk);
} // namespace xllm::kernel::npu
165 changes: 165 additions & 0 deletions xllm/core/kernels/npu/npu_recurrent_gated_delta_rule.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
/* Copyright 2026 The xLLM Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

https://github.com/jd-opensource/xllm/blob/main/LICENSE

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include <c10/core/Device.h>
#include <glog/logging.h>
#include <torch/torch.h>
#include <torch_npu/csrc/libs/init_npu.h>
#include <torch_npu/torch_npu.h>

#include <nlohmann/json.hpp>
#ifdef TORCH_HIGHER_THAN_PTA6
#include <torch_npu/csrc/framework/OpCommand.h>
#else
#include <torch_npu/csrc/aten/NPUNativeFunctions.h>
#include <torch_npu/csrc/framework/utils/OpPreparation.h>
#endif

#include "acl/acl.h"
#include "aclnn_recurrent_gated_delta_rule.h"
#include "core/common/macros.h"
#include "core/kernels/npu/utils.h"
#include "npu_ops_api.h"

namespace xllm::kernel::npu {

torch::Tensor npu_recurrent_gated_delta_rule(
const torch::Tensor& query,
const torch::Tensor& key,
const torch::Tensor& value,
torch::Tensor& state,
const std::optional<torch::Tensor>& beta,
const std::optional<double> scale,
const std::optional<torch::Tensor>& actual_seq_lengths,
const std::optional<torch::Tensor>& ssm_state_indices,
const std::optional<torch::Tensor>& num_accepted_tokens,
const std::optional<torch::Tensor>& g,
const std::optional<torch::Tensor>& gk) {
check_tensor(query, "query", "recurrent_gated_delta_rule");
check_tensor(key, "key", "recurrent_gated_delta_rule");
check_tensor(value, "value", "recurrent_gated_delta_rule");
check_tensor(state, "state", "recurrent_gated_delta_rule");

aclTensor* query_ids = nullptr;
aclTensor* key_ids = nullptr;
aclTensor* value_ids = nullptr;
aclTensor* state_ids = nullptr;
aclTensor* beta_ids = nullptr;
aclTensor* actual_seq_lengths_ids = nullptr;
aclTensor* ssm_state_indices_ids = nullptr;
aclTensor* num_accepted_tokens_ids = nullptr;
aclTensor* g_ids = nullptr;
aclTensor* gk_ids = nullptr;
aclTensor* out_ids = nullptr;

int32_t device_id = query.device().index();
aclrtStream stream = c10_npu::getCurrentNPUStream(device_id).stream();

create_acltensor(&query_ids, query);
create_acltensor(&key_ids, key);
create_acltensor(&value_ids, value);
create_acltensor(&state_ids, state);

if (beta.has_value() && beta.value().defined()) {
create_acltensor(&beta_ids, beta.value());
}
if (actual_seq_lengths.has_value() && actual_seq_lengths.value().defined()) {
create_acltensor(&actual_seq_lengths_ids, actual_seq_lengths.value());
}
if (ssm_state_indices.has_value() && ssm_state_indices.value().defined()) {
create_acltensor(&ssm_state_indices_ids, ssm_state_indices.value());
}
if (num_accepted_tokens.has_value() &&
num_accepted_tokens.value().defined()) {
create_acltensor(&num_accepted_tokens_ids, num_accepted_tokens.value());
}
if (g.has_value() && g.value().defined()) {
create_acltensor(&g_ids, g.value());
}
if (gk.has_value() && gk.value().defined()) {
create_acltensor(&gk_ids, gk.value());
}

at::Tensor out_result = at::empty_like(value);
create_acltensor(&out_ids, out_result);

float scale_value = static_cast<float>(scale.value());

uint64_t workspace_size = 0;
aclOpExecutor* executor = nullptr;

CHECK_ACL_SUCCESS(
aclnnRecurrentGatedDeltaRuleGetWorkspaceSize(query_ids,
key_ids,
value_ids,
beta_ids,
state_ids,
actual_seq_lengths_ids,
ssm_state_indices_ids,
g_ids,
gk_ids,
num_accepted_tokens_ids,
scale_value,
out_ids,
&workspace_size,
&executor),
"recurrent_gated_delta_rule: failed to get workspace size");

void* workspace_addr = nullptr;
if (workspace_size > 0) {
CHECK_ACL_SUCCESS(
aclrtMalloc(&workspace_addr, workspace_size, ACL_MEM_MALLOC_HUGE_FIRST),
"recurrent_gated_delta_rule: failed to allocate workspace");
}

CHECK_ACL_SUCCESS(aclnnRecurrentGatedDeltaRule(
workspace_addr, workspace_size, executor, stream),
"recurrent_gated_delta_rule: failed to perform recurrent "
"gated delta rule");

aclDestroyTensor(query_ids);
aclDestroyTensor(key_ids);
aclDestroyTensor(value_ids);
aclDestroyTensor(state_ids);
aclDestroyTensor(out_ids);

if (beta_ids != nullptr) {
aclDestroyTensor(beta_ids);
}
if (actual_seq_lengths_ids != nullptr) {
aclDestroyTensor(actual_seq_lengths_ids);
}
if (ssm_state_indices_ids != nullptr) {
aclDestroyTensor(ssm_state_indices_ids);
}
if (num_accepted_tokens_ids != nullptr) {
aclDestroyTensor(num_accepted_tokens_ids);
}
if (g_ids != nullptr) {
aclDestroyTensor(g_ids);
}
if (gk_ids != nullptr) {
aclDestroyTensor(gk_ids);
}

if (workspace_size > 0) {
CHECK_ACL_SUCCESS(aclrtFree(workspace_addr),
"recurrent_gated_delta_rule: failed to free workspace");
}

return out_result;
}

} // namespace xllm::kernel::npu
55 changes: 55 additions & 0 deletions xllm/core/kernels/ops_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -756,6 +756,14 @@ void fused_indexer_k(FusedIndexerKParams& params) {
#endif
}

torch::Tensor l2_norm(torch::Tensor& x, double eps) {
#if defined(USE_NPU)
return npu::npu_l2norm_last_dim(x, eps);
#else
NOT_IMPLEMENTED();
#endif
}

std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
moe_init_routing_v2(MoeInitRoutingV2Params& params) {
#if defined(USE_NPU)
Expand Down Expand Up @@ -1022,4 +1030,51 @@ torch::Tensor build_split_qkv_rmsnorm_mrope_gather_pattern(
#endif
}

std::pair<torch::Tensor, torch::Tensor> chunk_gated_delta_rule(
ChunkGatedDeltaRuleParams& params) {
#if defined(USE_NPU)
return npu::npu_chunk_gated_delta_rule(params.q,
params.k,
params.v,
params.g,
params.beta,
params.scale,
params.initial_state,
params.output_final_state,
params.cu_seqlens,
params.head_first,
params.use_qk_l2norm_in_kernel);
#else
NOT_IMPLEMENTED();
#endif
}

torch::Tensor recurrent_gated_delta_rule(
const torch::Tensor& query,
const torch::Tensor& key,
const torch::Tensor& value,
torch::Tensor& state,
const std::optional<torch::Tensor>& beta,
const std::optional<double> scale,
const std::optional<torch::Tensor>& actual_seq_lengths,
const std::optional<torch::Tensor>& ssm_state_indices,
const std::optional<torch::Tensor>& num_accepted_tokens,
const std::optional<torch::Tensor>& g,
const std::optional<torch::Tensor>& gk) {
#if defined(USE_NPU)
return npu::npu_recurrent_gated_delta_rule(query,
key,
value,
state,
beta,
scale,
actual_seq_lengths,
ssm_state_indices,
num_accepted_tokens,
g,
gk);
#else
NOT_IMPLEMENTED();
#endif
}
} // namespace xllm::kernel
18 changes: 18 additions & 0 deletions xllm/core/kernels/ops_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,9 @@ void fused_indexer_q(FusedIndexerQParams& params);

void fused_indexer_k(FusedIndexerKParams& params);

// L2 normalization along the last dimension
torch::Tensor l2_norm(torch::Tensor& x, double eps = 1e-6);

// TODO: NPU moe_init_routing_v2 is equivalent to moe_gen_idx + moe_expand_input
// (and token_count/cusum outputs) on other backends.
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
Expand Down Expand Up @@ -152,4 +155,19 @@ torch::Tensor build_split_qkv_rmsnorm_mrope_gather_pattern(
bool is_interleaved,
const torch::Device& device);

std::pair<torch::Tensor, torch::Tensor> chunk_gated_delta_rule(
ChunkGatedDeltaRuleParams& params);

torch::Tensor recurrent_gated_delta_rule(
const torch::Tensor& query,
const torch::Tensor& key,
const torch::Tensor& value,
torch::Tensor& state,
const std::optional<torch::Tensor>& beta,
const std::optional<double> scale,
const std::optional<torch::Tensor>& actual_seq_lengths,
const std::optional<torch::Tensor>& ssm_state_indices,
const std::optional<torch::Tensor>& num_accepted_tokens,
const std::optional<torch::Tensor>& g,
const std::optional<torch::Tensor>& gk);
} // namespace xllm::kernel
27 changes: 27 additions & 0 deletions xllm/core/kernels/param.h
Original file line number Diff line number Diff line change
Expand Up @@ -1411,4 +1411,31 @@ struct SplitQkvRmsnormMropeParams {
int64_t head_size;
};

struct ChunkGatedDeltaRuleParams {
// Query tensor. Shape: [B, T, Hqk, K]. Dtype: bfloat16.
torch::Tensor q;
// Key tensor. Shape: [B, T, Hqk, K]. Dtype: bfloat16.
torch::Tensor k;
// Value tensor. Shape: [B, T, H, V]. Dtype: bfloat16.
torch::Tensor v;
// Gating tensor. Shape: [B, T, H]. Dtype: float32 or bfloat16.
torch::Tensor g;
// Beta tensor. Shape: [B, T, H]. Dtype: float32 or bfloat16.
torch::Tensor beta;
// Optional scale factor for attention. Default: K^(-0.5).
std::optional<float> scale = std::nullopt;
// Optional initial state tensor. Shape: [N, H, K, V]. Dtype: bfloat16.
std::optional<torch::Tensor> initial_state = std::nullopt;
// Whether to output the final state.
bool output_final_state = false;
// Chunk size for processing. Default: 64.
int64_t chunk_size = 64;
// Optional cumulative sequence lengths. Shape: [num_sequences + 1]. Dtype:
// int32.
std::optional<torch::Tensor> cu_seqlens = std::nullopt;
// Whether input is head-first format. Default: false (batch-first).
bool head_first = false;
// Whether to apply L2 norm to q and k inside the kernel. Default: false.
bool use_qk_l2norm_in_kernel = false;
};
} // namespace xllm::kernel
Loading
Loading