diff --git a/xllm/c_api/default.h b/xllm/c_api/default.h index 4a40dc3f7..fdde1fbea 100644 --- a/xllm/c_api/default.h +++ b/xllm/c_api/default.h @@ -72,6 +72,7 @@ const XLLM_RequestParams XLLM_LLM_REQUEST_PARAMS_DEFAULT = { .ttft_slo_ms = INT32_MAX, .tpot_slo_ms = INT32_MAX, .beam_width = 0, + .num_return_sequences = 0, .top_logprobs = 0, .top_k = -1, .top_p = 1.0, @@ -134,6 +135,7 @@ const XLLM_RequestParams XLLM_REC_REQUEST_PARAMS_DEFAULT = { .ttft_slo_ms = INT32_MAX, .tpot_slo_ms = INT32_MAX, .beam_width = 128, + .num_return_sequences = 0, .top_logprobs = 0, .top_k = -1, .top_p = 1.0, diff --git a/xllm/c_api/internal/helper.cpp b/xllm/c_api/internal/helper.cpp index d7f16b35e..0383a13b3 100644 --- a/xllm/c_api/internal/helper.cpp +++ b/xllm/c_api/internal/helper.cpp @@ -127,6 +127,8 @@ void transfer_request_params(InferenceType inference_type, xllm_request_params->repetition_penalty = final_request_params.repetition_penalty; xllm_request_params->beam_width = final_request_params.beam_width; + xllm_request_params->num_return_sequences = + final_request_params.num_return_sequences; xllm_request_params->top_logprobs = final_request_params.top_logprobs; xllm_request_params->temperature = final_request_params.temperature; xllm_request_params->request_id = final_request_params.request_id; diff --git a/xllm/c_api/types.h b/xllm/c_api/types.h index e95f1b125..8c6b5c00f 100644 --- a/xllm/c_api/types.h +++ b/xllm/c_api/types.h @@ -200,6 +200,9 @@ typedef struct XLLM_CAPI_EXPORT XLLM_RequestParams { /** Beam search width (0 = disable beam search) */ uint32_t beam_width; + /** Final number of beam search results to return (0 = use beam_width) */ + uint32_t num_return_sequences; + /** Number of top log probabilities to return */ int64_t top_logprobs; diff --git a/xllm/cc_api/internal.h b/xllm/cc_api/internal.h index 4bfb49e78..292cc1e5f 100644 --- a/xllm/cc_api/internal.h +++ b/xllm/cc_api/internal.h @@ -70,6 +70,8 @@ RequestParams transfer_request_params( xllm_request_params.stop = request_params.stop; xllm_request_params.stop_token_ids = request_params.stop_token_ids; xllm_request_params.beam_width = request_params.beam_width; + xllm_request_params.num_return_sequences = + request_params.num_return_sequences; xllm_request_params.top_logprobs = request_params.top_logprobs; return xllm_request_params; diff --git a/xllm/cc_api/types.h b/xllm/cc_api/types.h index ae62ec916..5fd1a1bf1 100644 --- a/xllm/cc_api/types.h +++ b/xllm/cc_api/types.h @@ -182,6 +182,8 @@ struct XLLM_CAPI_EXPORT XLLM_RequestParams { int32_t beam_width = 0; + int32_t num_return_sequences = 0; + // Number of top log probabilities to return. default = 0. int64_t top_logprobs = 0; diff --git a/xllm/core/distributed_runtime/rec_engine.cpp b/xllm/core/distributed_runtime/rec_engine.cpp index e9f02e989..861908462 100644 --- a/xllm/core/distributed_runtime/rec_engine.cpp +++ b/xllm/core/distributed_runtime/rec_engine.cpp @@ -675,7 +675,10 @@ ForwardOutput RecEngine::OneRecEnginePipeline::step( } timer.reset(); - batches[0].process_sample_output(decode_output.sample_output, false); + batches[0].process_sample_output( + decode_output.sample_output, + false, + /*force_requested_beam_result_size=*/i + 1 == kRecDecodeSteps); COUNTER_ADD(rec_sampling_latency_microseconds, timer.elapsed_microseconds()); } diff --git a/xllm/core/distributed_runtime/rec_master.cpp b/xllm/core/distributed_runtime/rec_master.cpp index fafb3ff16..3dc4ef02c 100644 --- a/xllm/core/distributed_runtime/rec_master.cpp +++ b/xllm/core/distributed_runtime/rec_master.cpp @@ -754,6 +754,7 @@ std::shared_ptr RecMaster::build_request_common( sampling_param.top_logprobs = sp.top_logprobs; sampling_param.is_embeddings = sp.is_embeddings; sampling_param.beam_width = sp.beam_width; + sampling_param.num_return_sequences = sp.num_return_sequences; if (best_of > sp.n) { sampling_param.logprobs = true; } diff --git a/xllm/core/framework/batch/batch.cpp b/xllm/core/framework/batch/batch.cpp index bda79bc0d..e2691374e 100644 --- a/xllm/core/framework/batch/batch.cpp +++ b/xllm/core/framework/batch/batch.cpp @@ -619,7 +619,8 @@ void Batch::process_beam_sequence_group(const ForwardOutput& output) { } void Batch::process_sample_output(const SampleOutput& sample_output, - bool replace_fake_token) { + bool replace_fake_token, + bool force_requested_beam_result_size) { if (sample_output.embeddings.defined()) { const int64_t num_seqs = sample_output.embeddings.size(0); int64_t output_idx = 0; @@ -673,7 +674,7 @@ void Batch::process_sample_output(const SampleOutput& sample_output, } if (!FLAGS_enable_schedule_overlap || replace_fake_token) { - process_beam_search(); + process_beam_search(force_requested_beam_result_size); } } @@ -721,9 +722,9 @@ void Batch::append_token_for_sequence(Sequence* seq, } } -void Batch::process_beam_search() { +void Batch::process_beam_search(bool force_requested_result_size) { for (auto* sequence_group : sequence_groups_) { - sequence_group->process_beam_search(); + sequence_group->process_beam_search(force_requested_result_size); } } diff --git a/xllm/core/framework/batch/batch.h b/xllm/core/framework/batch/batch.h index a14167b37..1843bb3c4 100644 --- a/xllm/core/framework/batch/batch.h +++ b/xllm/core/framework/batch/batch.h @@ -107,7 +107,8 @@ class Batch { // The boolean parameter `replace_fake_token` indicates // whether the current stage is the second stage. void process_sample_output(const SampleOutput& sample_output, - bool replace_fake_token); + bool replace_fake_token, + bool force_requested_beam_result_size = false); void process_sample_output(const RawForwardOutput& raw_output, bool replace_fake_token); @@ -156,7 +157,7 @@ class Batch { int token_idx, bool replace_fake_token); - void process_beam_search(); + void process_beam_search(bool force_requested_result_size = false); bool has_partial_finished_beam_group() const; std::unordered_map cal_seq_exchange_index( diff --git a/xllm/core/framework/request/request_params.cpp b/xllm/core/framework/request/request_params.cpp index 00f3e6be7..835bf8252 100644 --- a/xllm/core/framework/request/request_params.cpp +++ b/xllm/core/framework/request/request_params.cpp @@ -222,6 +222,9 @@ RequestParams::RequestParams(const proto::CompletionRequest& request, if (request.has_beam_width()) { beam_width = request.beam_width(); } + if (request.has_num_return_sequences()) { + num_return_sequences = request.num_return_sequences(); + } if (request.has_add_special_tokens()) { add_special_tokens = request.add_special_tokens(); } else { @@ -414,6 +417,9 @@ void init_from_chat_request(RequestParams& params, const ChatRequest& request) { if (request.has_beam_width()) { params.beam_width = request.beam_width(); } + if (request.has_num_return_sequences()) { + params.num_return_sequences = request.num_return_sequences(); + } if (request.has_add_special_tokens()) { params.add_special_tokens = request.add_special_tokens(); @@ -576,6 +582,24 @@ bool RequestParams::verify_params(OutputCallback callback) const { return false; } + if (num_return_sequences > 0) { + if (beam_width <= 0) { + CALLBACK_WITH_ERROR(StatusCode::INVALID_ARGUMENT, + "num_return_sequences requires beam_width > 0", + service_request_id, + source_xservice_addr); + return false; + } + if (num_return_sequences < beam_width) { + CALLBACK_WITH_ERROR( + StatusCode::INVALID_ARGUMENT, + "num_return_sequences must be greater than or equal to beam_width", + service_request_id, + source_xservice_addr); + return false; + } + } + if (logprobs) { if (echo) { CALLBACK_WITH_ERROR(StatusCode::INVALID_ARGUMENT, diff --git a/xllm/core/framework/request/request_params.h b/xllm/core/framework/request/request_params.h index 4b249f62d..189638c83 100644 --- a/xllm/core/framework/request/request_params.h +++ b/xllm/core/framework/request/request_params.h @@ -161,6 +161,7 @@ struct RequestParams { // beam search int32_t beam_width = 0; + int32_t num_return_sequences = 0; bool add_special_tokens = false; diff --git a/xllm/core/framework/request/sequences_group.cpp b/xllm/core/framework/request/sequences_group.cpp index cbdea80e3..21894187b 100644 --- a/xllm/core/framework/request/sequences_group.cpp +++ b/xllm/core/framework/request/sequences_group.cpp @@ -219,7 +219,7 @@ void SequencesGroup::generate_outputs_parallel( } } -void SequencesGroup::process_beam_search() { +void SequencesGroup::process_beam_search(bool force_requested_result_size) { if (!check_beam_search()) { return; } @@ -229,6 +229,12 @@ void SequencesGroup::process_beam_search() { } const size_t beam_width = sequence_params_.sampling_param->beam_width; + const int32_t requested_num_return_sequences = + sequence_params_.sampling_param->num_return_sequences > 0 + ? sequence_params_.sampling_param->num_return_sequences + : sequence_params_.sampling_param->beam_width; + const size_t requested_result_size = + static_cast(requested_num_return_sequences); const size_t topk = std::max(1, sequence_params_.sampling_param->top_logprobs); @@ -261,7 +267,9 @@ void SequencesGroup::process_beam_search() { build_source_info(sequences_[seq_index].get()); } - SimpleTopKOptimizerBeamCandidate topk_optimizer(beam_width); + const size_t target_result_size = + force_requested_result_size ? requested_result_size : beam_width; + SimpleTopKOptimizerBeamCandidate topk_optimizer(target_result_size); auto add_self_candidate = [&](size_t seq_index, Sequence* seq) { BeamCandidate candidate; candidate.source_index = seq_index; @@ -324,15 +332,22 @@ void SequencesGroup::process_beam_search() { return; } - const size_t result_size = std::min(beam_width, candidates.size()); + const size_t result_size = std::min(target_result_size, candidates.size()); CHECK(!sequences_.empty()); - std::vector> replacement_sequences(result_size); + const size_t existing_size = sequences_.size(); + const size_t existing_result_size = std::min(result_size, existing_size); + std::vector> replacement_sequences( + existing_result_size); + std::vector> tail_sequences; + if (result_size > existing_size) { + tail_sequences.reserve(result_size - existing_size); + } for (size_t i = 0; i < result_size; ++i) { const BeamCandidate& candidate = candidates[i]; const BeamSourceInfo& source_info = source_infos[candidate.source_index]; const bool need_replace = - i >= sequences_.size() || sequences_[i] == nullptr || + i >= existing_size || sequences_[i] == nullptr || sequences_[i]->num_prompt_tokens() != source_info.suffix_start_idx || sequences_[i]->num_tokens() - source_info.suffix_start_idx != source_info.generated_token_ids.size(); @@ -342,19 +357,28 @@ void SequencesGroup::process_beam_search() { CHECK_LT(candidate.source_index, sequences_.size()); CHECK(sequences_[candidate.source_index] != nullptr); - replacement_sequences[i] = - std::make_unique(*sequences_[candidate.source_index]); + if (i < existing_size) { + replacement_sequences[i] = + std::make_unique(*sequences_[candidate.source_index]); + } else { + tail_sequences.emplace_back( + std::make_unique(*sequences_[candidate.source_index])); + } } - if (sequences_.size() < result_size) { - sequences_.resize(result_size); + if (existing_size < result_size) { + sequences_.reserve(result_size); + for (auto& tail_sequence : tail_sequences) { + sequences_.emplace_back(std::move(tail_sequence)); + } } std::unordered_set reused_src; for (size_t i = 0; i < result_size; ++i) { const BeamCandidate& candidate = candidates[i]; const BeamSourceInfo& source_info = source_infos[candidate.source_index]; - if (replacement_sequences[i] != nullptr) { + if (i < replacement_sequences.size() && + replacement_sequences[i] != nullptr) { sequences_[i] = std::move(replacement_sequences[i]); } auto& next_seq = sequences_[i]; diff --git a/xllm/core/framework/request/sequences_group.h b/xllm/core/framework/request/sequences_group.h index c6227221a..934e7101e 100644 --- a/xllm/core/framework/request/sequences_group.h +++ b/xllm/core/framework/request/sequences_group.h @@ -48,7 +48,7 @@ class SequencesGroup { const Tokenizer& tokenizer, ThreadPool* thread_pool = nullptr); - void process_beam_search(); + void process_beam_search(bool force_requested_result_size = false); bool check_beam_search() { return sequence_params_.sampling_param->beam_width > 1; diff --git a/xllm/core/framework/sampling/sampling_params.h b/xllm/core/framework/sampling/sampling_params.h index 0799e79f0..6fc0df13b 100644 --- a/xllm/core/framework/sampling/sampling_params.h +++ b/xllm/core/framework/sampling/sampling_params.h @@ -36,6 +36,7 @@ struct RequestSamplingParam { bool do_sample = false; bool is_embeddings = false; int32_t beam_width = 0; + int32_t num_return_sequences = 0; }; struct SamplingParameters { diff --git a/xllm/proto/chat.proto b/xllm/proto/chat.proto index 61b6167a5..b948365c6 100644 --- a/xllm/proto/chat.proto +++ b/xllm/proto/chat.proto @@ -26,7 +26,7 @@ message ChatMessage { } -// Next Id: 30 +// Next Id: 45 message ChatRequest { // ID of the model to use. You can use the ListModels endpoint to list available models. string model = 1; @@ -129,6 +129,9 @@ message ChatRequest { optional int32 beam_width = 33; + // Final number of beam results to return. Defaults to beam_width when unset. + optional int32 num_return_sequences = 45; + optional bool add_special_tokens = 34; optional google.protobuf.Struct chat_template_kwargs = 35; diff --git a/xllm/proto/completion.proto b/xllm/proto/completion.proto index 3ed0a83cb..329fd8265 100644 --- a/xllm/proto/completion.proto +++ b/xllm/proto/completion.proto @@ -6,7 +6,7 @@ package xllm.proto; import "common.proto"; import "rec.proto"; -// Next ID: 26 +// Next ID: 41 message CompletionRequest { // ID of the model to use. (required) // You can use the ListModels endpoint to list available models. @@ -97,6 +97,9 @@ message CompletionRequest { optional int32 beam_width = 29; + // Final number of beam results to return. Defaults to beam_width when unset. + optional int32 num_return_sequences = 41; + optional bool add_special_tokens = 30; // tensor for rec embedding. repeated InferInputTensor input_tensors = 31; diff --git a/xllm/proto/multimodal.proto b/xllm/proto/multimodal.proto index e5c3f888a..724e28409 100644 --- a/xllm/proto/multimodal.proto +++ b/xllm/proto/multimodal.proto @@ -42,7 +42,7 @@ message MMChatMessage { optional string tool_call_id = 4; } -// Next Id: 27 +// Next Id: 43 message MMChatRequest { // ID of the model to use. You can use the ListModels endpoint to list available models. @@ -145,6 +145,9 @@ message MMChatRequest { optional int32 beam_width = 32; + // Final number of beam results to return. Defaults to beam_width when unset. + optional int32 num_return_sequences = 42; + optional bool add_special_tokens = 33; optional google.protobuf.Struct chat_template_kwargs = 34; diff --git a/xllm/pybind/bind.cpp b/xllm/pybind/bind.cpp index e404f274d..0b22113e9 100644 --- a/xllm/pybind/bind.cpp +++ b/xllm/pybind/bind.cpp @@ -195,6 +195,8 @@ PYBIND11_MODULE(xllm_export, m) { .def_readwrite("stop", &RequestParams::stop) .def_readwrite("stop_token_ids", &RequestParams::stop_token_ids) .def_readwrite("beam_width", &RequestParams::beam_width) + .def_readwrite("num_return_sequences", + &RequestParams::num_return_sequences) .def_readwrite("add_special_tokens", &RequestParams::add_special_tokens) .def_readwrite("is_sample_request", &RequestParams::is_sample_request) .def_readwrite("sample_slots", &RequestParams::sample_slots);