Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions xllm/core/distributed_runtime/llm_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -613,9 +613,9 @@ bool LLMEngine::allocate_kv_cache(const Engine::KVCacheCapacity& kv_cache_cap) {
if (enable_gdn_attention) {
kv_cache_shape.emplace_back(std::vector<int64_t>{
kv_cache_cap.num_linear_state_blocks,
args_.linear_conv_kernel_dim() - 1,
Comment thread
maojunx99 marked this conversation as resolved.
args_.linear_key_head_dim() * n_local_linear_k_heads_ * 2 +
args_.linear_key_head_dim() * n_local_linear_v_heads_,
args_.linear_conv_kernel_dim() - 1});
args_.linear_key_head_dim() * n_local_linear_v_heads_});
kv_cache_shape.emplace_back(
std::vector<int64_t>{kv_cache_cap.num_linear_state_blocks,
n_local_linear_v_heads_,
Expand Down
26 changes: 13 additions & 13 deletions xllm/core/kernels/ops_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -921,19 +921,19 @@ torch::Tensor causal_conv1d_update(CausalConv1dUpdateParams& params) {
CHECK(params.conv_state_indices.value().is_contiguous())
<< "causal_conv1d_update: conv_state_indices must be contiguous.";
}
return npu::npu_causal_conv1d_update(params.x,
params.conv_state,
params.weight,
params.activation,
params.bias,
params.cache_seqlens,
params.conv_state_indices,
params.num_accepted_tokens,
params.query_start_loc,
params.max_query_len,
params.intermediate_conv_window,
params.pad_slot_id,
params.validate_data);
return npu::npu_causal_conv1d_update_v2(params.x,
params.conv_state,
params.weight,
params.activation,
params.bias,
params.conv_state_indices,
params.query_start_loc,
params.max_query_len,
params.pad_slot_id,
params.block_idx_last_scheduled_token,
params.initial_state_idx,
params.validate_data);

#else
NOT_IMPLEMENTED();
#endif
Expand Down
5 changes: 2 additions & 3 deletions xllm/core/kernels/param.h
Original file line number Diff line number Diff line change
Expand Up @@ -1352,13 +1352,12 @@ struct CausalConv1dUpdateParams {
torch::Tensor weight;
bool activation = true;
std::optional<torch::Tensor> bias = std::nullopt;
std::optional<torch::Tensor> cache_seqlens = std::nullopt;
std::optional<torch::Tensor> conv_state_indices = std::nullopt;
std::optional<torch::Tensor> num_accepted_tokens = std::nullopt;
std::optional<torch::Tensor> query_start_loc = std::nullopt;
int32_t max_query_len = -1;
std::optional<torch::Tensor> intermediate_conv_window = std::nullopt;
int32_t pad_slot_id = -1;
std::optional<torch::Tensor> block_idx_last_scheduled_token;
std::optional<torch::Tensor> initial_state_idx;
bool validate_data = false;
};

Expand Down
4 changes: 4 additions & 0 deletions xllm/core/layers/common/attention_metadata_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,10 @@ AttentionMetadata build_attention_metadata(
}
if (params.q_seq_lens.defined()) {
attn_metadata.q_seq_lens = params.q_seq_lens;
torch::Tensor cumsum_tensor =
torch::cumsum(attn_metadata.q_seq_lens, 0).to(torch::kInt32);
auto zero = torch::zeros({1}, cumsum_tensor.options());
attn_metadata.q_cu_seq_lens = torch::cat({zero, cumsum_tensor}, 0);
}
#endif

Expand Down
24 changes: 17 additions & 7 deletions xllm/core/layers/npu_torch/qwen3_gated_delta_net_base.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -360,15 +360,16 @@ torch::Tensor Qwen3GatedDeltaNetBaseImpl::forward(
auto conv_weight = conv1d_->weight();
auto linear_state_indices = get_linear_state_indices(input_params, device);

mixed_qkv = mixed_qkv.transpose(1, 2);
if (attn_metadata.is_prefill) {
mixed_qkv = mixed_qkv.transpose(1, 2);
torch::Tensor conv_state =
(seq_len < conv_kernel_size_ - 1)
? torch::pad(mixed_qkv, {0, conv_kernel_size_ - 1 - seq_len})
: (seq_len > conv_kernel_size_ - 1)
? mixed_qkv.narrow(
-1, seq_len - conv_kernel_size_ + 1, conv_kernel_size_ - 1)
: mixed_qkv;
Comment thread
yingxudeng marked this conversation as resolved.
conv_state = conv_state.transpose(1, 2).contiguous();
conv_cache.index_put_({linear_state_indices},
conv_state.to(conv_cache.dtype()));
torch::Tensor bias;
Expand All @@ -383,12 +384,21 @@ torch::Tensor Qwen3GatedDeltaNetBaseImpl::forward(
mixed_qkv = torch::silu(conv_output.slice(2, 0, seq_len));
Comment thread
zhang-minchao marked this conversation as resolved.

} else {
xllm::kernel::CausalConv1dUpdateParams params;
params.x = mixed_qkv;
params.conv_state = conv_cache;
params.weight = conv_weight;
params.conv_state_indices = linear_state_indices;
mixed_qkv = xllm::kernel::causal_conv1d_update(params);
xllm::kernel::CausalConv1dUpdateParams conv1d_params;
conv1d_params.x = mixed_qkv.reshape({-1, mixed_qkv.size(-1)});
conv1d_params.conv_state = conv_cache;
conv1d_params.weight = conv_weight;
conv1d_params.conv_state_indices = linear_state_indices;
conv1d_params.block_idx_last_scheduled_token =
std::optional<torch::Tensor>();
conv1d_params.initial_state_idx = std::optional<torch::Tensor>();
conv1d_params.query_start_loc = attn_metadata.q_cu_seq_lens;
conv1d_params.max_query_len = attn_metadata.max_query_len;
mixed_qkv = xllm::kernel::causal_conv1d_update(conv1d_params);
// Reshape back to 3D [batch_size, dim, seq_len]
mixed_qkv =
mixed_qkv.view({batch_size, -1, mixed_qkv.size(-1)}).contiguous();
mixed_qkv = mixed_qkv.transpose(1, 2);
}

// Compute gated delta net decay and beta terms.
Expand Down
Loading