diff --git a/xllm/core/runtime/rec_worker_impl.cpp b/xllm/core/runtime/rec_worker_impl.cpp index dfe450493..e8979f9d1 100644 --- a/xllm/core/runtime/rec_worker_impl.cpp +++ b/xllm/core/runtime/rec_worker_impl.cpp @@ -929,17 +929,8 @@ std::optional RecWorkerImpl::LlmRecMultiRoundPipeline::step( << ")"; #if defined(USE_NPU) - if (round == 0) { - // NPU beam_search_rec prefill only accepts top_k == 1. Flatten the - // sampler's per-request top-k output into one candidate per beam so - // the kernel can seed sequence_group without discarding beam diversity. - top_tokens = - sample_output.top_tokens.to(torch::kInt32).reshape({-1, 1}); - top_logprobs = sample_output.top_logprobs.reshape({-1, 1}); - } else { - top_tokens = sample_output.top_tokens.to(torch::kInt32); - top_logprobs = sample_output.top_logprobs; - } + top_tokens = sample_output.top_tokens.to(torch::kInt32); + top_logprobs = sample_output.top_logprobs; #else top_tokens = sample_output.top_tokens.to(torch::kInt32) .reshape({-1, step_meta->beam_width}); @@ -997,18 +988,25 @@ void RecWorkerImpl::LlmRecMultiRoundPipeline::execute_beam_search( int32_t round, int32_t batch_size) { #if defined(USE_NPU) - xllm::kernel::npu::beam_search_rec( - /*logprobs=*/beam_tensors.acc_logprob, - /*top_tokens=*/top_tokens.to(torch::kInt32), - /*top_logprobs=*/top_logprobs, - /*sequence_group=*/beam_tensors.sequence_group, - /*current_step=*/static_cast(round), - /*out_token_ids=*/beam_tensors.out_token_ids, - /*out_token_index=*/beam_tensors.out_token_index, - /*out_log_probs=*/beam_tensors.out_log_probs, - /*out_beam_count_prefix_sums=*/ - beam_tensors.out_beam_count_prefix_sums, - /*out_sequence=*/beam_tensors.out_seqgroup); + if (round == 0) { + beam_tensors.out_token_ids.copy_(top_tokens.reshape({-1, 1})); + beam_tensors.out_log_probs.copy_(top_logprobs.reshape({-1, 1})); + beam_tensors.out_seqgroup.select(/*dim=*/2, /*index=*/0) + .copy_(top_tokens.reshape({-1})); + } else { + xllm::kernel::npu::beam_search_rec( + /*logprobs=*/beam_tensors.acc_logprob, + /*top_tokens=*/top_tokens, + /*top_logprobs=*/top_logprobs, + /*sequence_group=*/beam_tensors.sequence_group, + /*current_step=*/static_cast(round), + /*out_token_ids=*/beam_tensors.out_token_ids, + /*out_token_index=*/beam_tensors.out_token_index, + /*out_log_probs=*/beam_tensors.out_log_probs, + /*out_beam_count_prefix_sums=*/ + beam_tensors.out_beam_count_prefix_sums, + /*out_sequence=*/beam_tensors.out_seqgroup); + } #elif defined(USE_CUDA) xllm::kernel::cuda::beam_search(beam_tensors.acc_logprob, beam_tensors.sequence_group,