Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions xllm/api_service/rec_completion_service_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,20 @@ bool send_result_to_client_brpc_rec(std::shared_ptr<CompletionCall> call,
logprobs_tensor->mutable_shape()->Add(logprob_width);
}

decltype(output_tensor) did_tensor = nullptr;
decltype(output_tensor) type_tensor = nullptr;
if (FLAGS_enable_extended_item_info) {
did_tensor = response.mutable_output_tensors()->Add();
did_tensor->set_name("item_did");
did_tensor->set_datatype(proto::DataType::STRING);
did_tensor->mutable_shape()->Add(output_count);

type_tensor = response.mutable_output_tensors()->Add();
type_tensor->set_name("item_type");
type_tensor->set_datatype(proto::DataType::STRING);
type_tensor->mutable_shape()->Add(output_count);
}

auto* output_context = output_tensor->mutable_contents();
auto* logprobs_context = logprobs_tensor == nullptr
? nullptr
Expand Down Expand Up @@ -156,6 +170,14 @@ bool send_result_to_client_brpc_rec(std::shared_ptr<CompletionCall> call,
append_output_logprobs(i);
++total_count;
}

if (did_tensor != nullptr && type_tensor != nullptr) {
const auto& item_info = output.item_info;
did_tensor->mutable_contents()->add_bytes_contents(
item_info.has_value() ? item_info->did : "");
type_tensor->mutable_contents()->add_bytes_contents(
item_info.has_value() ? item_info->type : "");
}
Comment thread
DragonFive marked this conversation as resolved.
}
} else {
output_tensor->set_datatype(proto::DataType::INT32);
Expand Down
4 changes: 4 additions & 0 deletions xllm/core/common/global_flags.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -631,6 +631,10 @@ DEFINE_bool(enable_convert_tokens_to_item,
false,
"Enable token ids conversion to item id in REC/OneRec response.");

DEFINE_bool(enable_extended_item_info,
false,
"Enable REC extended item info parsing and output tensors.");
Comment thread
DragonFive marked this conversation as resolved.

DEFINE_int64(dit_cache_start_steps,
5,
"The number of steps to skip at the start");
Expand Down
1 change: 1 addition & 0 deletions xllm/core/common/global_flags.h
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,7 @@ DECLARE_double(dit_cache_residual_diff_threshold);
DECLARE_bool(enable_constrained_decoding);
DECLARE_bool(enable_convert_tokens_to_item);
DECLARE_bool(enable_output_sku_logprobs);
DECLARE_bool(enable_extended_item_info);
DECLARE_int32(each_conversion_threshold);
DECLARE_int32(total_conversion_threshold);

Expand Down
6 changes: 6 additions & 0 deletions xllm/core/common/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,12 @@ struct RawToken {
std::vector<float> embeddings; // hidden states
};

struct RecItemInfo {
int64_t item_id = 0;
std::string did;
std::string type;
};

Comment thread
DragonFive marked this conversation as resolved.
// Weight segment info for D2D transfer (supports non-contiguous allocation)
// Forward declaration needed by InstanceInfo
struct WeightSegment {
Expand Down
3 changes: 3 additions & 0 deletions xllm/core/framework/request/request_output.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ struct SequenceOutput {
// decoded item ids for multi-item recommendation output.
std::vector<int64_t> item_ids_list;

// the decoded extended item info for constrained recommendation output.
std::optional<RecItemInfo> item_info;

// the reason the sequence finished.
std::optional<std::string> finish_reason;

Expand Down
32 changes: 23 additions & 9 deletions xllm/core/framework/request/sequence.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ limitations under the License.
#include "core/common/global_flags.h"
#include "core/common/metrics.h"
#include "core/framework/request/mm_data_visitor.h"
#include "core/framework/tokenizer/rec_tokenizer.h"
#include "core/framework/tokenizer/tokenizer.h"
#include "core/util/slice.h"
#include "core/util/tensor_helper.h"
Expand Down Expand Up @@ -150,15 +151,28 @@ void Sequence::generate_onerec_output(const Slice<int32_t>& ids,
const size_t rec_token_size = static_cast<size_t>(REC_TOKEN_SIZE);
if (FLAGS_enable_convert_tokens_to_item &&
output.token_ids.size() == rec_token_size) {
std::vector<int64_t> item_ids;
const bool ok = tokenizer.decode(
Slice<int32_t>{output.token_ids.data(), output.token_ids.size()},
sequence_params_.skip_special_tokens,
&item_ids);
if (ok && !item_ids.empty()) {
output.item_ids_list = normalize_rec_item_ids(item_ids, index_);
if (!output.item_ids_list.empty()) {
output.item_ids = output.item_ids_list.front();
const Slice<int32_t> token_slice{output.token_ids.data(),
output.token_ids.size()};
if (FLAGS_enable_extended_item_info) {
const auto* rec_tokenizer = dynamic_cast<const RecTokenizer*>(&tokenizer);
if (rec_tokenizer != nullptr) {
std::vector<RecItemInfo> item_infos;
const bool ok =
rec_tokenizer->decode_item_infos(token_slice, &item_infos);
if (ok && !item_infos.empty()) {
output.item_ids = item_infos.front().item_id;
output.item_info = item_infos.front();
}
}
Comment thread
DragonFive marked this conversation as resolved.
} else {
std::vector<int64_t> item_ids;
const bool ok = tokenizer.decode(
token_slice, sequence_params_.skip_special_tokens, &item_ids);
if (ok && !item_ids.empty()) {
output.item_ids_list = normalize_rec_item_ids(item_ids, index_);
if (!output.item_ids_list.empty()) {
output.item_ids = output.item_ids_list.front();
}
}
}
}
Expand Down
153 changes: 123 additions & 30 deletions xllm/core/framework/state_dict/rec_vocab_dict.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <array>
#include <filesystem>
#include <fstream>
#include <string>

#include "common/global_flags.h"
#include "util/timer.h"
Expand Down Expand Up @@ -34,52 +35,126 @@ bool RecVocabDict::initialize(const std::string& vocab_file) {
const size_t file_size = ifs.tellg();
ifs.seekg(0, std::ios::beg);

// Each line of content : 1 * int64_t(item id) + REC_TOKEN_SIZE *
// int32_t(token id);
const size_t itemid_size = sizeof(int64_t);
const size_t tokens_size = REC_TOKEN_SIZE * sizeof(int32_t);
const size_t line_size = tokens_size + itemid_size;
const size_t estimated_lines = (file_size + line_size - 1) / line_size;
const size_t min_line_size = itemid_size + tokens_size;
const size_t estimated_lines =
(file_size + min_line_size - 1) / min_line_size;

// 2 and 4 are only empirical values
item_to_tokens_map_.reserve(estimated_lines);
tokens_to_items_map_.reserve(estimated_lines / 2);
tokens_to_item_infos_map_.reserve(estimated_lines / 2);
prefix_tokens_to_next_tokens_map_.reserve(estimated_lines / 4);

int64_t item_id = 0;
RecTokenTriple tokens;

while (ifs.read(reinterpret_cast<char*>(&item_id), itemid_size) &&
ifs.read(reinterpret_cast<char*>(tokens.data()), tokens_size)) {
if (FLAGS_enable_constrained_decoding) {
for (int i = 0; i < tokens.size(); i++) {
std::vector<int32_t> prefix_tokens;
if (!FLAGS_enable_extended_item_info) {
const size_t line_size = tokens_size + itemid_size;
while (ifs.read(reinterpret_cast<char*>(&item_id), itemid_size) &&
ifs.read(reinterpret_cast<char*>(tokens.data()), tokens_size)) {
if (FLAGS_enable_constrained_decoding) {
for (int32_t i = 0; i < tokens.size(); ++i) {
std::vector<int32_t> prefix_tokens;
for (int32_t j = 0; j < i; ++j) {
prefix_tokens.emplace_back(tokens[j]);
}
prefix_tokens_to_next_tokens_map_[prefix_tokens].insert(tokens[i]);
}
}

item_to_tokens_map_[item_id] = tokens;
RecItemInfo item_info;
item_info.item_id = item_id;
tokens_to_item_infos_map_[tokens].emplace_back(std::move(item_info));
}

for (int j = 0; j < i; j++) {
prefix_tokens.emplace_back(tokens[j]);
if (ifs.gcount() != 0 &&
ifs.gcount() != static_cast<std::streamsize>(tokens_size)) {
LOG(ERROR) << "Possibly containing incomplete lines : " << vocab_file;
item_to_tokens_map_.clear();
tokens_to_item_infos_map_.clear();
prefix_tokens_to_next_tokens_map_.clear();
return false;
}
} else {
while (ifs.read(reinterpret_cast<char*>(&item_id), itemid_size)) {
uint32_t did_length = 0;
if (!ifs.read(reinterpret_cast<char*>(&did_length), sizeof(uint32_t))) {
break;
}
Comment thread
DragonFive marked this conversation as resolved.

std::string did;
if (did_length > 0) {
did.resize(did_length);
if (!ifs.read(did.data(), did_length)) {
Comment thread
DragonFive marked this conversation as resolved.
LOG(ERROR) << "Failed to read did string from " << vocab_file;
item_to_tokens_map_.clear();
tokens_to_item_infos_map_.clear();
prefix_tokens_to_next_tokens_map_.clear();
return false;
}
}

prefix_tokens_to_next_tokens_map_[prefix_tokens].insert(tokens[i]);
uint32_t type_length = 0;
if (!ifs.read(reinterpret_cast<char*>(&type_length), sizeof(uint32_t))) {
LOG(ERROR) << "Failed to read type length from " << vocab_file;
item_to_tokens_map_.clear();
tokens_to_item_infos_map_.clear();
prefix_tokens_to_next_tokens_map_.clear();
return false;
}
}

item_to_tokens_map_[item_id] = tokens;
std::string type;
if (type_length > 0) {
type.resize(type_length);
if (!ifs.read(type.data(), type_length)) {
LOG(ERROR) << "Failed to read type string from " << vocab_file;
item_to_tokens_map_.clear();
tokens_to_item_infos_map_.clear();
prefix_tokens_to_next_tokens_map_.clear();
return false;
}
}

tokens_to_items_map_[tokens].emplace_back(item_id);
}
if (!ifs.read(reinterpret_cast<char*>(tokens.data()), tokens_size)) {
LOG(ERROR) << "Failed to read token ids from " << vocab_file;
item_to_tokens_map_.clear();
tokens_to_item_infos_map_.clear();
prefix_tokens_to_next_tokens_map_.clear();
return false;
}

if (ifs.gcount() != 0 && ifs.gcount() != line_size) {
LOG(ERROR) << "Possibly containing incomplete lines : " << vocab_file;
item_to_tokens_map_.clear();
tokens_to_items_map_.clear();
prefix_tokens_to_next_tokens_map_.clear();
return false;
if (FLAGS_enable_constrained_decoding) {
for (int32_t i = 0; i < tokens.size(); ++i) {
std::vector<int32_t> prefix_tokens;
for (int32_t j = 0; j < i; ++j) {
prefix_tokens.emplace_back(tokens[j]);
}
prefix_tokens_to_next_tokens_map_[prefix_tokens].insert(tokens[i]);
}
}

item_to_tokens_map_[item_id] = tokens;
RecItemInfo item_info;
item_info.item_id = item_id;
item_info.did = std::move(did);
item_info.type = std::move(type);
tokens_to_item_infos_map_[tokens].emplace_back(std::move(item_info));
}

if (!ifs.eof() && ifs.fail()) {
LOG(ERROR) << "Failed while reading " << vocab_file;
item_to_tokens_map_.clear();
tokens_to_item_infos_map_.clear();
prefix_tokens_to_next_tokens_map_.clear();
return false;
}
}

initialized_ = true;
LOG(INFO) << "Total line size:" << estimated_lines
<< ",parse tokens to item id map size: "
<< tokens_to_items_map_.size()
<< tokens_to_item_infos_map_.size()
<< ", parse item to tokens map size:" << item_to_tokens_map_.size()
<< ", parse prefix tokens to next tokens map size:"
<< prefix_tokens_to_next_tokens_map_.size()
Expand All @@ -93,14 +168,32 @@ bool RecVocabDict::get_items_by_tokens(const RecTokenTriple& rec_token_triple,
CHECK_EQ(initialized_, true);
CHECK_NE(item_ids, nullptr);

auto iter = tokens_to_items_map_.find(rec_token_triple);
if (iter == tokens_to_items_map_.end()) {
std::vector<RecItemInfo> item_infos;
if (!get_item_infos_by_tokens(rec_token_triple, &item_infos)) {
return false;
}

std::copy(
iter->second.begin(), iter->second.end(), std::back_inserter(*item_ids));
item_ids->reserve(item_ids->size() + item_infos.size());
for (const auto& item_info : item_infos) {
item_ids->emplace_back(item_info.item_id);
}
return true;
}

bool RecVocabDict::get_item_infos_by_tokens(
const RecTokenTriple& rec_token_triple,
std::vector<RecItemInfo>* item_infos) const {
CHECK_EQ(initialized_, true);
CHECK_NE(item_infos, nullptr);

auto iter = tokens_to_item_infos_map_.find(rec_token_triple);
if (iter == tokens_to_item_infos_map_.end()) {
return false;
}

std::copy(iter->second.begin(),
iter->second.end(),
std::back_inserter(*item_infos));
return true;
}

Expand Down Expand Up @@ -136,4 +229,4 @@ RecVocabDict::get_next_tokens_by_prefix_tokens(
return iter->second;
}

} // namespace xllm
} // namespace xllm
18 changes: 11 additions & 7 deletions xllm/core/framework/state_dict/rec_vocab_dict.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ limitations under the License.
#include <cstdint>
#include <optional>
#include <set>
#include <string>
#include <unordered_map>
#include <vector>

Expand All @@ -38,7 +39,7 @@ class RecVocabDict final {
~RecVocabDict() {
initialized_ = false;
item_to_tokens_map_.clear();
tokens_to_items_map_.clear();
tokens_to_item_infos_map_.clear();
prefix_tokens_to_next_tokens_map_.clear();
}

Expand All @@ -59,6 +60,9 @@ class RecVocabDict final {
bool get_items_by_tokens(const RecTokenTriple& rec_token_triple,
std::vector<int64_t>* item_ids) const;

bool get_item_infos_by_tokens(const RecTokenTriple& rec_token_triple,
std::vector<RecItemInfo>* item_infos) const;

/**
* @brief Get the corresponding token ID triplet through a item id
* @param item_ids, input item id
Expand Down Expand Up @@ -88,13 +92,13 @@ class RecVocabDict final {
// Check if initialization has been successful
bool initialized_ = false;

// Convert token to item map, key: token id triplet, value: item id list,
// there is a token id triplet corresponding to multiple item IDs, and
// boost::hash<RecTokenTriple> will generate ordered triplet hash value
// Convert token to item map, key: token id triplet, value: item info list,
// there is a token id triplet corresponding to multiple items, and
// boost::hash<RecTokenTriple> will generate ordered triplet hash value.
std::unordered_map<RecTokenTriple,
std::vector<int64_t>,
std::vector<RecItemInfo>,
boost::hash<RecTokenTriple>>
tokens_to_items_map_;
tokens_to_item_infos_map_;

// Convert item to tokens map, key: item id, value: token triplet, there is a
// item id corresponding to a token id triplet
Expand All @@ -107,4 +111,4 @@ class RecVocabDict final {
boost::hash<std::vector<int32_t>>>
prefix_tokens_to_next_tokens_map_;
};
} // namespace xllm
} // namespace xllm
Loading
Loading