Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions xllm/c_api/default.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ const XLLM_RequestParams XLLM_LLM_REQUEST_PARAMS_DEFAULT = {
.ttft_slo_ms = INT32_MAX,
.tpot_slo_ms = INT32_MAX,
.beam_width = 0,
.num_return_sequences = 0,
.top_logprobs = 0,
.top_k = -1,
.top_p = 1.0,
Expand Down Expand Up @@ -134,6 +135,7 @@ const XLLM_RequestParams XLLM_REC_REQUEST_PARAMS_DEFAULT = {
.ttft_slo_ms = INT32_MAX,
.tpot_slo_ms = INT32_MAX,
.beam_width = 128,
.num_return_sequences = 0,
.top_logprobs = 0,
.top_k = -1,
.top_p = 1.0,
Expand Down
2 changes: 2 additions & 0 deletions xllm/c_api/internal/helper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,8 @@ void transfer_request_params(InferenceType inference_type,
xllm_request_params->repetition_penalty =
final_request_params.repetition_penalty;
xllm_request_params->beam_width = final_request_params.beam_width;
xllm_request_params->num_return_sequences =
final_request_params.num_return_sequences;
xllm_request_params->top_logprobs = final_request_params.top_logprobs;
xllm_request_params->temperature = final_request_params.temperature;
xllm_request_params->request_id = final_request_params.request_id;
Expand Down
3 changes: 3 additions & 0 deletions xllm/c_api/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,9 @@ typedef struct XLLM_CAPI_EXPORT XLLM_RequestParams {
/** Beam search width (0 = disable beam search) */
uint32_t beam_width;

/** Final number of beam search results to return (0 = use beam_width) */
uint32_t num_return_sequences;

/** Number of top log probabilities to return */
int64_t top_logprobs;

Expand Down
2 changes: 2 additions & 0 deletions xllm/cc_api/internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ RequestParams transfer_request_params(
xllm_request_params.stop = request_params.stop;
xllm_request_params.stop_token_ids = request_params.stop_token_ids;
xllm_request_params.beam_width = request_params.beam_width;
xllm_request_params.num_return_sequences =
request_params.num_return_sequences;
xllm_request_params.top_logprobs = request_params.top_logprobs;

return xllm_request_params;
Expand Down
2 changes: 2 additions & 0 deletions xllm/cc_api/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,8 @@ struct XLLM_CAPI_EXPORT XLLM_RequestParams {

int32_t beam_width = 0;

int32_t num_return_sequences = 0;

// Number of top log probabilities to return. default = 0.
int64_t top_logprobs = 0;

Expand Down
5 changes: 4 additions & 1 deletion xllm/core/distributed_runtime/rec_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -675,7 +675,10 @@ ForwardOutput RecEngine::OneRecEnginePipeline::step(
}

timer.reset();
batches[0].process_sample_output(decode_output.sample_output, false);
batches[0].process_sample_output(
decode_output.sample_output,
false,
/*force_requested_beam_result_size=*/i + 1 == kRecDecodeSteps);
COUNTER_ADD(rec_sampling_latency_microseconds,
timer.elapsed_microseconds());
}
Expand Down
1 change: 1 addition & 0 deletions xllm/core/distributed_runtime/rec_master.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -754,6 +754,7 @@ std::shared_ptr<Request> RecMaster::build_request_common(
sampling_param.top_logprobs = sp.top_logprobs;
sampling_param.is_embeddings = sp.is_embeddings;
sampling_param.beam_width = sp.beam_width;
sampling_param.num_return_sequences = sp.num_return_sequences;
if (best_of > sp.n) {
sampling_param.logprobs = true;
}
Expand Down
9 changes: 5 additions & 4 deletions xllm/core/framework/batch/batch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -619,7 +619,8 @@ void Batch::process_beam_sequence_group(const ForwardOutput& output) {
}

void Batch::process_sample_output(const SampleOutput& sample_output,
bool replace_fake_token) {
bool replace_fake_token,
bool force_requested_beam_result_size) {
if (sample_output.embeddings.defined()) {
const int64_t num_seqs = sample_output.embeddings.size(0);
int64_t output_idx = 0;
Expand Down Expand Up @@ -673,7 +674,7 @@ void Batch::process_sample_output(const SampleOutput& sample_output,
}

if (!FLAGS_enable_schedule_overlap || replace_fake_token) {
process_beam_search();
process_beam_search(force_requested_beam_result_size);
}
}

Expand Down Expand Up @@ -721,9 +722,9 @@ void Batch::append_token_for_sequence(Sequence* seq,
}
}

void Batch::process_beam_search() {
void Batch::process_beam_search(bool force_requested_result_size) {
Comment thread
DragonFive marked this conversation as resolved.
for (auto* sequence_group : sequence_groups_) {
sequence_group->process_beam_search();
sequence_group->process_beam_search(force_requested_result_size);
}
}

Expand Down
5 changes: 3 additions & 2 deletions xllm/core/framework/batch/batch.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,8 @@ class Batch {
// The boolean parameter `replace_fake_token` indicates
// whether the current stage is the second stage.
void process_sample_output(const SampleOutput& sample_output,
bool replace_fake_token);
bool replace_fake_token,
bool force_requested_beam_result_size = false);

void process_sample_output(const RawForwardOutput& raw_output,
bool replace_fake_token);
Expand Down Expand Up @@ -156,7 +157,7 @@ class Batch {
int token_idx,
bool replace_fake_token);

void process_beam_search();
void process_beam_search(bool force_requested_result_size = false);
bool has_partial_finished_beam_group() const;

std::unordered_map<uint32_t, uint32_t> cal_seq_exchange_index(
Expand Down
24 changes: 24 additions & 0 deletions xllm/core/framework/request/request_params.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,9 @@ RequestParams::RequestParams(const proto::CompletionRequest& request,
if (request.has_beam_width()) {
beam_width = request.beam_width();
}
if (request.has_num_return_sequences()) {
num_return_sequences = request.num_return_sequences();
}
if (request.has_add_special_tokens()) {
add_special_tokens = request.add_special_tokens();
} else {
Expand Down Expand Up @@ -414,6 +417,9 @@ void init_from_chat_request(RequestParams& params, const ChatRequest& request) {
if (request.has_beam_width()) {
params.beam_width = request.beam_width();
}
if (request.has_num_return_sequences()) {
params.num_return_sequences = request.num_return_sequences();
}

if (request.has_add_special_tokens()) {
params.add_special_tokens = request.add_special_tokens();
Expand Down Expand Up @@ -576,6 +582,24 @@ bool RequestParams::verify_params(OutputCallback callback) const {
return false;
}

if (num_return_sequences > 0) {
if (beam_width <= 0) {
CALLBACK_WITH_ERROR(StatusCode::INVALID_ARGUMENT,
"num_return_sequences requires beam_width > 0",
service_request_id,
source_xservice_addr);
return false;
}
if (num_return_sequences < beam_width) {
CALLBACK_WITH_ERROR(
StatusCode::INVALID_ARGUMENT,
"num_return_sequences must be greater than or equal to beam_width",
service_request_id,
source_xservice_addr);
return false;
}
}
Comment thread
DragonFive marked this conversation as resolved.

if (logprobs) {
if (echo) {
CALLBACK_WITH_ERROR(StatusCode::INVALID_ARGUMENT,
Expand Down
1 change: 1 addition & 0 deletions xllm/core/framework/request/request_params.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ struct RequestParams {

// beam search
int32_t beam_width = 0;
int32_t num_return_sequences = 0;

bool add_special_tokens = false;

Expand Down
44 changes: 34 additions & 10 deletions xllm/core/framework/request/sequences_group.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ void SequencesGroup::generate_outputs_parallel(
}
}

void SequencesGroup::process_beam_search() {
void SequencesGroup::process_beam_search(bool force_requested_result_size) {
if (!check_beam_search()) {
return;
}
Expand All @@ -229,6 +229,12 @@ void SequencesGroup::process_beam_search() {
}

const size_t beam_width = sequence_params_.sampling_param->beam_width;
const int32_t requested_num_return_sequences =
sequence_params_.sampling_param->num_return_sequences > 0
? sequence_params_.sampling_param->num_return_sequences
: sequence_params_.sampling_param->beam_width;
const size_t requested_result_size =
static_cast<size_t>(requested_num_return_sequences);
const size_t topk =
std::max<size_t>(1, sequence_params_.sampling_param->top_logprobs);

Expand Down Expand Up @@ -261,7 +267,9 @@ void SequencesGroup::process_beam_search() {
build_source_info(sequences_[seq_index].get());
}

SimpleTopKOptimizerBeamCandidate topk_optimizer(beam_width);
const size_t target_result_size =
force_requested_result_size ? requested_result_size : beam_width;
SimpleTopKOptimizerBeamCandidate topk_optimizer(target_result_size);
auto add_self_candidate = [&](size_t seq_index, Sequence* seq) {
BeamCandidate candidate;
candidate.source_index = seq_index;
Expand Down Expand Up @@ -324,15 +332,22 @@ void SequencesGroup::process_beam_search() {
return;
}

const size_t result_size = std::min(beam_width, candidates.size());
const size_t result_size = std::min(target_result_size, candidates.size());
CHECK(!sequences_.empty());

std::vector<std::unique_ptr<Sequence>> replacement_sequences(result_size);
const size_t existing_size = sequences_.size();
const size_t existing_result_size = std::min(result_size, existing_size);
std::vector<std::unique_ptr<Sequence>> replacement_sequences(
existing_result_size);
std::vector<std::unique_ptr<Sequence>> tail_sequences;
if (result_size > existing_size) {
tail_sequences.reserve(result_size - existing_size);
}
for (size_t i = 0; i < result_size; ++i) {
const BeamCandidate& candidate = candidates[i];
const BeamSourceInfo& source_info = source_infos[candidate.source_index];
const bool need_replace =
i >= sequences_.size() || sequences_[i] == nullptr ||
i >= existing_size || sequences_[i] == nullptr ||
sequences_[i]->num_prompt_tokens() != source_info.suffix_start_idx ||
sequences_[i]->num_tokens() - source_info.suffix_start_idx !=
source_info.generated_token_ids.size();
Expand All @@ -342,19 +357,28 @@ void SequencesGroup::process_beam_search() {

CHECK_LT(candidate.source_index, sequences_.size());
CHECK(sequences_[candidate.source_index] != nullptr);
replacement_sequences[i] =
std::make_unique<Sequence>(*sequences_[candidate.source_index]);
if (i < existing_size) {
replacement_sequences[i] =
std::make_unique<Sequence>(*sequences_[candidate.source_index]);
} else {
tail_sequences.emplace_back(
std::make_unique<Sequence>(*sequences_[candidate.source_index]));
}
}

if (sequences_.size() < result_size) {
sequences_.resize(result_size);
if (existing_size < result_size) {
sequences_.reserve(result_size);
for (auto& tail_sequence : tail_sequences) {
sequences_.emplace_back(std::move(tail_sequence));
}
}

std::unordered_set<size_t> reused_src;
for (size_t i = 0; i < result_size; ++i) {
const BeamCandidate& candidate = candidates[i];
const BeamSourceInfo& source_info = source_infos[candidate.source_index];
if (replacement_sequences[i] != nullptr) {
if (i < replacement_sequences.size() &&
replacement_sequences[i] != nullptr) {
sequences_[i] = std::move(replacement_sequences[i]);
}
auto& next_seq = sequences_[i];
Expand Down
2 changes: 1 addition & 1 deletion xllm/core/framework/request/sequences_group.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class SequencesGroup {
const Tokenizer& tokenizer,
ThreadPool* thread_pool = nullptr);

void process_beam_search();
void process_beam_search(bool force_requested_result_size = false);

bool check_beam_search() {
return sequence_params_.sampling_param->beam_width > 1;
Expand Down
1 change: 1 addition & 0 deletions xllm/core/framework/sampling/sampling_params.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ struct RequestSamplingParam {
bool do_sample = false;
bool is_embeddings = false;
int32_t beam_width = 0;
int32_t num_return_sequences = 0;
};

struct SamplingParameters {
Expand Down
5 changes: 4 additions & 1 deletion xllm/proto/chat.proto
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ message ChatMessage {
}


// Next Id: 30
// Next Id: 45
message ChatRequest {
// ID of the model to use. You can use the ListModels endpoint to list available models.
string model = 1;
Expand Down Expand Up @@ -129,6 +129,9 @@ message ChatRequest {

optional int32 beam_width = 33;

// Final number of beam results to return. Defaults to beam_width when unset.
optional int32 num_return_sequences = 45;

optional bool add_special_tokens = 34;

optional google.protobuf.Struct chat_template_kwargs = 35;
Expand Down
5 changes: 4 additions & 1 deletion xllm/proto/completion.proto
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ package xllm.proto;
import "common.proto";
import "rec.proto";

// Next ID: 26
// Next ID: 41
message CompletionRequest {
// ID of the model to use. (required)
// You can use the ListModels endpoint to list available models.
Expand Down Expand Up @@ -97,6 +97,9 @@ message CompletionRequest {

optional int32 beam_width = 29;

// Final number of beam results to return. Defaults to beam_width when unset.
optional int32 num_return_sequences = 41;

optional bool add_special_tokens = 30;
// tensor for rec embedding.
repeated InferInputTensor input_tensors = 31;
Expand Down
5 changes: 4 additions & 1 deletion xllm/proto/multimodal.proto
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ message MMChatMessage {
optional string tool_call_id = 4;
}

// Next Id: 27
// Next Id: 43
message MMChatRequest {

// ID of the model to use. You can use the ListModels endpoint to list available models.
Expand Down Expand Up @@ -145,6 +145,9 @@ message MMChatRequest {

optional int32 beam_width = 32;

// Final number of beam results to return. Defaults to beam_width when unset.
optional int32 num_return_sequences = 42;

optional bool add_special_tokens = 33;

optional google.protobuf.Struct chat_template_kwargs = 34;
Expand Down
2 changes: 2 additions & 0 deletions xllm/pybind/bind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,8 @@ PYBIND11_MODULE(xllm_export, m) {
.def_readwrite("stop", &RequestParams::stop)
.def_readwrite("stop_token_ids", &RequestParams::stop_token_ids)
.def_readwrite("beam_width", &RequestParams::beam_width)
.def_readwrite("num_return_sequences",
Comment thread
DragonFive marked this conversation as resolved.
&RequestParams::num_return_sequences)
.def_readwrite("add_special_tokens", &RequestParams::add_special_tokens)
.def_readwrite("is_sample_request", &RequestParams::is_sample_request)
.def_readwrite("sample_slots", &RequestParams::sample_slots);
Expand Down
Loading