@@ -929,17 +929,8 @@ std::optional<ForwardOutput> RecWorkerImpl::LlmRecMultiRoundPipeline::step(
929929 << " )" ;
930930
931931#if defined(USE_NPU)
932- if (round == 0 ) {
933- // NPU beam_search_rec prefill only accepts top_k == 1. Flatten the
934- // sampler's per-request top-k output into one candidate per beam so
935- // the kernel can seed sequence_group without discarding beam diversity.
936- top_tokens =
937- sample_output.top_tokens .to (torch::kInt32 ).reshape ({-1 , 1 });
938- top_logprobs = sample_output.top_logprobs .reshape ({-1 , 1 });
939- } else {
940- top_tokens = sample_output.top_tokens .to (torch::kInt32 );
941- top_logprobs = sample_output.top_logprobs ;
942- }
932+ top_tokens = sample_output.top_tokens .to (torch::kInt32 );
933+ top_logprobs = sample_output.top_logprobs ;
943934#else
944935 top_tokens = sample_output.top_tokens .to (torch::kInt32 )
945936 .reshape ({-1 , step_meta->beam_width });
@@ -997,18 +988,25 @@ void RecWorkerImpl::LlmRecMultiRoundPipeline::execute_beam_search(
997988 int32_t round,
998989 int32_t batch_size) {
999990#if defined(USE_NPU)
1000- xllm::kernel::npu::beam_search_rec (
1001- /* logprobs=*/ beam_tensors.acc_logprob ,
1002- /* top_tokens=*/ top_tokens.to (torch::kInt32 ),
1003- /* top_logprobs=*/ top_logprobs,
1004- /* sequence_group=*/ beam_tensors.sequence_group ,
1005- /* current_step=*/ static_cast <int64_t >(round),
1006- /* out_token_ids=*/ beam_tensors.out_token_ids ,
1007- /* out_token_index=*/ beam_tensors.out_token_index ,
1008- /* out_log_probs=*/ beam_tensors.out_log_probs ,
1009- /* out_beam_count_prefix_sums=*/
1010- beam_tensors.out_beam_count_prefix_sums ,
1011- /* out_sequence=*/ beam_tensors.out_seqgroup );
991+ if (round == 0 ) {
992+ beam_tensors.out_token_ids .copy_ (top_tokens.reshape ({-1 , 1 }));
993+ beam_tensors.out_log_probs .copy_ (top_logprobs.reshape ({-1 , 1 }));
994+ beam_tensors.out_seqgroup .select (/* dim=*/ 2 , /* index=*/ 0 )
995+ .copy_ (top_tokens.reshape ({-1 }));
996+ } else {
997+ xllm::kernel::npu::beam_search_rec (
998+ /* logprobs=*/ beam_tensors.acc_logprob ,
999+ /* top_tokens=*/ top_tokens,
1000+ /* top_logprobs=*/ top_logprobs,
1001+ /* sequence_group=*/ beam_tensors.sequence_group ,
1002+ /* current_step=*/ static_cast <int64_t >(round),
1003+ /* out_token_ids=*/ beam_tensors.out_token_ids ,
1004+ /* out_token_index=*/ beam_tensors.out_token_index ,
1005+ /* out_log_probs=*/ beam_tensors.out_log_probs ,
1006+ /* out_beam_count_prefix_sums=*/
1007+ beam_tensors.out_beam_count_prefix_sums ,
1008+ /* out_sequence=*/ beam_tensors.out_seqgroup );
1009+ }
10121010#elif defined(USE_CUDA)
10131011 xllm::kernel::cuda::beam_search (beam_tensors.acc_logprob ,
10141012 beam_tensors.sequence_group ,
0 commit comments