Skip to content

Commit ab4ed2e

Browse files
committed
update conv1d_update op
1 parent 7eff2db commit ab4ed2e

5 files changed

Lines changed: 38 additions & 25 deletions

File tree

xllm/core/distributed_runtime/llm_engine.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -613,9 +613,9 @@ bool LLMEngine::allocate_kv_cache(const Engine::KVCacheCapacity& kv_cache_cap) {
613613
if (enable_gdn_attention) {
614614
kv_cache_shape.emplace_back(std::vector<int64_t>{
615615
kv_cache_cap.num_linear_state_blocks,
616+
args_.linear_conv_kernel_dim() - 1,
616617
args_.linear_key_head_dim() * n_local_linear_k_heads_ * 2 +
617-
args_.linear_key_head_dim() * n_local_linear_v_heads_,
618-
args_.linear_conv_kernel_dim() - 1});
618+
args_.linear_key_head_dim() * n_local_linear_v_heads_});
619619
kv_cache_shape.emplace_back(
620620
std::vector<int64_t>{kv_cache_cap.num_linear_state_blocks,
621621
n_local_linear_v_heads_,

xllm/core/kernels/ops_api.cpp

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -921,19 +921,19 @@ torch::Tensor causal_conv1d_update(CausalConv1dUpdateParams& params) {
921921
CHECK(params.conv_state_indices.value().is_contiguous())
922922
<< "causal_conv1d_update: conv_state_indices must be contiguous.";
923923
}
924-
return npu::npu_causal_conv1d_update(params.x,
925-
params.conv_state,
926-
params.weight,
927-
params.activation,
928-
params.bias,
929-
params.cache_seqlens,
930-
params.conv_state_indices,
931-
params.num_accepted_tokens,
932-
params.query_start_loc,
933-
params.max_query_len,
934-
params.intermediate_conv_window,
935-
params.pad_slot_id,
936-
params.validate_data);
924+
return npu::npu_causal_conv1d_update_v2(params.x,
925+
params.conv_state,
926+
params.weight,
927+
params.activation,
928+
params.bias,
929+
params.conv_state_indices,
930+
params.query_start_loc,
931+
params.max_query_len,
932+
params.pad_slot_id,
933+
params.block_idx_last_scheduled_token,
934+
params.initial_state_idx,
935+
params.validate_data);
936+
937937
#else
938938
NOT_IMPLEMENTED();
939939
#endif

xllm/core/kernels/param.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1352,13 +1352,12 @@ struct CausalConv1dUpdateParams {
13521352
torch::Tensor weight;
13531353
bool activation = true;
13541354
std::optional<torch::Tensor> bias = std::nullopt;
1355-
std::optional<torch::Tensor> cache_seqlens = std::nullopt;
13561355
std::optional<torch::Tensor> conv_state_indices = std::nullopt;
1357-
std::optional<torch::Tensor> num_accepted_tokens = std::nullopt;
13581356
std::optional<torch::Tensor> query_start_loc = std::nullopt;
13591357
int32_t max_query_len = -1;
1360-
std::optional<torch::Tensor> intermediate_conv_window = std::nullopt;
13611358
int32_t pad_slot_id = -1;
1359+
std::optional<torch::Tensor> block_idx_last_scheduled_token;
1360+
std::optional<torch::Tensor> initial_state_idx;
13621361
bool validate_data = false;
13631362
};
13641363

xllm/core/layers/common/attention_metadata_builder.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,10 @@ AttentionMetadata build_attention_metadata(
114114
}
115115
if (params.q_seq_lens.defined()) {
116116
attn_metadata.q_seq_lens = params.q_seq_lens;
117+
torch::Tensor cumsum_tensor =
118+
torch::cumsum(attn_metadata.q_seq_lens, 0).to(torch::kInt32);
119+
auto zero = torch::zeros({1}, cumsum_tensor.options());
120+
attn_metadata.q_cu_seq_lens = torch::cat({zero, cumsum_tensor}, 0);
117121
}
118122
#endif
119123

xllm/core/layers/npu_torch/qwen3_gated_delta_net_base.cpp

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -360,15 +360,16 @@ torch::Tensor Qwen3GatedDeltaNetBaseImpl::forward(
360360
auto conv_weight = conv1d_->weight();
361361
auto linear_state_indices = get_linear_state_indices(input_params, device);
362362

363-
mixed_qkv = mixed_qkv.transpose(1, 2);
364363
if (attn_metadata.is_prefill) {
364+
mixed_qkv = mixed_qkv.transpose(1, 2);
365365
torch::Tensor conv_state =
366366
(seq_len < conv_kernel_size_ - 1)
367367
? torch::pad(mixed_qkv, {0, conv_kernel_size_ - 1 - seq_len})
368368
: (seq_len > conv_kernel_size_ - 1)
369369
? mixed_qkv.narrow(
370370
-1, seq_len - conv_kernel_size_ + 1, conv_kernel_size_ - 1)
371371
: mixed_qkv;
372+
conv_state = conv_state.transpose(1, 2).contiguous();
372373
conv_cache.index_put_({linear_state_indices},
373374
conv_state.to(conv_cache.dtype()));
374375
torch::Tensor bias;
@@ -383,12 +384,21 @@ torch::Tensor Qwen3GatedDeltaNetBaseImpl::forward(
383384
mixed_qkv = torch::silu(conv_output.slice(2, 0, seq_len));
384385

385386
} else {
386-
xllm::kernel::CausalConv1dUpdateParams params;
387-
params.x = mixed_qkv;
388-
params.conv_state = conv_cache;
389-
params.weight = conv_weight;
390-
params.conv_state_indices = linear_state_indices;
391-
mixed_qkv = xllm::kernel::causal_conv1d_update(params);
387+
xllm::kernel::CausalConv1dUpdateParams conv1d_params;
388+
conv1d_params.x = mixed_qkv.reshape({-1, mixed_qkv.size(-1)});
389+
conv1d_params.conv_state = conv_cache;
390+
conv1d_params.weight = conv_weight;
391+
conv1d_params.conv_state_indices = linear_state_indices;
392+
conv1d_params.block_idx_last_scheduled_token =
393+
std::optional<torch::Tensor>();
394+
conv1d_params.initial_state_idx = std::optional<torch::Tensor>();
395+
conv1d_params.query_start_loc = attn_metadata.q_cu_seq_lens;
396+
conv1d_params.max_query_len = attn_metadata.max_query_len;
397+
mixed_qkv = xllm::kernel::causal_conv1d_update(conv1d_params);
398+
// Reshape back to 3D [batch_size, dim, seq_len]
399+
mixed_qkv =
400+
mixed_qkv.view({batch_size, -1, mixed_qkv.size(-1)}).contiguous();
401+
mixed_qkv = mixed_qkv.transpose(1, 2);
392402
}
393403

394404
// Compute gated delta net decay and beta terms.

0 commit comments

Comments
 (0)