Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions xllm/c_api/default.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ const XLLM_RequestParams XLLM_LLM_REQUEST_PARAMS_DEFAULT = {
.ttft_slo_ms = INT32_MAX,
.tpot_slo_ms = INT32_MAX,
.beam_width = 0,
.num_return_sequences = 0,
.top_logprobs = 0,
.top_k = -1,
.top_p = 1.0,
Expand Down Expand Up @@ -134,6 +135,7 @@ const XLLM_RequestParams XLLM_REC_REQUEST_PARAMS_DEFAULT = {
.ttft_slo_ms = INT32_MAX,
.tpot_slo_ms = INT32_MAX,
.beam_width = 128,
.num_return_sequences = 0,
.top_logprobs = 0,
.top_k = -1,
.top_p = 1.0,
Expand Down
2 changes: 2 additions & 0 deletions xllm/c_api/internal/helper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,8 @@ void transfer_request_params(InferenceType inference_type,
xllm_request_params->repetition_penalty =
final_request_params.repetition_penalty;
xllm_request_params->beam_width = final_request_params.beam_width;
xllm_request_params->num_return_sequences =
final_request_params.num_return_sequences;
xllm_request_params->top_logprobs = final_request_params.top_logprobs;
xllm_request_params->temperature = final_request_params.temperature;
xllm_request_params->request_id = final_request_params.request_id;
Expand Down
4 changes: 2 additions & 2 deletions xllm/c_api/internal/rec.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,15 +54,15 @@ void reset_pipeline_runtime_toggles() {
FLAGS_enable_graph_mode_decode_no_padding = false;
FLAGS_enable_rec_prefill_only = false;
FLAGS_enable_constrained_decoding = false;
FLAGS_enable_topk_sorted = false;
FLAGS_enable_topk_sorted = true;
}

void apply_multi_round_pipeline_toggles() {
FLAGS_enable_rec_fast_sampler = true;
FLAGS_enable_prefill_piecewise_graph = true;
FLAGS_enable_xattention_one_stage = false;
FLAGS_enable_graph_mode_decode_no_padding = true;
FLAGS_enable_topk_sorted = false;
FLAGS_enable_topk_sorted = true;
}

void apply_onerec_pipeline_toggles(xllm::Options* options) {
Expand Down
3 changes: 3 additions & 0 deletions xllm/c_api/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,9 @@ typedef struct XLLM_CAPI_EXPORT XLLM_RequestParams {
/** Beam search width (0 = disable beam search) */
uint32_t beam_width;

/** Final number of beam search results to return (0 = use beam_width) */
uint32_t num_return_sequences;

/** Number of top log probabilities to return */
int64_t top_logprobs;

Expand Down
2 changes: 2 additions & 0 deletions xllm/cc_api/internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ RequestParams transfer_request_params(
xllm_request_params.stop = request_params.stop;
xllm_request_params.stop_token_ids = request_params.stop_token_ids;
xllm_request_params.beam_width = request_params.beam_width;
xllm_request_params.num_return_sequences =
request_params.num_return_sequences;
xllm_request_params.top_logprobs = request_params.top_logprobs;

return xllm_request_params;
Expand Down
2 changes: 2 additions & 0 deletions xllm/cc_api/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,8 @@ struct XLLM_CAPI_EXPORT XLLM_RequestParams {

int32_t beam_width = 0;

int32_t num_return_sequences = 0;

// Number of top log probabilities to return. default = 0.
int64_t top_logprobs = 0;

Expand Down
1 change: 1 addition & 0 deletions xllm/core/distributed_runtime/llm_master.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,7 @@ std::shared_ptr<Request> LLMMaster::generate_request(
sampling_param.top_logprobs = sp.top_logprobs;
sampling_param.is_embeddings = sp.is_embeddings;
sampling_param.beam_width = sp.beam_width;
sampling_param.num_return_sequences = sp.num_return_sequences;
if (best_of > sp.n) {
// enable logprobs for best_of to generate sequence logprob
sampling_param.logprobs = true;
Expand Down
29 changes: 27 additions & 2 deletions xllm/core/distributed_runtime/rec_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -675,7 +675,10 @@ ForwardOutput RecEngine::OneRecEnginePipeline::step(
}

timer.reset();
batches[0].process_sample_output(decode_output.sample_output, false);
batches[0].process_sample_output(
decode_output.sample_output,
false,
/*force_requested_beam_result_size=*/i + 1 == kRecDecodeSteps);
COUNTER_ADD(rec_sampling_latency_microseconds,
timer.elapsed_microseconds());
}
Expand Down Expand Up @@ -953,8 +956,30 @@ ForwardOutput RecEngine::RecMultiRoundEnginePipeline::get_model_output(

// D2H transfer for beam_sequence_group (multi-round results)
auto& output = forward_output.value();
// TODO. uncomment this in next pr.
output.beam_sequence_group = safe_to(output.beam_sequence_group, torch::kCPU);
if (output.beam_base_logprobs.defined()) {
Comment thread
DragonFive marked this conversation as resolved.
Outdated
output.beam_base_logprobs = safe_to(output.beam_base_logprobs, torch::kCPU);
}
if (output.beam_source_sequence_group.defined()) {
output.beam_source_sequence_group =
safe_to(output.beam_source_sequence_group, torch::kCPU);
}
auto& sample_output = output.sample_output;
if (sample_output.top_tokens.defined()) {
sample_output.top_tokens = safe_to(sample_output.top_tokens, torch::kCPU);
}
if (sample_output.top_logprobs.defined()) {
sample_output.top_logprobs =
safe_to(sample_output.top_logprobs, torch::kCPU);
}
if (output.beam_search_output.src_seq_idxes.defined()) {
output.beam_search_output.src_seq_idxes =
safe_to(output.beam_search_output.src_seq_idxes, torch::kCPU);
}
if (output.beam_search_output.out_tokens.defined()) {
output.beam_search_output.out_tokens =
safe_to(output.beam_search_output.out_tokens, torch::kCPU);
}
if (output.beam_search_output.out_logprobs.defined()) {
output.beam_search_output.out_logprobs =
safe_to(output.beam_search_output.out_logprobs, torch::kCPU);
Expand Down
1 change: 1 addition & 0 deletions xllm/core/distributed_runtime/rec_master.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -754,6 +754,7 @@ std::shared_ptr<Request> RecMaster::build_request_common(
sampling_param.top_logprobs = sp.top_logprobs;
sampling_param.is_embeddings = sp.is_embeddings;
sampling_param.beam_width = sp.beam_width;
sampling_param.num_return_sequences = sp.num_return_sequences;
if (best_of > sp.n) {
sampling_param.logprobs = true;
}
Expand Down
140 changes: 136 additions & 4 deletions xllm/core/framework/batch/batch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ limitations under the License.
#include "common/global_flags.h"
#include "common/metrics.h"
#include "core/util/rec_model_utils.h"
#include "framework/batch/beam_search.h"
Comment thread
DragonFive marked this conversation as resolved.
Outdated
#include "framework/batch/mposition.h"
#include "framework/model/model_args.h"
#include "framework/model/model_input_params.h"
Expand Down Expand Up @@ -64,6 +65,127 @@ Token make_empty_logprob_placeholder(const Sequence& seq) {
return Token(placeholder_token_id);
}

size_t resolve_multi_round_result_size(
const RequestSamplingParam* sampling_param) {
if (sampling_param == nullptr) {
return 0;
}
return static_cast<size_t>(
std::max(sampling_param->beam_width,
sampling_param->resolved_num_return_sequences()));
}

bool expand_multi_round_beam_results(
const ForwardOutput& output,
size_t group_index,
int32_t beam_width,
int32_t total_rounds,
size_t requested_result_size,
std::vector<std::vector<int32_t>>* group_flat2d,
std::vector<float>* last_logprobs) {
CHECK(group_flat2d != nullptr);
CHECK(last_logprobs != nullptr);
if (requested_result_size <= static_cast<size_t>(beam_width)) {
return false;
}
if (!output.beam_base_logprobs.defined() ||
!output.beam_source_sequence_group.defined() ||
!output.sample_output.top_tokens.defined() ||
!output.sample_output.top_logprobs.defined()) {
return false;
}
CHECK_EQ(output.beam_base_logprobs.dim(), 1)
<< "beam_base_logprobs must be 1-D, got "
<< output.beam_base_logprobs.sizes();
Comment thread
DragonFive marked this conversation as resolved.
Outdated
CHECK_EQ(output.beam_source_sequence_group.dim(), 3)
<< "beam_source_sequence_group must be 3-D, got "
<< output.beam_source_sequence_group.sizes();
CHECK_EQ(output.sample_output.top_tokens.dim(), 2)
<< "sample_output.top_tokens must be 2-D, got "
<< output.sample_output.top_tokens.sizes();
CHECK_EQ(output.sample_output.top_logprobs.dim(), 2)
<< "sample_output.top_logprobs must be 2-D, got "
<< output.sample_output.top_logprobs.sizes();

const int64_t row_offset =
static_cast<int64_t>(group_index) * static_cast<int64_t>(beam_width);
CHECK_LE(row_offset + beam_width, output.beam_base_logprobs.size(0))
<< "beam_base_logprobs size is too small, row_offset=" << row_offset
<< ", beam_width=" << beam_width
<< ", size=" << output.beam_base_logprobs.size(0);
CHECK_LE(row_offset + beam_width, output.sample_output.top_tokens.size(0))
<< "sample_output.top_tokens size is too small, row_offset=" << row_offset
<< ", beam_width=" << beam_width
<< ", size=" << output.sample_output.top_tokens.size(0);
CHECK_LE(row_offset + beam_width, output.sample_output.top_logprobs.size(0))
<< "sample_output.top_logprobs size is too small, row_offset="
<< row_offset << ", beam_width=" << beam_width
<< ", size=" << output.sample_output.top_logprobs.size(0);

const auto base_logprob_accessor =
output.beam_base_logprobs.accessor<float, 1>();
const auto prefix_group_accessor =
output.beam_source_sequence_group.accessor<int32_t, 3>();
const auto top_tokens_accessor =
output.sample_output.top_tokens.accessor<int64_t, 2>();
const auto top_logprobs_accessor =
output.sample_output.top_logprobs.accessor<float, 2>();
const int64_t top_count = std::min(output.sample_output.top_tokens.size(1),
output.sample_output.top_logprobs.size(1));
if (top_count <= 0) {
return false;
}

SimpleTopKOptimizerBeamCandidate topk_optimizer(requested_result_size);
for (int32_t source_index = 0; source_index < beam_width; ++source_index) {
const int64_t row_index = row_offset + source_index;
const float base_logprob = base_logprob_accessor[row_index];
for (int64_t top_index = 0; top_index < top_count; ++top_index) {
const float candidate_logprob =
base_logprob + top_logprobs_accessor[row_index][top_index];
if (!topk_optimizer.worthInserting(candidate_logprob)) {
break;
}

BeamCandidate candidate;
candidate.source_index = static_cast<size_t>(source_index);
candidate.logprob_sum = candidate_logprob;
candidate.override_last_token = true;
candidate.last_token_id =
static_cast<int32_t>(top_tokens_accessor[row_index][top_index]);
candidate.last_token_logprob =
top_logprobs_accessor[row_index][top_index];
topk_optimizer.insert(std::move(candidate));
}
}

std::vector<BeamCandidate> candidates = topk_optimizer.getTopKSorted();
if (candidates.empty()) {
return false;
}

const size_t result_size = std::min(requested_result_size, candidates.size());
group_flat2d->clear();
last_logprobs->clear();
group_flat2d->reserve(result_size);
last_logprobs->reserve(result_size);
for (size_t i = 0; i < result_size; ++i) {
const BeamCandidate& candidate = candidates[i];
std::vector<int32_t> row_tokens;
row_tokens.reserve(static_cast<size_t>(total_rounds));
for (int32_t round = 0; round < total_rounds; ++round) {
row_tokens.push_back(
prefix_group_accessor[group_index][candidate.source_index][round]);
}
if (!row_tokens.empty() && candidate.override_last_token) {
row_tokens.back() = candidate.last_token_id;
}
group_flat2d->emplace_back(std::move(row_tokens));
last_logprobs->push_back(candidate.logprob_sum);
}
return true;
}

} // namespace

Batch::Batch(Sequence* sequence) { add(sequence); }
Expand Down Expand Up @@ -571,6 +693,8 @@ void Batch::process_beam_sequence_group(const ForwardOutput& output) {
if (beam_width <= 1) {
return;
}
const size_t requested_result_size =
resolve_multi_round_result_size(sequences[0]->sampling_param());
int32_t total_rounds = get_rec_multi_round_decode_rounds();
size_t num_groups = sequence_groups_.size();
if (num_groups == 0) {
Expand Down Expand Up @@ -610,6 +734,13 @@ void Batch::process_beam_sequence_group(const ForwardOutput& output) {
output.beam_search_output.out_logprobs[logprob_idx].item<float>());
}
}
expand_multi_round_beam_results(output,
Comment thread
DragonFive marked this conversation as resolved.
Outdated
g,
beam_width,
total_rounds,
requested_result_size,
&group_flat2d,
&last_logprobs);
// Access sequence from sequence_groups_ if available
Sequence* seq = sequence_groups_.empty()
? sequences[g]
Expand All @@ -619,7 +750,8 @@ void Batch::process_beam_sequence_group(const ForwardOutput& output) {
}

void Batch::process_sample_output(const SampleOutput& sample_output,
bool replace_fake_token) {
bool replace_fake_token,
bool force_requested_beam_result_size) {
if (sample_output.embeddings.defined()) {
const int64_t num_seqs = sample_output.embeddings.size(0);
int64_t output_idx = 0;
Expand Down Expand Up @@ -673,7 +805,7 @@ void Batch::process_sample_output(const SampleOutput& sample_output,
}

if (!FLAGS_enable_schedule_overlap || replace_fake_token) {
process_beam_search();
process_beam_search(force_requested_beam_result_size);
}
}

Expand Down Expand Up @@ -721,9 +853,9 @@ void Batch::append_token_for_sequence(Sequence* seq,
}
}

void Batch::process_beam_search() {
void Batch::process_beam_search(bool force_requested_result_size) {
Comment thread
DragonFive marked this conversation as resolved.
for (auto* sequence_group : sequence_groups_) {
sequence_group->process_beam_search();
sequence_group->process_beam_search(force_requested_result_size);
}
}

Expand Down
5 changes: 3 additions & 2 deletions xllm/core/framework/batch/batch.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,8 @@ class Batch {
// The boolean parameter `replace_fake_token` indicates
// whether the current stage is the second stage.
void process_sample_output(const SampleOutput& sample_output,
bool replace_fake_token);
bool replace_fake_token,
bool force_requested_beam_result_size = false);

void process_sample_output(const RawForwardOutput& raw_output,
bool replace_fake_token);
Expand Down Expand Up @@ -156,7 +157,7 @@ class Batch {
int token_idx,
bool replace_fake_token);

void process_beam_search();
void process_beam_search(bool force_requested_result_size = false);
bool has_partial_finished_beam_group() const;

std::unordered_map<uint32_t, uint32_t> cal_seq_exchange_index(
Expand Down
Loading
Loading