feat: add npu_recurrent_gated_delta_rule and chunk_gated_delta_rule fusion operators for qwen3.5/qwen3-next.#1262
feat: add npu_recurrent_gated_delta_rule and chunk_gated_delta_rule fusion operators for qwen3.5/qwen3-next.#1262fems14 wants to merge 4 commits intojd-opensource:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request updates the Qwen3GatedDeltaNetBase implementation to utilize new NPU-specific kernels for chunked and recurrent gated-delta attention and adds cumulative sequence length calculation to the attention metadata. However, several critical issues need to be addressed: the initial state fetched from the SSM cache is being incorrectly zeroed out, which breaks statefulness; the recurrent state is transposed before being stored back in the cache, leading to a layout mismatch; and the removal of head repetition logic combined with the use of squeeze(0) on tensors could result in shape and dimension errors.
| // Todo: chunked-prefill/prefix-cache use initial_state | ||
| initial_state_tensor.fill_(0.0); | ||
| chunk_gated_delta_params.initial_state = initial_state_tensor; |
There was a problem hiding this comment.
The initial state fetched from the SSM cache is immediately overwritten with zeros by initial_state_tensor.fill_(0.0). This is a critical bug that breaks statefulness for sequences requiring a previous state (e.g., in chunked prefill or prefix caching scenarios). The initial_state_tensor should be used as-is after being indexed from the cache.
| // Todo: chunked-prefill/prefix-cache use initial_state | |
| initial_state_tensor.fill_(0.0); | |
| chunk_gated_delta_params.initial_state = initial_state_tensor; | |
| // Todo: chunked-prefill/prefix-cache use initial_state | |
| chunk_gated_delta_params.initial_state = initial_state_tensor; |
| auto [processed_q, processed_k, processed_v] = process_mixed_qkv(mixed_qkv); | ||
| int64_t repeat_times = num_v_heads_ / num_k_heads_; | ||
| if (repeat_times > 1) { | ||
| processed_q = processed_q.repeat_interleave(repeat_times, 2); | ||
| processed_k = processed_k.repeat_interleave(repeat_times, 2); | ||
| } | ||
| // Apply chunked or recurrent gated-delta attention and update caches. |
There was a problem hiding this comment.
The logic to repeat Q and K heads to match V heads (previously at lines 416-420) has been removed. In GQA/MQA architectures where num_k_heads_ != num_v_heads_, this will result in processed_q and processed_k having a different number of heads than processed_v. Unless the new fusion kernels (chunk_gated_delta_rule and npu_recurrent_gated_delta_rule) explicitly handle broadcasting for mismatched head counts internally, this change will cause shape mismatch errors or incorrect results.
auto [processed_q, processed_k, processed_v] = process_mixed_qkv(mixed_qkv);
int64_t repeat_times = num_v_heads_ / num_k_heads_;
if (repeat_times > 1) {
processed_q = processed_q.repeat_interleave(repeat_times, 2);
processed_k = processed_k.repeat_interleave(repeat_times, 2);
}
// Apply chunked or recurrent gated-delta attention and update caches.| ssm_cache.index_put_({input_params.block_tables.select(1, 0)}, | ||
| last_recurrent_state.to(ssm_cache.dtype())); | ||
| last_recurrent_state.transpose(-1, -2).to(ssm_cache.dtype())); |
There was a problem hiding this comment.
The recurrent state is being transposed (.transpose(-1, -2)) before being stored back into the ssm_cache. This changes the cache layout from [B, H, K, V] to [B, H, V, K], which contradicts the shape documentation on line 426 and the way the state is read on line 427. This inconsistency will lead to incorrect results in subsequent steps that expect the standard layout.
| ssm_cache.index_put_({input_params.block_tables.select(1, 0)}, | |
| last_recurrent_state.to(ssm_cache.dtype())); | |
| last_recurrent_state.transpose(-1, -2).to(ssm_cache.dtype())); | |
| ssm_cache.index_put_({input_params.block_tables.select(1, 0)}, | |
| last_recurrent_state.to(ssm_cache.dtype())); |
| beta.squeeze(0).contiguous(), | ||
| scale, | ||
| actual_seq_lengths, | ||
| ssm_state_indices, | ||
| c10::nullopt, | ||
| g.squeeze(0).contiguous(), |
There was a problem hiding this comment.
Using .squeeze(0) on beta and g will remove the batch dimension when the batch size is 1. If the NPU kernel npu_recurrent_gated_delta_rule expects a 2D tensor with a batch/token dimension (e.g., [Batch, Heads]), this will cause a dimension mismatch error during inference with batch size 1.
| beta.squeeze(0).contiguous(), | |
| scale, | |
| actual_seq_lengths, | |
| ssm_state_indices, | |
| c10::nullopt, | |
| g.squeeze(0).contiguous(), | |
| beta.contiguous(), | |
| scale, | |
| actual_seq_lengths, | |
| ssm_state_indices, | |
| c10::nullopt, | |
| g.contiguous(), |
54fd624 to
44ca51d
Compare
| torch::Tensor ssm_state_indices = | ||
| attn_metadata.block_table.select(1, 0).contiguous(); | ||
|
|
||
| // Todo: 使用 q_lens |
There was a problem hiding this comment.
No Chinese in comments.
| torch::Tensor cumsum_tensor = | ||
| torch::cumsum(attn_metadata.q_seq_lens, 0).to(torch::kInt32); | ||
| auto zero = torch::zeros({1}, cumsum_tensor.options()); | ||
| attn_metadata.q_cu_seq_lens = torch::cat({zero, cumsum_tensor}, 0); |
There was a problem hiding this comment.
不会影响,因为1.目前没有其他模型使用, 2. 这个初始化被npu宏控制了
|
|
||
| torch::Tensor actual_seq_lengths = attn_metadata.q_seq_lens.clone(); | ||
| double scale = 1.0 / std::sqrt(static_cast<float>(processed_q.size(-1))); | ||
| core_attn_out = at_npu::native::custom_ops::npu_recurrent_gated_delta_rule( |
There was a problem hiding this comment.
别直接在这里调用,还是统一从这里暴露接口,然后调用。
/export/home/dengyingxu1/projects/xllm/xllm/core/kernels/ops_api.h
| chunk_gated_delta_params.initial_state = initial_state_tensor; | ||
| chunk_gated_delta_params.output_final_state = true; | ||
| chunk_gated_delta_params.cu_seqlens = | ||
| attn_metadata.q_cu_seq_lens.to(torch::kInt32); |
There was a problem hiding this comment.
to(torch::kInt32)有必要吗?
| double scale = 1.0 / std::sqrt(static_cast<float>(processed_q.size(-1))); | ||
| core_attn_out = at_npu::native::custom_ops::npu_recurrent_gated_delta_rule( | ||
| processed_q.reshape( | ||
| {-1, processed_q.size(-2), processed_q.size(-1)}), |
There was a problem hiding this comment.
view是否可以?reshape对非连续的tensor会重组?
There was a problem hiding this comment.
mixedqkv可能不连续,这里用reshape可以自动重组
|
|
||
| torch::Tensor actual_seq_lengths = attn_metadata.q_seq_lens.clone(); | ||
| double scale = 1.0 / std::sqrt(static_cast<float>(processed_q.size(-1))); | ||
| core_attn_out = at_npu::native::custom_ops::npu_recurrent_gated_delta_rule( |
There was a problem hiding this comment.
npu_recurrent_gated_delta_rule这个算子现在支持fp32了吗?
| processed_q = xllm::kernel::l2_norm(processed_q, 1e-6); | ||
| processed_k = xllm::kernel::l2_norm(processed_k, 1e-6); | ||
| torch::Tensor ssm_state_indices = | ||
| attn_metadata.block_table.select(1, 0).contiguous(); |
There was a problem hiding this comment.
从block_table取第一列的contiguous算子太低效了
…usion operators for qwen3.5/qwen3-next.
…rator Implemented the recurrent_gated_delta_rule operator on the NPU, including: 1. Adding operator interface declaration 2. Implementing NPU backend computation logic 3. Updating CMake build configuration 4. Integrating and using it in Qwen3GatedDeltaNetBase
qwe3.5/qwen3-next model add npu_recurrent_gated_delta_rule and chunk_gated_delta_rule fusion operater
需依赖算子先合入:https://gitcode.com/xLLM-AI/torch_npu_ops/pull/12