Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 72 additions & 21 deletions xllm/api_service/rec_completion_service_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,12 @@ limitations under the License.

namespace xllm {
namespace {
struct RecEmitRecord {
int32_t output_index = 0;
int64_t item_id = 0;
std::optional<RecItemInfo> item_info;
};

void append_rec_logprobs(proto::InferTensorContents* logprobs_context,
const SequenceOutput& output,
int32_t expected_count) {
Expand Down Expand Up @@ -120,13 +126,65 @@ bool send_result_to_client_brpc_rec(std::shared_ptr<CompletionCall> call,

if (FLAGS_enable_convert_tokens_to_item) {
output_tensor->set_datatype(proto::DataType::INT64);
const int32_t output_count =
static_cast<int32_t>(req_output.outputs.size());
output_tensor->mutable_shape()->Add(output_count);
proto::InferOutputTensor* did_tensor = nullptr;
proto::InferOutputTensor* type_tensor = nullptr;
if (FLAGS_enable_extended_item_info) {
did_tensor = response.mutable_output_tensors()->Add();
did_tensor->set_name("item_did");
did_tensor->set_datatype(proto::DataType::STRING);

type_tensor = response.mutable_output_tensors()->Add();
type_tensor->set_name("item_type");
type_tensor->set_datatype(proto::DataType::STRING);
}

std::vector<RecEmitRecord> emitted_items;
emitted_items.reserve(req_output.outputs.size());
const int32_t total_threshold = FLAGS_total_conversion_threshold;
for (int32_t i = 0; i < static_cast<int32_t>(req_output.outputs.size());
++i) {
const auto& output = req_output.outputs[i];
if (!output.item_ids_list.empty()) {
const bool has_item_infos =
output.item_infos_list.size() == output.item_ids_list.size();
for (size_t item_idx = 0; item_idx < output.item_ids_list.size();
++item_idx) {
if (static_cast<int32_t>(emitted_items.size()) >= total_threshold) {
break;
}
std::optional<RecItemInfo> item_info;
if (has_item_infos) {
item_info = output.item_infos_list[item_idx];
}
RecEmitRecord emitted_item;
emitted_item.output_index = i;
emitted_item.item_id = output.item_ids_list[item_idx];
emitted_item.item_info = std::move(item_info);
emitted_items.emplace_back(std::move(emitted_item));
}
} else if (output.item_ids.has_value() &&
static_cast<int32_t>(emitted_items.size()) < total_threshold) {
RecEmitRecord emitted_item;
emitted_item.output_index = i;
emitted_item.item_id = output.item_ids.value();
emitted_item.item_info = output.item_info;
emitted_items.emplace_back(std::move(emitted_item));
}
if (static_cast<int32_t>(emitted_items.size()) >= total_threshold) {
break;
}
}

const int32_t emitted_count = static_cast<int32_t>(emitted_items.size());
output_tensor->mutable_shape()->Add(emitted_count);
if (logprobs_tensor != nullptr) {
logprobs_tensor->mutable_shape()->Add(output_count);
logprobs_tensor->mutable_shape()->Add(emitted_count);
logprobs_tensor->mutable_shape()->Add(logprob_width);
}
if (did_tensor != nullptr && type_tensor != nullptr) {
did_tensor->mutable_shape()->Add(emitted_count);
type_tensor->mutable_shape()->Add(emitted_count);
}

auto* output_context = output_tensor->mutable_contents();
auto* logprobs_context = logprobs_tensor == nullptr
Expand All @@ -138,23 +196,16 @@ bool send_result_to_client_brpc_rec(std::shared_ptr<CompletionCall> call,
logprobs_context, req_output.outputs[output_index], logprob_width);
}
};
int32_t total_count = 0;
const int32_t total_threshold = FLAGS_total_conversion_threshold;
for (int32_t i = 0; i < output_count; ++i) {
const auto& output = req_output.outputs[i];
if (!output.item_ids_list.empty()) {
for (const int64_t item_id : output.item_ids_list) {
if (total_count >= total_threshold) {
break;
}
output_context->mutable_int64_contents()->Add(item_id);
append_output_logprobs(i);
++total_count;
}
} else if (output.item_ids.has_value() && total_count < total_threshold) {
output_context->mutable_int64_contents()->Add(output.item_ids.value());
append_output_logprobs(i);
++total_count;
for (const RecEmitRecord& emitted_item : emitted_items) {
output_context->mutable_int64_contents()->Add(emitted_item.item_id);
append_output_logprobs(emitted_item.output_index);
if (did_tensor != nullptr && type_tensor != nullptr) {
did_tensor->mutable_contents()->add_bytes_contents(
emitted_item.item_info.has_value() ? emitted_item.item_info->did
: "");
type_tensor->mutable_contents()->add_bytes_contents(
emitted_item.item_info.has_value() ? emitted_item.item_info->type
: "");
}
Comment thread
DragonFive marked this conversation as resolved.
}
} else {
Expand Down
4 changes: 4 additions & 0 deletions xllm/core/common/global_flags.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -631,6 +631,10 @@ DEFINE_bool(enable_convert_tokens_to_item,
false,
"Enable token ids conversion to item id in REC/OneRec response.");

DEFINE_bool(enable_extended_item_info,
false,
"Enable REC extended item info parsing and output tensors.");
Comment thread
DragonFive marked this conversation as resolved.

DEFINE_int64(dit_cache_start_steps,
5,
"The number of steps to skip at the start");
Expand Down
1 change: 1 addition & 0 deletions xllm/core/common/global_flags.h
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,7 @@ DECLARE_double(dit_cache_residual_diff_threshold);
DECLARE_bool(enable_constrained_decoding);
DECLARE_bool(enable_convert_tokens_to_item);
DECLARE_bool(enable_output_sku_logprobs);
DECLARE_bool(enable_extended_item_info);
DECLARE_int32(each_conversion_threshold);
DECLARE_int32(total_conversion_threshold);

Expand Down
1 change: 1 addition & 0 deletions xllm/core/common/help_formatter.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ const OptionCategory kBeamSearchOptions = {
const OptionCategory kRecOptions = {"REC OPTIONS",
{"enable_rec_fast_sampler",
"enable_convert_tokens_to_item",
"enable_extended_item_info",
"enable_output_sku_logprobs",
"each_conversion_threshold",
"total_conversion_threshold",
Expand Down
9 changes: 9 additions & 0 deletions xllm/core/common/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,15 @@ struct RawToken {
std::vector<float> embeddings; // hidden states
};

struct RecItemInfo {
// The decoded recommendation item id.
int64_t item_id = 0;
// The DID associated with the recommendation item.
std::string did;
// The business type associated with the recommendation item.
std::string type;
};

Comment thread
DragonFive marked this conversation as resolved.
// Weight segment info for D2D transfer (supports non-contiguous allocation)
// Forward declaration needed by InstanceInfo
struct WeightSegment {
Expand Down
6 changes: 6 additions & 0 deletions xllm/core/framework/request/request_output.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,12 @@ struct SequenceOutput {
// decoded item ids for multi-item recommendation output.
std::vector<int64_t> item_ids_list;

// the decoded extended item info for constrained recommendation output.
std::optional<RecItemInfo> item_info;

// decoded extended item infos for multi-item recommendation output.
std::vector<RecItemInfo> item_infos_list;

// the reason the sequence finished.
std::optional<std::string> finish_reason;

Expand Down
66 changes: 57 additions & 9 deletions xllm/core/framework/request/sequence.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ limitations under the License.
#include "core/common/global_flags.h"
#include "core/common/metrics.h"
#include "core/framework/request/mm_data_visitor.h"
#include "core/framework/tokenizer/rec_tokenizer.h"
#include "core/framework/tokenizer/tokenizer.h"
#include "core/util/slice.h"
#include "core/util/tensor_helper.h"
Expand Down Expand Up @@ -70,6 +71,33 @@ std::vector<int64_t> normalize_rec_item_ids(const std::vector<int64_t>& raw_ids,

return item_ids;
}

std::vector<RecItemInfo> normalize_rec_item_infos(
const std::vector<RecItemInfo>& raw_item_infos,
size_t sequence_index) {
std::vector<RecItemInfo> item_infos;
item_infos.reserve(raw_item_infos.size());
std::unordered_set<int64_t> seen_item_ids;
for (const RecItemInfo& item_info : raw_item_infos) {
if (seen_item_ids.insert(item_info.item_id).second) {
item_infos.emplace_back(item_info);
}
}

const int32_t each_threshold = FLAGS_each_conversion_threshold;
if (each_threshold > 0 &&
static_cast<int32_t>(item_infos.size()) > each_threshold) {
uint32_t seed = FLAGS_random_seed >= 0
? static_cast<uint32_t>(FLAGS_random_seed) +
static_cast<uint32_t>(sequence_index)
: std::random_device{}();
std::mt19937 generator(seed);
std::shuffle(item_infos.begin(), item_infos.end(), generator);
item_infos.resize(each_threshold);
}

return item_infos;
}
} // namespace

const std::string Sequence::ENCODER_SPARSE_EMBEDDING_NAME = "sparse_embedding";
Expand Down Expand Up @@ -150,15 +178,35 @@ void Sequence::generate_onerec_output(const Slice<int32_t>& ids,
const size_t rec_token_size = static_cast<size_t>(REC_TOKEN_SIZE);
if (FLAGS_enable_convert_tokens_to_item &&
output.token_ids.size() == rec_token_size) {
std::vector<int64_t> item_ids;
const bool ok = tokenizer.decode(
Slice<int32_t>{output.token_ids.data(), output.token_ids.size()},
sequence_params_.skip_special_tokens,
&item_ids);
if (ok && !item_ids.empty()) {
output.item_ids_list = normalize_rec_item_ids(item_ids, index_);
if (!output.item_ids_list.empty()) {
output.item_ids = output.item_ids_list.front();
const Slice<int32_t> token_slice{output.token_ids.data(),
output.token_ids.size()};
if (FLAGS_enable_extended_item_info) {
const auto* rec_tokenizer = dynamic_cast<const RecTokenizer*>(&tokenizer);
if (rec_tokenizer != nullptr) {
std::vector<RecItemInfo> item_infos;
const bool ok =
rec_tokenizer->decode_item_infos(token_slice, &item_infos);
if (ok && !item_infos.empty()) {
output.item_infos_list = normalize_rec_item_infos(item_infos, index_);
output.item_ids_list.reserve(output.item_infos_list.size());
for (const RecItemInfo& item_info : output.item_infos_list) {
output.item_ids_list.emplace_back(item_info.item_id);
}
if (!output.item_infos_list.empty()) {
output.item_ids = output.item_ids_list.front();
output.item_info = output.item_infos_list.front();
}
}
}
Comment thread
DragonFive marked this conversation as resolved.
} else {
std::vector<int64_t> item_ids;
const bool ok = tokenizer.decode(
token_slice, sequence_params_.skip_special_tokens, &item_ids);
if (ok && !item_ids.empty()) {
output.item_ids_list = normalize_rec_item_ids(item_ids, index_);
if (!output.item_ids_list.empty()) {
output.item_ids = output.item_ids_list.front();
}
}
}
}
Expand Down
Loading
Loading