diff --git a/third_party/torch_npu_ops b/third_party/torch_npu_ops index 9dc44e054..599eb0334 160000 --- a/third_party/torch_npu_ops +++ b/third_party/torch_npu_ops @@ -1 +1 @@ -Subproject commit 9dc44e054e62a5afc778491674ec60d2298a7a1b +Subproject commit 599eb033413ec249e0d614796f0bcfedc5191253 diff --git a/xllm/core/distributed_runtime/llm_engine.cpp b/xllm/core/distributed_runtime/llm_engine.cpp index f71da268a..5764a9f8c 100644 --- a/xllm/core/distributed_runtime/llm_engine.cpp +++ b/xllm/core/distributed_runtime/llm_engine.cpp @@ -613,9 +613,9 @@ bool LLMEngine::allocate_kv_cache(const Engine::KVCacheCapacity& kv_cache_cap) { if (enable_gdn_attention) { kv_cache_shape.emplace_back(std::vector{ kv_cache_cap.num_linear_state_blocks, + args_.linear_conv_kernel_dim() - 1, args_.linear_key_head_dim() * n_local_linear_k_heads_ * 2 + - args_.linear_key_head_dim() * n_local_linear_v_heads_, - args_.linear_conv_kernel_dim() - 1}); + args_.linear_key_head_dim() * n_local_linear_v_heads_}); kv_cache_shape.emplace_back( std::vector{kv_cache_cap.num_linear_state_blocks, n_local_linear_v_heads_, diff --git a/xllm/core/kernels/ops_api.cpp b/xllm/core/kernels/ops_api.cpp index d4ea329ff..529e7a6af 100644 --- a/xllm/core/kernels/ops_api.cpp +++ b/xllm/core/kernels/ops_api.cpp @@ -921,19 +921,19 @@ torch::Tensor causal_conv1d_update(CausalConv1dUpdateParams& params) { CHECK(params.conv_state_indices.value().is_contiguous()) << "causal_conv1d_update: conv_state_indices must be contiguous."; } - return npu::npu_causal_conv1d_update(params.x, - params.conv_state, - params.weight, - params.activation, - params.bias, - params.cache_seqlens, - params.conv_state_indices, - params.num_accepted_tokens, - params.query_start_loc, - params.max_query_len, - params.intermediate_conv_window, - params.pad_slot_id, - params.validate_data); + return npu::npu_causal_conv1d_update_v2(params.x, + params.conv_state, + params.weight, + params.activation, + params.bias, + params.conv_state_indices, + params.query_start_loc, + params.max_query_len, + params.pad_slot_id, + params.block_idx_last_scheduled_token, + params.initial_state_idx, + params.validate_data); + #else NOT_IMPLEMENTED(); #endif diff --git a/xllm/core/kernels/param.h b/xllm/core/kernels/param.h index 6b76b20e6..fe60f20bd 100644 --- a/xllm/core/kernels/param.h +++ b/xllm/core/kernels/param.h @@ -1352,13 +1352,12 @@ struct CausalConv1dUpdateParams { torch::Tensor weight; bool activation = true; std::optional bias = std::nullopt; - std::optional cache_seqlens = std::nullopt; std::optional conv_state_indices = std::nullopt; - std::optional num_accepted_tokens = std::nullopt; std::optional query_start_loc = std::nullopt; int32_t max_query_len = -1; - std::optional intermediate_conv_window = std::nullopt; int32_t pad_slot_id = -1; + std::optional block_idx_last_scheduled_token; + std::optional initial_state_idx; bool validate_data = false; }; diff --git a/xllm/core/layers/common/attention_metadata_builder.cpp b/xllm/core/layers/common/attention_metadata_builder.cpp index 399d3eccf..fc74993ab 100644 --- a/xllm/core/layers/common/attention_metadata_builder.cpp +++ b/xllm/core/layers/common/attention_metadata_builder.cpp @@ -114,6 +114,10 @@ AttentionMetadata build_attention_metadata( } if (params.q_seq_lens.defined()) { attn_metadata.q_seq_lens = params.q_seq_lens; + torch::Tensor cumsum_tensor = + torch::cumsum(attn_metadata.q_seq_lens, 0).to(torch::kInt32); + auto zero = torch::zeros({1}, cumsum_tensor.options()); + attn_metadata.q_cu_seq_lens = torch::cat({zero, cumsum_tensor}, 0); } #endif diff --git a/xllm/core/layers/npu_torch/qwen3_gated_delta_net_base.cpp b/xllm/core/layers/npu_torch/qwen3_gated_delta_net_base.cpp index cb877a272..1bf7190e2 100644 --- a/xllm/core/layers/npu_torch/qwen3_gated_delta_net_base.cpp +++ b/xllm/core/layers/npu_torch/qwen3_gated_delta_net_base.cpp @@ -360,8 +360,8 @@ torch::Tensor Qwen3GatedDeltaNetBaseImpl::forward( auto conv_weight = conv1d_->weight(); auto linear_state_indices = get_linear_state_indices(input_params, device); - mixed_qkv = mixed_qkv.transpose(1, 2); if (attn_metadata.is_prefill) { + mixed_qkv = mixed_qkv.transpose(1, 2); torch::Tensor conv_state = (seq_len < conv_kernel_size_ - 1) ? torch::pad(mixed_qkv, {0, conv_kernel_size_ - 1 - seq_len}) @@ -369,6 +369,7 @@ torch::Tensor Qwen3GatedDeltaNetBaseImpl::forward( ? mixed_qkv.narrow( -1, seq_len - conv_kernel_size_ + 1, conv_kernel_size_ - 1) : mixed_qkv; + conv_state = conv_state.transpose(1, 2).contiguous(); conv_cache.index_put_({linear_state_indices}, conv_state.to(conv_cache.dtype())); torch::Tensor bias; @@ -383,12 +384,21 @@ torch::Tensor Qwen3GatedDeltaNetBaseImpl::forward( mixed_qkv = torch::silu(conv_output.slice(2, 0, seq_len)); } else { - xllm::kernel::CausalConv1dUpdateParams params; - params.x = mixed_qkv; - params.conv_state = conv_cache; - params.weight = conv_weight; - params.conv_state_indices = linear_state_indices; - mixed_qkv = xllm::kernel::causal_conv1d_update(params); + xllm::kernel::CausalConv1dUpdateParams conv1d_params; + conv1d_params.x = mixed_qkv.reshape({-1, mixed_qkv.size(-1)}); + conv1d_params.conv_state = conv_cache; + conv1d_params.weight = conv_weight; + conv1d_params.conv_state_indices = linear_state_indices; + conv1d_params.block_idx_last_scheduled_token = + std::optional(); + conv1d_params.initial_state_idx = std::optional(); + conv1d_params.query_start_loc = attn_metadata.q_cu_seq_lens; + conv1d_params.max_query_len = attn_metadata.max_query_len; + mixed_qkv = xllm::kernel::causal_conv1d_update(conv1d_params); + // Reshape back to 3D [batch_size, dim, seq_len] + mixed_qkv = + mixed_qkv.view({batch_size, -1, mixed_qkv.size(-1)}).contiguous(); + mixed_qkv = mixed_qkv.transpose(1, 2); } // Compute gated delta net decay and beta terms.