Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion third_party/torch_npu_ops
Submodule torch_npu_ops updated from 9dc44e to 599eb0
4 changes: 2 additions & 2 deletions xllm/core/distributed_runtime/llm_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -613,9 +613,9 @@ bool LLMEngine::allocate_kv_cache(const Engine::KVCacheCapacity& kv_cache_cap) {
if (enable_gdn_attention) {
kv_cache_shape.emplace_back(std::vector<int64_t>{
kv_cache_cap.num_linear_state_blocks,
args_.linear_conv_kernel_dim() - 1,
Comment thread
maojunx99 marked this conversation as resolved.
args_.linear_key_head_dim() * n_local_linear_k_heads_ * 2 +
args_.linear_key_head_dim() * n_local_linear_v_heads_,
args_.linear_conv_kernel_dim() - 1});
args_.linear_key_head_dim() * n_local_linear_v_heads_});
kv_cache_shape.emplace_back(
std::vector<int64_t>{kv_cache_cap.num_linear_state_blocks,
n_local_linear_v_heads_,
Expand Down
26 changes: 13 additions & 13 deletions xllm/core/kernels/ops_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -921,19 +921,19 @@ torch::Tensor causal_conv1d_update(CausalConv1dUpdateParams& params) {
CHECK(params.conv_state_indices.value().is_contiguous())
<< "causal_conv1d_update: conv_state_indices must be contiguous.";
}
return npu::npu_causal_conv1d_update(params.x,
params.conv_state,
params.weight,
params.activation,
params.bias,
params.cache_seqlens,
params.conv_state_indices,
params.num_accepted_tokens,
params.query_start_loc,
params.max_query_len,
params.intermediate_conv_window,
params.pad_slot_id,
params.validate_data);
return npu::npu_causal_conv1d_update_v2(params.x,
params.conv_state,
params.weight,
params.activation,
params.bias,
params.conv_state_indices,
params.query_start_loc,
params.max_query_len,
params.pad_slot_id,
params.block_idx_last_scheduled_token,
params.initial_state_idx,
params.validate_data);

#else
NOT_IMPLEMENTED();
#endif
Expand Down
5 changes: 2 additions & 3 deletions xllm/core/kernels/param.h
Original file line number Diff line number Diff line change
Expand Up @@ -1352,13 +1352,12 @@ struct CausalConv1dUpdateParams {
torch::Tensor weight;
bool activation = true;
std::optional<torch::Tensor> bias = std::nullopt;
std::optional<torch::Tensor> cache_seqlens = std::nullopt;
std::optional<torch::Tensor> conv_state_indices = std::nullopt;
std::optional<torch::Tensor> num_accepted_tokens = std::nullopt;
std::optional<torch::Tensor> query_start_loc = std::nullopt;
int32_t max_query_len = -1;
std::optional<torch::Tensor> intermediate_conv_window = std::nullopt;
int32_t pad_slot_id = -1;
std::optional<torch::Tensor> block_idx_last_scheduled_token;
std::optional<torch::Tensor> initial_state_idx;
bool validate_data = false;
};

Expand Down
4 changes: 4 additions & 0 deletions xllm/core/layers/common/attention_metadata_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,10 @@ AttentionMetadata build_attention_metadata(
}
if (params.q_seq_lens.defined()) {
attn_metadata.q_seq_lens = params.q_seq_lens;
torch::Tensor cumsum_tensor =
torch::cumsum(attn_metadata.q_seq_lens, 0).to(torch::kInt32);
auto zero = torch::zeros({1}, cumsum_tensor.options());
attn_metadata.q_cu_seq_lens = torch::cat({zero, cumsum_tensor}, 0);
}
#endif

Expand Down
24 changes: 17 additions & 7 deletions xllm/core/layers/npu_torch/qwen3_gated_delta_net_base.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -360,15 +360,16 @@ torch::Tensor Qwen3GatedDeltaNetBaseImpl::forward(
auto conv_weight = conv1d_->weight();
auto linear_state_indices = get_linear_state_indices(input_params, device);

mixed_qkv = mixed_qkv.transpose(1, 2);
if (attn_metadata.is_prefill) {
mixed_qkv = mixed_qkv.transpose(1, 2);
torch::Tensor conv_state =
(seq_len < conv_kernel_size_ - 1)
? torch::pad(mixed_qkv, {0, conv_kernel_size_ - 1 - seq_len})
: (seq_len > conv_kernel_size_ - 1)
? mixed_qkv.narrow(
-1, seq_len - conv_kernel_size_ + 1, conv_kernel_size_ - 1)
: mixed_qkv;
Comment thread
yingxudeng marked this conversation as resolved.
conv_state = conv_state.transpose(1, 2).contiguous();
conv_cache.index_put_({linear_state_indices},
conv_state.to(conv_cache.dtype()));
torch::Tensor bias;
Expand All @@ -383,12 +384,21 @@ torch::Tensor Qwen3GatedDeltaNetBaseImpl::forward(
mixed_qkv = torch::silu(conv_output.slice(2, 0, seq_len));
Comment thread
zhang-minchao marked this conversation as resolved.

} else {
xllm::kernel::CausalConv1dUpdateParams params;
params.x = mixed_qkv;
params.conv_state = conv_cache;
params.weight = conv_weight;
params.conv_state_indices = linear_state_indices;
mixed_qkv = xllm::kernel::causal_conv1d_update(params);
xllm::kernel::CausalConv1dUpdateParams conv1d_params;
conv1d_params.x = mixed_qkv.reshape({-1, mixed_qkv.size(-1)});
conv1d_params.conv_state = conv_cache;
conv1d_params.weight = conv_weight;
conv1d_params.conv_state_indices = linear_state_indices;
conv1d_params.block_idx_last_scheduled_token =
std::optional<torch::Tensor>();
conv1d_params.initial_state_idx = std::optional<torch::Tensor>();
conv1d_params.query_start_loc = attn_metadata.q_cu_seq_lens;
conv1d_params.max_query_len = attn_metadata.max_query_len;
mixed_qkv = xllm::kernel::causal_conv1d_update(conv1d_params);
// Reshape back to 3D [batch_size, dim, seq_len]
mixed_qkv =
mixed_qkv.view({batch_size, -1, mixed_qkv.size(-1)}).contiguous();
mixed_qkv = mixed_qkv.transpose(1, 2);
}

// Compute gated delta net decay and beta terms.
Expand Down
Loading