@@ -360,15 +360,16 @@ torch::Tensor Qwen3GatedDeltaNetBaseImpl::forward(
360360 auto conv_weight = conv1d_->weight ();
361361 auto linear_state_indices = get_linear_state_indices (input_params, device);
362362
363- mixed_qkv = mixed_qkv.transpose (1 , 2 );
364363 if (attn_metadata.is_prefill ) {
364+ mixed_qkv = mixed_qkv.transpose (1 , 2 );
365365 torch::Tensor conv_state =
366366 (seq_len < conv_kernel_size_ - 1 )
367367 ? torch::pad (mixed_qkv, {0 , conv_kernel_size_ - 1 - seq_len})
368368 : (seq_len > conv_kernel_size_ - 1 )
369369 ? mixed_qkv.narrow (
370370 -1 , seq_len - conv_kernel_size_ + 1 , conv_kernel_size_ - 1 )
371371 : mixed_qkv;
372+ conv_state = conv_state.transpose (1 , 2 ).contiguous ();
372373 conv_cache.index_put_ ({linear_state_indices},
373374 conv_state.to (conv_cache.dtype ()));
374375 torch::Tensor bias;
@@ -383,12 +384,20 @@ torch::Tensor Qwen3GatedDeltaNetBaseImpl::forward(
383384 mixed_qkv = torch::silu (conv_output.slice (2 , 0 , seq_len));
384385
385386 } else {
386- xllm::kernel::CausalConv1dUpdateParams params;
387- params.x = mixed_qkv;
388- params.conv_state = conv_cache;
389- params.weight = conv_weight;
390- params.conv_state_indices = linear_state_indices;
391- mixed_qkv = xllm::kernel::causal_conv1d_update (params);
387+ xllm::kernel::CausalConv1dUpdateParams conv1d_params;
388+ conv1d_params.x = mixed_qkv.reshape ({-1 , mixed_qkv.size (-1 )});
389+ conv1d_params.conv_state = conv_cache;
390+ conv1d_params.weight = conv_weight;
391+ conv1d_params.conv_state_indices = linear_state_indices;
392+ conv1d_params.block_idx_last_scheduled_token =
393+ std::optional<torch::Tensor>();
394+ conv1d_params.initial_state_idx = std::optional<torch::Tensor>();
395+ conv1d_params.query_start_loc = attn_metadata.q_cu_seq_lens ;
396+ conv1d_params.max_query_len = attn_metadata.max_query_len ;
397+ mixed_qkv = xllm::kernel::causal_conv1d_update (conv1d_params);
398+ // Reshape back to 3D [batch_size, dim, seq_len]
399+ mixed_qkv = mixed_qkv.view ({batch_size, -1 , mixed_qkv.size (-1 )}).contiguous ();
400+ mixed_qkv = mixed_qkv.transpose (1 , 2 );
392401 }
393402
394403 // Compute gated delta net decay and beta terms.
0 commit comments