Skip to content

feat: add npu_recurrent_gated_delta_rule and chunk_gated_delta_rule fusion operators for qwen3.5/qwen3-next.#1262

Open
fems14 wants to merge 4 commits intojd-opensource:mainfrom
fems14:delat_net
Open

feat: add npu_recurrent_gated_delta_rule and chunk_gated_delta_rule fusion operators for qwen3.5/qwen3-next.#1262
fems14 wants to merge 4 commits intojd-opensource:mainfrom
fems14:delat_net

Conversation

@fems14
Copy link
Copy Markdown
Contributor

@fems14 fems14 commented Apr 11, 2026

qwe3.5/qwen3-next model add npu_recurrent_gated_delta_rule and chunk_gated_delta_rule fusion operater
需依赖算子先合入:https://gitcode.com/xLLM-AI/torch_npu_ops/pull/12

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request updates the Qwen3GatedDeltaNetBase implementation to utilize new NPU-specific kernels for chunked and recurrent gated-delta attention and adds cumulative sequence length calculation to the attention metadata. However, several critical issues need to be addressed: the initial state fetched from the SSM cache is being incorrectly zeroed out, which breaks statefulness; the recurrent state is transposed before being stored back in the cache, leading to a layout mismatch; and the removal of head repetition logic combined with the use of squeeze(0) on tensors could result in shape and dimension errors.

Comment on lines +429 to +431
// Todo: chunked-prefill/prefix-cache use initial_state
initial_state_tensor.fill_(0.0);
chunk_gated_delta_params.initial_state = initial_state_tensor;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The initial state fetched from the SSM cache is immediately overwritten with zeros by initial_state_tensor.fill_(0.0). This is a critical bug that breaks statefulness for sequences requiring a previous state (e.g., in chunked prefill or prefix caching scenarios). The initial_state_tensor should be used as-is after being indexed from the cache.

Suggested change
// Todo: chunked-prefill/prefix-cache use initial_state
initial_state_tensor.fill_(0.0);
chunk_gated_delta_params.initial_state = initial_state_tensor;
// Todo: chunked-prefill/prefix-cache use initial_state
chunk_gated_delta_params.initial_state = initial_state_tensor;

Comment on lines 416 to 417
auto [processed_q, processed_k, processed_v] = process_mixed_qkv(mixed_qkv);
int64_t repeat_times = num_v_heads_ / num_k_heads_;
if (repeat_times > 1) {
processed_q = processed_q.repeat_interleave(repeat_times, 2);
processed_k = processed_k.repeat_interleave(repeat_times, 2);
}
// Apply chunked or recurrent gated-delta attention and update caches.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The logic to repeat Q and K heads to match V heads (previously at lines 416-420) has been removed. In GQA/MQA architectures where num_k_heads_ != num_v_heads_, this will result in processed_q and processed_k having a different number of heads than processed_v. Unless the new fusion kernels (chunk_gated_delta_rule and npu_recurrent_gated_delta_rule) explicitly handle broadcasting for mismatched head counts internally, this change will cause shape mismatch errors or incorrect results.

  auto [processed_q, processed_k, processed_v] = process_mixed_qkv(mixed_qkv);
  int64_t repeat_times = num_v_heads_ / num_k_heads_;
  if (repeat_times > 1) {
    processed_q = processed_q.repeat_interleave(repeat_times, 2);
    processed_k = processed_k.repeat_interleave(repeat_times, 2);
  }
  // Apply chunked or recurrent gated-delta attention and update caches.

Comment on lines +438 to +439
ssm_cache.index_put_({input_params.block_tables.select(1, 0)},
last_recurrent_state.to(ssm_cache.dtype()));
last_recurrent_state.transpose(-1, -2).to(ssm_cache.dtype()));
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The recurrent state is being transposed (.transpose(-1, -2)) before being stored back into the ssm_cache. This changes the cache layout from [B, H, K, V] to [B, H, V, K], which contradicts the shape documentation on line 426 and the way the state is read on line 427. This inconsistency will lead to incorrect results in subsequent steps that expect the standard layout.

Suggested change
ssm_cache.index_put_({input_params.block_tables.select(1, 0)},
last_recurrent_state.to(ssm_cache.dtype()));
last_recurrent_state.transpose(-1, -2).to(ssm_cache.dtype()));
ssm_cache.index_put_({input_params.block_tables.select(1, 0)},
last_recurrent_state.to(ssm_cache.dtype()));

