diff --git a/xllm/api_service/rec_completion_service_impl.cpp b/xllm/api_service/rec_completion_service_impl.cpp index 55a8b71ec..d5e0489e1 100644 --- a/xllm/api_service/rec_completion_service_impl.cpp +++ b/xllm/api_service/rec_completion_service_impl.cpp @@ -44,6 +44,12 @@ limitations under the License. namespace xllm { namespace { +struct RecEmitRecord { + int32_t output_index = 0; + int64_t item_id = 0; + std::optional item_info; +}; + void append_rec_logprobs(proto::InferTensorContents* logprobs_context, const SequenceOutput& output, int32_t expected_count) { @@ -120,13 +126,65 @@ bool send_result_to_client_brpc_rec(std::shared_ptr call, if (FLAGS_enable_convert_tokens_to_item) { output_tensor->set_datatype(proto::DataType::INT64); - const int32_t output_count = - static_cast(req_output.outputs.size()); - output_tensor->mutable_shape()->Add(output_count); + proto::InferOutputTensor* did_tensor = nullptr; + proto::InferOutputTensor* type_tensor = nullptr; + if (FLAGS_enable_extended_item_info) { + did_tensor = response.mutable_output_tensors()->Add(); + did_tensor->set_name("item_did"); + did_tensor->set_datatype(proto::DataType::STRING); + + type_tensor = response.mutable_output_tensors()->Add(); + type_tensor->set_name("item_type"); + type_tensor->set_datatype(proto::DataType::STRING); + } + + std::vector emitted_items; + emitted_items.reserve(req_output.outputs.size()); + const int32_t total_threshold = FLAGS_total_conversion_threshold; + for (int32_t i = 0; i < static_cast(req_output.outputs.size()); + ++i) { + const auto& output = req_output.outputs[i]; + if (!output.item_ids_list.empty()) { + const bool has_item_infos = + output.item_infos_list.size() == output.item_ids_list.size(); + for (size_t item_idx = 0; item_idx < output.item_ids_list.size(); + ++item_idx) { + if (static_cast(emitted_items.size()) >= total_threshold) { + break; + } + std::optional item_info; + if (has_item_infos) { + item_info = output.item_infos_list[item_idx]; + } + RecEmitRecord emitted_item; + emitted_item.output_index = i; + emitted_item.item_id = output.item_ids_list[item_idx]; + emitted_item.item_info = std::move(item_info); + emitted_items.emplace_back(std::move(emitted_item)); + } + } else if (output.item_ids.has_value() && + static_cast(emitted_items.size()) < total_threshold) { + RecEmitRecord emitted_item; + emitted_item.output_index = i; + emitted_item.item_id = output.item_ids.value(); + emitted_item.item_info = output.item_info; + emitted_items.emplace_back(std::move(emitted_item)); + } + if (static_cast(emitted_items.size()) >= total_threshold) { + break; + } + } + + const int32_t emitted_count = static_cast(emitted_items.size()); + output_tensor->mutable_shape()->Add(emitted_count); if (logprobs_tensor != nullptr) { - logprobs_tensor->mutable_shape()->Add(output_count); + logprobs_tensor->mutable_shape()->Add(emitted_count); logprobs_tensor->mutable_shape()->Add(logprob_width); } + if (did_tensor != nullptr && type_tensor != nullptr) { + did_tensor->mutable_shape()->Add(emitted_count); + type_tensor->mutable_shape()->Add(emitted_count); + } auto* output_context = output_tensor->mutable_contents(); auto* logprobs_context = logprobs_tensor == nullptr @@ -138,23 +196,16 @@ bool send_result_to_client_brpc_rec(std::shared_ptr call, logprobs_context, req_output.outputs[output_index], logprob_width); } }; - int32_t total_count = 0; - const int32_t total_threshold = FLAGS_total_conversion_threshold; - for (int32_t i = 0; i < output_count; ++i) { - const auto& output = req_output.outputs[i]; - if (!output.item_ids_list.empty()) { - for (const int64_t item_id : output.item_ids_list) { - if (total_count >= total_threshold) { - break; - } - output_context->mutable_int64_contents()->Add(item_id); - append_output_logprobs(i); - ++total_count; - } - } else if (output.item_ids.has_value() && total_count < total_threshold) { - output_context->mutable_int64_contents()->Add(output.item_ids.value()); - append_output_logprobs(i); - ++total_count; + for (const RecEmitRecord& emitted_item : emitted_items) { + output_context->mutable_int64_contents()->Add(emitted_item.item_id); + append_output_logprobs(emitted_item.output_index); + if (did_tensor != nullptr && type_tensor != nullptr) { + did_tensor->mutable_contents()->add_bytes_contents( + emitted_item.item_info.has_value() ? emitted_item.item_info->did + : ""); + type_tensor->mutable_contents()->add_bytes_contents( + emitted_item.item_info.has_value() ? emitted_item.item_info->type + : ""); } } } else { diff --git a/xllm/core/common/global_flags.cpp b/xllm/core/common/global_flags.cpp index fa24d4d9c..2d68ad796 100644 --- a/xllm/core/common/global_flags.cpp +++ b/xllm/core/common/global_flags.cpp @@ -631,6 +631,10 @@ DEFINE_bool(enable_convert_tokens_to_item, false, "Enable token ids conversion to item id in REC/OneRec response."); +DEFINE_bool(enable_extended_item_info, + false, + "Enable REC extended item info parsing and output tensors."); + DEFINE_int64(dit_cache_start_steps, 5, "The number of steps to skip at the start"); diff --git a/xllm/core/common/global_flags.h b/xllm/core/common/global_flags.h index 59412570d..460ea7f1a 100644 --- a/xllm/core/common/global_flags.h +++ b/xllm/core/common/global_flags.h @@ -296,6 +296,7 @@ DECLARE_double(dit_cache_residual_diff_threshold); DECLARE_bool(enable_constrained_decoding); DECLARE_bool(enable_convert_tokens_to_item); DECLARE_bool(enable_output_sku_logprobs); +DECLARE_bool(enable_extended_item_info); DECLARE_int32(each_conversion_threshold); DECLARE_int32(total_conversion_threshold); diff --git a/xllm/core/common/help_formatter.h b/xllm/core/common/help_formatter.h index 4719eee0a..13002adf1 100644 --- a/xllm/core/common/help_formatter.h +++ b/xllm/core/common/help_formatter.h @@ -100,6 +100,7 @@ const OptionCategory kBeamSearchOptions = { const OptionCategory kRecOptions = {"REC OPTIONS", {"enable_rec_fast_sampler", "enable_convert_tokens_to_item", + "enable_extended_item_info", "enable_output_sku_logprobs", "each_conversion_threshold", "total_conversion_threshold", diff --git a/xllm/core/common/types.h b/xllm/core/common/types.h index dadc09f31..e8163fa47 100644 --- a/xllm/core/common/types.h +++ b/xllm/core/common/types.h @@ -198,6 +198,15 @@ struct RawToken { std::vector embeddings; // hidden states }; +struct RecItemInfo { + // The decoded recommendation item id. + int64_t item_id = 0; + // The DID associated with the recommendation item. + std::string did; + // The business type associated with the recommendation item. + std::string type; +}; + // Weight segment info for D2D transfer (supports non-contiguous allocation) // Forward declaration needed by InstanceInfo struct WeightSegment { diff --git a/xllm/core/framework/request/request_output.h b/xllm/core/framework/request/request_output.h index cfcc8fe7c..3807d8133 100644 --- a/xllm/core/framework/request/request_output.h +++ b/xllm/core/framework/request/request_output.h @@ -77,6 +77,12 @@ struct SequenceOutput { // decoded item ids for multi-item recommendation output. std::vector item_ids_list; + // the decoded extended item info for constrained recommendation output. + std::optional item_info; + + // decoded extended item infos for multi-item recommendation output. + std::vector item_infos_list; + // the reason the sequence finished. std::optional finish_reason; diff --git a/xllm/core/framework/request/sequence.cpp b/xllm/core/framework/request/sequence.cpp index 06ec4985d..a2a3e6786 100644 --- a/xllm/core/framework/request/sequence.cpp +++ b/xllm/core/framework/request/sequence.cpp @@ -33,6 +33,7 @@ limitations under the License. #include "core/common/global_flags.h" #include "core/common/metrics.h" #include "core/framework/request/mm_data_visitor.h" +#include "core/framework/tokenizer/rec_tokenizer.h" #include "core/framework/tokenizer/tokenizer.h" #include "core/util/slice.h" #include "core/util/tensor_helper.h" @@ -70,6 +71,33 @@ std::vector normalize_rec_item_ids(const std::vector& raw_ids, return item_ids; } + +std::vector normalize_rec_item_infos( + const std::vector& raw_item_infos, + size_t sequence_index) { + std::vector item_infos; + item_infos.reserve(raw_item_infos.size()); + std::unordered_set seen_item_ids; + for (const RecItemInfo& item_info : raw_item_infos) { + if (seen_item_ids.insert(item_info.item_id).second) { + item_infos.emplace_back(item_info); + } + } + + const int32_t each_threshold = FLAGS_each_conversion_threshold; + if (each_threshold > 0 && + static_cast(item_infos.size()) > each_threshold) { + uint32_t seed = FLAGS_random_seed >= 0 + ? static_cast(FLAGS_random_seed) + + static_cast(sequence_index) + : std::random_device{}(); + std::mt19937 generator(seed); + std::shuffle(item_infos.begin(), item_infos.end(), generator); + item_infos.resize(each_threshold); + } + + return item_infos; +} } // namespace const std::string Sequence::ENCODER_SPARSE_EMBEDDING_NAME = "sparse_embedding"; @@ -150,15 +178,35 @@ void Sequence::generate_onerec_output(const Slice& ids, const size_t rec_token_size = static_cast(REC_TOKEN_SIZE); if (FLAGS_enable_convert_tokens_to_item && output.token_ids.size() == rec_token_size) { - std::vector item_ids; - const bool ok = tokenizer.decode( - Slice{output.token_ids.data(), output.token_ids.size()}, - sequence_params_.skip_special_tokens, - &item_ids); - if (ok && !item_ids.empty()) { - output.item_ids_list = normalize_rec_item_ids(item_ids, index_); - if (!output.item_ids_list.empty()) { - output.item_ids = output.item_ids_list.front(); + const Slice token_slice{output.token_ids.data(), + output.token_ids.size()}; + if (FLAGS_enable_extended_item_info) { + const auto* rec_tokenizer = dynamic_cast(&tokenizer); + if (rec_tokenizer != nullptr) { + std::vector item_infos; + const bool ok = + rec_tokenizer->decode_item_infos(token_slice, &item_infos); + if (ok && !item_infos.empty()) { + output.item_infos_list = normalize_rec_item_infos(item_infos, index_); + output.item_ids_list.reserve(output.item_infos_list.size()); + for (const RecItemInfo& item_info : output.item_infos_list) { + output.item_ids_list.emplace_back(item_info.item_id); + } + if (!output.item_infos_list.empty()) { + output.item_ids = output.item_ids_list.front(); + output.item_info = output.item_infos_list.front(); + } + } + } + } else { + std::vector item_ids; + const bool ok = tokenizer.decode( + token_slice, sequence_params_.skip_special_tokens, &item_ids); + if (ok && !item_ids.empty()) { + output.item_ids_list = normalize_rec_item_ids(item_ids, index_); + if (!output.item_ids_list.empty()) { + output.item_ids = output.item_ids_list.front(); + } } } } diff --git a/xllm/core/framework/state_dict/rec_vocab_dict.cpp b/xllm/core/framework/state_dict/rec_vocab_dict.cpp index 1d109406c..80c930e07 100644 --- a/xllm/core/framework/state_dict/rec_vocab_dict.cpp +++ b/xllm/core/framework/state_dict/rec_vocab_dict.cpp @@ -4,11 +4,15 @@ #include #include #include +#include #include "common/global_flags.h" #include "util/timer.h" namespace xllm { +namespace { +constexpr uint32_t kMaxExtendedFieldBytes = 1U << 20; +} bool RecVocabDict::initialize(const std::string& vocab_file) { if (initialized_) { @@ -31,55 +35,149 @@ bool RecVocabDict::initialize(const std::string& vocab_file) { return false; } - const size_t file_size = ifs.tellg(); + const std::streamoff file_end = ifs.tellg(); + if (file_end < 0) { + LOG(ERROR) << "Failed to read content data file size: " << vocab_file; + return false; + } + const size_t file_size = static_cast(file_end); ifs.seekg(0, std::ios::beg); - // Each line of content : 1 * int64_t(item id) + REC_TOKEN_SIZE * - // int32_t(token id); const size_t itemid_size = sizeof(int64_t); const size_t tokens_size = REC_TOKEN_SIZE * sizeof(int32_t); - const size_t line_size = tokens_size + itemid_size; - const size_t estimated_lines = (file_size + line_size - 1) / line_size; + const size_t min_line_size = itemid_size + tokens_size; + const size_t estimated_lines = + (file_size + min_line_size - 1) / min_line_size; - // 2 and 4 are only empirical values item_to_tokens_map_.reserve(estimated_lines); - tokens_to_items_map_.reserve(estimated_lines / 2); + tokens_to_item_infos_map_.reserve(estimated_lines / 2); prefix_tokens_to_next_tokens_map_.reserve(estimated_lines / 4); + auto clear_partial_state = [this]() { + item_to_tokens_map_.clear(); + tokens_to_item_infos_map_.clear(); + prefix_tokens_to_next_tokens_map_.clear(); + }; + auto fail_with_error = [&](const std::string& message) { + LOG(ERROR) << message; + clear_partial_state(); + return false; + }; + auto get_remaining_bytes = [&ifs, file_size]() -> size_t { + const std::streamoff current_pos = ifs.tellg(); + if (current_pos < 0) { + return 0; + } + const size_t current_offset = static_cast(current_pos); + return current_offset <= file_size ? file_size - current_offset : 0; + }; + auto validate_extended_field_length = [&](uint32_t field_length, + const char* field_name) { + if (field_length > kMaxExtendedFieldBytes) { + return fail_with_error(std::string("Field length for ") + field_name + + " exceeds limit in " + vocab_file); + } + if (static_cast(field_length) > get_remaining_bytes()) { + return fail_with_error(std::string("Field length for ") + field_name + + " exceeds remaining bytes in " + vocab_file); + } + return true; + }; int64_t item_id = 0; RecTokenTriple tokens; - while (ifs.read(reinterpret_cast(&item_id), itemid_size) && - ifs.read(reinterpret_cast(tokens.data()), tokens_size)) { - if (FLAGS_enable_constrained_decoding) { - for (int i = 0; i < tokens.size(); i++) { - std::vector prefix_tokens; + if (!FLAGS_enable_extended_item_info) { + const size_t line_size = tokens_size + itemid_size; + while (ifs.read(reinterpret_cast(&item_id), itemid_size) && + ifs.read(reinterpret_cast(tokens.data()), tokens_size)) { + if (FLAGS_enable_constrained_decoding) { + for (int32_t i = 0; i < tokens.size(); ++i) { + std::vector prefix_tokens; + for (int32_t j = 0; j < i; ++j) { + prefix_tokens.emplace_back(tokens[j]); + } + prefix_tokens_to_next_tokens_map_[prefix_tokens].insert(tokens[i]); + } + } + + item_to_tokens_map_[item_id] = tokens; + RecItemInfo item_info; + item_info.item_id = item_id; + tokens_to_item_infos_map_[tokens].emplace_back(std::move(item_info)); + } - for (int j = 0; j < i; j++) { - prefix_tokens.emplace_back(tokens[j]); + if (ifs.gcount() != 0 && + ifs.gcount() != static_cast(tokens_size)) { + return fail_with_error("Possibly containing incomplete lines : " + + vocab_file); + } + } else { + while (ifs.read(reinterpret_cast(&item_id), itemid_size)) { + uint32_t did_length = 0; + if (!ifs.read(reinterpret_cast(&did_length), sizeof(uint32_t))) { + return fail_with_error("Failed to read did length from " + vocab_file); + } + if (!validate_extended_field_length(did_length, "did")) { + return false; + } + + std::string did; + if (did_length > 0) { + did.resize(did_length); + if (!ifs.read(did.data(), did_length)) { + return fail_with_error("Failed to read did string from " + + vocab_file); } + } - prefix_tokens_to_next_tokens_map_[prefix_tokens].insert(tokens[i]); + uint32_t type_length = 0; + if (!ifs.read(reinterpret_cast(&type_length), sizeof(uint32_t))) { + return fail_with_error("Failed to read type length from " + vocab_file); + } + if (!validate_extended_field_length(type_length, "type")) { + return false; } - } - item_to_tokens_map_[item_id] = tokens; + std::string type; + if (type_length > 0) { + type.resize(type_length); + if (!ifs.read(type.data(), type_length)) { + return fail_with_error("Failed to read type string from " + + vocab_file); + } + } - tokens_to_items_map_[tokens].emplace_back(item_id); - } + if (!ifs.read(reinterpret_cast(tokens.data()), tokens_size)) { + return fail_with_error("Failed to read token ids from " + vocab_file); + } - if (ifs.gcount() != 0 && ifs.gcount() != line_size) { - LOG(ERROR) << "Possibly containing incomplete lines : " << vocab_file; - item_to_tokens_map_.clear(); - tokens_to_items_map_.clear(); - prefix_tokens_to_next_tokens_map_.clear(); - return false; + if (FLAGS_enable_constrained_decoding) { + for (int32_t i = 0; i < tokens.size(); ++i) { + std::vector prefix_tokens; + for (int32_t j = 0; j < i; ++j) { + prefix_tokens.emplace_back(tokens[j]); + } + prefix_tokens_to_next_tokens_map_[prefix_tokens].insert(tokens[i]); + } + } + + item_to_tokens_map_[item_id] = tokens; + RecItemInfo item_info; + item_info.item_id = item_id; + item_info.did = std::move(did); + item_info.type = std::move(type); + tokens_to_item_infos_map_[tokens].emplace_back(std::move(item_info)); + } + + if (ifs.gcount() > 0 || (!ifs.eof() && ifs.fail())) { + return fail_with_error("Failed while reading " + vocab_file); + } } initialized_ = true; LOG(INFO) << "Total line size:" << estimated_lines << ",parse tokens to item id map size: " - << tokens_to_items_map_.size() + << tokens_to_item_infos_map_.size() << ", parse item to tokens map size:" << item_to_tokens_map_.size() << ", parse prefix tokens to next tokens map size:" << prefix_tokens_to_next_tokens_map_.size() @@ -93,14 +191,32 @@ bool RecVocabDict::get_items_by_tokens(const RecTokenTriple& rec_token_triple, CHECK_EQ(initialized_, true); CHECK_NE(item_ids, nullptr); - auto iter = tokens_to_items_map_.find(rec_token_triple); - if (iter == tokens_to_items_map_.end()) { + std::vector item_infos; + if (!get_item_infos_by_tokens(rec_token_triple, &item_infos)) { return false; } - std::copy( - iter->second.begin(), iter->second.end(), std::back_inserter(*item_ids)); + item_ids->reserve(item_ids->size() + item_infos.size()); + for (const auto& item_info : item_infos) { + item_ids->emplace_back(item_info.item_id); + } + return true; +} + +bool RecVocabDict::get_item_infos_by_tokens( + const RecTokenTriple& rec_token_triple, + std::vector* item_infos) const { + CHECK_EQ(initialized_, true); + CHECK_NE(item_infos, nullptr); + + auto iter = tokens_to_item_infos_map_.find(rec_token_triple); + if (iter == tokens_to_item_infos_map_.end()) { + return false; + } + std::copy(iter->second.begin(), + iter->second.end(), + std::back_inserter(*item_infos)); return true; } @@ -136,4 +252,4 @@ RecVocabDict::get_next_tokens_by_prefix_tokens( return iter->second; } -} // namespace xllm \ No newline at end of file +} // namespace xllm diff --git a/xllm/core/framework/state_dict/rec_vocab_dict.h b/xllm/core/framework/state_dict/rec_vocab_dict.h index 1863e5be7..5c53fb03f 100644 --- a/xllm/core/framework/state_dict/rec_vocab_dict.h +++ b/xllm/core/framework/state_dict/rec_vocab_dict.h @@ -20,6 +20,7 @@ limitations under the License. #include #include #include +#include #include #include @@ -38,7 +39,7 @@ class RecVocabDict final { ~RecVocabDict() { initialized_ = false; item_to_tokens_map_.clear(); - tokens_to_items_map_.clear(); + tokens_to_item_infos_map_.clear(); prefix_tokens_to_next_tokens_map_.clear(); } @@ -59,6 +60,9 @@ class RecVocabDict final { bool get_items_by_tokens(const RecTokenTriple& rec_token_triple, std::vector* item_ids) const; + bool get_item_infos_by_tokens(const RecTokenTriple& rec_token_triple, + std::vector* item_infos) const; + /** * @brief Get the corresponding token ID triplet through a item id * @param item_ids, input item id @@ -88,13 +92,13 @@ class RecVocabDict final { // Check if initialization has been successful bool initialized_ = false; - // Convert token to item map, key: token id triplet, value: item id list, - // there is a token id triplet corresponding to multiple item IDs, and - // boost::hash will generate ordered triplet hash value + // Convert token to item map, key: token id triplet, value: item info list, + // there is a token id triplet corresponding to multiple items, and + // boost::hash will generate ordered triplet hash value. std::unordered_map, + std::vector, boost::hash> - tokens_to_items_map_; + tokens_to_item_infos_map_; // Convert item to tokens map, key: item id, value: token triplet, there is a // item id corresponding to a token id triplet @@ -107,4 +111,4 @@ class RecVocabDict final { boost::hash>> prefix_tokens_to_next_tokens_map_; }; -} // namespace xllm \ No newline at end of file +} // namespace xllm diff --git a/xllm/core/framework/tokenizer/rec_tokenizer.cpp b/xllm/core/framework/tokenizer/rec_tokenizer.cpp index d65943d47..3ce0e960f 100644 --- a/xllm/core/framework/tokenizer/rec_tokenizer.cpp +++ b/xllm/core/framework/tokenizer/rec_tokenizer.cpp @@ -38,6 +38,7 @@ bool RecTokenizer::encode(int64_t item_id, bool RecTokenizer::decode(const Slice& token_ids, bool skip_special_tokens, std::vector* item_ids) const { + (void)skip_special_tokens; CHECK_EQ(token_ids.size(), REC_TOKEN_SIZE); RecTokenTriple rec_token_triple; @@ -47,6 +48,18 @@ bool RecTokenizer::decode(const Slice& token_ids, ->get_items_by_tokens(rec_token_triple, item_ids); } +bool RecTokenizer::decode_item_infos( + const Slice& token_ids, + std::vector* item_infos) const { + CHECK_EQ(token_ids.size(), REC_TOKEN_SIZE); + + RecTokenTriple rec_token_triple; + std::copy(token_ids.begin(), token_ids.end(), rec_token_triple.begin()); + + return VersionSingleton::GetInstance(model_version_) + ->get_item_infos_by_tokens(rec_token_triple, item_infos); +} + size_t RecTokenizer::vocab_size() const { // currently, there is no voice size set in the tokenizer configuration. The // voice size can be obtained from the model args @@ -57,4 +70,4 @@ std::unique_ptr RecTokenizer::clone() const { return std::make_unique(dir_path_, args_); } -} // namespace xllm \ No newline at end of file +} // namespace xllm diff --git a/xllm/core/framework/tokenizer/rec_tokenizer.h b/xllm/core/framework/tokenizer/rec_tokenizer.h index 41b03f0bd..c7ddd5125 100644 --- a/xllm/core/framework/tokenizer/rec_tokenizer.h +++ b/xllm/core/framework/tokenizer/rec_tokenizer.h @@ -23,6 +23,7 @@ limitations under the License. #include #include +#include "core/common/types.h" #include "tokenizer.h" #include "tokenizer_args.h" #include "util/slice.h" @@ -41,6 +42,9 @@ class RecTokenizer : public Tokenizer { bool skip_special_tokens, std::vector* item_ids) const override; + bool decode_item_infos(const Slice& token_ids, + std::vector* item_infos) const; + size_t vocab_size() const override; std::unique_ptr clone() const override;