Skip to content

Commit 4e5a6de

Browse files
committed
bugfix: fix the accuracy error of NPU xattention.
1 parent 4fe61bd commit 4e5a6de

2 files changed

Lines changed: 22 additions & 24 deletions

File tree

third_party/xllm_ops

Submodule xllm_ops updated from d2236de to bca7c61

xllm/core/runtime/rec_worker_impl.cpp

Lines changed: 21 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -929,17 +929,8 @@ std::optional<ForwardOutput> RecWorkerImpl::LlmRecMultiRoundPipeline::step(
929929
<< ")";
930930

931931
#if defined(USE_NPU)
932-
if (round == 0) {
933-
// NPU beam_search_rec prefill only accepts top_k == 1. Flatten the
934-
// sampler's per-request top-k output into one candidate per beam so
935-
// the kernel can seed sequence_group without discarding beam diversity.
936-
top_tokens =
937-
sample_output.top_tokens.to(torch::kInt32).reshape({-1, 1});
938-
top_logprobs = sample_output.top_logprobs.reshape({-1, 1});
939-
} else {
940-
top_tokens = sample_output.top_tokens.to(torch::kInt32);
941-
top_logprobs = sample_output.top_logprobs;
942-
}
932+
top_tokens = sample_output.top_tokens.to(torch::kInt32);
933+
top_logprobs = sample_output.top_logprobs;
943934
#else
944935
top_tokens = sample_output.top_tokens.to(torch::kInt32)
945936
.reshape({-1, step_meta->beam_width});
@@ -997,18 +988,25 @@ void RecWorkerImpl::LlmRecMultiRoundPipeline::execute_beam_search(
997988
int32_t round,
998989
int32_t batch_size) {
999990
#if defined(USE_NPU)
1000-
xllm::kernel::npu::beam_search_rec(
1001-
/*logprobs=*/beam_tensors.acc_logprob,
1002-
/*top_tokens=*/top_tokens.to(torch::kInt32),
1003-
/*top_logprobs=*/top_logprobs,
1004-
/*sequence_group=*/beam_tensors.sequence_group,
1005-
/*current_step=*/static_cast<int64_t>(round),
1006-
/*out_token_ids=*/beam_tensors.out_token_ids,
1007-
/*out_token_index=*/beam_tensors.out_token_index,
1008-
/*out_log_probs=*/beam_tensors.out_log_probs,
1009-
/*out_beam_count_prefix_sums=*/
1010-
beam_tensors.out_beam_count_prefix_sums,
1011-
/*out_sequence=*/beam_tensors.out_seqgroup);
991+
if (round == 0) {
992+
beam_tensors.out_token_ids.copy_(top_tokens.reshape({-1, 1}));
993+
beam_tensors.out_log_probs.copy_(top_logprobs.reshape({-1, 1}));
994+
beam_tensors.out_seqgroup.select(/*dim=*/2, /*index=*/0)
995+
.copy_(top_tokens.reshape({-1}));
996+
} else {
997+
xllm::kernel::npu::beam_search_rec(
998+
/*logprobs=*/beam_tensors.acc_logprob,
999+
/*top_tokens=*/top_tokens,
1000+
/*top_logprobs=*/top_logprobs,
1001+
/*sequence_group=*/beam_tensors.sequence_group,
1002+
/*current_step=*/static_cast<int64_t>(round),
1003+
/*out_token_ids=*/beam_tensors.out_token_ids,
1004+
/*out_token_index=*/beam_tensors.out_token_index,
1005+
/*out_log_probs=*/beam_tensors.out_log_probs,
1006+
/*out_beam_count_prefix_sums=*/
1007+
beam_tensors.out_beam_count_prefix_sums,
1008+
/*out_sequence=*/beam_tensors.out_seqgroup);
1009+
}
10121010
#elif defined(USE_CUDA)
10131011
xllm::kernel::cuda::beam_search(beam_tensors.acc_logprob,
10141012
beam_tensors.sequence_group,

0 commit comments

Comments
 (0)