Comment on lines +456 to +461
beta.squeeze(0).contiguous(),
scale,
actual_seq_lengths,
ssm_state_indices,
c10::nullopt,
g.squeeze(0).contiguous(),
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using .squeeze(0) on beta and g will remove the batch dimension when the batch size is 1. If the NPU kernel npu_recurrent_gated_delta_rule expects a 2D tensor with a batch/token dimension (e.g., [Batch, Heads]), this will cause a dimension mismatch error during inference with batch size 1.

Suggested change
beta.squeeze(0).contiguous(),
scale,
actual_seq_lengths,
ssm_state_indices,
c10::nullopt,
g.squeeze(0).contiguous(),
beta.contiguous(),
scale,
actual_seq_lengths,
ssm_state_indices,
c10::nullopt,
g.contiguous(),

@yingxudeng yingxudeng changed the title qwe3.5/qwen3-next model add npu_recurrent_gated_delta_rule and chunk_gated_delta_rule fusion operater feat: add npu_recurrent_gated_delta_rule and chunk_gated_delta_rule fusion operators for qwen3.5/qwen3-next. Apr 11, 2026
@yingxudeng yingxudeng marked this pull request as draft April 11, 2026 10:35
@Vectorwh Vectorwh force-pushed the delat_net branch 2 times, most recently from 54fd624 to 44ca51d Compare April 14, 2026 12:26
torch::Tensor ssm_state_indices =
attn_metadata.block_table.select(1, 0).contiguous();

// Todo: 使用 q_lens
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No Chinese in comments.

torch::Tensor cumsum_tensor =
torch::cumsum(attn_metadata.q_seq_lens, 0).to(torch::kInt32);
auto zero = torch::zeros({1}, cumsum_tensor.options());
attn_metadata.q_cu_seq_lens = torch::cat({zero, cumsum_tensor}, 0);
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个值被修改,会影响其他模型吗

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不会影响,因为1.目前没有其他模型使用, 2. 这个初始化被npu宏控制了


torch::Tensor actual_seq_lengths = attn_metadata.q_seq_lens.clone();
double scale = 1.0 / std::sqrt(static_cast<float>(processed_q.size(-1)));
core_attn_out = at_npu::native::custom_ops::npu_recurrent_gated_delta_rule(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

别直接在这里调用,还是统一从这里暴露接口,然后调用。
/export/home/dengyingxu1/projects/xllm/xllm/core/kernels/ops_api.h

chunk_gated_delta_params.initial_state = initial_state_tensor;
chunk_gated_delta_params.output_final_state = true;
chunk_gated_delta_params.cu_seqlens =
attn_metadata.q_cu_seq_lens.to(torch::kInt32);
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to(torch::kInt32)有必要吗?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

无比要,已删除

double scale = 1.0 / std::sqrt(static_cast<float>(processed_q.size(-1)));
core_attn_out = at_npu::native::custom_ops::npu_recurrent_gated_delta_rule(
processed_q.reshape(
{-1, processed_q.size(-2), processed_q.size(-1)}),
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

view是否可以?reshape对非连续的tensor会重组?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mixedqkv可能不连续,这里用reshape可以自动重组


torch::Tensor actual_seq_lengths = attn_metadata.q_seq_lens.clone();
double scale = 1.0 / std::sqrt(static_cast<float>(processed_q.size(-1)));
core_attn_out = at_npu::native::custom_ops::npu_recurrent_gated_delta_rule(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

npu_recurrent_gated_delta_rule这个算子现在支持fp32了吗?

processed_q = xllm::kernel::l2_norm(processed_q, 1e-6);
processed_k = xllm::kernel::l2_norm(processed_k, 1e-6);
torch::Tensor ssm_state_indices =
attn_metadata.block_table.select(1, 0).contiguous();
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

从block_table取第一列的contiguous算子太低效了

wanghao and others added 3 commits April 20, 2026 10:42
…rator

Implemented the recurrent_gated_delta_rule operator on the NPU, including:

1. Adding operator interface declaration

2. Implementing NPU backend computation logic

3. Updating CMake build configuration

4. Integrating and using it in Qwen3GatedDeltaNetBase
@yingxudeng yingxudeng marked this pull request as ready for review April 20, 2026 12:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants