Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion third_party/xllm_atb_layers
Submodule xllm_atb_layers updated from d6aa21 to 96d3de
2 changes: 1 addition & 1 deletion xllm/core/framework/model/model_input_params.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ struct OneRecModelInputParams {
torch::Tensor cross_attn_kv_cu_seq_lens;
torch::Tensor cross_attn_new_cache_slots;
torch::Tensor cross_attn_block_tables;
std::vector<int> cross_attn_kv_cu_seq_lens_vec;
std::vector<int32_t> cross_attn_kv_cu_seq_lens_vec;

torch::Tensor encoder_token_ids;
torch::Tensor encoder_positions;
Expand Down
1 change: 1 addition & 0 deletions xllm/core/framework/tokenizer/rec_tokenizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ limitations under the License.
#include <string_view>
#include <vector>

#include "common/types.h"
#include "tokenizer.h"
#include "tokenizer_args.h"
#include "util/slice.h"
Expand Down
601 changes: 514 additions & 87 deletions xllm/core/layers/npu/npu_onerec_block_layer_impl.cpp

Large diffs are not rendered by default.

28 changes: 22 additions & 6 deletions xllm/core/layers/npu/npu_onerec_block_layer_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ class NpuOneRecBlockLayerImpl final : public BaseLayer {
KVCache& kv_cache,
ModelInputParams& input_params,
bool is_prefill,
bool is_first_prefill,
torch::Tensor* encoder_output = nullptr,
int32_t layer_id = 0);

Expand All @@ -94,6 +95,7 @@ class NpuOneRecBlockLayerImpl final : public BaseLayer {
KVCache& kv_cache,
ModelInputParams& input_params,
bool is_prefill,
bool is_first_prefill,
torch::Tensor* encoder_output = nullptr,
int32_t layer_id = 0,
const torch::Tensor& expert_array = torch::Tensor());
Expand All @@ -103,12 +105,16 @@ class NpuOneRecBlockLayerImpl final : public BaseLayer {

int64_t init_attn_mask();

int32_t setup_common_decoder_tensors(atb_speed::Model::Node& node,
torch::Tensor& x,
at::Tensor& attn_mask,
ModelInputParams& input_params,
torch::Tensor* encoder_output = nullptr,
int32_t start_tensor_idx = 0);
int32_t setup_common_decoder_tensors(
atb_speed::Model::Node& node,
torch::Tensor& x,
at::Tensor& attn_mask,
KVCache& kv_cache,
ModelInputParams& input_params,
const atb_speed::onerec::BlockLayerParam& param,
bool is_first_prefill,
torch::Tensor* encoder_output = nullptr,
int32_t start_tensor_idx = 0);

void resize_experts_weights(int32_t num_of_device_experts);
void process_expert_weights(const StateDict& state_dict,
Expand All @@ -129,17 +135,27 @@ class NpuOneRecBlockLayerImpl final : public BaseLayer {
std::string extract_endswith(const std::string& input);

atb_speed::Model::Node prefill_node_;
atb_speed::Model::Node prefill_node_atb_;
atb_speed::Model::Node decode_node_;
atb_speed::Model::Node decoder_prefill_only_decode_node_;
atb_speed::Model::Node decoder_prefill_only_decode_node_atb_;
std::string model_name_;
atb_speed::onerec::BlockLayerParam prefill_param_;
atb_speed::onerec::BlockLayerParam prefill_param_atb_;
atb_speed::onerec::BlockLayerParam decode_param_;
atb_speed::onerec::BlockLayerParam decoder_prefill_only_decode_param_;
atb_speed::onerec::BlockLayerParam decoder_prefill_only_decode_param_atb_;

atb::Tensor internal_tensors_;
atb::Tensor placeholder_;

at::Tensor encoder_output_contiguous_;
at::Tensor at_placeholder_;
std::vector<int32_t> placeholder_vec_;
torch::Tensor cross_k_cache_;
torch::Tensor cross_v_cache_;
torch::Tensor fallback_kv_seq_lens_tensor_;
std::vector<int32_t> seq_lens_vec_;
Comment thread
DragonFive marked this conversation as resolved.

int32_t device_id_ = 0;
bool is_decoder_ = false;
Expand Down
13 changes: 11 additions & 2 deletions xllm/models/rec/npu/onerec.h
Original file line number Diff line number Diff line change
Expand Up @@ -270,8 +270,17 @@ class OneRecForConditionalGenerationImpl final
void load_model(std::unique_ptr<ModelLoader> loader,
std::string prefix = "model.") override {
for (const auto& state_dict : loader->get_state_dicts()) {
StateDict model_state_dict = state_dict->get_dict_with_prefix(prefix);
if (model_state_dict.size() == 0) {
StateDict prefixed_state_dict =
state_dict->get_dict_with_prefix("module.module3.t5_model.");
StateDict model_state_dict =
prefixed_state_dict.size() > 0
? prefixed_state_dict
: state_dict->get_dict_with_prefix(prefix);
if (prefixed_state_dict.size() > 0) {
LOG(INFO) << "Detected temporary OneRec checkpoint prefix "
<< "`module.module3.t5_model.`; loading weights via the "
"compatibility path.";
} else if (model_state_dict.size() == 0) {
model_state_dict = *state_dict;
}
model_->load_state_dict(model_state_dict);
Expand Down
28 changes: 14 additions & 14 deletions xllm/models/rec/npu/onerec_npu_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -251,8 +251,10 @@ class OneRecStackImpl : public torch::nn::Module {
const bool is_decode_stage = is_decoder_ && !is_prefill;
torch::Tensor effective_attn_mask;
if (use_absolute_position_embedding_) {
const int64_t batch_size =
std::max<int64_t>(1, input_params.num_sequences);
effective_attn_mask =
create_moe_attention_mask(query_length, h, is_decoder_);
create_moe_attention_mask(query_length, h, is_decoder_, batch_size);
} else {
effective_attn_mask = compute_position_bias_mask(
query_length, key_length, h, is_decode_stage, input_params);
Expand Down Expand Up @@ -382,24 +384,22 @@ class OneRecStackImpl : public torch::nn::Module {

torch::Tensor create_moe_attention_mask(int64_t seq_length,
const torch::Tensor& h,
bool is_decoder) const {
bool is_decoder,
int64_t batch_size) const {
if (!is_decoder) {
return torch::ones({num_heads_, seq_length, seq_length}, h.options());
}

const float mask_value = -9984.0f;
auto upper_tri_mask =
batch_size = std::max<int64_t>(1, batch_size);
torch::Tensor mask =
torch::triu(torch::ones({seq_length, seq_length},
torch::dtype(h.dtype()).device(h.device())),
1);
auto expanded_mask = upper_tri_mask.unsqueeze(0).expand(
{num_heads_, seq_length, seq_length});
auto effective_attn_mask =
torch::zeros({num_heads_, seq_length, seq_length},
torch::dtype(h.dtype()).device(h.device()));
effective_attn_mask.masked_fill_(expanded_mask.to(torch::kBool),
mask_value);
return effective_attn_mask;
h.options().dtype(torch::kUInt8)),
1)
.unsqueeze(0)
.unsqueeze(0)
.expand({batch_size, 1, seq_length, seq_length})
.contiguous();
return mask;
}

torch::Tensor compute_position_bias_mask(
Expand Down
Loading